From 3e8ae9f73bff33cc64217170dc92c67cf677299b Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Mon, 5 Sep 2022 02:52:37 -0400 Subject: [PATCH 001/133] ARROW-17610: [C++] Support additional source types in SourceNode --- cpp/src/arrow/compute/exec/options.h | 25 ++++ cpp/src/arrow/compute/exec/plan_test.cc | 142 ++++++++++++++++++++++ cpp/src/arrow/compute/exec/source_node.cc | 141 +++++++++++++++++++++ cpp/src/arrow/compute/exec/test_util.cc | 32 +++++ cpp/src/arrow/compute/exec/test_util.h | 12 ++ 5 files changed, 352 insertions(+) diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index a8e8c1ee230..6055a640349 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -77,6 +77,31 @@ class ARROW_EXPORT TableSourceNodeOptions : public ExecNodeOptions { int64_t max_batch_size; }; +/// \brief An extended Source node which accepts a schema +/// +/// ItMaker is a maker of an iterator of tabular data. +template +class ARROW_EXPORT SchemaSourceNodeOptions : public ExecNodeOptions { + public: + SchemaSourceNodeOptions(std::shared_ptr schema, ItMaker it_maker) + : schema(schema), it_maker(it_maker) {} + + // the schema of the record batches from the iterator + std::shared_ptr schema; + + // maker of an iterator which acts as the data source + ItMaker it_maker; +}; + +using ExecBatchIteratorMaker = std::function>()>; +using ExecBatchSourceNodeOptions = SchemaSourceNodeOptions; + +using RecordBatchIteratorMaker = std::function>()>; +using RecordBatchSourceNodeOptions = SchemaSourceNodeOptions; + +using ArrayVectorIteratorMaker = std::function>()>; +using ArrayVectorSourceNodeOptions = SchemaSourceNodeOptions; + /// \brief Make a node which excludes some rows from batches passed through it /// /// filter_expression will be evaluated against each batch which is pushed to diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index e06c41c7489..99d4e675841 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -296,6 +296,148 @@ TEST(ExecPlanExecution, TableSourceSinkError) { Raises(StatusCode::Invalid, HasSubstr("batch_size > 0"))); } +TEST(ExecPlanExecution, ArrayVectorSourceSink) { + for (int num_threads : {1, 4}) { + ASSERT_OK_AND_ASSIGN(auto io_executor, + arrow::internal::ThreadPool::Make(num_threads)); + ExecContext exec_context(default_memory_pool(), io_executor.get()); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_context)); + AsyncGenerator> sink_gen; + + auto exp_batches = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN(auto arrayvecs, ToArrayVectors(exp_batches)); + auto arrayvec_it_maker = [&arrayvecs]() { + return MakeVectorIterator>(arrayvecs); + }; + + ASSERT_OK(Declaration::Sequence( + { + {"array_source", ArrayVectorSourceNodeOptions{exp_batches.schema, + arrayvec_it_maker}}, + {"sink", SinkNodeOptions{&sink_gen}}, + }) + .AddToPlan(plan.get())); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + ASSERT_EQ(res, exp_batches.batches); + } +} + +TEST(ExecPlanExecution, ArrayVectorSourceSinkError) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + std::shared_ptr no_schema; + + auto exp_batches = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN(auto arrayvecs, ToArrayVectors(exp_batches)); + auto arrayvec_it_maker = [&arrayvecs]() { + return MakeVectorIterator>(arrayvecs); + }; + + auto null_executor_options = + ArrayVectorSourceNodeOptions{exp_batches.schema, arrayvec_it_maker}; + ASSERT_THAT(MakeExecNode("array_source", plan.get(), {}, null_executor_options), + Raises(StatusCode::Invalid, HasSubstr("not null"))); + + auto null_schema_options = ArrayVectorSourceNodeOptions{no_schema, arrayvec_it_maker}; + ASSERT_THAT(MakeExecNode("array_source", plan.get(), {}, null_schema_options), + Raises(StatusCode::Invalid, HasSubstr("not null"))); +} + +TEST(ExecPlanExecution, ExecBatchSourceSink) { + for (int num_threads : {1, 4}) { + ASSERT_OK_AND_ASSIGN(auto io_executor, + arrow::internal::ThreadPool::Make(num_threads)); + ExecContext exec_context(default_memory_pool(), io_executor.get()); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_context)); + AsyncGenerator> sink_gen; + + auto exp_batches = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN(auto exec_batches, ToExecBatches(exp_batches)); + auto exec_batch_it_maker = [&exec_batches]() { + return MakeVectorIterator>(exec_batches); + }; + + ASSERT_OK(Declaration::Sequence( + { + {"exec_source", ExecBatchSourceNodeOptions{exp_batches.schema, + exec_batch_it_maker}}, + {"sink", SinkNodeOptions{&sink_gen}}, + }) + .AddToPlan(plan.get())); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + ASSERT_EQ(res, exp_batches.batches); + } +} + +TEST(ExecPlanExecution, ExecBatchSourceSinkError) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + std::shared_ptr no_schema; + + auto exp_batches = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN(auto exec_batches, ToExecBatches(exp_batches)); + auto exec_batch_it_maker = [&exec_batches]() { + return MakeVectorIterator>(exec_batches); + }; + + auto null_executor_options = + ExecBatchSourceNodeOptions{exp_batches.schema, exec_batch_it_maker}; + ASSERT_THAT(MakeExecNode("exec_source", plan.get(), {}, null_executor_options), + Raises(StatusCode::Invalid, HasSubstr("not null"))); + + auto null_schema_options = ExecBatchSourceNodeOptions{no_schema, exec_batch_it_maker}; + ASSERT_THAT(MakeExecNode("exec_source", plan.get(), {}, null_schema_options), + Raises(StatusCode::Invalid, HasSubstr("not null"))); +} + +TEST(ExecPlanExecution, RecordBatchSourceSink) { + for (int num_threads : {1, 4}) { + ASSERT_OK_AND_ASSIGN(auto io_executor, + arrow::internal::ThreadPool::Make(num_threads)); + ExecContext exec_context(default_memory_pool(), io_executor.get()); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_context)); + AsyncGenerator> sink_gen; + + auto exp_batches = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN(auto record_batches, ToRecordBatches(exp_batches)); + auto record_batch_it_maker = [&record_batches]() { + return MakeVectorIterator>(record_batches); + }; + + ASSERT_OK(Declaration::Sequence({ + {"record_source", + RecordBatchSourceNodeOptions{ + exp_batches.schema, record_batch_it_maker}}, + {"sink", SinkNodeOptions{&sink_gen}}, + }) + .AddToPlan(plan.get())); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + ASSERT_EQ(res, exp_batches.batches); + } +} + +TEST(ExecPlanExecution, RecordBatchSourceSinkError) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + std::shared_ptr no_schema; + + auto exp_batches = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN(auto record_batches, ToRecordBatches(exp_batches)); + auto record_batch_it_maker = [&record_batches]() { + return MakeVectorIterator>(record_batches); + }; + + auto null_executor_options = + RecordBatchSourceNodeOptions{exp_batches.schema, record_batch_it_maker}; + ASSERT_THAT(MakeExecNode("record_source", plan.get(), {}, null_executor_options), + Raises(StatusCode::Invalid, HasSubstr("not null"))); + + auto null_schema_options = + RecordBatchSourceNodeOptions{no_schema, record_batch_it_maker}; + ASSERT_THAT(MakeExecNode("record_source", plan.get(), {}, null_schema_options), + Raises(StatusCode::Invalid, HasSubstr("not null"))); +} + TEST(ExecPlanExecution, SinkNodeBackpressure) { util::optional batch = ExecBatchFromJSON({int32(), boolean()}, diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index a640cf737ef..263334f9f1d 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -291,6 +291,144 @@ struct TableSourceNode : public SourceNode { } }; +template +struct SchemaSourceNode : public SourceNode { + SchemaSourceNode(ExecPlan* plan, std::shared_ptr schema, + arrow::AsyncGenerator> generator) + : SourceNode(plan, schema, generator) {} + + static Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 0, This::kKindName)); + const auto& cast_options = checked_cast(options); + auto& it_maker = cast_options.it_maker; + auto& schema = cast_options.schema; + + auto io_executor = plan->exec_context()->executor(); + auto it = it_maker(); + + RETURN_NOT_OK(ValidateSchemaSourceNodeInput(io_executor, schema, This::kKindName)); + ARROW_ASSIGN_OR_RAISE(auto generator, This::MakeGenerator(it, io_executor, schema)); + return plan->EmplaceNode(plan, schema, MakeOrderedGenerator(generator)); + } + + static arrow::Status ValidateSchemaSourceNodeInput( + arrow::internal::Executor* io_executor, const std::shared_ptr& schema, + const char* kKindName) { + if (schema == NULLPTR) { + return Status::Invalid(kKindName, " requires schema which is not null"); + } + if (io_executor == NULLPTR) { + return Status::Invalid(kKindName, " requires IO-executor which is not null"); + } + + return Status::OK(); + } + + template + static arrow::AsyncGenerator> MakeOrderedGenerator( + arrow::AsyncGenerator>& unordered_gen) { + using Enum = Enumerated>; + auto enum_gen = MakeEnumeratedGenerator(unordered_gen); + auto seq_gen = MakeSequencingGenerator( + enum_gen, + /*compare=*/[](const Enum& a, const Enum& b) { return a.index > b.index; }, + /*is_next=*/[](const Enum& a, const Enum& b) { return a.index + 1 == b.index; }, + /*initial_value=*/Enum{{}, 0, false}); + return MakeMappedGenerator(enum_gen, [](const Enum& e) { return e.value; }); + } +}; + +struct RecordBatchSourceNode + : public SchemaSourceNode { + using RecordBatchSchemaSourceNode = + SchemaSourceNode; + + using RecordBatchSchemaSourceNode::Make; + using RecordBatchSchemaSourceNode::RecordBatchSchemaSourceNode; + + const char* kind_name() const override { return kKindName; } + + static Result>> MakeGenerator( + Iterator>& batch_it, + arrow::internal::Executor* io_executor, const std::shared_ptr& schema) { + auto to_exec_batch = + [schema](const std::shared_ptr& batch) -> util::optional { + if (batch == NULLPTR || *batch->schema() != *schema) { + return util::nullopt; + } + return util::optional(ExecBatch(*batch)); + }; + auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(batch_it)); + return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); + } + + static const char kKindName[]; +}; + +const char RecordBatchSourceNode::kKindName[] = "RecordBatchSourceNode"; + +struct ExecBatchSourceNode + : public SchemaSourceNode { + using ExecBatchSchemaSourceNode = + SchemaSourceNode; + + using ExecBatchSchemaSourceNode::ExecBatchSchemaSourceNode; + using ExecBatchSchemaSourceNode::Make; + + const char* kind_name() const override { return kKindName; } + + static Result>> MakeGenerator( + Iterator>& batch_it, + arrow::internal::Executor* io_executor, const std::shared_ptr& schema) { + auto to_exec_batch = + [&schema](const std::shared_ptr& batch) -> util::optional { + return batch == NULLPTR ? util::nullopt : util::optional(*batch); + }; + auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(batch_it)); + return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); + } + + static const char kKindName[]; +}; + +const char ExecBatchSourceNode::kKindName[] = "ExecBatchSourceNode"; + +struct ArrayVectorSourceNode + : public SchemaSourceNode { + using ArrayVectorSchemaSourceNode = + SchemaSourceNode; + + using ArrayVectorSchemaSourceNode::ArrayVectorSchemaSourceNode; + using ArrayVectorSchemaSourceNode::Make; + + const char* kind_name() const override { return kKindName; } + + static Result>> MakeGenerator( + Iterator>& arrayvec_it, + arrow::internal::Executor* io_executor, const std::shared_ptr& schema) { + auto to_exec_batch = + [&schema]( + const std::shared_ptr& arrayvec) -> util::optional { + if (arrayvec == NULLPTR || arrayvec->size() == 0) { + return util::nullopt; + } + std::vector datumvec; + for (const auto& array : *arrayvec) { + datumvec.push_back(Datum(array)); + } + return util::optional( + ExecBatch(std::move(datumvec), (*arrayvec)[0]->length())); + }; + auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(arrayvec_it)); + return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); + } + + static const char kKindName[]; +}; + +const char ArrayVectorSourceNode::kKindName[] = "ArrayVectorSourceNode"; + } // namespace namespace internal { @@ -298,6 +436,9 @@ namespace internal { void RegisterSourceNode(ExecFactoryRegistry* registry) { DCHECK_OK(registry->AddFactory("source", SourceNode::Make)); DCHECK_OK(registry->AddFactory("table_source", TableSourceNode::Make)); + DCHECK_OK(registry->AddFactory("record_source", RecordBatchSourceNode::Make)); + DCHECK_OK(registry->AddFactory("exec_source", ExecBatchSourceNode::Make)); + DCHECK_OK(registry->AddFactory("array_source", ArrayVectorSourceNode::Make)); } } // namespace internal diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index cc26143179a..709a5e8358e 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -259,6 +259,38 @@ BatchesWithSchema MakeBatchesFromString( return out_batches; } +Result>> ToArrayVectors( + const BatchesWithSchema& batches_with_schema) { + std::vector> arrayvecs; + for (auto batch : batches_with_schema.batches) { + ARROW_ASSIGN_OR_RAISE(auto record_batch, + batch.ToRecordBatch(batches_with_schema.schema)); + arrayvecs.push_back(std::make_shared(record_batch->columns())); + } + return arrayvecs; +} + +Result>> ToExecBatches( + const BatchesWithSchema& batches_with_schema) { + std::vector> exec_batches; + for (auto batch : batches_with_schema.batches) { + auto exec_batch = std::make_shared(batch); + exec_batches.push_back(exec_batch); + } + return exec_batches; +} + +Result>> ToRecordBatches( + const BatchesWithSchema& batches_with_schema) { + std::vector> record_batches; + for (auto batch : batches_with_schema.batches) { + ARROW_ASSIGN_OR_RAISE(auto record_batch, + batch.ToRecordBatch(batches_with_schema.schema)); + record_batches.push_back(record_batch); + } + return record_batches; +} + Result> SortTableOnAllFields(const std::shared_ptr& tab) { std::vector sort_keys; for (auto&& f : tab->schema()->fields()) { diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index ac9a4ae4ced..da9d423856c 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -113,6 +113,18 @@ BatchesWithSchema MakeBatchesFromString( const std::shared_ptr& schema, const std::vector& json_strings, int multiplicity = 1); +ARROW_TESTING_EXPORT +Result>> ToArrayVectors( + const BatchesWithSchema& batches_with_schema); + +ARROW_TESTING_EXPORT +Result>> ToExecBatches( + const BatchesWithSchema& batches); + +ARROW_TESTING_EXPORT +Result>> ToRecordBatches( + const BatchesWithSchema& batches); + ARROW_TESTING_EXPORT Result> SortTableOnAllFields(const std::shared_ptr
& tab); From 59427a1f7e225c609d785e5def38d8d2a4e7ecca Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Mon, 5 Sep 2022 10:19:16 -0400 Subject: [PATCH 002/133] fix source ordering --- cpp/src/arrow/compute/exec/options.h | 2 +- cpp/src/arrow/compute/exec/source_node.cc | 73 +++++++++++++++++------ cpp/src/arrow/util/async_generator.h | 2 +- 3 files changed, 57 insertions(+), 20 deletions(-) diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 6055a640349..d5a394fba47 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -84,7 +84,7 @@ template class ARROW_EXPORT SchemaSourceNodeOptions : public ExecNodeOptions { public: SchemaSourceNodeOptions(std::shared_ptr schema, ItMaker it_maker) - : schema(schema), it_maker(it_maker) {} + : schema(schema), it_maker(std::move(it_maker)) {} // the schema of the record batches from the iterator std::shared_ptr schema; diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 263334f9f1d..48ff4b4f9ed 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -309,7 +309,7 @@ struct SchemaSourceNode : public SourceNode { RETURN_NOT_OK(ValidateSchemaSourceNodeInput(io_executor, schema, This::kKindName)); ARROW_ASSIGN_OR_RAISE(auto generator, This::MakeGenerator(it, io_executor, schema)); - return plan->EmplaceNode(plan, schema, MakeOrderedGenerator(generator)); + return plan->EmplaceNode(plan, schema, generator); } static arrow::Status ValidateSchemaSourceNodeInput( @@ -326,16 +326,33 @@ struct SchemaSourceNode : public SourceNode { } template - static arrow::AsyncGenerator> MakeOrderedGenerator( - arrow::AsyncGenerator>& unordered_gen) { - using Enum = Enumerated>; - auto enum_gen = MakeEnumeratedGenerator(unordered_gen); - auto seq_gen = MakeSequencingGenerator( - enum_gen, + static Iterator> MakeEnumeratedIterator(Iterator it) { + struct { + int64_t index = 0; + Enumerated operator()(const Item& item) { + return Enumerated{item, index++, false}; + } + } enumerator; + return MakeMapIterator(std::move(enumerator), std::move(it)); + } + + template + static arrow::AsyncGenerator MakeUnenumeratedGenerator( + const arrow::AsyncGenerator>& enum_gen) { + using Enum = Enumerated; + return MakeMappedGenerator(enum_gen, [](const Enum& e) { return e.value; }); + } + + template + static arrow::AsyncGenerator MakeOrderedGenerator( + const arrow::AsyncGenerator>& unordered_gen) { + using Enum = Enumerated; + auto enum_gen = MakeSequencingGenerator( + unordered_gen, /*compare=*/[](const Enum& a, const Enum& b) { return a.index > b.index; }, /*is_next=*/[](const Enum& a, const Enum& b) { return a.index + 1 == b.index; }, - /*initial_value=*/Enum{{}, 0, false}); - return MakeMappedGenerator(enum_gen, [](const Enum& e) { return e.value; }); + /*initial_value=*/Enum{{}, 0}); + return MakeUnenumeratedGenerator(enum_gen); } }; @@ -344,9 +361,13 @@ struct RecordBatchSourceNode using RecordBatchSchemaSourceNode = SchemaSourceNode; - using RecordBatchSchemaSourceNode::Make; using RecordBatchSchemaSourceNode::RecordBatchSchemaSourceNode; + static Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + return RecordBatchSchemaSourceNode::Make(plan, inputs, options); + } + const char* kind_name() const override { return kKindName; } static Result>> MakeGenerator( @@ -360,7 +381,10 @@ struct RecordBatchSourceNode return util::optional(ExecBatch(*batch)); }; auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(batch_it)); - return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); + auto enum_it = MakeEnumeratedIterator(std::move(exec_batch_it)); + ARROW_ASSIGN_OR_RAISE(auto enum_gen, + MakeBackgroundGenerator(std::move(enum_it), io_executor)); + return MakeUnenumeratedGenerator(std::move(enum_gen)); } static const char kKindName[]; @@ -374,7 +398,11 @@ struct ExecBatchSourceNode SchemaSourceNode; using ExecBatchSchemaSourceNode::ExecBatchSchemaSourceNode; - using ExecBatchSchemaSourceNode::Make; + + static Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + return ExecBatchSchemaSourceNode::Make(plan, inputs, options); + } const char* kind_name() const override { return kKindName; } @@ -382,11 +410,14 @@ struct ExecBatchSourceNode Iterator>& batch_it, arrow::internal::Executor* io_executor, const std::shared_ptr& schema) { auto to_exec_batch = - [&schema](const std::shared_ptr& batch) -> util::optional { + [](const std::shared_ptr& batch) -> util::optional { return batch == NULLPTR ? util::nullopt : util::optional(*batch); }; auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(batch_it)); - return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); + auto enum_it = MakeEnumeratedIterator(std::move(exec_batch_it)); + ARROW_ASSIGN_OR_RAISE(auto enum_gen, + MakeBackgroundGenerator(std::move(enum_it), io_executor)); + return MakeUnenumeratedGenerator(std::move(enum_gen)); } static const char kKindName[]; @@ -400,7 +431,11 @@ struct ArrayVectorSourceNode SchemaSourceNode; using ArrayVectorSchemaSourceNode::ArrayVectorSchemaSourceNode; - using ArrayVectorSchemaSourceNode::Make; + + static Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + return ArrayVectorSchemaSourceNode::Make(plan, inputs, options); + } const char* kind_name() const override { return kKindName; } @@ -408,8 +443,7 @@ struct ArrayVectorSourceNode Iterator>& arrayvec_it, arrow::internal::Executor* io_executor, const std::shared_ptr& schema) { auto to_exec_batch = - [&schema]( - const std::shared_ptr& arrayvec) -> util::optional { + [](const std::shared_ptr& arrayvec) -> util::optional { if (arrayvec == NULLPTR || arrayvec->size() == 0) { return util::nullopt; } @@ -421,7 +455,10 @@ struct ArrayVectorSourceNode ExecBatch(std::move(datumvec), (*arrayvec)[0]->length())); }; auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(arrayvec_it)); - return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); + auto enum_it = MakeEnumeratedIterator(std::move(exec_batch_it)); + ARROW_ASSIGN_OR_RAISE(auto enum_gen, + MakeBackgroundGenerator(std::move(enum_it), io_executor)); + return MakeUnenumeratedGenerator(std::move(enum_gen)); } static const char kKindName[]; diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index 9819b5ce923..172228f7cfd 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -1501,7 +1501,7 @@ AsyncGenerator MakeConcatenatedGenerator(AsyncGenerator> so template struct Enumerated { T value; - int index; + int64_t index; bool last; }; From 9c9c204076e24960a35f82db51ffca649d4fc777 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Mon, 5 Sep 2022 15:47:09 -0400 Subject: [PATCH 003/133] add doc strings --- cpp/src/arrow/compute/exec/options.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index d5a394fba47..1842fe7019a 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -94,11 +94,14 @@ class ARROW_EXPORT SchemaSourceNodeOptions : public ExecNodeOptions { }; using ExecBatchIteratorMaker = std::function>()>; +/// \brief An extended Source node which accepts a schema and exec-batches using ExecBatchSourceNodeOptions = SchemaSourceNodeOptions; using RecordBatchIteratorMaker = std::function>()>; +/// \brief An extended Source node which accepts a schema and record-batches using RecordBatchSourceNodeOptions = SchemaSourceNodeOptions; +/// \brief An extended Source node which accepts a schema and array-vectors using ArrayVectorIteratorMaker = std::function>()>; using ArrayVectorSourceNodeOptions = SchemaSourceNodeOptions; From 7a0ba80702ef63e63e346d2a9ca3137d8baca8bb Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Tue, 6 Sep 2022 14:52:52 +0900 Subject: [PATCH 004/133] ARROW-17081: [Java][Datasets] Move JNI build configuration from cpp/ to java/ (#13911) Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- ci/docker/java-jni-manylinux-201x.dockerfile | 3 +- ci/scripts/java_jni_build.sh | 35 +++++++- ci/scripts/java_jni_macos_build.sh | 32 ++++--- ci/scripts/java_jni_manylinux_build.sh | 39 ++++----- ci/vcpkg/vcpkg.json | 1 + cpp/Brewfile | 2 + cpp/CMakeLists.txt | 19 ++++- cpp/cmake_modules/FindAWSSDKAlt.cmake | 50 +++++++++++ cpp/cmake_modules/FindProtobufAlt.cmake | 32 +++++++ cpp/cmake_modules/ThirdpartyToolchain.cmake | 52 +----------- cpp/src/arrow/ArrowConfig.cmake.in | 10 +++ cpp/src/arrow/dataset/api.h | 6 ++ cpp/src/arrow/filesystem/s3_internal.h | 4 +- cpp/src/arrow/filesystem/s3_test_util.cc | 2 + cpp/src/arrow/util/config.h.cmake | 2 + cpp/src/gandiva/jni/CMakeLists.txt | 8 +- dev/tasks/java-jars/github.yml | 24 +++--- docker-compose.yml | 2 +- docs/source/developers/java/building.rst | 88 +++++++++----------- java/CMakeLists.txt | 13 +++ java/c/CMakeLists.txt | 12 +-- java/dataset/CMakeLists.txt | 37 ++++---- java/dataset/src/main/cpp/CMakeLists.txt | 65 --------------- 23 files changed, 288 insertions(+), 250 deletions(-) create mode 100644 cpp/cmake_modules/FindAWSSDKAlt.cmake create mode 100644 cpp/cmake_modules/FindProtobufAlt.cmake delete mode 100644 java/dataset/src/main/cpp/CMakeLists.txt diff --git a/ci/docker/java-jni-manylinux-201x.dockerfile b/ci/docker/java-jni-manylinux-201x.dockerfile index de953fd5ae0..c77ec63df74 100644 --- a/ci/docker/java-jni-manylinux-201x.dockerfile +++ b/ci/docker/java-jni-manylinux-201x.dockerfile @@ -24,6 +24,7 @@ RUN vcpkg install \ --clean-after-build \ --x-install-root=${VCPKG_ROOT}/installed \ --x-manifest-root=/arrow/ci/vcpkg \ + --x-feature=dev \ --x-feature=flight \ --x-feature=gcs \ --x-feature=json \ @@ -36,7 +37,7 @@ ARG java=1.8.0 RUN yum install -y java-$java-openjdk-devel rh-maven35 && yum clean all ENV JAVA_HOME=/usr/lib/jvm/java-$java-openjdk/ -# For ci/scripts/java_*.sh +# For ci/scripts/{cpp,java}_*.sh ENV ARROW_GANDIVA_JAVA=ON \ ARROW_HOME=/tmp/local \ ARROW_JAVA_CDATA=ON \ diff --git a/ci/scripts/java_jni_build.sh b/ci/scripts/java_jni_build.sh index 0f19e614133..c68b52d77ef 100755 --- a/ci/scripts/java_jni_build.sh +++ b/ci/scripts/java_jni_build.sh @@ -20,9 +20,10 @@ set -ex arrow_dir=${1} -build_dir=${2}/java_jni +arrow_install_dir=${2} +build_dir=${3}/java_jni # The directory where the final binaries will be stored when scripts finish -dist_dir=${3} +dist_dir=${4} echo "=== Clear output directories and leftovers ===" # Clear output directories and leftovers @@ -32,11 +33,37 @@ echo "=== Building Arrow Java C Data Interface native library ===" mkdir -p "${build_dir}" pushd "${build_dir}" +case "$(uname)" in + Linux) + n_jobs=$(nproc) + ;; + Darwin) + n_jobs=$(sysctl -n hw.ncpu) + ;; + *) + n_jobs=${NPROC:-1} + ;; +esac + +: ${ARROW_JAVA_BUILD_TESTS:=${ARROW_BUILD_TESTS:-OFF}} +: ${CMAKE_BUILD_TYPE:=release} cmake \ - -DCMAKE_BUILD_TYPE=${ARROW_BUILD_TYPE:-release} \ + -DARROW_JAVA_JNI_ENABLE_DATASET=${ARROW_DATASET:-ON} \ + -DBUILD_TESTING=${ARROW_JAVA_BUILD_TESTS} \ + -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ + -DCMAKE_PREFIX_PATH=${arrow_install_dir} \ -DCMAKE_INSTALL_PREFIX=${dist_dir} \ -DCMAKE_UNITY_BUILD=${CMAKE_UNITY_BUILD:-OFF} \ + -GNinja \ ${JAVA_JNI_CMAKE_ARGS:-} \ ${arrow_dir}/java -cmake --build . --target install --config ${ARROW_BUILD_TYPE:-release} +export CMAKE_BUILD_PARALLEL_LEVEL=${n_jobs} +cmake --build . --config ${CMAKE_BUILD_TYPE} +if [ "${ARROW_JAVA_BUILD_TESTS}" = "ON" ]; then + ctest \ + --output-on-failure \ + --parallel ${n_jobs} \ + --timeout 300 +fi +cmake --build . --config ${CMAKE_BUILD_TYPE} --target install popd diff --git a/ci/scripts/java_jni_macos_build.sh b/ci/scripts/java_jni_macos_build.sh index 5418daaf011..342bc2d1188 100755 --- a/ci/scripts/java_jni_macos_build.sh +++ b/ci/scripts/java_jni_macos_build.sh @@ -30,7 +30,7 @@ rm -rf ${build_dir} echo "=== Building Arrow C++ libraries ===" install_dir=${build_dir}/cpp-install -: ${ARROW_BUILD_TESTS:=OFF} +: ${ARROW_BUILD_TESTS:=ON} : ${ARROW_DATASET:=ON} : ${ARROW_FILESYSTEM:=ON} : ${ARROW_GANDIVA_JAVA:=ON} @@ -39,7 +39,6 @@ install_dir=${build_dir}/cpp-install : ${ARROW_PARQUET:=ON} : ${ARROW_PLASMA_JAVA_CLIENT:=ON} : ${ARROW_PLASMA:=ON} -: ${ARROW_PYTHON:=OFF} : ${ARROW_S3:=ON} : ${ARROW_USE_CCACHE:=OFF} : ${CMAKE_BUILD_TYPE:=Release} @@ -58,33 +57,23 @@ mkdir -p "${build_dir}/cpp" pushd "${build_dir}/cpp" cmake \ - -DARROW_BOOST_USE_SHARED=OFF \ - -DARROW_BROTLI_USE_SHARED=OFF \ + -DARROW_BUILD_SHARED=OFF \ -DARROW_BUILD_TESTS=${ARROW_BUILD_TESTS} \ -DARROW_BUILD_UTILITIES=OFF \ - -DARROW_BZ2_USE_SHARED=OFF \ + -DARROW_CSV=${ARROW_DATASET} \ -DARROW_DATASET=${ARROW_DATASET} \ + -DARROW_DEPENDENCY_USE_SHARED=OFF \ -DARROW_FILESYSTEM=${ARROW_FILESYSTEM} \ -DARROW_GANDIVA=${ARROW_GANDIVA} \ -DARROW_GANDIVA_JAVA=${ARROW_GANDIVA_JAVA} \ -DARROW_GANDIVA_STATIC_LIBSTDCPP=ON \ - -DARROW_GFLAGS_USE_SHARED=OFF \ - -DARROW_GRPC_USE_SHARED=OFF \ -DARROW_JNI=ON \ - -DARROW_LZ4_USE_SHARED=OFF \ - -DARROW_OPENSSL_USE_SHARED=OFF \ -DARROW_ORC=${ARROW_ORC} \ -DARROW_PARQUET=${ARROW_PARQUET} \ -DARROW_PLASMA=${ARROW_PLASMA} \ -DARROW_PLASMA_JAVA_CLIENT=${ARROW_PLASMA_JAVA_CLIENT} \ - -DARROW_PROTOBUF_USE_SHARED=OFF \ - -DARROW_PYTHON=${ARROW_PYTHON} \ -DARROW_S3=${ARROW_S3} \ - -DARROW_SNAPPY_USE_SHARED=OFF \ - -DARROW_THRIFT_USE_SHARED=OFF \ -DARROW_USE_CCACHE=${ARROW_USE_CCACHE} \ - -DARROW_UTF8PROC_USE_SHARED=OFF \ - -DARROW_ZSTD_USE_SHARED=OFF \ -DAWSSDK_SOURCE=BUNDLED \ -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ -DCMAKE_INSTALL_LIBDIR=lib \ @@ -99,7 +88,16 @@ cmake \ cmake --build . --target install if [ "${ARROW_BUILD_TESTS}" == "ON" ]; then - ctest + # MinIO is required + exclude_tests="arrow-s3fs-test" + # unstable + exclude_tests="${exclude_tests}|arrow-compute-hash-join-node-test" + ctest \ + --exclude-regex "${exclude_tests}" \ + --label-regex unittest \ + --output-on-failure \ + --parallel $(sysctl -n hw.ncpu) \ + --timeout 300 fi popd @@ -107,6 +105,7 @@ popd ${arrow_dir}/ci/scripts/java_jni_build.sh \ ${arrow_dir} \ + ${install_dir} \ ${build_dir} \ ${dist_dir} @@ -117,7 +116,6 @@ fi echo "=== Copying libraries to the distribution folder ===" mkdir -p "${dist_dir}" -cp -L ${install_dir}/lib/libarrow_dataset_jni.dylib ${dist_dir} cp -L ${install_dir}/lib/libarrow_orc_jni.dylib ${dist_dir} cp -L ${install_dir}/lib/libgandiva_jni.dylib ${dist_dir} cp -L ${build_dir}/cpp/*/libplasma_java.dylib ${dist_dir} diff --git a/ci/scripts/java_jni_manylinux_build.sh b/ci/scripts/java_jni_manylinux_build.sh index 331d74b34a1..6669c4fdaa6 100755 --- a/ci/scripts/java_jni_manylinux_build.sh +++ b/ci/scripts/java_jni_manylinux_build.sh @@ -32,7 +32,7 @@ echo "=== Building Arrow C++ libraries ===" devtoolset_version=$(rpm -qa "devtoolset-*-gcc" --queryformat %{VERSION} | \ grep -o "^[0-9]*") devtoolset_include_cpp="/opt/rh/devtoolset-${devtoolset_version}/root/usr/include/c++/${devtoolset_version}" -: ${ARROW_BUILD_TESTS:=OFF} +: ${ARROW_BUILD_TESTS:=ON} : ${ARROW_DATASET:=ON} : ${ARROW_GANDIVA:=ON} : ${ARROW_GANDIVA_JAVA:=ON} @@ -43,10 +43,9 @@ devtoolset_include_cpp="/opt/rh/devtoolset-${devtoolset_version}/root/usr/includ : ${ARROW_PARQUET:=ON} : ${ARROW_PLASMA:=ON} : ${ARROW_PLASMA_JAVA_CLIENT:=ON} -: ${ARROW_PYTHON:=OFF} : ${ARROW_S3:=ON} : ${ARROW_USE_CCACHE:=OFF} -: ${CMAKE_BUILD_TYPE:=Release} +: ${CMAKE_BUILD_TYPE:=release} : ${CMAKE_UNITY_BUILD:=ON} : ${VCPKG_ROOT:=/opt/vcpkg} : ${VCPKG_FEATURE_FLAGS:=-manifests} @@ -66,36 +65,26 @@ mkdir -p "${build_dir}/cpp" pushd "${build_dir}/cpp" cmake \ - -DARROW_BOOST_USE_SHARED=OFF \ - -DARROW_BROTLI_USE_SHARED=OFF \ - -DARROW_BUILD_SHARED=ON \ - -DARROW_BUILD_TESTS=${ARROW_BUILD_TESTS} \ + -DARROW_BUILD_SHARED=OFF \ + -DARROW_BUILD_TESTS=ON \ -DARROW_BUILD_UTILITIES=OFF \ - -DARROW_BZ2_USE_SHARED=OFF \ + -DARROW_CSV=${ARROW_DATASET} \ -DARROW_DATASET=${ARROW_DATASET} \ -DARROW_DEPENDENCY_SOURCE="VCPKG" \ + -DARROW_DEPENDENCY_USE_SHARED=OFF \ -DARROW_FILESYSTEM=${ARROW_FILESYSTEM} \ -DARROW_GANDIVA_JAVA=${ARROW_GANDIVA_JAVA} \ -DARROW_GANDIVA_PC_CXX_FLAGS=${GANDIVA_CXX_FLAGS} \ -DARROW_GANDIVA=${ARROW_GANDIVA} \ - -DARROW_GRPC_USE_SHARED=OFF \ -DARROW_JEMALLOC=${ARROW_JEMALLOC} \ -DARROW_JNI=ON \ - -DARROW_LZ4_USE_SHARED=OFF \ - -DARROW_OPENSSL_USE_SHARED=OFF \ -DARROW_ORC=${ARROW_ORC} \ -DARROW_PARQUET=${ARROW_PARQUET} \ -DARROW_PLASMA_JAVA_CLIENT=${ARROW_PLASMA_JAVA_CLIENT} \ -DARROW_PLASMA=${ARROW_PLASMA} \ - -DARROW_PROTOBUF_USE_SHARED=OFF \ - -DARROW_PYTHON=${ARROW_PYTHON} \ -DARROW_RPATH_ORIGIN=${ARROW_RPATH_ORIGIN} \ -DARROW_S3=${ARROW_S3} \ - -DARROW_SNAPPY_USE_SHARED=OFF \ - -DARROW_THRIFT_USE_SHARED=OFF \ -DARROW_USE_CCACHE=${ARROW_USE_CCACHE} \ - -DARROW_UTF8PROC_USE_SHARED=OFF \ - -DARROW_ZSTD_USE_SHARED=OFF \ -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ -DCMAKE_INSTALL_LIBDIR=lib \ -DCMAKE_INSTALL_PREFIX=${ARROW_HOME} \ @@ -105,16 +94,22 @@ cmake \ -DPARQUET_BUILD_EXAMPLES=OFF \ -DPARQUET_BUILD_EXECUTABLES=OFF \ -DPARQUET_REQUIRE_ENCRYPTION=OFF \ - -DPythonInterp_FIND_VERSION_MAJOR=3 \ - -DPythonInterp_FIND_VERSION=ON \ -DVCPKG_MANIFEST_MODE=OFF \ -DVCPKG_TARGET_TRIPLET=${VCPKG_TARGET_TRIPLET} \ -GNinja \ ${arrow_dir}/cpp ninja install -if [ $ARROW_BUILD_TESTS = "ON" ]; then +if [ "${ARROW_BUILD_TESTS}" = "ON" ]; then + # MinIO is required + exclude_tests="arrow-s3fs-test" + # unstable + exclude_tests="${exclude_tests}|arrow-compute-hash-join-node-test" + exclude_tests="${exclude_tests}|arrow-dataset-scanner-test" + # strptime + exclude_tests="${exclude_tests}|arrow-utility-test" ctest \ + --exclude-regex "${exclude_tests}" \ --label-regex unittest \ --output-on-failure \ --parallel $(nproc) \ @@ -125,11 +120,12 @@ popd JAVA_JNI_CMAKE_ARGS="" -JAVA_JNI_CMAKE_ARGS="${JAVA_JNI_CMAKE_ARGS} -DVCPKG_MANIFEST_MODE=OFF" +JAVA_JNI_CMAKE_ARGS="${JAVA_JNI_CMAKE_ARGS} -DCMAKE_TOOLCHAIN_FILE=${VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake" JAVA_JNI_CMAKE_ARGS="${JAVA_JNI_CMAKE_ARGS} -DVCPKG_TARGET_TRIPLET=${VCPKG_TARGET_TRIPLET}" export JAVA_JNI_CMAKE_ARGS ${arrow_dir}/ci/scripts/java_jni_build.sh \ ${arrow_dir} \ + ${ARROW_HOME} \ ${build_dir} \ ${dist_dir} @@ -140,7 +136,6 @@ fi echo "=== Copying libraries to the distribution folder ===" -cp -L ${ARROW_HOME}/lib/libarrow_dataset_jni.so ${dist_dir} cp -L ${ARROW_HOME}/lib/libarrow_orc_jni.so ${dist_dir} cp -L ${ARROW_HOME}/lib/libgandiva_jni.so ${dist_dir} cp -L ${build_dir}/cpp/*/libplasma_java.so ${dist_dir} diff --git a/ci/vcpkg/vcpkg.json b/ci/vcpkg/vcpkg.json index d9d074e99b0..71c23165e61 100644 --- a/ci/vcpkg/vcpkg.json +++ b/ci/vcpkg/vcpkg.json @@ -43,6 +43,7 @@ "description": "Development dependencies", "dependencies": [ "benchmark", + "boost-process", "gtest" ] }, diff --git a/cpp/Brewfile b/cpp/Brewfile index 9cffd8e3a81..61fb619dc66 100644 --- a/cpp/Brewfile +++ b/cpp/Brewfile @@ -26,6 +26,7 @@ brew "cmake" brew "flatbuffers" brew "git" brew "glog" +brew "googletest" brew "grpc" brew "llvm" brew "llvm@12" @@ -39,4 +40,5 @@ brew "rapidjson" brew "snappy" brew "thrift" brew "wget" +brew "xsimd" brew "zstd" diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 4c0d8f1e91b..6a01f18e6bb 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -786,6 +786,19 @@ endif() if(ARROW_S3) list(APPEND ARROW_SHARED_LINK_LIBS ${AWSSDK_LINK_LIBRARIES}) list(APPEND ARROW_STATIC_LINK_LIBS ${AWSSDK_LINK_LIBRARIES}) + if(AWSSDK_SOURCE STREQUAL "SYSTEM") + list(APPEND + ARROW_STATIC_INSTALL_INTERFACE_LIBS + aws-cpp-sdk-identity-management + aws-cpp-sdk-sts + aws-cpp-sdk-cognito-identity + aws-cpp-sdk-s3 + aws-cpp-sdk-core) + elseif(AWSSDK_SOURCE STREQUAL "BUNDLED") + if(UNIX) + list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS CURL::libcurl) + endif() + endif() endif() if(ARROW_WITH_OPENTELEMETRY) @@ -851,6 +864,9 @@ add_dependencies(arrow_test_dependencies toolchain-tests) if(ARROW_STATIC_LINK_LIBS) add_dependencies(arrow_dependencies ${ARROW_STATIC_LINK_LIBS}) if(ARROW_HDFS OR ARROW_ORC) + if(Protobuf_SOURCE STREQUAL "SYSTEM") + list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS ${ARROW_PROTOBUF_LIBPROTOBUF}) + endif() if(NOT MSVC_TOOLCHAIN) list(APPEND ARROW_STATIC_LINK_LIBS ${CMAKE_DL_LIBS}) list(APPEND ARROW_STATIC_INSTALL_INTERFACE_LIBS ${CMAKE_DL_LIBS}) @@ -977,9 +993,6 @@ if(ARROW_JNI) if(ARROW_ORC) add_subdirectory(../java/adapter/orc/src/main/cpp ./java/orc/jni) endif() - if(ARROW_DATASET) - add_subdirectory(../java/dataset/src/main/cpp ./java/dataset/jni) - endif() endif() if(ARROW_GANDIVA) diff --git a/cpp/cmake_modules/FindAWSSDKAlt.cmake b/cpp/cmake_modules/FindAWSSDKAlt.cmake new file mode 100644 index 00000000000..611184aa1d1 --- /dev/null +++ b/cpp/cmake_modules/FindAWSSDKAlt.cmake @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set(find_package_args) +if(AWSSDKAlt_FIND_VERSION) + list(APPEND find_package_args ${AWSSDKAlt_FIND_VERSION}) +endif() +if(AWSSDKAlt_FIND_QUIETLY) + list(APPEND find_package_args QUIET) +endif() +# See https://aws.amazon.com/blogs/developer/developer-experience-of-the-aws-sdk-for-c-now-simplified-by-cmake/ +# Workaround to force AWS CMake configuration to look for shared libraries +if(DEFINED ENV{CONDA_PREFIX}) + if(DEFINED BUILD_SHARED_LIBS) + set(BUILD_SHARED_LIBS_WAS_SET TRUE) + set(BUILD_SHARED_LIBS_KEEP ${BUILD_SHARED_LIBS}) + else() + set(BUILD_SHARED_LIBS_WAS_SET FALSE) + endif() + set(BUILD_SHARED_LIBS ON) +endif() +find_package(AWSSDK ${find_package_args} + COMPONENTS config + s3 + transfer + identity-management + sts) +# Restore previous value of BUILD_SHARED_LIBS +if(DEFINED ENV{CONDA_PREFIX}) + if(BUILD_SHARED_LIBS_WAS_SET) + set(BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS_KEEP}) + else() + unset(BUILD_SHARED_LIBS) + endif() +endif() +set(AWSSDKAlt_FOUND ${AWSSDK_FOUND}) diff --git a/cpp/cmake_modules/FindProtobufAlt.cmake b/cpp/cmake_modules/FindProtobufAlt.cmake new file mode 100644 index 00000000000..d29f757aeb6 --- /dev/null +++ b/cpp/cmake_modules/FindProtobufAlt.cmake @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(ARROW_PROTOBUF_USE_SHARED) + set(Protobuf_USE_STATIC_LIBS OFF) +else() + set(Protobuf_USE_STATIC_LIBS ON) +endif() + +set(find_package_args) +if(ProtobufAlt_FIND_VERSION) + list(APPEND find_package_args ${ProtobufAlt_FIND_VERSION}) +endif() +if(ProtobufAlt_FIND_QUIETLY) + list(APPEND find_package_args QUIET) +endif() +find_package(Protobuf ${find_package_args}) +set(ProtobufAlt_FOUND ${Protobuf_FOUND}) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 7c3e3a53322..515cdfe8ef4 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -40,12 +40,6 @@ set(ARROW_RE2_LINKAGE "static" CACHE STRING "How to link the re2 library. static|shared (default static)") -if(ARROW_PROTOBUF_USE_SHARED) - set(Protobuf_USE_STATIC_LIBS OFF) -else() - set(Protobuf_USE_STATIC_LIBS ON) -endif() - # ---------------------------------------------------------------------- # Resolve the dependencies @@ -1640,6 +1634,8 @@ if(ARROW_WITH_PROTOBUF) set(ARROW_PROTOBUF_REQUIRED_VERSION "2.6.1") endif() resolve_dependency(Protobuf + HAVE_ALT + TRUE REQUIRED_VERSION ${ARROW_PROTOBUF_REQUIRED_VERSION} PC_PACKAGE_NAMES @@ -4746,49 +4742,7 @@ macro(build_awssdk) endmacro() if(ARROW_S3) - # See https://aws.amazon.com/blogs/developer/developer-experience-of-the-aws-sdk-for-c-now-simplified-by-cmake/ - - # Workaround to force AWS CMake configuration to look for shared libraries - if(DEFINED ENV{CONDA_PREFIX}) - if(DEFINED BUILD_SHARED_LIBS) - set(BUILD_SHARED_LIBS_WAS_SET TRUE) - set(BUILD_SHARED_LIBS_VALUE ${BUILD_SHARED_LIBS}) - else() - set(BUILD_SHARED_LIBS_WAS_SET FALSE) - endif() - set(BUILD_SHARED_LIBS "ON") - endif() - - # Need to customize the find_package() call, so cannot call resolve_dependency() - if(AWSSDK_SOURCE STREQUAL "AUTO") - find_package(AWSSDK - COMPONENTS config - s3 - transfer - identity-management - sts) - if(NOT AWSSDK_FOUND) - build_awssdk() - endif() - elseif(AWSSDK_SOURCE STREQUAL "BUNDLED") - build_awssdk() - elseif(AWSSDK_SOURCE STREQUAL "SYSTEM") - find_package(AWSSDK REQUIRED - COMPONENTS config - s3 - transfer - identity-management - sts) - endif() - - # Restore previous value of BUILD_SHARED_LIBS - if(DEFINED ENV{CONDA_PREFIX}) - if(BUILD_SHARED_LIBS_WAS_SET) - set(BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS_VALUE}) - else() - unset(BUILD_SHARED_LIBS) - endif() - endif() + resolve_dependency(AWSSDK HAVE_ALT TRUE) message(STATUS "Found AWS SDK headers: ${AWSSDK_INCLUDE_DIR}") message(STATUS "Found AWS SDK libraries: ${AWSSDK_LINK_LIBRARIES}") diff --git a/cpp/src/arrow/ArrowConfig.cmake.in b/cpp/src/arrow/ArrowConfig.cmake.in index 0aa298b6658..f0aa1bc959b 100644 --- a/cpp/src/arrow/ArrowConfig.cmake.in +++ b/cpp/src/arrow/ArrowConfig.cmake.in @@ -102,6 +102,16 @@ if(TARGET Arrow::arrow_static AND NOT TARGET Arrow::arrow_bundled_dependencies) PROPERTIES IMPORTED_LOCATION "${arrow_lib_dir}/${CMAKE_STATIC_LIBRARY_PREFIX}arrow_bundled_dependencies${CMAKE_STATIC_LIBRARY_SUFFIX}" ) + + # CMP0057: Support new if() IN_LIST operator. + # https://cmake.org/cmake/help/latest/policy/CMP0057.html + cmake_policy(PUSH) + cmake_policy(SET CMP0057 NEW) + if(APPLE AND "AWS::aws-c-common" IN_LIST ARROW_BUNDLED_STATIC_LIBS) + find_library(CORE_FOUNDATION CoreFoundation) + target_link_libraries(Arrow::arrow_bundled_dependencies INTERFACE ${CORE_FOUNDATION}) + endif() + cmake_policy(POP) endif() macro(arrow_keep_backward_compatibility namespace target_base_name) diff --git a/cpp/src/arrow/dataset/api.h b/cpp/src/arrow/dataset/api.h index 8b81f4c15d1..6e8aab5e9ea 100644 --- a/cpp/src/arrow/dataset/api.h +++ b/cpp/src/arrow/dataset/api.h @@ -23,8 +23,14 @@ #include "arrow/dataset/dataset.h" #include "arrow/dataset/discovery.h" #include "arrow/dataset/file_base.h" +#ifdef ARROW_CSV #include "arrow/dataset/file_csv.h" +#endif #include "arrow/dataset/file_ipc.h" +#ifdef ARROW_ORC #include "arrow/dataset/file_orc.h" +#endif +#ifdef ARROW_PARQUET #include "arrow/dataset/file_parquet.h" +#endif #include "arrow/dataset/scanner.h" diff --git a/cpp/src/arrow/filesystem/s3_internal.h b/cpp/src/arrow/filesystem/s3_internal.h index 0943037aef0..093fdc7ca45 100644 --- a/cpp/src/arrow/filesystem/s3_internal.h +++ b/cpp/src/arrow/filesystem/s3_internal.h @@ -43,7 +43,7 @@ namespace internal { enum class S3Backend { Amazon, Minio, Other }; // Detect the S3 backend type from the S3 server's response headers -S3Backend DetectS3Backend(const Aws::Http::HeaderValueCollection& headers) { +inline S3Backend DetectS3Backend(const Aws::Http::HeaderValueCollection& headers) { const auto it = headers.find("server"); if (it != headers.end()) { const auto& value = util::string_view(it->second); @@ -58,7 +58,7 @@ S3Backend DetectS3Backend(const Aws::Http::HeaderValueCollection& headers) { } template -S3Backend DetectS3Backend(const Aws::Client::AWSError& error) { +inline S3Backend DetectS3Backend(const Aws::Client::AWSError& error) { return DetectS3Backend(error.GetResponseHeaders()); } diff --git a/cpp/src/arrow/filesystem/s3_test_util.cc b/cpp/src/arrow/filesystem/s3_test_util.cc index 1aafb5ec66c..f5a054a8efa 100644 --- a/cpp/src/arrow/filesystem/s3_test_util.cc +++ b/cpp/src/arrow/filesystem/s3_test_util.cc @@ -31,7 +31,9 @@ // includes windows.h. boost/process/args.hpp is included before // boost/process/async.h that includes // boost/asio/detail/socket_types.hpp implicitly is included. +#ifdef __MINGW32__ #include +#endif // We need BOOST_USE_WINDOWS_H definition with MinGW when we use // boost/process.hpp. See BOOST_USE_WINDOWS_H=1 in // cpp/cmake_modules/ThirdpartyToolchain.cmake for details. diff --git a/cpp/src/arrow/util/config.h.cmake b/cpp/src/arrow/util/config.h.cmake index c987a0cae36..9948c1e3587 100644 --- a/cpp/src/arrow/util/config.h.cmake +++ b/cpp/src/arrow/util/config.h.cmake @@ -46,6 +46,8 @@ #cmakedefine ARROW_JEMALLOC #cmakedefine ARROW_JEMALLOC_VENDORED #cmakedefine ARROW_JSON +#cmakedefine ARROW_ORC +#cmakedefine ARROW_PARQUET #cmakedefine ARROW_GCS #cmakedefine ARROW_S3 diff --git a/cpp/src/gandiva/jni/CMakeLists.txt b/cpp/src/gandiva/jni/CMakeLists.txt index 046934141f6..b89356121dc 100644 --- a/cpp/src/gandiva/jni/CMakeLists.txt +++ b/cpp/src/gandiva/jni/CMakeLists.txt @@ -76,9 +76,11 @@ add_arrow_lib(gandiva_jni ${GANDIVA_JNI_SOURCES} OUTPUTS GANDIVA_JNI_LIBRARIES - SHARED_PRIVATE_LINK_LIBS - ${GANDIVA_LINK_LIBS} - STATIC_LINK_LIBS + BUILD_SHARED + ON + BUILD_STATIC + OFF + SHARED_LINK_LIBS ${GANDIVA_LINK_LIBS} DEPENDENCIES ${GANDIVA_LINK_LIBS} diff --git a/dev/tasks/java-jars/github.yml b/dev/tasks/java-jars/github.yml index 23b97087c39..f94a43a8b44 100644 --- a/dev/tasks/java-jars/github.yml +++ b/dev/tasks/java-jars/github.yml @@ -22,12 +22,12 @@ jobs: build-cpp-ubuntu: - name: Build C++ Libs Ubuntu + name: Build C++ libraries Ubuntu runs-on: ubuntu-latest steps: {{ macros.github_checkout_arrow()|indent }} {{ macros.github_install_archery()|indent }} - - name: Build C++ Libs + - name: Build C++ libraries run: | archery docker run \ -e ARROW_JAVA_BUILD=OFF \ @@ -35,27 +35,27 @@ jobs: java-jni-manylinux-2014 - name: Compress into single artifact to keep directory structure run: tar -cvzf arrow-shared-libs-linux.tar.gz arrow/java-dist/ - - name: Upload Artifacts + - name: Upload artifacts uses: actions/upload-artifact@v2 with: name: ubuntu-shared-lib path: arrow-shared-libs-linux.tar.gz {% if arrow.branch == 'master' %} {{ macros.github_login_dockerhub()|indent }} - - name: Push Docker Image + - name: Push Docker image shell: bash run: archery docker push java-jni-manylinux-2014 {% endif %} build-cpp-macos: - name: Build C++ Libs MacOS + name: Build C++ libraries macOS runs-on: macos-latest env: MACOSX_DEPLOYMENT_TARGET: "10.13" steps: {{ macros.github_checkout_arrow()|indent }} {{ macros.github_install_archery()|indent }} - - name: Install Dependencies + - name: Install dependencies run: | brew install --overwrite git brew bundle --file=arrow/cpp/Brewfile @@ -68,7 +68,7 @@ jobs: - name: Setup ccache run: | arrow/ci/scripts/ccache_setup.sh - - name: Build C++ Libs + - name: Build C++ libraries run: | set -e arrow/ci/scripts/java_jni_macos_build.sh \ @@ -77,14 +77,14 @@ jobs: $GITHUB_WORKSPACE/arrow/java-dist - name: Compress into single artifact to keep directory structure run: tar -cvzf arrow-shared-libs-macos.tar.gz arrow/java-dist/ - - name: Upload Artifacts + - name: Upload artifacts uses: actions/upload-artifact@v2 with: name: macos-shared-lib path: arrow-shared-libs-macos.tar.gz package-jars: - name: Build Jar Files + name: Build jar files runs-on: macos-latest needs: [build-cpp-macos, build-cpp-ubuntu] steps: @@ -93,7 +93,7 @@ jobs: uses: actions/download-artifact@v2 with: name: ubuntu-shared-lib - - name: Download MacOS C++ Library + - name: Download macOS C++ libraries uses: actions/download-artifact@v2 with: name: macos-shared-lib @@ -101,7 +101,7 @@ jobs: run: | tar -xvzf arrow-shared-libs-linux.tar.gz tar -xvzf arrow-shared-libs-macos.tar.gz - - name: Test that Shared Libraries Exist + - name: Test that shared libraries exist run: | set -x test -f arrow/java-dist/libarrow_cdata_jni.dylib @@ -114,7 +114,7 @@ jobs: test -f arrow/java-dist/libarrow_orc_jni.so test -f arrow/java-dist/libgandiva_jni.so test -f arrow/java-dist/libplasma_java.so - - name: Build Bundled Jar + - name: Build bundled jar run: | set -e pushd arrow/java diff --git a/docker-compose.yml b/docker-compose.yml index 751a81fa554..67dfd87512e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1167,7 +1167,7 @@ services: command: [ "/arrow/ci/scripts/cpp_build.sh /arrow /build && /arrow/ci/scripts/python_build.sh /arrow /build && - /arrow/ci/scripts/java_jni_build.sh /arrow /build /tmp/dist/java && + /arrow/ci/scripts/java_jni_build.sh /arrow $${ARROW_HOME} /build /tmp/dist/java && /arrow/ci/scripts/java_build.sh /arrow /build /tmp/dist/java && /arrow/ci/scripts/java_cdata_integration.sh /arrow /tmp/dist/java" ] diff --git a/docs/source/developers/java/building.rst b/docs/source/developers/java/building.rst index add2b11b278..b45afa70a9d 100644 --- a/docs/source/developers/java/building.rst +++ b/docs/source/developers/java/building.rst @@ -75,78 +75,70 @@ We can build these manually or we can use `Archery`_ to build them using a Docke |__ libarrow_dataset_jni.so |__ libarrow_orc_jni.so |__ libgandiva_jni.so + |__ libplasma_java.so Building JNI Libraries on MacOS ------------------------------- Note: If you are building on Apple Silicon, be sure to use a JDK version that was compiled for that architecture. See, for example, the `Azul JDK `_. -To build only the C Data Interface library: +First, you need to build Apache Arrow C++: .. code-block:: $ cd arrow $ brew bundle --file=cpp/Brewfile Homebrew Bundle complete! 25 Brewfile dependencies now installed. + $ brew uninstall aws-sdk-cpp + (We can't use aws-sdk-cpp installed by Homebrew because it has + an issue: https://github.com/aws/aws-sdk-cpp/issues/1809 ) $ export JAVA_HOME= - $ mkdir -p java-dist java-native-c - $ cd java-native-c + $ mkdir -p java-dist cpp-jni $ cmake \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_INSTALL_PREFIX=../java-dist/lib \ - ../java - $ cmake --build . --target install - $ ls -latr ../java-dist/lib - |__ libarrow_cdata_jni.dylib - -To build other JNI libraries: - -.. code-block:: - - $ cd arrow - $ brew bundle --file=cpp/Brewfile - Homebrew Bundle complete! 25 Brewfile dependencies now installed. - $ export JAVA_HOME= - $ mkdir -p java-dist java-native-cpp - $ cd java-native-cpp - $ cmake \ - -DARROW_BOOST_USE_SHARED=OFF \ - -DARROW_BROTLI_USE_SHARED=OFF \ - -DARROW_BZ2_USE_SHARED=OFF \ - -DARROW_GFLAGS_USE_SHARED=OFF \ - -DARROW_GRPC_USE_SHARED=OFF \ - -DARROW_LZ4_USE_SHARED=OFF \ - -DARROW_OPENSSL_USE_SHARED=OFF \ - -DARROW_PROTOBUF_USE_SHARED=OFF \ - -DARROW_SNAPPY_USE_SHARED=OFF \ - -DARROW_THRIFT_USE_SHARED=OFF \ - -DARROW_UTF8PROC_USE_SHARED=OFF \ - -DARROW_ZSTD_USE_SHARED=OFF \ - -DARROW_JNI=ON \ - -DARROW_PARQUET=ON \ - -DARROW_FILESYSTEM=ON \ + -S cpp \ + -B cpp-jni \ + -DARROW_CSV=ON \ -DARROW_DATASET=ON \ + -DARROW_DEPENDENCY_USE_SHARED=OFF \ + -DARROW_FILESYSTEM=ON \ + -DARROW_GANDIVA=ON \ -DARROW_GANDIVA_JAVA=ON \ -DARROW_GANDIVA_STATIC_LIBSTDCPP=ON \ - -DARROW_GANDIVA=ON \ + -DARROW_JNI=ON \ -DARROW_ORC=ON \ - -DARROW_PLASMA_JAVA_CLIENT=ON \ + -DARROW_PARQUET=ON \ -DARROW_PLASMA=ON \ + -DARROW_PLASMA_JAVA_CLIENT=ON \ + -DARROW_S3=ON \ + -DAWSSDK_SOURCE=BUNDLED \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_LIBDIR=lib \ - -DCMAKE_INSTALL_PREFIX=../java-dist \ + -DCMAKE_INSTALL_PREFIX=java-dist \ -DCMAKE_UNITY_BUILD=ON \ - -Dre2_SOURCE=BUNDLED \ - -DBoost_SOURCE=BUNDLED \ - -Dutf8proc_SOURCE=BUNDLED \ - -DSnappy_SOURCE=BUNDLED \ - -DORC_SOURCE=BUNDLED \ - -DZLIB_SOURCE=BUNDLED \ - ../cpp - $ cmake --build . --target install + -Dre2_SOURCE=BUNDLED + $ cmake --build cpp-jni --target install $ ls -latr ../java-dist/lib - |__ libarrow_dataset_jni.dylib |__ libarrow_orc_jni.dylib |__ libgandiva_jni.dylib + |__ libplasma_java.dylib + +Then, you can build JNI libraries: + +.. code-block:: + + $ mkdir -p java-jni + $ cmake \ + -S java \ + -B java-jni \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_INSTALL_PREFIX=java-dist/lib \ + -DCMAKE_PREFIX_PATH=java-dist + $ cmake --build java-jni --target install + $ ls -latr ../java-dist/lib + |__ libarrow_cdata_jni.dylib + |__ libarrow_dataset_jni.dylib + +To build other JNI libraries: + Building Arrow JNI Modules -------------------------- diff --git a/java/CMakeLists.txt b/java/CMakeLists.txt index 43818e7a9f3..f187cd943d1 100644 --- a/java/CMakeLists.txt +++ b/java/CMakeLists.txt @@ -28,6 +28,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) # Components option(ARROW_JAVA_JNI_ENABLE_DEFAULT "Whether enable components by default or not" ON) option(ARROW_JAVA_JNI_ENABLE_C "Enable C data interface" ${ARROW_JAVA_JNI_ENABLE_DEFAULT}) +option(ARROW_JAVA_JNI_ENABLE_DATASET "Enable dataset" ${ARROW_JAVA_JNI_ENABLE_DEFAULT}) # ccache option(ARROW_JAVA_JNI_USE_CCACHE "Use ccache when compiling (if available)" ON) @@ -54,6 +55,18 @@ include(UseJava) add_library(jni INTERFACE IMPORTED) set_target_properties(jni PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${JNI_INCLUDE_DIRS}") +include(CTest) +if(BUILD_TESTING) + find_package(ArrowTesting REQUIRED) + find_package(GTest REQUIRED) + add_library(arrow_java_test INTERFACE IMPORTED) + target_link_libraries(arrow_java_test INTERFACE ArrowTesting::arrow_testing_static + GTest::gtest_main) +endif() + if(ARROW_JAVA_JNI_ENABLE_C) add_subdirectory(c) endif() +if(ARROW_JAVA_JNI_ENABLE_DATASET) + add_subdirectory(dataset) +endif() diff --git a/java/c/CMakeLists.txt b/java/c/CMakeLists.txt index f3b3117eacf..7510ab233fe 100644 --- a/java/c/CMakeLists.txt +++ b/java/c/CMakeLists.txt @@ -18,16 +18,16 @@ include_directories(${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR} ${JNI_INCLUDE_DIRS} ${JNI_HEADERS_DIR}) -add_jar(arrow_cdata_jar +add_jar(arrow_java_jni_cdata_jar src/main/java/org/apache/arrow/c/jni/CDataJniException.java src/main/java/org/apache/arrow/c/jni/JniLoader.java src/main/java/org/apache/arrow/c/jni/JniWrapper.java src/main/java/org/apache/arrow/c/jni/PrivateData.java GENERATE_NATIVE_HEADERS - arrow_cdata_jni_headers) + arrow_java_jni_cdata_headers) -set(ARROW_CDATA_JNI_SOURCES src/main/cpp/jni_wrapper.cc) -add_library(arrow_cdata_jni SHARED ${ARROW_CDATA_JNI_SOURCES}) -target_link_libraries(arrow_cdata_jni arrow_cdata_jni_headers jni) +add_library(arrow_java_jni_cdata SHARED src/main/cpp/jni_wrapper.cc) +set_property(TARGET arrow_java_jni_cdata PROPERTY OUTPUT_NAME "arrow_cdata_jni") +target_link_libraries(arrow_java_jni_cdata arrow_java_jni_cdata_headers jni) -install(TARGETS arrow_cdata_jni DESTINATION ${CMAKE_INSTALL_PREFIX}) +install(TARGETS arrow_java_jni_cdata DESTINATION ${CMAKE_INSTALL_PREFIX}) diff --git a/java/dataset/CMakeLists.txt b/java/dataset/CMakeLists.txt index 5b6e4a9ce24..3b76b4e03bc 100644 --- a/java/dataset/CMakeLists.txt +++ b/java/dataset/CMakeLists.txt @@ -15,28 +15,31 @@ # specific language governing permissions and limitations # under the License. -# -# arrow_dataset_java -# - -# Headers: top level - -project(arrow_dataset_java) +find_package(ArrowDataset REQUIRED) -# Find java/jni -include(FindJava) -include(UseJava) -include(FindJNI) +include_directories(${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR} + ${JNI_INCLUDE_DIRS} ${JNI_HEADERS_DIR}) -message("generating headers to ${JNI_HEADERS_DIR}") - -add_jar(arrow_dataset_java +add_jar(arrow_java_jni_dataset_jar src/main/java/org/apache/arrow/dataset/jni/JniLoader.java src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java src/main/java/org/apache/arrow/dataset/file/JniWrapper.java src/main/java/org/apache/arrow/dataset/jni/NativeMemoryPool.java src/main/java/org/apache/arrow/dataset/jni/ReservationListener.java GENERATE_NATIVE_HEADERS - arrow_dataset_java-native - DESTINATION - ${JNI_HEADERS_DIR}) + arrow_java_jni_dataset_headers) + +add_library(arrow_java_jni_dataset SHARED src/main/cpp/jni_wrapper.cc + src/main/cpp/jni_util.cc) +set_property(TARGET arrow_java_jni_dataset PROPERTY OUTPUT_NAME "arrow_dataset_jni") +target_link_libraries(arrow_java_jni_dataset arrow_java_jni_dataset_headers jni + ArrowDataset::arrow_dataset_static) + +if(BUILD_TESTING) + add_executable(arrow-java-jni-dataset-test src/main/cpp/jni_util_test.cc + src/main/cpp/jni_util.cc) + target_link_libraries(arrow-java-jni-dataset-test arrow_java_test) + add_test(NAME arrow-java-jni-dataset-test COMMAND arrow-java-jni-dataset-test) +endif() + +install(TARGETS arrow_java_jni_dataset DESTINATION ${CMAKE_INSTALL_PREFIX}) diff --git a/java/dataset/src/main/cpp/CMakeLists.txt b/java/dataset/src/main/cpp/CMakeLists.txt deleted file mode 100644 index 6a0be9b7f58..00000000000 --- a/java/dataset/src/main/cpp/CMakeLists.txt +++ /dev/null @@ -1,65 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitationsn -# under the License. - -# -# arrow_dataset_jni -# - -project(arrow_dataset_jni) - -cmake_minimum_required(VERSION 3.11) - -find_package(JNI REQUIRED) - -add_custom_target(arrow_dataset_jni) - -set(JNI_HEADERS_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated") - -add_subdirectory(../../../../dataset ./java) - -set(ARROW_BUILD_STATIC OFF) - -set(ARROW_DATASET_JNI_LIBS arrow_dataset_static) - -set(ARROW_DATASET_JNI_SOURCES jni_wrapper.cc jni_util.cc) - -add_arrow_lib(arrow_dataset_jni - BUILD_SHARED - SOURCES - ${ARROW_DATASET_JNI_SOURCES} - OUTPUTS - ARROW_DATASET_JNI_LIBRARIES - SHARED_PRIVATE_LINK_LIBS - ${ARROW_DATASET_JNI_LIBS} - STATIC_LINK_LIBS - ${ARROW_DATASET_JNI_LIBS} - EXTRA_INCLUDES - ${JNI_HEADERS_DIR} - PRIVATE_INCLUDES - ${JNI_INCLUDE_DIRS} - DEPENDENCIES - arrow_static - arrow_dataset_java) - -add_dependencies(arrow_dataset_jni ${ARROW_DATASET_JNI_LIBRARIES}) - -add_arrow_test(dataset_jni_test - SOURCES - jni_util_test.cc - jni_util.cc - EXTRA_INCLUDES - ${JNI_INCLUDE_DIRS}) From abd05fb88cf29ffbf0ac55d38c8a822deac37997 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Tue, 6 Sep 2022 03:29:40 -0400 Subject: [PATCH 005/133] fix test failure --- cpp/src/arrow/compute/exec/source_node.cc | 5 ++++- cpp/src/arrow/util/async_generator.h | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 48ff4b4f9ed..57ff991b7b4 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -327,8 +327,11 @@ struct SchemaSourceNode : public SourceNode { template static Iterator> MakeEnumeratedIterator(Iterator it) { + // TODO: Should Enumerated<>.index be changed to int64_t? Currently, this change + // causes dataset unit-test failures + using index_t = decltype(Enumerated{}.index); struct { - int64_t index = 0; + index_t index = 0; Enumerated operator()(const Item& item) { return Enumerated{item, index++, false}; } diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index 172228f7cfd..9819b5ce923 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -1501,7 +1501,7 @@ AsyncGenerator MakeConcatenatedGenerator(AsyncGenerator> so template struct Enumerated { T value; - int64_t index; + int index; bool last; }; From 23c5ceea785a96cc01c0fa976db1ffc7ff3e839f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilhelm=20=C3=85gren?= <36638274+willeagren@users.noreply.github.com> Date: Tue, 6 Sep 2022 10:26:12 +0200 Subject: [PATCH 006/133] MINOR: [Ruby] Fix wrong English (#14046) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change "There are the [...]" to "These are the [...]". Incorrect wording in previous README. Authored-by: Wilhelm Ågren <36638274+willeagren@users.noreply.github.com> Signed-off-by: Sutou Kouhei --- ruby/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ruby/README.md b/ruby/README.md index f0e380cdc85..785c4687014 100644 --- a/ruby/README.md +++ b/ruby/README.md @@ -19,7 +19,7 @@ # Apache Arrow Ruby -There are the official Ruby bindings for Apache Arrow. +Here are the official Ruby bindings for Apache Arrow. [Red Arrow](https://github.com/apache/arrow/tree/master/ruby/red-arrow) is the base Apache Arrow bindings. From 3e40cd3648a5b4f6ee7203a6c408f08c0abe4696 Mon Sep 17 00:00:00 2001 From: Kevin Gurney <5904145+kevingurney@users.noreply.github.com> Date: Tue, 6 Sep 2022 07:15:01 -0400 Subject: [PATCH 007/133] ARROW-15693: [Dev] Update crossbow templates to use master or main (#13975) # Overview This pull request: 1. Removes hard-coded dependencies on "master" as the default branch name in the crossbow infrastructure and CI template files. # Implementation 1. Removed comment/text references to "master" branch, including URLs to other repositories. 2. Modified `core.py` to add a new `default_branch` property and a new method `is_default_branch`, for checking whether on the default branch, to the `Target` class. 3. Modified CI template files to use the new `is_default_branch` function to check whether on the default branch. # Testing 1. Using [lafiona/crossbow](https://github.com/lafiona/crossbow) as a queue repository for qualification. 2. Ran modified template jobs. All failures appeared to be unrelated to the changes. 3. The branch names for all relevant qualification jobs are prefixed with `build-34-*`. 4. Example of a passing job: [https://github.com/lafiona/crossbow/actions/runs/2920227769](https://github.com/lafiona/crossbow/actions/runs/2920227769) 5. Example of a failing job: [https://github.com/lafiona/crossbow/runs/7998190113](https://github.com/lafiona/crossbow/runs/7998190113) - in this example, the *"Push Docker Image"* workflow step is not included, since we are not on the default branch. The failure appears to be related to issues fetching R package resources and not related to the default branch checking logic. There were a variety of other kinds of failures, but none of them appear related to the default branch checking logic. # Future Directions 1. Remove "master" from `default_branch` name property of `Target` class. 2. Remove all remaining uses of "master" terminology in crossbow. 3. [ARROW-17512](https://issues.apache.org/jira/browse/ARROW-17512): Address minor issues with crossbow documentation. # Notes 1. Thank you to @lafiona for her help with this pull request! 2. Due to unexpected technical issues, we opened this pull request as a follow up to https://github.com/apache/arrow/pull/13750. Please see https://github.com/apache/arrow/pull/13750 for more discussion regarding qualification efforts. Lead-authored-by: Kevin Gurney Co-authored-by: Fiona La Signed-off-by: Alessandro Molina --- dev/archery/archery/crossbow/core.py | 8 +++++++ dev/tasks/conda-recipes/README.md | 2 +- dev/tasks/conda-recipes/azure.clean.yml | 21 +++++++++++++++++-- dev/tasks/docker-tests/github.linux.yml | 2 +- .../homebrew-formulae/apache-arrow-glib.rb | 2 +- dev/tasks/homebrew-formulae/apache-arrow.rb | 2 +- .../autobrew/apache-arrow.rb | 3 +-- dev/tasks/java-jars/github.yml | 2 +- dev/tasks/macros.jinja | 10 ++++----- dev/tasks/nightlies.sample.yml | 2 +- dev/tasks/python-sdist/github.yml | 2 +- .../python-wheels/github.linux.amd64.yml | 2 +- dev/tasks/python-wheels/github.windows.yml | 2 +- .../python-wheels/travis.linux.arm64.yml | 2 +- dev/tasks/r/github.linux.rchk.yml | 2 +- dev/tasks/r/github.macos.brew.yml | 2 ++ .../verify-rc/github.linux.amd64.docker.yml | 2 +- 17 files changed, 47 insertions(+), 21 deletions(-) diff --git a/dev/archery/archery/crossbow/core.py b/dev/archery/archery/crossbow/core.py index 6aa53eded3a..5784b78a8a0 100644 --- a/dev/archery/archery/crossbow/core.py +++ b/dev/archery/archery/crossbow/core.py @@ -745,6 +745,9 @@ def __init__(self, head, branch, remote, version, email=None): self.github_repo = "/".join(_parse_github_user_repo(remote)) self.version = version self.no_rc_version = re.sub(r'-rc\d+\Z', '', version) + # TODO(ARROW-17552): Remove "master" from default_branch after + # migration to "main". + self.default_branch = ['main', 'master'] # Semantic Versioning 1.0.0: https://semver.org/spec/v1.0.0.html # # > A pre-release version number MAY be denoted by appending an @@ -782,6 +785,11 @@ def from_repo(cls, repo, head=None, branch=None, remote=None, version=None, return cls(head=head, email=email, branch=branch, remote=remote, version=version) + def is_default_branch(self): + # TODO(ARROW-17552): Switch the condition to "is" instead of "in" + # once "master" is removed from "default_branch". + return self.branch in self.default_branch + class Task(Serializable): """ diff --git a/dev/tasks/conda-recipes/README.md b/dev/tasks/conda-recipes/README.md index 39f82f1b01a..4cf0d3426d4 100644 --- a/dev/tasks/conda-recipes/README.md +++ b/dev/tasks/conda-recipes/README.md @@ -64,4 +64,4 @@ copied to the upstream feedstocks. [arrow-cpp-feedstock]: https://github.com/conda-forge/arrow-cpp-feedstock [parquet-cpp-feedstock]: https://github.com/conda-forge/parquet-cpp-feedstock -[matrix-definition]: https://github.com/conda-forge/arrow-cpp-feedstock/blob/master/.azure-pipelines/azure-pipelines-linux.yml#L12 +[matrix-definition]: https://github.com/conda-forge/arrow-cpp-feedstock/blob/main/.azure-pipelines/azure-pipelines-linux.yml#L12 diff --git a/dev/tasks/conda-recipes/azure.clean.yml b/dev/tasks/conda-recipes/azure.clean.yml index 84f167812b2..b68f3c93ef3 100644 --- a/dev/tasks/conda-recipes/azure.clean.yml +++ b/dev/tasks/conda-recipes/azure.clean.yml @@ -1,3 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + jobs: - job: linux pool: @@ -17,12 +34,12 @@ jobs: displayName: Install requirements - script: | - {% if arrow.branch == 'master' %} + {% if arrow.is_default_branch() %} mkdir -p $HOME/.continuum/anaconda-client/tokens/ echo $(CROSSBOW_ANACONDA_TOKEN) > $HOME/.continuum/anaconda-client/tokens/https%3A%2F%2Fapi.anaconda.org.token {% endif %} eval "$(conda shell.bash hook)" conda activate base - python3 arrow/dev/tasks/conda-recipes/clean.py {% if arrow.branch == 'master' %}FORCE{% endif %} + python3 arrow/dev/tasks/conda-recipes/clean.py {% if arrow.is_default_branch() %}FORCE{% endif %} displayName: Delete outdated packages diff --git a/dev/tasks/docker-tests/github.linux.yml b/dev/tasks/docker-tests/github.linux.yml index f7fd6a0be6e..638d846e410 100644 --- a/dev/tasks/docker-tests/github.linux.yml +++ b/dev/tasks/docker-tests/github.linux.yml @@ -62,7 +62,7 @@ jobs: path: arrow/r/check/arrow.Rcheck/tests/testthat.Rout* if-no-files-found: ignore - {% if arrow.branch == 'master' %} + {% if arrow.is_default_branch() %} {{ macros.github_login_dockerhub()|indent }} - name: Push Docker Image shell: bash diff --git a/dev/tasks/homebrew-formulae/apache-arrow-glib.rb b/dev/tasks/homebrew-formulae/apache-arrow-glib.rb index 520ff41aec4..2bbdf71ae74 100644 --- a/dev/tasks/homebrew-formulae/apache-arrow-glib.rb +++ b/dev/tasks/homebrew-formulae/apache-arrow-glib.rb @@ -24,7 +24,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# https://github.com/Homebrew/homebrew-core/blob/master/Formula/apache-arrow-glib.rb +# https://github.com/Homebrew/homebrew-core/blob/-/Formula/apache-arrow-glib.rb class ApacheArrowGlib < Formula desc "GLib bindings for Apache Arrow" diff --git a/dev/tasks/homebrew-formulae/apache-arrow.rb b/dev/tasks/homebrew-formulae/apache-arrow.rb index cceaad05397..e68f939bcab 100644 --- a/dev/tasks/homebrew-formulae/apache-arrow.rb +++ b/dev/tasks/homebrew-formulae/apache-arrow.rb @@ -24,7 +24,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# https://github.com/Homebrew/homebrew-core/blob/master/Formula/apache-arrow.rb +# https://github.com/Homebrew/homebrew-core/blob/-/Formula/apache-arrow.rb class ApacheArrow < Formula desc "Columnar in-memory analytics layer designed to accelerate big data" diff --git a/dev/tasks/homebrew-formulae/autobrew/apache-arrow.rb b/dev/tasks/homebrew-formulae/autobrew/apache-arrow.rb index de0c65dae40..c5e2ca557c9 100644 --- a/dev/tasks/homebrew-formulae/autobrew/apache-arrow.rb +++ b/dev/tasks/homebrew-formulae/autobrew/apache-arrow.rb @@ -15,8 +15,7 @@ # specific language governing permissions and limitations # under the License. -# https://github.com/autobrew/homebrew-core/blob/master/Formula/apache-arrow.rb - +# https://github.com/autobrew/homebrew-core/blob/-/Formula/apache-arrow.rb class ApacheArrow < Formula desc "Columnar in-memory analytics layer designed to accelerate big data" homepage "https://arrow.apache.org/" diff --git a/dev/tasks/java-jars/github.yml b/dev/tasks/java-jars/github.yml index f94a43a8b44..abaddc028e9 100644 --- a/dev/tasks/java-jars/github.yml +++ b/dev/tasks/java-jars/github.yml @@ -40,7 +40,7 @@ jobs: with: name: ubuntu-shared-lib path: arrow-shared-libs-linux.tar.gz - {% if arrow.branch == 'master' %} + {% if arrow.is_default_branch() %} {{ macros.github_login_dockerhub()|indent }} - name: Push Docker image shell: bash diff --git a/dev/tasks/macros.jinja b/dev/tasks/macros.jinja index 3e87d507e77..e7099a5163e 100644 --- a/dev/tasks/macros.jinja +++ b/dev/tasks/macros.jinja @@ -111,7 +111,7 @@ on: {% endmacro %} {%- macro github_upload_gemfury(pattern) -%} - {%- if arrow.branch == 'master' -%} + {%- if arrow.is_default_branch() -%} - name: Upload package to Gemfury shell: bash run: | @@ -157,7 +157,7 @@ on: {% endmacro %} {%- macro azure_upload_anaconda(pattern) -%} - {%- if arrow.branch == 'master' -%} + {%- if arrow.is_default_branch() -%} - task: CondaEnvironment@1 inputs: packageSpecs: 'anaconda-client' @@ -216,7 +216,7 @@ on: {% endmacro %} {%- macro travis_upload_gemfury(pattern) -%} - {%- if arrow.branch == 'master' -%} + {%- if arrow.is_default_branch() -%} - | WHEEL_PATH=$(echo arrow/python/repaired_wheels/*.whl) curl \ @@ -243,7 +243,7 @@ on: continue fi # Pin the current commit in the formula to test so that - # we're not always pulling from master + # we're not always pulling from the tip of the default branch sed -i '' -E \ -e 's@https://github.com/apache/arrow.git"$@{{ arrow.remote }}.git", revision: "{{ arrow.head }}"@' \ ${formula} @@ -284,7 +284,7 @@ on: {% endmacro %} {%- macro github_test_r_src_pkg() -%} - source("https://raw.githubusercontent.com/apache/arrow/master/ci/etc/rprofile") + source("https://raw.githubusercontent.com/apache/arrow/HEAD/ci/etc/rprofile") install.packages( "arrow", diff --git a/dev/tasks/nightlies.sample.yml b/dev/tasks/nightlies.sample.yml index 710f7c0ad37..a5e400abf1a 100644 --- a/dev/tasks/nightlies.sample.yml +++ b/dev/tasks/nightlies.sample.yml @@ -16,7 +16,7 @@ # under the License. # this travis configuration can be used to submit cron scheduled tasks -# 1. copy this file to one of crossbow's branch (master for example) with +# 1. copy this file to one of crossbow's branches with # filename .travis.yml # 2. setup daily cron jobs for that particular branch, see travis' # documentation https://docs.travis-ci.com/user/cron-jobs/ diff --git a/dev/tasks/python-sdist/github.yml b/dev/tasks/python-sdist/github.yml index 68371876ab8..ef36e358aa9 100644 --- a/dev/tasks/python-sdist/github.yml +++ b/dev/tasks/python-sdist/github.yml @@ -30,7 +30,7 @@ jobs: - name: Build sdist run: | archery docker run python-sdist - {% if arrow.branch == 'master' %} + {% if arrow.is_default_branch() %} archery docker push python-sdist || : {% endif %} env: diff --git a/dev/tasks/python-wheels/github.linux.amd64.yml b/dev/tasks/python-wheels/github.linux.amd64.yml index dc2386482f1..3b307e0b561 100644 --- a/dev/tasks/python-wheels/github.linux.amd64.yml +++ b/dev/tasks/python-wheels/github.linux.amd64.yml @@ -47,7 +47,7 @@ jobs: {{ macros.github_upload_releases("arrow/python/repaired_wheels/*.whl")|indent }} {{ macros.github_upload_gemfury("arrow/python/repaired_wheels/*.whl")|indent }} - {% if arrow.branch == 'master' %} + {% if arrow.is_default_branch() %} - name: Push Docker Image shell: bash run: | diff --git a/dev/tasks/python-wheels/github.windows.yml b/dev/tasks/python-wheels/github.windows.yml index 6694e9feca6..0db4047951d 100644 --- a/dev/tasks/python-wheels/github.windows.yml +++ b/dev/tasks/python-wheels/github.windows.yml @@ -67,7 +67,7 @@ jobs: {{ macros.github_upload_releases("arrow/python/dist/*.whl")|indent }} {{ macros.github_upload_gemfury("arrow/python/dist/*.whl")|indent }} - {% if arrow.branch == 'master' %} + {% if arrow.is_default_branch() %} - name: Push Docker Image shell: cmd run: | diff --git a/dev/tasks/python-wheels/travis.linux.arm64.yml b/dev/tasks/python-wheels/travis.linux.arm64.yml index d32d89d8301..4557624856e 100644 --- a/dev/tasks/python-wheels/travis.linux.arm64.yml +++ b/dev/tasks/python-wheels/travis.linux.arm64.yml @@ -66,7 +66,7 @@ after_success: {{ macros.travis_upload_releases("arrow/python/repaired_wheels/*.whl") }} {{ macros.travis_upload_gemfury("arrow/python/repaired_wheels/*.whl") }} - {% if arrow.branch == 'master' %} + {% if arrow.is_default_branch() %} # Push the docker image to dockerhub - archery docker push python-wheel-manylinux-{{ manylinux_version }} - archery docker push python-wheel-manylinux-test-unittests diff --git a/dev/tasks/r/github.linux.rchk.yml b/dev/tasks/r/github.linux.rchk.yml index 9854e885f7a..2e2c91061af 100644 --- a/dev/tasks/r/github.linux.rchk.yml +++ b/dev/tasks/r/github.linux.rchk.yml @@ -48,7 +48,7 @@ jobs: docker run -v `pwd`/packages:/rchk/packages kalibera/rchk:latest /rchk/packages/arrow_*.tar.gz |& tee rchk.out - name: Confirm that rchk has no errors # Suspicious call, [UP], and [PB] are all of the error types currently at - # https://github.com/kalibera/cran-checks/tree/master/rchk/results + # https://github.com/kalibera/cran-checks/tree/HEAD/rchk/results # though this might not be exhaustive, there does not appear to be a way to have rchk return an error code # CRAN also will remove some of the outputs (especially those related to Rcpp and strptime, e.g. # ERROR: too many states (abstraction error?)) diff --git a/dev/tasks/r/github.macos.brew.yml b/dev/tasks/r/github.macos.brew.yml index a403a655954..c2cbcc5e3b2 100644 --- a/dev/tasks/r/github.macos.brew.yml +++ b/dev/tasks/r/github.macos.brew.yml @@ -30,6 +30,8 @@ jobs: - name: Install apache-arrow run: | + # TODO: Update the TODO for ARROW-16907 below to refer to main instead of master + # after migrating the default branch to main. # TODO(ARROW-16907): apache/arrow@master seems to be installed already # so this does nothing on a branch/PR brew install -v --HEAD apache-arrow diff --git a/dev/tasks/verify-rc/github.linux.amd64.docker.yml b/dev/tasks/verify-rc/github.linux.amd64.docker.yml index aa6b837e307..15e0e597a4d 100644 --- a/dev/tasks/verify-rc/github.linux.amd64.docker.yml +++ b/dev/tasks/verify-rc/github.linux.amd64.docker.yml @@ -43,7 +43,7 @@ jobs: -e TEST_{{ target|upper }}=1 \ {{ distro }}-verify-rc - {% if arrow.branch == 'master' %} + {% if arrow.is_default_branch() %} {{ macros.github_login_dockerhub()|indent }} - name: Push Docker Image shell: bash From a5ecb0ff0774805b0f912e231eaedf42e7194c36 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 6 Sep 2022 06:46:36 -0700 Subject: [PATCH 008/133] ARROW-17079: Show HTTP status code for unknown S3 errors (#14019) This is the last change I propose to improve our S3 error message. For certain errors, unfortunately the AWS SDK is doing a poor job in propagating the error and just reports UNKNOWN (see https://github.com/aws/aws-sdk-cpp/blob/1614bce979a201ada1e3436358edb7bd1834b5d6/aws-cpp-sdk-core/source/client/AWSClient.cpp#L77), in these cases the HTTP status code can be an important source to find out what is going wrong (and is also reported by boto3). This has the downside of cluttering the error message a bit more, but in general this information will be very valuable to diagnose the problem. Given that we now have the API call and the HTTP status error, in general there is good documentation on the internet that helps diagnose the problem. Before: > When getting information for key 'test.csv' in bucket 'pcmoritz-test-bucket-arrow-errors': AWS Error UNKNOWN during HeadObject call: No response body. After: > When getting information for key 'test.csv' in bucket 'pcmoritz-test-bucket-arrow-errors': AWS Error UNKNOWN **(HTTP status 400)** during HeadObject call: No response body. Lead-authored-by: Philipp Moritz Co-authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- cpp/src/arrow/filesystem/s3_internal.h | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/filesystem/s3_internal.h b/cpp/src/arrow/filesystem/s3_internal.h index 093fdc7ca45..c6e6349ba2c 100644 --- a/cpp/src/arrow/filesystem/s3_internal.h +++ b/cpp/src/arrow/filesystem/s3_internal.h @@ -152,10 +152,14 @@ Status ErrorToStatus(const std::string& prefix, const std::string& operation, // XXX Handle fine-grained error types // See // https://sdk.amazonaws.com/cpp/api/LATEST/namespace_aws_1_1_s3.html#ae3f82f8132b619b6e91c88a9f1bde371 - return Status::IOError( - prefix, "AWS Error ", - S3ErrorToString(static_cast(error.GetErrorType())), " during ", - operation, " operation: ", error.GetMessage()); + auto error_type = static_cast(error.GetErrorType()); + std::stringstream ss; + ss << S3ErrorToString(error_type); + if (error_type == Aws::S3::S3Errors::UNKNOWN) { + ss << " (HTTP status " << static_cast(error.GetResponseCode()) << ")"; + } + return Status::IOError(prefix, "AWS Error ", ss.str(), " during ", operation, + " operation: ", error.GetMessage()); } template From cbf0ec0d05fe6301988f3b8f02ea39fead788f6c Mon Sep 17 00:00:00 2001 From: Joost Hoozemans Date: Tue, 6 Sep 2022 22:31:59 +0200 Subject: [PATCH 009/133] ARROW-16000: [C++][Python] Dataset: Alternative implementation for adding transcoding function option to CSV scanner (#13820) This is an alternative version of https://github.com/apache/arrow/pull/13709, to compare what the best approach is. Instead of extending the C++ ReadOptions struct with an `encoding` field, this implementations adds a python version of the ReadOptions object to both `CsvFileFormat` and `CsvFragmentScanOptions`. The reason it is needed in both places, is to prevent these kinds of inconsistencies: ``` >>> import pyarrow.dataset as ds >>> import pyarrow.csv as csv >>> ro =csv.ReadOptions(encoding='iso8859') >>> fo = ds.CsvFileFormat(read_options=ro) >>> fo.default_fragment_scan_options.read_options.encoding 'utf8' ``` Authored-by: Joost Hoozemans Signed-off-by: David Li --- cpp/src/arrow/dataset/file_csv.cc | 17 ++++++- cpp/src/arrow/dataset/file_csv.h | 10 ++++ python/pyarrow/_dataset.pyx | 30 +++++++++++- python/pyarrow/includes/libarrow.pxd | 8 ++++ python/pyarrow/includes/libarrow_dataset.pxd | 1 + python/pyarrow/io.pxi | 29 +++++++++++- python/pyarrow/lib.pxd | 3 ++ python/pyarrow/src/io.cc | 10 ++++ python/pyarrow/src/io.h | 5 ++ python/pyarrow/tests/test_dataset.py | 49 ++++++++++++++++++++ 10 files changed, 157 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index d4e0af7808c..780f845429b 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -183,9 +183,15 @@ static inline Future> OpenReaderAsync( auto tracer = arrow::internal::tracing::GetTracer(); auto span = tracer->StartSpan("arrow::dataset::CsvFileFormat::OpenReaderAsync"); #endif + ARROW_ASSIGN_OR_RAISE( + auto fragment_scan_options, + GetFragmentScanOptions( + kCsvTypeName, scan_options.get(), format.default_fragment_scan_options)); ARROW_ASSIGN_OR_RAISE(auto reader_options, GetReadOptions(format, scan_options)); - ARROW_ASSIGN_OR_RAISE(auto input, source.OpenCompressed()); + if (fragment_scan_options->stream_transform_func) { + ARROW_ASSIGN_OR_RAISE(input, fragment_scan_options->stream_transform_func(input)); + } const auto& path = source.path(); ARROW_ASSIGN_OR_RAISE( input, io::BufferedInputStream::Create(reader_options.block_size, @@ -289,8 +295,15 @@ Future> CsvFileFormat::CountRows( return Future>::MakeFinished(util::nullopt); } auto self = checked_pointer_cast(shared_from_this()); - ARROW_ASSIGN_OR_RAISE(auto input, file->source().OpenCompressed()); + ARROW_ASSIGN_OR_RAISE( + auto fragment_scan_options, + GetFragmentScanOptions( + kCsvTypeName, options.get(), self->default_fragment_scan_options)); ARROW_ASSIGN_OR_RAISE(auto read_options, GetReadOptions(*self, options)); + ARROW_ASSIGN_OR_RAISE(auto input, file->source().OpenCompressed()); + if (fragment_scan_options->stream_transform_func) { + ARROW_ASSIGN_OR_RAISE(input, fragment_scan_options->stream_transform_func(input)); + } return csv::CountRowsAsync(options->io_context, std::move(input), ::arrow::internal::GetCpuThreadPool(), read_options, self->parse_options) diff --git a/cpp/src/arrow/dataset/file_csv.h b/cpp/src/arrow/dataset/file_csv.h index 83dbb88b85f..84bcf94abe3 100644 --- a/cpp/src/arrow/dataset/file_csv.h +++ b/cpp/src/arrow/dataset/file_csv.h @@ -73,6 +73,9 @@ class ARROW_DS_EXPORT CsvFileFormat : public FileFormat { struct ARROW_DS_EXPORT CsvFragmentScanOptions : public FragmentScanOptions { std::string type_name() const override { return kCsvTypeName; } + using StreamWrapFunc = std::function>( + std::shared_ptr)>; + /// CSV conversion options csv::ConvertOptions convert_options = csv::ConvertOptions::Defaults(); @@ -80,6 +83,13 @@ struct ARROW_DS_EXPORT CsvFragmentScanOptions : public FragmentScanOptions { /// /// Note that use_threads is always ignored. csv::ReadOptions read_options = csv::ReadOptions::Defaults(); + + /// Optional stream wrapping function + /// + /// If defined, all open dataset file fragments will be passed + /// through this function. One possible use case is to transparently + /// transcode all input files from a given character set to utf8. + StreamWrapFunc stream_transform_func{}; }; class ARROW_DS_EXPORT CsvFileWriteOptions : public FileWriteOptions { diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 2258e31beec..57029b8da5c 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -21,6 +21,7 @@ from cython.operator cimport dereference as deref +import codecs import collections import os import warnings @@ -831,8 +832,14 @@ cdef class FileFormat(_Weakrefable): @property def default_fragment_scan_options(self): - return FragmentScanOptions.wrap( + dfso = FragmentScanOptions.wrap( self.wrapped.get().default_fragment_scan_options) + # CsvFileFormat stores a Python-specific encoding field that needs + # to be restored because it does not exist in the C++ struct + if isinstance(self, CsvFileFormat): + if self._read_options_py is not None: + dfso.read_options = self._read_options_py + return dfso @default_fragment_scan_options.setter def default_fragment_scan_options(self, FragmentScanOptions options): @@ -1178,6 +1185,10 @@ cdef class CsvFileFormat(FileFormat): """ cdef: CCsvFileFormat* csv_format + # The encoding field in ReadOptions does not exist in the C++ struct. + # We need to store it here and override it when reading + # default_fragment_scan_options.read_options + public ReadOptions _read_options_py # Avoid mistakingly creating attributes __slots__ = () @@ -1205,6 +1216,8 @@ cdef class CsvFileFormat(FileFormat): raise TypeError('`default_fragment_scan_options` must be either ' 'a dictionary or an instance of ' 'CsvFragmentScanOptions') + if read_options is not None: + self._read_options_py = read_options cdef void init(self, const shared_ptr[CFileFormat]& sp): FileFormat.init(self, sp) @@ -1227,6 +1240,8 @@ cdef class CsvFileFormat(FileFormat): cdef _set_default_fragment_scan_options(self, FragmentScanOptions options): if options.type_name == 'csv': self.csv_format.default_fragment_scan_options = options.wrapped + self.default_fragment_scan_options.read_options = options.read_options + self._read_options_py = options.read_options else: super()._set_default_fragment_scan_options(options) @@ -1258,6 +1273,9 @@ cdef class CsvFragmentScanOptions(FragmentScanOptions): cdef: CCsvFragmentScanOptions* csv_options + # The encoding field in ReadOptions does not exist in the C++ struct. + # We need to store it here and override it when reading read_options + ReadOptions _read_options_py # Avoid mistakingly creating attributes __slots__ = () @@ -1270,6 +1288,7 @@ cdef class CsvFragmentScanOptions(FragmentScanOptions): self.convert_options = convert_options if read_options is not None: self.read_options = read_options + self._read_options_py = read_options cdef void init(self, const shared_ptr[CFragmentScanOptions]& sp): FragmentScanOptions.init(self, sp) @@ -1285,11 +1304,18 @@ cdef class CsvFragmentScanOptions(FragmentScanOptions): @property def read_options(self): - return ReadOptions.wrap(self.csv_options.read_options) + read_options = ReadOptions.wrap(self.csv_options.read_options) + if self._read_options_py is not None: + read_options.encoding = self._read_options_py.encoding + return read_options @read_options.setter def read_options(self, ReadOptions read_options not None): self.csv_options.read_options = deref(read_options.options) + self._read_options_py = read_options + if codecs.lookup(read_options.encoding).name != 'utf-8': + self.csv_options.stream_transform_func = deref( + make_streamwrap_func(read_options.encoding, 'utf-8')) def equals(self, CsvFragmentScanOptions other): return ( diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 7911567d791..be273975f94 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -1229,6 +1229,9 @@ cdef extern from "arrow/builder.h" namespace "arrow" nogil: ctypedef void CallbackTransform(object, const shared_ptr[CBuffer]& src, shared_ptr[CBuffer]* dest) +ctypedef CResult[shared_ptr[CInputStream]] StreamWrapFunc( + shared_ptr[CInputStream]) + cdef extern from "arrow/util/cancel.h" namespace "arrow" nogil: cdef cppclass CStopToken "arrow::StopToken": @@ -1396,6 +1399,11 @@ cdef extern from "arrow/io/api.h" namespace "arrow::io" nogil: shared_ptr[CInputStream] wrapped, CTransformInputStreamVTable vtable, object method_arg) + shared_ptr[function[StreamWrapFunc]] MakeStreamTransformFunc \ + "arrow::py::MakeStreamTransformFunc"( + CTransformInputStreamVTable vtable, + object method_arg) + # ---------------------------------------------------------------------- # HDFS diff --git a/python/pyarrow/includes/libarrow_dataset.pxd b/python/pyarrow/includes/libarrow_dataset.pxd index d418830bc2f..2147483fcae 100644 --- a/python/pyarrow/includes/libarrow_dataset.pxd +++ b/python/pyarrow/includes/libarrow_dataset.pxd @@ -279,6 +279,7 @@ cdef extern from "arrow/dataset/api.h" namespace "arrow::dataset" nogil: "arrow::dataset::CsvFragmentScanOptions"(CFragmentScanOptions): CCSVConvertOptions convert_options CCSVReadOptions read_options + function[StreamWrapFunc] stream_transform_func cdef cppclass CPartitioning "arrow::dataset::Partitioning": c_string type_name() const diff --git a/python/pyarrow/io.pxi b/python/pyarrow/io.pxi index 3dd60735c3c..bb1405be0a3 100644 --- a/python/pyarrow/io.pxi +++ b/python/pyarrow/io.pxi @@ -1607,6 +1607,33 @@ class Transcoder: return self._encoder.encode(self._decoder.decode(buf, final), final) +cdef shared_ptr[function[StreamWrapFunc]] make_streamwrap_func( + src_encoding, dest_encoding) except *: + """ + Create a function that will add a transcoding transformation to a stream. + Data from that stream will be decoded according to ``src_encoding`` and + then re-encoded according to ``dest_encoding``. + The created function can be used to wrap streams. + + Parameters + ---------- + src_encoding : str + The codec to use when reading data. + dest_encoding : str + The codec to use for emitted data. + """ + cdef: + shared_ptr[function[StreamWrapFunc]] empty_func + CTransformInputStreamVTable vtable + + vtable.transform = _cb_transform + src_codec = codecs.lookup(src_encoding) + dest_codec = codecs.lookup(dest_encoding) + return MakeStreamTransformFunc(move(vtable), + Transcoder(src_codec.incrementaldecoder(), + dest_codec.incrementalencoder())) + + def transcoding_input_stream(stream, src_encoding, dest_encoding): """ Add a transcoding transformation to the stream. @@ -1618,7 +1645,7 @@ def transcoding_input_stream(stream, src_encoding, dest_encoding): stream : NativeFile The stream to which the transformation should be applied. src_encoding : str - The codec to use when reading data data. + The codec to use when reading data. dest_encoding : str The codec to use for emitted data. """ diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 953b0e7b518..67db3d2ffb8 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -536,6 +536,9 @@ cdef shared_ptr[CInputStream] native_transcoding_input_stream( shared_ptr[CInputStream] stream, src_encoding, dest_encoding) except * +cdef shared_ptr[function[StreamWrapFunc]] make_streamwrap_func( + src_encoding, dest_encoding) except * + # Default is allow_none=False cpdef DataType ensure_type(object type, bint allow_none=*) diff --git a/python/pyarrow/src/io.cc b/python/pyarrow/src/io.cc index 173d84ff567..0aa2c85939f 100644 --- a/python/pyarrow/src/io.cc +++ b/python/pyarrow/src/io.cc @@ -370,5 +370,15 @@ std::shared_ptr<::arrow::io::InputStream> MakeTransformInputStream( return std::make_shared(std::move(wrapped), std::move(transform)); } +std::shared_ptr MakeStreamTransformFunc(TransformInputStreamVTable vtable, + PyObject* handler) { + TransformInputStream::TransformFunc transform( + TransformFunctionWrapper{std::move(vtable.transform), handler}); + StreamWrapFunc func = [transform](std::shared_ptr<::arrow::io::InputStream> wrapped) { + return std::make_shared(wrapped, transform); + }; + return std::make_shared(func); +} + } // namespace py } // namespace arrow diff --git a/python/pyarrow/src/io.h b/python/pyarrow/src/io.h index 53b15434ea6..9d79d566efe 100644 --- a/python/pyarrow/src/io.h +++ b/python/pyarrow/src/io.h @@ -112,5 +112,10 @@ std::shared_ptr<::arrow::io::InputStream> MakeTransformInputStream( std::shared_ptr<::arrow::io::InputStream> wrapped, TransformInputStreamVTable vtable, PyObject* arg); +using StreamWrapFunc = std::function>( + std::shared_ptr)>; +ARROW_PYTHON_EXPORT +std::shared_ptr MakeStreamTransformFunc(TransformInputStreamVTable vtable, + PyObject* handler); } // namespace py } // namespace arrow diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index c3c80a15bb1..ff1ac7e1065 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -3137,6 +3137,55 @@ def test_csv_fragment_options(tempdir, dataset_reader): pa.table({'col0': pa.array(['foo', 'spam', 'MYNULL'])})) +def test_encoding(tempdir, dataset_reader): + path = str(tempdir / 'test.csv') + + for encoding, input_rows in [ + ('latin-1', b"a,b\nun,\xe9l\xe9phant"), + ('utf16', b'\xff\xfea\x00,\x00b\x00\n\x00u\x00n\x00,' + b'\x00\xe9\x00l\x00\xe9\x00p\x00h\x00a\x00n\x00t\x00'), + ]: + + with open(path, 'wb') as sink: + sink.write(input_rows) + + # Interpret as utf8: + expected_schema = pa.schema([("a", pa.string()), ("b", pa.string())]) + expected_table = pa.table({'a': ["un"], + 'b': ["éléphant"]}, schema=expected_schema) + + read_options = pa.csv.ReadOptions(encoding=encoding) + file_format = ds.CsvFileFormat(read_options=read_options) + dataset_transcoded = ds.dataset(path, format=file_format) + assert dataset_transcoded.schema.equals(expected_schema) + assert dataset_transcoded.to_table().equals(expected_table) + + +# Test if a dataset with non-utf8 chars in the column names is properly handled +def test_column_names_encoding(tempdir, dataset_reader): + path = str(tempdir / 'test.csv') + + with open(path, 'wb') as sink: + sink.write(b"\xe9,b\nun,\xe9l\xe9phant") + + # Interpret as utf8: + expected_schema = pa.schema([("é", pa.string()), ("b", pa.string())]) + expected_table = pa.table({'é': ["un"], + 'b': ["éléphant"]}, schema=expected_schema) + + # Reading as string without specifying encoding should produce an error + dataset = ds.dataset(path, format='csv', schema=expected_schema) + with pytest.raises(pyarrow.lib.ArrowInvalid, match="invalid UTF8"): + dataset_reader.to_table(dataset) + + # Setting the encoding in the read_options should transcode the data + read_options = pa.csv.ReadOptions(encoding='latin-1') + file_format = ds.CsvFileFormat(read_options=read_options) + dataset_transcoded = ds.dataset(path, format=file_format) + assert dataset_transcoded.schema.equals(expected_schema) + assert dataset_transcoded.to_table().equals(expected_table) + + def test_feather_format(tempdir, dataset_reader): from pyarrow.feather import write_feather From ff3aa3b7bb31c679892d19ff74d67563b986828f Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 7 Sep 2022 11:18:18 -0400 Subject: [PATCH 010/133] ARROW-17638: [Go] Extend C Data API support for Union arrays and RecordReader interface (#14057) Lead-authored-by: Matt Topol Co-authored-by: Matthew Topol Signed-off-by: Matt Topol --- go/arrow/cdata/cdata.go | 175 +++++++++++++++++++++++++++++++++-- go/arrow/cdata/cdata_test.go | 1 - 2 files changed, 167 insertions(+), 9 deletions(-) diff --git a/go/arrow/cdata/cdata.go b/go/arrow/cdata/cdata.go index 9e1f0b2076d..a2b583f268e 100644 --- a/go/arrow/cdata/cdata.go +++ b/go/arrow/cdata/cdata.go @@ -243,6 +243,39 @@ func importSchema(schema *CArrowSchema) (ret arrow.Field, err error) { st := childFields[0].Type.(*arrow.StructType) dt = arrow.MapOf(st.Field(0).Type, st.Field(1).Type) dt.(*arrow.MapType).KeysSorted = (schema.flags & C.ARROW_FLAG_MAP_KEYS_SORTED) != 0 + case 'u': // union + var mode arrow.UnionMode + switch f[2] { + case 'd': + mode = arrow.DenseMode + case 's': + mode = arrow.SparseMode + default: + err = fmt.Errorf("%w: invalid union type", arrow.ErrInvalid) + return + } + + codes := strings.Split(strings.Split(f, ":")[1], ",") + typeCodes := make([]arrow.UnionTypeCode, 0, len(codes)) + for _, i := range codes { + v, e := strconv.ParseInt(i, 10, 8) + if e != nil { + err = fmt.Errorf("%w: invalid type code: %s", arrow.ErrInvalid, e) + return + } + if v < 0 { + err = fmt.Errorf("%w: negative type code in union: format string %s", arrow.ErrInvalid, f) + return + } + typeCodes = append(typeCodes, arrow.UnionTypeCode(v)) + } + + if len(childFields) != len(typeCodes) { + err = fmt.Errorf("%w: ArrowArray struct number of children incompatible with format string", arrow.ErrInvalid) + return + } + + dt = arrow.UnionOf(mode, childFields, typeCodes) } } @@ -311,6 +344,18 @@ func (imp *cimporter) doImportChildren() error { if err := imp.children[0].importChild(imp, children[0]); err != nil { return err } + case arrow.DENSE_UNION: + dt := imp.dt.(*arrow.DenseUnionType) + for i, c := range children { + imp.children[i].dt = dt.Fields()[i].Type + imp.children[i].importChild(imp, c) + } + case arrow.SPARSE_UNION: + dt := imp.dt.(*arrow.SparseUnionType) + for i, c := range children { + imp.children[i].dt = dt.Fields()[i].Type + imp.children[i].importChild(imp, c) + } } return nil @@ -407,6 +452,52 @@ func (imp *cimporter) doImport(src *CArrowArray) error { } imp.data = array.NewData(dt, int(imp.arr.length), []*memory.Buffer{nulls}, children, int(imp.arr.null_count), int(imp.arr.offset)) + case *arrow.DenseUnionType: + if err := imp.checkNoNulls(); err != nil { + return err + } + + bufs := []*memory.Buffer{nil, nil, nil} + if imp.arr.n_buffers == 3 { + // legacy format exported by older arrow c++ versions + bufs[1] = imp.importFixedSizeBuffer(1, 1) + bufs[2] = imp.importFixedSizeBuffer(2, int64(arrow.Int32SizeBytes)) + } else { + if err := imp.checkNumBuffers(2); err != nil { + return err + } + + bufs[1] = imp.importFixedSizeBuffer(0, 1) + bufs[2] = imp.importFixedSizeBuffer(1, int64(arrow.Int32SizeBytes)) + } + + children := make([]arrow.ArrayData, len(imp.children)) + for i := range imp.children { + children[i] = imp.children[i].data + } + imp.data = array.NewData(dt, int(imp.arr.length), bufs, children, 0, int(imp.arr.offset)) + case *arrow.SparseUnionType: + if err := imp.checkNoNulls(); err != nil { + return err + } + + var buf *memory.Buffer + if imp.arr.n_buffers == 2 { + // legacy format exported by older Arrow C++ versions + buf = imp.importFixedSizeBuffer(1, 1) + } else { + if err := imp.checkNumBuffers(1); err != nil { + return err + } + + buf = imp.importFixedSizeBuffer(0, 1) + } + + children := make([]arrow.ArrayData, len(imp.children)) + for i := range imp.children { + children[i] = imp.children[i].data + } + imp.data = array.NewData(dt, int(imp.arr.length), []*memory.Buffer{nil, buf}, children, 0, int(imp.arr.offset)) default: return fmt.Errorf("unimplemented type %s", dt) } @@ -494,6 +585,13 @@ func (imp *cimporter) importFixedSizePrimitive() error { func (imp *cimporter) checkNoChildren() error { return imp.checkNumChildren(0) } +func (imp *cimporter) checkNoNulls() error { + if imp.arr.null_count != 0 { + return fmt.Errorf("%w: unexpected non-zero null count for imported type %s", arrow.ErrInvalid, imp.dt) + } + return nil +} + func (imp *cimporter) checkNumChildren(n int64) error { if int64(imp.arr.n_children) != n { return fmt.Errorf("expected %d children, for imported type %s, ArrowArray has %d", n, imp.dt, imp.arr.n_children) @@ -558,6 +656,9 @@ func initReader(rdr *nativeCRecordBatchReader, stream *CArrowArrayStream) { rdr.stream = C.get_stream() C.ArrowArrayStreamMove(stream, rdr.stream) runtime.SetFinalizer(rdr, func(r *nativeCRecordBatchReader) { + if r.cur != nil { + r.cur.Release() + } C.ArrowArrayStreamRelease(r.stream) C.free(unsafe.Pointer(r.stream)) }) @@ -567,40 +668,98 @@ func initReader(rdr *nativeCRecordBatchReader, stream *CArrowArrayStream) { type nativeCRecordBatchReader struct { stream *CArrowArrayStream schema *arrow.Schema + + cur arrow.Record + err error } -func (n *nativeCRecordBatchReader) getError(errno int) error { - return fmt.Errorf("%w: %s", syscall.Errno(errno), C.GoString(C.stream_get_last_error(n.stream))) +// No need to implement retain and release here as we used runtime.SetFinalizer when constructing +// the reader to free up the ArrowArrayStream memory when the garbage collector cleans it up. +func (n *nativeCRecordBatchReader) Retain() {} +func (n *nativeCRecordBatchReader) Release() {} + +func (n *nativeCRecordBatchReader) Record() arrow.Record { return n.cur } + +func (n *nativeCRecordBatchReader) Next() bool { + err := n.next() + switch { + case err == nil: + return true + case err == io.EOF: + return false + } + n.err = err + return false } -func (n *nativeCRecordBatchReader) Read() (arrow.Record, error) { +func (n *nativeCRecordBatchReader) next() error { if n.schema == nil { var sc CArrowSchema errno := C.stream_get_schema(n.stream, &sc) if errno != 0 { - return nil, n.getError(int(errno)) + return n.getError(int(errno)) } defer C.ArrowSchemaRelease(&sc) s, err := ImportCArrowSchema((*CArrowSchema)(&sc)) if err != nil { - return nil, err + return err } n.schema = s } + if n.cur != nil { + n.cur.Release() + n.cur = nil + } + arr := C.get_arr() defer C.free(unsafe.Pointer(arr)) errno := C.stream_get_next(n.stream, arr) if errno != 0 { - return nil, n.getError(int(errno)) + return n.getError(int(errno)) } if C.ArrowArrayIsReleased(arr) == 1 { - return nil, io.EOF + return io.EOF + } + + rec, err := ImportCRecordBatchWithSchema(arr, n.schema) + if err != nil { + return err } - return ImportCRecordBatchWithSchema(arr, n.schema) + n.cur = rec + return nil +} + +func (n *nativeCRecordBatchReader) Schema() *arrow.Schema { + if n.schema == nil { + var sc CArrowSchema + errno := C.stream_get_schema(n.stream, &sc) + if errno != 0 { + panic(n.getError(int(errno))) + } + defer C.ArrowSchemaRelease(&sc) + s, err := ImportCArrowSchema((*CArrowSchema)(&sc)) + if err != nil { + panic(err) + } + + n.schema = s + } + return n.schema +} + +func (n *nativeCRecordBatchReader) getError(errno int) error { + return fmt.Errorf("%w: %s", syscall.Errno(errno), C.GoString(C.stream_get_last_error(n.stream))) +} + +func (n *nativeCRecordBatchReader) Read() (arrow.Record, error) { + if err := n.next(); err != nil { + return nil, err + } + return n.cur, nil } func releaseArr(arr *CArrowArray) { diff --git a/go/arrow/cdata/cdata_test.go b/go/arrow/cdata/cdata_test.go index 03c01181c13..0b73a08d6b0 100644 --- a/go/arrow/cdata/cdata_test.go +++ b/go/arrow/cdata/cdata_test.go @@ -646,7 +646,6 @@ func TestRecordReaderStream(t *testing.T) { } assert.NoError(t, err) } - defer rec.Release() assert.EqualValues(t, 2, rec.NumCols()) assert.Equal(t, "a", rec.ColumnName(0)) From c586b9fe459ead3bf151de9a87e1ca51d49a5687 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Wed, 7 Sep 2022 14:18:48 -0300 Subject: [PATCH 011/133] ARROW-17519: [R] RTools35 job is failing (#14035) This PR removes the RTools35 CI job as it's currently failing and we're about to drop support following the vote to move to C++17. Authored-by: Dewey Dunnington Signed-off-by: Dewey Dunnington --- .github/workflows/r.yml | 16 ---------- ci/scripts/PKGBUILD | 29 +++---------------- r/R/arrow-info.R | 6 ---- r/R/arrow-package.R | 8 ----- r/README.md | 8 ++--- r/src/safe-call-into-r-impl.cpp | 8 +---- r/src/safe-call-into-r.h | 6 ++-- r/tests/testthat/test-dataset-csv.R | 6 ++-- r/tests/testthat/test-dplyr-arrange.R | 2 -- r/tests/testthat/test-dplyr-collapse.R | 2 -- r/tests/testthat/test-dplyr-count.R | 2 -- r/tests/testthat/test-dplyr-distinct.R | 2 -- r/tests/testthat/test-dplyr-filter.R | 2 -- .../testthat/test-dplyr-funcs-conditional.R | 2 -- r/tests/testthat/test-dplyr-funcs-datetime.R | 1 - r/tests/testthat/test-dplyr-funcs-math.R | 2 -- r/tests/testthat/test-dplyr-funcs-string.R | 1 - r/tests/testthat/test-dplyr-funcs-type.R | 2 -- r/tests/testthat/test-dplyr-group-by.R | 2 -- r/tests/testthat/test-dplyr-join.R | 2 -- r/tests/testthat/test-dplyr-mutate.R | 2 -- r/tests/testthat/test-dplyr-query.R | 2 -- r/tests/testthat/test-dplyr-select.R | 2 -- r/tests/testthat/test-dplyr-summarize.R | 2 -- r/tests/testthat/test-dplyr-union.R | 3 -- r/vignettes/dataset.Rmd | 2 +- 26 files changed, 12 insertions(+), 110 deletions(-) diff --git a/.github/workflows/r.yml b/.github/workflows/r.yml index 4f706e3e5b1..f1bafbef3b7 100644 --- a/.github/workflows/r.yml +++ b/.github/workflows/r.yml @@ -170,8 +170,6 @@ jobs: fail-fast: false matrix: config: - - { rtools: 35, arch: 'mingw32' } - - { rtools: 35, arch: 'mingw64' } - { rtools: 40, arch: 'mingw32' } - { rtools: 40, arch: 'mingw64' } - { rtools: 40, arch: 'ucrt64' } @@ -199,19 +197,11 @@ jobs: restore-keys: | r-${{ matrix.config.rtools }}-ccache-mingw-${{ matrix.config.arch }}-${{ hashFiles('cpp/src/**/*.cc','cpp/src/**/*.h)') }}- r-${{ matrix.config.rtools }}-ccache-mingw-${{ matrix.config.arch }}- - # We use the makepkg-mingw setup that is included in rtools40 even when - # we use the rtools35 compilers, so we always install R 4.0/Rtools40 - uses: r-lib/actions/setup-r@v2 with: r-version: "4.1" rtools-version: 40 Ncpus: 2 - - uses: r-lib/actions/setup-r@v2 - if: ${{ matrix.config.rtools == 35 }} - with: - rtools-version: 35 - r-version: "3.6" - Ncpus: 2 - name: Build Arrow C++ shell: bash env: @@ -226,11 +216,6 @@ jobs: with: name: libarrow-rtools${{ matrix.config.rtools }}-${{ matrix.config.arch }}.zip path: libarrow-rtools${{ matrix.config.rtools }}-${{ matrix.config.arch }}.zip - # We can remove this when we drop support for Rtools 3.5. - - name: Ensure using system tar in actions/cache - run: | - Write-Output "${Env:windir}\System32" | ` - Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append windows-r: needs: [windows-cpp] @@ -242,7 +227,6 @@ jobs: fail-fast: false matrix: config: - - { rtools: 35, rversion: "3.6" } - { rtools: 40, rversion: "4.1" } - { rtools: 42, rversion: "4.2" } - { rtools: 42, rversion: "devel" } diff --git a/ci/scripts/PKGBUILD b/ci/scripts/PKGBUILD index f0a09bab7f0..81822cc4eb4 100644 --- a/ci/scripts/PKGBUILD +++ b/ci/scripts/PKGBUILD @@ -73,27 +73,6 @@ build() { # set the appropriate compiler definition. export CPPFLAGS="-DUTF8PROC_STATIC" - # This is the difference between rtools-packages and rtools-backports - # Remove this when submitting to rtools-packages - if [ "$RTOOLS_VERSION" = "35" ]; then - export CC="/C/Rtools${MINGW_PREFIX/mingw/mingw_}/bin/gcc" - export CXX="/C/Rtools${MINGW_PREFIX/mingw/mingw_}/bin/g++" - export PATH="/C/Rtools${MINGW_PREFIX/mingw/mingw_}/bin:$PATH" - export CPPFLAGS="${CPPFLAGS} -I${MINGW_PREFIX}/include" - export LIBS="-L${MINGW_PREFIX}/libs" - export ARROW_GCS=OFF - export ARROW_S3=OFF - export ARROW_WITH_RE2=OFF - # Without this, some dataset functionality segfaults - export CMAKE_UNITY_BUILD=ON - else - export ARROW_GCS=ON - export ARROW_S3=ON - export ARROW_WITH_RE2=ON - # Without this, some compute functionality segfaults in tests - export CMAKE_UNITY_BUILD=OFF - fi - MSYS2_ARG_CONV_EXCL="-DCMAKE_INSTALL_PREFIX=" \ ${MINGW_PREFIX}/bin/cmake.exe \ ${ARROW_CPP_DIR} \ @@ -105,7 +84,7 @@ build() { -DARROW_CSV=ON \ -DARROW_DATASET=ON \ -DARROW_FILESYSTEM=ON \ - -DARROW_GCS="${ARROW_GCS}" \ + -DARROW_GCS=ON \ -DARROW_HDFS=OFF \ -DARROW_JEMALLOC=OFF \ -DARROW_JSON=ON \ @@ -113,13 +92,13 @@ build() { -DARROW_MIMALLOC=ON \ -DARROW_PACKAGE_PREFIX="${MINGW_PREFIX}" \ -DARROW_PARQUET=ON \ - -DARROW_S3="${ARROW_S3}" \ + -DARROW_S3=ON \ -DARROW_SNAPPY_USE_SHARED=OFF \ -DARROW_USE_GLOG=OFF \ -DARROW_UTF8PROC_USE_SHARED=OFF \ -DARROW_VERBOSE_THIRDPARTY_BUILD=ON \ -DARROW_WITH_LZ4=ON \ - -DARROW_WITH_RE2="${ARROW_WITH_RE2}" \ + -DARROW_WITH_RE2=ON \ -DARROW_WITH_SNAPPY=ON \ -DARROW_WITH_ZLIB=ON \ -DARROW_WITH_ZSTD=ON \ @@ -129,7 +108,7 @@ build() { -DARROW_CXXFLAGS="${CPPFLAGS}" \ -DCMAKE_BUILD_TYPE="release" \ -DCMAKE_INSTALL_PREFIX=${MINGW_PREFIX} \ - -DCMAKE_UNITY_BUILD=${CMAKE_UNITY_BUILD} \ + -DCMAKE_UNITY_BUILD=OFF \ -DCMAKE_VERBOSE_MAKEFILE=ON make -j3 diff --git a/r/R/arrow-info.R b/r/R/arrow-info.R index 55d07b77cb4..52e0cf3009f 100644 --- a/r/R/arrow-info.R +++ b/r/R/arrow-info.R @@ -82,12 +82,6 @@ arrow_available <- function() { #' @rdname arrow_info #' @export arrow_with_dataset <- function() { - if (on_old_windows()) { - # 32-bit rtools 3.5 does not properly implement the std::thread expectations - # but we can't just disable ARROW_DATASET in that build, - # so report it as "off" here. - return(FALSE) - } tryCatch(.Call(`_dataset_available`), error = function(e) { return(FALSE) }) diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index e8aa93f9534..53fb0280a50 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -117,14 +117,6 @@ configure_tzdb <- function() { }) } -on_old_windows <- function() { - is_32bit <- .Machine$sizeof.pointer < 8 - is_old_r <- getRversion() < "4.0.0" - is_windows <- tolower(Sys.info()[["sysname"]]) == "windows" - - is_32bit && is_old_r && is_windows -} - # True when the OS is linux + and the R version is development # helpful for skipping on Valgrind, and the sanitizer checks (clang + gcc) on cran diff --git a/r/README.md b/r/README.md index 1509ae7793f..2a85a82aeb3 100644 --- a/r/README.md +++ b/r/README.md @@ -65,12 +65,8 @@ packages that contain the Arrow C++ library. On Linux, source package installation will also build necessary C++ dependencies. For a faster, more complete installation, set the environment variable `NOT_CRAN=true`. See `vignette("install", package = "arrow")` for -details. - -For Windows users of R 3.6 and earlier, note that support for AWS S3 is not -available, and the 32-bit version does not support Arrow Datasets. -These features are only supported by the `rtools40` toolchain on Windows -and thus are only available in R >= 4.0. +details. Note that version 9.0.0 was the last version to support +R 3.6 and lower on Windows. ### Installing a development version diff --git a/r/src/safe-call-into-r-impl.cpp b/r/src/safe-call-into-r-impl.cpp index 4eec3a85df8..6b3ebc9fccb 100644 --- a/r/src/safe-call-into-r-impl.cpp +++ b/r/src/safe-call-into-r-impl.cpp @@ -32,13 +32,7 @@ void InitializeMainRThread() { GetMainRThread().Initialize(); } // [[arrow::export]] bool CanRunWithCapturedR() { #if defined(HAS_UNWIND_PROTECT) - static int on_old_windows = -1; - if (on_old_windows == -1) { - cpp11::function on_old_windows_fun = cpp11::package("arrow")["on_old_windows"]; - on_old_windows = on_old_windows_fun(); - } - - return !on_old_windows && GetMainRThread().Executor() == nullptr; + return GetMainRThread().Executor() == nullptr; #else return false; #endif diff --git a/r/src/safe-call-into-r.h b/r/src/safe-call-into-r.h index 08e8a8c11b6..5e24a3892b1 100644 --- a/r/src/safe-call-into-r.h +++ b/r/src/safe-call-into-r.h @@ -28,8 +28,7 @@ #include // Unwind protection was added in R 3.5 and some calls here use it -// and crash R in older versions (ARROW-16201). Crashes also occur -// on 32-bit R builds on R 3.6 and lower. Implementation provided +// and crash R in older versions (ARROW-16201). Implementation provided // in safe-call-into-r-impl.cpp so that we can skip some tests // when this feature is not provided. This also checks that there // is not already an event loop registered (via MainRThread::Executor()), @@ -163,8 +162,7 @@ static inline arrow::Status SafeCallIntoRVoid(std::function fun, template arrow::Result RunWithCapturedR(std::function()> make_arrow_call) { if (!CanRunWithCapturedR()) { - return arrow::Status::NotImplemented( - "RunWithCapturedR() without UnwindProtect or on 32-bit Windows + R <= 3.6"); + return arrow::Status::NotImplemented("RunWithCapturedR() without UnwindProtect"); } if (GetMainRThread().Executor() != nullptr) { diff --git a/r/tests/testthat/test-dataset-csv.R b/r/tests/testthat/test-dataset-csv.R index b718bce2ffd..0718746624e 100644 --- a/r/tests/testthat/test-dataset-csv.R +++ b/r/tests/testthat/test-dataset-csv.R @@ -42,10 +42,8 @@ test_that("CSV dataset", { expect_r6_class(ds$format, "CsvFileFormat") expect_r6_class(ds$filesystem, "LocalFileSystem") expect_identical(names(ds), c(names(df1), "part")) - if (getRversion() >= "4.0.0") { - # CountRows segfaults on RTools35/R 3.6, so don't test it there - expect_identical(dim(ds), c(20L, 7L)) - } + expect_identical(dim(ds), c(20L, 7L)) + expect_equal( ds %>% select(string = chr, integer = int, part) %>% diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R index e6e361483a4..fee1475a44e 100644 --- a/r/tests/testthat/test-dplyr-arrange.R +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -skip_if(on_old_windows()) - library(dplyr, warn.conflicts = FALSE) # randomize order of rows in test data diff --git a/r/tests/testthat/test-dplyr-collapse.R b/r/tests/testthat/test-dplyr-collapse.R index f1b4f9cea3a..1809cb6e388 100644 --- a/r/tests/testthat/test-dplyr-collapse.R +++ b/r/tests/testthat/test-dplyr-collapse.R @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -skip_if(on_old_windows()) - withr::local_options(list(arrow.summarise.sort = TRUE)) library(dplyr, warn.conflicts = FALSE) diff --git a/r/tests/testthat/test-dplyr-count.R b/r/tests/testthat/test-dplyr-count.R index b94cc10753f..d263a7576f5 100644 --- a/r/tests/testthat/test-dplyr-count.R +++ b/r/tests/testthat/test-dplyr-count.R @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -skip_if(on_old_windows()) - library(dplyr, warn.conflicts = FALSE) tbl <- example_data diff --git a/r/tests/testthat/test-dplyr-distinct.R b/r/tests/testthat/test-dplyr-distinct.R index 8b42614084a..c679794d419 100644 --- a/r/tests/testthat/test-dplyr-distinct.R +++ b/r/tests/testthat/test-dplyr-distinct.R @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -skip_if(on_old_windows()) - library(dplyr, warn.conflicts = FALSE) tbl <- example_data diff --git a/r/tests/testthat/test-dplyr-filter.R b/r/tests/testthat/test-dplyr-filter.R index aed46d801ce..e019a91cac4 100644 --- a/r/tests/testthat/test-dplyr-filter.R +++ b/r/tests/testthat/test-dplyr-filter.R @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -skip_if(on_old_windows()) - library(dplyr, warn.conflicts = FALSE) library(stringr) diff --git a/r/tests/testthat/test-dplyr-funcs-conditional.R b/r/tests/testthat/test-dplyr-funcs-conditional.R index 4898d1e9e3e..e1dcd7bb091 100644 --- a/r/tests/testthat/test-dplyr-funcs-conditional.R +++ b/r/tests/testthat/test-dplyr-funcs-conditional.R @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -skip_if(on_old_windows()) - library(dplyr, warn.conflicts = FALSE) suppressPackageStartupMessages(library(bit64)) diff --git a/r/tests/testthat/test-dplyr-funcs-datetime.R b/r/tests/testthat/test-dplyr-funcs-datetime.R index 25fe23a28db..1f13eda74fe 100644 --- a/r/tests/testthat/test-dplyr-funcs-datetime.R +++ b/r/tests/testthat/test-dplyr-funcs-datetime.R @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -skip_if(on_old_windows()) # In 3.4 the lack of tzone attribute causes spurious failures skip_on_r_older_than("3.5") diff --git a/r/tests/testthat/test-dplyr-funcs-math.R b/r/tests/testthat/test-dplyr-funcs-math.R index 5f7da452395..b9a6a3707d4 100644 --- a/r/tests/testthat/test-dplyr-funcs-math.R +++ b/r/tests/testthat/test-dplyr-funcs-math.R @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -skip_if(on_old_windows()) - library(dplyr, warn.conflicts = FALSE) diff --git a/r/tests/testthat/test-dplyr-funcs-string.R b/r/tests/testthat/test-dplyr-funcs-string.R index 4574e33e748..229347372ae 100644 --- a/r/tests/testthat/test-dplyr-funcs-string.R +++ b/r/tests/testthat/test-dplyr-funcs-string.R @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -skip_if(on_old_windows()) skip_if_not_available("utf8proc") library(dplyr, warn.conflicts = FALSE) diff --git a/r/tests/testthat/test-dplyr-funcs-type.R b/r/tests/testthat/test-dplyr-funcs-type.R index 3f274b97f7f..5770e6ff439 100644 --- a/r/tests/testthat/test-dplyr-funcs-type.R +++ b/r/tests/testthat/test-dplyr-funcs-type.R @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -skip_if(on_old_windows()) - library(dplyr, warn.conflicts = FALSE) suppressPackageStartupMessages(library(bit64)) suppressPackageStartupMessages(library(lubridate)) diff --git a/r/tests/testthat/test-dplyr-group-by.R b/r/tests/testthat/test-dplyr-group-by.R index 08d6a77d3d1..c7380e96ec3 100644 --- a/r/tests/testthat/test-dplyr-group-by.R +++ b/r/tests/testthat/test-dplyr-group-by.R @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -skip_if(on_old_windows()) - library(dplyr, warn.conflicts = FALSE) library(stringr) diff --git a/r/tests/testthat/test-dplyr-join.R b/r/tests/testthat/test-dplyr-join.R index 9d8e22596a6..74ad5fa328e 100644 --- a/r/tests/testthat/test-dplyr-join.R +++ b/r/tests/testthat/test-dplyr-join.R @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -skip_if(on_old_windows()) - library(dplyr, warn.conflicts = FALSE) left <- example_data diff --git a/r/tests/testthat/test-dplyr-mutate.R b/r/tests/testthat/test-dplyr-mutate.R index f1de5c70454..a6f4e49be34 100644 --- a/r/tests/testthat/test-dplyr-mutate.R +++ b/r/tests/testthat/test-dplyr-mutate.R @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -skip_if(on_old_windows()) - library(dplyr, warn.conflicts = FALSE) library(stringr) diff --git a/r/tests/testthat/test-dplyr-query.R b/r/tests/testthat/test-dplyr-query.R index 1a5b6ec8a7c..b08016f2170 100644 --- a/r/tests/testthat/test-dplyr-query.R +++ b/r/tests/testthat/test-dplyr-query.R @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -skip_if(on_old_windows()) - library(dplyr, warn.conflicts = FALSE) library(stringr) diff --git a/r/tests/testthat/test-dplyr-select.R b/r/tests/testthat/test-dplyr-select.R index fa5af734cb1..98dcd6396d9 100644 --- a/r/tests/testthat/test-dplyr-select.R +++ b/r/tests/testthat/test-dplyr-select.R @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -skip_if(on_old_windows()) - library(dplyr, warn.conflicts = FALSE) library(stringr) diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index 0ee0c5739db..283d5d77837 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. -skip_if(on_old_windows()) - withr::local_options(list( arrow.summarise.sort = TRUE, rlib_warning_verbosity = "verbose", diff --git a/r/tests/testthat/test-dplyr-union.R b/r/tests/testthat/test-dplyr-union.R index 5cc6f8eea57..1bf8610c560 100644 --- a/r/tests/testthat/test-dplyr-union.R +++ b/r/tests/testthat/test-dplyr-union.R @@ -13,9 +13,6 @@ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations -# under the License. - -skip_if(on_old_windows()) library(dplyr, warn.conflicts = FALSE) diff --git a/r/vignettes/dataset.Rmd b/r/vignettes/dataset.Rmd index 0890d36ff42..e58922c23a0 100644 --- a/r/vignettes/dataset.Rmd +++ b/r/vignettes/dataset.Rmd @@ -24,7 +24,7 @@ The total file size is around 37 gigabytes, even in the efficient Parquet file format. That's bigger than memory on most people's computers, so you can't just read it all in and stack it into a single data frame. -In Windows (for R > 3.6) and macOS binary packages, S3 support is included. +In Windows and macOS binary packages, S3 support is included. On Linux, when installing from source, S3 support is not enabled by default, and it has additional system requirements. See `vignette("install", package = "arrow")` for details. From 21491ec0fa5fad2eb20bfedb3a19873f08e7e895 Mon Sep 17 00:00:00 2001 From: Igor Suhorukov Date: Wed, 7 Sep 2022 23:38:06 +0300 Subject: [PATCH 012/133] ARROW-17525: [Java] Read ORC files using NativeDatasetFactory (#13973) Support ORC file format in java Dataset API Authored-by: igor.suhorukov Signed-off-by: David Li --- java/dataset/pom.xml | 32 ++++++++++++++ java/dataset/src/main/cpp/jni_wrapper.cc | 2 + .../apache/arrow/dataset/file/FileFormat.java | 1 + .../apache/arrow/dataset/OrcWriteSupport.java | 42 +++++++++++++++++++ .../dataset/file/TestFileSystemDataset.java | 31 ++++++++++++++ 5 files changed, 108 insertions(+) create mode 100644 java/dataset/src/test/java/org/apache/arrow/dataset/OrcWriteSupport.java diff --git a/java/dataset/pom.xml b/java/dataset/pom.xml index 9eadf896888..4d5df9c45e7 100644 --- a/java/dataset/pom.xml +++ b/java/dataset/pom.xml @@ -109,6 +109,38 @@ jackson-databind test + + org.apache.arrow.orc + arrow-orc + ${project.version} + test + + + org.apache.orc + orc-core + 1.7.6 + test + + + log4j + log4j + + + org.slf4j + slf4j-log4j12 + + + commons-logging + commons-logging + + + + + org.apache.hive + hive-storage-api + 2.8.1 + test + diff --git a/java/dataset/src/main/cpp/jni_wrapper.cc b/java/dataset/src/main/cpp/jni_wrapper.cc index d0881639034..ef9178b1b5d 100644 --- a/java/dataset/src/main/cpp/jni_wrapper.cc +++ b/java/dataset/src/main/cpp/jni_wrapper.cc @@ -91,6 +91,8 @@ arrow::Result> GetFileFormat( return std::make_shared(); case 1: return std::make_shared(); + case 2: + return std::make_shared(); default: std::string error_message = "illegal file format id: " + std::to_string(file_format_id); diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileFormat.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileFormat.java index 343e458ce23..b428b254b10 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileFormat.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileFormat.java @@ -23,6 +23,7 @@ public enum FileFormat { PARQUET(0), ARROW_IPC(1), + ORC(2), NONE(-1); private final int id; diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/OrcWriteSupport.java b/java/dataset/src/test/java/org/apache/arrow/dataset/OrcWriteSupport.java new file mode 100644 index 00000000000..c49612995ee --- /dev/null +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/OrcWriteSupport.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset; + +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.hive.ql.exec.vector.LongColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.VectorizedRowBatch; +import org.apache.orc.OrcFile; +import org.apache.orc.TypeDescription; +import org.apache.orc.Writer; + +public class OrcWriteSupport { + public static void writeTempFile(TypeDescription orcSchema, Path path, Integer[] values) throws IOException { + Writer writer = OrcFile.createWriter(path, OrcFile.writerOptions(new Configuration()).setSchema(orcSchema)); + VectorizedRowBatch batch = orcSchema.createRowBatch(); + LongColumnVector longColumnVector = (LongColumnVector) batch.cols[0]; + for (int idx = 0; idx < values.length; idx++) { + longColumnVector.vector[idx] = values[idx]; + } + batch.size = values.length; + writer.addRowBatch(batch); + writer.close(); + } +} diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java index 2fd8a19bac1..b8d51a3edb1 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java @@ -37,6 +37,7 @@ import java.util.concurrent.Executors; import java.util.stream.Collectors; +import org.apache.arrow.dataset.OrcWriteSupport; import org.apache.arrow.dataset.ParquetWriteSupport; import org.apache.arrow.dataset.jni.NativeDataset; import org.apache.arrow.dataset.jni.NativeInstanceReleasedException; @@ -59,6 +60,8 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.apache.avro.generic.GenericRecord; import org.apache.avro.generic.GenericRecordBuilder; +import org.apache.hadoop.fs.Path; +import org.apache.orc.TypeDescription; import org.junit.Assert; import org.junit.ClassRule; import org.junit.Test; @@ -357,6 +360,34 @@ public void testBaseArrowIpcRead() throws Exception { AutoCloseables.close(factory); } + @Test + public void testBaseOrcRead() throws Exception { + String dataName = "test-orc"; + String basePath = TMP.getRoot().getAbsolutePath(); + + TypeDescription orcSchema = TypeDescription.fromString("struct"); + Path path = new Path(basePath, dataName); + OrcWriteSupport.writeTempFile(orcSchema, path, new Integer[]{Integer.MIN_VALUE, Integer.MAX_VALUE}); + + String orcDatasetUri = new File(basePath, dataName).toURI().toString(); + FileSystemDatasetFactory factory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.ORC, orcDatasetUri); + ScanOptions options = new ScanOptions(100); + Schema schema = inferResultSchemaFromFactory(factory, options); + List datum = collectResultFromFactory(factory, options); + + assertSingleTaskProduced(factory, options); + assertEquals(1, datum.size()); + assertEquals(1, schema.getFields().size()); + assertEquals("ints", schema.getFields().get(0).getName()); + + String expectedJsonUnordered = "[[2147483647], [-2147483648]]"; + checkParquetReadResult(schema, expectedJsonUnordered, datum); + + AutoCloseables.close(datum); + AutoCloseables.close(factory); + } + private void checkParquetReadResult(Schema schema, String expectedJson, List actual) throws IOException { final ObjectMapper json = new ObjectMapper(); From d123277bf0a261cc9fc479a376ac9420a9420eea Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 7 Sep 2022 17:31:58 -0400 Subject: [PATCH 013/133] ARROW-17600: [Go] Implement Casting for Nested types (#14056) Authored-by: Matt Topol Signed-off-by: Matt Topol --- ci/scripts/go_test.sh | 15 +- go/arrow/array/array_test.go | 1 + go/arrow/array/struct.go | 34 ++ go/arrow/compute/cast.go | 202 ++++++++ go/arrow/compute/cast_test.go | 456 +++++++++++++++++- go/arrow/compute/datum.go | 3 + go/arrow/compute/executor.go | 3 + .../compute/internal/kernels/cast_nested.go | 17 + .../compute/internal/kernels/cast_numeric.go | 22 +- go/arrow/compute/internal/kernels/helpers.go | 4 +- go/arrow/datatype.go | 2 + go/arrow/datatype_fixedwidth.go | 16 +- go/arrow/datatype_numeric.gen.go | 12 + go/arrow/datatype_numeric.gen.go.tmpl | 1 + go/arrow/scalar/binary.go | 14 +- go/arrow/scalar/nested.go | 42 +- go/arrow/scalar/scalar.go | 14 +- 17 files changed, 821 insertions(+), 37 deletions(-) create mode 100644 go/arrow/compute/internal/kernels/cast_nested.go diff --git a/ci/scripts/go_test.sh b/ci/scripts/go_test.sh index 0c07e0fc6bf..e31fa555642 100755 --- a/ci/scripts/go_test.sh +++ b/ci/scripts/go_test.sh @@ -19,12 +19,20 @@ set -ex +# simplistic semver comparison +verlte() { + [ "$1" = "`echo -e "$1\n$2" | sort -V | head -n1`" ] +} +verlt() { + [ "$1" = "$2" ] && return 1 || verlte $1 $2 +} + ver=`go env GOVERSION` source_dir=${1}/go testargs="-race" -if [[ "${ver#go}" =~ ^1\.1[8-9] ]] && [ "$(go env GOOS)" != "darwin" ]; then +if verlte "1.18" "${ver#go}" && [ "$(go env GOOS)" != "darwin" ]; then # asan not supported on darwin/amd64 testargs="-asan" fi @@ -65,6 +73,11 @@ fi go test $testargs -tags $TAGS ./... +# only test compute when Go is >= 1.18 +if verlte "1.18" "${ver#go}"; then + go test $testargs -tags $TAGS ./compute/... +fi + popd export PARQUET_TEST_DATA=${1}/cpp/submodules/parquet-testing/data diff --git a/go/arrow/array/array_test.go b/go/arrow/array/array_test.go index f2cee669fa3..d93e8d0c048 100644 --- a/go/arrow/array/array_test.go +++ b/go/arrow/array/array_test.go @@ -34,6 +34,7 @@ type testDataType struct { func (d *testDataType) ID() arrow.Type { return d.id } func (d *testDataType) Name() string { panic("implement me") } func (d *testDataType) BitWidth() int { return 8 } +func (d *testDataType) Bytes() int { return 1 } func (d *testDataType) Fingerprint() string { return "" } func (testDataType) Layout() arrow.DataTypeLayout { return arrow.DataTypeLayout{} } func (testDataType) String() string { return "" } diff --git a/go/arrow/array/struct.go b/go/arrow/array/struct.go index 2adf17623c0..213febfa416 100644 --- a/go/arrow/array/struct.go +++ b/go/arrow/array/struct.go @@ -36,6 +36,40 @@ type Struct struct { fields []arrow.Array } +// NewStructArray constructs a new Struct Array out of the columns passed +// in and the field names. The length of all cols must be the same and +// there should be the same number of columns as names. +func NewStructArray(cols []arrow.Array, names []string) (*Struct, error) { + return NewStructArrayWithNulls(cols, names, nil, 0, 0) +} + +// NewStructArrayWithNulls is like NewStructArray as a convenience function, +// but also takes in a null bitmap, the number of nulls, and an optional offset +// to use for creating the Struct Array. +func NewStructArrayWithNulls(cols []arrow.Array, names []string, nullBitmap *memory.Buffer, nullCount int, offset int) (*Struct, error) { + if len(cols) != len(names) { + return nil, fmt.Errorf("%w: mismatching number of fields and child arrays", arrow.ErrInvalid) + } + if len(cols) == 0 { + return nil, fmt.Errorf("%w: can't infer struct array length with 0 child arrays", arrow.ErrInvalid) + } + length := cols[0].Len() + children := make([]arrow.ArrayData, len(cols)) + fields := make([]arrow.Field, len(cols)) + for i, c := range cols { + if length != c.Len() { + return nil, fmt.Errorf("%w: mismatching child array lengths", arrow.ErrInvalid) + } + children[i] = c.Data() + fields[i].Name = names[i] + fields[i].Type = c.DataType() + fields[i].Nullable = true + } + data := NewData(arrow.StructOf(fields...), length, []*memory.Buffer{nullBitmap}, children, nullCount, offset) + defer data.Release() + return NewStructData(data), nil +} + // NewStructData returns a new Struct array value from data. func NewStructData(data arrow.ArrayData) *Struct { a := &Struct{} diff --git a/go/arrow/compute/cast.go b/go/arrow/compute/cast.go index 2480cbfb358..1ef8b77d5c3 100644 --- a/go/arrow/compute/cast.go +++ b/go/arrow/compute/cast.go @@ -23,6 +23,7 @@ import ( "github.com/apache/arrow/go/v10/arrow" "github.com/apache/arrow/go/v10/arrow/array" + "github.com/apache/arrow/go/v10/arrow/bitutil" "github.com/apache/arrow/go/v10/arrow/compute/internal/exec" "github.com/apache/arrow/go/v10/arrow/compute/internal/kernels" ) @@ -150,6 +151,156 @@ func CastFromExtension(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.Exec return nil } +func CastList[SrcOffsetT, DestOffsetT int32 | int64](ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + var ( + opts = ctx.State.(kernels.CastState) + childType = out.Type.(arrow.NestedType).Fields()[0].Type + input = &batch.Values[0].Array + offsets = exec.GetSpanOffsets[SrcOffsetT](input, 1) + isDowncast = kernels.SizeOf[SrcOffsetT]() > kernels.SizeOf[DestOffsetT]() + ) + + out.Buffers[0] = input.Buffers[0] + out.Buffers[1] = input.Buffers[1] + + if input.Offset != 0 && len(input.Buffers[0].Buf) > 0 { + out.Buffers[0].WrapBuffer(ctx.AllocateBitmap(input.Len)) + bitutil.CopyBitmap(input.Buffers[0].Buf, int(input.Offset), int(input.Len), + out.Buffers[0].Buf, 0) + } + + // Handle list offsets + // Several cases possible: + // - The source offset is non-zero, in which case we slice the + // underlying values and shift the list offsets (regardless of + // their respective types) + // - the source offset is zero but the source and destination types + // have different list offset types, in which case we cast the offsets + // - otherwise we simply keep the original offsets + if isDowncast { + if offsets[input.Len] > SrcOffsetT(kernels.MaxOf[DestOffsetT]()) { + return fmt.Errorf("%w: array of type %s too large to convert to %s", + arrow.ErrInvalid, input.Type, out.Type) + } + } + + values := input.Children[0].MakeArray() + defer values.Release() + + if input.Offset != 0 { + out.Buffers[1].WrapBuffer( + ctx.Allocate(out.Type.(arrow.OffsetsDataType). + OffsetTypeTraits().BytesRequired(int(input.Len) + 1))) + + shiftedOffsets := exec.GetSpanOffsets[DestOffsetT](out, 1) + for i := 0; i < int(input.Len)+1; i++ { + shiftedOffsets[i] = DestOffsetT(offsets[i] - offsets[0]) + } + + values = array.NewSlice(values, int64(offsets[0]), int64(offsets[input.Len])) + defer values.Release() + } else if kernels.SizeOf[SrcOffsetT]() != kernels.SizeOf[DestOffsetT]() { + out.Buffers[1].WrapBuffer(ctx.Allocate(out.Type.(arrow.OffsetsDataType). + OffsetTypeTraits().BytesRequired(int(input.Len) + 1))) + + kernels.DoStaticCast(exec.GetSpanOffsets[SrcOffsetT](input, 1), + exec.GetSpanOffsets[DestOffsetT](out, 1)) + } + + // handle values + opts.ToType = childType + + castedValues, err := CastArray(ctx.Ctx, values, &opts) + if err != nil { + return err + } + defer castedValues.Release() + + out.Children = make([]exec.ArraySpan, 1) + out.Children[0].SetMembers(castedValues.Data()) + for i, b := range out.Children[0].Buffers { + if b.Owner != nil && b.Owner != values.Data().Buffers()[i] { + b.Owner.Retain() + b.SelfAlloc = true + } + } + return nil +} + +func CastStruct(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + var ( + opts = ctx.State.(kernels.CastState) + inType = batch.Values[0].Array.Type.(*arrow.StructType) + outType = out.Type.(*arrow.StructType) + inFieldCount = len(inType.Fields()) + outFieldCount = len(outType.Fields()) + ) + + fieldsToSelect := make([]int, outFieldCount) + for i := range fieldsToSelect { + fieldsToSelect[i] = -1 + } + + outFieldIndex := 0 + for inFieldIndex := 0; inFieldIndex < inFieldCount && outFieldIndex < outFieldCount; inFieldIndex++ { + inField := inType.Field(inFieldIndex) + outField := outType.Field(outFieldIndex) + if inField.Name == outField.Name { + if inField.Nullable && !outField.Nullable { + return fmt.Errorf("%w: cannot cast nullable field to non-nullable field: %s %s", + arrow.ErrType, inType, outType) + } + fieldsToSelect[outFieldIndex] = inFieldIndex + outFieldIndex++ + } + } + + if outFieldIndex < outFieldCount { + return fmt.Errorf("%w: struct fields don't match or are in the wrong order: Input: %s Output: %s", + arrow.ErrType, inType, outType) + } + + input := &batch.Values[0].Array + if len(input.Buffers[0].Buf) > 0 { + out.Buffers[0].WrapBuffer(ctx.AllocateBitmap(input.Len)) + bitutil.CopyBitmap(input.Buffers[0].Buf, int(input.Offset), int(input.Len), + out.Buffers[0].Buf, 0) + } + + out.Children = make([]exec.ArraySpan, outFieldCount) + for outFieldIndex, idx := range fieldsToSelect { + values := input.Children[idx].MakeArray() + defer values.Release() + values = array.NewSlice(values, input.Offset, input.Len) + defer values.Release() + + opts.ToType = outType.Field(outFieldIndex).Type + castedValues, err := CastArray(ctx.Ctx, values, &opts) + if err != nil { + return err + } + defer castedValues.Release() + + out.Children[outFieldIndex].TakeOwnership(castedValues.Data()) + } + return nil +} + +func addListCast[SrcOffsetT, DestOffsetT int32 | int64](fn *castFunction, inType arrow.Type) error { + kernel := exec.NewScalarKernel([]exec.InputType{exec.NewIDInput(inType)}, + kernels.OutputTargetType, CastList[SrcOffsetT, DestOffsetT], nil) + kernel.NullHandling = exec.NullComputedNoPrealloc + kernel.MemAlloc = exec.MemNoPrealloc + return fn.AddTypeCast(inType, kernel) +} + +func addStructToStructCast(fn *castFunction) error { + kernel := exec.NewScalarKernel([]exec.InputType{exec.NewIDInput(arrow.STRUCT)}, + kernels.OutputTargetType, CastStruct, nil) + kernel.NullHandling = exec.NullComputedNoPrealloc + return fn.AddTypeCast(arrow.STRUCT, kernel) +} + func addCastFuncs(fn []*castFunction) { for _, f := range fn { f.AddNewTypeCast(arrow.EXTENSION, []exec.InputType{exec.NewIDInput(arrow.EXTENSION)}, @@ -165,6 +316,12 @@ func initCastTable() { addCastFuncs(getNumericCasts()) addCastFuncs(getBinaryLikeCasts()) addCastFuncs(getTemporalCasts()) + addCastFuncs(getNestedCasts()) + + nullToExt := newCastFunction("cast_extension", arrow.EXTENSION) + nullToExt.AddNewTypeCast(arrow.NULL, []exec.InputType{exec.NewExactInput(arrow.Null)}, + kernels.OutputTargetType, kernels.CastFromNull, exec.NullComputedNoPrealloc, exec.MemNoPrealloc) + castTable[arrow.EXTENSION] = nullToExt } func getCastFunction(to arrow.DataType) (*castFunction, error) { @@ -178,6 +335,51 @@ func getCastFunction(to arrow.DataType) (*castFunction, error) { return nil, fmt.Errorf("%w: unsupported cast to %s", arrow.ErrNotImplemented, to) } +func getNestedCasts() []*castFunction { + out := make([]*castFunction, 0) + + addKernels := func(fn *castFunction, kernels []exec.ScalarKernel) { + for _, k := range kernels { + if err := fn.AddTypeCast(k.Signature.InputTypes[0].MatchID(), k); err != nil { + panic(err) + } + } + } + + castLists := newCastFunction("cast_list", arrow.LIST) + addKernels(castLists, kernels.GetCommonCastKernels(arrow.LIST, kernels.OutputTargetType)) + if err := addListCast[int32, int32](castLists, arrow.LIST); err != nil { + panic(err) + } + if err := addListCast[int64, int32](castLists, arrow.LARGE_LIST); err != nil { + panic(err) + } + out = append(out, castLists) + + castLargeLists := newCastFunction("cast_large_list", arrow.LARGE_LIST) + addKernels(castLargeLists, kernels.GetCommonCastKernels(arrow.LARGE_LIST, kernels.OutputTargetType)) + if err := addListCast[int32, int64](castLargeLists, arrow.LIST); err != nil { + panic(err) + } + if err := addListCast[int64, int64](castLargeLists, arrow.LARGE_LIST); err != nil { + panic(err) + } + out = append(out, castLargeLists) + + castFsl := newCastFunction("cast_fixed_size_list", arrow.FIXED_SIZE_LIST) + addKernels(castFsl, kernels.GetCommonCastKernels(arrow.FIXED_SIZE_LIST, kernels.OutputTargetType)) + out = append(out, castFsl) + + castStruct := newCastFunction("cast_struct", arrow.STRUCT) + addKernels(castStruct, kernels.GetCommonCastKernels(arrow.STRUCT, kernels.OutputTargetType)) + if err := addStructToStructCast(castStruct); err != nil { + panic(err) + } + out = append(out, castStruct) + + return out +} + func getBooleanCasts() []*castFunction { fn := newCastFunction("cast_boolean", arrow.BOOL) kns := kernels.GetBooleanCastKernels() diff --git a/go/arrow/compute/cast_test.go b/go/arrow/compute/cast_test.go index 807a5c281d9..e98ecd2b7d6 100644 --- a/go/arrow/compute/cast_test.go +++ b/go/arrow/compute/cast_test.go @@ -93,12 +93,18 @@ func checkScalarNonRecursive(t *testing.T, funcName string, inputs []compute.Dat func checkScalarWithScalars(t *testing.T, funcName string, inputs []scalar.Scalar, expected scalar.Scalar, opts compute.FunctionOptions) { datums := getDatums(inputs) defer func() { + for _, s := range inputs { + if r, ok := s.(scalar.Releasable); ok { + r.Release() + } + } for _, d := range datums { d.Release() } }() out, err := compute.CallFunction(context.Background(), funcName, opts, datums...) assert.NoError(t, err) + defer out.Release() if !scalar.Equals(out.(*compute.ScalarDatum).Value, expected) { var b strings.Builder b.WriteString(funcName + "(") @@ -145,6 +151,9 @@ func checkScalar(t *testing.T, funcName string, inputs []compute.Datum, expected for i := 0; i < exp.Len(); i++ { e, _ := scalar.GetScalar(exp, i) checkScalarWithScalars(t, funcName, getScalars(inputs, i), e, opts) + if r, ok := e.(scalar.Releasable); ok { + r.Release() + } } } @@ -174,6 +183,9 @@ func checkCastFails(t *testing.T, input arrow.Array, opt compute.CastOptions) { nfail := 0 for i := 0; i < input.Len(); i++ { sc, _ := scalar.GetScalar(input, i) + if r, ok := sc.(scalar.Releasable); ok { + defer r.Release() + } d := compute.NewDatum(sc) defer d.Release() out, err := compute.CastDatum(context.Background(), d, &opt) @@ -323,20 +335,18 @@ func (c *CastSuite) TestCanCast() { expectCanCast(from, toSet, false) } - // will uncomment lines as support for those casts is added - canCast(arrow.Null, []arrow.DataType{arrow.FixedWidthTypes.Boolean}) canCast(arrow.Null, numericTypes) canCast(arrow.Null, baseBinaryTypes) canCast(arrow.Null, []arrow.DataType{ arrow.FixedWidthTypes.Date32, arrow.FixedWidthTypes.Date64, arrow.FixedWidthTypes.Time32ms, arrow.FixedWidthTypes.Timestamp_s, }) - // canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Uint16, ValueType: arrow.Null}, []arrow.DataType{arrow.Null}) + cannotCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Uint16, ValueType: arrow.Null}, []arrow.DataType{arrow.Null}) canCast(arrow.FixedWidthTypes.Boolean, []arrow.DataType{arrow.FixedWidthTypes.Boolean}) canCast(arrow.FixedWidthTypes.Boolean, numericTypes) canCast(arrow.FixedWidthTypes.Boolean, []arrow.DataType{arrow.BinaryTypes.String, arrow.BinaryTypes.LargeString}) - // canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, ValueType: arrow.FixedWidthTypes.Boolean}, []arrow.DataType{arrow.FixedWidthTypes.Boolean}) + cannotCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, ValueType: arrow.FixedWidthTypes.Boolean}, []arrow.DataType{arrow.FixedWidthTypes.Boolean}) cannotCast(arrow.FixedWidthTypes.Boolean, []arrow.DataType{arrow.Null}) cannotCast(arrow.FixedWidthTypes.Boolean, []arrow.DataType{arrow.BinaryTypes.Binary, arrow.BinaryTypes.LargeBinary}) @@ -347,16 +357,16 @@ func (c *CastSuite) TestCanCast() { canCast(from, []arrow.DataType{arrow.FixedWidthTypes.Boolean}) canCast(from, numericTypes) canCast(from, []arrow.DataType{arrow.BinaryTypes.String, arrow.BinaryTypes.LargeString}) - // canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, ValueType: from}, []arrow.DataType{from}) + cannotCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, ValueType: from}, []arrow.DataType{from}) cannotCast(from, []arrow.DataType{arrow.Null}) } for _, from := range baseBinaryTypes { canCast(from, []arrow.DataType{arrow.FixedWidthTypes.Boolean}) - // canCast(from, numericTypes) + canCast(from, numericTypes) canCast(from, baseBinaryTypes) - // canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int64, ValueType: from}, []arrow.DataType{from}) + cannotCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int64, ValueType: from}, []arrow.DataType{from}) // any cast which is valid for the dictionary is valid for the dictionary array // canCast(&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Uint32, ValueType: from}, baseBinaryTypes) @@ -365,8 +375,8 @@ func (c *CastSuite) TestCanCast() { cannotCast(from, []arrow.DataType{arrow.Null}) } - // canCast(arrow.BinaryTypes.String, []arrow.DataType{arrow.FixedWidthTypes.Timestamp_ms}) - // canCast(arrow.BinaryTypes.LargeString, []arrow.DataType{arrow.FixedWidthTypes.Timestamp_ns}) + canCast(arrow.BinaryTypes.String, []arrow.DataType{arrow.FixedWidthTypes.Timestamp_ms}) + canCast(arrow.BinaryTypes.LargeString, []arrow.DataType{arrow.FixedWidthTypes.Timestamp_ns}) // no formatting supported cannotCast(arrow.FixedWidthTypes.Timestamp_us, []arrow.DataType{arrow.BinaryTypes.Binary, arrow.BinaryTypes.LargeBinary}) @@ -376,9 +386,9 @@ func (c *CastSuite) TestCanCast() { arrow.RegisterExtensionType(types.NewSmallintType()) defer arrow.UnregisterExtensionType("smallint") - // canCast(types.NewSmallintType(), []arrow.DataType{arrow.PrimitiveTypes.Int16}) - // canCast(types.NewSmallintType(), numericTypes) // any cast which is valid for storage is supported - // canCast(arrow.Null, []arrow.DataType{types.NewSmallintType()}) + canCast(types.NewSmallintType(), []arrow.DataType{arrow.PrimitiveTypes.Int16}) + canCast(types.NewSmallintType(), numericTypes) // any cast which is valid for storage is supported + canCast(arrow.Null, []arrow.DataType{types.NewSmallintType()}) canCast(arrow.FixedWidthTypes.Date32, []arrow.DataType{arrow.BinaryTypes.String, arrow.BinaryTypes.LargeString}) canCast(arrow.FixedWidthTypes.Date64, []arrow.DataType{arrow.BinaryTypes.String, arrow.BinaryTypes.LargeString}) @@ -1562,14 +1572,14 @@ func (c *CastSuite) TestUnsupportedInputType() { toType := arrow.ListOf(arrow.BinaryTypes.String) _, err := compute.CastToType(context.Background(), arr, toType) c.ErrorIs(err, arrow.ErrNotImplemented) - c.ErrorContains(err, "unsupported cast to list from int32") + c.ErrorContains(err, "function 'cast_list' has no kernel matching input types (int32)") // test calling through the generic kernel API datum := compute.NewDatum(arr) defer datum.Release() _, err = compute.CallFunction(context.Background(), "cast", compute.SafeCastOptions(toType), datum) c.ErrorIs(err, arrow.ErrNotImplemented) - c.ErrorContains(err, "unsupported cast to list from int32") + c.ErrorContains(err, "function 'cast_list' has no kernel matching input types (int32)") } func (c *CastSuite) TestUnsupportedTargetType() { @@ -2242,6 +2252,424 @@ func (c *CastSuite) TestIdentityCasts() { c.checkCastSelfZeroCopy(arrow.FixedWidthTypes.Timestamp_s, `[1, 2, 3, 4]`) } +func (c *CastSuite) TestListToPrimitive() { + arr, _, _ := array.FromJSON(c.mem, arrow.ListOf(arrow.PrimitiveTypes.Int8), strings.NewReader(`[[1, 2], [3, 4]]`)) + defer arr.Release() + + _, err := compute.CastToType(context.Background(), arr, arrow.PrimitiveTypes.Uint8) + c.ErrorIs(err, arrow.ErrNotImplemented) +} + +type makeList func(arrow.DataType) arrow.DataType + +var listFactories = []makeList{ + func(dt arrow.DataType) arrow.DataType { return arrow.ListOf(dt) }, + func(dt arrow.DataType) arrow.DataType { return arrow.LargeListOf(dt) }, +} + +func (c *CastSuite) checkListToList(valTypes []arrow.DataType, jsonData string) { + for _, makeSrc := range listFactories { + for _, makeDest := range listFactories { + for _, srcValueType := range valTypes { + for _, dstValueType := range valTypes { + srcType := makeSrc(srcValueType) + dstType := makeDest(dstValueType) + c.Run(fmt.Sprintf("from %s to %s", srcType, dstType), func() { + c.checkCast(srcType, dstType, jsonData, jsonData) + }) + } + } + } + } +} + +func (c *CastSuite) TestListToList() { + c.checkListToList([]arrow.DataType{arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Float32, arrow.PrimitiveTypes.Int64}, + `[[0], [1], null, [2, 3, 4], [5, 6], null, [], [7], [8, 9]]`) +} + +func (c *CastSuite) TestListToListNoNulls() { + c.checkListToList([]arrow.DataType{arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Float32, arrow.PrimitiveTypes.Int64}, + `[[0], [1], [2, 3, 4], [5, 6], [], [7], [8, 9]]`) +} + +func (c *CastSuite) TestListToListOptionsPassthru() { + for _, makeSrc := range listFactories { + for _, makeDest := range listFactories { + opts := compute.SafeCastOptions(makeDest(arrow.PrimitiveTypes.Int16)) + c.checkCastFails(makeSrc(arrow.PrimitiveTypes.Int32), `[[87654321]]`, opts) + + opts.AllowIntOverflow = true + c.checkCastOpts(makeSrc(arrow.PrimitiveTypes.Int32), makeDest(arrow.PrimitiveTypes.Int16), + `[[87654321]]`, `[[32689]]`, *opts) + } + } +} + +func (c *CastSuite) checkStructToStruct(types []arrow.DataType) { + for _, srcType := range types { + c.Run(srcType.String(), func() { + for _, destType := range types { + c.Run(destType.String(), func() { + fieldNames := []string{"a", "b"} + a1, _, _ := array.FromJSON(c.mem, srcType, strings.NewReader(`[1, 2, 3, 4, null]`)) + b1, _, _ := array.FromJSON(c.mem, srcType, strings.NewReader(`[null, 7, 8, 9, 0]`)) + a2, _, _ := array.FromJSON(c.mem, destType, strings.NewReader(`[1, 2, 3, 4, null]`)) + b2, _, _ := array.FromJSON(c.mem, destType, strings.NewReader(`[null, 7, 8, 9, 0]`)) + src, _ := array.NewStructArray([]arrow.Array{a1, b1}, fieldNames) + dest, _ := array.NewStructArray([]arrow.Array{a2, b2}, fieldNames) + defer func() { + a1.Release() + b1.Release() + a2.Release() + b2.Release() + src.Release() + dest.Release() + }() + + checkCast(c.T(), src, dest, *compute.DefaultCastOptions(true)) + c.Run("with nulls", func() { + nullBitmap := memory.NewBufferBytes([]byte{10}) + srcNullData := src.Data().(*array.Data).Copy() + srcNullData.Buffers()[0] = nullBitmap + srcNullData.SetNullN(3) + defer srcNullData.Release() + destNullData := dest.Data().(*array.Data).Copy() + destNullData.Buffers()[0] = nullBitmap + destNullData.SetNullN(3) + defer destNullData.Release() + + srcNulls := array.NewStructData(srcNullData) + destNulls := array.NewStructData(destNullData) + defer srcNulls.Release() + defer destNulls.Release() + + checkCast(c.T(), srcNulls, destNulls, *compute.DefaultCastOptions(true)) + }) + }) + } + }) + } +} + +func (c *CastSuite) checkStructToStructSubset(types []arrow.DataType) { + for _, srcType := range types { + c.Run(srcType.String(), func() { + for _, destType := range types { + c.Run(destType.String(), func() { + fieldNames := []string{"a", "b", "c", "d", "e"} + + a1, _, _ := array.FromJSON(c.mem, srcType, strings.NewReader(`[1, 2, 5]`)) + defer a1.Release() + b1, _, _ := array.FromJSON(c.mem, srcType, strings.NewReader(`[3, 4, 7]`)) + defer b1.Release() + c1, _, _ := array.FromJSON(c.mem, srcType, strings.NewReader(`[9, 11, 44]`)) + defer c1.Release() + d1, _, _ := array.FromJSON(c.mem, srcType, strings.NewReader(`[6, 51, 49]`)) + defer d1.Release() + e1, _, _ := array.FromJSON(c.mem, srcType, strings.NewReader(`[19, 17, 74]`)) + defer e1.Release() + + a2, _, _ := array.FromJSON(c.mem, destType, strings.NewReader(`[1, 2, 5]`)) + defer a2.Release() + b2, _, _ := array.FromJSON(c.mem, destType, strings.NewReader(`[3, 4, 7]`)) + defer b2.Release() + c2, _, _ := array.FromJSON(c.mem, destType, strings.NewReader(`[9, 11, 44]`)) + defer c2.Release() + d2, _, _ := array.FromJSON(c.mem, destType, strings.NewReader(`[6, 51, 49]`)) + defer d2.Release() + e2, _, _ := array.FromJSON(c.mem, destType, strings.NewReader(`[19, 17, 74]`)) + defer e2.Release() + + src, _ := array.NewStructArray([]arrow.Array{a1, b1, c1, d1, e1}, fieldNames) + defer src.Release() + dest1, _ := array.NewStructArray([]arrow.Array{a2}, []string{"a"}) + defer dest1.Release() + + opts := *compute.DefaultCastOptions(true) + checkCast(c.T(), src, dest1, opts) + + dest2, _ := array.NewStructArray([]arrow.Array{b2, c2}, []string{"b", "c"}) + defer dest2.Release() + checkCast(c.T(), src, dest2, opts) + + dest3, _ := array.NewStructArray([]arrow.Array{c2, d2, e2}, []string{"c", "d", "e"}) + defer dest3.Release() + checkCast(c.T(), src, dest3, opts) + + dest4, _ := array.NewStructArray([]arrow.Array{a2, b2, c2, e2}, []string{"a", "b", "c", "e"}) + defer dest4.Release() + checkCast(c.T(), src, dest4, opts) + + dest5, _ := array.NewStructArray([]arrow.Array{a2, b2, c2, d2, e2}, []string{"a", "b", "c", "d", "e"}) + defer dest5.Release() + checkCast(c.T(), src, dest5, opts) + + // field does not exist + dest6 := arrow.StructOf( + arrow.Field{Name: "a", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + arrow.Field{Name: "d", Type: arrow.PrimitiveTypes.Int16, Nullable: true}, + arrow.Field{Name: "f", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + ) + options6 := compute.SafeCastOptions(dest6) + _, err := compute.CastArray(context.TODO(), src, options6) + c.ErrorIs(err, arrow.ErrType) + c.ErrorContains(err, "struct fields don't match or are in the wrong order") + + // fields in wrong order + dest7 := arrow.StructOf( + arrow.Field{Name: "a", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + arrow.Field{Name: "c", Type: arrow.PrimitiveTypes.Int16, Nullable: true}, + arrow.Field{Name: "b", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + ) + options7 := compute.SafeCastOptions(dest7) + _, err = compute.CastArray(context.TODO(), src, options7) + c.ErrorIs(err, arrow.ErrType) + c.ErrorContains(err, "struct fields don't match or are in the wrong order") + }) + } + }) + } +} + +func (c *CastSuite) checkStructToStructSubsetWithNulls(types []arrow.DataType) { + for _, srcType := range types { + c.Run(srcType.String(), func() { + for _, destType := range types { + c.Run(destType.String(), func() { + fieldNames := []string{"a", "b", "c", "d", "e"} + + a1, _, _ := array.FromJSON(c.mem, srcType, strings.NewReader(`[1, 2, 5]`)) + defer a1.Release() + b1, _, _ := array.FromJSON(c.mem, srcType, strings.NewReader(`[3, null, 7]`)) + defer b1.Release() + c1, _, _ := array.FromJSON(c.mem, srcType, strings.NewReader(`[9, 11, 44]`)) + defer c1.Release() + d1, _, _ := array.FromJSON(c.mem, srcType, strings.NewReader(`[6, 51, null]`)) + defer d1.Release() + e1, _, _ := array.FromJSON(c.mem, srcType, strings.NewReader(`[null, 17, 74]`)) + defer e1.Release() + + a2, _, _ := array.FromJSON(c.mem, destType, strings.NewReader(`[1, 2, 5]`)) + defer a2.Release() + b2, _, _ := array.FromJSON(c.mem, destType, strings.NewReader(`[3, null, 7]`)) + defer b2.Release() + c2, _, _ := array.FromJSON(c.mem, destType, strings.NewReader(`[9, 11, 44]`)) + defer c2.Release() + d2, _, _ := array.FromJSON(c.mem, destType, strings.NewReader(`[6, 51, null]`)) + defer d2.Release() + e2, _, _ := array.FromJSON(c.mem, destType, strings.NewReader(`[null, 17, 74]`)) + defer e2.Release() + + // 0, 1, 0 + nullBitmap := memory.NewBufferBytes([]byte{2}) + srcNull, _ := array.NewStructArrayWithNulls([]arrow.Array{a1, b1, c1, d1, e1}, fieldNames, nullBitmap, 2, 0) + defer srcNull.Release() + + dest1Null, _ := array.NewStructArrayWithNulls([]arrow.Array{a2}, []string{"a"}, nullBitmap, -1, 0) + defer dest1Null.Release() + opts := compute.DefaultCastOptions(true) + checkCast(c.T(), srcNull, dest1Null, *opts) + + dest2Null, _ := array.NewStructArrayWithNulls([]arrow.Array{b2, c2}, []string{"b", "c"}, nullBitmap, -1, 0) + defer dest2Null.Release() + checkCast(c.T(), srcNull, dest2Null, *opts) + + dest3Null, _ := array.NewStructArrayWithNulls([]arrow.Array{a2, d2, e2}, []string{"a", "d", "e"}, nullBitmap, -1, 0) + defer dest3Null.Release() + checkCast(c.T(), srcNull, dest3Null, *opts) + + dest4Null, _ := array.NewStructArrayWithNulls([]arrow.Array{a2, b2, c2, e2}, []string{"a", "b", "c", "e"}, nullBitmap, -1, 0) + defer dest4Null.Release() + checkCast(c.T(), srcNull, dest4Null, *opts) + + dest5Null, _ := array.NewStructArrayWithNulls([]arrow.Array{a2, b2, c2, d2, e2}, []string{"a", "b", "c", "d", "e"}, nullBitmap, -1, 0) + defer dest5Null.Release() + checkCast(c.T(), srcNull, dest5Null, *opts) + + // field does not exist + dest6Null := arrow.StructOf( + arrow.Field{Name: "a", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + arrow.Field{Name: "d", Type: arrow.PrimitiveTypes.Int16, Nullable: true}, + arrow.Field{Name: "f", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + ) + options6Null := compute.SafeCastOptions(dest6Null) + _, err := compute.CastArray(context.TODO(), srcNull, options6Null) + c.ErrorIs(err, arrow.ErrType) + c.ErrorContains(err, "struct fields don't match or are in the wrong order") + + // fields in wrong order + dest7Null := arrow.StructOf( + arrow.Field{Name: "a", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + arrow.Field{Name: "c", Type: arrow.PrimitiveTypes.Int16, Nullable: true}, + arrow.Field{Name: "b", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + ) + options7Null := compute.SafeCastOptions(dest7Null) + _, err = compute.CastArray(context.TODO(), srcNull, options7Null) + c.ErrorIs(err, arrow.ErrType) + c.ErrorContains(err, "struct fields don't match or are in the wrong order") + }) + } + }) + } +} + +func (c *CastSuite) TestStructToSameSizedAndNamedStruct() { + c.checkStructToStruct(numericTypes) +} + +func (c *CastSuite) TestStructToStructSubset() { + c.checkStructToStructSubset(numericTypes) +} + +func (c *CastSuite) TestStructToStructSubsetWithNulls() { + c.checkStructToStructSubsetWithNulls(numericTypes) +} + +func (c *CastSuite) TestStructToSameSizedButDifferentNamedStruct() { + fieldNames := []string{"a", "b"} + a, _, _ := array.FromJSON(c.mem, arrow.PrimitiveTypes.Int8, strings.NewReader(`[1, 2]`)) + defer a.Release() + b, _, _ := array.FromJSON(c.mem, arrow.PrimitiveTypes.Int8, strings.NewReader(`[3, 4]`)) + defer b.Release() + + src, _ := array.NewStructArray([]arrow.Array{a, b}, fieldNames) + defer src.Release() + + dest := arrow.StructOf( + arrow.Field{Name: "c", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + arrow.Field{Name: "d", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + ) + opts := compute.SafeCastOptions(dest) + _, err := compute.CastArray(context.TODO(), src, opts) + c.ErrorIs(err, arrow.ErrType) + c.ErrorContains(err, "struct fields don't match or are in the wrong order") +} + +func (c *CastSuite) TestStructToBiggerStruct() { + fieldNames := []string{"a", "b"} + a, _, _ := array.FromJSON(c.mem, arrow.PrimitiveTypes.Int8, strings.NewReader(`[1, 2]`)) + defer a.Release() + b, _, _ := array.FromJSON(c.mem, arrow.PrimitiveTypes.Int8, strings.NewReader(`[3, 4]`)) + defer b.Release() + + src, _ := array.NewStructArray([]arrow.Array{a, b}, fieldNames) + defer src.Release() + + dest := arrow.StructOf( + arrow.Field{Name: "a", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + arrow.Field{Name: "b", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + arrow.Field{Name: "c", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + ) + opts := compute.SafeCastOptions(dest) + _, err := compute.CastArray(context.TODO(), src, opts) + c.ErrorIs(err, arrow.ErrType) + c.ErrorContains(err, "struct fields don't match or are in the wrong order") +} + +func (c *CastSuite) TestStructToDifferentNullabilityStruct() { + c.Run("non-nullable to nullable", func() { + fieldsSrcNonNullable := []arrow.Field{ + {Name: "a", Type: arrow.PrimitiveTypes.Int8}, + {Name: "b", Type: arrow.PrimitiveTypes.Int8}, + {Name: "c", Type: arrow.PrimitiveTypes.Int8}, + } + srcNonNull, _, err := array.FromJSON(c.mem, arrow.StructOf(fieldsSrcNonNullable...), + strings.NewReader(`[ + {"a": 11, "b": 32, "c", 95}, + {"a": 23, "b": 46, "c": 11}, + {"a": 56, "b": 37, "c": 44} + ]`)) + c.Require().NoError(err) + defer srcNonNull.Release() + + fieldsDest1Nullable := []arrow.Field{ + {Name: "a", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + {Name: "b", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + {Name: "c", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + } + destNullable, _, err := array.FromJSON(c.mem, arrow.StructOf(fieldsDest1Nullable...), + strings.NewReader(`[ + {"a": 11, "b": 32, "c", 95}, + {"a": 23, "b": 46, "c": 11}, + {"a": 56, "b": 37, "c": 44} + ]`)) + c.Require().NoError(err) + defer destNullable.Release() + + checkCast(c.T(), srcNonNull, destNullable, *compute.DefaultCastOptions(true)) + + fieldsDest2Nullable := []arrow.Field{ + {Name: "a", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + {Name: "c", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + } + + data := array.NewData(arrow.StructOf(fieldsDest2Nullable...), destNullable.Len(), destNullable.Data().Buffers(), + []arrow.ArrayData{destNullable.Data().Children()[0], destNullable.Data().Children()[2]}, + destNullable.NullN(), 0) + defer data.Release() + dest2Nullable := array.NewStructData(data) + defer dest2Nullable.Release() + checkCast(c.T(), srcNonNull, dest2Nullable, *compute.DefaultCastOptions(true)) + + fieldsDest3Nullable := []arrow.Field{ + {Name: "b", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + } + + data = array.NewData(arrow.StructOf(fieldsDest3Nullable...), destNullable.Len(), destNullable.Data().Buffers(), + []arrow.ArrayData{destNullable.Data().Children()[1]}, destNullable.NullN(), 0) + defer data.Release() + dest3Nullable := array.NewStructData(data) + defer dest3Nullable.Release() + checkCast(c.T(), srcNonNull, dest3Nullable, *compute.DefaultCastOptions(true)) + }) + c.Run("non-nullable to nullable", func() { + fieldsSrcNullable := []arrow.Field{ + {Name: "a", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + {Name: "b", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + {Name: "c", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + } + srcNullable, _, err := array.FromJSON(c.mem, arrow.StructOf(fieldsSrcNullable...), + strings.NewReader(`[ + {"a": 1, "b": 3, "c", 9}, + {"a": null, "b": 4, "c": 11}, + {"a": 5, "b": null, "c": 44} + ]`)) + c.Require().NoError(err) + defer srcNullable.Release() + + fieldsDest1NonNullable := []arrow.Field{ + {Name: "a", Type: arrow.PrimitiveTypes.Int64, Nullable: false}, + {Name: "b", Type: arrow.PrimitiveTypes.Int64, Nullable: false}, + {Name: "c", Type: arrow.PrimitiveTypes.Int64, Nullable: false}, + } + dest1NonNullable := arrow.StructOf(fieldsDest1NonNullable...) + options1NoNullable := compute.SafeCastOptions(dest1NonNullable) + _, err = compute.CastArray(context.TODO(), srcNullable, options1NoNullable) + c.ErrorIs(err, arrow.ErrType) + c.ErrorContains(err, "cannot cast nullable field to non-nullable field") + + fieldsDest2NonNullable := []arrow.Field{ + {Name: "a", Type: arrow.PrimitiveTypes.Int64, Nullable: false}, + {Name: "c", Type: arrow.PrimitiveTypes.Int64, Nullable: false}, + } + dest2NonNullable := arrow.StructOf(fieldsDest2NonNullable...) + options2NoNullable := compute.SafeCastOptions(dest2NonNullable) + _, err = compute.CastArray(context.TODO(), srcNullable, options2NoNullable) + c.ErrorIs(err, arrow.ErrType) + c.ErrorContains(err, "cannot cast nullable field to non-nullable field") + + fieldsDest3NonNullable := []arrow.Field{ + {Name: "c", Type: arrow.PrimitiveTypes.Int64, Nullable: false}, + } + dest3NonNullable := arrow.StructOf(fieldsDest3NonNullable...) + options3NoNullable := compute.SafeCastOptions(dest3NonNullable) + _, err = compute.CastArray(context.TODO(), srcNullable, options3NoNullable) + c.ErrorIs(err, arrow.ErrType) + c.ErrorContains(err, "cannot cast nullable field to non-nullable field") + }) +} + func (c *CastSuite) smallIntArrayFromJSON(data string) arrow.Array { arr, _, _ := array.FromJSON(c.mem, types.NewSmallintType(), strings.NewReader(data)) return arr diff --git a/go/arrow/compute/datum.go b/go/arrow/compute/datum.go index 005126ecad4..805ce929f8b 100644 --- a/go/arrow/compute/datum.go +++ b/go/arrow/compute/datum.go @@ -275,6 +275,9 @@ func NewDatum(value interface{}) Datum { v.Retain() return &TableDatum{v} case scalar.Scalar: + if ls, ok := v.(scalar.Releasable); ok { + ls.Retain() + } return &ScalarDatum{v} default: return &ScalarDatum{scalar.MakeScalar(value)} diff --git a/go/arrow/compute/executor.go b/go/arrow/compute/executor.go index 72f6cf4623b..e406ab92529 100644 --- a/go/arrow/compute/executor.go +++ b/go/arrow/compute/executor.go @@ -649,6 +649,9 @@ func (s *scalarExecutor) emitResult(resultData *exec.ArraySpan, data chan<- Datu if err != nil { return err } + if r, ok := sc.(scalar.Releasable); ok { + defer r.Release() + } output = NewDatum(sc) } else { d := resultData.MakeData() diff --git a/go/arrow/compute/internal/kernels/cast_nested.go b/go/arrow/compute/internal/kernels/cast_nested.go new file mode 100644 index 00000000000..0fc92fbb0d4 --- /dev/null +++ b/go/arrow/compute/internal/kernels/cast_nested.go @@ -0,0 +1,17 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernels diff --git a/go/arrow/compute/internal/kernels/cast_numeric.go b/go/arrow/compute/internal/kernels/cast_numeric.go index c1ba809c879..d5d33d70947 100644 --- a/go/arrow/compute/internal/kernels/cast_numeric.go +++ b/go/arrow/compute/internal/kernels/cast_numeric.go @@ -24,7 +24,7 @@ import ( var castNumericUnsafe func(itype, otype arrow.Type, in, out []byte, len int) = castNumericGo -func doStaticCast[InT, OutT numeric](in []InT, out []OutT) { +func DoStaticCast[InT, OutT numeric](in []InT, out []OutT) { for i, v := range in { out[i] = OutT(v) } @@ -37,25 +37,25 @@ func reinterpret[T numeric](b []byte, len int) (res []T) { func castNumberToNumberUnsafeImpl[T numeric](outT arrow.Type, in []T, out []byte) { switch outT { case arrow.INT8: - doStaticCast(in, reinterpret[int8](out, len(in))) + DoStaticCast(in, reinterpret[int8](out, len(in))) case arrow.UINT8: - doStaticCast(in, reinterpret[uint8](out, len(in))) + DoStaticCast(in, reinterpret[uint8](out, len(in))) case arrow.INT16: - doStaticCast(in, reinterpret[int16](out, len(in))) + DoStaticCast(in, reinterpret[int16](out, len(in))) case arrow.UINT16: - doStaticCast(in, reinterpret[uint16](out, len(in))) + DoStaticCast(in, reinterpret[uint16](out, len(in))) case arrow.INT32: - doStaticCast(in, reinterpret[int32](out, len(in))) + DoStaticCast(in, reinterpret[int32](out, len(in))) case arrow.UINT32: - doStaticCast(in, reinterpret[uint32](out, len(in))) + DoStaticCast(in, reinterpret[uint32](out, len(in))) case arrow.INT64: - doStaticCast(in, reinterpret[int64](out, len(in))) + DoStaticCast(in, reinterpret[int64](out, len(in))) case arrow.UINT64: - doStaticCast(in, reinterpret[uint64](out, len(in))) + DoStaticCast(in, reinterpret[uint64](out, len(in))) case arrow.FLOAT32: - doStaticCast(in, reinterpret[float32](out, len(in))) + DoStaticCast(in, reinterpret[float32](out, len(in))) case arrow.FLOAT64: - doStaticCast(in, reinterpret[float64](out, len(in))) + DoStaticCast(in, reinterpret[float64](out, len(in))) } } diff --git a/go/arrow/compute/internal/kernels/helpers.go b/go/arrow/compute/internal/kernels/helpers.go index b09a9f44089..884fcd0a9ce 100644 --- a/go/arrow/compute/internal/kernels/helpers.go +++ b/go/arrow/compute/internal/kernels/helpers.go @@ -415,7 +415,9 @@ func castNumberToNumberUnsafe(in, out *exec.ArraySpan) { return } - castNumericUnsafe(in.Type.ID(), out.Type.ID(), in.Buffers[1].Buf, out.Buffers[1].Buf, int(in.Len)) + inputOffset := in.Type.(arrow.FixedWidthDataType).Bytes() * int(in.Offset) + outputOffset := out.Type.(arrow.FixedWidthDataType).Bytes() * int(out.Offset) + castNumericUnsafe(in.Type.ID(), out.Type.ID(), in.Buffers[1].Buf[inputOffset:], out.Buffers[1].Buf[outputOffset:], int(in.Len)) } func maxDecimalDigitsForInt(id arrow.Type) (int32, error) { diff --git a/go/arrow/datatype.go b/go/arrow/datatype.go index 2ffca317e7e..2cd27cf64d7 100644 --- a/go/arrow/datatype.go +++ b/go/arrow/datatype.go @@ -192,6 +192,8 @@ type FixedWidthDataType interface { DataType // BitWidth returns the number of bits required to store a single element of this data type in memory. BitWidth() int + // Bytes returns the number of bytes required to store a single element of this data type in memory. + Bytes() int } type BinaryDataType interface { diff --git a/go/arrow/datatype_fixedwidth.go b/go/arrow/datatype_fixedwidth.go index 374baaee89f..12f752d4a45 100644 --- a/go/arrow/datatype_fixedwidth.go +++ b/go/arrow/datatype_fixedwidth.go @@ -31,6 +31,7 @@ func (t *BooleanType) ID() Type { return BOOL } func (t *BooleanType) Name() string { return "bool" } func (t *BooleanType) String() string { return "bool" } func (t *BooleanType) Fingerprint() string { return typeFingerprint(t) } +func (BooleanType) Bytes() int { return 1 } // BitWidth returns the number of bits required to store a single element of this data type in memory. func (t *BooleanType) BitWidth() int { return 1 } @@ -46,6 +47,7 @@ type FixedSizeBinaryType struct { func (*FixedSizeBinaryType) ID() Type { return FIXED_SIZE_BINARY } func (*FixedSizeBinaryType) Name() string { return "fixed_size_binary" } func (t *FixedSizeBinaryType) BitWidth() int { return 8 * t.ByteWidth } +func (t *FixedSizeBinaryType) Bytes() int { return t.ByteWidth } func (t *FixedSizeBinaryType) Fingerprint() string { return typeFingerprint(t) } func (t *FixedSizeBinaryType) String() string { return "fixed_size_binary[" + strconv.Itoa(t.ByteWidth) + "]" @@ -354,6 +356,8 @@ func (t *TimestampType) Fingerprint() string { // BitWidth returns the number of bits required to store a single element of this data type in memory. func (*TimestampType) BitWidth() int { return 64 } +func (TimestampType) Bytes() int { return Int64SizeBytes } + func (TimestampType) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{SpecBitmap(), SpecFixedWidth(TimestampSizeBytes)}} } @@ -445,6 +449,7 @@ type Time32Type struct { func (*Time32Type) ID() Type { return TIME32 } func (*Time32Type) Name() string { return "time32" } func (*Time32Type) BitWidth() int { return 32 } +func (*Time32Type) Bytes() int { return Int32SizeBytes } func (t *Time32Type) String() string { return "time32[" + t.Unit.String() + "]" } func (t *Time32Type) Fingerprint() string { return typeFingerprint(t) + string(timeUnitFingerprint(t.Unit)) @@ -464,6 +469,7 @@ type Time64Type struct { func (*Time64Type) ID() Type { return TIME64 } func (*Time64Type) Name() string { return "time64" } func (*Time64Type) BitWidth() int { return 64 } +func (*Time64Type) Bytes() int { return Int64SizeBytes } func (t *Time64Type) String() string { return "time64[" + t.Unit.String() + "]" } func (t *Time64Type) Fingerprint() string { return typeFingerprint(t) + string(timeUnitFingerprint(t.Unit)) @@ -484,6 +490,7 @@ type DurationType struct { func (*DurationType) ID() Type { return DURATION } func (*DurationType) Name() string { return "duration" } func (*DurationType) BitWidth() int { return 64 } +func (*DurationType) Bytes() int { return Int64SizeBytes } func (t *DurationType) String() string { return "duration[" + t.Unit.String() + "]" } func (t *DurationType) Fingerprint() string { return typeFingerprint(t) + string(timeUnitFingerprint(t.Unit)) @@ -506,6 +513,8 @@ func (t *Float16Type) Fingerprint() string { return typeFingerprint(t) } // BitWidth returns the number of bits required to store a single element of this data type in memory. func (t *Float16Type) BitWidth() int { return 16 } +func (Float16Type) Bytes() int { return Float16SizeBytes } + func (Float16Type) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{SpecBitmap(), SpecFixedWidth(Float16SizeBytes)}} } @@ -519,6 +528,7 @@ type Decimal128Type struct { func (*Decimal128Type) ID() Type { return DECIMAL128 } func (*Decimal128Type) Name() string { return "decimal" } func (*Decimal128Type) BitWidth() int { return 128 } +func (*Decimal128Type) Bytes() int { return Decimal128SizeBytes } func (t *Decimal128Type) String() string { return fmt.Sprintf("%s(%d, %d)", t.Name(), t.Precision, t.Scale) } @@ -539,6 +549,7 @@ type Decimal256Type struct { func (*Decimal256Type) ID() Type { return DECIMAL256 } func (*Decimal256Type) Name() string { return "decimal256" } func (*Decimal256Type) BitWidth() int { return 256 } +func (*Decimal256Type) Bytes() int { return Decimal256SizeBytes } func (t *Decimal256Type) String() string { return fmt.Sprintf("%s(%d, %d)", t.Name(), t.Precision, t.Scale) } @@ -583,6 +594,7 @@ func (*MonthIntervalType) Fingerprint() string { return typeIDFingerprint(INTERV // BitWidth returns the number of bits required to store a single element of this data type in memory. func (t *MonthIntervalType) BitWidth() int { return 32 } +func (MonthIntervalType) Bytes() int { return Int32SizeBytes } func (MonthIntervalType) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{SpecBitmap(), SpecFixedWidth(MonthIntervalSizeBytes)}} } @@ -605,6 +617,7 @@ func (*DayTimeIntervalType) Fingerprint() string { return typeIDFingerprint(INTE // BitWidth returns the number of bits required to store a single element of this data type in memory. func (t *DayTimeIntervalType) BitWidth() int { return 64 } +func (DayTimeIntervalType) Bytes() int { return DayTimeIntervalSizeBytes } func (DayTimeIntervalType) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{SpecBitmap(), SpecFixedWidth(DayTimeIntervalSizeBytes)}} } @@ -630,7 +643,7 @@ func (*MonthDayNanoIntervalType) Fingerprint() string { // BitWidth returns the number of bits required to store a single element of this data type in memory. func (*MonthDayNanoIntervalType) BitWidth() int { return 128 } - +func (*MonthDayNanoIntervalType) Bytes() int { return MonthDayNanoIntervalSizeBytes } func (MonthDayNanoIntervalType) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{SpecBitmap(), SpecFixedWidth(MonthDayNanoIntervalSizeBytes)}} } @@ -701,6 +714,7 @@ type DictionaryType struct { func (*DictionaryType) ID() Type { return DICTIONARY } func (*DictionaryType) Name() string { return "dictionary" } func (d *DictionaryType) BitWidth() int { return d.IndexType.(FixedWidthDataType).BitWidth() } +func (d *DictionaryType) Bytes() int { return d.IndexType.(FixedWidthDataType).Bytes() } func (d *DictionaryType) String() string { return fmt.Sprintf("%s", d.Name(), d.ValueType, d.IndexType, d.Ordered) diff --git a/go/arrow/datatype_numeric.gen.go b/go/arrow/datatype_numeric.gen.go index dfcdab5924f..62cbd90016f 100644 --- a/go/arrow/datatype_numeric.gen.go +++ b/go/arrow/datatype_numeric.gen.go @@ -24,6 +24,7 @@ func (t *Int8Type) ID() Type { return INT8 } func (t *Int8Type) Name() string { return "int8" } func (t *Int8Type) String() string { return "int8" } func (t *Int8Type) BitWidth() int { return 8 } +func (t *Int8Type) Bytes() int { return Int8SizeBytes } func (t *Int8Type) Fingerprint() string { return typeFingerprint(t) } func (t *Int8Type) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{ @@ -36,6 +37,7 @@ func (t *Int16Type) ID() Type { return INT16 } func (t *Int16Type) Name() string { return "int16" } func (t *Int16Type) String() string { return "int16" } func (t *Int16Type) BitWidth() int { return 16 } +func (t *Int16Type) Bytes() int { return Int16SizeBytes } func (t *Int16Type) Fingerprint() string { return typeFingerprint(t) } func (t *Int16Type) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{ @@ -48,6 +50,7 @@ func (t *Int32Type) ID() Type { return INT32 } func (t *Int32Type) Name() string { return "int32" } func (t *Int32Type) String() string { return "int32" } func (t *Int32Type) BitWidth() int { return 32 } +func (t *Int32Type) Bytes() int { return Int32SizeBytes } func (t *Int32Type) Fingerprint() string { return typeFingerprint(t) } func (t *Int32Type) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{ @@ -60,6 +63,7 @@ func (t *Int64Type) ID() Type { return INT64 } func (t *Int64Type) Name() string { return "int64" } func (t *Int64Type) String() string { return "int64" } func (t *Int64Type) BitWidth() int { return 64 } +func (t *Int64Type) Bytes() int { return Int64SizeBytes } func (t *Int64Type) Fingerprint() string { return typeFingerprint(t) } func (t *Int64Type) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{ @@ -72,6 +76,7 @@ func (t *Uint8Type) ID() Type { return UINT8 } func (t *Uint8Type) Name() string { return "uint8" } func (t *Uint8Type) String() string { return "uint8" } func (t *Uint8Type) BitWidth() int { return 8 } +func (t *Uint8Type) Bytes() int { return Uint8SizeBytes } func (t *Uint8Type) Fingerprint() string { return typeFingerprint(t) } func (t *Uint8Type) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{ @@ -84,6 +89,7 @@ func (t *Uint16Type) ID() Type { return UINT16 } func (t *Uint16Type) Name() string { return "uint16" } func (t *Uint16Type) String() string { return "uint16" } func (t *Uint16Type) BitWidth() int { return 16 } +func (t *Uint16Type) Bytes() int { return Uint16SizeBytes } func (t *Uint16Type) Fingerprint() string { return typeFingerprint(t) } func (t *Uint16Type) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{ @@ -96,6 +102,7 @@ func (t *Uint32Type) ID() Type { return UINT32 } func (t *Uint32Type) Name() string { return "uint32" } func (t *Uint32Type) String() string { return "uint32" } func (t *Uint32Type) BitWidth() int { return 32 } +func (t *Uint32Type) Bytes() int { return Uint32SizeBytes } func (t *Uint32Type) Fingerprint() string { return typeFingerprint(t) } func (t *Uint32Type) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{ @@ -108,6 +115,7 @@ func (t *Uint64Type) ID() Type { return UINT64 } func (t *Uint64Type) Name() string { return "uint64" } func (t *Uint64Type) String() string { return "uint64" } func (t *Uint64Type) BitWidth() int { return 64 } +func (t *Uint64Type) Bytes() int { return Uint64SizeBytes } func (t *Uint64Type) Fingerprint() string { return typeFingerprint(t) } func (t *Uint64Type) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{ @@ -120,6 +128,7 @@ func (t *Float32Type) ID() Type { return FLOAT32 } func (t *Float32Type) Name() string { return "float32" } func (t *Float32Type) String() string { return "float32" } func (t *Float32Type) BitWidth() int { return 32 } +func (t *Float32Type) Bytes() int { return Float32SizeBytes } func (t *Float32Type) Fingerprint() string { return typeFingerprint(t) } func (t *Float32Type) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{ @@ -132,6 +141,7 @@ func (t *Float64Type) ID() Type { return FLOAT64 } func (t *Float64Type) Name() string { return "float64" } func (t *Float64Type) String() string { return "float64" } func (t *Float64Type) BitWidth() int { return 64 } +func (t *Float64Type) Bytes() int { return Float64SizeBytes } func (t *Float64Type) Fingerprint() string { return typeFingerprint(t) } func (t *Float64Type) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{ @@ -144,6 +154,7 @@ func (t *Date32Type) ID() Type { return DATE32 } func (t *Date32Type) Name() string { return "date32" } func (t *Date32Type) String() string { return "date32" } func (t *Date32Type) BitWidth() int { return 32 } +func (t *Date32Type) Bytes() int { return Date32SizeBytes } func (t *Date32Type) Fingerprint() string { return typeFingerprint(t) } func (t *Date32Type) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{ @@ -156,6 +167,7 @@ func (t *Date64Type) ID() Type { return DATE64 } func (t *Date64Type) Name() string { return "date64" } func (t *Date64Type) String() string { return "date64" } func (t *Date64Type) BitWidth() int { return 64 } +func (t *Date64Type) Bytes() int { return Date64SizeBytes } func (t *Date64Type) Fingerprint() string { return typeFingerprint(t) } func (t *Date64Type) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{ diff --git a/go/arrow/datatype_numeric.gen.go.tmpl b/go/arrow/datatype_numeric.gen.go.tmpl index a784619bd15..611046afc42 100644 --- a/go/arrow/datatype_numeric.gen.go.tmpl +++ b/go/arrow/datatype_numeric.gen.go.tmpl @@ -23,6 +23,7 @@ func (t *{{.Name}}Type) ID() Type { return {{.Name|upper}} } func (t *{{.Name}}Type) Name() string { return "{{.Name|lower}}" } func (t *{{.Name}}Type) String() string { return "{{.Name|lower}}" } func (t *{{.Name}}Type) BitWidth() int { return {{.Size}} } +func (t *{{.Name}}Type) Bytes() int { return {{.Name}}SizeBytes } func (t *{{.Name}}Type) Fingerprint() string { return typeFingerprint(t) } func (t *{{.Name}}Type) Layout() DataTypeLayout { return DataTypeLayout{Buffers: []BufferSpec{ diff --git a/go/arrow/scalar/binary.go b/go/arrow/scalar/binary.go index 19ff6850475..5852d2afb03 100644 --- a/go/arrow/scalar/binary.go +++ b/go/arrow/scalar/binary.go @@ -40,8 +40,18 @@ type Binary struct { Value *memory.Buffer } -func (b *Binary) Retain() { b.Value.Retain() } -func (b *Binary) Release() { b.Value.Release() } +func (b *Binary) Retain() { + if b.Value != nil { + b.Value.Retain() + } +} + +func (b *Binary) Release() { + if b.Value != nil { + b.Value.Release() + } +} + func (b *Binary) value() interface{} { return b.Value } func (b *Binary) Data() []byte { return b.Value.Bytes() } func (b *Binary) equals(rhs Scalar) bool { diff --git a/go/arrow/scalar/nested.go b/go/arrow/scalar/nested.go index 4a2e99bcf38..c0124840c1f 100644 --- a/go/arrow/scalar/nested.go +++ b/go/arrow/scalar/nested.go @@ -40,8 +40,18 @@ type List struct { Value arrow.Array } -func (l *List) Release() { l.Value.Release() } -func (l *List) Retain() { l.Value.Retain() } +func (l *List) Release() { + if l.Value != nil { + l.Value.Release() + } +} + +func (l *List) Retain() { + if l.Value != nil { + l.Value.Retain() + } +} + func (l *List) value() interface{} { return l.Value } func (l *List) GetList() arrow.Array { return l.Value } func (l *List) equals(rhs Scalar) bool { @@ -258,8 +268,11 @@ func (s *Struct) Validate() (err error) { } if !s.Valid { - if len(s.Value) != 0 { - err = fmt.Errorf("%s scalar is marked null but has child values", s.Type) + for _, v := range s.Value { + if v.IsValid() { + err = fmt.Errorf("%s scalar is marked null but has child values", s.Type) + return + } } return } @@ -293,8 +306,11 @@ func (s *Struct) ValidateFull() (err error) { } if !s.Valid { - if len(s.Value) != 0 { - err = fmt.Errorf("%s scalar is marked null but has child values", s.Type) + for _, v := range s.Value { + if v.IsValid() { + err = fmt.Errorf("%s scalar is marked null but has child values", s.Type) + return + } } return } @@ -550,6 +566,14 @@ func (s *SparseUnion) String() string { return "union{" + dt.Fields()[dt.ChildIDs()[s.TypeCode]].String() + " = " + val.String() + "}" } +func (s *SparseUnion) Retain() { + for _, v := range s.Value { + if v, ok := v.(Releasable); ok { + v.Retain() + } + } +} + func (s *SparseUnion) Release() { for _, v := range s.Value { if v, ok := v.(Releasable); ok { @@ -662,6 +686,12 @@ func (s *DenseUnion) String() string { return "union{" + dt.Fields()[dt.ChildIDs()[s.TypeCode]].String() + " = " + s.Value.String() + "}" } +func (s *DenseUnion) Retain() { + if v, ok := s.Value.(Releasable); ok { + v.Retain() + } +} + func (s *DenseUnion) Release() { if v, ok := s.Value.(Releasable); ok { v.Release() diff --git a/go/arrow/scalar/scalar.go b/go/arrow/scalar/scalar.go index 14176443621..89ca6fb6610 100644 --- a/go/arrow/scalar/scalar.go +++ b/go/arrow/scalar/scalar.go @@ -519,7 +519,14 @@ func init() { arrow.INTERVAL_MONTH_DAY_NANO: func(dt arrow.DataType) Scalar { return &MonthDayNanoInterval{scalar: scalar{dt, false}} }, arrow.DECIMAL128: func(dt arrow.DataType) Scalar { return &Decimal128{scalar: scalar{dt, false}} }, arrow.LIST: func(dt arrow.DataType) Scalar { return &List{scalar: scalar{dt, false}} }, - arrow.STRUCT: func(dt arrow.DataType) Scalar { return &Struct{scalar: scalar{dt, false}} }, + arrow.STRUCT: func(dt arrow.DataType) Scalar { + typ := dt.(*arrow.StructType) + values := make([]Scalar, len(typ.Fields())) + for i, f := range typ.Fields() { + values[i] = MakeNullScalar(f.Type) + } + return &Struct{scalar: scalar{dt, false}, Value: values} + }, arrow.SPARSE_UNION: func(dt arrow.DataType) Scalar { typ := dt.(*arrow.SparseUnionType) if len(typ.Fields()) == 0 { @@ -631,6 +638,11 @@ func GetScalar(arr arrow.Array, idx int) (Scalar, error) { slice := array.NewSlice(arr.ListValues(), int64(offsets[idx]), int64(offsets[idx+1])) defer slice.Release() return NewListScalar(slice), nil + case *array.LargeList: + offsets := arr.Offsets() + slice := array.NewSlice(arr.ListValues(), int64(offsets[idx]), int64(offsets[idx+1])) + defer slice.Release() + return NewLargeListScalar(slice), nil case *array.Map: offsets := arr.Offsets() slice := array.NewSlice(arr.ListValues(), int64(offsets[idx]), int64(offsets[idx+1])) From 47314c3999d7b7a7f9167c6ed6793da756c411a1 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 7 Sep 2022 18:48:55 -0400 Subject: [PATCH 014/133] ARROW-17646: [Go][CI] Switch C Data to use cgo.Handle (bumps to Go1.17) (#14067) Authored-by: Matt Topol Signed-off-by: Matt Topol --- .env | 2 +- .github/workflows/go.yml | 24 ++++---- ci/docker/debian-10-go.dockerfile | 2 +- ci/docker/debian-11-go.dockerfile | 2 +- ci/scripts/go_build.sh | 2 +- dev/release/verify-release-candidate.sh | 2 +- dev/tasks/tasks.yml | 4 +- go/arrow/cdata/cdata.go | 18 ++++-- go/arrow/cdata/cdata_exports.go | 27 ++++++++- go/arrow/cdata/cdata_test.go | 35 ++++++++++++ go/arrow/cdata/cdata_test_framework.go | 10 +++- go/arrow/cdata/exports.go | 76 +++++++++++++++---------- go/arrow/cdata/interface.go | 9 +++ go/go.mod | 26 ++++++++- 14 files changed, 180 insertions(+), 59 deletions(-) diff --git a/.env b/.env index 4aa04daab04..8ae036b6b5b 100644 --- a/.env +++ b/.env @@ -58,7 +58,7 @@ CUDA=11.0.3 DASK=latest DOTNET=6.0 GCC_VERSION="" -GO=1.16 +GO=1.17 STATICCHECK=v0.2.2 HDFS=3.2.1 JDK=8 diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 4112bf3bd4c..5fccebbca15 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -50,9 +50,9 @@ jobs: strategy: fail-fast: false matrix: - go: [1.16, 1.18] + go: [1.17, 1.18] include: - - go: 1.16 + - go: 1.17 staticcheck: v0.2.2 - go: 1.18 staticcheck: latest @@ -86,9 +86,9 @@ jobs: strategy: fail-fast: false matrix: - go: [1.16, 1.18] + go: [1.17, 1.18] include: - - go: 1.16 + - go: 1.17 staticcheck: v0.2.2 - go: 1.18 staticcheck: latest @@ -123,9 +123,9 @@ jobs: strategy: fail-fast: false matrix: - go: [1.16, 1.18] + go: [1.17, 1.18] include: - - go: 1.16 + - go: 1.17 staticcheck: v0.2.2 - go: 1.18 staticcheck: latest @@ -158,9 +158,9 @@ jobs: strategy: fail-fast: false matrix: - go: [1.16, 1.18] + go: [1.17, 1.18] include: - - go: 1.16 + - go: 1.17 staticcheck: v0.2.2 - go: 1.18 staticcheck: latest @@ -193,9 +193,9 @@ jobs: strategy: fail-fast: false matrix: - go: [1.16, 1.18] + go: [1.17, 1.18] include: - - go: 1.16 + - go: 1.17 staticcheck: v0.2.2 - go: 1.18 staticcheck: latest @@ -228,9 +228,9 @@ jobs: strategy: fail-fast: false matrix: - go: [1.16, 1.18] + go: [1.17, 1.18] include: - - go: 1.16 + - go: 1.17 staticcheck: v0.2.2 - go: 1.18 staticcheck: latest diff --git a/ci/docker/debian-10-go.dockerfile b/ci/docker/debian-10-go.dockerfile index dfe81f5a73c..8d964c76a66 100644 --- a/ci/docker/debian-10-go.dockerfile +++ b/ci/docker/debian-10-go.dockerfile @@ -16,7 +16,7 @@ # under the License. ARG arch=amd64 -ARG go=1.16 +ARG go=1.17 ARG staticcheck=v0.2.2 FROM ${arch}/golang:${go}-buster diff --git a/ci/docker/debian-11-go.dockerfile b/ci/docker/debian-11-go.dockerfile index 32d7b3af390..9f75bf23fdd 100644 --- a/ci/docker/debian-11-go.dockerfile +++ b/ci/docker/debian-11-go.dockerfile @@ -16,7 +16,7 @@ # under the License. ARG arch=amd64 -ARG go=1.16 +ARG go=1.17 ARG staticcheck=v0.2.2 FROM ${arch}/golang:${go}-bullseye diff --git a/ci/scripts/go_build.sh b/ci/scripts/go_build.sh index 43f348b1538..c113bbd320e 100755 --- a/ci/scripts/go_build.sh +++ b/ci/scripts/go_build.sh @@ -22,7 +22,7 @@ set -ex source_dir=${1}/go ARCH=`uname -m` -# Arm64 CI is triggered by travis and run in arm64v8/golang:1.16-bullseye +# Arm64 CI is triggered by travis and run in arm64v8/golang:1.17-bullseye if [ "aarch64" == "$ARCH" ]; then # Install `staticcheck` GO111MODULE=on go install honnef.co/go/tools/cmd/staticcheck@v0.2.2 diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index eb44e3e4fec..b016988ba91 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -399,7 +399,7 @@ install_go() { return 0 fi - local version=1.16.12 + local version=1.17.13 show_info "Installing go version ${version}..." local arch="$(uname -m)" diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index 0816c24589e..ae3c613902b 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -1449,13 +1449,13 @@ tasks: ci: github template: r/github.linux.revdepcheck.yml - test-debian-11-go-1.16: + test-debian-11-go-1.17: ci: azure template: docker-tests/azure.linux.yml params: env: DEBIAN: 11 - GO: 1.16 + GO: 1.17 image: debian-go test-ubuntu-default-docs: diff --git a/go/arrow/cdata/cdata.go b/go/arrow/cdata/cdata.go index a2b583f268e..aec166a0110 100644 --- a/go/arrow/cdata/cdata.go +++ b/go/arrow/cdata/cdata.go @@ -27,7 +27,11 @@ package cdata // int stream_get_schema(struct ArrowArrayStream* st, struct ArrowSchema* out) { return st->get_schema(st, out); } // int stream_get_next(struct ArrowArrayStream* st, struct ArrowArray* out) { return st->get_next(st, out); } // const char* stream_get_last_error(struct ArrowArrayStream* st) { return st->get_last_error(st); } -// struct ArrowArray* get_arr() { return (struct ArrowArray*)(malloc(sizeof(struct ArrowArray))); } +// struct ArrowArray* get_arr() { +// struct ArrowArray* out = (struct ArrowArray*)(malloc(sizeof(struct ArrowArray))); +// memset(out, 0, sizeof(struct ArrowArray)); +// return out; +// } // struct ArrowArrayStream* get_stream() { return (struct ArrowArrayStream*)malloc(sizeof(struct ArrowArrayStream)); } // import "C" @@ -655,18 +659,22 @@ func importCArrayAsType(arr *CArrowArray, dt arrow.DataType) (imp *cimporter, er func initReader(rdr *nativeCRecordBatchReader, stream *CArrowArrayStream) { rdr.stream = C.get_stream() C.ArrowArrayStreamMove(stream, rdr.stream) + rdr.arr = C.get_arr() runtime.SetFinalizer(rdr, func(r *nativeCRecordBatchReader) { if r.cur != nil { r.cur.Release() } C.ArrowArrayStreamRelease(r.stream) + C.ArrowArrayRelease(r.arr) C.free(unsafe.Pointer(r.stream)) + C.free(unsafe.Pointer(r.arr)) }) } // Record Batch reader that conforms to arrio.Reader for the ArrowArrayStream interface type nativeCRecordBatchReader struct { stream *CArrowArrayStream + arr *CArrowArray schema *arrow.Schema cur arrow.Record @@ -713,18 +721,16 @@ func (n *nativeCRecordBatchReader) next() error { n.cur = nil } - arr := C.get_arr() - defer C.free(unsafe.Pointer(arr)) - errno := C.stream_get_next(n.stream, arr) + errno := C.stream_get_next(n.stream, n.arr) if errno != 0 { return n.getError(int(errno)) } - if C.ArrowArrayIsReleased(arr) == 1 { + if C.ArrowArrayIsReleased(n.arr) == 1 { return io.EOF } - rec, err := ImportCRecordBatchWithSchema(arr, n.schema) + rec, err := ImportCRecordBatchWithSchema(n.arr, n.schema) if err != nil { return err } diff --git a/go/arrow/cdata/cdata_exports.go b/go/arrow/cdata/cdata_exports.go index a3da68447db..b69d44d9b50 100644 --- a/go/arrow/cdata/cdata_exports.go +++ b/go/arrow/cdata/cdata_exports.go @@ -36,6 +36,7 @@ import ( "encoding/binary" "fmt" "reflect" + "runtime/cgo" "strings" "unsafe" @@ -362,7 +363,9 @@ func exportArray(arr arrow.Array, out *CArrowArray, outSchema *CArrowSchema) { out.buffers = (*unsafe.Pointer)(unsafe.Pointer(&buffers[0])) } - out.private_data = unsafe.Pointer(storeData(arr.Data())) + arr.Data().Retain() + h := cgo.NewHandle(arr.Data()) + out.private_data = unsafe.Pointer(&h) out.release = (*[0]byte)(C.goReleaseArray) switch arr := arr.(type) { case *array.List: @@ -400,3 +403,25 @@ func exportArray(arr arrow.Array, out *CArrowArray, outSchema *CArrowSchema) { out.children = nil } } + +type cRecordReader struct { + rdr array.RecordReader +} + +func (rr cRecordReader) getSchema(out *CArrowSchema) int { + ExportArrowSchema(rr.rdr.Schema(), out) + return 0 +} + +func (rr cRecordReader) next(out *CArrowArray) int { + if rr.rdr.Next() { + ExportArrowRecordBatch(rr.rdr.Record(), out, nil) + return 0 + } + releaseArr(out) + return 0 +} + +func (rr cRecordReader) release() { + rr.rdr.Release() +} diff --git a/go/arrow/cdata/cdata_test.go b/go/arrow/cdata/cdata_test.go index 0b73a08d6b0..a143b5cb79a 100644 --- a/go/arrow/cdata/cdata_test.go +++ b/go/arrow/cdata/cdata_test.go @@ -27,6 +27,7 @@ import ( "errors" "io" "runtime" + "runtime/cgo" "testing" "time" "unsafe" @@ -34,6 +35,7 @@ import ( "github.com/apache/arrow/go/v10/arrow" "github.com/apache/arrow/go/v10/arrow/array" "github.com/apache/arrow/go/v10/arrow/decimal128" + "github.com/apache/arrow/go/v10/arrow/internal/arrdata" "github.com/apache/arrow/go/v10/arrow/memory" "github.com/stretchr/testify/assert" ) @@ -659,3 +661,36 @@ func TestRecordReaderStream(t *testing.T) { assert.Equal(t, "baz", rec.Column(1).(*array.String).Value(2)) } } + +func TestExportRecordReaderStream(t *testing.T) { + reclist := arrdata.Records["primitives"] + rdr, _ := array.NewRecordReader(reclist[0].Schema(), reclist) + + out := createTestStreamObj() + ExportRecordReader(rdr, out) + + assert.NotNil(t, out.get_schema) + assert.NotNil(t, out.get_next) + assert.NotNil(t, out.get_last_error) + assert.NotNil(t, out.release) + assert.NotNil(t, out.private_data) + + h := *(*cgo.Handle)(out.private_data) + assert.Same(t, rdr, h.Value().(cRecordReader).rdr) + + importedRdr := ImportCArrayStream(out, nil) + i := 0 + for { + rec, err := importedRdr.Read() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + assert.NoError(t, err) + } + + assert.Truef(t, array.RecordEqual(reclist[i], rec), "expected: %s\ngot: %s", reclist[i], rec) + i++ + } + assert.EqualValues(t, len(reclist), i) +} diff --git a/go/arrow/cdata/cdata_test_framework.go b/go/arrow/cdata/cdata_test_framework.go index bb4db1e339b..0274b01fb73 100644 --- a/go/arrow/cdata/cdata_test_framework.go +++ b/go/arrow/cdata/cdata_test_framework.go @@ -26,7 +26,11 @@ package cdata // // void setup_array_stream_test(const int n_batches, struct ArrowArrayStream* out); // struct ArrowArray* get_test_arr() { return (struct ArrowArray*)(malloc(sizeof(struct ArrowArray))); } -// struct ArrowArrayStream* get_test_stream() { return (struct ArrowArrayStream*)malloc(sizeof(struct ArrowArrayStream)); } +// struct ArrowArrayStream* get_test_stream() { +// struct ArrowArrayStream* out = (struct ArrowArrayStream*)malloc(sizeof(struct ArrowArrayStream)); +// memset(out, 0, sizeof(struct ArrowArrayStream)); +// return out; +// } // // void release_test_arr(struct ArrowArray* arr) { // for (int i = 0; i < arr->n_buffers; ++i) { @@ -251,6 +255,10 @@ func createCArr(arr arrow.Array) *CArrowArray { return carr } +func createTestStreamObj() *CArrowArrayStream { + return C.get_test_stream() +} + func arrayStreamTest() *CArrowArrayStream { st := C.get_test_stream() C.setup_array_stream_test(2, st) diff --git a/go/arrow/cdata/exports.go b/go/arrow/cdata/exports.go index 4ad4b7fac31..c7d77a52a72 100644 --- a/go/arrow/cdata/exports.go +++ b/go/arrow/cdata/exports.go @@ -18,42 +18,24 @@ package cdata import ( "reflect" - "sync" - "sync/atomic" + "runtime/cgo" "unsafe" "github.com/apache/arrow/go/v10/arrow" + "github.com/apache/arrow/go/v10/arrow/array" ) // #include // #include "arrow/c/helpers.h" +// +// typedef const char cchar_t; +// extern int streamGetSchema(struct ArrowArrayStream*, struct ArrowSchema*); +// extern int streamGetNext(struct ArrowArrayStream*, struct ArrowArray*); +// extern const char* streamGetError(struct ArrowArrayStream*); +// extern void streamRelease(struct ArrowArrayStream*); +// import "C" -var ( - handles = sync.Map{} - handleIdx uintptr -) - -type dataHandle uintptr - -func storeData(d arrow.ArrayData) dataHandle { - h := atomic.AddUintptr(&handleIdx, 1) - if h == 0 { - panic("cgo: ran out of space") - } - d.Retain() - handles.Store(h, d) - return dataHandle(h) -} - -func (d dataHandle) releaseData() { - arrd, ok := handles.LoadAndDelete(uintptr(d)) - if !ok { - panic("cgo: invalid datahandle") - } - arrd.(arrow.ArrayData).Release() -} - //export releaseExportedSchema func releaseExportedSchema(schema *CArrowSchema) { if C.ArrowSchemaIsReleased(schema) == 1 { @@ -108,6 +90,42 @@ func releaseExportedArray(arr *CArrowArray) { C.free(unsafe.Pointer(arr.children)) } - h := dataHandle(arr.private_data) - h.releaseData() + h := *(*cgo.Handle)(arr.private_data) + h.Value().(arrow.ArrayData).Release() + h.Delete() +} + +//export streamGetSchema +func streamGetSchema(handle *CArrowArrayStream, out *CArrowSchema) C.int { + h := *(*cgo.Handle)(handle.private_data) + rdr := h.Value().(cRecordReader) + return C.int(rdr.getSchema(out)) +} + +//export streamGetNext +func streamGetNext(handle *CArrowArrayStream, out *CArrowArray) C.int { + h := *(*cgo.Handle)(handle.private_data) + rdr := h.Value().(cRecordReader) + return C.int(rdr.next(out)) +} + +//export streamGetError +func streamGetError(*CArrowArrayStream) *C.cchar_t { return nil } + +//export streamRelease +func streamRelease(handle *CArrowArrayStream) { + h := *(*cgo.Handle)(handle.private_data) + h.Value().(cRecordReader).release() + h.Delete() + handle.release = nil + handle.private_data = nil +} + +func exportStream(rdr array.RecordReader, out *CArrowArrayStream) { + out.get_schema = (*[0]byte)(C.streamGetSchema) + out.get_next = (*[0]byte)(C.streamGetNext) + out.get_last_error = (*[0]byte)(C.streamGetError) + out.release = (*[0]byte)(C.streamRelease) + h := cgo.NewHandle(cRecordReader{rdr}) + out.private_data = unsafe.Pointer(&h) } diff --git a/go/arrow/cdata/interface.go b/go/arrow/cdata/interface.go index e567ce599a4..9b80b7c2f0d 100644 --- a/go/arrow/cdata/interface.go +++ b/go/arrow/cdata/interface.go @@ -225,6 +225,15 @@ func ExportArrowArray(arr arrow.Array, out *CArrowArray, outSchema *CArrowSchema exportArray(arr, out, outSchema) } +// ExportRecordReader populates the CArrowArrayStream that is passed in with the appropriate +// callbacks to be a working ArrowArrayStream utilizing the passed in RecordReader. The +// CArrowArrayStream takes ownership of the RecordReader until the consumer calls the release +// callback, as such it is unnecesary to call Release on the passed in reader unless it has +// previously been retained. +func ExportRecordReader(reader array.RecordReader, out *CArrowArrayStream) { + exportStream(reader, out) +} + // ReleaseCArrowArray calls ArrowArrayRelease on the passed in cdata array func ReleaseCArrowArray(arr *CArrowArray) { releaseArr(arr) } diff --git a/go/go.mod b/go/go.mod index 9e7054b8d52..25ca1e084c7 100644 --- a/go/go.mod +++ b/go/go.mod @@ -16,7 +16,7 @@ module github.com/apache/arrow/go/v10 -go 1.16 +go 1.17 require ( github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c @@ -28,7 +28,6 @@ require ( github.com/google/flatbuffers v2.0.8+incompatible github.com/klauspost/asmfmt v1.3.2 github.com/klauspost/compress v1.15.9 - github.com/kr/pretty v0.3.0 // indirect github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 github.com/pierrec/lz4/v4 v4.1.15 @@ -42,12 +41,33 @@ require ( gonum.org/v1/gonum v0.11.0 google.golang.org/grpc v1.49.0 google.golang.org/protobuf v1.28.1 + modernc.org/sqlite v1.18.1 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/golang/protobuf v1.5.2 // indirect + github.com/google/uuid v1.3.0 // indirect + github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect + github.com/klauspost/cpuid/v2 v2.0.9 // indirect + github.com/kr/pretty v0.3.0 // indirect + github.com/mattn/go-isatty v0.0.16 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0 // indirect + github.com/stretchr/objx v0.4.0 // indirect + golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect + golang.org/x/net v0.0.0-20220722155237-a158d28d115b // indirect + golang.org/x/text v0.3.7 // indirect + google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect lukechampine.com/uint128 v1.2.0 // indirect modernc.org/cc/v3 v3.36.3 // indirect modernc.org/ccgo/v3 v3.16.9 // indirect modernc.org/libc v1.17.1 // indirect + modernc.org/mathutil v1.5.0 // indirect + modernc.org/memory v1.2.1 // indirect modernc.org/opt v0.1.3 // indirect - modernc.org/sqlite v1.18.1 modernc.org/strutil v1.1.3 // indirect + modernc.org/token v1.0.0 // indirect ) From 6ff522433bc3a3739e50e4f323cf521be1ad83dd Mon Sep 17 00:00:00 2001 From: Kai Fricke Date: Thu, 8 Sep 2022 08:14:22 +0100 Subject: [PATCH 015/133] ARROW-17641: [python] Fix ParseOptions deserialization of invalid_row_handler (#14061) Authored-by: Kai Fricke Signed-off-by: Joris Van den Bossche --- python/pyarrow/_csv.pyx | 4 ++-- python/pyarrow/tests/test_csv.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/python/pyarrow/_csv.pyx b/python/pyarrow/_csv.pyx index d1db03c75f1..16bd0985e23 100644 --- a/python/pyarrow/_csv.pyx +++ b/python/pyarrow/_csv.pyx @@ -550,12 +550,12 @@ cdef class ParseOptions(_Weakrefable): def __getstate__(self): return (self.delimiter, self.quote_char, self.double_quote, self.escape_char, self.newlines_in_values, - self.ignore_empty_lines, self._invalid_row_handler) + self.ignore_empty_lines, self.invalid_row_handler) def __setstate__(self, state): (self.delimiter, self.quote_char, self.double_quote, self.escape_char, self.newlines_in_values, - self.ignore_empty_lines, self._invalid_row_handler) = state + self.ignore_empty_lines, self.invalid_row_handler) = state def __eq__(self, other): try: diff --git a/python/pyarrow/tests/test_csv.py b/python/pyarrow/tests/test_csv.py index 3be6f07bfec..c8ae9ff15d1 100644 --- a/python/pyarrow/tests/test_csv.py +++ b/python/pyarrow/tests/test_csv.py @@ -654,6 +654,16 @@ def row_num(x): expected_rows = [InvalidRow(2, 1, row_num(2), "c")] assert parse_opts.invalid_row_handler.rows == expected_rows + # Test ser/de + parse_opts.invalid_row_handler = InvalidRowHandler('skip') + parse_opts = pickle.loads(pickle.dumps(parse_opts)) + + table = self.read_bytes(rows, parse_options=parse_opts) + assert table.to_pydict() == { + 'a': ["d", "i"], + 'b': ["e", "j"], + } + class BaseCSVTableRead(BaseTestCSV): From 43670af02f0913580fd20e26006fd550d6fdf2da Mon Sep 17 00:00:00 2001 From: Joost Hoozemans Date: Thu, 8 Sep 2022 10:36:48 +0200 Subject: [PATCH 016/133] ARROW-17583: [C++][Python] Changed datawidth of WrittenFile.size to int64 to match C++ code (#14032) To fix an exception while writing large parquet files: ``` Traceback (most recent call last): File "pyarrow/_dataset_parquet.pyx", line 165, in pyarrow._dataset_parquet.ParquetFileFormat._finish_write File "pyarrow/dataset.pyx", line 2695, in pyarrow._dataset.WrittenFile.init_ OverflowError: value too large to convert to int Exception ignored in: 'pyarrow._dataset._filesystemdataset_write_visitor' ``` Authored-by: Joost Hoozemans Signed-off-by: Joris Van den Bossche --- python/pyarrow/_dataset.pxd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyarrow/_dataset.pxd b/python/pyarrow/_dataset.pxd index 8e5501fa16f..a512477d501 100644 --- a/python/pyarrow/_dataset.pxd +++ b/python/pyarrow/_dataset.pxd @@ -161,4 +161,4 @@ cdef class WrittenFile(_Weakrefable): # the written file. cdef public object metadata # The size of the file in bytes - cdef public int size + cdef public int64_t size From df121b7feec92464a4e97fe535a864537a16be1b Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Thu, 8 Sep 2022 18:04:08 +0530 Subject: [PATCH 017/133] ARROW-16651 : [Python] Casting Table to new schema ignores nullability of fields (#14048) ```python table = pa.table({'a': [None, 1], 'b': [None, True]}) new_schema = pa.schema([pa.field("a", "int64", nullable=True), pa.field("b", "bool", nullable=False)]) casted = table.cast(new_schema) ``` Now leads to ``` RuntimeError: Casting field 'b' with null values to non-nullable ``` Authored-by: kshitij12345 Signed-off-by: Joris Van den Bossche --- python/pyarrow/table.pxi | 3 +++ python/pyarrow/tests/test_table.py | 12 ++++++++++++ 2 files changed, 15 insertions(+) diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 931677f9848..30352bf3950 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -3401,6 +3401,9 @@ cdef class Table(_PandasConvertible): .format(self.schema.names, target_schema.names)) for column, field in zip(self.itercolumns(), target_schema): + if not field.nullable and column.null_count > 0: + raise ValueError("Casting field {!r} with null values to non-nullable" + .format(field.name)) casted = column.cast(field.type, safe=safe, options=options) newcols.append(casted) diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index c0c60da6272..fad1c0acb24 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -2192,3 +2192,15 @@ def test_table_join_many_columns(): "col6": ["A", "B", None, "Z"], "col7": ["A", "B", None, "Z"], }) + + +def test_table_cast_invalid(): + # Casting a nullable field to non-nullable should be invalid! + table = pa.table({'a': [None, 1], 'b': [None, True]}) + new_schema = pa.schema([pa.field("a", "int64", nullable=True), + pa.field("b", "bool", nullable=False)]) + with pytest.raises(ValueError): + table.cast(new_schema) + + table = pa.table({'a': [None, 1], 'b': [False, True]}) + assert table.cast(new_schema).schema == new_schema From 2920d52dce3a3c06234c1a383ceb73e7f160f31f Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 8 Sep 2022 10:01:41 -0400 Subject: [PATCH 018/133] ARROW-16384: [Docs] Add Flight SQL to status page (#14053) Authored-by: David Li Signed-off-by: David Li --- docs/source/status.rst | 50 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/docs/source/status.rst b/docs/source/status.rst index 64e6b6923ff..bd33a8d3540 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -207,7 +207,7 @@ Supported features in the UCX transport: Notes: -* \(1) No support for handshake or DoExchange. +* \(1) No support for Handshake or DoExchange. * \(2) Support using AspNetCore authentication handlers. * \(3) Whether a single client can support multiple concurrent calls. * \(4) Only support for DoExchange, DoGet, DoPut, and GetFlightInfo. @@ -222,6 +222,54 @@ Notes: .. _gRPC: https://grpc.io/ .. _UCX: https://openucx.org/ +Flight SQL +========== + +.. note:: Flight SQL is still experimental. + +The feature support refers to the client/server libraries only; +databases which implement the Flight SQL protocol in turn will +support/not support individual features. + ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | ++============================================+=======+=======+=======+============+=======+=======+=======+ +| ClosePreparedStatement | ✓ | ✓ | ✓ | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| CreatePreparedStatement | ✓ | ✓ | ✓ | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| GetCatalogs | ✓ | ✓ | ✓ | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| GetCrossReference | ✓ | ✓ | ✓ | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| GetDbSchemas | ✓ | ✓ | ✓ | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| GetExportedKeys | ✓ | ✓ | ✓ | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| GetImportedKeys | ✓ | ✓ | ✓ | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| GetPrimaryKeys | ✓ | ✓ | ✓ | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| GetSqlInfo | ✓ | ✓ | ✓ | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| GetTables | ✓ | ✓ | ✓ | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| GetTableTypes | ✓ | ✓ | ✓ | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| GetXdbcTypeInfo | ✓ | ✓ | ✓ | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| PreparedStatementQuery | ✓ | ✓ | ✓ | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| PreparedStatementUpdate | ✓ | ✓ | ✓ | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| StatementQuery | ✓ | ✓ | ✓ | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| StatementUpdate | ✓ | ✓ | ✓ | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ + +.. seealso:: + The :doc:`./format/FlightSql` specification. + C Data Interface ================ From 8fe7e35388a8147527037711e4262981fa81644a Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 8 Sep 2022 10:02:20 -0400 Subject: [PATCH 019/133] MINOR: [C++] ARROW_TESTING implies ARROW_JSON (#14054) Otherwise a minimal build fails with errors like ``` CMake Error at cmake_modules/BuildUtils.cmake:272 (target_link_libraries): Target "arrow_testing_objlib" links to: rapidjson::rapidjson but the target was not found. Possible reasons include: * There is a typo in the target name. * A find_package call is missing for an IMPORTED target. * An ALIAS target is missing. Call Stack (most recent call first): src/arrow/CMakeLists.txt:653 (add_arrow_lib) ``` Authored-by: David Li Signed-off-by: David Li --- cpp/CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6a01f18e6bb..4dbdd2353fd 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -328,7 +328,6 @@ if(ARROW_BUILD_BENCHMARKS OR ARROW_BUILD_TESTS OR ARROW_BUILD_INTEGRATION OR ARROW_FUZZING) - set(ARROW_JSON ON) set(ARROW_TESTING ON) endif() @@ -366,6 +365,10 @@ if(ARROW_SKYHOOK) set(ARROW_WITH_SNAPPY ON) endif() +if(ARROW_TESTING) + set(ARROW_JSON ON) +endif() + if(ARROW_DATASET) set(ARROW_COMPUTE ON) set(ARROW_FILESYSTEM ON) From 74756051c4f6a8b13a40057f586817d56198d4ba Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Thu, 8 Sep 2022 20:41:48 +0530 Subject: [PATCH 020/133] ARROW-16855: [C++] Adding Read Relation ToProto (#13401) This is the initial PR to set the util functions and structure to include the `ToProto` functionality to relations. Here the objective is to create an ACERO relation by interpretting what is included in a Substrait-Relation. In this PR the `read` relation ToProto is added. Authored-by: Vibhatha Abeykoon Signed-off-by: Weston Pace --- .../arrow/engine/substrait/extension_set.cc | 19 ++ .../arrow/engine/substrait/plan_internal.cc | 16 + .../arrow/engine/substrait/plan_internal.h | 14 + .../engine/substrait/relation_internal.cc | 201 ++++++++++-- .../engine/substrait/relation_internal.h | 10 + cpp/src/arrow/engine/substrait/serde.cc | 17 + cpp/src/arrow/engine/substrait/serde.h | 24 ++ cpp/src/arrow/engine/substrait/serde_test.cc | 291 +++++++++++++++++- cpp/src/arrow/filesystem/localfs_test.cc | 14 +- cpp/src/arrow/filesystem/util_internal.cc | 1 + cpp/src/arrow/util/io_util.cc | 4 +- cpp/src/arrow/util/uri.cc | 10 + cpp/src/arrow/util/uri.h | 4 + 13 files changed, 588 insertions(+), 37 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 6e8522897ee..0e1f5ebc664 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -698,6 +698,20 @@ ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessOverflowableArithmetic }; } +ExtensionIdRegistry::ArrowToSubstraitCall EncodeOptionlessComparison(Id substrait_fn_id) { + return + [substrait_fn_id](const compute::Expression::Call& call) -> Result { + // nullable=true isn't quite correct but we don't know the nullability of + // the inputs + SubstraitCall substrait_call(substrait_fn_id, call.type.GetSharedPtr(), + /*nullable=*/true); + for (std::size_t i = 0; i < call.arguments.size(); i++) { + substrait_call.SetValueArg(static_cast(i), call.arguments[i]); + } + return std::move(substrait_call); + }; +} + ExtensionIdRegistry::SubstraitCallToArrow DecodeOptionlessBasicMapping( const std::string& function_name, uint32_t max_args) { return [function_name, @@ -873,6 +887,11 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { AddArrowToSubstraitCall(std::string(fn_name) + "_checked", EncodeOptionlessOverflowableArithmetic(fn_id))); } + // Comparison operators + for (const auto& fn_name : {"equal", "is_not_distinct_from"}) { + Id fn_id{kSubstraitComparisonFunctionsUri, fn_name}; + DCHECK_OK(AddArrowToSubstraitCall(fn_name, EncodeOptionlessComparison(fn_id))); + } } }; diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index b0fdb9bdc2f..1efd4e1a0a9 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -17,6 +17,8 @@ #include "arrow/engine/substrait/plan_internal.h" +#include "arrow/dataset/plan.h" +#include "arrow/engine/substrait/relation_internal.h" #include "arrow/result.h" #include "arrow/util/hashing.h" #include "arrow/util/logging.h" @@ -133,5 +135,19 @@ Result GetExtensionSetFromPlan(const substrait::Plan& plan, registry); } +Result> PlanToProto( + const compute::Declaration& declr, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + auto subs_plan = internal::make_unique(); + auto plan_rel = internal::make_unique(); + auto rel_root = internal::make_unique(); + ARROW_ASSIGN_OR_RAISE(auto rel, ToProto(declr, ext_set, conversion_options)); + rel_root->set_allocated_input(rel.release()); + plan_rel->set_allocated_root(rel_root.release()); + subs_plan->mutable_relations()->AddAllocated(plan_rel.release()); + RETURN_NOT_OK(AddExtensionSetToPlan(*ext_set, subs_plan.get())); + return std::move(subs_plan); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/plan_internal.h b/cpp/src/arrow/engine/substrait/plan_internal.h index dce23cdceba..e1ced549ce1 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.h +++ b/cpp/src/arrow/engine/substrait/plan_internal.h @@ -19,7 +19,9 @@ #pragma once +#include "arrow/compute/exec/exec_plan.h" #include "arrow/engine/substrait/extension_set.h" +#include "arrow/engine/substrait/options.h" #include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" @@ -51,5 +53,17 @@ Result GetExtensionSetFromPlan( const substrait::Plan& plan, const ExtensionIdRegistry* registry = default_extension_id_registry()); +/// \brief Serialize a declaration into a substrait::Plan. +/// +/// Note that, this is a part of a roundtripping test API and not +/// designed for use in production +/// \param[in] declr the sequence of declarations to be serialized +/// \param[in, out] ext_set the extension set to be updated +/// \param[in] conversion_options options to control serialization behavior +/// \return the serialized plan +ARROW_ENGINE_EXPORT Result> PlanToProto( + const compute::Declaration& declr, ExtensionSet* ext_set, + const ConversionOptions& conversion_options = {}); + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index c5c02f51558..c5d212c8c2f 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -29,8 +29,16 @@ #include "arrow/filesystem/localfs.h" #include "arrow/filesystem/path_util.h" #include "arrow/filesystem/util_internal.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/uri.h" namespace arrow { + +using ::arrow::internal::UriFromAbsolutePath; +using internal::checked_cast; +using internal::make_unique; + namespace engine { template @@ -162,36 +170,45 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& } path = path.substr(7); - if (item.path_type_case() == - substrait::ReadRel_LocalFiles_FileOrFiles::kUriPath) { - ARROW_ASSIGN_OR_RAISE(auto file, filesystem->GetFileInfo(path)); - if (file.type() == fs::FileType::File) { - files.push_back(std::move(file)); - } else if (file.type() == fs::FileType::Directory) { + switch (item.path_type_case()) { + case substrait::ReadRel_LocalFiles_FileOrFiles::kUriPath: { + ARROW_ASSIGN_OR_RAISE(auto file, filesystem->GetFileInfo(path)); + if (file.type() == fs::FileType::File) { + files.push_back(std::move(file)); + } else if (file.type() == fs::FileType::Directory) { + fs::FileSelector selector; + selector.base_dir = path; + selector.recursive = true; + ARROW_ASSIGN_OR_RAISE(auto discovered_files, + filesystem->GetFileInfo(selector)); + std::move(files.begin(), files.end(), std::back_inserter(discovered_files)); + } + break; + } + case substrait::ReadRel_LocalFiles_FileOrFiles::kUriFile: { + files.emplace_back(path, fs::FileType::File); + break; + } + case substrait::ReadRel_LocalFiles_FileOrFiles::kUriFolder: { fs::FileSelector selector; selector.base_dir = path; selector.recursive = true; ARROW_ASSIGN_OR_RAISE(auto discovered_files, filesystem->GetFileInfo(selector)); - std::move(files.begin(), files.end(), std::back_inserter(discovered_files)); + std::move(discovered_files.begin(), discovered_files.end(), + std::back_inserter(files)); + break; + } + case substrait::ReadRel_LocalFiles_FileOrFiles::kUriPathGlob: { + ARROW_ASSIGN_OR_RAISE(auto discovered_files, + fs::internal::GlobFiles(filesystem, path)); + std::move(discovered_files.begin(), discovered_files.end(), + std::back_inserter(files)); + break; + } + default: { + return Status::Invalid("Unrecognized file type in LocalFiles"); } - } - if (item.path_type_case() == - substrait::ReadRel_LocalFiles_FileOrFiles::kUriFile) { - files.emplace_back(path, fs::FileType::File); - } else if (item.path_type_case() == - substrait::ReadRel_LocalFiles_FileOrFiles::kUriFolder) { - fs::FileSelector selector; - selector.base_dir = path; - selector.recursive = true; - ARROW_ASSIGN_OR_RAISE(auto discovered_files, filesystem->GetFileInfo(selector)); - std::move(discovered_files.begin(), discovered_files.end(), - std::back_inserter(files)); - } else { - ARROW_ASSIGN_OR_RAISE(auto discovered_files, - fs::internal::GlobFiles(filesystem, path)); - std::move(discovered_files.begin(), discovered_files.end(), - std::back_inserter(files)); } } @@ -421,5 +438,141 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& rel.DebugString()); } +namespace { + +Result> ExtractSchemaToBind(const compute::Declaration& declr) { + std::shared_ptr bind_schema; + if (declr.factory_name == "scan") { + const auto& opts = checked_cast(*(declr.options)); + bind_schema = opts.dataset->schema(); + } else if (declr.factory_name == "filter") { + auto input_declr = util::get(declr.inputs[0]); + ARROW_ASSIGN_OR_RAISE(bind_schema, ExtractSchemaToBind(input_declr)); + } else if (declr.factory_name == "sink") { + // Note that the sink has no output_schema + return bind_schema; + } else { + return Status::Invalid("Schema extraction failed, unsupported factory ", + declr.factory_name); + } + return bind_schema; +} + +Result> ScanRelationConverter( + const std::shared_ptr& schema, const compute::Declaration& declaration, + ExtensionSet* ext_set, const ConversionOptions& conversion_options) { + auto read_rel = make_unique(); + const auto& scan_node_options = + checked_cast(*declaration.options); + auto dataset = + dynamic_cast(scan_node_options.dataset.get()); + if (dataset == nullptr) { + return Status::Invalid( + "Can only convert scan node with FileSystemDataset to a Substrait plan."); + } + // set schema + ARROW_ASSIGN_OR_RAISE(auto named_struct, + ToProto(*dataset->schema(), ext_set, conversion_options)); + read_rel->set_allocated_base_schema(named_struct.release()); + + // set local files + auto read_rel_lfs = make_unique(); + for (const auto& file : dataset->files()) { + auto read_rel_lfs_ffs = make_unique(); + read_rel_lfs_ffs->set_uri_path(UriFromAbsolutePath(file)); + // set file format + auto format_type_name = dataset->format()->type_name(); + if (format_type_name == "parquet") { + read_rel_lfs_ffs->set_allocated_parquet( + new substrait::ReadRel::LocalFiles::FileOrFiles::ParquetReadOptions()); + } else if (format_type_name == "ipc") { + read_rel_lfs_ffs->set_allocated_arrow( + new substrait::ReadRel::LocalFiles::FileOrFiles::ArrowReadOptions()); + } else if (format_type_name == "orc") { + read_rel_lfs_ffs->set_allocated_orc( + new substrait::ReadRel::LocalFiles::FileOrFiles::OrcReadOptions()); + } else { + return Status::NotImplemented("Unsupported file type: ", format_type_name); + } + read_rel_lfs->mutable_items()->AddAllocated(read_rel_lfs_ffs.release()); + } + read_rel->set_allocated_local_files(read_rel_lfs.release()); + return std::move(read_rel); +} + +Result> FilterRelationConverter( + const std::shared_ptr& schema, const compute::Declaration& declaration, + ExtensionSet* ext_set, const ConversionOptions& conversion_options) { + auto filter_rel = make_unique(); + const auto& filter_node_options = + checked_cast(*(declaration.options)); + + auto filter_expr = filter_node_options.filter_expression; + compute::Expression bound_expression; + if (!filter_expr.IsBound()) { + ARROW_ASSIGN_OR_RAISE(bound_expression, filter_expr.Bind(*schema)); + } + + if (declaration.inputs.size() == 0) { + return Status::Invalid("Filter node doesn't have an input."); + } + + // handling input + auto declr_input = declaration.inputs[0]; + ARROW_ASSIGN_OR_RAISE( + auto input_rel, + ToProto(util::get(declr_input), ext_set, conversion_options)); + filter_rel->set_allocated_input(input_rel.release()); + + ARROW_ASSIGN_OR_RAISE(auto subs_expr, + ToProto(bound_expression, ext_set, conversion_options)); + filter_rel->set_allocated_condition(subs_expr.release()); + return std::move(filter_rel); +} + +} // namespace + +Status SerializeAndCombineRelations(const compute::Declaration& declaration, + ExtensionSet* ext_set, + std::unique_ptr* rel, + const ConversionOptions& conversion_options) { + const auto& factory_name = declaration.factory_name; + ARROW_ASSIGN_OR_RAISE(auto schema, ExtractSchemaToBind(declaration)); + // Note that the sink declaration factory doesn't exist for serialization as + // Substrait doesn't deal with a sink node definition + + if (factory_name == "scan") { + ARROW_ASSIGN_OR_RAISE( + auto read_rel, + ScanRelationConverter(schema, declaration, ext_set, conversion_options)); + (*rel)->set_allocated_read(read_rel.release()); + } else if (factory_name == "filter") { + ARROW_ASSIGN_OR_RAISE( + auto filter_rel, + FilterRelationConverter(schema, declaration, ext_set, conversion_options)); + (*rel)->set_allocated_filter(filter_rel.release()); + } else if (factory_name == "sink") { + // Generally when a plan is deserialized the declaration will be a sink declaration. + // Since there is no Sink relation in substrait, this function would be recursively + // called on the input of the Sink declaration. + auto sink_input_decl = util::get(declaration.inputs[0]); + RETURN_NOT_OK( + SerializeAndCombineRelations(sink_input_decl, ext_set, rel, conversion_options)); + } else { + return Status::NotImplemented("Factory ", factory_name, + " not implemented for roundtripping."); + } + + return Status::OK(); +} + +Result> ToProto( + const compute::Declaration& declr, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + auto rel = make_unique(); + RETURN_NOT_OK(SerializeAndCombineRelations(declr, ext_set, &rel, conversion_options)); + return std::move(rel); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/relation_internal.h b/cpp/src/arrow/engine/substrait/relation_internal.h index 3699d1f6577..778d1e5bc01 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.h +++ b/cpp/src/arrow/engine/substrait/relation_internal.h @@ -40,9 +40,19 @@ struct DeclarationInfo { int num_columns; }; +/// \brief Convert a Substrait Rel object to an Acero declaration ARROW_ENGINE_EXPORT Result FromProto(const substrait::Rel&, const ExtensionSet&, const ConversionOptions&); +/// \brief Convert an Acero Declaration to a Substrait Rel +/// +/// Note that, in order to provide a generic interface for ToProto, +/// the ExecNode or ExecPlan are not used in this context as Declaration +/// is preferred in the Substrait space rather than internal components of +/// Acero execution engine. +ARROW_ENGINE_EXPORT Result> ToProto( + const compute::Declaration&, ExtensionSet*, const ConversionOptions&); + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 9f7d979e2f0..c6297675492 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -52,6 +52,23 @@ Result ParseFromBuffer(const Buffer& buf) { return message; } +Result> SerializePlan( + const compute::Declaration& declaration, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + ARROW_ASSIGN_OR_RAISE(auto subs_plan, + PlanToProto(declaration, ext_set, conversion_options)); + std::string serialized = subs_plan->SerializeAsString(); + return Buffer::FromString(std::move(serialized)); +} + +Result> SerializeRelation( + const compute::Declaration& declaration, ExtensionSet* ext_set, + const ConversionOptions& conversion_options) { + ARROW_ASSIGN_OR_RAISE(auto relation, ToProto(declaration, ext_set, conversion_options)); + std::string serialized = relation->SerializeAsString(); + return Buffer::FromString(std::move(serialized)); +} + Result DeserializeRelation( const Buffer& buf, const ExtensionSet& ext_set, const ConversionOptions& conversion_options) { diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index 6c2083fb56a..2a14ca67570 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -36,6 +36,19 @@ namespace arrow { namespace engine { +/// \brief Serialize an Acero Plan to a binary protobuf Substrait message +/// +/// \param[in] declaration the Acero declaration to serialize. +/// This declaration is the sink relation of the Acero plan. +/// \param[in,out] ext_set the extension mapping to use; may be updated to add +/// \param[in] conversion_options options to control how the conversion is done +/// +/// \return a buffer containing the protobuf serialization of the Acero relation +ARROW_ENGINE_EXPORT +Result> SerializePlan( + const compute::Declaration& declaration, ExtensionSet* ext_set, + const ConversionOptions& conversion_options = {}); + /// Factory function type for generating the node that consumes the batches produced by /// each toplevel Substrait relation when deserializing a Substrait Plan. using ConsumerFactory = std::function()>; @@ -202,6 +215,17 @@ Result> SerializeExpression( const compute::Expression& expr, ExtensionSet* ext_set, const ConversionOptions& conversion_options = {}); +/// \brief Serialize an Acero Declaration to a binary protobuf Substrait message +/// +/// \param[in] declaration the Acero declaration to serialize +/// \param[in,out] ext_set the extension mapping to use; may be updated to add +/// \param[in] conversion_options options to control how the conversion is done +/// +/// \return a buffer containing the protobuf serialization of the Acero relation +ARROW_ENGINE_EXPORT Result> SerializeRelation( + const compute::Declaration& declaration, ExtensionSet* ext_set, + const ConversionOptions& conversion_options = {}); + /// \brief Deserializes a Substrait Rel (relation) message to an ExecNode declaration /// /// \param[in] buf a buffer containing the protobuf serialization of a Substrait diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 04405b31680..9b6c3f715f7 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -23,17 +23,30 @@ #include "arrow/compute/exec/expression_internal.h" #include "arrow/dataset/file_base.h" #include "arrow/dataset/file_ipc.h" +#include "arrow/dataset/file_parquet.h" + #include "arrow/dataset/plan.h" #include "arrow/dataset/scanner.h" #include "arrow/engine/substrait/extension_types.h" #include "arrow/engine/substrait/serde.h" + #include "arrow/engine/substrait/util.h" + +#include "arrow/filesystem/localfs.h" #include "arrow/filesystem/mockfs.h" #include "arrow/filesystem/test_util.h" +#include "arrow/io/compressed.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/writer.h" #include "arrow/testing/gtest_util.h" #include "arrow/testing/matchers.h" #include "arrow/util/key_value_metadata.h" +#include "parquet/arrow/writer.h" + +#include "arrow/util/hash_util.h" +#include "arrow/util/hashing.h" + using testing::ElementsAre; using testing::Eq; using testing::HasSubstr; @@ -42,9 +55,46 @@ using testing::UnorderedElementsAre; namespace arrow { using internal::checked_cast; - +using internal::hash_combine; namespace engine { +Status WriteIpcData(const std::string& path, + const std::shared_ptr file_system, + const std::shared_ptr
input) { + EXPECT_OK_AND_ASSIGN(auto mmap, file_system->OpenOutputStream(path)); + ARROW_ASSIGN_OR_RAISE( + auto file_writer, + MakeFileWriter(mmap, input->schema(), ipc::IpcWriteOptions::Defaults())); + TableBatchReader reader(input); + std::shared_ptr batch; + while (true) { + RETURN_NOT_OK(reader.ReadNext(&batch)); + if (batch == nullptr) { + break; + } + RETURN_NOT_OK(file_writer->WriteRecordBatch(*batch)); + } + RETURN_NOT_OK(file_writer->Close()); + return Status::OK(); +} + +Result> GetTableFromPlan( + compute::Declaration& declarations, + arrow::AsyncGenerator>& sink_gen, + compute::ExecContext& exec_context, std::shared_ptr& output_schema) { + ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(&exec_context)); + ARROW_ASSIGN_OR_RAISE(auto decl, declarations.AddToPlan(plan.get())); + + RETURN_NOT_OK(decl->Validate()); + + std::shared_ptr sink_reader = compute::MakeGeneratorReader( + output_schema, std::move(sink_gen), exec_context.memory_pool()); + + RETURN_NOT_OK(plan->Validate()); + RETURN_NOT_OK(plan->StartProducing()); + return arrow::Table::FromRecordBatchReader(sink_reader.get()); +} + class NullSinkNodeConsumer : public compute::SinkNodeConsumer { public: Status Init(const std::shared_ptr&, compute::BackpressureControl*) override { @@ -866,6 +916,7 @@ Result GetSubstraitJSON() { auto file_name = arrow::internal::PlatformFilename::FromString(dir_string)->Join("binary.parquet"); auto file_path = file_name->ToString(); + std::string substrait_json = R"({ "relations": [ {"rel": { @@ -1814,5 +1865,243 @@ TEST(Substrait, AggregateBadPhase) { ASSERT_RAISES(NotImplemented, DeserializePlans(*buf, [] { return kNullConsumer; })); } +TEST(Substrait, BasicPlanRoundTripping) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + arrow::dataset::internal::Initialize(); + + auto dummy_schema = schema( + {field("key", int32()), field("shared", int32()), field("distinct", int32())}); + + // creating a dummy dataset using a dummy table + auto table = TableFromJSON(dummy_schema, {R"([ + [1, 1, 10], + [3, 4, 20] + ])", + R"([ + [0, 2, 1], + [1, 3, 2], + [4, 1, 3], + [3, 1, 3], + [1, 2, 5] + ])", + R"([ + [2, 2, 12], + [5, 3, 12], + [1, 3, 12] + ])"}); + + auto format = std::make_shared(); + auto filesystem = std::make_shared(); + const std::string file_name = "serde_test.arrow"; + + ASSERT_OK_AND_ASSIGN(auto tempdir, + arrow::internal::TemporaryDir::Make("substrait-tempdir-")); + std::cout << "file_path_str " << tempdir->path().ToString() << std::endl; + ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); + std::string file_path_str = file_path.ToString(); + + ARROW_EXPECT_OK(WriteIpcData(file_path_str, filesystem, table)); + + std::vector files; + const std::vector f_paths = {file_path_str}; + + for (const auto& f_path : f_paths) { + ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path)); + files.push_back(std::move(f_file)); + } + + ASSERT_OK_AND_ASSIGN(auto ds_factory, dataset::FileSystemDatasetFactory::Make( + filesystem, std::move(files), format, {})); + ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema)); + + auto scan_options = std::make_shared(); + scan_options->projection = compute::project({}, {}); + const std::string filter_col_left = "shared"; + const std::string filter_col_right = "distinct"; + auto comp_left_value = compute::field_ref(filter_col_left); + auto comp_right_value = compute::field_ref(filter_col_right); + auto filter = compute::equal(comp_left_value, comp_right_value); + + arrow::AsyncGenerator> sink_gen; + + auto declarations = compute::Declaration::Sequence( + {compute::Declaration( + {"scan", dataset::ScanNodeOptions{dataset, scan_options}, "s"}), + compute::Declaration({"filter", compute::FilterNodeOptions{filter}, "f"}), + compute::Declaration({"sink", compute::SinkNodeOptions{&sink_gen}, "e"})}); + + std::shared_ptr sp_ext_id_reg = MakeExtensionIdRegistry(); + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + + ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set)); + + ASSERT_OK_AND_ASSIGN( + auto sink_decls, + DeserializePlans( + *serialized_plan, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); + // filter declaration + auto roundtripped_filter = sink_decls[0].inputs[0].get(); + const auto& filter_opts = + checked_cast(*(roundtripped_filter->options)); + auto roundtripped_expr = filter_opts.filter_expression; + + if (auto* call = roundtripped_expr.call()) { + EXPECT_EQ(call->function_name, "equal"); + auto args = call->arguments; + auto left_index = args[0].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left); + auto right_index = args[1].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right); + } + // scan declaration + auto roundtripped_scan = roundtripped_filter->inputs[0].get(); + const auto& dataset_opts = + checked_cast(*(roundtripped_scan->options)); + const auto& roundripped_ds = dataset_opts.dataset; + EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema)); + ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments()); + ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments()); + + auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs)); + auto expected_frg_vec = IteratorToVector(std::move(expected_frgs)); + EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size()); + int64_t idx = 0; + for (auto fragment : expected_frg_vec) { + const auto* l_frag = checked_cast(fragment.get()); + const auto* r_frag = + checked_cast(roundtrip_frg_vec[idx++].get()); + EXPECT_TRUE(l_frag->Equals(*r_frag)); + } +} + +TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + compute::ExecContext exec_context; + arrow::dataset::internal::Initialize(); + + auto dummy_schema = schema( + {field("key", int32()), field("shared", int32()), field("distinct", int32())}); + + // creating a dummy dataset using a dummy table + auto table = TableFromJSON(dummy_schema, {R"([ + [1, 1, 10], + [3, 4, 4] + ])", + R"([ + [0, 2, 1], + [1, 3, 2], + [4, 1, 1], + [3, 1, 3], + [1, 2, 2] + ])", + R"([ + [2, 2, 12], + [5, 3, 12], + [1, 3, 3] + ])"}); + + auto format = std::make_shared(); + auto filesystem = std::make_shared(); + const std::string file_name = "serde_test.arrow"; + + ASSERT_OK_AND_ASSIGN(auto tempdir, + arrow::internal::TemporaryDir::Make("substrait-tempdir-")); + ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); + std::string file_path_str = file_path.ToString(); + + ARROW_EXPECT_OK(WriteIpcData(file_path_str, filesystem, table)); + + std::vector files; + const std::vector f_paths = {file_path_str}; + + for (const auto& f_path : f_paths) { + ASSERT_OK_AND_ASSIGN(auto f_file, filesystem->GetFileInfo(f_path)); + files.push_back(std::move(f_file)); + } + + ASSERT_OK_AND_ASSIGN(auto ds_factory, dataset::FileSystemDatasetFactory::Make( + filesystem, std::move(files), format, {})); + ASSERT_OK_AND_ASSIGN(auto dataset, ds_factory->Finish(dummy_schema)); + + auto scan_options = std::make_shared(); + scan_options->projection = compute::project({}, {}); + const std::string filter_col_left = "shared"; + const std::string filter_col_right = "distinct"; + auto comp_left_value = compute::field_ref(filter_col_left); + auto comp_right_value = compute::field_ref(filter_col_right); + auto filter = compute::equal(comp_left_value, comp_right_value); + + arrow::AsyncGenerator> sink_gen; + + auto declarations = compute::Declaration::Sequence( + {compute::Declaration( + {"scan", dataset::ScanNodeOptions{dataset, scan_options}, "s"}), + compute::Declaration({"filter", compute::FilterNodeOptions{filter}, "f"}), + compute::Declaration({"sink", compute::SinkNodeOptions{&sink_gen}, "e"})}); + + ASSERT_OK_AND_ASSIGN(auto expected_table, GetTableFromPlan(declarations, sink_gen, + exec_context, dummy_schema)); + + std::shared_ptr sp_ext_id_reg = MakeExtensionIdRegistry(); + ExtensionIdRegistry* ext_id_reg = sp_ext_id_reg.get(); + ExtensionSet ext_set(ext_id_reg); + + ASSERT_OK_AND_ASSIGN(auto serialized_plan, SerializePlan(declarations, &ext_set)); + + ASSERT_OK_AND_ASSIGN( + auto sink_decls, + DeserializePlans( + *serialized_plan, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); + // filter declaration + auto roundtripped_filter = sink_decls[0].inputs[0].get(); + const auto& filter_opts = + checked_cast(*(roundtripped_filter->options)); + auto roundtripped_expr = filter_opts.filter_expression; + + if (auto* call = roundtripped_expr.call()) { + EXPECT_EQ(call->function_name, "equal"); + auto args = call->arguments; + auto left_index = args[0].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[left_index], filter_col_left); + auto right_index = args[1].field_ref()->field_path()->indices()[0]; + EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right); + } + // scan declaration + auto roundtripped_scan = roundtripped_filter->inputs[0].get(); + const auto& dataset_opts = + checked_cast(*(roundtripped_scan->options)); + const auto& roundripped_ds = dataset_opts.dataset; + EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema)); + ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments()); + ASSERT_OK_AND_ASSIGN(auto expected_frgs, dataset->GetFragments()); + + auto roundtrip_frg_vec = IteratorToVector(std::move(roundtripped_frgs)); + auto expected_frg_vec = IteratorToVector(std::move(expected_frgs)); + EXPECT_EQ(expected_frg_vec.size(), roundtrip_frg_vec.size()); + int64_t idx = 0; + for (auto fragment : expected_frg_vec) { + const auto* l_frag = checked_cast(fragment.get()); + const auto* r_frag = + checked_cast(roundtrip_frg_vec[idx++].get()); + EXPECT_TRUE(l_frag->Equals(*r_frag)); + } + arrow::AsyncGenerator> rnd_trp_sink_gen; + auto rnd_trp_sink_node_options = compute::SinkNodeOptions{&rnd_trp_sink_gen}; + auto rnd_trp_sink_declaration = + compute::Declaration({"sink", rnd_trp_sink_node_options, "e"}); + auto rnd_trp_declarations = + compute::Declaration::Sequence({*roundtripped_filter, rnd_trp_sink_declaration}); + ASSERT_OK_AND_ASSIGN(auto rnd_trp_table, + GetTableFromPlan(rnd_trp_declarations, rnd_trp_sink_gen, + exec_context, dummy_schema)); + EXPECT_TRUE(expected_table->Equals(*rnd_trp_table)); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/filesystem/localfs_test.cc b/cpp/src/arrow/filesystem/localfs_test.cc index 0078a593938..fd36faf30fa 100644 --- a/cpp/src/arrow/filesystem/localfs_test.cc +++ b/cpp/src/arrow/filesystem/localfs_test.cc @@ -32,6 +32,7 @@ #include "arrow/filesystem/util_internal.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/io_util.h" +#include "arrow/util/uri.h" namespace arrow { namespace fs { @@ -40,6 +41,7 @@ namespace internal { using ::arrow::internal::FileDescriptor; using ::arrow::internal::PlatformFilename; using ::arrow::internal::TemporaryDir; +using ::arrow::internal::UriFromAbsolutePath; class LocalFSTestMixin : public ::testing::Test { public: @@ -173,16 +175,6 @@ class TestLocalFS : public LocalFSTestMixin { fs_ = std::make_shared(local_path_, local_fs_); } - std::string UriFromAbsolutePath(const std::string& path) { -#ifdef _WIN32 - // Path is supposed to start with "X:/..." - return "file:///" + path; -#else - // Path is supposed to start with "/..." - return "file://" + path; -#endif - } - template void CheckFileSystemFromUriFunc(const std::string& uri, FileSystemFromUriFunc&& fs_from_uri) { @@ -307,7 +299,7 @@ TYPED_TEST(TestLocalFS, NormalizePathThroughSubtreeFS) { TYPED_TEST(TestLocalFS, FileSystemFromUriFile) { // Concrete test with actual file - const auto uri_string = this->UriFromAbsolutePath(this->local_path_); + const auto uri_string = UriFromAbsolutePath(this->local_path_); this->TestFileSystemFromUri(uri_string); this->TestFileSystemFromUriOrPath(uri_string); diff --git a/cpp/src/arrow/filesystem/util_internal.cc b/cpp/src/arrow/filesystem/util_internal.cc index 0d2ad709026..e6f301bdbf1 100644 --- a/cpp/src/arrow/filesystem/util_internal.cc +++ b/cpp/src/arrow/filesystem/util_internal.cc @@ -78,6 +78,7 @@ Status InvalidDeleteDirContents(util::string_view path) { Result GlobFiles(const std::shared_ptr& filesystem, const std::string& glob) { + // TODO: ARROW-17640 // The candidate entries at the current depth level. // We start with the filesystem root. FileInfoVector results{FileInfo("", FileType::Directory)}; diff --git a/cpp/src/arrow/util/io_util.cc b/cpp/src/arrow/util/io_util.cc index 11ae80d03e2..a62040f3a70 100644 --- a/cpp/src/arrow/util/io_util.cc +++ b/cpp/src/arrow/util/io_util.cc @@ -1867,7 +1867,9 @@ Result> TemporaryDir::Make(const std::string& pref [&](const NativePathString& base_dir) -> Result> { Status st; for (int attempt = 0; attempt < 3; ++attempt) { - PlatformFilename fn(base_dir + kNativeSep + base_name + kNativeSep); + PlatformFilename fn_base_dir(base_dir); + PlatformFilename fn_base_name(base_name + kNativeSep); + PlatformFilename fn = fn_base_dir.Join(fn_base_name); auto result = CreateDir(fn); if (!result.ok()) { // Probably a permissions error or a non-existing base_dir diff --git a/cpp/src/arrow/util/uri.cc b/cpp/src/arrow/util/uri.cc index 7a8484ce51a..abfc9de8b49 100644 --- a/cpp/src/arrow/util/uri.cc +++ b/cpp/src/arrow/util/uri.cc @@ -304,5 +304,15 @@ Status Uri::Parse(const std::string& uri_string) { return Status::OK(); } +std::string UriFromAbsolutePath(const std::string& path) { +#ifdef _WIN32 + // Path is supposed to start with "X:/..." + return "file:///" + path; +#else + // Path is supposed to start with "/..." + return "file://" + path; +#endif +} + } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/uri.h b/cpp/src/arrow/util/uri.h index eae1956eafc..50d9eccf82f 100644 --- a/cpp/src/arrow/util/uri.h +++ b/cpp/src/arrow/util/uri.h @@ -104,5 +104,9 @@ std::string UriEncodeHost(const std::string& host); ARROW_EXPORT bool IsValidUriScheme(const arrow::util::string_view s); +/// Create a file uri from a given absolute path +ARROW_EXPORT +std::string UriFromAbsolutePath(const std::string& path); + } // namespace internal } // namespace arrow From 99b57e84277f24e8ec1ddadbb11ef8b4f43c8c89 Mon Sep 17 00:00:00 2001 From: rtpsw Date: Thu, 8 Sep 2022 23:05:40 +0300 Subject: [PATCH 021/133] ARROW-17412: [C++] AsofJoin multiple keys and types (#13880) See https://issues.apache.org/jira/browse/ARROW-17412 Lead-authored-by: Yaron Gvili Co-authored-by: rtpsw Signed-off-by: Weston Pace --- cpp/src/arrow/array/data.h | 15 +- .../arrow/compute/exec/asof_join_benchmark.cc | 2 +- cpp/src/arrow/compute/exec/asof_join_node.cc | 675 ++++++++--- .../arrow/compute/exec/asof_join_node_test.cc | 1079 ++++++++++++++--- cpp/src/arrow/compute/exec/hash_join.cc | 1 - cpp/src/arrow/compute/exec/options.h | 20 +- cpp/src/arrow/compute/light_array.cc | 6 + cpp/src/arrow/compute/light_array.h | 19 +- cpp/src/arrow/type_traits.h | 7 + 9 files changed, 1482 insertions(+), 342 deletions(-) diff --git a/cpp/src/arrow/array/data.h b/cpp/src/arrow/array/data.h index dde66ac79c4..e024483f665 100644 --- a/cpp/src/arrow/array/data.h +++ b/cpp/src/arrow/array/data.h @@ -167,6 +167,11 @@ struct ARROW_EXPORT ArrayData { std::shared_ptr Copy() const { return std::make_shared(*this); } + bool IsNull(int64_t i) const { + return ((buffers[0] != NULLPTR) ? !bit_util::GetBit(buffers[0]->data(), i + offset) + : null_count.load() == length); + } + // Access a buffer's data as a typed C pointer template inline const T* GetValues(int i, int64_t absolute_offset) const { @@ -324,18 +329,14 @@ struct ARROW_EXPORT ArraySpan { return GetValues(i, this->offset); } - bool IsNull(int64_t i) const { - return ((this->buffers[0].data != NULLPTR) - ? !bit_util::GetBit(this->buffers[0].data, i + this->offset) - : this->null_count == this->length); - } - - bool IsValid(int64_t i) const { + inline bool IsValid(int64_t i) const { return ((this->buffers[0].data != NULLPTR) ? bit_util::GetBit(this->buffers[0].data, i + this->offset) : this->null_count != this->length); } + inline bool IsNull(int64_t i) const { return !IsValid(i); } + std::shared_ptr ToArrayData() const; std::shared_ptr ToArray() const; diff --git a/cpp/src/arrow/compute/exec/asof_join_benchmark.cc b/cpp/src/arrow/compute/exec/asof_join_benchmark.cc index 543a4ece575..7d8abc0ba4c 100644 --- a/cpp/src/arrow/compute/exec/asof_join_benchmark.cc +++ b/cpp/src/arrow/compute/exec/asof_join_benchmark.cc @@ -109,7 +109,7 @@ static void TableJoinOverhead(benchmark::State& state, static void AsOfJoinOverhead(benchmark::State& state) { int64_t tolerance = 0; - AsofJoinNodeOptions options = AsofJoinNodeOptions(kTimeCol, kKeyCol, tolerance); + AsofJoinNodeOptions options = AsofJoinNodeOptions(kTimeCol, {kKeyCol}, tolerance); TableJoinOverhead( state, TableGenerationProperties{int(state.range(0)), int(state.range(1)), diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 3da612aa03e..869456a5775 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -17,34 +17,63 @@ #include #include -#include #include #include +#include +#include "arrow/array/builder_binary.h" #include "arrow/array/builder_primitive.h" #include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/key_hash.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/schema_util.h" #include "arrow/compute/exec/util.h" +#include "arrow/compute/light_array.h" #include "arrow/record_batch.h" #include "arrow/result.h" #include "arrow/status.h" +#include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" #include "arrow/util/future.h" #include "arrow/util/make_unique.h" #include "arrow/util/optional.h" +#include "arrow/util/string_view.h" namespace arrow { namespace compute { -// Remove this when multiple keys and/or types is supported -typedef int32_t KeyType; +template +inline typename T::const_iterator std_find(const T& container, const V& val) { + return std::find(container.begin(), container.end(), val); +} + +template +inline bool std_has(const T& container, const V& val) { + return container.end() != std_find(container, val); +} + +typedef uint64_t ByType; +typedef uint64_t OnType; +typedef uint64_t HashType; // Maximum number of tables that can be joined #define MAX_JOIN_TABLES 64 typedef uint64_t row_index_t; typedef int col_index_t; +// normalize the value to 64-bits while preserving ordering of values +template ::value, bool> = true> +static inline uint64_t time_value(T t) { + uint64_t bias = std::is_signed::value ? (uint64_t)1 << (8 * sizeof(T) - 1) : 0; + return t < 0 ? static_cast(t + bias) : static_cast(t); +} + +// indicates normalization of a key value +template ::value, bool> = true> +static inline uint64_t key_value(T t) { + return static_cast(t); +} + /** * Simple implementation for an unbound concurrent queue */ @@ -65,6 +94,11 @@ class ConcurrentQueue { cond_.notify_one(); } + void Clear() { + std::unique_lock lock(mutex_); + queue_ = std::queue(); + } + util::optional TryPop() { // Try to pop the oldest value from the queue (or return nullopt if none) std::unique_lock lock(mutex_); @@ -99,7 +133,7 @@ struct MemoStore { struct Entry { // Timestamp associated with the entry - int64_t time; + OnType time; // Batch associated with the entry (perf is probably OK for this; batches change // rarely) @@ -109,10 +143,10 @@ struct MemoStore { row_index_t row; }; - std::unordered_map entries_; + std::unordered_map entries_; - void Store(const std::shared_ptr& batch, row_index_t row, int64_t time, - KeyType key) { + void Store(const std::shared_ptr& batch, row_index_t row, OnType time, + ByType key) { auto& e = entries_[key]; // that we can do this assignment optionally, is why we // can get array with using shared_ptr above (the batch @@ -122,13 +156,13 @@ struct MemoStore { e.time = time; } - util::optional GetEntryForKey(KeyType key) const { + util::optional GetEntryForKey(ByType key) const { auto e = entries_.find(key); if (entries_.end() == e) return util::nullopt; return util::optional(&e->second); } - void RemoveEntriesWithLesserTime(int64_t ts) { + void RemoveEntriesWithLesserTime(OnType ts) { for (auto e = entries_.begin(); e != entries_.end();) if (e->second.time < ts) e = entries_.erase(e); @@ -137,18 +171,89 @@ struct MemoStore { } }; +// a specialized higher-performance variation of Hashing64 logic from hash_join_node +// the code here avoids recreating objects that are independent of each batch processed +class KeyHasher { + static constexpr int kMiniBatchLength = util::MiniBatch::kMiniBatchLength; + + public: + explicit KeyHasher(const std::vector& indices) + : indices_(indices), + metadata_(indices.size()), + batch_(NULLPTR), + hashes_(), + ctx_(), + column_arrays_(), + stack_() { + ctx_.stack = &stack_; + column_arrays_.resize(indices.size()); + } + + Status Init(ExecContext* exec_context, const std::shared_ptr& schema) { + ctx_.hardware_flags = exec_context->cpu_info()->hardware_flags(); + const auto& fields = schema->fields(); + for (size_t k = 0; k < metadata_.size(); k++) { + ARROW_ASSIGN_OR_RAISE(metadata_[k], + ColumnMetadataFromDataType(fields[indices_[k]]->type())); + } + return stack_.Init(exec_context->memory_pool(), + 4 * kMiniBatchLength * sizeof(uint32_t)); + } + + const std::vector& HashesFor(const RecordBatch* batch) { + if (batch_ == batch) { + return hashes_; + } + batch_ = NULLPTR; // invalidate cached hashes for batch + size_t batch_length = batch->num_rows(); + hashes_.resize(batch_length); + for (int64_t i = 0; i < static_cast(batch_length); i += kMiniBatchLength) { + int64_t length = std::min(static_cast(batch_length - i), + static_cast(kMiniBatchLength)); + for (size_t k = 0; k < indices_.size(); k++) { + auto array_data = batch->column_data(indices_[k]); + column_arrays_[k] = + ColumnArrayFromArrayDataAndMetadata(array_data, metadata_[k], i, length); + } + Hashing64::HashMultiColumn(column_arrays_, &ctx_, hashes_.data() + i); + } + batch_ = batch; + return hashes_; + } + + private: + std::vector indices_; + std::vector metadata_; + const RecordBatch* batch_; + std::vector hashes_; + LightContext ctx_; + std::vector column_arrays_; + util::TempVectorStack stack_; +}; + class InputState { // InputState correponds to an input // Input record batches are queued up in InputState until processed and // turned into output record batches. public: - InputState(const std::shared_ptr& schema, - const std::string& time_col_name, const std::string& key_col_name) + InputState(bool must_hash, bool may_rehash, KeyHasher* key_hasher, + const std::shared_ptr& schema, + const col_index_t time_col_index, + const std::vector& key_col_index) : queue_(), schema_(schema), - time_col_index_(schema->GetFieldIndex(time_col_name)), - key_col_index_(schema->GetFieldIndex(key_col_name)) {} + time_col_index_(time_col_index), + key_col_index_(key_col_index), + time_type_id_(schema_->fields()[time_col_index_]->type()->id()), + key_type_id_(key_col_index.size()), + key_hasher_(key_hasher), + must_hash_(must_hash), + may_rehash_(may_rehash) { + for (size_t k = 0; k < key_col_index_.size(); k++) { + key_type_id_[k] = schema_->fields()[key_col_index_[k]]->type()->id(); + } + } col_index_t InitSrcToDstMapping(col_index_t dst_offset, bool skip_time_and_key_fields) { src_to_dst_.resize(schema_->num_fields()); @@ -164,7 +269,7 @@ class InputState { bool IsTimeOrKeyColumn(col_index_t i) const { DCHECK_LT(i, schema_->num_fields()); - return (i == time_col_index_) || (i == key_col_index_); + return (i == time_col_index_) || std_has(key_col_index_, i); } // Gets the latest row index, assuming the queue isn't empty @@ -184,27 +289,87 @@ class InputState { return queue_.UnsyncFront(); } - KeyType GetLatestKey() const { - return queue_.UnsyncFront() - ->column_data(key_col_index_) - ->GetValues(1)[latest_ref_row_]; +#define LATEST_VAL_CASE(id, val) \ + case Type::id: { \ + using T = typename TypeIdTraits::Type; \ + using CType = typename TypeTraits::CType; \ + return val(data->GetValues(1)[row]); \ + } + + inline ByType GetLatestKey() const { + return GetLatestKey(queue_.UnsyncFront().get(), latest_ref_row_); } - int64_t GetLatestTime() const { - return queue_.UnsyncFront() - ->column_data(time_col_index_) - ->GetValues(1)[latest_ref_row_]; + inline ByType GetLatestKey(const RecordBatch* batch, row_index_t row) const { + if (must_hash_) { + return key_hasher_->HashesFor(batch)[row]; + } + if (key_col_index_.size() == 0) { + return 0; + } + auto data = batch->column_data(key_col_index_[0]); + switch (key_type_id_[0]) { + LATEST_VAL_CASE(INT8, key_value) + LATEST_VAL_CASE(INT16, key_value) + LATEST_VAL_CASE(INT32, key_value) + LATEST_VAL_CASE(INT64, key_value) + LATEST_VAL_CASE(UINT8, key_value) + LATEST_VAL_CASE(UINT16, key_value) + LATEST_VAL_CASE(UINT32, key_value) + LATEST_VAL_CASE(UINT64, key_value) + LATEST_VAL_CASE(DATE32, key_value) + LATEST_VAL_CASE(DATE64, key_value) + LATEST_VAL_CASE(TIME32, key_value) + LATEST_VAL_CASE(TIME64, key_value) + LATEST_VAL_CASE(TIMESTAMP, key_value) + default: + DCHECK(false); + return 0; // cannot happen + } } + inline OnType GetLatestTime() const { + return GetLatestTime(queue_.UnsyncFront().get(), latest_ref_row_); + } + + inline ByType GetLatestTime(const RecordBatch* batch, row_index_t row) const { + auto data = batch->column_data(time_col_index_); + switch (time_type_id_) { + LATEST_VAL_CASE(INT8, time_value) + LATEST_VAL_CASE(INT16, time_value) + LATEST_VAL_CASE(INT32, time_value) + LATEST_VAL_CASE(INT64, time_value) + LATEST_VAL_CASE(UINT8, time_value) + LATEST_VAL_CASE(UINT16, time_value) + LATEST_VAL_CASE(UINT32, time_value) + LATEST_VAL_CASE(UINT64, time_value) + LATEST_VAL_CASE(DATE32, time_value) + LATEST_VAL_CASE(DATE64, time_value) + LATEST_VAL_CASE(TIME32, time_value) + LATEST_VAL_CASE(TIME64, time_value) + LATEST_VAL_CASE(TIMESTAMP, time_value) + default: + DCHECK(false); + return 0; // cannot happen + } + } + +#undef LATEST_VAL_CASE + bool Finished() const { return batches_processed_ == total_batches_; } - bool Advance() { + Result Advance() { // Try advancing to the next row and update latest_ref_row_ // Returns true if able to advance, false if not. bool have_active_batch = (latest_ref_row_ > 0 /*short circuit the lock on the queue*/) || !queue_.Empty(); if (have_active_batch) { + OnType next_time = GetLatestTime(); + if (latest_time_ > next_time) { + return Status::Invalid("AsofJoin does not allow out-of-order on-key values"); + } + latest_time_ = next_time; // If we have an active batch if (++latest_ref_row_ >= (row_index_t)queue_.UnsyncFront()->num_rows()) { // hit the end of the batch, need to get the next batch if possible. @@ -222,46 +387,60 @@ class InputState { // latest_time and latest_ref_row to the value that immediately pass the // specified timestamp. // Returns true if updates were made, false if not. - bool AdvanceAndMemoize(int64_t ts) { + Result AdvanceAndMemoize(OnType ts) { // Advance the right side row index until we reach the latest right row (for each key) // for the given left timestamp. // Check if already updated for TS (or if there is no latest) if (Empty()) return false; // can't advance if empty - auto latest_time = GetLatestTime(); - if (latest_time > ts) return false; // already advanced // Not updated. Try to update and possibly advance. - bool updated = false; + bool advanced, updated = false; do { - latest_time = GetLatestTime(); + auto latest_time = GetLatestTime(); // if Advance() returns true, then the latest_ts must also be valid // Keep advancing right table until we hit the latest row that has // timestamp <= ts. This is because we only need the latest row for the // match given a left ts. - if (latest_time <= ts) { - memo_.Store(GetLatestBatch(), latest_ref_row_, latest_time, GetLatestKey()); - } else { + if (latest_time > ts) { break; // hit a future timestamp -- done updating for now } + auto rb = GetLatestBatch(); + if (may_rehash_ && rb->column_data(key_col_index_[0])->GetNullCount() > 0) { + must_hash_ = true; + may_rehash_ = false; + Rehash(); + } + memo_.Store(rb, latest_ref_row_, latest_time, GetLatestKey()); updated = true; - } while (Advance()); + ARROW_ASSIGN_OR_RAISE(advanced, Advance()); + } while (advanced); return updated; } - void Push(const std::shared_ptr& rb) { + void Rehash() { + MemoStore new_memo; + for (const auto& entry : memo_.entries_) { + const auto& e = entry.second; + new_memo.Store(e.batch, e.row, e.time, GetLatestKey(e.batch.get(), e.row)); + } + memo_ = new_memo; + } + + Status Push(const std::shared_ptr& rb) { if (rb->num_rows() > 0) { queue_.Push(rb); } else { ++batches_processed_; // don't enqueue empty batches, just record as processed } + return Status::OK(); } - util::optional GetMemoEntryForKey(KeyType key) { + util::optional GetMemoEntryForKey(ByType key) { return memo_.GetEntryForKey(key); } - util::optional GetMemoTimeForKey(KeyType key) { + util::optional GetMemoTimeForKey(ByType key) { auto r = GetMemoEntryForKey(key); if (r.has_value()) { return (*r)->time; @@ -270,7 +449,7 @@ class InputState { } } - void RemoveMemoEntriesWithLesserTime(int64_t ts) { + void RemoveMemoEntriesWithLesserTime(OnType ts) { memo_.RemoveEntriesWithLesserTime(ts); } @@ -294,10 +473,22 @@ class InputState { // Index of the time col col_index_t time_col_index_; // Index of the key col - col_index_t key_col_index_; + std::vector key_col_index_; + // Type id of the time column + Type::type time_type_id_; + // Type id of the key column + std::vector key_type_id_; + // Hasher for key elements + mutable KeyHasher* key_hasher_; + // True if hashing is mandatory + bool must_hash_; + // True if by-key values may be rehashed + bool may_rehash_; // Index of the latest row reference within; if >0 then queue_ cannot be empty // Must be < queue_.front()->num_rows() if queue_ is non-empty row_index_t latest_ref_row_ = 0; + // Time of latest row + OnType latest_time_ = std::numeric_limits::lowest(); // Stores latest known values for the various keys MemoStore memo_; // Mapping of source columns to destination columns @@ -336,18 +527,18 @@ class CompositeReferenceTable { // Adds the latest row from the input state as a new composite reference row // - LHS must have a valid key,timestep,and latest rows // - RHS must have valid data memo'ed for the key - void Emplace(std::vector>& in, int64_t tolerance) { + void Emplace(std::vector>& in, OnType tolerance) { DCHECK_EQ(in.size(), n_tables_); // Get the LHS key - KeyType key = in[0]->GetLatestKey(); + ByType key = in[0]->GetLatestKey(); // Add row and setup LHS // (the LHS state comes just from the latest row of the LHS table) DCHECK(!in[0]->Empty()); const std::shared_ptr& lhs_latest_batch = in[0]->GetLatestBatch(); row_index_t lhs_latest_row = in[0]->GetLatestRow(); - int64_t lhs_latest_time = in[0]->GetLatestTime(); + OnType lhs_latest_time = in[0]->GetLatestTime(); if (0 == lhs_latest_row) { // On the first row of the batch, we resize the destination. // The destination size is dictated by the size of the LHS batch. @@ -407,29 +598,42 @@ class CompositeReferenceTable { DCHECK_EQ(src_field->name(), dst_field->name()); const auto& field_type = src_field->type(); - if (field_type->Equals(arrow::int32())) { - ARROW_ASSIGN_OR_RAISE( - arrays.at(i_dst_col), - (MaterializePrimitiveColumn( - memory_pool, i_table, i_src_col))); - } else if (field_type->Equals(arrow::int64())) { - ARROW_ASSIGN_OR_RAISE( - arrays.at(i_dst_col), - (MaterializePrimitiveColumn( - memory_pool, i_table, i_src_col))); - } else if (field_type->Equals(arrow::float32())) { - ARROW_ASSIGN_OR_RAISE(arrays.at(i_dst_col), - (MaterializePrimitiveColumn( - memory_pool, i_table, i_src_col))); - } else if (field_type->Equals(arrow::float64())) { - ARROW_ASSIGN_OR_RAISE( - arrays.at(i_dst_col), - (MaterializePrimitiveColumn( - memory_pool, i_table, i_src_col))); - } else { - ARROW_RETURN_NOT_OK( - Status::Invalid("Unsupported data type: ", src_field->name())); +#define ASOFJOIN_MATERIALIZE_CASE(id) \ + case Type::id: { \ + using T = typename TypeIdTraits::Type; \ + ARROW_ASSIGN_OR_RAISE( \ + arrays.at(i_dst_col), \ + MaterializeColumn(memory_pool, field_type, i_table, i_src_col)); \ + break; \ + } + + switch (field_type->id()) { + ASOFJOIN_MATERIALIZE_CASE(INT8) + ASOFJOIN_MATERIALIZE_CASE(INT16) + ASOFJOIN_MATERIALIZE_CASE(INT32) + ASOFJOIN_MATERIALIZE_CASE(INT64) + ASOFJOIN_MATERIALIZE_CASE(UINT8) + ASOFJOIN_MATERIALIZE_CASE(UINT16) + ASOFJOIN_MATERIALIZE_CASE(UINT32) + ASOFJOIN_MATERIALIZE_CASE(UINT64) + ASOFJOIN_MATERIALIZE_CASE(FLOAT) + ASOFJOIN_MATERIALIZE_CASE(DOUBLE) + ASOFJOIN_MATERIALIZE_CASE(DATE32) + ASOFJOIN_MATERIALIZE_CASE(DATE64) + ASOFJOIN_MATERIALIZE_CASE(TIME32) + ASOFJOIN_MATERIALIZE_CASE(TIME64) + ASOFJOIN_MATERIALIZE_CASE(TIMESTAMP) + ASOFJOIN_MATERIALIZE_CASE(STRING) + ASOFJOIN_MATERIALIZE_CASE(LARGE_STRING) + ASOFJOIN_MATERIALIZE_CASE(BINARY) + ASOFJOIN_MATERIALIZE_CASE(LARGE_BINARY) + default: + return Status::Invalid("Unsupported data type ", + src_field->type()->ToString(), " for field ", + src_field->name()); } + +#undef ASOFJOIN_MATERIALIZE_CASE } } } @@ -459,17 +663,45 @@ class CompositeReferenceTable { if (!_ptr2ref.count((uintptr_t)ref.get())) _ptr2ref[(uintptr_t)ref.get()] = ref; } - template - Result> MaterializePrimitiveColumn(MemoryPool* memory_pool, - size_t i_table, - col_index_t i_col) { - Builder builder(memory_pool); + template ::BuilderType> + enable_if_fixed_width_type static BuilderAppend( + Builder& builder, const std::shared_ptr& source, row_index_t row) { + if (source->IsNull(row)) { + builder.UnsafeAppendNull(); + return Status::OK(); + } + using CType = typename TypeTraits::CType; + builder.UnsafeAppend(source->template GetValues(1)[row]); + return Status::OK(); + } + + template ::BuilderType> + enable_if_base_binary static BuilderAppend( + Builder& builder, const std::shared_ptr& source, row_index_t row) { + if (source->IsNull(row)) { + return builder.AppendNull(); + } + using offset_type = typename Type::offset_type; + const uint8_t* data = source->buffers[2]->data(); + const offset_type* offsets = source->GetValues(1); + const offset_type offset0 = offsets[row]; + const offset_type offset1 = offsets[row + 1]; + return builder.Append(data + offset0, offset1 - offset0); + } + + template ::BuilderType> + Result> MaterializeColumn(MemoryPool* memory_pool, + const std::shared_ptr& type, + size_t i_table, col_index_t i_col) { + ARROW_ASSIGN_OR_RAISE(auto a_builder, MakeBuilder(type, memory_pool)); + Builder& builder = *checked_cast(a_builder.get()); ARROW_RETURN_NOT_OK(builder.Reserve(rows_.size())); for (row_index_t i_row = 0; i_row < rows_.size(); ++i_row) { const auto& ref = rows_[i_row].refs[i_table]; if (ref.batch) { - builder.UnsafeAppend( - ref.batch->column_data(i_col)->template GetValues(1)[ref.row]); + Status st = + BuilderAppend(builder, ref.batch->column_data(i_col), ref.row); + ARROW_RETURN_NOT_OK(st); } else { builder.UnsafeAppendNull(); } @@ -480,14 +712,21 @@ class CompositeReferenceTable { } }; +// TODO: Currently, AsofJoinNode uses 64-bit hashing which leads to a non-negligible +// probability of collision, which can cause incorrect results when many different by-key +// values are processed. Thus, AsofJoinNode is currently limited to about 100k by-keys for +// guaranteeing this probability is below 1 in a billion. The fix is 128-bit hashing. +// See ARROW-17653 class AsofJoinNode : public ExecNode { // Advances the RHS as far as possible to be up to date for the current LHS timestamp - bool UpdateRhs() { + Result UpdateRhs() { auto& lhs = *state_.at(0); auto lhs_latest_time = lhs.GetLatestTime(); bool any_updated = false; - for (size_t i = 1; i < state_.size(); ++i) - any_updated |= state_[i]->AdvanceAndMemoize(lhs_latest_time); + for (size_t i = 1; i < state_.size(); ++i) { + ARROW_ASSIGN_OR_RAISE(bool advanced, state_[i]->AdvanceAndMemoize(lhs_latest_time)); + any_updated |= advanced; + } return any_updated; } @@ -495,7 +734,7 @@ class AsofJoinNode : public ExecNode { bool IsUpToDateWithLhsRow() const { auto& lhs = *state_[0]; if (lhs.Empty()) return false; // can't proceed if nothing on the LHS - int64_t lhs_ts = lhs.GetLatestTime(); + OnType lhs_ts = lhs.GetLatestTime(); for (size_t i = 1; i < state_.size(); ++i) { auto& rhs = *state_[i]; if (!rhs.Finished()) { @@ -523,7 +762,7 @@ class AsofJoinNode : public ExecNode { if (lhs.Finished() || lhs.Empty()) break; // Advance each of the RHS as far as possible to be up to date for the LHS timestamp - bool any_rhs_advanced = UpdateRhs(); + ARROW_ASSIGN_OR_RAISE(bool any_rhs_advanced, UpdateRhs()); // If we have received enough inputs to produce the next output batch // (decided by IsUpToDateWithLhsRow), we will perform the join and @@ -531,8 +770,9 @@ class AsofJoinNode : public ExecNode { // the LHS and adding joined row to rows_ (done by Emplace). Finally, // input batches that are no longer needed are removed to free up memory. if (IsUpToDateWithLhsRow()) { - dst.Emplace(state_, options_.tolerance); - if (!lhs.Advance()) break; // if we can't advance LHS, we're done for this batch + dst.Emplace(state_, tolerance_); + ARROW_ASSIGN_OR_RAISE(bool advanced, lhs.Advance()); + if (!advanced) break; // if we can't advance LHS, we're done for this batch } else { if (!any_rhs_advanced) break; // need to wait for new data } @@ -541,8 +781,7 @@ class AsofJoinNode : public ExecNode { // Prune memo entries that have expired (to bound memory consumption) if (!lhs.Empty()) { for (size_t i = 1; i < state_.size(); ++i) { - state_[i]->RemoveMemoEntriesWithLesserTime(lhs.GetLatestTime() - - options_.tolerance); + state_[i]->RemoveMemoEntriesWithLesserTime(lhs.GetLatestTime() - tolerance_); } } @@ -572,7 +811,6 @@ class AsofJoinNode : public ExecNode { ExecBatch out_b(*out_rb); outputs_[0]->InputReceived(this, std::move(out_b)); } else { - StopProducing(); ErrorIfNotOk(result.status()); return; } @@ -584,8 +822,8 @@ class AsofJoinNode : public ExecNode { // It may happen here in cases where InputFinished was called before we were finished // producing results (so we didn't know the output size at that time) if (state_.at(0)->Finished()) { - StopProducing(); outputs_[0]->InputFinished(this, batches_produced_); + finished_.MarkFinished(); } } @@ -602,54 +840,172 @@ class AsofJoinNode : public ExecNode { public: AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, - const AsofJoinNodeOptions& join_options, - std::shared_ptr output_schema); + const std::vector& indices_of_on_key, + const std::vector>& indices_of_by_key, + OnType tolerance, std::shared_ptr output_schema, + std::vector> key_hashers, bool must_hash, + bool may_rehash); + + Status Init() override { + auto inputs = this->inputs(); + for (size_t i = 0; i < inputs.size(); i++) { + RETURN_NOT_OK(key_hashers_[i]->Init(plan()->exec_context(), output_schema())); + state_.push_back(::arrow::internal::make_unique( + must_hash_, may_rehash_, key_hashers_[i].get(), inputs[i]->output_schema(), + indices_of_on_key_[i], indices_of_by_key_[i])); + } + + col_index_t dst_offset = 0; + for (auto& state : state_) + dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); + + return Status::OK(); + } virtual ~AsofJoinNode() { process_.Push(false); // poison pill process_thread_.join(); } + const std::vector& indices_of_on_key() { return indices_of_on_key_; } + const std::vector>& indices_of_by_key() { + return indices_of_by_key_; + } + + static Status is_valid_on_field(const std::shared_ptr& field) { + switch (field->type()->id()) { + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::UINT8: + case Type::UINT16: + case Type::UINT32: + case Type::UINT64: + case Type::DATE32: + case Type::DATE64: + case Type::TIME32: + case Type::TIME64: + case Type::TIMESTAMP: + return Status::OK(); + default: + return Status::Invalid("Unsupported type for on-key ", field->name(), " : ", + field->type()->ToString()); + } + } + + static Status is_valid_by_field(const std::shared_ptr& field) { + switch (field->type()->id()) { + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::UINT8: + case Type::UINT16: + case Type::UINT32: + case Type::UINT64: + case Type::DATE32: + case Type::DATE64: + case Type::TIME32: + case Type::TIME64: + case Type::TIMESTAMP: + case Type::STRING: + case Type::LARGE_STRING: + case Type::BINARY: + case Type::LARGE_BINARY: + return Status::OK(); + default: + return Status::Invalid("Unsupported type for by-key ", field->name(), " : ", + field->type()->ToString()); + } + } + + static Status is_valid_data_field(const std::shared_ptr& field) { + switch (field->type()->id()) { + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::UINT8: + case Type::UINT16: + case Type::UINT32: + case Type::UINT64: + case Type::FLOAT: + case Type::DOUBLE: + case Type::DATE32: + case Type::DATE64: + case Type::TIME32: + case Type::TIME64: + case Type::TIMESTAMP: + case Type::STRING: + case Type::LARGE_STRING: + case Type::BINARY: + case Type::LARGE_BINARY: + return Status::OK(); + default: + return Status::Invalid("Unsupported type for data field ", field->name(), " : ", + field->type()->ToString()); + } + } + static arrow::Result> MakeOutputSchema( - const std::vector& inputs, const AsofJoinNodeOptions& options) { + const std::vector& inputs, + const std::vector& indices_of_on_key, + const std::vector>& indices_of_by_key) { std::vector> fields; - const auto& on_field_name = *options.on_key.name(); - const auto& by_field_name = *options.by_key.name(); - + size_t n_by = indices_of_by_key[0].size(); + const DataType* on_key_type = NULLPTR; + std::vector by_key_type(n_by, NULLPTR); // Take all non-key, non-time RHS fields for (size_t j = 0; j < inputs.size(); ++j) { const auto& input_schema = inputs[j]->output_schema(); - const auto& on_field_ix = input_schema->GetFieldIndex(on_field_name); - const auto& by_field_ix = input_schema->GetFieldIndex(by_field_name); + const auto& on_field_ix = indices_of_on_key[j]; + const auto& by_field_ix = indices_of_by_key[j]; - if ((on_field_ix == -1) | (by_field_ix == -1)) { + if ((on_field_ix == -1) || std_has(by_field_ix, -1)) { return Status::Invalid("Missing join key on table ", j); } + const auto& on_field = input_schema->fields()[on_field_ix]; + std::vector by_field(n_by); + for (size_t k = 0; k < n_by; k++) { + by_field[k] = input_schema->fields()[by_field_ix[k]].get(); + } + + if (on_key_type == NULLPTR) { + on_key_type = on_field->type().get(); + } else if (*on_key_type != *on_field->type()) { + return Status::Invalid("Expected on-key type ", *on_key_type, " but got ", + *on_field->type(), " for field ", on_field->name(), + " in input ", j); + } + for (size_t k = 0; k < n_by; k++) { + if (by_key_type[k] == NULLPTR) { + by_key_type[k] = by_field[k]->type().get(); + } else if (*by_key_type[k] != *by_field[k]->type()) { + return Status::Invalid("Expected on-key type ", *by_key_type[k], " but got ", + *by_field[k]->type(), " for field ", by_field[k]->name(), + " in input ", j); + } + } + for (int i = 0; i < input_schema->num_fields(); ++i) { const auto field = input_schema->field(i); - if (field->name() == on_field_name) { - if (kSupportedOnTypes_.find(field->type()) == kSupportedOnTypes_.end()) { - return Status::Invalid("Unsupported type for on key: ", field->name()); - } + if (i == on_field_ix) { + ARROW_RETURN_NOT_OK(is_valid_on_field(field)); // Only add on field from the left table if (j == 0) { fields.push_back(field); } - } else if (field->name() == by_field_name) { - if (kSupportedByTypes_.find(field->type()) == kSupportedByTypes_.end()) { - return Status::Invalid("Unsupported type for by key: ", field->name()); - } + } else if (std_has(by_field_ix, i)) { + ARROW_RETURN_NOT_OK(is_valid_by_field(field)); // Only add by field from the left table if (j == 0) { fields.push_back(field); } } else { - if (kSupportedDataTypes_.find(field->type()) == kSupportedDataTypes_.end()) { - return Status::Invalid("Unsupported data type: ", field->name()); - } - + ARROW_RETURN_NOT_OK(is_valid_data_field(field)); fields.push_back(field); } } @@ -657,45 +1013,91 @@ class AsofJoinNode : public ExecNode { return std::make_shared(fields); } + static inline Result FindColIndex(const Schema& schema, + const FieldRef& field_ref, + util::string_view key_kind) { + auto match_res = field_ref.FindOne(schema); + if (!match_res.ok()) { + return Status::Invalid("Bad join key on table : ", match_res.status().message()); + } + ARROW_ASSIGN_OR_RAISE(auto match, match_res); + if (match.indices().size() != 1) { + return Status::Invalid("AsOfJoinNode does not support a nested ", + to_string(key_kind), "-key ", field_ref.ToString()); + } + return match.indices()[0]; + } + static arrow::Result Make(ExecPlan* plan, std::vector inputs, const ExecNodeOptions& options) { DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; const auto& join_options = checked_cast(options); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, - MakeOutputSchema(inputs, join_options)); + if (join_options.tolerance < 0) { + return Status::Invalid("AsOfJoin tolerance must be non-negative but is ", + join_options.tolerance); + } - std::vector input_labels(inputs.size()); - input_labels[0] = "left"; - for (size_t i = 1; i < inputs.size(); ++i) { - input_labels[i] = "right_" + std::to_string(i); + size_t n_input = inputs.size(), n_by = join_options.by_key.size(); + std::vector input_labels(n_input); + std::vector indices_of_on_key(n_input); + std::vector> indices_of_by_key( + n_input, std::vector(n_by)); + for (size_t i = 0; i < n_input; ++i) { + input_labels[i] = i == 0 ? "left" : "right_" + std::to_string(i); + const Schema& input_schema = *inputs[i]->output_schema(); + ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i], + FindColIndex(input_schema, join_options.on_key, "on")); + for (size_t k = 0; k < n_by; k++) { + ARROW_ASSIGN_OR_RAISE(indices_of_by_key[i][k], + FindColIndex(input_schema, join_options.by_key[k], "by")); + } } - return plan->EmplaceNode(plan, inputs, std::move(input_labels), - join_options, std::move(output_schema)); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, + MakeOutputSchema(inputs, indices_of_on_key, indices_of_by_key)); + + std::vector> key_hashers; + for (size_t i = 0; i < n_input; i++) { + key_hashers.push_back( + ::arrow::internal::make_unique(indices_of_by_key[i])); + } + bool must_hash = + n_by > 1 || + (n_by == 1 && + !is_primitive( + inputs[0]->output_schema()->field(indices_of_by_key[0][0])->type()->id())); + bool may_rehash = n_by == 1 && !must_hash; + return plan->EmplaceNode( + plan, inputs, std::move(input_labels), std::move(indices_of_on_key), + std::move(indices_of_by_key), time_value(join_options.tolerance), + std::move(output_schema), std::move(key_hashers), must_hash, may_rehash); } const char* kind_name() const override { return "AsofJoinNode"; } void InputReceived(ExecNode* input, ExecBatch batch) override { // Get the input - ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); - size_t k = std::find(inputs_.begin(), inputs_.end(), input) - inputs_.begin(); + ARROW_DCHECK(std_has(inputs_, input)); + size_t k = std_find(inputs_, input) - inputs_.begin(); // Put into the queue auto rb = *batch.ToRecordBatch(input->output_schema()); - state_.at(k)->Push(rb); + Status st = state_.at(k)->Push(rb); + if (!st.ok()) { + ErrorReceived(input, st); + return; + } process_.Push(true); } void ErrorReceived(ExecNode* input, Status error) override { outputs_[0]->ErrorReceived(this, std::move(error)); - StopProducing(); } void InputFinished(ExecNode* input, int total_batches) override { { std::lock_guard guard(gate_); - ARROW_DCHECK(std::find(inputs_.begin(), inputs_.end(), input) != inputs_.end()); - size_t k = std::find(inputs_.begin(), inputs_.end(), input) - inputs_.begin(); + ARROW_DCHECK(std_has(inputs_, input)); + size_t k = std_find(inputs_, input) - inputs_.begin(); state_.at(k)->set_total_batches(total_batches); } // Trigger a process call @@ -714,20 +1116,24 @@ class AsofJoinNode : public ExecNode { DCHECK_EQ(output, outputs_[0]); StopProducing(); } - void StopProducing() override { finished_.MarkFinished(); } + void StopProducing() override { + process_.Clear(); + process_.Push(false); + } arrow::Future<> finished() override { return finished_; } private: - static const std::set> kSupportedOnTypes_; - static const std::set> kSupportedByTypes_; - static const std::set> kSupportedDataTypes_; - arrow::Future<> finished_; + std::vector indices_of_on_key_; + std::vector> indices_of_by_key_; + std::vector> key_hashers_; + bool must_hash_; + bool may_rehash_; // InputStates // Each input state correponds to an input table std::vector> state_; std::mutex gate_; - AsofJoinNodeOptions options_; + OnType tolerance_; // Queue for triggering processing of a given input // (a false value is a poison pill) @@ -741,30 +1147,25 @@ class AsofJoinNode : public ExecNode { AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, - const AsofJoinNodeOptions& join_options, - std::shared_ptr output_schema) + const std::vector& indices_of_on_key, + const std::vector>& indices_of_by_key, + OnType tolerance, std::shared_ptr output_schema, + std::vector> key_hashers, + bool must_hash, bool may_rehash) : ExecNode(plan, inputs, input_labels, /*output_schema=*/std::move(output_schema), /*num_outputs=*/1), - options_(join_options), + indices_of_on_key_(std::move(indices_of_on_key)), + indices_of_by_key_(std::move(indices_of_by_key)), + key_hashers_(std::move(key_hashers)), + must_hash_(must_hash), + may_rehash_(may_rehash), + tolerance_(tolerance), process_(), process_thread_(&AsofJoinNode::ProcessThreadWrapper, this) { - for (size_t i = 0; i < inputs.size(); ++i) - state_.push_back(::arrow::internal::make_unique( - inputs[i]->output_schema(), *options_.on_key.name(), *options_.by_key.name())); - col_index_t dst_offset = 0; - for (auto& state : state_) - dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); - finished_ = arrow::Future<>::MakeFinished(); } -// Currently supported types -const std::set> AsofJoinNode::kSupportedOnTypes_ = {int64()}; -const std::set> AsofJoinNode::kSupportedByTypes_ = {int32()}; -const std::set> AsofJoinNode::kSupportedDataTypes_ = { - int32(), int64(), float32(), float64()}; - namespace internal { void RegisterAsofJoinNode(ExecFactoryRegistry* registry) { DCHECK_OK(registry->AddFactory("asofjoin", AsofJoinNode::Make)); diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 8b993764abe..48d1ae6410b 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -17,11 +17,13 @@ #include +#include #include #include #include #include "arrow/api.h" +#include "arrow/compute/api_scalar.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/test_util.h" #include "arrow/compute/exec/util.h" @@ -32,23 +34,185 @@ #include "arrow/testing/random.h" #include "arrow/util/checked_cast.h" #include "arrow/util/make_unique.h" +#include "arrow/util/string_view.h" #include "arrow/util/thread_pool.h" +#define TRACED_TEST(t_class, t_name, t_body) \ + TEST(t_class, t_name) { \ + ARROW_SCOPED_TRACE(#t_class "_" #t_name); \ + t_body; \ + } + +#define TRACED_TEST_P(t_class, t_name, t_body) \ + TEST_P(t_class, t_name) { \ + ARROW_SCOPED_TRACE(#t_class "_" #t_name "_" + std::get<1>(GetParam())); \ + t_body; \ + } + using testing::UnorderedElementsAreArray; namespace arrow { namespace compute { +bool is_temporal_primitive(Type::type type_id) { + switch (type_id) { + case Type::TIME32: + case Type::TIME64: + case Type::DATE32: + case Type::DATE64: + case Type::TIMESTAMP: + return true; + default: + return false; + } +} + +Result MakeBatchesFromNumString( + const std::shared_ptr& schema, + const std::vector& json_strings, int multiplicity = 1) { + FieldVector num_fields; + for (auto field : schema->fields()) { + num_fields.push_back( + is_base_binary_like(field->type()->id()) ? field->WithType(int64()) : field); + } + auto num_schema = + std::make_shared(num_fields, schema->endianness(), schema->metadata()); + BatchesWithSchema num_batches = + MakeBatchesFromString(num_schema, json_strings, multiplicity); + BatchesWithSchema batches; + batches.schema = schema; + int n_fields = schema->num_fields(); + for (auto num_batch : num_batches.batches) { + std::vector values; + for (int i = 0; i < n_fields; i++) { + auto type = schema->field(i)->type(); + if (is_base_binary_like(type->id())) { + // casting to string first enables casting to binary + ARROW_ASSIGN_OR_RAISE(Datum as_string, Cast(num_batch.values[i], utf8())); + ARROW_ASSIGN_OR_RAISE(Datum as_type, Cast(as_string, type)); + values.push_back(as_type); + } else { + values.push_back(num_batch.values[i]); + } + } + ExecBatch batch(values, num_batch.length); + batches.batches.push_back(batch); + } + return batches; +} + +void BuildNullArray(std::shared_ptr& empty, const std::shared_ptr& type, + int64_t length) { + ASSERT_OK_AND_ASSIGN(auto builder, MakeBuilder(type, default_memory_pool())); + ASSERT_OK(builder->Reserve(length)); + ASSERT_OK(builder->AppendNulls(length)); + ASSERT_OK(builder->Finish(&empty)); +} + +void BuildZeroPrimitiveArray(std::shared_ptr& empty, + const std::shared_ptr& type, int64_t length) { + ASSERT_OK_AND_ASSIGN(auto builder, MakeBuilder(type, default_memory_pool())); + ASSERT_OK(builder->Reserve(length)); + ASSERT_OK_AND_ASSIGN(auto scalar, MakeScalar(type, 0)); + ASSERT_OK(builder->AppendScalar(*scalar, length)); + ASSERT_OK(builder->Finish(&empty)); +} + +template +void BuildZeroBaseBinaryArray(std::shared_ptr& empty, int64_t length) { + Builder builder(default_memory_pool()); + ASSERT_OK(builder.Reserve(length)); + for (int64_t i = 0; i < length; i++) { + ASSERT_OK(builder.Append("0", /*length=*/1)); + } + ASSERT_OK(builder.Finish(&empty)); +} + +// mutates by copying from_key into to_key and changing from_key to zero +Result MutateByKey(BatchesWithSchema& batches, std::string from_key, + std::string to_key, bool replace_key = false, + bool null_key = false, bool remove_key = false) { + int from_index = batches.schema->GetFieldIndex(from_key); + int n_fields = batches.schema->num_fields(); + auto fields = batches.schema->fields(); + BatchesWithSchema new_batches; + if (remove_key) { + ARROW_ASSIGN_OR_RAISE(new_batches.schema, batches.schema->RemoveField(from_index)); + } else { + auto new_field = batches.schema->field(from_index)->WithName(to_key); + ARROW_ASSIGN_OR_RAISE(new_batches.schema, + replace_key ? batches.schema->SetField(from_index, new_field) + : batches.schema->AddField(from_index, new_field)); + } + for (const ExecBatch& batch : batches.batches) { + std::vector new_values; + for (int i = 0; i < n_fields; i++) { + const Datum& value = batch.values[i]; + if (i == from_index) { + if (remove_key) { + continue; + } + auto type = fields[i]->type(); + if (null_key) { + std::shared_ptr empty; + BuildNullArray(empty, type, batch.length); + new_values.push_back(empty); + } else if (is_primitive(type->id())) { + std::shared_ptr empty; + BuildZeroPrimitiveArray(empty, type, batch.length); + new_values.push_back(empty); + } else if (is_base_binary_like(type->id())) { + std::shared_ptr empty; + switch (type->id()) { + case Type::STRING: + BuildZeroBaseBinaryArray(empty, batch.length); + break; + case Type::LARGE_STRING: + BuildZeroBaseBinaryArray(empty, batch.length); + break; + case Type::BINARY: + BuildZeroBaseBinaryArray(empty, batch.length); + break; + case Type::LARGE_BINARY: + BuildZeroBaseBinaryArray(empty, batch.length); + break; + default: + DCHECK(false); + break; + } + new_values.push_back(empty); + } else { + ARROW_ASSIGN_OR_RAISE(auto sub, Subtract(value, value)); + new_values.push_back(sub); + } + if (replace_key) { + continue; + } + } + new_values.push_back(value); + } + new_batches.batches.emplace_back(new_values, batch.length); + } + return new_batches; +} + +// code generation for the by_key types supported by AsofJoinNodeOptions constructors +// which cannot be directly done using templates because of failure to deduce the template +// argument for an invocation with a string- or initializer_list-typed keys-argument +#define EXPAND_BY_KEY_TYPE(macro) \ + macro(const FieldRef); \ + macro(std::vector); \ + macro(std::initializer_list); + void CheckRunOutput(const BatchesWithSchema& l_batches, const BatchesWithSchema& r0_batches, const BatchesWithSchema& r1_batches, - const BatchesWithSchema& exp_batches, const FieldRef time, - const FieldRef keys, const int64_t tolerance) { + const BatchesWithSchema& exp_batches, + const AsofJoinNodeOptions join_options) { auto exec_ctx = arrow::internal::make_unique(default_memory_pool(), nullptr); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); - AsofJoinNodeOptions join_options(time, keys, tolerance); Declaration join{"asofjoin", join_options}; join.inputs.emplace_back(Declaration{ @@ -64,6 +228,9 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, .AddToPlan(plan.get())); ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + for (auto batch : res) { + ASSERT_EQ(exp_batches.schema->num_fields(), batch.values.size()); + } ASSERT_OK_AND_ASSIGN(auto exp_table, TableFromExecBatches(exp_batches.schema, exp_batches.batches)); @@ -74,237 +241,783 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, /*same_chunk_layout=*/true, /*flatten=*/true); } -void DoRunBasicTest(const std::vector& l_data, - const std::vector& r0_data, - const std::vector& r1_data, - const std::vector& exp_data, int64_t tolerance) { - auto l_schema = - schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}); - auto r0_schema = - schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())}); - auto r1_schema = - schema({field("time", int64()), field("key", int32()), field("r1_v0", float32())}); - - auto exp_schema = schema({ - field("time", int64()), - field("key", int32()), - field("l_v0", float64()), - field("r0_v0", float64()), - field("r1_v0", float32()), - }); - - // Test three table join - BatchesWithSchema l_batches, r0_batches, r1_batches, exp_batches; - l_batches = MakeBatchesFromString(l_schema, l_data); - r0_batches = MakeBatchesFromString(r0_schema, r0_data); - r1_batches = MakeBatchesFromString(r1_schema, r1_data); - exp_batches = MakeBatchesFromString(exp_schema, exp_data); - CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", - tolerance); -} - -void DoRunInvalidTypeTest(const std::shared_ptr& l_schema, - const std::shared_ptr& r_schema) { - BatchesWithSchema l_batches = MakeBatchesFromString(l_schema, {R"([])"}); - BatchesWithSchema r_batches = MakeBatchesFromString(r_schema, {R"([])"}); - +#define CHECK_RUN_OUTPUT(by_key_type) \ + void CheckRunOutput( \ + const BatchesWithSchema& l_batches, const BatchesWithSchema& r0_batches, \ + const BatchesWithSchema& r1_batches, const BatchesWithSchema& exp_batches, \ + const FieldRef time, by_key_type key, const int64_t tolerance) { \ + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, \ + AsofJoinNodeOptions(time, {key}, tolerance)); \ + } + +EXPAND_BY_KEY_TYPE(CHECK_RUN_OUTPUT) + +void DoInvalidPlanTest(const BatchesWithSchema& l_batches, + const BatchesWithSchema& r_batches, + const AsofJoinNodeOptions& join_options, + const std::string& expected_error_str, + bool fail_on_plan_creation = false) { ExecContext exec_ctx; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx)); - AsofJoinNodeOptions join_options("time", "key", 0); Declaration join{"asofjoin", join_options}; join.inputs.emplace_back(Declaration{ "source", SourceNodeOptions{l_batches.schema, l_batches.gen(false, false)}}); join.inputs.emplace_back(Declaration{ "source", SourceNodeOptions{r_batches.schema, r_batches.gen(false, false)}}); - ASSERT_RAISES(Invalid, join.AddToPlan(plan.get())); + if (fail_on_plan_creation) { + AsyncGenerator> sink_gen; + ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); + EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr(expected_error_str), + StartAndCollect(plan.get(), sink_gen)); + } else { + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, ::testing::HasSubstr(expected_error_str), + join.AddToPlan(plan.get())); + } +} + +void DoRunInvalidPlanTest(const BatchesWithSchema& l_batches, + const BatchesWithSchema& r_batches, + const AsofJoinNodeOptions& join_options, + const std::string& expected_error_str) { + DoInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str); +} + +void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema, + const AsofJoinNodeOptions& join_options, + const std::string& expected_error_str) { + ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema, {R"([])"})); + ASSERT_OK_AND_ASSIGN(auto r_batches, MakeBatchesFromNumString(r_schema, {R"([])"})); + + return DoRunInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str); +} + +void DoRunInvalidPlanTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema, int64_t tolerance, + const std::string& expected_error_str) { + DoRunInvalidPlanTest(l_schema, r_schema, + AsofJoinNodeOptions("time", {"key"}, tolerance), + expected_error_str); +} + +void DoRunInvalidTypeTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, 0, "Unsupported type for "); +} + +void DoRunInvalidToleranceTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, -1, + "AsOfJoin tolerance must be non-negative but is "); +} + +void DoRunMissingKeysTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : No match"); +} + +void DoRunMissingOnKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, + AsofJoinNodeOptions("invalid_time", {"key"}, 0), + "Bad join key on table : No match"); +} + +void DoRunMissingByKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, + AsofJoinNodeOptions("time", {"invalid_key"}, 0), + "Bad join key on table : No match"); +} + +void DoRunNestedOnKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions({0, "time"}, {"key"}, 0), + "Bad join key on table : No match"); +} + +void DoRunNestedByKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, + AsofJoinNodeOptions("time", {FieldRef{0, 1}}, 0), + "Bad join key on table : No match"); +} + +void DoRunAmbiguousOnKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : Multiple matches"); +} + +void DoRunAmbiguousByKeyTest(const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunInvalidPlanTest(l_schema, r_schema, 0, "Bad join key on table : Multiple matches"); +} + +// Gets a batch for testing as a Json string +// The batch will have n_rows rows n_cols columns, the first column being the on-field +// If unordered is true then the first column will be out-of-order +std::string GetTestBatchAsJsonString(int n_rows, int n_cols, bool unordered = false) { + int order_mask = unordered ? 1 : 0; + std::stringstream s; + s << '['; + for (int i = 0; i < n_rows; i++) { + if (i > 0) { + s << ", "; + } + s << '['; + for (int j = 0; j < n_cols; j++) { + if (j > 0) { + s << ", " << j; + } else if (j < 2) { + s << (i ^ order_mask); + } else { + s << i; + } + } + s << ']'; + } + s << ']'; + return s.str(); +} + +void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered, + const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema, + const AsofJoinNodeOptions& join_options, + const std::string& expected_error_str) { + ASSERT_TRUE(l_unordered || r_unordered); + int n_rows = 5; + auto l_str = GetTestBatchAsJsonString(n_rows, l_schema->num_fields(), l_unordered); + auto r_str = GetTestBatchAsJsonString(n_rows, r_schema->num_fields(), r_unordered); + ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema, {l_str})); + ASSERT_OK_AND_ASSIGN(auto r_batches, MakeBatchesFromNumString(r_schema, {r_str})); + + return DoInvalidPlanTest(l_batches, r_batches, join_options, expected_error_str, + /*then_run_plan=*/true); } +void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered, + const std::shared_ptr& l_schema, + const std::shared_ptr& r_schema) { + DoRunUnorderedPlanTest(l_unordered, r_unordered, l_schema, r_schema, + AsofJoinNodeOptions("time", {"key"}, 1000), + "out-of-order on-key values"); +} + +struct BasicTestTypes { + std::shared_ptr time, key, l_val, r0_val, r1_val; +}; + +struct BasicTest { + BasicTest(const std::vector& l_data, + const std::vector& r0_data, + const std::vector& r1_data, + const std::vector& exp_nokey_data, + const std::vector& exp_emptykey_data, + const std::vector& exp_data, int64_t tolerance) + : l_data(std::move(l_data)), + r0_data(std::move(r0_data)), + r1_data(std::move(r1_data)), + exp_nokey_data(std::move(exp_nokey_data)), + exp_emptykey_data(std::move(exp_emptykey_data)), + exp_data(std::move(exp_data)), + tolerance(tolerance) {} + + static inline void check_init(const std::vector>& types) { + ASSERT_NE(0, types.size()); + } + + template + static inline std::vector> init_types( + const std::vector>& all_types, TypeCond type_cond) { + std::vector> types; + for (auto type : all_types) { + if (type_cond(type)) { + types.push_back(type); + } + } + check_init(types); + return types; + } + + void RunSingleByKey() { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_emptykey_batches, B exp_batches) { + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", "key", + tolerance); + }); + } + static void DoSingleByKey(BasicTest& basic_tests) { basic_tests.RunSingleByKey(); } + void RunDoubleByKey() { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_emptykey_batches, B exp_batches) { + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", + {"key", "key"}, tolerance); + }); + } + static void DoDoubleByKey(BasicTest& basic_tests) { basic_tests.RunDoubleByKey(); } + void RunMutateByKey() { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_emptykey_batches, B exp_batches) { + ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2")); + ASSERT_OK_AND_ASSIGN(r0_batches, MutateByKey(r0_batches, "key", "key2")); + ASSERT_OK_AND_ASSIGN(r1_batches, MutateByKey(r1_batches, "key", "key2")); + ASSERT_OK_AND_ASSIGN(exp_batches, MutateByKey(exp_batches, "key", "key2")); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, "time", + {"key", "key2"}, tolerance); + }); + } + static void DoMutateByKey(BasicTest& basic_tests) { basic_tests.RunMutateByKey(); } + void RunMutateNoKey() { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_emptykey_batches, B exp_batches) { + ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2", true)); + ASSERT_OK_AND_ASSIGN(r0_batches, MutateByKey(r0_batches, "key", "key2", true)); + ASSERT_OK_AND_ASSIGN(r1_batches, MutateByKey(r1_batches, "key", "key2", true)); + ASSERT_OK_AND_ASSIGN(exp_nokey_batches, + MutateByKey(exp_nokey_batches, "key", "key2", true)); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches, "time", "key2", + tolerance); + }); + } + static void DoMutateNoKey(BasicTest& basic_tests) { basic_tests.RunMutateNoKey(); } + void RunMutateNullKey() { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_emptykey_batches, B exp_batches) { + ASSERT_OK_AND_ASSIGN(l_batches, MutateByKey(l_batches, "key", "key2", true, true)); + ASSERT_OK_AND_ASSIGN(r0_batches, + MutateByKey(r0_batches, "key", "key2", true, true)); + ASSERT_OK_AND_ASSIGN(r1_batches, + MutateByKey(r1_batches, "key", "key2", true, true)); + ASSERT_OK_AND_ASSIGN(exp_nokey_batches, + MutateByKey(exp_nokey_batches, "key", "key2", true, true)); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches, + AsofJoinNodeOptions("time", {"key2"}, tolerance)); + }); + } + static void DoMutateNullKey(BasicTest& basic_tests) { basic_tests.RunMutateNullKey(); } + void RunMutateEmptyKey() { + using B = BatchesWithSchema; + RunBatches([this](B l_batches, B r0_batches, B r1_batches, B exp_nokey_batches, + B exp_emptykey_batches, B exp_batches) { + ASSERT_OK_AND_ASSIGN(r0_batches, + MutateByKey(r0_batches, "key", "key", false, false, true)); + ASSERT_OK_AND_ASSIGN(r1_batches, + MutateByKey(r1_batches, "key", "key", false, false, true)); + CheckRunOutput(l_batches, r0_batches, r1_batches, exp_emptykey_batches, + AsofJoinNodeOptions("time", {}, tolerance)); + }); + } + static void DoMutateEmptyKey(BasicTest& basic_tests) { + basic_tests.RunMutateEmptyKey(); + } + template + void RunBatches(BatchesRunner batches_runner) { + std::vector> all_types = { + utf8(), + large_utf8(), + binary(), + large_binary(), + int8(), + int16(), + int32(), + int64(), + uint8(), + uint16(), + uint32(), + uint64(), + date32(), + date64(), + time32(TimeUnit::MILLI), + time32(TimeUnit::SECOND), + time64(TimeUnit::NANO), + time64(TimeUnit::MICRO), + timestamp(TimeUnit::NANO, "UTC"), + timestamp(TimeUnit::MICRO, "UTC"), + timestamp(TimeUnit::MILLI, "UTC"), + timestamp(TimeUnit::SECOND, "UTC"), + float32(), + float64()}; + using T = const std::shared_ptr; + // byte_width > 1 below allows fitting the tested data + auto time_types = init_types( + all_types, [](T& t) { return t->byte_width() > 1 && !is_floating(t->id()); }); + auto key_types = init_types(all_types, [](T& t) { return !is_floating(t->id()); }); + auto l_types = init_types(all_types, [](T& t) { return true; }); + auto r0_types = init_types(all_types, [](T& t) { return t->byte_width() > 1; }); + auto r1_types = init_types(all_types, [](T& t) { return t->byte_width() > 1; }); + + // sample a limited number of type-combinations to keep the runnning time reasonable + // the scoped-traces below help reproduce a test failure, should it happen + auto start_time = std::chrono::system_clock::now(); + auto seed = start_time.time_since_epoch().count(); + ARROW_SCOPED_TRACE("Types seed: ", seed); + std::default_random_engine engine(static_cast(seed)); + std::uniform_int_distribution time_distribution(0, time_types.size() - 1); + std::uniform_int_distribution key_distribution(0, key_types.size() - 1); + std::uniform_int_distribution l_distribution(0, l_types.size() - 1); + std::uniform_int_distribution r0_distribution(0, r0_types.size() - 1); + std::uniform_int_distribution r1_distribution(0, r1_types.size() - 1); + + for (int i = 0; i < 1000; i++) { + auto time_type = time_types[time_distribution(engine)]; + ARROW_SCOPED_TRACE("Time type: ", *time_type); + auto key_type = key_types[key_distribution(engine)]; + ARROW_SCOPED_TRACE("Key type: ", *key_type); + auto l_type = l_types[l_distribution(engine)]; + ARROW_SCOPED_TRACE("Left type: ", *l_type); + auto r0_type = r0_types[r0_distribution(engine)]; + ARROW_SCOPED_TRACE("Right-0 type: ", *r0_type); + auto r1_type = r1_types[r1_distribution(engine)]; + ARROW_SCOPED_TRACE("Right-1 type: ", *r1_type); + + RunTypes({time_type, key_type, l_type, r0_type, r1_type}, batches_runner); + + auto end_time = std::chrono::system_clock::now(); + std::chrono::duration diff = end_time - start_time; + if (diff.count() > 2) { + // this normally happens on slow CI systems, but is fine + break; + } + } + } + template + void RunTypes(BasicTestTypes basic_test_types, BatchesRunner batches_runner) { + const BasicTestTypes& b = basic_test_types; + auto l_schema = + schema({field("time", b.time), field("key", b.key), field("l_v0", b.l_val)}); + auto r0_schema = + schema({field("time", b.time), field("key", b.key), field("r0_v0", b.r0_val)}); + auto r1_schema = + schema({field("time", b.time), field("key", b.key), field("r1_v0", b.r1_val)}); + + auto exp_schema = schema({ + field("time", b.time), + field("key", b.key), + field("l_v0", b.l_val), + field("r0_v0", b.r0_val), + field("r1_v0", b.r1_val), + }); + + // Test three table join + ASSERT_OK_AND_ASSIGN(auto l_batches, MakeBatchesFromNumString(l_schema, l_data)); + ASSERT_OK_AND_ASSIGN(auto r0_batches, MakeBatchesFromNumString(r0_schema, r0_data)); + ASSERT_OK_AND_ASSIGN(auto r1_batches, MakeBatchesFromNumString(r1_schema, r1_data)); + ASSERT_OK_AND_ASSIGN(auto exp_nokey_batches, + MakeBatchesFromNumString(exp_schema, exp_nokey_data)); + ASSERT_OK_AND_ASSIGN(auto exp_emptykey_batches, + MakeBatchesFromNumString(exp_schema, exp_emptykey_data)); + ASSERT_OK_AND_ASSIGN(auto exp_batches, + MakeBatchesFromNumString(exp_schema, exp_data)); + batches_runner(l_batches, r0_batches, r1_batches, exp_nokey_batches, + exp_emptykey_batches, exp_batches); + } + + std::vector l_data; + std::vector r0_data; + std::vector r1_data; + std::vector exp_nokey_data; + std::vector exp_emptykey_data; + std::vector exp_data; + int64_t tolerance; +}; + +using AsofJoinBasicParams = std::tuple, std::string>; + +struct AsofJoinBasicTest : public testing::TestWithParam {}; + class AsofJoinTest : public testing::Test {}; -TEST(AsofJoinTest, TestBasic1) { +BasicTest GetBasicTest1() { // Single key, single batch - DoRunBasicTest( - /*l*/ {R"([[0, 1, 1.0], [1000, 1, 2.0]])"}, - /*r0*/ {R"([[0, 1, 11.0]])"}, - /*r1*/ {R"([[1000, 1, 101.0]])"}, - /*exp*/ {R"([[0, 1, 1.0, 11.0, null], [1000, 1, 2.0, 11.0, 101.0]])"}, 1000); + return BasicTest( + /*l*/ {R"([[0, 1, 1], [1000, 1, 2]])"}, + /*r0*/ {R"([[0, 1, 11]])"}, + /*r1*/ {R"([[1000, 1, 101]])"}, + /*exp_nokey*/ {R"([[0, 0, 1, 11, null], [1000, 0, 2, 11, 101]])"}, + /*exp_emptykey*/ {R"([[0, 1, 1, 11, null], [1000, 1, 2, 11, 101]])"}, + /*exp*/ {R"([[0, 1, 1, 11, null], [1000, 1, 2, 11, 101]])"}, 1000); } -TEST(AsofJoinTest, TestBasic2) { +TRACED_TEST_P(AsofJoinBasicTest, TestBasic1, { + BasicTest basic_test = GetBasicTest1(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetBasicTest2() { // Single key, multiple batches - DoRunBasicTest( - /*l*/ {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"}, - /*r0*/ {R"([[0, 1, 11.0]])", R"([[1000, 1, 12.0]])"}, - /*r1*/ {R"([[0, 1, 101.0]])", R"([[1000, 1, 102.0]])"}, - /*exp*/ {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"}, 1000); + return BasicTest( + /*l*/ {R"([[0, 1, 1]])", R"([[1000, 1, 2]])"}, + /*r0*/ {R"([[0, 1, 11]])", R"([[1000, 1, 12]])"}, + /*r1*/ {R"([[0, 1, 101]])", R"([[1000, 1, 102]])"}, + /*exp_nokey*/ {R"([[0, 0, 1, 11, 101], [1000, 0, 2, 12, 102]])"}, + /*exp_emptykey*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, + /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, 1000); } -TEST(AsofJoinTest, TestBasic3) { +TRACED_TEST_P(AsofJoinBasicTest, TestBasic2, { + BasicTest basic_test = GetBasicTest2(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetBasicTest3() { // Single key, multiple left batches, single right batches - DoRunBasicTest( - /*l*/ {R"([[0, 1, 1.0]])", R"([[1000, 1, 2.0]])"}, - /*r0*/ {R"([[0, 1, 11.0], [1000, 1, 12.0]])"}, - /*r1*/ {R"([[0, 1, 101.0], [1000, 1, 102.0]])"}, - /*exp*/ {R"([[0, 1, 1.0, 11.0, 101.0], [1000, 1, 2.0, 12.0, 102.0]])"}, 1000); + return BasicTest( + /*l*/ {R"([[0, 1, 1]])", R"([[1000, 1, 2]])"}, + /*r0*/ {R"([[0, 1, 11], [1000, 1, 12]])"}, + /*r1*/ {R"([[0, 1, 101], [1000, 1, 102]])"}, + /*exp_nokey*/ {R"([[0, 0, 1, 11, 101], [1000, 0, 2, 12, 102]])"}, + /*exp_emptykey*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, + /*exp*/ {R"([[0, 1, 1, 11, 101], [1000, 1, 2, 12, 102]])"}, 1000); } -TEST(AsofJoinTest, TestBasic4) { +TRACED_TEST_P(AsofJoinBasicTest, TestBasic3, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestBasic3_" + std::get<1>(GetParam())); + BasicTest basic_test = GetBasicTest3(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetBasicTest4() { // Multi key, multiple batches, misaligned batches - DoRunBasicTest( + return BasicTest( /*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", + R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp_emptykey*/ + {R"([[0, 1, 1, 11, 1001], [0, 2, 21, 11, 1001], [500, 1, 2, 31, 101], [1000, 2, 22, 12, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"}, /*exp*/ - {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, 1001.0], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, + {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, 11, 101], [1000, 2, 22, 31, 1001], [1500, 1, 3, 12, 102], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 1000); } -TEST(AsofJoinTest, TestBasic5) { +TRACED_TEST_P(AsofJoinBasicTest, TestBasic4, { + BasicTest basic_test = GetBasicTest4(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetBasicTest5() { // Multi key, multiple batches, misaligned batches, smaller tolerance - DoRunBasicTest(/*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, - /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, 11.0, 101.0], [1000, 2, 22.0, 31.0, null], [1500, 1, 3.0, 12.0, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, - 500); -} - -TEST(AsofJoinTest, TestBasic6) { + return BasicTest(/*l*/ + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", + R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp_emptykey*/ + {R"([[0, 1, 1, 11, 1001], [0, 2, 21, 11, 1001], [500, 1, 2, 31, 101], [1000, 2, 22, 12, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"}, + /*exp*/ + {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, 11, 101], [1000, 2, 22, 31, null], [1500, 1, 3, 12, 102], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, + 500); +} + +TRACED_TEST_P(AsofJoinBasicTest, TestBasic5, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestBasic5_" + std::get<1>(GetParam())); + BasicTest basic_test = GetBasicTest5(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetBasicTest6() { // Multi key, multiple batches, misaligned batches, zero tolerance - DoRunBasicTest(/*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, - /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([[0, 1, 1.0, 11.0, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, null], [1500, 1, 3.0, null, null], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, null, null]])"}, - 0); -} - -TEST(AsofJoinTest, TestEmpty1) { + return BasicTest(/*l*/ + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, 11, 1001], [0, 0, 21, 11, 1001], [500, 0, 2, 31, 101], [1000, 0, 22, 12, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", + R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp_emptykey*/ + {R"([[0, 1, 1, 11, 1001], [0, 2, 21, 11, 1001], [500, 1, 2, 31, 101], [1000, 2, 22, 12, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"}, + /*exp*/ + {R"([[0, 1, 1, 11, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, null], [1500, 1, 3, null, null], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, null, null]])"}, + 0); +} + +TRACED_TEST_P(AsofJoinBasicTest, TestBasic6, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestBasic6_" + std::get<1>(GetParam())); + BasicTest basic_test = GetBasicTest6(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetEmptyTest1() { // Empty left batch - DoRunBasicTest(/*l*/ - {R"([])", R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, - /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, - 1000); -} - -TEST(AsofJoinTest, TestEmpty2) { + return BasicTest(/*l*/ + {R"([])", R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp_emptykey*/ + {R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"}, + /*exp*/ + {R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, 1000); +} + +TRACED_TEST_P(AsofJoinBasicTest, TestEmpty1, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty1_" + std::get<1>(GetParam())); + BasicTest basic_test = GetEmptyTest1(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetEmptyTest2() { // Empty left input - DoRunBasicTest(/*l*/ - {R"([])"}, - /*r0*/ - {R"([[0, 1, 11.0], [500, 2, 31.0], [1000, 1, 12.0]])", - R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([])"}, 1000); -} - -TEST(AsofJoinTest, TestEmpty3) { + return BasicTest(/*l*/ + {R"([])"}, + /*r0*/ + {R"([[0, 1, 11], [500, 2, 31], [1000, 1, 12]])", + R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([])"}, + /*exp_emptykey*/ + {R"([])"}, + /*exp*/ + {R"([])"}, 1000); +} + +TRACED_TEST_P(AsofJoinBasicTest, TestEmpty2, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty2_" + std::get<1>(GetParam())); + BasicTest basic_test = GetEmptyTest2(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetEmptyTest3() { // Empty right batch - DoRunBasicTest(/*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, - /*r0*/ - {R"([])", R"([[1500, 2, 32.0], [2000, 1, 13.0], [2500, 2, 33.0]])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([[0, 1, 1.0, null, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, 1001.0], [1500, 1, 3.0, null, 102.0], [1500, 2, 23.0, 32.0, 1002.0]])", - R"([[2000, 1, 4.0, 13.0, 103.0], [2000, 2, 24.0, 32.0, 1002.0]])"}, - 1000); -} - -TEST(AsofJoinTest, TestEmpty4) { - // Empty right input - DoRunBasicTest(/*l*/ - {R"([[0, 1, 1.0], [0, 2, 21.0], [500, 1, 2.0], [1000, 2, 22.0], [1500, 1, 3.0], [1500, 2, 23.0]])", - R"([[2000, 1, 4.0], [2000, 2, 24.0]])"}, - /*r0*/ - {R"([])"}, - /*r1*/ - {R"([[0, 2, 1001.0], [500, 1, 101.0]])", - R"([[1000, 1, 102.0], [1500, 2, 1002.0], [2000, 1, 103.0]])"}, - /*exp*/ - {R"([[0, 1, 1.0, null, null], [0, 2, 21.0, null, 1001.0], [500, 1, 2.0, null, 101.0], [1000, 2, 22.0, null, 1001.0], [1500, 1, 3.0, null, 102.0], [1500, 2, 23.0, null, 1002.0]])", - R"([[2000, 1, 4.0, null, 103.0], [2000, 2, 24.0, null, 1002.0]])"}, - 1000); -} - -TEST(AsofJoinTest, TestEmpty5) { - // All empty - DoRunBasicTest(/*l*/ - {R"([])"}, - /*r0*/ - {R"([])"}, - /*r1*/ - {R"([])"}, - /*exp*/ - {R"([])"}, 1000); + return BasicTest(/*l*/ + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([])", R"([[1500, 2, 32], [2000, 1, 13], [2500, 2, 33]])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, null, 1001], [0, 0, 21, null, 1001], [500, 0, 2, null, 101], [1000, 0, 22, null, 102], [1500, 0, 3, 32, 1002], [1500, 0, 23, 32, 1002]])", + R"([[2000, 0, 4, 13, 103], [2000, 0, 24, 13, 103]])"}, + /*exp_emptykey*/ + {R"([[0, 1, 1, null, 1001], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 102], [1500, 1, 3, 32, 1002], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 13, 103]])"}, + /*exp*/ + {R"([[0, 1, 1, null, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 1001], [1500, 1, 3, null, 102], [1500, 2, 23, 32, 1002]])", + R"([[2000, 1, 4, 13, 103], [2000, 2, 24, 32, 1002]])"}, + 1000); } -TEST(AsofJoinTest, TestUnsupportedOntype) { - DoRunInvalidTypeTest( - schema({field("time", utf8()), field("key", int32()), field("l_v0", float64())}), - schema({field("time", utf8()), field("key", int32()), field("r0_v0", float32())})); +TRACED_TEST_P(AsofJoinBasicTest, TestEmpty3, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty3_" + std::get<1>(GetParam())); + BasicTest basic_test = GetEmptyTest3(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetEmptyTest4() { + // Empty right input + return BasicTest(/*l*/ + {R"([[0, 1, 1], [0, 2, 21], [500, 1, 2], [1000, 2, 22], [1500, 1, 3], [1500, 2, 23]])", + R"([[2000, 1, 4], [2000, 2, 24]])"}, + /*r0*/ + {R"([])"}, + /*r1*/ + {R"([[0, 2, 1001], [500, 1, 101]])", + R"([[1000, 1, 102], [1500, 2, 1002], [2000, 1, 103]])"}, + /*exp_nokey*/ + {R"([[0, 0, 1, null, 1001], [0, 0, 21, null, 1001], [500, 0, 2, null, 101], [1000, 0, 22, null, 102], [1500, 0, 3, null, 1002], [1500, 0, 23, null, 1002]])", + R"([[2000, 0, 4, null, 103], [2000, 0, 24, null, 103]])"}, + /*exp_emptykey*/ + {R"([[0, 1, 1, null, 1001], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 102], [1500, 1, 3, null, 1002], [1500, 2, 23, null, 1002]])", + R"([[2000, 1, 4, null, 103], [2000, 2, 24, null, 103]])"}, + /*exp*/ + {R"([[0, 1, 1, null, null], [0, 2, 21, null, 1001], [500, 1, 2, null, 101], [1000, 2, 22, null, 1001], [1500, 1, 3, null, 102], [1500, 2, 23, null, 1002]])", + R"([[2000, 1, 4, null, 103], [2000, 2, 24, null, 1002]])"}, + 1000); } -TEST(AsofJoinTest, TestUnsupportedBytype) { - DoRunInvalidTypeTest( - schema({field("time", int64()), field("key", utf8()), field("l_v0", float64())}), - schema({field("time", int64()), field("key", utf8()), field("r0_v0", float32())})); +TRACED_TEST_P(AsofJoinBasicTest, TestEmpty4, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty4_" + std::get<1>(GetParam())); + BasicTest basic_test = GetEmptyTest4(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +BasicTest GetEmptyTest5() { + // All empty + return BasicTest(/*l*/ + {R"([])"}, + /*r0*/ + {R"([])"}, + /*r1*/ + {R"([])"}, + /*exp_nokey*/ + {R"([])"}, + /*exp_emptykey*/ + {R"([])"}, + /*exp*/ + {R"([])"}, 1000); } -TEST(AsofJoinTest, TestUnsupportedDatatype) { - // Utf8 is unsupported +TRACED_TEST_P(AsofJoinBasicTest, TestEmpty5, { + ARROW_SCOPED_TRACE("AsofJoinBasicTest_TestEmpty5_" + std::get<1>(GetParam())); + BasicTest basic_test = GetEmptyTest5(); + auto runner = std::get<0>(GetParam()); + runner(basic_test); +}) + +INSTANTIATE_TEST_SUITE_P( + AsofJoinNodeTest, AsofJoinBasicTest, + testing::Values(AsofJoinBasicParams(BasicTest::DoSingleByKey, "SingleByKey"), + AsofJoinBasicParams(BasicTest::DoDoubleByKey, "DoubleByKey"), + AsofJoinBasicParams(BasicTest::DoMutateByKey, "MutateByKey"), + AsofJoinBasicParams(BasicTest::DoMutateNoKey, "MutateNoKey"), + AsofJoinBasicParams(BasicTest::DoMutateNullKey, "MutateNullKey"), + AsofJoinBasicParams(BasicTest::DoMutateEmptyKey, "MutateEmptyKey"))); + +TRACED_TEST(AsofJoinTest, TestUnsupportedOntype, { + DoRunInvalidTypeTest(schema({field("time", list(int32())), field("key", int32()), + field("l_v0", float64())}), + schema({field("time", list(int32())), field("key", int32()), + field("r0_v0", float32())})); +}) + +TRACED_TEST(AsofJoinTest, TestUnsupportedBytype, { + DoRunInvalidTypeTest(schema({field("time", int64()), field("key", list(int32())), + field("l_v0", float64())}), + schema({field("time", int64()), field("key", list(int32())), + field("r0_v0", float32())})); +}) + +TRACED_TEST(AsofJoinTest, TestUnsupportedDatatype, { + // List is unsupported DoRunInvalidTypeTest( schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), - schema({field("time", int64()), field("key", int32()), field("r0_v0", utf8())})); -} + schema({field("time", int64()), field("key", int32()), + field("r0_v0", list(int32()))})); +}) -TEST(AsofJoinTest, TestMissingKeys) { - DoRunInvalidTypeTest( +TRACED_TEST(AsofJoinTest, TestMissingKeys, { + DoRunMissingKeysTest( schema({field("time1", int64()), field("key", int32()), field("l_v0", float64())}), schema( {field("time1", int64()), field("key", int32()), field("r0_v0", float64())})); - DoRunInvalidTypeTest( + DoRunMissingKeysTest( schema({field("time", int64()), field("key1", int32()), field("l_v0", float64())}), schema( {field("time", int64()), field("key1", int32()), field("r0_v0", float64())})); -} +}) + +TRACED_TEST(AsofJoinTest, TestUnsupportedTolerance, { + // Utf8 is unsupported + DoRunInvalidToleranceTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestMissingOnKey, { + DoRunMissingOnKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestMissingByKey, { + DoRunMissingByKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestNestedOnKey, { + DoRunNestedOnKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestNestedByKey, { + DoRunNestedByKeyTest( + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestAmbiguousOnKey, { + DoRunAmbiguousOnKeyTest( + schema({field("time", int64()), field("time", int64()), field("key", int32()), + field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestAmbiguousByKey, { + DoRunAmbiguousByKeyTest( + schema({field("time", int64()), field("key", int64()), field("key", int32()), + field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestLeftUnorderedOnKey, { + DoRunUnorderedPlanTest( + /*l_unordered=*/true, /*r_unordered=*/false, + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestRightUnorderedOnKey, { + DoRunUnorderedPlanTest( + /*l_unordered=*/false, /*r_unordered=*/true, + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) + +TRACED_TEST(AsofJoinTest, TestUnorderedOnKey, { + DoRunUnorderedPlanTest( + /*l_unordered=*/true, /*r_unordered=*/true, + schema({field("time", int64()), field("key", int32()), field("l_v0", float64())}), + schema({field("time", int64()), field("key", int32()), field("r0_v0", float64())})); +}) } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/hash_join.cc b/cpp/src/arrow/compute/exec/hash_join.cc index 5cf66b3d09e..da1710fe08d 100644 --- a/cpp/src/arrow/compute/exec/hash_join.cc +++ b/cpp/src/arrow/compute/exec/hash_join.cc @@ -26,7 +26,6 @@ #include #include "arrow/compute/exec/hash_join_dict.h" -#include "arrow/compute/exec/key_hash.h" #include "arrow/compute/exec/task_util.h" #include "arrow/compute/kernels/row_encoder.h" #include "arrow/compute/row/encode_internal.h" diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index a8e8c1ee230..e0172bff7f7 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -397,23 +397,25 @@ class ARROW_EXPORT HashJoinNodeOptions : public ExecNodeOptions { /// This node will output one row for each row in the left table. class ARROW_EXPORT AsofJoinNodeOptions : public ExecNodeOptions { public: - AsofJoinNodeOptions(FieldRef on_key, FieldRef by_key, int64_t tolerance) - : on_key(std::move(on_key)), by_key(std::move(by_key)), tolerance(tolerance) {} + AsofJoinNodeOptions(FieldRef on_key, std::vector by_key, int64_t tolerance) + : on_key(std::move(on_key)), by_key(by_key), tolerance(tolerance) {} - /// \brief "on" key for the join. Each + /// \brief "on" key for the join. /// - /// All inputs tables must be sorted by the "on" key. Inexact - /// match is used on the "on" key. i.e., a row is considiered match iff + /// All inputs tables must be sorted by the "on" key. Must be a single field of a common + /// type. Inexact match is used on the "on" key. i.e., a row is considered match iff /// left_on - tolerance <= right_on <= left_on. - /// Currently, "on" key must be an int64 field + /// Currently, the "on" key must be of an integer, date, or timestamp type. FieldRef on_key; /// \brief "by" key for the join. /// /// All input tables must have the "by" key. Exact equality /// is used for the "by" key. - /// Currently, the "by" key must be an int32 field - FieldRef by_key; - /// Tolerance for inexact "on" key matching + /// Currently, the "by" key must be of an integer, date, timestamp, or base-binary type + std::vector by_key; + /// \brief Tolerance for inexact "on" key matching. Must be non-negative. + /// + /// The tolerance is interpreted in the same units as the "on" key. int64_t tolerance; }; diff --git a/cpp/src/arrow/compute/light_array.cc b/cpp/src/arrow/compute/light_array.cc index a337d4f999e..caa392319b7 100644 --- a/cpp/src/arrow/compute/light_array.cc +++ b/cpp/src/arrow/compute/light_array.cc @@ -147,6 +147,12 @@ Result ColumnArrayFromArrayData( const std::shared_ptr& array_data, int64_t start_row, int64_t num_rows) { ARROW_ASSIGN_OR_RAISE(KeyColumnMetadata metadata, ColumnMetadataFromDataType(array_data->type)); + return ColumnArrayFromArrayDataAndMetadata(array_data, metadata, start_row, num_rows); +} + +KeyColumnArray ColumnArrayFromArrayDataAndMetadata( + const std::shared_ptr& array_data, const KeyColumnMetadata& metadata, + int64_t start_row, int64_t num_rows) { KeyColumnArray column_array = KeyColumnArray( metadata, array_data->offset + start_row + num_rows, array_data->buffers[0] != NULLPTR ? array_data->buffers[0]->data() : nullptr, diff --git a/cpp/src/arrow/compute/light_array.h b/cpp/src/arrow/compute/light_array.h index 0620f6d3eb1..389b63cca41 100644 --- a/cpp/src/arrow/compute/light_array.h +++ b/cpp/src/arrow/compute/light_array.h @@ -135,7 +135,7 @@ class ARROW_EXPORT KeyColumnArray { /// Only valid if this is a view into a varbinary type uint32_t* mutable_offsets() { DCHECK(!metadata_.is_fixed_length); - DCHECK(metadata_.fixed_length == sizeof(uint32_t)); + DCHECK_EQ(metadata_.fixed_length, sizeof(uint32_t)); return reinterpret_cast(mutable_data(kFixedLengthBuffer)); } /// \brief Return a read-only version of the offsets buffer @@ -143,7 +143,7 @@ class ARROW_EXPORT KeyColumnArray { /// Only valid if this is a view into a varbinary type const uint32_t* offsets() const { DCHECK(!metadata_.is_fixed_length); - DCHECK(metadata_.fixed_length == sizeof(uint32_t)); + DCHECK_EQ(metadata_.fixed_length, sizeof(uint32_t)); return reinterpret_cast(data(kFixedLengthBuffer)); } /// \brief Return a mutable version of the large-offsets buffer @@ -151,7 +151,7 @@ class ARROW_EXPORT KeyColumnArray { /// Only valid if this is a view into a large varbinary type uint64_t* mutable_large_offsets() { DCHECK(!metadata_.is_fixed_length); - DCHECK(metadata_.fixed_length == sizeof(uint64_t)); + DCHECK_EQ(metadata_.fixed_length, sizeof(uint64_t)); return reinterpret_cast(mutable_data(kFixedLengthBuffer)); } /// \brief Return a read-only version of the large-offsets buffer @@ -159,7 +159,7 @@ class ARROW_EXPORT KeyColumnArray { /// Only valid if this is a view into a large varbinary type const uint64_t* large_offsets() const { DCHECK(!metadata_.is_fixed_length); - DCHECK(metadata_.fixed_length == sizeof(uint64_t)); + DCHECK_EQ(metadata_.fixed_length, sizeof(uint64_t)); return reinterpret_cast(data(kFixedLengthBuffer)); } /// \brief Return the type metadata @@ -205,6 +205,17 @@ ARROW_EXPORT Result ColumnMetadataFromDataType( ARROW_EXPORT Result ColumnArrayFromArrayData( const std::shared_ptr& array_data, int64_t start_row, int64_t num_rows); +/// \brief Create KeyColumnArray from ArrayData and KeyColumnMetadata +/// +/// If `type` is a dictionary type then this will return the KeyColumnArray for +/// the indices array +/// +/// The caller should ensure this is only called on "key" columns. +/// \see ColumnMetadataFromDataType for details +ARROW_EXPORT KeyColumnArray ColumnArrayFromArrayDataAndMetadata( + const std::shared_ptr& array_data, const KeyColumnMetadata& metadata, + int64_t start_row, int64_t num_rows); + /// \brief Create KeyColumnMetadata instances from an ExecBatch /// /// column_metadatas will be resized to fit diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 66da3cadcb5..e2b74e865fd 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -622,6 +622,13 @@ using is_fixed_size_binary_type = std::is_base_of; template using enable_if_fixed_size_binary = enable_if_t::value, R>; +// This includes primitive, dictionary, and fixed-size-binary types +template +using is_fixed_width_type = std::is_base_of; + +template +using enable_if_fixed_width_type = enable_if_t::value, R>; + template using is_binary_like_type = std::integral_constant::value && From f184255cbb9bf911ea2a04910f711e1a924b12b8 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 9 Sep 2022 17:18:18 -0400 Subject: [PATCH 022/133] ARROW-17627: [Go][Parquet] Forward schema metadata to file without StoreSchema (#14087) Authored-by: Matt Topol Signed-off-by: Matt Topol --- go/parquet/pqarrow/file_reader_test.go | 25 +++++++++++++++++++++++++ go/parquet/pqarrow/file_writer.go | 8 ++++---- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/go/parquet/pqarrow/file_reader_test.go b/go/parquet/pqarrow/file_reader_test.go index 416bf8169b0..4011f5b7093 100644 --- a/go/parquet/pqarrow/file_reader_test.go +++ b/go/parquet/pqarrow/file_reader_test.go @@ -191,3 +191,28 @@ func TestRecordReaderSerial(t *testing.T) { assert.Same(t, io.EOF, err) assert.Nil(t, rec) } + +func TestFileReaderWriterMetadata(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + tbl := makeDateTimeTypesTable(mem, true, true) + defer tbl.Release() + + meta := arrow.NewMetadata([]string{"foo", "bar"}, []string{"bar", "baz"}) + sc := arrow.NewSchema(tbl.Schema().Fields(), &meta) + + var buf bytes.Buffer + writer, err := pqarrow.NewFileWriter(sc, &buf, nil, pqarrow.NewArrowWriterProperties(pqarrow.WithAllocator(mem))) + require.NoError(t, err) + require.NoError(t, writer.WriteTable(tbl, tbl.NumRows())) + require.NoError(t, writer.Close()) + + pf, err := file.NewParquetReader(bytes.NewReader(buf.Bytes()), file.WithReadProps(parquet.NewReaderProperties(mem))) + require.NoError(t, err) + defer pf.Close() + + kvMeta := pf.MetaData().KeyValueMetadata() + assert.Equal(t, []string{"foo", "bar"}, kvMeta.Keys()) + assert.Equal(t, []string{"bar", "baz"}, kvMeta.Values()) +} diff --git a/go/parquet/pqarrow/file_writer.go b/go/parquet/pqarrow/file_writer.go index 9a44b7f08f7..f24a5968aa4 100644 --- a/go/parquet/pqarrow/file_writer.go +++ b/go/parquet/pqarrow/file_writer.go @@ -73,11 +73,11 @@ func NewFileWriter(arrschema *arrow.Schema, w io.Writer, props *parquet.WriterPr } meta := make(metadata.KeyValueMetadata, 0) - if arrprops.storeSchema { - for i := 0; i < arrschema.Metadata().Len(); i++ { - meta.Append(arrschema.Metadata().Keys()[i], arrschema.Metadata().Values()[i]) - } + for i := 0; i < arrschema.Metadata().Len(); i++ { + meta.Append(arrschema.Metadata().Keys()[i], arrschema.Metadata().Values()[i]) + } + if arrprops.storeSchema { serializedSchema := flight.SerializeSchema(arrschema, props.Allocator()) meta.Append("ARROW:schema", base64.StdEncoding.EncodeToString(serializedSchema)) } From b8fac31ba1b38924de05146a9f128aad624d8e0b Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sat, 10 Sep 2022 17:09:30 -0400 Subject: [PATCH 023/133] remove ordering --- cpp/src/arrow/compute/exec/plan_test.cc | 12 +++--- cpp/src/arrow/compute/exec/source_node.cc | 48 ++--------------------- 2 files changed, 9 insertions(+), 51 deletions(-) diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 99d4e675841..2fa4cde9d20 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -318,8 +318,8 @@ TEST(ExecPlanExecution, ArrayVectorSourceSink) { }) .AddToPlan(plan.get())); - ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); - ASSERT_EQ(res, exp_batches.batches); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); } } @@ -365,8 +365,8 @@ TEST(ExecPlanExecution, ExecBatchSourceSink) { }) .AddToPlan(plan.get())); - ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); - ASSERT_EQ(res, exp_batches.batches); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); } } @@ -412,8 +412,8 @@ TEST(ExecPlanExecution, RecordBatchSourceSink) { }) .AddToPlan(plan.get())); - ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); - ASSERT_EQ(res, exp_batches.batches); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); } } diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 57ff991b7b4..73fc7e5ebda 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -324,39 +324,6 @@ struct SchemaSourceNode : public SourceNode { return Status::OK(); } - - template - static Iterator> MakeEnumeratedIterator(Iterator it) { - // TODO: Should Enumerated<>.index be changed to int64_t? Currently, this change - // causes dataset unit-test failures - using index_t = decltype(Enumerated{}.index); - struct { - index_t index = 0; - Enumerated operator()(const Item& item) { - return Enumerated{item, index++, false}; - } - } enumerator; - return MakeMapIterator(std::move(enumerator), std::move(it)); - } - - template - static arrow::AsyncGenerator MakeUnenumeratedGenerator( - const arrow::AsyncGenerator>& enum_gen) { - using Enum = Enumerated; - return MakeMappedGenerator(enum_gen, [](const Enum& e) { return e.value; }); - } - - template - static arrow::AsyncGenerator MakeOrderedGenerator( - const arrow::AsyncGenerator>& unordered_gen) { - using Enum = Enumerated; - auto enum_gen = MakeSequencingGenerator( - unordered_gen, - /*compare=*/[](const Enum& a, const Enum& b) { return a.index > b.index; }, - /*is_next=*/[](const Enum& a, const Enum& b) { return a.index + 1 == b.index; }, - /*initial_value=*/Enum{{}, 0}); - return MakeUnenumeratedGenerator(enum_gen); - } }; struct RecordBatchSourceNode @@ -384,10 +351,7 @@ struct RecordBatchSourceNode return util::optional(ExecBatch(*batch)); }; auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(batch_it)); - auto enum_it = MakeEnumeratedIterator(std::move(exec_batch_it)); - ARROW_ASSIGN_OR_RAISE(auto enum_gen, - MakeBackgroundGenerator(std::move(enum_it), io_executor)); - return MakeUnenumeratedGenerator(std::move(enum_gen)); + return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); } static const char kKindName[]; @@ -417,10 +381,7 @@ struct ExecBatchSourceNode return batch == NULLPTR ? util::nullopt : util::optional(*batch); }; auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(batch_it)); - auto enum_it = MakeEnumeratedIterator(std::move(exec_batch_it)); - ARROW_ASSIGN_OR_RAISE(auto enum_gen, - MakeBackgroundGenerator(std::move(enum_it), io_executor)); - return MakeUnenumeratedGenerator(std::move(enum_gen)); + return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); } static const char kKindName[]; @@ -458,10 +419,7 @@ struct ArrayVectorSourceNode ExecBatch(std::move(datumvec), (*arrayvec)[0]->length())); }; auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(arrayvec_it)); - auto enum_it = MakeEnumeratedIterator(std::move(exec_batch_it)); - ARROW_ASSIGN_OR_RAISE(auto enum_gen, - MakeBackgroundGenerator(std::move(enum_it), io_executor)); - return MakeUnenumeratedGenerator(std::move(enum_gen)); + return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); } static const char kKindName[]; From 8a9037d4cd459daef43a083e2e29a7833fb4c62b Mon Sep 17 00:00:00 2001 From: David Li Date: Sat, 10 Sep 2022 18:19:51 -0400 Subject: [PATCH 024/133] MINOR: [C++] Fix unity build error (#14083) ``` /arrow/cpp/src/arrow/compute/exec/bloom_filter.h:252:25: error: type attributes ignored after type is already defined [-Werror=attributes] 252 | enum class ARROW_EXPORT BloomFilterBuildStrategy { | ^~~~~~~~~~~~~~~~~~~~~~~~ ``` Authored-by: David Li Signed-off-by: David Li --- cpp/src/arrow/compute/exec/bloom_filter.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/exec/bloom_filter.h b/cpp/src/arrow/compute/exec/bloom_filter.h index 06920c6c14f..b0227e720d8 100644 --- a/cpp/src/arrow/compute/exec/bloom_filter.h +++ b/cpp/src/arrow/compute/exec/bloom_filter.h @@ -249,7 +249,7 @@ class ARROW_EXPORT BlockedBloomFilter { // b) It is preferred for small and medium size Bloom filters, because it skips extra // synchronization related steps from parallel variant (partitioning and taking locks). // -enum class ARROW_EXPORT BloomFilterBuildStrategy { +enum class BloomFilterBuildStrategy { SINGLE_THREADED = 0, PARALLEL = 1, }; From a63e60bad89b41266d155bc496eb383765702492 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Sat, 10 Sep 2022 12:50:06 -1000 Subject: [PATCH 025/133] ARROW-17675: [C++] Modified the FileSource::Equals method to handle the case where buffer_ is null (#14085) Authored-by: Weston Pace Signed-off-by: David Li --- cpp/src/arrow/dataset/file_base.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index ff5d1e43eb3..81bf10abe30 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -93,8 +93,11 @@ bool FileSource::Equals(const FileSource& other) const { bool match_file_system = (filesystem_ == nullptr && other.filesystem_ == nullptr) || (filesystem_ && other.filesystem_ && filesystem_->Equals(other.filesystem_)); - return match_file_system && file_info_.Equals(other.file_info_) && - buffer_->Equals(*other.buffer_) && compression_ == other.compression_; + bool match_buffer = (buffer_ == nullptr && other.buffer_ == nullptr) || + ((buffer_ != nullptr && other.buffer_ != nullptr) && + (buffer_->address() == other.buffer_->address())); + return match_file_system && match_buffer && file_info_.Equals(other.file_info_) && + compression_ == other.compression_; } Future> FileFormat::CountRows( From e558ac7fc9443b6758dc02f29a3ce4c976516edb Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 11 Sep 2022 02:50:10 -0400 Subject: [PATCH 026/133] support RecordBatchReader maker --- cpp/src/arrow/compute/exec/options.h | 9 +++++---- cpp/src/arrow/record_batch.h | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 1842fe7019a..65c020d5128 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -26,6 +26,7 @@ #include "arrow/compute/api_vector.h" #include "arrow/compute/exec.h" #include "arrow/compute/exec/expression.h" +#include "arrow/record_batch.h" #include "arrow/result.h" #include "arrow/util/async_generator.h" #include "arrow/util/async_util.h" @@ -93,6 +94,10 @@ class ARROW_EXPORT SchemaSourceNodeOptions : public ExecNodeOptions { ItMaker it_maker; }; +/// \brief An extended Source node which accepts a schema and array-vectors +using ArrayVectorIteratorMaker = std::function>()>; +using ArrayVectorSourceNodeOptions = SchemaSourceNodeOptions; + using ExecBatchIteratorMaker = std::function>()>; /// \brief An extended Source node which accepts a schema and exec-batches using ExecBatchSourceNodeOptions = SchemaSourceNodeOptions; @@ -101,10 +106,6 @@ using RecordBatchIteratorMaker = std::function; -/// \brief An extended Source node which accepts a schema and array-vectors -using ArrayVectorIteratorMaker = std::function>()>; -using ArrayVectorSourceNodeOptions = SchemaSourceNodeOptions; - /// \brief Make a node which excludes some rows from batches passed through it /// /// filter_expression will be evaluated against each batch which is pushed to diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index 8bc70322560..32c8e5fa795 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -217,7 +217,7 @@ struct ARROW_EXPORT RecordBatchWithMetadata { }; /// \brief Abstract interface for reading stream of record batches -class ARROW_EXPORT RecordBatchReader { +class ARROW_EXPORT RecordBatchReader : public Iterator> { public: using ValueType = std::shared_ptr; From f42f3df080bee157a4de4912a29a918082d03e7e Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Sun, 11 Sep 2022 22:04:41 -0500 Subject: [PATCH 027/133] ARROW-17616: [CI][Java] Solving regex to support last Arrow Java versions >= 10.0.0 (#14076) Solving regex to support last Arrow Java versions >= 10.0.0 Authored-by: david dali susanibar arce Signed-off-by: Sutou Kouhei --- .github/workflows/java_nightly.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/java_nightly.yml b/.github/workflows/java_nightly.yml index 5e7de690570..015e28984fb 100644 --- a/.github/workflows/java_nightly.yml +++ b/.github/workflows/java_nightly.yml @@ -100,7 +100,7 @@ jobs: if [ -z $PREFIX ]; then PREFIX=nightly-packaging-$(date +%Y-%m-%d)-0 fi - PATTERN_TO_GET_LIB_AND_VERSION='([a-z].+)-([0-9].[0-9].[0-9].dev[0-9]+)' + PATTERN_TO_GET_LIB_AND_VERSION='([a-z].+)-([0-9]+.[0-9]+.[0-9]+.dev[0-9]+)' mkdir -p repo/org/apache/arrow/ for LIBRARY in $(ls binaries/$PREFIX/java-jars | grep -E '.jar|.pom' | grep dev); do [[ $LIBRARY =~ $PATTERN_TO_GET_LIB_AND_VERSION ]] From 3aee08b41b4f748bc5086f9728c4314a70067d8a Mon Sep 17 00:00:00 2001 From: Dhruv Vats Date: Mon, 12 Sep 2022 10:51:12 +0530 Subject: [PATCH 028/133] ARROW-17632: [Python][C++] Add details of where libarrow is being found during build (#14059) This PR aims to add back `message(STATUS ...)` statements that printed some details about Arrow being found, its version, and the paths. These were refactored away as part of #13892. Authored-by: Dhruv Vats Signed-off-by: Sutou Kouhei --- cpp/src/arrow/ArrowConfig.cmake.in | 15 +++++++++++++++ cpp/src/arrow/ArrowTestingConfig.cmake.in | 2 ++ cpp/src/arrow/dataset/ArrowDatasetConfig.cmake.in | 2 ++ .../arrow/engine/ArrowSubstraitConfig.cmake.in | 2 ++ cpp/src/arrow/flight/ArrowFlightConfig.cmake.in | 2 ++ .../flight/ArrowFlightTestingConfig.cmake.in | 2 ++ .../flight/sql/ArrowFlightSqlConfig.cmake.in | 2 ++ cpp/src/arrow/gpu/ArrowCUDAConfig.cmake.in | 2 ++ cpp/src/gandiva/GandivaConfig.cmake.in | 2 ++ cpp/src/parquet/ParquetConfig.cmake.in | 2 ++ cpp/src/plasma/PlasmaConfig.cmake.in | 2 ++ python/pyarrow/src/ArrowPythonConfig.cmake.in | 2 ++ .../pyarrow/src/ArrowPythonFlightConfig.cmake.in | 2 ++ 13 files changed, 39 insertions(+) diff --git a/cpp/src/arrow/ArrowConfig.cmake.in b/cpp/src/arrow/ArrowConfig.cmake.in index f0aa1bc959b..8386bcd7280 100644 --- a/cpp/src/arrow/ArrowConfig.cmake.in +++ b/cpp/src/arrow/ArrowConfig.cmake.in @@ -172,3 +172,18 @@ endmacro() arrow_keep_backward_compatibility(Arrow arrow) check_required_components(Arrow) + +macro(arrow_show_details package_name variable_prefix) + if(NOT ${package_name}_FIND_QUIETLY AND NOT ${package_name}_SHOWED_DETAILS) + message(STATUS "${package_name} version: ${${package_name}_VERSION}") + message(STATUS "Found the ${package_name} shared library: ${${variable_prefix}_SHARED_LIB}" + ) + message(STATUS "Found the ${package_name} import library: ${${variable_prefix}_IMPORT_LIB}" + ) + message(STATUS "Found the ${package_name} static library: ${${variable_prefix}_STATIC_LIB}" + ) + set(${package_name}_SHOWED_DETAILS TRUE) + endif() +endmacro() + +arrow_show_details(Arrow ARROW) diff --git a/cpp/src/arrow/ArrowTestingConfig.cmake.in b/cpp/src/arrow/ArrowTestingConfig.cmake.in index 03775b043ed..87ee9e755e1 100644 --- a/cpp/src/arrow/ArrowTestingConfig.cmake.in +++ b/cpp/src/arrow/ArrowTestingConfig.cmake.in @@ -34,3 +34,5 @@ include("${CMAKE_CURRENT_LIST_DIR}/ArrowTestingTargets.cmake") arrow_keep_backward_compatibility(ArrowTesting arrow_testing) check_required_components(ArrowTesting) + +arrow_show_details(ArrowTesting ARROW_TESTING) diff --git a/cpp/src/arrow/dataset/ArrowDatasetConfig.cmake.in b/cpp/src/arrow/dataset/ArrowDatasetConfig.cmake.in index 38b9baf4089..6816f2c837d 100644 --- a/cpp/src/arrow/dataset/ArrowDatasetConfig.cmake.in +++ b/cpp/src/arrow/dataset/ArrowDatasetConfig.cmake.in @@ -35,3 +35,5 @@ include("${CMAKE_CURRENT_LIST_DIR}/ArrowDatasetTargets.cmake") arrow_keep_backward_compatibility(ArrowDataset arrow_dataset) check_required_components(ArrowDataset) + +arrow_show_details(ArrowDataset ARROW_DATASET) diff --git a/cpp/src/arrow/engine/ArrowSubstraitConfig.cmake.in b/cpp/src/arrow/engine/ArrowSubstraitConfig.cmake.in index 2263c735d26..2e96d372ad7 100644 --- a/cpp/src/arrow/engine/ArrowSubstraitConfig.cmake.in +++ b/cpp/src/arrow/engine/ArrowSubstraitConfig.cmake.in @@ -36,3 +36,5 @@ include("${CMAKE_CURRENT_LIST_DIR}/ArrowSubstraitTargets.cmake") arrow_keep_backward_compatibility(ArrowSubstrait arrow_substrait) check_required_components(ArrowSubstrait) + +arrow_show_details(ArrowSubstrait ARROW_SUBSTRAIT) diff --git a/cpp/src/arrow/flight/ArrowFlightConfig.cmake.in b/cpp/src/arrow/flight/ArrowFlightConfig.cmake.in index 04560f91e08..70beb901c85 100644 --- a/cpp/src/arrow/flight/ArrowFlightConfig.cmake.in +++ b/cpp/src/arrow/flight/ArrowFlightConfig.cmake.in @@ -34,3 +34,5 @@ include("${CMAKE_CURRENT_LIST_DIR}/ArrowFlightTargets.cmake") arrow_keep_backward_compatibility(ArrowFlight arrow_flight) check_required_components(ArrowFlight) + +arrow_show_details(ArrowFlight ARROW_FLIGHT) diff --git a/cpp/src/arrow/flight/ArrowFlightTestingConfig.cmake.in b/cpp/src/arrow/flight/ArrowFlightTestingConfig.cmake.in index 0c42c5c1ff8..f072b2603e3 100644 --- a/cpp/src/arrow/flight/ArrowFlightTestingConfig.cmake.in +++ b/cpp/src/arrow/flight/ArrowFlightTestingConfig.cmake.in @@ -35,3 +35,5 @@ include("${CMAKE_CURRENT_LIST_DIR}/ArrowFlightTestingTargets.cmake") arrow_keep_backward_compatibility(ArrowFlightTetsing arrow_flight_testing) check_required_components(ArrowFlightTesting) + +arrow_show_details(ArrowFlightTesting ARROW_FLIGHT_TESTING) diff --git a/cpp/src/arrow/flight/sql/ArrowFlightSqlConfig.cmake.in b/cpp/src/arrow/flight/sql/ArrowFlightSqlConfig.cmake.in index 9d0e9ea2dac..3a70dbdeda6 100644 --- a/cpp/src/arrow/flight/sql/ArrowFlightSqlConfig.cmake.in +++ b/cpp/src/arrow/flight/sql/ArrowFlightSqlConfig.cmake.in @@ -34,3 +34,5 @@ include("${CMAKE_CURRENT_LIST_DIR}/ArrowFlightSqlTargets.cmake") arrow_keep_backward_compatibility(ArrowFlightSql arrow_flight_sql) check_required_components(ArrowFlightSql) + +arrow_show_details(ArrowFlightSql ARROW_FLIGHT_SQL) diff --git a/cpp/src/arrow/gpu/ArrowCUDAConfig.cmake.in b/cpp/src/arrow/gpu/ArrowCUDAConfig.cmake.in index e987d82a3a1..b251b86f43e 100644 --- a/cpp/src/arrow/gpu/ArrowCUDAConfig.cmake.in +++ b/cpp/src/arrow/gpu/ArrowCUDAConfig.cmake.in @@ -34,3 +34,5 @@ include("${CMAKE_CURRENT_LIST_DIR}/ArrowCUDATargets.cmake") arrow_keep_backward_compatibility(ArrowCUDA arrow_cuda) check_required_components(ArrowCUDA) + +arrow_show_details(ArrowCUDA ARROW_CUDA) diff --git a/cpp/src/gandiva/GandivaConfig.cmake.in b/cpp/src/gandiva/GandivaConfig.cmake.in index 861166dc3d9..c6d7cef73d7 100644 --- a/cpp/src/gandiva/GandivaConfig.cmake.in +++ b/cpp/src/gandiva/GandivaConfig.cmake.in @@ -35,3 +35,5 @@ include("${CMAKE_CURRENT_LIST_DIR}/GandivaTargets.cmake") arrow_keep_backward_compatibility(Gandiva gandiva) check_required_components(Gandiva) + +arrow_show_details(Gandiva GANDIVA) diff --git a/cpp/src/parquet/ParquetConfig.cmake.in b/cpp/src/parquet/ParquetConfig.cmake.in index 19f1b4b6395..10305301388 100644 --- a/cpp/src/parquet/ParquetConfig.cmake.in +++ b/cpp/src/parquet/ParquetConfig.cmake.in @@ -41,3 +41,5 @@ include("${CMAKE_CURRENT_LIST_DIR}/ParquetTargets.cmake") arrow_keep_backward_compatibility(Parquet parquet) check_required_components(Parquet) + +arrow_show_details(Parquet PARQUET) diff --git a/cpp/src/plasma/PlasmaConfig.cmake.in b/cpp/src/plasma/PlasmaConfig.cmake.in index cdd312d04cb..ec3c51ec281 100644 --- a/cpp/src/plasma/PlasmaConfig.cmake.in +++ b/cpp/src/plasma/PlasmaConfig.cmake.in @@ -46,3 +46,5 @@ include("${CMAKE_CURRENT_LIST_DIR}/PlasmaTargets.cmake") arrow_keep_backward_compatibility(Plasma plasma) check_required_components(Plasma) + +arrow_show_details(Plasma PLASMA) diff --git a/python/pyarrow/src/ArrowPythonConfig.cmake.in b/python/pyarrow/src/ArrowPythonConfig.cmake.in index cab92cddfd3..874c5cc09d1 100644 --- a/python/pyarrow/src/ArrowPythonConfig.cmake.in +++ b/python/pyarrow/src/ArrowPythonConfig.cmake.in @@ -37,3 +37,5 @@ include("${CMAKE_CURRENT_LIST_DIR}/ArrowPythonTargets.cmake") arrow_keep_backward_compatibility(ArrowPython arrow_python) check_required_components(ArrowPython) + +arrow_show_details(ArrowPython ARROW_PYTHON) diff --git a/python/pyarrow/src/ArrowPythonFlightConfig.cmake.in b/python/pyarrow/src/ArrowPythonFlightConfig.cmake.in index fb2ad918fc8..1aacb9f212d 100644 --- a/python/pyarrow/src/ArrowPythonFlightConfig.cmake.in +++ b/python/pyarrow/src/ArrowPythonFlightConfig.cmake.in @@ -35,3 +35,5 @@ include("${CMAKE_CURRENT_LIST_DIR}/ArrowPythonFlightTargets.cmake") arrow_keep_backward_compatibility(ArrowPythonFlight arrow_python_flight) check_required_components(ArrowPythonFlight) + +arrow_show_details(ArrowPythonFlight ARROW_PYTHON_FLIGHT) From d8571a45d17b5b0224f4c3ae5e403cc0c5373b9e Mon Sep 17 00:00:00 2001 From: Nic Crane Date: Mon, 12 Sep 2022 11:09:44 +0100 Subject: [PATCH 029/133] ARROW-17639: [R] infer_type() fails for lists where the first element is NULL (#14062) This PR updates the internals of type inference on lists; instead of looking at the type of the first element, we instead iterate through all elements until we find a non-null element. Authored-by: Nic Crane Signed-off-by: Nic Crane --- r/src/type_infer.cpp | 11 +++++++++-- r/tests/testthat/test-type.R | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/r/src/type_infer.cpp b/r/src/type_infer.cpp index 616be0467f9..e30d0e12887 100644 --- a/r/src/type_infer.cpp +++ b/r/src/type_infer.cpp @@ -165,8 +165,13 @@ std::shared_ptr InferArrowTypeFromVector(SEXP x) { cpp11::stop( "Requires at least one element to infer the values' type of a list vector"); } - - ptype = VECTOR_ELT(x, 0); + // Iterate through the vector until we get a non-null result + for (R_xlen_t i = 0; i < XLENGTH(x); i++) { + ptype = VECTOR_ELT(x, i); + if (!Rf_isNull(ptype)) { + break; + } + } } return arrow::list(InferArrowType(ptype)); @@ -198,6 +203,8 @@ std::shared_ptr InferArrowType(SEXP x) { return InferArrowTypeFromVector(x); case VECSXP: return InferArrowTypeFromVector(x); + case NILSXP: + return null(); default: cpp11::stop("Cannot infer type from vector"); } diff --git a/r/tests/testthat/test-type.R b/r/tests/testthat/test-type.R index d7c6da0792c..14f0ea7a8d5 100644 --- a/r/tests/testthat/test-type.R +++ b/r/tests/testthat/test-type.R @@ -293,3 +293,19 @@ test_that("type() is deprecated", { ) expect_equal(a_type, a$type) }) + +test_that("infer_type() infers type for lists starting with NULL - ARROW-17639", { + null_start_list <- list(NULL, c(2, 3), c(4, 5)) + + expect_equal( + infer_type(null_start_list), + list_of(float64()) + ) + + totally_null_list <- list(NULL, NULL, NULL) + + expect_equal( + infer_type(totally_null_list), + list_of(null()) + ) +}) From 75d2bfd33e78e7769bac9fc7aafe3b7c724bf2bd Mon Sep 17 00:00:00 2001 From: Nic Crane Date: Mon, 12 Sep 2022 11:24:42 +0100 Subject: [PATCH 030/133] ARROW-17362: [R] Implement dplyr::across() inside summarise() (#14042) This PR implements `across()` within `summarise()`. It also adds the `.groups` argument explicitly to the `summarise()` function signature instead of being passed in via `...` (this was necessary to prevent test failures after the addition of `expand_across()` to `summarise()`). Authored-by: Nic Crane Signed-off-by: Nic Crane --- r/R/dplyr-summarize.R | 6 +++--- r/tests/testthat/test-dplyr-summarize.R | 21 +++++++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/r/R/dplyr-summarize.R b/r/R/dplyr-summarize.R index 92587f6c685..3181cee1378 100644 --- a/r/R/dplyr-summarize.R +++ b/r/R/dplyr-summarize.R @@ -179,10 +179,10 @@ agg_funcs[["::"]] <- function(lhs, rhs) { # The following S3 methods are registered on load if dplyr is present -summarise.arrow_dplyr_query <- function(.data, ...) { +summarise.arrow_dplyr_query <- function(.data, ..., .groups = NULL) { call <- match.call() .data <- as_adq(.data) - exprs <- quos(...) + exprs <- expand_across(.data, quos(...)) # Only retain the columns we need to do our aggregations vars_to_keep <- unique(c( unlist(lapply(exprs, all.vars)), # vars referenced in summarise @@ -198,7 +198,7 @@ summarise.arrow_dplyr_query <- function(.data, ...) { .data <- dplyr::select(.data, intersect(vars_to_keep, names(.data))) # Try stuff, if successful return() - out <- try(do_arrow_summarize(.data, ...), silent = TRUE) + out <- try(do_arrow_summarize(.data, !!!exprs, .groups = .groups), silent = TRUE) if (inherits(out, "try-error")) { return(abandon_ship(call, .data, format(out))) } else { diff --git a/r/tests/testthat/test-dplyr-summarize.R b/r/tests/testthat/test-dplyr-summarize.R index 283d5d77837..7eda0431b68 100644 --- a/r/tests/testthat/test-dplyr-summarize.R +++ b/r/tests/testthat/test-dplyr-summarize.R @@ -1124,3 +1124,24 @@ test_that("We don't add unnecessary ProjectNodes when aggregating", { 2 ) }) + +test_that("Can use across() within summarise()", { + compare_dplyr_binding( + .input %>% + group_by(lgl) %>% + summarise(across(starts_with("dbl"), sum, .names = "sum_{.col}")) %>% + arrange(lgl) %>% + collect(), + example_data + ) + + # across() doesn't work in summarise when input expressions evaluate to bare field references + expect_warning( + example_data %>% + arrow_table() %>% + group_by(lgl) %>% + summarise(across(everything())) %>% + collect(), + regexp = "Expression int is not an aggregate expression or is not supported in Arrow; pulling data into R" + ) +}) From efb7fb08072710c251f1be2c69777c3086da4952 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Percy=20Camilo=20Trive=C3=B1o=20Aucahuasi?= Date: Mon, 12 Sep 2022 08:12:09 -0500 Subject: [PATCH 031/133] ARROW-16870: [C++] Fix link issues with ldd and clang for flight examples (#14077) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit It seems `ldd` doesn't like this flag `--no-as-needed`, so the arrow flight example cannot build on macOS with clang. This minor PR tries to fix that. I was able to build the arrow flight example and run without issues: `./flight-grpc-example -port 8086` Notes: ```bash clang++ --version Apple clang version 14.0.0 (clang-1400.0.29.102) ``` Authored-by: Percy Camilo Triveño Aucahuasi Signed-off-by: David Li --- cpp/examples/arrow/CMakeLists.txt | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/cpp/examples/arrow/CMakeLists.txt b/cpp/examples/arrow/CMakeLists.txt index 88b760e3978..38a8005fe3c 100644 --- a/cpp/examples/arrow/CMakeLists.txt +++ b/cpp/examples/arrow/CMakeLists.txt @@ -36,10 +36,14 @@ if(ARROW_FLIGHT) # we'll violate ODR for gRPC symbols if(ARROW_GRPC_USE_SHARED) set(FLIGHT_EXAMPLES_LINK_LIBS arrow_flight_shared) - # We don't directly use symbols from the reflection library, so - # ensure the linker still links to it - set(GRPC_REFLECTION_LINK_LIBS -Wl,--no-as-needed gRPC::grpc++_reflection - -Wl,--as-needed) + if(APPLE) + set(GRPC_REFLECTION_LINK_LIBS gRPC::grpc++_reflection) + else() + # We don't directly use symbols from the reflection library, so + # ensure the linker still links to it + set(GRPC_REFLECTION_LINK_LIBS -Wl,--no-as-needed gRPC::grpc++_reflection + -Wl,--as-needed) + endif() elseif(NOT ARROW_BUILD_STATIC) message(FATAL_ERROR "Statically built gRPC requires ARROW_BUILD_STATIC=ON") else() From 8c76855b17471ac7194c1bba630091eaaaa0d181 Mon Sep 17 00:00:00 2001 From: Sam Albers Date: Mon, 12 Sep 2022 06:29:55 -0700 Subject: [PATCH 032/133] ARROW-17448:[R] Fix cloud storage paths in some documentation (#14070) This PR fixes some paths that weren't working in the vignettes for both gcs and aws. The list of files is a bit long so they are in the details. I think using this makes sense, rather than modifying the bucket as this is the default naming for `write_dataset`. I've also tweaked the manual download portion of the vignette.
``` r library(arrow, warn.conflicts = FALSE) ## aws aws <- s3_bucket("voltrondata-labs-datasets") aws_files <- aws$ls(recursive = TRUE) aws_files #> [1] "arrow-project" #> [2] "arrow-project/2022-07-26" #> [3] "arrow-project/2022-07-26/apache-arrow-jira.json" #> [4] "arrow-project/2022-07-26/arrow-commits.csv" #> [5] "arrow-project/2022-07-26/arrow-jira-issues.csv" #> [6] "nyc-taxi" #> [7] "nyc-taxi/year=2009" #> [8] "nyc-taxi/year=2009/month=1" #> [9] "nyc-taxi/year=2009/month=1/part-0.parquet" #> [10] "nyc-taxi/year=2009/month=10" #> [11] "nyc-taxi/year=2009/month=10/part-0.parquet" #> [12] "nyc-taxi/year=2009/month=11" #> [13] "nyc-taxi/year=2009/month=11/part-0.parquet" #> [14] "nyc-taxi/year=2009/month=12" #> [15] "nyc-taxi/year=2009/month=12/part-0.parquet" #> [16] "nyc-taxi/year=2009/month=2" #> [17] "nyc-taxi/year=2009/month=2/part-0.parquet" #> [18] "nyc-taxi/year=2009/month=3" #> [19] "nyc-taxi/year=2009/month=3/part-0.parquet" #> [20] "nyc-taxi/year=2009/month=4" #> [21] "nyc-taxi/year=2009/month=4/part-0.parquet" #> [22] "nyc-taxi/year=2009/month=5" #> [23] "nyc-taxi/year=2009/month=5/part-0.parquet" #> [24] "nyc-taxi/year=2009/month=6" #> [25] "nyc-taxi/year=2009/month=6/part-0.parquet" #> [26] "nyc-taxi/year=2009/month=7" #> [27] "nyc-taxi/year=2009/month=7/part-0.parquet" #> [28] "nyc-taxi/year=2009/month=8" #> [29] "nyc-taxi/year=2009/month=8/part-0.parquet" #> [30] "nyc-taxi/year=2009/month=9" #> [31] "nyc-taxi/year=2009/month=9/part-0.parquet" #> [32] "nyc-taxi/year=2010" #> [33] "nyc-taxi/year=2010/month=1" #> [34] "nyc-taxi/year=2010/month=1/part-0.parquet" #> [35] "nyc-taxi/year=2010/month=10" #> [36] "nyc-taxi/year=2010/month=10/part-0.parquet" #> [37] "nyc-taxi/year=2010/month=11" #> [38] "nyc-taxi/year=2010/month=11/part-0.parquet" #> [39] "nyc-taxi/year=2010/month=12" #> [40] "nyc-taxi/year=2010/month=12/part-0.parquet" #> [41] "nyc-taxi/year=2010/month=2" #> [42] "nyc-taxi/year=2010/month=2/part-0.parquet" #> [43] "nyc-taxi/year=2010/month=3" #> [44] "nyc-taxi/year=2010/month=3/part-0.parquet" #> [45] "nyc-taxi/year=2010/month=4" #> [46] "nyc-taxi/year=2010/month=4/part-0.parquet" #> [47] "nyc-taxi/year=2010/month=5" #> [48] "nyc-taxi/year=2010/month=5/part-0.parquet" #> [49] "nyc-taxi/year=2010/month=6" #> [50] "nyc-taxi/year=2010/month=6/part-0.parquet" #> [51] "nyc-taxi/year=2010/month=7" #> [52] "nyc-taxi/year=2010/month=7/part-0.parquet" #> [53] "nyc-taxi/year=2010/month=8" #> [54] "nyc-taxi/year=2010/month=8/part-0.parquet" #> [55] "nyc-taxi/year=2010/month=9" #> [56] "nyc-taxi/year=2010/month=9/part-0.parquet" #> [57] "nyc-taxi/year=2011" #> [58] "nyc-taxi/year=2011/month=1" #> [59] "nyc-taxi/year=2011/month=1/part-0.parquet" #> [60] "nyc-taxi/year=2011/month=10" #> [61] "nyc-taxi/year=2011/month=10/part-0.parquet" #> [62] "nyc-taxi/year=2011/month=11" #> [63] "nyc-taxi/year=2011/month=11/part-0.parquet" #> [64] "nyc-taxi/year=2011/month=12" #> [65] "nyc-taxi/year=2011/month=12/part-0.parquet" #> [66] "nyc-taxi/year=2011/month=2" #> [67] "nyc-taxi/year=2011/month=2/part-0.parquet" #> [68] "nyc-taxi/year=2011/month=3" #> [69] "nyc-taxi/year=2011/month=3/part-0.parquet" #> [70] "nyc-taxi/year=2011/month=4" #> [71] "nyc-taxi/year=2011/month=4/part-0.parquet" #> [72] "nyc-taxi/year=2011/month=5" #> [73] "nyc-taxi/year=2011/month=5/part-0.parquet" #> [74] "nyc-taxi/year=2011/month=6" #> [75] "nyc-taxi/year=2011/month=6/part-0.parquet" #> [76] "nyc-taxi/year=2011/month=7" #> [77] "nyc-taxi/year=2011/month=7/part-0.parquet" #> [78] "nyc-taxi/year=2011/month=8" #> [79] "nyc-taxi/year=2011/month=8/part-0.parquet" #> [80] "nyc-taxi/year=2011/month=9" #> [81] "nyc-taxi/year=2011/month=9/part-0.parquet" #> [82] "nyc-taxi/year=2012" #> [83] "nyc-taxi/year=2012/month=1" #> [84] "nyc-taxi/year=2012/month=1/part-0.parquet" #> [85] "nyc-taxi/year=2012/month=10" #> [86] "nyc-taxi/year=2012/month=10/part-0.parquet" #> [87] "nyc-taxi/year=2012/month=11" #> [88] "nyc-taxi/year=2012/month=11/part-0.parquet" #> [89] "nyc-taxi/year=2012/month=12" #> [90] "nyc-taxi/year=2012/month=12/part-0.parquet" #> [91] "nyc-taxi/year=2012/month=2" #> [92] "nyc-taxi/year=2012/month=2/part-0.parquet" #> [93] "nyc-taxi/year=2012/month=3" #> [94] "nyc-taxi/year=2012/month=3/part-0.parquet" #> [95] "nyc-taxi/year=2012/month=4" #> [96] "nyc-taxi/year=2012/month=4/part-0.parquet" #> [97] "nyc-taxi/year=2012/month=5" #> [98] "nyc-taxi/year=2012/month=5/part-0.parquet" #> [99] "nyc-taxi/year=2012/month=6" #> [100] "nyc-taxi/year=2012/month=6/part-0.parquet" #> [101] "nyc-taxi/year=2012/month=7" #> [102] "nyc-taxi/year=2012/month=7/part-0.parquet" #> [103] "nyc-taxi/year=2012/month=8" #> [104] "nyc-taxi/year=2012/month=8/part-0.parquet" #> [105] "nyc-taxi/year=2012/month=9" #> [106] "nyc-taxi/year=2012/month=9/part-0.parquet" #> [107] "nyc-taxi/year=2013" #> [108] "nyc-taxi/year=2013/month=1" #> [109] "nyc-taxi/year=2013/month=1/part-0.parquet" #> [110] "nyc-taxi/year=2013/month=10" #> [111] "nyc-taxi/year=2013/month=10/part-0.parquet" #> [112] "nyc-taxi/year=2013/month=11" #> [113] "nyc-taxi/year=2013/month=11/part-0.parquet" #> [114] "nyc-taxi/year=2013/month=12" #> [115] "nyc-taxi/year=2013/month=12/part-0.parquet" #> [116] "nyc-taxi/year=2013/month=2" #> [117] "nyc-taxi/year=2013/month=2/part-0.parquet" #> [118] "nyc-taxi/year=2013/month=3" #> [119] "nyc-taxi/year=2013/month=3/part-0.parquet" #> [120] "nyc-taxi/year=2013/month=4" #> [121] "nyc-taxi/year=2013/month=4/part-0.parquet" #> [122] "nyc-taxi/year=2013/month=5" #> [123] "nyc-taxi/year=2013/month=5/part-0.parquet" #> [124] "nyc-taxi/year=2013/month=6" #> [125] "nyc-taxi/year=2013/month=6/part-0.parquet" #> [126] "nyc-taxi/year=2013/month=7" #> [127] "nyc-taxi/year=2013/month=7/part-0.parquet" #> [128] "nyc-taxi/year=2013/month=8" #> [129] "nyc-taxi/year=2013/month=8/part-0.parquet" #> [130] "nyc-taxi/year=2013/month=9" #> [131] "nyc-taxi/year=2013/month=9/part-0.parquet" #> [132] "nyc-taxi/year=2014" #> [133] "nyc-taxi/year=2014/month=1" #> [134] "nyc-taxi/year=2014/month=1/part-0.parquet" #> [135] "nyc-taxi/year=2014/month=10" #> [136] "nyc-taxi/year=2014/month=10/part-0.parquet" #> [137] "nyc-taxi/year=2014/month=11" #> [138] "nyc-taxi/year=2014/month=11/part-0.parquet" #> [139] "nyc-taxi/year=2014/month=12" #> [140] "nyc-taxi/year=2014/month=12/part-0.parquet" #> [141] "nyc-taxi/year=2014/month=2" #> [142] "nyc-taxi/year=2014/month=2/part-0.parquet" #> [143] "nyc-taxi/year=2014/month=3" #> [144] "nyc-taxi/year=2014/month=3/part-0.parquet" #> [145] "nyc-taxi/year=2014/month=4" #> [146] "nyc-taxi/year=2014/month=4/part-0.parquet" #> [147] "nyc-taxi/year=2014/month=5" #> [148] "nyc-taxi/year=2014/month=5/part-0.parquet" #> [149] "nyc-taxi/year=2014/month=6" #> [150] "nyc-taxi/year=2014/month=6/part-0.parquet" #> [151] "nyc-taxi/year=2014/month=7" #> [152] "nyc-taxi/year=2014/month=7/part-0.parquet" #> [153] "nyc-taxi/year=2014/month=8" #> [154] "nyc-taxi/year=2014/month=8/part-0.parquet" #> [155] "nyc-taxi/year=2014/month=9" #> [156] "nyc-taxi/year=2014/month=9/part-0.parquet" #> [157] "nyc-taxi/year=2015" #> [158] "nyc-taxi/year=2015/month=1" #> [159] "nyc-taxi/year=2015/month=1/part-0.parquet" #> [160] "nyc-taxi/year=2015/month=10" #> [161] "nyc-taxi/year=2015/month=10/part-0.parquet" #> [162] "nyc-taxi/year=2015/month=11" #> [163] "nyc-taxi/year=2015/month=11/part-0.parquet" #> [164] "nyc-taxi/year=2015/month=12" #> [165] "nyc-taxi/year=2015/month=12/part-0.parquet" #> [166] "nyc-taxi/year=2015/month=2" #> [167] "nyc-taxi/year=2015/month=2/part-0.parquet" #> [168] "nyc-taxi/year=2015/month=3" #> [169] "nyc-taxi/year=2015/month=3/part-0.parquet" #> [170] "nyc-taxi/year=2015/month=4" #> [171] "nyc-taxi/year=2015/month=4/part-0.parquet" #> [172] "nyc-taxi/year=2015/month=5" #> [173] "nyc-taxi/year=2015/month=5/part-0.parquet" #> [174] "nyc-taxi/year=2015/month=6" #> [175] "nyc-taxi/year=2015/month=6/part-0.parquet" #> [176] "nyc-taxi/year=2015/month=7" #> [177] "nyc-taxi/year=2015/month=7/part-0.parquet" #> [178] "nyc-taxi/year=2015/month=8" #> [179] "nyc-taxi/year=2015/month=8/part-0.parquet" #> [180] "nyc-taxi/year=2015/month=9" #> [181] "nyc-taxi/year=2015/month=9/part-0.parquet" #> [182] "nyc-taxi/year=2016" #> [183] "nyc-taxi/year=2016/month=1" #> [184] "nyc-taxi/year=2016/month=1/part-0.parquet" #> [185] "nyc-taxi/year=2016/month=10" #> [186] "nyc-taxi/year=2016/month=10/part-0.parquet" #> [187] "nyc-taxi/year=2016/month=11" #> [188] "nyc-taxi/year=2016/month=11/part-0.parquet" #> [189] "nyc-taxi/year=2016/month=12" #> [190] "nyc-taxi/year=2016/month=12/part-0.parquet" #> [191] "nyc-taxi/year=2016/month=2" #> [192] "nyc-taxi/year=2016/month=2/part-0.parquet" #> [193] "nyc-taxi/year=2016/month=3" #> [194] "nyc-taxi/year=2016/month=3/part-0.parquet" #> [195] "nyc-taxi/year=2016/month=4" #> [196] "nyc-taxi/year=2016/month=4/part-0.parquet" #> [197] "nyc-taxi/year=2016/month=5" #> [198] "nyc-taxi/year=2016/month=5/part-0.parquet" #> [199] "nyc-taxi/year=2016/month=6" #> [200] "nyc-taxi/year=2016/month=6/part-0.parquet" #> [201] "nyc-taxi/year=2016/month=7" #> [202] "nyc-taxi/year=2016/month=7/part-0.parquet" #> [203] "nyc-taxi/year=2016/month=8" #> [204] "nyc-taxi/year=2016/month=8/part-0.parquet" #> [205] "nyc-taxi/year=2016/month=9" #> [206] "nyc-taxi/year=2016/month=9/part-0.parquet" #> [207] "nyc-taxi/year=2017" #> [208] "nyc-taxi/year=2017/month=1" #> [209] "nyc-taxi/year=2017/month=1/part-0.parquet" #> [210] "nyc-taxi/year=2017/month=10" #> [211] "nyc-taxi/year=2017/month=10/part-0.parquet" #> [212] "nyc-taxi/year=2017/month=11" #> [213] "nyc-taxi/year=2017/month=11/part-0.parquet" #> [214] "nyc-taxi/year=2017/month=12" #> [215] "nyc-taxi/year=2017/month=12/part-0.parquet" #> [216] "nyc-taxi/year=2017/month=2" #> [217] "nyc-taxi/year=2017/month=2/part-0.parquet" #> [218] "nyc-taxi/year=2017/month=3" #> [219] "nyc-taxi/year=2017/month=3/part-0.parquet" #> [220] "nyc-taxi/year=2017/month=4" #> [221] "nyc-taxi/year=2017/month=4/part-0.parquet" #> [222] "nyc-taxi/year=2017/month=5" #> [223] "nyc-taxi/year=2017/month=5/part-0.parquet" #> [224] "nyc-taxi/year=2017/month=6" #> [225] "nyc-taxi/year=2017/month=6/part-0.parquet" #> [226] "nyc-taxi/year=2017/month=7" #> [227] "nyc-taxi/year=2017/month=7/part-0.parquet" #> [228] "nyc-taxi/year=2017/month=8" #> [229] "nyc-taxi/year=2017/month=8/part-0.parquet" #> [230] "nyc-taxi/year=2017/month=9" #> [231] "nyc-taxi/year=2017/month=9/part-0.parquet" #> [232] "nyc-taxi/year=2018" #> [233] "nyc-taxi/year=2018/month=1" #> [234] "nyc-taxi/year=2018/month=1/part-0.parquet" #> [235] "nyc-taxi/year=2018/month=10" #> [236] "nyc-taxi/year=2018/month=10/part-0.parquet" #> [237] "nyc-taxi/year=2018/month=11" #> [238] "nyc-taxi/year=2018/month=11/part-0.parquet" #> [239] "nyc-taxi/year=2018/month=12" #> [240] "nyc-taxi/year=2018/month=12/part-0.parquet" #> [241] "nyc-taxi/year=2018/month=2" #> [242] "nyc-taxi/year=2018/month=2/part-0.parquet" #> [243] "nyc-taxi/year=2018/month=3" #> [244] "nyc-taxi/year=2018/month=3/part-0.parquet" #> [245] "nyc-taxi/year=2018/month=4" #> [246] "nyc-taxi/year=2018/month=4/part-0.parquet" #> [247] "nyc-taxi/year=2018/month=5" #> [248] "nyc-taxi/year=2018/month=5/part-0.parquet" #> [249] "nyc-taxi/year=2018/month=6" #> [250] "nyc-taxi/year=2018/month=6/part-0.parquet" #> [251] "nyc-taxi/year=2018/month=7" #> [252] "nyc-taxi/year=2018/month=7/part-0.parquet" #> [253] "nyc-taxi/year=2018/month=8" #> [254] "nyc-taxi/year=2018/month=8/part-0.parquet" #> [255] "nyc-taxi/year=2018/month=9" #> [256] "nyc-taxi/year=2018/month=9/part-0.parquet" #> [257] "nyc-taxi/year=2019" #> [258] "nyc-taxi/year=2019/month=1" #> [259] "nyc-taxi/year=2019/month=1/part-0.parquet" #> [260] "nyc-taxi/year=2019/month=10" #> [261] "nyc-taxi/year=2019/month=10/part-0.parquet" #> [262] "nyc-taxi/year=2019/month=11" #> [263] "nyc-taxi/year=2019/month=11/part-0.parquet" #> [264] "nyc-taxi/year=2019/month=12" #> [265] "nyc-taxi/year=2019/month=12/part-0.parquet" #> [266] "nyc-taxi/year=2019/month=2" #> [267] "nyc-taxi/year=2019/month=2/part-0.parquet" #> [268] "nyc-taxi/year=2019/month=3" #> [269] "nyc-taxi/year=2019/month=3/part-0.parquet" #> [270] "nyc-taxi/year=2019/month=4" #> [271] "nyc-taxi/year=2019/month=4/part-0.parquet" #> [272] "nyc-taxi/year=2019/month=5" #> [273] "nyc-taxi/year=2019/month=5/part-0.parquet" #> [274] "nyc-taxi/year=2019/month=6" #> [275] "nyc-taxi/year=2019/month=6/part-0.parquet" #> [276] "nyc-taxi/year=2019/month=7" #> [277] "nyc-taxi/year=2019/month=7/part-0.parquet" #> [278] "nyc-taxi/year=2019/month=8" #> [279] "nyc-taxi/year=2019/month=8/part-0.parquet" #> [280] "nyc-taxi/year=2019/month=9" #> [281] "nyc-taxi/year=2019/month=9/part-0.parquet" #> [282] "nyc-taxi/year=2020" #> [283] "nyc-taxi/year=2020/month=1" #> [284] "nyc-taxi/year=2020/month=1/part-0.parquet" #> [285] "nyc-taxi/year=2020/month=10" #> [286] "nyc-taxi/year=2020/month=10/part-0.parquet" #> [287] "nyc-taxi/year=2020/month=11" #> [288] "nyc-taxi/year=2020/month=11/part-0.parquet" #> [289] "nyc-taxi/year=2020/month=12" #> [290] "nyc-taxi/year=2020/month=12/part-0.parquet" #> [291] "nyc-taxi/year=2020/month=2" #> [292] "nyc-taxi/year=2020/month=2/part-0.parquet" #> [293] "nyc-taxi/year=2020/month=3" #> [294] "nyc-taxi/year=2020/month=3/part-0.parquet" #> [295] "nyc-taxi/year=2020/month=4" #> [296] "nyc-taxi/year=2020/month=4/part-0.parquet" #> [297] "nyc-taxi/year=2020/month=5" #> [298] "nyc-taxi/year=2020/month=5/part-0.parquet" #> [299] "nyc-taxi/year=2020/month=6" #> [300] "nyc-taxi/year=2020/month=6/part-0.parquet" #> [301] "nyc-taxi/year=2020/month=7" #> [302] "nyc-taxi/year=2020/month=7/part-0.parquet" #> [303] "nyc-taxi/year=2020/month=8" #> [304] "nyc-taxi/year=2020/month=8/part-0.parquet" #> [305] "nyc-taxi/year=2020/month=9" #> [306] "nyc-taxi/year=2020/month=9/part-0.parquet" #> [307] "nyc-taxi/year=2021" #> [308] "nyc-taxi/year=2021/month=1" #> [309] "nyc-taxi/year=2021/month=1/part-0.parquet" #> [310] "nyc-taxi/year=2021/month=10" #> [311] "nyc-taxi/year=2021/month=10/part-0.parquet" #> [312] "nyc-taxi/year=2021/month=11" #> [313] "nyc-taxi/year=2021/month=11/part-0.parquet" #> [314] "nyc-taxi/year=2021/month=12" #> [315] "nyc-taxi/year=2021/month=12/part-0.parquet" #> [316] "nyc-taxi/year=2021/month=2" #> [317] "nyc-taxi/year=2021/month=2/part-0.parquet" #> [318] "nyc-taxi/year=2021/month=3" #> [319] "nyc-taxi/year=2021/month=3/part-0.parquet" #> [320] "nyc-taxi/year=2021/month=4" #> [321] "nyc-taxi/year=2021/month=4/part-0.parquet" #> [322] "nyc-taxi/year=2021/month=5" #> [323] "nyc-taxi/year=2021/month=5/part-0.parquet" #> [324] "nyc-taxi/year=2021/month=6" #> [325] "nyc-taxi/year=2021/month=6/part-0.parquet" #> [326] "nyc-taxi/year=2021/month=7" #> [327] "nyc-taxi/year=2021/month=7/part-0.parquet" #> [328] "nyc-taxi/year=2021/month=8" #> [329] "nyc-taxi/year=2021/month=8/part-0.parquet" #> [330] "nyc-taxi/year=2021/month=9" #> [331] "nyc-taxi/year=2021/month=9/part-0.parquet" #> [332] "nyc-taxi/year=2022" #> [333] "nyc-taxi/year=2022/month=1" #> [334] "nyc-taxi/year=2022/month=1/part-0.parquet" #> [335] "nyc-taxi/year=2022/month=2" #> [336] "nyc-taxi/year=2022/month=2/part-0.parquet" ## gcs gcs <- gs_bucket("voltrondata-labs-datasets") gcs_files <- gcs$ls(recursive = TRUE) gcs_files #> [1] "nyc-taxi/year=2009/month=1/part-0.parquet" #> [2] "nyc-taxi/year=2009/month=10/part-0.parquet" #> [3] "nyc-taxi/year=2009/month=11/part-0.parquet" #> [4] "nyc-taxi/year=2009/month=12/part-0.parquet" #> [5] "nyc-taxi/year=2009/month=2/part-0.parquet" #> [6] "nyc-taxi/year=2009/month=3/part-0.parquet" #> [7] "nyc-taxi/year=2009/month=4/part-0.parquet" #> [8] "nyc-taxi/year=2009/month=5/part-0.parquet" #> [9] "nyc-taxi/year=2009/month=6/part-0.parquet" #> [10] "nyc-taxi/year=2009/month=7/part-0.parquet" #> [11] "nyc-taxi/year=2009/month=8/part-0.parquet" #> [12] "nyc-taxi/year=2009/month=9/part-0.parquet" #> [13] "nyc-taxi/year=2010/month=1/part-0.parquet" #> [14] "nyc-taxi/year=2010/month=10/part-0.parquet" #> [15] "nyc-taxi/year=2010/month=11/part-0.parquet" #> [16] "nyc-taxi/year=2010/month=12/part-0.parquet" #> [17] "nyc-taxi/year=2010/month=2/part-0.parquet" #> [18] "nyc-taxi/year=2010/month=3/part-0.parquet" #> [19] "nyc-taxi/year=2010/month=4/part-0.parquet" #> [20] "nyc-taxi/year=2010/month=5/part-0.parquet" #> [21] "nyc-taxi/year=2010/month=6/part-0.parquet" #> [22] "nyc-taxi/year=2010/month=7/part-0.parquet" #> [23] "nyc-taxi/year=2010/month=8/part-0.parquet" #> [24] "nyc-taxi/year=2010/month=9/part-0.parquet" #> [25] "nyc-taxi/year=2011/month=1/part-0.parquet" #> [26] "nyc-taxi/year=2011/month=10/part-0.parquet" #> [27] "nyc-taxi/year=2011/month=11/part-0.parquet" #> [28] "nyc-taxi/year=2011/month=12/part-0.parquet" #> [29] "nyc-taxi/year=2011/month=2/part-0.parquet" #> [30] "nyc-taxi/year=2011/month=3/part-0.parquet" #> [31] "nyc-taxi/year=2011/month=4/part-0.parquet" #> [32] "nyc-taxi/year=2011/month=5/part-0.parquet" #> [33] "nyc-taxi/year=2011/month=6/part-0.parquet" #> [34] "nyc-taxi/year=2011/month=7/part-0.parquet" #> [35] "nyc-taxi/year=2011/month=8/part-0.parquet" #> [36] "nyc-taxi/year=2011/month=9/part-0.parquet" #> [37] "nyc-taxi/year=2012/month=1/part-0.parquet" #> [38] "nyc-taxi/year=2012/month=10/part-0.parquet" #> [39] "nyc-taxi/year=2012/month=11/part-0.parquet" #> [40] "nyc-taxi/year=2012/month=12/part-0.parquet" #> [41] "nyc-taxi/year=2012/month=2/part-0.parquet" #> [42] "nyc-taxi/year=2012/month=3/part-0.parquet" #> [43] "nyc-taxi/year=2012/month=4/part-0.parquet" #> [44] "nyc-taxi/year=2012/month=5/part-0.parquet" #> [45] "nyc-taxi/year=2012/month=6/part-0.parquet" #> [46] "nyc-taxi/year=2012/month=7/part-0.parquet" #> [47] "nyc-taxi/year=2012/month=8/part-0.parquet" #> [48] "nyc-taxi/year=2012/month=9/part-0.parquet" #> [49] "nyc-taxi/year=2013/month=1/part-0.parquet" #> [50] "nyc-taxi/year=2013/month=10/part-0.parquet" #> [51] "nyc-taxi/year=2013/month=11/part-0.parquet" #> [52] "nyc-taxi/year=2013/month=12/part-0.parquet" #> [53] "nyc-taxi/year=2013/month=2/part-0.parquet" #> [54] "nyc-taxi/year=2013/month=3/part-0.parquet" #> [55] "nyc-taxi/year=2013/month=4/part-0.parquet" #> [56] "nyc-taxi/year=2013/month=5/part-0.parquet" #> [57] "nyc-taxi/year=2013/month=6/part-0.parquet" #> [58] "nyc-taxi/year=2013/month=7/part-0.parquet" #> [59] "nyc-taxi/year=2013/month=8/part-0.parquet" #> [60] "nyc-taxi/year=2013/month=9/part-0.parquet" #> [61] "nyc-taxi/year=2014/month=1/part-0.parquet" #> [62] "nyc-taxi/year=2014/month=10/part-0.parquet" #> [63] "nyc-taxi/year=2014/month=11/part-0.parquet" #> [64] "nyc-taxi/year=2014/month=12/part-0.parquet" #> [65] "nyc-taxi/year=2014/month=2/part-0.parquet" #> [66] "nyc-taxi/year=2014/month=3/part-0.parquet" #> [67] "nyc-taxi/year=2014/month=4/part-0.parquet" #> [68] "nyc-taxi/year=2014/month=5/part-0.parquet" #> [69] "nyc-taxi/year=2014/month=6/part-0.parquet" #> [70] "nyc-taxi/year=2014/month=7/part-0.parquet" #> [71] "nyc-taxi/year=2014/month=8/part-0.parquet" #> [72] "nyc-taxi/year=2014/month=9/part-0.parquet" #> [73] "nyc-taxi/year=2015/month=1/part-0.parquet" #> [74] "nyc-taxi/year=2015/month=10/part-0.parquet" #> [75] "nyc-taxi/year=2015/month=11/part-0.parquet" #> [76] "nyc-taxi/year=2015/month=12/part-0.parquet" #> [77] "nyc-taxi/year=2015/month=2/part-0.parquet" #> [78] "nyc-taxi/year=2015/month=3/part-0.parquet" #> [79] "nyc-taxi/year=2015/month=4/part-0.parquet" #> [80] "nyc-taxi/year=2015/month=5/part-0.parquet" #> [81] "nyc-taxi/year=2015/month=6/part-0.parquet" #> [82] "nyc-taxi/year=2015/month=7/part-0.parquet" #> [83] "nyc-taxi/year=2015/month=8/part-0.parquet" #> [84] "nyc-taxi/year=2015/month=9/part-0.parquet" #> [85] "nyc-taxi/year=2016/month=1/part-0.parquet" #> [86] "nyc-taxi/year=2016/month=10/part-0.parquet" #> [87] "nyc-taxi/year=2016/month=11/part-0.parquet" #> [88] "nyc-taxi/year=2016/month=12/part-0.parquet" #> [89] "nyc-taxi/year=2016/month=2/part-0.parquet" #> [90] "nyc-taxi/year=2016/month=3/part-0.parquet" #> [91] "nyc-taxi/year=2016/month=4/part-0.parquet" #> [92] "nyc-taxi/year=2016/month=5/part-0.parquet" #> [93] "nyc-taxi/year=2016/month=6/part-0.parquet" #> [94] "nyc-taxi/year=2016/month=7/part-0.parquet" #> [95] "nyc-taxi/year=2016/month=8/part-0.parquet" #> [96] "nyc-taxi/year=2016/month=9/part-0.parquet" #> [97] "nyc-taxi/year=2017/month=1/part-0.parquet" #> [98] "nyc-taxi/year=2017/month=10/part-0.parquet" #> [99] "nyc-taxi/year=2017/month=11/part-0.parquet" #> [100] "nyc-taxi/year=2017/month=12/part-0.parquet" #> [101] "nyc-taxi/year=2017/month=2/part-0.parquet" #> [102] "nyc-taxi/year=2017/month=3/part-0.parquet" #> [103] "nyc-taxi/year=2017/month=4/part-0.parquet" #> [104] "nyc-taxi/year=2017/month=5/part-0.parquet" #> [105] "nyc-taxi/year=2017/month=6/part-0.parquet" #> [106] "nyc-taxi/year=2017/month=7/part-0.parquet" #> [107] "nyc-taxi/year=2017/month=8/part-0.parquet" #> [108] "nyc-taxi/year=2017/month=9/part-0.parquet" #> [109] "nyc-taxi/year=2018/month=1/part-0.parquet" #> [110] "nyc-taxi/year=2018/month=10/part-0.parquet" #> [111] "nyc-taxi/year=2018/month=11/part-0.parquet" #> [112] "nyc-taxi/year=2018/month=12/part-0.parquet" #> [113] "nyc-taxi/year=2018/month=2/part-0.parquet" #> [114] "nyc-taxi/year=2018/month=3/part-0.parquet" #> [115] "nyc-taxi/year=2018/month=4/part-0.parquet" #> [116] "nyc-taxi/year=2018/month=5/part-0.parquet" #> [117] "nyc-taxi/year=2018/month=6/part-0.parquet" #> [118] "nyc-taxi/year=2018/month=7/part-0.parquet" #> [119] "nyc-taxi/year=2018/month=8/part-0.parquet" #> [120] "nyc-taxi/year=2018/month=9/part-0.parquet" #> [121] "nyc-taxi/year=2019/month=1/part-0.parquet" #> [122] "nyc-taxi/year=2019/month=10/part-0.parquet" #> [123] "nyc-taxi/year=2019/month=11/part-0.parquet" #> [124] "nyc-taxi/year=2019/month=12/part-0.parquet" #> [125] "nyc-taxi/year=2019/month=2/part-0.parquet" #> [126] "nyc-taxi/year=2019/month=3/part-0.parquet" #> [127] "nyc-taxi/year=2019/month=4/part-0.parquet" #> [128] "nyc-taxi/year=2019/month=5/part-0.parquet" #> [129] "nyc-taxi/year=2019/month=6/part-0.parquet" #> [130] "nyc-taxi/year=2019/month=7/part-0.parquet" #> [131] "nyc-taxi/year=2019/month=8/part-0.parquet" #> [132] "nyc-taxi/year=2019/month=9/part-0.parquet" #> [133] "nyc-taxi/year=2020/month=1/part-0.parquet" #> [134] "nyc-taxi/year=2020/month=10/part-0.parquet" #> [135] "nyc-taxi/year=2020/month=11/part-0.parquet" #> [136] "nyc-taxi/year=2020/month=12/part-0.parquet" #> [137] "nyc-taxi/year=2020/month=2/part-0.parquet" #> [138] "nyc-taxi/year=2020/month=3/part-0.parquet" #> [139] "nyc-taxi/year=2020/month=4/part-0.parquet" #> [140] "nyc-taxi/year=2020/month=5/part-0.parquet" #> [141] "nyc-taxi/year=2020/month=6/part-0.parquet" #> [142] "nyc-taxi/year=2020/month=7/part-0.parquet" #> [143] "nyc-taxi/year=2020/month=8/part-0.parquet" #> [144] "nyc-taxi/year=2020/month=9/part-0.parquet" #> [145] "nyc-taxi/year=2021/month=1/part-0.parquet" #> [146] "nyc-taxi/year=2021/month=10/part-0.parquet" #> [147] "nyc-taxi/year=2021/month=11/part-0.parquet" #> [148] "nyc-taxi/year=2021/month=12/part-0.parquet" #> [149] "nyc-taxi/year=2021/month=2/part-0.parquet" #> [150] "nyc-taxi/year=2021/month=3/part-0.parquet" #> [151] "nyc-taxi/year=2021/month=4/part-0.parquet" #> [152] "nyc-taxi/year=2021/month=5/part-0.parquet" #> [153] "nyc-taxi/year=2021/month=6/part-0.parquet" #> [154] "nyc-taxi/year=2021/month=7/part-0.parquet" #> [155] "nyc-taxi/year=2021/month=8/part-0.parquet" #> [156] "nyc-taxi/year=2021/month=9/part-0.parquet" #> [157] "nyc-taxi/year=2022/month=1/part-0.parquet" #> [158] "nyc-taxi/year=2022/month=2/part-0.parquet" ```
Authored-by: Sam Albers Signed-off-by: Nic Crane --- r/vignettes/dataset.Rmd | 27 ++++++++++++++------------- r/vignettes/fs.Rmd | 18 +++++++++--------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/r/vignettes/dataset.Rmd b/r/vignettes/dataset.Rmd index e58922c23a0..0dd42b4e806 100644 --- a/r/vignettes/dataset.Rmd +++ b/r/vignettes/dataset.Rmd @@ -56,18 +56,19 @@ you may need to increase R's download timeout from the default of 60 seconds, e. ```{r, eval = FALSE} bucket <- "https://voltrondata-labs-datasets.s3.us-east-2.amazonaws.com" -for (year in 2009:2019) { - if (year == 2019) { - # We only have through June 2019 there - months <- 1:6 +for (year in 2009:2022) { + if (year == 2022) { + # We only have through Feb 2022 there + months <- 1:2 } else { months <- 1:12 } - for (month in sprintf("%02d", months)) { - dir.create(file.path("nyc-taxi", year, month), recursive = TRUE) + for (month in months) { + dataset_path <- file.path("nyc-taxi", paste0("year=", year), paste0("month=", month)) + dir.create(dataset_path, recursive = TRUE) try(download.file( - paste(bucket, "nyc-taxi", paste0("year=", year), paste0("month=", month), "data.parquet", sep = "/"), - file.path("nyc-taxi", paste0("year=", year), paste0("month=", month), "data.parquet"), + paste(bucket, dataset_path, "part-0.parquet", sep = "/"), + file.path(dataset_path, "part-0.parquet"), mode = "wb" ), silent = TRUE) } @@ -129,8 +130,8 @@ For more information on the usage of these parameters, see `?read_delim_arrow()` [Hive](https://hive.apache.org/)-style partitioning structure is self-describing, with file paths like ``` -year=2009/month=1/data.parquet -year=2009/month=2/data.parquet +year=2009/month=1/part-0.parquet +year=2009/month=2/part-0.parquet ... ``` @@ -138,15 +139,15 @@ But sometimes the directory partitioning isn't self describing; that is, it does contain field names. For example, if instead we had file paths like ``` -2009/01/data.parquet -2009/02/data.parquet +2009/01/part-0.parquet +2009/02/part-0.parquet ... ``` then `open_dataset()` would need some hints as to how to use the file paths. In this case, you could provide `c("year", "month")` to the `partitioning` argument, saying that the first path segment gives the value for `year`, and the second -segment is `month`. Every row in `2009/01/data.parquet` has a value of 2009 for `year` +segment is `month`. Every row in `2009/01/part-0.parquet` has a value of 2009 for `year` and 1 for `month`, even though those columns may not be present in the file. In either case, when you look at the dataset, you can see that in addition to the columns present diff --git a/r/vignettes/fs.Rmd b/r/vignettes/fs.Rmd index 6fb7e2d1af9..10bb1e30e44 100644 --- a/r/vignettes/fs.Rmd +++ b/r/vignettes/fs.Rmd @@ -37,7 +37,7 @@ For example, to read a parquet file from the example NYC taxi data bucket <- s3_bucket("voltrondata-labs-datasets") # Or in GCS (anonymous = TRUE is required if credentials are not configured): bucket <- gs_bucket("voltrondata-labs-datasets", anonymous = TRUE) -df <- read_parquet(bucket$path("nyc-taxi/year=2019/month=6/data.parquet")) +df <- read_parquet(bucket$path("nyc-taxi/year=2019/month=6/part-0.parquet")) ``` Note that this will be slower to read than if the file were local, @@ -68,14 +68,14 @@ useful for holding a reference to a subdirectory somewhere (on S3, GCS, or elsew One way to get a subtree is to call the `$cd()` method on a `FileSystem` ```r -june2019 <- bucket$cd("2019/06") -df <- read_parquet(june2019$path("data.parquet")) +june2019 <- bucket$cd("nyc-taxi/year=2019/month=6") +df <- read_parquet(june2019$path("part-0.parquet")) ``` `SubTreeFileSystem` can also be made from a URI: ```r -june2019 <- SubTreeFileSystem$create("s3://voltrondata-labs-datasets/nyc-taxi/2019/06") +june2019 <- SubTreeFileSystem$create("s3://voltrondata-labs-datasets/nyc-taxi/year=2019/month=6") ``` ## URIs @@ -98,17 +98,17 @@ gs://anonymous@bucket/path For example, one of the NYC taxi data files used in `vignette("dataset", package = "arrow")` is found at ``` -s3://voltrondata-labs-datasets/nyc-taxi/year=2019/month=6/data.parquet +s3://voltrondata-labs-datasets/nyc-taxi/year=2019/month=6/part-0.parquet # Or in GCS (anonymous required on public buckets): -gs://anonymous@voltrondata-labs-datasets/nyc-taxi/year=2019/month=6/data.parquet +gs://anonymous@voltrondata-labs-datasets/nyc-taxi/year=2019/month=6/part-0.parquet ``` Given this URI, you can pass it to `read_parquet()` just as if it were a local file path: ```r -df <- read_parquet("s3://voltrondata-labs-datasets/nyc-taxi/year=2019/month=6/data.parquet") +df <- read_parquet("s3://voltrondata-labs-datasets/nyc-taxi/year=2019/month=6/part-0.parquet") # Or in GCS: -df <- read_parquet("gs://anonymous@voltrondata-labs-datasets/nyc-taxi/year=2019/month=6/data.parquet") +df <- read_parquet("gs://anonymous@voltrondata-labs-datasets/nyc-taxi/year=2019/month=6/part-0.parquet") ``` ### URI options @@ -190,7 +190,7 @@ must pass `anonymous = TRUE` or `anonymous` as the user in a URI: ```r bucket <- gs_bucket("voltrondata-labs-datasets", anonymous = TRUE) fs <- GcsFileSystem$create(anonymous = TRUE) -df <- read_parquet("gs://anonymous@voltrondata-labs-datasets/nyc-taxi/year=2019/month=6/data.parquet") +df <- read_parquet("gs://anonymous@voltrondata-labs-datasets/nyc-taxi/year=2019/month=6/part-0.parquet") ``` + + diff --git a/java/flight/flight-sql-jdbc-driver/jdbc-spotbugs-exclude.xml b/java/flight/flight-sql-jdbc-driver/jdbc-spotbugs-exclude.xml new file mode 100644 index 00000000000..af75d70425c --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/jdbc-spotbugs-exclude.xml @@ -0,0 +1,40 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/java/flight/flight-sql-jdbc-driver/pom.xml b/java/flight/flight-sql-jdbc-driver/pom.xml new file mode 100644 index 00000000000..b8a49165adb --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/pom.xml @@ -0,0 +1,375 @@ + + + + + + arrow-flight + org.apache.arrow + 10.0.0-SNAPSHOT + ../pom.xml + + 4.0.0 + + flight-sql-jdbc-driver + Arrow Flight SQL JDBC Driver + (Contrib/Experimental) A JDBC driver based on Arrow Flight SQL. + jar + https://arrow.apache.org + + + ${project.parent.groupId}:${project.parent.artifactId} + ${project.parent.version} + ${project.name} + ${project.version} + ${project.build.directory}/coverage-reports/jacoco-ut.html + + + + + org.apache.arrow + flight-core + ${project.version} + + + io.netty + netty-transport-native-kqueue + + + io.netty + netty-transport-native-epoll + + + + + + + org.apache.arrow + arrow-memory-core + ${project.version} + + + + + org.apache.arrow + arrow-memory-netty + ${project.version} + runtime + + + + + org.apache.arrow + arrow-vector + ${project.version} + ${arrow.vector.classifier} + + + + com.google.guava + guava + + + + org.slf4j + slf4j-api + runtime + + + + com.google.protobuf + protobuf-java + + + org.hamcrest + hamcrest-core + 1.3 + test + + + me.alexpanov + free-port-finder + 1.1.1 + test + + + + commons-io + commons-io + 2.6 + test + + + + org.mockito + mockito-core + 3.12.4 + test + + + + org.mockito + mockito-inline + 3.12.4 + test + + + + io.netty + netty-common + + + + org.apache.arrow + flight-sql + ${project.version} + + + + org.apache.calcite.avatica + avatica + 1.18.0 + + + org.bouncycastle + bcpkix-jdk15on + 1.61 + + + + joda-time + joda-time + 2.10.14 + + + + + + + src/main/resources + + + + + org.apache.maven.plugins + maven-shade-plugin + 3.2.4 + + + package + + shade + + + false + false + false + + + *:* + + + + + com. + cfjd.com. + + com.sun.** + + + + org. + cfjd.org. + + org.apache.arrow.driver.jdbc.** + org.slf4j.** + + org.apache.arrow.flight.name + org.apache.arrow.flight.version + org.apache.arrow.flight.jdbc-driver.name + org.apache.arrow.flight.jdbc-driver.version + + + + io. + cfjd.io. + + + + META-INF.native.libnetty_ + META-INF.native.libcfjd_netty_ + + + META-INF.native.netty_ + META-INF.native.cfjd_netty_ + + + + + + + + org.apache.calcite.avatica:* + + META-INF/services/java.sql.Driver + + + + *:* + + **/*.SF + **/*.RSA + **/*.DSA + META-INF/native/libio_grpc_netty* + META-INF/native/io_grpc_netty_shaded* + + + + + + + + + org.codehaus.mojo + properties-maven-plugin + 1.1.0 + + + write-project-properties-to-file + generate-resources + + write-project-properties + + + src/main/resources/properties/flight.properties + + + + + + org.jacoco + jacoco-maven-plugin + + + + prepare-agent + + + + ${jacoco.ut.execution.data.file} + + surefireArgLine + + + + + report + test + + report + + + + ${jacoco.ut.execution.data.file} + + + + + + check + + check + + + ${jacoco.ut.execution.data.file} + + + CLASS + + org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl + + org.apache.arrow.driver.jdbc.utils.UrlParser + + + + BRANCH + COVEREDRATIO + 0.80 + + + + + + + + + + + + + + jdk8 + + 1.8 + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${surefireArgLine} + + **/IT*.java + + false + + ${project.basedir}/../../../testing/data + + + + + + + + + jdk9+ + + [9,] + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${surefireArgLine} --add-opens=java.base/java.nio=ALL-UNNAMED + + **/IT*.java + + false + + ${project.basedir}/../../../testing/data + + + + + + + + diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowDatabaseMetadata.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowDatabaseMetadata.java new file mode 100644 index 00000000000..da2b0b00eda --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowDatabaseMetadata.java @@ -0,0 +1,1218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static java.sql.Types.BIGINT; +import static java.sql.Types.BINARY; +import static java.sql.Types.BIT; +import static java.sql.Types.CHAR; +import static java.sql.Types.DATE; +import static java.sql.Types.DECIMAL; +import static java.sql.Types.FLOAT; +import static java.sql.Types.INTEGER; +import static java.sql.Types.LONGNVARCHAR; +import static java.sql.Types.LONGVARBINARY; +import static java.sql.Types.NUMERIC; +import static java.sql.Types.REAL; +import static java.sql.Types.SMALLINT; +import static java.sql.Types.TIMESTAMP; +import static java.sql.Types.TINYINT; +import static java.sql.Types.VARCHAR; +import static org.apache.arrow.flight.sql.util.SqlInfoOptionsUtils.doesBitmaskTranslateToEnum; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.nio.channels.Channels; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumMap; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import org.apache.arrow.driver.jdbc.utils.SqlTypes; +import org.apache.arrow.driver.jdbc.utils.VectorSchemaRootTransformer; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.sql.FlightSqlColumnMetadata; +import org.apache.arrow.flight.sql.FlightSqlProducer.Schemas; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlOuterJoinsSupportLevel; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedElementActions; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedGroupBy; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedPositionedCommands; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedResultSetType; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedSubqueries; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedUnions; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportsConvert; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlTransactionIsolationLevel; +import org.apache.arrow.flight.sql.impl.FlightSql.SupportedAnsi92SqlGrammarLevel; +import org.apache.arrow.flight.sql.impl.FlightSql.SupportedSqlGrammar; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.ipc.ReadChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.avatica.AvaticaDatabaseMetaData; + +import com.google.protobuf.ProtocolMessageEnum; + +/** + * Arrow Flight JDBC's implementation of {@link DatabaseMetaData}. + */ +public class ArrowDatabaseMetadata extends AvaticaDatabaseMetaData { + private static final String JAVA_REGEX_SPECIALS = "[]()|^-+*?{}$\\."; + private static final Charset CHARSET = StandardCharsets.UTF_8; + private static final byte[] EMPTY_BYTE_ARRAY = new byte[0]; + static final int NO_DECIMAL_DIGITS = 0; + private static final int BASE10_RADIX = 10; + static final int COLUMN_SIZE_BYTE = (int) Math.ceil((Byte.SIZE - 1) * Math.log(2) / Math.log(10)); + static final int COLUMN_SIZE_SHORT = + (int) Math.ceil((Short.SIZE - 1) * Math.log(2) / Math.log(10)); + static final int COLUMN_SIZE_INT = + (int) Math.ceil((Integer.SIZE - 1) * Math.log(2) / Math.log(10)); + static final int COLUMN_SIZE_LONG = (int) Math.ceil((Long.SIZE - 1) * Math.log(2) / Math.log(10)); + static final int COLUMN_SIZE_VARCHAR_AND_BINARY = 65536; + static final int COLUMN_SIZE_DATE = "YYYY-MM-DD".length(); + static final int COLUMN_SIZE_TIME = "HH:MM:ss".length(); + static final int COLUMN_SIZE_TIME_MILLISECONDS = "HH:MM:ss.SSS".length(); + static final int COLUMN_SIZE_TIME_MICROSECONDS = "HH:MM:ss.SSSSSS".length(); + static final int COLUMN_SIZE_TIME_NANOSECONDS = "HH:MM:ss.SSSSSSSSS".length(); + static final int COLUMN_SIZE_TIMESTAMP_SECONDS = COLUMN_SIZE_DATE + 1 + COLUMN_SIZE_TIME; + static final int COLUMN_SIZE_TIMESTAMP_MILLISECONDS = + COLUMN_SIZE_DATE + 1 + COLUMN_SIZE_TIME_MILLISECONDS; + static final int COLUMN_SIZE_TIMESTAMP_MICROSECONDS = + COLUMN_SIZE_DATE + 1 + COLUMN_SIZE_TIME_MICROSECONDS; + static final int COLUMN_SIZE_TIMESTAMP_NANOSECONDS = + COLUMN_SIZE_DATE + 1 + COLUMN_SIZE_TIME_NANOSECONDS; + static final int DECIMAL_DIGITS_TIME_MILLISECONDS = 3; + static final int DECIMAL_DIGITS_TIME_MICROSECONDS = 6; + static final int DECIMAL_DIGITS_TIME_NANOSECONDS = 9; + private static final Schema GET_COLUMNS_SCHEMA = new Schema( + Arrays.asList( + Field.nullable("TABLE_CAT", Types.MinorType.VARCHAR.getType()), + Field.nullable("TABLE_SCHEM", Types.MinorType.VARCHAR.getType()), + Field.notNullable("TABLE_NAME", Types.MinorType.VARCHAR.getType()), + Field.notNullable("COLUMN_NAME", Types.MinorType.VARCHAR.getType()), + Field.nullable("DATA_TYPE", Types.MinorType.INT.getType()), + Field.nullable("TYPE_NAME", Types.MinorType.VARCHAR.getType()), + Field.nullable("COLUMN_SIZE", Types.MinorType.INT.getType()), + Field.nullable("BUFFER_LENGTH", Types.MinorType.INT.getType()), + Field.nullable("DECIMAL_DIGITS", Types.MinorType.INT.getType()), + Field.nullable("NUM_PREC_RADIX", Types.MinorType.INT.getType()), + Field.notNullable("NULLABLE", Types.MinorType.INT.getType()), + Field.nullable("REMARKS", Types.MinorType.VARCHAR.getType()), + Field.nullable("COLUMN_DEF", Types.MinorType.VARCHAR.getType()), + Field.nullable("SQL_DATA_TYPE", Types.MinorType.INT.getType()), + Field.nullable("SQL_DATETIME_SUB", Types.MinorType.INT.getType()), + Field.notNullable("CHAR_OCTET_LENGTH", Types.MinorType.INT.getType()), + Field.notNullable("ORDINAL_POSITION", Types.MinorType.INT.getType()), + Field.notNullable("IS_NULLABLE", Types.MinorType.VARCHAR.getType()), + Field.nullable("SCOPE_CATALOG", Types.MinorType.VARCHAR.getType()), + Field.nullable("SCOPE_SCHEMA", Types.MinorType.VARCHAR.getType()), + Field.nullable("SCOPE_TABLE", Types.MinorType.VARCHAR.getType()), + Field.nullable("SOURCE_DATA_TYPE", Types.MinorType.SMALLINT.getType()), + Field.notNullable("IS_AUTOINCREMENT", Types.MinorType.VARCHAR.getType()), + Field.notNullable("IS_GENERATEDCOLUMN", Types.MinorType.VARCHAR.getType()) + )); + private final Map cachedSqlInfo = + Collections.synchronizedMap(new EnumMap<>(SqlInfo.class)); + private static final Map sqlTypesToFlightEnumConvertTypes = new HashMap<>(); + + static { + sqlTypesToFlightEnumConvertTypes.put(BIT, SqlSupportsConvert.SQL_CONVERT_BIT_VALUE); + sqlTypesToFlightEnumConvertTypes.put(INTEGER, SqlSupportsConvert.SQL_CONVERT_INTEGER_VALUE); + sqlTypesToFlightEnumConvertTypes.put(NUMERIC, SqlSupportsConvert.SQL_CONVERT_NUMERIC_VALUE); + sqlTypesToFlightEnumConvertTypes.put(SMALLINT, SqlSupportsConvert.SQL_CONVERT_SMALLINT_VALUE); + sqlTypesToFlightEnumConvertTypes.put(TINYINT, SqlSupportsConvert.SQL_CONVERT_TINYINT_VALUE); + sqlTypesToFlightEnumConvertTypes.put(FLOAT, SqlSupportsConvert.SQL_CONVERT_FLOAT_VALUE); + sqlTypesToFlightEnumConvertTypes.put(BIGINT, SqlSupportsConvert.SQL_CONVERT_BIGINT_VALUE); + sqlTypesToFlightEnumConvertTypes.put(REAL, SqlSupportsConvert.SQL_CONVERT_REAL_VALUE); + sqlTypesToFlightEnumConvertTypes.put(DECIMAL, SqlSupportsConvert.SQL_CONVERT_DECIMAL_VALUE); + sqlTypesToFlightEnumConvertTypes.put(BINARY, SqlSupportsConvert.SQL_CONVERT_BINARY_VALUE); + sqlTypesToFlightEnumConvertTypes.put(LONGVARBINARY, + SqlSupportsConvert.SQL_CONVERT_LONGVARBINARY_VALUE); + sqlTypesToFlightEnumConvertTypes.put(CHAR, SqlSupportsConvert.SQL_CONVERT_CHAR_VALUE); + sqlTypesToFlightEnumConvertTypes.put(VARCHAR, SqlSupportsConvert.SQL_CONVERT_VARCHAR_VALUE); + sqlTypesToFlightEnumConvertTypes.put(LONGNVARCHAR, + SqlSupportsConvert.SQL_CONVERT_LONGVARCHAR_VALUE); + sqlTypesToFlightEnumConvertTypes.put(DATE, SqlSupportsConvert.SQL_CONVERT_DATE_VALUE); + sqlTypesToFlightEnumConvertTypes.put(TIMESTAMP, SqlSupportsConvert.SQL_CONVERT_TIMESTAMP_VALUE); + } + + ArrowDatabaseMetadata(final AvaticaConnection connection) { + super(connection); + } + + @Override + public String getDatabaseProductName() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.FLIGHT_SQL_SERVER_NAME, String.class); + } + + @Override + public String getDatabaseProductVersion() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.FLIGHT_SQL_SERVER_VERSION, String.class); + } + + @Override + public String getIdentifierQuoteString() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_IDENTIFIER_QUOTE_CHAR, String.class); + } + + @Override + public boolean isReadOnly() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY, Boolean.class); + } + + @Override + public String getSQLKeywords() throws SQLException { + return convertListSqlInfoToString( + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_KEYWORDS, List.class)); + } + + @Override + public String getNumericFunctions() throws SQLException { + return convertListSqlInfoToString( + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_NUMERIC_FUNCTIONS, List.class)); + } + + @Override + public String getStringFunctions() throws SQLException { + return convertListSqlInfoToString( + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_STRING_FUNCTIONS, List.class)); + } + + @Override + public String getSystemFunctions() throws SQLException { + return convertListSqlInfoToString( + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SYSTEM_FUNCTIONS, List.class)); + } + + @Override + public String getTimeDateFunctions() throws SQLException { + return convertListSqlInfoToString( + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_DATETIME_FUNCTIONS, List.class)); + } + + @Override + public String getSearchStringEscape() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SEARCH_STRING_ESCAPE, String.class); + } + + @Override + public String getExtraNameCharacters() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_EXTRA_NAME_CHARACTERS, String.class); + } + + @Override + public boolean supportsColumnAliasing() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_COLUMN_ALIASING, Boolean.class); + } + + @Override + public boolean nullPlusNonNullIsNull() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_NULL_PLUS_NULL_IS_NULL, Boolean.class); + } + + @Override + public boolean supportsConvert() throws SQLException { + return !getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_CONVERT, Map.class).isEmpty(); + } + + @Override + public boolean supportsConvert(final int fromType, final int toType) throws SQLException { + final Map> sqlSupportsConvert = + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_CONVERT, Map.class); + + if (!sqlTypesToFlightEnumConvertTypes.containsKey(fromType)) { + return false; + } + + final List list = + sqlSupportsConvert.get(sqlTypesToFlightEnumConvertTypes.get(fromType)); + + return list != null && list.contains(sqlTypesToFlightEnumConvertTypes.get(toType)); + } + + @Override + public boolean supportsTableCorrelationNames() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_TABLE_CORRELATION_NAMES, + Boolean.class); + } + + @Override + public boolean supportsDifferentTableCorrelationNames() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES, + Boolean.class); + } + + @Override + public boolean supportsExpressionsInOrderBy() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_EXPRESSIONS_IN_ORDER_BY, + Boolean.class); + } + + @Override + public boolean supportsOrderByUnrelated() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_ORDER_BY_UNRELATED, Boolean.class); + } + + @Override + public boolean supportsGroupBy() throws SQLException { + final int bitmask = + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_GROUP_BY, Integer.class); + return bitmask != 0; + } + + @Override + public boolean supportsGroupByUnrelated() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_GROUP_BY, + SqlSupportedGroupBy.SQL_GROUP_BY_UNRELATED); + } + + @Override + public boolean supportsLikeEscapeClause() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_LIKE_ESCAPE_CLAUSE, Boolean.class); + } + + @Override + public boolean supportsNonNullableColumns() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_NON_NULLABLE_COLUMNS, + Boolean.class); + } + + @Override + public boolean supportsMinimumSQLGrammar() throws SQLException { + return checkEnumLevel( + Arrays.asList(getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_GRAMMAR, + SupportedSqlGrammar.SQL_EXTENDED_GRAMMAR), + getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_GRAMMAR, + SupportedSqlGrammar.SQL_CORE_GRAMMAR), + getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_GRAMMAR, + SupportedSqlGrammar.SQL_MINIMUM_GRAMMAR))); + } + + @Override + public boolean supportsCoreSQLGrammar() throws SQLException { + return checkEnumLevel( + Arrays.asList(getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_GRAMMAR, + SupportedSqlGrammar.SQL_EXTENDED_GRAMMAR), + getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_GRAMMAR, + SupportedSqlGrammar.SQL_CORE_GRAMMAR))); + } + + @Override + public boolean supportsExtendedSQLGrammar() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_GRAMMAR, + SupportedSqlGrammar.SQL_EXTENDED_GRAMMAR); + } + + @Override + public boolean supportsANSI92EntryLevelSQL() throws SQLException { + return checkEnumLevel( + Arrays.asList(getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_ANSI92_SUPPORTED_LEVEL, + SupportedAnsi92SqlGrammarLevel.ANSI92_ENTRY_SQL), + getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_ANSI92_SUPPORTED_LEVEL, + SupportedAnsi92SqlGrammarLevel.ANSI92_INTERMEDIATE_SQL), + getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_ANSI92_SUPPORTED_LEVEL, + SupportedAnsi92SqlGrammarLevel.ANSI92_FULL_SQL))); + } + + @Override + public boolean supportsANSI92IntermediateSQL() throws SQLException { + return checkEnumLevel( + Arrays.asList(getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_ANSI92_SUPPORTED_LEVEL, + SupportedAnsi92SqlGrammarLevel.ANSI92_ENTRY_SQL), + getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_ANSI92_SUPPORTED_LEVEL, + SupportedAnsi92SqlGrammarLevel.ANSI92_INTERMEDIATE_SQL))); + } + + @Override + public boolean supportsANSI92FullSQL() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_ANSI92_SUPPORTED_LEVEL, + SupportedAnsi92SqlGrammarLevel.ANSI92_FULL_SQL); + } + + @Override + public boolean supportsIntegrityEnhancementFacility() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY, + Boolean.class); + } + + @Override + public boolean supportsOuterJoins() throws SQLException { + final int bitmask = + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_OUTER_JOINS_SUPPORT_LEVEL, Integer.class); + return bitmask != 0; + } + + @Override + public boolean supportsFullOuterJoins() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_OUTER_JOINS_SUPPORT_LEVEL, + SqlOuterJoinsSupportLevel.SQL_FULL_OUTER_JOINS); + } + + @Override + public boolean supportsLimitedOuterJoins() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_OUTER_JOINS_SUPPORT_LEVEL, + SqlOuterJoinsSupportLevel.SQL_LIMITED_OUTER_JOINS); + } + + @Override + public String getSchemaTerm() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SCHEMA_TERM, String.class); + } + + @Override + public String getProcedureTerm() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_PROCEDURE_TERM, String.class); + } + + @Override + public String getCatalogTerm() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_CATALOG_TERM, String.class); + } + + @Override + public boolean isCatalogAtStart() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_CATALOG_AT_START, Boolean.class); + } + + @Override + public boolean supportsSchemasInProcedureCalls() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SCHEMAS_SUPPORTED_ACTIONS, + SqlSupportedElementActions.SQL_ELEMENT_IN_PROCEDURE_CALLS); + } + + @Override + public boolean supportsSchemasInIndexDefinitions() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SCHEMAS_SUPPORTED_ACTIONS, + SqlSupportedElementActions.SQL_ELEMENT_IN_INDEX_DEFINITIONS); + } + + @Override + public boolean supportsSchemasInPrivilegeDefinitions() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SCHEMAS_SUPPORTED_ACTIONS, + SqlSupportedElementActions.SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS); + } + + @Override + public boolean supportsCatalogsInIndexDefinitions() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_CATALOGS_SUPPORTED_ACTIONS, + SqlSupportedElementActions.SQL_ELEMENT_IN_INDEX_DEFINITIONS); + } + + @Override + public boolean supportsCatalogsInPrivilegeDefinitions() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_CATALOGS_SUPPORTED_ACTIONS, + SqlSupportedElementActions.SQL_ELEMENT_IN_PRIVILEGE_DEFINITIONS); + } + + @Override + public boolean supportsPositionedDelete() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_POSITIONED_COMMANDS, + SqlSupportedPositionedCommands.SQL_POSITIONED_DELETE); + } + + @Override + public boolean supportsPositionedUpdate() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_POSITIONED_COMMANDS, + SqlSupportedPositionedCommands.SQL_POSITIONED_UPDATE); + } + + @Override + public boolean supportsResultSetType(final int type) throws SQLException { + final int bitmask = + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_RESULT_SET_TYPES, Integer.class); + + switch (type) { + case ResultSet.TYPE_FORWARD_ONLY: + return doesBitmaskTranslateToEnum(SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY, + bitmask); + case ResultSet.TYPE_SCROLL_INSENSITIVE: + return doesBitmaskTranslateToEnum(SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE, + bitmask); + case ResultSet.TYPE_SCROLL_SENSITIVE: + return doesBitmaskTranslateToEnum(SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_SENSITIVE, + bitmask); + default: + throw new SQLException( + "Invalid result set type argument. The informed type is not defined in java.sql.ResultSet."); + } + } + + @Override + public boolean supportsSelectForUpdate() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SELECT_FOR_UPDATE_SUPPORTED, Boolean.class); + } + + @Override + public boolean supportsStoredProcedures() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_STORED_PROCEDURES_SUPPORTED, Boolean.class); + } + + @Override + public boolean supportsSubqueriesInComparisons() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_SUBQUERIES, + SqlSupportedSubqueries.SQL_SUBQUERIES_IN_COMPARISONS); + } + + @Override + public boolean supportsSubqueriesInExists() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_SUBQUERIES, + SqlSupportedSubqueries.SQL_SUBQUERIES_IN_EXISTS); + } + + @Override + public boolean supportsSubqueriesInIns() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_SUBQUERIES, + SqlSupportedSubqueries.SQL_SUBQUERIES_IN_INS); + } + + @Override + public boolean supportsSubqueriesInQuantifieds() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_SUBQUERIES, + SqlSupportedSubqueries.SQL_SUBQUERIES_IN_QUANTIFIEDS); + } + + @Override + public boolean supportsCorrelatedSubqueries() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_CORRELATED_SUBQUERIES_SUPPORTED, + Boolean.class); + } + + @Override + public boolean supportsUnion() throws SQLException { + final int bitmask = + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_UNIONS, Integer.class); + return bitmask != 0; + } + + @Override + public boolean supportsUnionAll() throws SQLException { + return getSqlInfoEnumOptionAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_UNIONS, + SqlSupportedUnions.SQL_UNION_ALL); + } + + @Override + public int getMaxBinaryLiteralLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_BINARY_LITERAL_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxCharLiteralLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_CHAR_LITERAL_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxColumnNameLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_COLUMN_NAME_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxColumnsInGroupBy() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_COLUMNS_IN_GROUP_BY, + Long.class).intValue(); + } + + @Override + public int getMaxColumnsInIndex() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_COLUMNS_IN_INDEX, + Long.class).intValue(); + } + + @Override + public int getMaxColumnsInOrderBy() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_COLUMNS_IN_ORDER_BY, + Long.class).intValue(); + } + + @Override + public int getMaxColumnsInSelect() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_COLUMNS_IN_SELECT, + Long.class).intValue(); + } + + @Override + public int getMaxColumnsInTable() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_COLUMNS_IN_TABLE, + Long.class).intValue(); + } + + @Override + public int getMaxConnections() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_CONNECTIONS, Long.class).intValue(); + } + + @Override + public int getMaxCursorNameLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_CURSOR_NAME_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxIndexLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_INDEX_LENGTH, Long.class).intValue(); + } + + @Override + public int getMaxSchemaNameLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_DB_SCHEMA_NAME_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxProcedureNameLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_PROCEDURE_NAME_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxCatalogNameLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_CATALOG_NAME_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxRowSize() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_ROW_SIZE, Long.class).intValue(); + } + + @Override + public boolean doesMaxRowSizeIncludeBlobs() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_ROW_SIZE_INCLUDES_BLOBS, Boolean.class); + } + + @Override + public int getMaxStatementLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_STATEMENT_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxStatements() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_STATEMENTS, Long.class).intValue(); + } + + @Override + public int getMaxTableNameLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_TABLE_NAME_LENGTH, + Long.class).intValue(); + } + + @Override + public int getMaxTablesInSelect() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_TABLES_IN_SELECT, + Long.class).intValue(); + } + + @Override + public int getMaxUserNameLength() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_MAX_USERNAME_LENGTH, Long.class).intValue(); + } + + @Override + public int getDefaultTransactionIsolation() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_DEFAULT_TRANSACTION_ISOLATION, + Long.class).intValue(); + } + + @Override + public boolean supportsTransactions() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_TRANSACTIONS_SUPPORTED, Boolean.class); + } + + @Override + public boolean supportsTransactionIsolationLevel(final int level) throws SQLException { + final int bitmask = + getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SUPPORTED_TRANSACTIONS_ISOLATION_LEVELS, + Integer.class); + + switch (level) { + case Connection.TRANSACTION_NONE: + return doesBitmaskTranslateToEnum(SqlTransactionIsolationLevel.SQL_TRANSACTION_NONE, bitmask); + case Connection.TRANSACTION_READ_COMMITTED: + return doesBitmaskTranslateToEnum(SqlTransactionIsolationLevel.SQL_TRANSACTION_READ_COMMITTED, + bitmask); + case Connection.TRANSACTION_READ_UNCOMMITTED: + return doesBitmaskTranslateToEnum(SqlTransactionIsolationLevel.SQL_TRANSACTION_READ_UNCOMMITTED, + bitmask); + case Connection.TRANSACTION_REPEATABLE_READ: + return doesBitmaskTranslateToEnum(SqlTransactionIsolationLevel.SQL_TRANSACTION_REPEATABLE_READ, + bitmask); + case Connection.TRANSACTION_SERIALIZABLE: + return doesBitmaskTranslateToEnum(SqlTransactionIsolationLevel.SQL_TRANSACTION_SERIALIZABLE, + bitmask); + default: + throw new SQLException( + "Invalid transaction isolation level argument. The informed level is not defined in java.sql.Connection."); + } + } + + @Override + public boolean dataDefinitionCausesTransactionCommit() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT, + Boolean.class); + } + + @Override + public boolean dataDefinitionIgnoredInTransactions() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED, + Boolean.class); + } + + @Override + public boolean supportsBatchUpdates() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_BATCH_UPDATES_SUPPORTED, Boolean.class); + } + + @Override + public boolean supportsSavepoints() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_SAVEPOINTS_SUPPORTED, Boolean.class); + } + + @Override + public boolean supportsNamedParameters() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_NAMED_PARAMETERS_SUPPORTED, Boolean.class); + } + + @Override + public boolean locatorsUpdateCopy() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty(SqlInfo.SQL_LOCATORS_UPDATE_COPY, Boolean.class); + } + + @Override + public boolean supportsStoredFunctionsUsingCallSyntax() throws SQLException { + return getSqlInfoAndCacheIfCacheIsEmpty( + SqlInfo.SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED, Boolean.class); + } + + @Override + public ArrowFlightConnection getConnection() throws SQLException { + return (ArrowFlightConnection) super.getConnection(); + } + + private T getSqlInfoAndCacheIfCacheIsEmpty(final SqlInfo sqlInfoCommand, + final Class desiredType) + throws SQLException { + final ArrowFlightConnection connection = getConnection(); + if (cachedSqlInfo.isEmpty()) { + final FlightInfo sqlInfo = connection.getClientHandler().getSqlInfo(); + synchronized (cachedSqlInfo) { + if (cachedSqlInfo.isEmpty()) { + try (final ResultSet resultSet = + ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo( + connection, sqlInfo, null)) { + while (resultSet.next()) { + cachedSqlInfo.put(SqlInfo.forNumber((Integer) resultSet.getObject("info_name")), + resultSet.getObject("value")); + } + } + } + } + } + return desiredType.cast(cachedSqlInfo.get(sqlInfoCommand)); + } + + private String convertListSqlInfoToString(final List sqlInfoList) { + return sqlInfoList.stream().map(Object::toString).collect(Collectors.joining(", ")); + } + + private boolean getSqlInfoEnumOptionAndCacheIfCacheIsEmpty( + final SqlInfo sqlInfoCommand, + final ProtocolMessageEnum enumInstance + ) throws SQLException { + final int bitmask = getSqlInfoAndCacheIfCacheIsEmpty(sqlInfoCommand, Integer.class); + return doesBitmaskTranslateToEnum(enumInstance, bitmask); + } + + private boolean checkEnumLevel(final List toCheck) { + return toCheck.stream().anyMatch(e -> e); + } + + @Override + public ResultSet getCatalogs() throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final FlightInfo flightInfoCatalogs = connection.getClientHandler().getCatalogs(); + + final BufferAllocator allocator = connection.getBufferAllocator(); + final VectorSchemaRootTransformer transformer = + new VectorSchemaRootTransformer.Builder(Schemas.GET_CATALOGS_SCHEMA, allocator) + .renameFieldVector("catalog_name", "TABLE_CAT") + .build(); + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoCatalogs, + transformer); + } + + @Override + public ResultSet getImportedKeys(final String catalog, final String schema, final String table) + throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final FlightInfo flightInfoImportedKeys = + connection.getClientHandler().getImportedKeys(catalog, schema, table); + + final BufferAllocator allocator = connection.getBufferAllocator(); + final VectorSchemaRootTransformer transformer = getForeignKeysTransformer(allocator); + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoImportedKeys, + transformer); + } + + @Override + public ResultSet getExportedKeys(final String catalog, final String schema, final String table) + throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final FlightInfo flightInfoExportedKeys = + connection.getClientHandler().getExportedKeys(catalog, schema, table); + + final BufferAllocator allocator = connection.getBufferAllocator(); + final VectorSchemaRootTransformer transformer = getForeignKeysTransformer(allocator); + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoExportedKeys, + transformer); + } + + @Override + public ResultSet getCrossReference(final String parentCatalog, final String parentSchema, + final String parentTable, + final String foreignCatalog, final String foreignSchema, + final String foreignTable) + throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final FlightInfo flightInfoCrossReference = connection.getClientHandler().getCrossReference( + parentCatalog, parentSchema, parentTable, foreignCatalog, foreignSchema, foreignTable); + + final BufferAllocator allocator = connection.getBufferAllocator(); + final VectorSchemaRootTransformer transformer = getForeignKeysTransformer(allocator); + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoCrossReference, + transformer); + } + + /** + * Transformer used on getImportedKeys, getExportedKeys and getCrossReference methods, since + * all three share the same schema. + */ + private VectorSchemaRootTransformer getForeignKeysTransformer(final BufferAllocator allocator) { + return new VectorSchemaRootTransformer.Builder(Schemas.GET_IMPORTED_KEYS_SCHEMA, + allocator) + .renameFieldVector("pk_catalog_name", "PKTABLE_CAT") + .renameFieldVector("pk_db_schema_name", "PKTABLE_SCHEM") + .renameFieldVector("pk_table_name", "PKTABLE_NAME") + .renameFieldVector("pk_column_name", "PKCOLUMN_NAME") + .renameFieldVector("fk_catalog_name", "FKTABLE_CAT") + .renameFieldVector("fk_db_schema_name", "FKTABLE_SCHEM") + .renameFieldVector("fk_table_name", "FKTABLE_NAME") + .renameFieldVector("fk_column_name", "FKCOLUMN_NAME") + .renameFieldVector("key_sequence", "KEY_SEQ") + .renameFieldVector("fk_key_name", "FK_NAME") + .renameFieldVector("pk_key_name", "PK_NAME") + .renameFieldVector("update_rule", "UPDATE_RULE") + .renameFieldVector("delete_rule", "DELETE_RULE") + .addEmptyField("DEFERRABILITY", new ArrowType.Int(Byte.SIZE, false)) + .build(); + } + + @Override + public ResultSet getSchemas(final String catalog, final String schemaPattern) + throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final FlightInfo flightInfoSchemas = + connection.getClientHandler().getSchemas(catalog, schemaPattern); + + final BufferAllocator allocator = connection.getBufferAllocator(); + final VectorSchemaRootTransformer transformer = + new VectorSchemaRootTransformer.Builder(Schemas.GET_SCHEMAS_SCHEMA, allocator) + .renameFieldVector("db_schema_name", "TABLE_SCHEM") + .renameFieldVector("catalog_name", "TABLE_CATALOG") + .build(); + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoSchemas, + transformer); + } + + @Override + public ResultSet getTableTypes() throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final FlightInfo flightInfoTableTypes = connection.getClientHandler().getTableTypes(); + + final BufferAllocator allocator = connection.getBufferAllocator(); + final VectorSchemaRootTransformer transformer = + new VectorSchemaRootTransformer.Builder(Schemas.GET_TABLE_TYPES_SCHEMA, allocator) + .renameFieldVector("table_type", "TABLE_TYPE") + .build(); + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoTableTypes, + transformer); + } + + @Override + public ResultSet getTables(final String catalog, final String schemaPattern, + final String tableNamePattern, + final String[] types) + throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final List typesList = types == null ? null : Arrays.asList(types); + final FlightInfo flightInfoTables = + connection.getClientHandler() + .getTables(catalog, schemaPattern, tableNamePattern, typesList, false); + + final BufferAllocator allocator = connection.getBufferAllocator(); + final VectorSchemaRootTransformer transformer = + new VectorSchemaRootTransformer.Builder(Schemas.GET_TABLES_SCHEMA_NO_SCHEMA, allocator) + .renameFieldVector("catalog_name", "TABLE_CAT") + .renameFieldVector("db_schema_name", "TABLE_SCHEM") + .renameFieldVector("table_name", "TABLE_NAME") + .renameFieldVector("table_type", "TABLE_TYPE") + .addEmptyField("REMARKS", Types.MinorType.VARBINARY) + .addEmptyField("TYPE_CAT", Types.MinorType.VARBINARY) + .addEmptyField("TYPE_SCHEM", Types.MinorType.VARBINARY) + .addEmptyField("TYPE_NAME", Types.MinorType.VARBINARY) + .addEmptyField("SELF_REFERENCING_COL_NAME", Types.MinorType.VARBINARY) + .addEmptyField("REF_GENERATION", Types.MinorType.VARBINARY) + .build(); + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoTables, + transformer); + } + + @Override + public ResultSet getPrimaryKeys(final String catalog, final String schema, final String table) + throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final FlightInfo flightInfoPrimaryKeys = + connection.getClientHandler().getPrimaryKeys(catalog, schema, table); + + final BufferAllocator allocator = connection.getBufferAllocator(); + final VectorSchemaRootTransformer transformer = + new VectorSchemaRootTransformer.Builder(Schemas.GET_PRIMARY_KEYS_SCHEMA, allocator) + .renameFieldVector("catalog_name", "TABLE_CAT") + .renameFieldVector("db_schema_name", "TABLE_SCHEM") + .renameFieldVector("table_name", "TABLE_NAME") + .renameFieldVector("column_name", "COLUMN_NAME") + .renameFieldVector("key_sequence", "KEY_SEQ") + .renameFieldVector("key_name", "PK_NAME") + .build(); + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoPrimaryKeys, + transformer); + } + + @Override + public ResultSet getColumns(final String catalog, final String schemaPattern, + final String tableNamePattern, + final String columnNamePattern) + throws SQLException { + final ArrowFlightConnection connection = getConnection(); + final FlightInfo flightInfoTables = + connection.getClientHandler() + .getTables(catalog, schemaPattern, tableNamePattern, null, true); + + final BufferAllocator allocator = connection.getBufferAllocator(); + + final Pattern columnNamePat = + columnNamePattern != null ? Pattern.compile(sqlToRegexLike(columnNamePattern)) : null; + + return ArrowFlightJdbcFlightStreamResultSet.fromFlightInfo(connection, flightInfoTables, + (originalRoot, transformedRoot) -> { + int columnCounter = 0; + if (transformedRoot == null) { + transformedRoot = VectorSchemaRoot.create(GET_COLUMNS_SCHEMA, allocator); + } + + final int originalRootRowCount = originalRoot.getRowCount(); + + final VarCharVector catalogNameVector = + (VarCharVector) originalRoot.getVector("catalog_name"); + final VarCharVector tableNameVector = + (VarCharVector) originalRoot.getVector("table_name"); + final VarCharVector schemaNameVector = + (VarCharVector) originalRoot.getVector("db_schema_name"); + + final VarBinaryVector schemaVector = + (VarBinaryVector) originalRoot.getVector("table_schema"); + + for (int i = 0; i < originalRootRowCount; i++) { + final Text catalogName = catalogNameVector.getObject(i); + final Text tableName = tableNameVector.getObject(i); + final Text schemaName = schemaNameVector.getObject(i); + + final Schema currentSchema; + try { + currentSchema = MessageSerializer.deserializeSchema( + new ReadChannel(Channels.newChannel( + new ByteArrayInputStream(schemaVector.get(i))))); + } catch (final IOException e) { + throw new IOException( + String.format("Failed to deserialize schema for table %s", tableName), e); + } + final List tableColumns = currentSchema.getFields(); + + columnCounter = setGetColumnsVectorSchemaRootFromFields(transformedRoot, columnCounter, + tableColumns, + catalogName, tableName, schemaName, columnNamePat); + } + + transformedRoot.setRowCount(columnCounter); + + originalRoot.clear(); + return transformedRoot; + }); + } + + private int setGetColumnsVectorSchemaRootFromFields(final VectorSchemaRoot currentRoot, + int insertIndex, + final List tableColumns, + final Text catalogName, + final Text tableName, final Text schemaName, + final Pattern columnNamePattern) { + int ordinalIndex = 1; + final int tableColumnsSize = tableColumns.size(); + + final VarCharVector tableCatVector = (VarCharVector) currentRoot.getVector("TABLE_CAT"); + final VarCharVector tableSchemVector = (VarCharVector) currentRoot.getVector("TABLE_SCHEM"); + final VarCharVector tableNameVector = (VarCharVector) currentRoot.getVector("TABLE_NAME"); + final VarCharVector columnNameVector = (VarCharVector) currentRoot.getVector("COLUMN_NAME"); + final IntVector dataTypeVector = (IntVector) currentRoot.getVector("DATA_TYPE"); + final VarCharVector typeNameVector = (VarCharVector) currentRoot.getVector("TYPE_NAME"); + final IntVector columnSizeVector = (IntVector) currentRoot.getVector("COLUMN_SIZE"); + final IntVector decimalDigitsVector = (IntVector) currentRoot.getVector("DECIMAL_DIGITS"); + final IntVector numPrecRadixVector = (IntVector) currentRoot.getVector("NUM_PREC_RADIX"); + final IntVector nullableVector = (IntVector) currentRoot.getVector("NULLABLE"); + final IntVector ordinalPositionVector = (IntVector) currentRoot.getVector("ORDINAL_POSITION"); + final VarCharVector isNullableVector = (VarCharVector) currentRoot.getVector("IS_NULLABLE"); + final VarCharVector isAutoincrementVector = (VarCharVector) currentRoot.getVector("IS_AUTOINCREMENT"); + final VarCharVector isGeneratedColumnVector = (VarCharVector) currentRoot.getVector("IS_GENERATEDCOLUMN"); + + for (int i = 0; i < tableColumnsSize; i++, ordinalIndex++) { + final Field field = tableColumns.get(i); + final FlightSqlColumnMetadata columnMetadata = new FlightSqlColumnMetadata(field.getMetadata()); + final String columnName = field.getName(); + + if (columnNamePattern != null && !columnNamePattern.matcher(columnName).matches()) { + continue; + } + final ArrowType fieldType = field.getType(); + + if (catalogName != null) { + tableCatVector.setSafe(insertIndex, catalogName); + } + + if (schemaName != null) { + tableSchemVector.setSafe(insertIndex, schemaName); + } + + if (tableName != null) { + tableNameVector.setSafe(insertIndex, tableName); + } + + if (columnName != null) { + columnNameVector.setSafe(insertIndex, columnName.getBytes(CHARSET)); + } + + dataTypeVector.setSafe(insertIndex, SqlTypes.getSqlTypeIdFromArrowType(fieldType)); + byte[] typeName = columnMetadata.getTypeName() != null ? + columnMetadata.getTypeName().getBytes(CHARSET) : + SqlTypes.getSqlTypeNameFromArrowType(fieldType).getBytes(CHARSET); + typeNameVector.setSafe(insertIndex, typeName); + + // We're not setting COLUMN_SIZE for ROWID SQL Types, as there's no such Arrow type. + // We're not setting COLUMN_SIZE nor DECIMAL_DIGITS for Float/Double as their precision and scale are variable. + if (fieldType instanceof ArrowType.Decimal) { + numPrecRadixVector.setSafe(insertIndex, BASE10_RADIX); + } else if (fieldType instanceof ArrowType.Int) { + numPrecRadixVector.setSafe(insertIndex, BASE10_RADIX); + } else if (fieldType instanceof ArrowType.FloatingPoint) { + numPrecRadixVector.setSafe(insertIndex, BASE10_RADIX); + } + + Integer decimalDigits = columnMetadata.getScale(); + if (decimalDigits == null) { + decimalDigits = getDecimalDigits(fieldType); + } + if (decimalDigits != null) { + decimalDigitsVector.setSafe(insertIndex, decimalDigits); + } + + Integer columnSize = columnMetadata.getPrecision(); + if (columnSize == null) { + columnSize = getColumnSize(fieldType); + } + if (columnSize != null) { + columnSizeVector.setSafe(insertIndex, columnSize); + } + + nullableVector.setSafe(insertIndex, field.isNullable() ? 1 : 0); + + isNullableVector.setSafe(insertIndex, booleanToYesOrNo(field.isNullable())); + + Boolean autoIncrement = columnMetadata.isAutoIncrement(); + if (autoIncrement != null) { + isAutoincrementVector.setSafe(insertIndex, booleanToYesOrNo(autoIncrement)); + } else { + isAutoincrementVector.setSafe(insertIndex, EMPTY_BYTE_ARRAY); + } + + // Fields also don't hold information about IS_AUTOINCREMENT and IS_GENERATEDCOLUMN, + // so we're setting an empty string (as bytes), which means it couldn't be determined. + isGeneratedColumnVector.setSafe(insertIndex, EMPTY_BYTE_ARRAY); + + ordinalPositionVector.setSafe(insertIndex, ordinalIndex); + + insertIndex++; + } + return insertIndex; + } + + private static byte[] booleanToYesOrNo(boolean autoIncrement) { + return autoIncrement ? "YES".getBytes(CHARSET) : "NO".getBytes(CHARSET); + } + + static Integer getDecimalDigits(final ArrowType fieldType) { + // We're not setting DECIMAL_DIGITS for Float/Double as their precision and scale are variable. + if (fieldType instanceof ArrowType.Decimal) { + final ArrowType.Decimal thisDecimal = (ArrowType.Decimal) fieldType; + return thisDecimal.getScale(); + } else if (fieldType instanceof ArrowType.Int) { + return NO_DECIMAL_DIGITS; + } else if (fieldType instanceof ArrowType.Timestamp) { + switch (((ArrowType.Timestamp) fieldType).getUnit()) { + case SECOND: + return NO_DECIMAL_DIGITS; + case MILLISECOND: + return DECIMAL_DIGITS_TIME_MILLISECONDS; + case MICROSECOND: + return DECIMAL_DIGITS_TIME_MICROSECONDS; + case NANOSECOND: + return DECIMAL_DIGITS_TIME_NANOSECONDS; + default: + break; + } + } else if (fieldType instanceof ArrowType.Time) { + switch (((ArrowType.Time) fieldType).getUnit()) { + case SECOND: + return NO_DECIMAL_DIGITS; + case MILLISECOND: + return DECIMAL_DIGITS_TIME_MILLISECONDS; + case MICROSECOND: + return DECIMAL_DIGITS_TIME_MICROSECONDS; + case NANOSECOND: + return DECIMAL_DIGITS_TIME_NANOSECONDS; + default: + break; + } + } else if (fieldType instanceof ArrowType.Date) { + return NO_DECIMAL_DIGITS; + } + + return null; + } + + static Integer getColumnSize(final ArrowType fieldType) { + // We're not setting COLUMN_SIZE for ROWID SQL Types, as there's no such Arrow type. + // We're not setting COLUMN_SIZE nor DECIMAL_DIGITS for Float/Double as their precision and scale are variable. + if (fieldType instanceof ArrowType.Decimal) { + final ArrowType.Decimal thisDecimal = (ArrowType.Decimal) fieldType; + return thisDecimal.getPrecision(); + } else if (fieldType instanceof ArrowType.Int) { + final ArrowType.Int thisInt = (ArrowType.Int) fieldType; + switch (thisInt.getBitWidth()) { + case Byte.SIZE: + return COLUMN_SIZE_BYTE; + case Short.SIZE: + return COLUMN_SIZE_SHORT; + case Integer.SIZE: + return COLUMN_SIZE_INT; + case Long.SIZE: + return COLUMN_SIZE_LONG; + default: + break; + } + } else if (fieldType instanceof ArrowType.Utf8 || fieldType instanceof ArrowType.Binary) { + return COLUMN_SIZE_VARCHAR_AND_BINARY; + } else if (fieldType instanceof ArrowType.Timestamp) { + switch (((ArrowType.Timestamp) fieldType).getUnit()) { + case SECOND: + return COLUMN_SIZE_TIMESTAMP_SECONDS; + case MILLISECOND: + return COLUMN_SIZE_TIMESTAMP_MILLISECONDS; + case MICROSECOND: + return COLUMN_SIZE_TIMESTAMP_MICROSECONDS; + case NANOSECOND: + return COLUMN_SIZE_TIMESTAMP_NANOSECONDS; + default: + break; + } + } else if (fieldType instanceof ArrowType.Time) { + switch (((ArrowType.Time) fieldType).getUnit()) { + case SECOND: + return COLUMN_SIZE_TIME; + case MILLISECOND: + return COLUMN_SIZE_TIME_MILLISECONDS; + case MICROSECOND: + return COLUMN_SIZE_TIME_MICROSECONDS; + case NANOSECOND: + return COLUMN_SIZE_TIME_NANOSECONDS; + default: + break; + } + } else if (fieldType instanceof ArrowType.Date) { + return COLUMN_SIZE_DATE; + } + + return null; + } + + static String sqlToRegexLike(final String sqlPattern) { + final int len = sqlPattern.length(); + final StringBuilder javaPattern = new StringBuilder(len + len); + + for (int i = 0; i < len; i++) { + final char currentChar = sqlPattern.charAt(i); + + if (JAVA_REGEX_SPECIALS.indexOf(currentChar) >= 0) { + javaPattern.append('\\'); + } + + switch (currentChar) { + case '_': + javaPattern.append('.'); + break; + case '%': + javaPattern.append("."); + javaPattern.append('*'); + break; + default: + javaPattern.append(currentChar); + break; + } + } + return javaPattern.toString(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java new file mode 100644 index 00000000000..d2b6e89e3fb --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.replaceSemiColons; + +import java.sql.SQLException; +import java.util.Properties; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler; +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.util.Preconditions; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.avatica.AvaticaFactory; + +import io.netty.util.concurrent.DefaultThreadFactory; + +/** + * Connection to the Arrow Flight server. + */ +public final class ArrowFlightConnection extends AvaticaConnection { + + private final BufferAllocator allocator; + private final ArrowFlightSqlClientHandler clientHandler; + private final ArrowFlightConnectionConfigImpl config; + private ExecutorService executorService; + + /** + * Creates a new {@link ArrowFlightConnection}. + * + * @param driver the {@link ArrowFlightJdbcDriver} to use. + * @param factory the {@link AvaticaFactory} to use. + * @param url the URL to use. + * @param properties the {@link Properties} to use. + * @param config the {@link ArrowFlightConnectionConfigImpl} to use. + * @param allocator the {@link BufferAllocator} to use. + * @param clientHandler the {@link ArrowFlightSqlClientHandler} to use. + */ + private ArrowFlightConnection(final ArrowFlightJdbcDriver driver, final AvaticaFactory factory, + final String url, final Properties properties, + final ArrowFlightConnectionConfigImpl config, + final BufferAllocator allocator, + final ArrowFlightSqlClientHandler clientHandler) { + super(driver, factory, url, properties); + this.config = Preconditions.checkNotNull(config, "Config cannot be null."); + this.allocator = Preconditions.checkNotNull(allocator, "Allocator cannot be null."); + this.clientHandler = Preconditions.checkNotNull(clientHandler, "Handler cannot be null."); + } + + /** + * Creates a new {@link ArrowFlightConnection} to a {@link FlightClient}. + * + * @param driver the {@link ArrowFlightJdbcDriver} to use. + * @param factory the {@link AvaticaFactory} to use. + * @param url the URL to establish the connection to. + * @param properties the {@link Properties} to use for this session. + * @param allocator the {@link BufferAllocator} to use. + * @return a new {@link ArrowFlightConnection}. + * @throws SQLException on error. + */ + static ArrowFlightConnection createNewConnection(final ArrowFlightJdbcDriver driver, + final AvaticaFactory factory, + String url, final Properties properties, + final BufferAllocator allocator) + throws SQLException { + url = replaceSemiColons(url); + final ArrowFlightConnectionConfigImpl config = new ArrowFlightConnectionConfigImpl(properties); + final ArrowFlightSqlClientHandler clientHandler = createNewClientHandler(config, allocator); + return new ArrowFlightConnection(driver, factory, url, properties, config, allocator, clientHandler); + } + + private static ArrowFlightSqlClientHandler createNewClientHandler( + final ArrowFlightConnectionConfigImpl config, + final BufferAllocator allocator) throws SQLException { + try { + return new ArrowFlightSqlClientHandler.Builder() + .withHost(config.getHost()) + .withPort(config.getPort()) + .withUsername(config.getUser()) + .withPassword(config.getPassword()) + .withTrustStorePath(config.getTrustStorePath()) + .withTrustStorePassword(config.getTrustStorePassword()) + .withSystemTrustStore(config.useSystemTrustStore()) + .withBufferAllocator(allocator) + .withEncryption(config.useEncryption()) + .withDisableCertificateVerification(config.getDisableCertificateVerification()) + .withToken(config.getToken()) + .withCallOptions(config.toCallOption()) + .build(); + } catch (final SQLException e) { + try { + allocator.close(); + } catch (final Exception allocatorCloseEx) { + e.addSuppressed(allocatorCloseEx); + } + throw e; + } + } + + void reset() throws SQLException { + // Clean up any open Statements + try { + AutoCloseables.close(statementMap.values()); + } catch (final Exception e) { + throw AvaticaConnection.HELPER.createException(e.getMessage(), e); + } + + statementMap.clear(); + + // Reset Holdability + this.setHoldability(this.metaData.getResultSetHoldability()); + + // Reset Meta + ((ArrowFlightMetaImpl) this.meta).setDefaultConnectionProperties(); + } + + /** + * Gets the client {@link #clientHandler} backing this connection. + * + * @return the handler. + */ + ArrowFlightSqlClientHandler getClientHandler() throws SQLException { + return clientHandler; + } + + /** + * Gets the {@link ExecutorService} of this connection. + * + * @return the {@link #executorService}. + */ + synchronized ExecutorService getExecutorService() { + return executorService = executorService == null ? + Executors.newFixedThreadPool(config.threadPoolSize(), + new DefaultThreadFactory(getClass().getSimpleName())) : + executorService; + } + + @Override + public Properties getClientInfo() { + final Properties copy = new Properties(); + copy.putAll(info); + return copy; + } + + @Override + public void close() throws SQLException { + if (executorService != null) { + executorService.shutdown(); + } + + try { + AutoCloseables.close(clientHandler); + allocator.getChildAllocators().forEach(AutoCloseables::closeNoChecked); + AutoCloseables.close(allocator); + + super.close(); + } catch (final Exception e) { + throw AvaticaConnection.HELPER.createException(e.getMessage(), e); + } + } + + BufferAllocator getBufferAllocator() { + return allocator; + } + + public ArrowFlightMetaImpl getMeta() { + return (ArrowFlightMetaImpl) this.meta; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightInfoStatement.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightInfoStatement.java new file mode 100644 index 00000000000..8365c7bb57a --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightInfoStatement.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.SQLException; +import java.sql.Statement; + +import org.apache.arrow.flight.FlightInfo; + +/** + * A {@link Statement} that deals with {@link FlightInfo}. + */ +public interface ArrowFlightInfoStatement extends Statement { + + @Override + ArrowFlightConnection getConnection() throws SQLException; + + /** + * Executes the query this {@link Statement} is holding. + * + * @return the {@link FlightInfo} for the results. + * @throws SQLException on error. + */ + FlightInfo executeFlightInfoQuery() throws SQLException; +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcArray.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcArray.java new file mode 100644 index 00000000000..ed67c97cf69 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcArray.java @@ -0,0 +1,178 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.Array; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.util.Arrays; +import java.util.Map; + +import org.apache.arrow.driver.jdbc.accessor.impl.complex.AbstractArrowFlightJdbcListVectorAccessor; +import org.apache.arrow.driver.jdbc.utils.SqlTypes; +import org.apache.arrow.memory.util.LargeMemoryUtil; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.util.JsonStringArrayList; +import org.apache.arrow.vector.util.TransferPair; + +/** + * Implementation of {@link Array} using an underlying {@link FieldVector}. + * + * @see AbstractArrowFlightJdbcListVectorAccessor + */ +public class ArrowFlightJdbcArray implements Array { + + private final FieldVector dataVector; + private final long startOffset; + private final long valuesCount; + + /** + * Instantiate an {@link Array} backed up by given {@link FieldVector}, limited by a start offset and values count. + * + * @param dataVector underlying FieldVector, containing the Array items. + * @param startOffset offset from FieldVector pointing to this Array's first value. + * @param valuesCount how many items this Array contains. + */ + public ArrowFlightJdbcArray(FieldVector dataVector, long startOffset, long valuesCount) { + this.dataVector = dataVector; + this.startOffset = startOffset; + this.valuesCount = valuesCount; + } + + @Override + public String getBaseTypeName() { + final ArrowType arrowType = this.dataVector.getField().getType(); + return SqlTypes.getSqlTypeNameFromArrowType(arrowType); + } + + @Override + public int getBaseType() { + final ArrowType arrowType = this.dataVector.getField().getType(); + return SqlTypes.getSqlTypeIdFromArrowType(arrowType); + } + + @Override + public Object getArray() throws SQLException { + return getArray(null); + } + + @Override + public Object getArray(Map> map) throws SQLException { + if (map != null) { + throw new SQLFeatureNotSupportedException(); + } + + return getArrayNoBoundCheck(this.dataVector, this.startOffset, this.valuesCount); + } + + @Override + public Object getArray(long index, int count) throws SQLException { + return getArray(index, count, null); + } + + private void checkBoundaries(long index, int count) { + if (index < 0 || index + count > this.startOffset + this.valuesCount) { + throw new ArrayIndexOutOfBoundsException(); + } + } + + private static Object getArrayNoBoundCheck(ValueVector dataVector, long start, long count) { + Object[] result = new Object[LargeMemoryUtil.checkedCastToInt(count)]; + for (int i = 0; i < count; i++) { + result[i] = dataVector.getObject(LargeMemoryUtil.checkedCastToInt(start + i)); + } + + return result; + } + + @Override + public Object getArray(long index, int count, Map> map) throws SQLException { + if (map != null) { + throw new SQLFeatureNotSupportedException(); + } + + checkBoundaries(index, count); + return getArrayNoBoundCheck(this.dataVector, + LargeMemoryUtil.checkedCastToInt(this.startOffset + index), count); + } + + @Override + public ResultSet getResultSet() throws SQLException { + return this.getResultSet(null); + } + + @Override + public ResultSet getResultSet(Map> map) throws SQLException { + if (map != null) { + throw new SQLFeatureNotSupportedException(); + } + + return getResultSetNoBoundariesCheck(this.dataVector, this.startOffset, this.valuesCount); + } + + @Override + public ResultSet getResultSet(long index, int count) throws SQLException { + return getResultSet(index, count, null); + } + + private static ResultSet getResultSetNoBoundariesCheck(ValueVector dataVector, long start, + long count) + throws SQLException { + TransferPair transferPair = dataVector.getTransferPair(dataVector.getAllocator()); + transferPair.splitAndTransfer(LargeMemoryUtil.checkedCastToInt(start), + LargeMemoryUtil.checkedCastToInt(count)); + FieldVector vectorSlice = (FieldVector) transferPair.getTo(); + + VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.of(vectorSlice); + return ArrowFlightJdbcVectorSchemaRootResultSet.fromVectorSchemaRoot(vectorSchemaRoot); + } + + @Override + public ResultSet getResultSet(long index, int count, Map> map) + throws SQLException { + if (map != null) { + throw new SQLFeatureNotSupportedException(); + } + + checkBoundaries(index, count); + return getResultSetNoBoundariesCheck(this.dataVector, + LargeMemoryUtil.checkedCastToInt(this.startOffset + index), count); + } + + @Override + public void free() { + + } + + @Override + public String toString() { + JsonStringArrayList array = new JsonStringArrayList<>((int) this.valuesCount); + + try { + array.addAll(Arrays.asList((Object[]) getArray())); + } catch (SQLException e) { + throw new RuntimeException(e); + } + + return array.toString(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionPoolDataSource.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionPoolDataSource.java new file mode 100644 index 00000000000..46a1d3ff87c --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionPoolDataSource.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.SQLException; +import java.util.Map; +import java.util.Properties; +import java.util.Queue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; + +import javax.sql.ConnectionEvent; +import javax.sql.ConnectionEventListener; +import javax.sql.ConnectionPoolDataSource; +import javax.sql.PooledConnection; + +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl; + +/** + * {@link ConnectionPoolDataSource} implementation for Arrow Flight JDBC Driver. + */ +public class ArrowFlightJdbcConnectionPoolDataSource extends ArrowFlightJdbcDataSource + implements ConnectionPoolDataSource, ConnectionEventListener, AutoCloseable { + private final Map> pool = + new ConcurrentHashMap<>(); + + /** + * Instantiates a new DataSource. + * + * @param properties the properties + * @param config the config. + */ + protected ArrowFlightJdbcConnectionPoolDataSource(final Properties properties, + final ArrowFlightConnectionConfigImpl config) { + super(properties, config); + } + + /** + * Creates a new {@link ArrowFlightJdbcConnectionPoolDataSource}. + * + * @param properties the properties. + * @return a new data source. + */ + public static ArrowFlightJdbcConnectionPoolDataSource createNewDataSource( + final Properties properties) { + return new ArrowFlightJdbcConnectionPoolDataSource(properties, + new ArrowFlightConnectionConfigImpl(properties)); + } + + @Override + public PooledConnection getPooledConnection() throws SQLException { + final ArrowFlightConnectionConfigImpl config = getConfig(); + return this.getPooledConnection(config.getUser(), config.getPassword()); + } + + @Override + public PooledConnection getPooledConnection(final String username, final String password) + throws SQLException { + final Properties properties = getProperties(username, password); + Queue objectPool = + pool.computeIfAbsent(properties, s -> new ConcurrentLinkedQueue<>()); + ArrowFlightJdbcPooledConnection pooledConnection = objectPool.poll(); + if (pooledConnection == null) { + pooledConnection = createPooledConnection(new ArrowFlightConnectionConfigImpl(properties)); + } else { + pooledConnection.reset(); + } + return pooledConnection; + } + + private ArrowFlightJdbcPooledConnection createPooledConnection( + final ArrowFlightConnectionConfigImpl config) + throws SQLException { + ArrowFlightJdbcPooledConnection pooledConnection = + new ArrowFlightJdbcPooledConnection(getConnection(config.getUser(), config.getPassword())); + pooledConnection.addConnectionEventListener(this); + return pooledConnection; + } + + @Override + public void connectionClosed(ConnectionEvent connectionEvent) { + final ArrowFlightJdbcPooledConnection pooledConnection = + (ArrowFlightJdbcPooledConnection) connectionEvent.getSource(); + Queue connectionQueue = + pool.get(pooledConnection.getProperties()); + connectionQueue.add(pooledConnection); + } + + @Override + public void connectionErrorOccurred(ConnectionEvent connectionEvent) { + + } + + @Override + public void close() throws Exception { + SQLException lastException = null; + for (Queue connections : this.pool.values()) { + while (!connections.isEmpty()) { + PooledConnection pooledConnection = connections.poll(); + try { + pooledConnection.close(); + } catch (SQLException e) { + lastException = e; + } + } + } + + if (lastException != null) { + throw lastException; + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcCursor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcCursor.java new file mode 100644 index 00000000000..45c23e4d529 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcCursor.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + + +import java.util.ArrayList; +import java.util.Calendar; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.util.AbstractCursor; +import org.apache.calcite.avatica.util.ArrayImpl; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Arrow Flight Jdbc's Cursor class. + */ +public class ArrowFlightJdbcCursor extends AbstractCursor { + + private static final Logger LOGGER; + private final VectorSchemaRoot root; + private final int rowCount; + private int currentRow = -1; + + static { + LOGGER = LoggerFactory.getLogger(ArrowFlightJdbcCursor.class); + } + + public ArrowFlightJdbcCursor(VectorSchemaRoot root) { + this.root = root; + rowCount = root.getRowCount(); + } + + @Override + public List createAccessors(List columns, + Calendar localCalendar, + ArrayImpl.Factory factory) { + final List fieldVectors = root.getFieldVectors(); + + return IntStream.range(0, fieldVectors.size()).mapToObj(root::getVector) + .map(this::createAccessor) + .collect(Collectors.toCollection(() -> new ArrayList<>(fieldVectors.size()))); + } + + private Accessor createAccessor(FieldVector vector) { + return ArrowFlightJdbcAccessorFactory.createAccessor(vector, this::getCurrentRow, + (boolean wasNull) -> { + // AbstractCursor creates a boolean array of length 1 to hold the wasNull value + this.wasNull[0] = wasNull; + }); + } + + /** + * ArrowFlightJdbcAccessors do not use {@link AbstractCursor.Getter}, as it would box primitive types and cause + * performance issues. Each Accessor implementation works directly on Arrow Vectors. + */ + @Override + protected Getter createGetter(int column) { + throw new UnsupportedOperationException("Not allowed."); + } + + @Override + public boolean next() { + currentRow++; + return currentRow < rowCount; + } + + @Override + public void close() { + try { + AutoCloseables.close(root); + } catch (Exception e) { + LOGGER.error(e.getMessage(), e); + } + } + + private int getCurrentRow() { + return currentRow; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDataSource.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDataSource.java new file mode 100644 index 00000000000..a57eeaa8304 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDataSource.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; + +import java.io.PrintWriter; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.util.Properties; +import java.util.logging.Logger; + +import javax.sql.DataSource; + +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl; +import org.apache.arrow.util.Preconditions; + +/** + * {@link DataSource} implementation for Arrow Flight JDBC Driver. + */ +public class ArrowFlightJdbcDataSource implements DataSource { + private final Properties properties; + private final ArrowFlightConnectionConfigImpl config; + private PrintWriter logWriter; + + /** + * Instantiates a new DataSource. + */ + protected ArrowFlightJdbcDataSource(final Properties properties, + final ArrowFlightConnectionConfigImpl config) { + this.properties = Preconditions.checkNotNull(properties); + this.config = Preconditions.checkNotNull(config); + } + + /** + * Gets the {@link #config} for this {@link ArrowFlightJdbcDataSource}. + * + * @return the {@link ArrowFlightConnectionConfigImpl}. + */ + protected final ArrowFlightConnectionConfigImpl getConfig() { + return config; + } + + /** + * Gets a copy of the {@link #properties} for this {@link ArrowFlightJdbcDataSource} with + * the provided {@code username} and {@code password}. + * + * @return the {@link Properties} for this data source. + */ + protected final Properties getProperties(final String username, final String password) { + final Properties newProperties = new Properties(); + newProperties.putAll(this.properties); + if (username != null) { + newProperties.replace(ArrowFlightConnectionProperty.USER.camelName(), username); + } + if (password != null) { + newProperties.replace(ArrowFlightConnectionProperty.PASSWORD.camelName(), password); + } + return ArrowFlightJdbcDriver.lowerCasePropertyKeys(newProperties); + } + + /** + * Creates a new {@link ArrowFlightJdbcDataSource}. + * + * @param properties the properties. + * @return a new data source. + */ + public static ArrowFlightJdbcDataSource createNewDataSource(final Properties properties) { + return new ArrowFlightJdbcDataSource(properties, + new ArrowFlightConnectionConfigImpl(properties)); + } + + @Override + public ArrowFlightConnection getConnection() throws SQLException { + return getConnection(config.getUser(), config.getPassword()); + } + + @Override + public ArrowFlightConnection getConnection(final String username, final String password) + throws SQLException { + final Properties properties = getProperties(username, password); + return new ArrowFlightJdbcDriver().connect(config.url(), properties); + } + + @Override + public T unwrap(Class aClass) throws SQLException { + throw new SQLException("ArrowFlightJdbcDataSource is not a wrapper."); + } + + @Override + public boolean isWrapperFor(Class aClass) { + return false; + } + + @Override + public PrintWriter getLogWriter() { + return this.logWriter; + } + + @Override + public void setLogWriter(PrintWriter logWriter) { + this.logWriter = logWriter; + } + + @Override + public void setLoginTimeout(int timeout) throws SQLException { + throw new SQLFeatureNotSupportedException("Setting Login timeout is not supported."); + } + + @Override + public int getLoginTimeout() { + return 0; + } + + @Override + public Logger getParentLogger() { + return Logger.getLogger("ArrowFlightJdbc"); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriver.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriver.java new file mode 100644 index 00000000000..a72fbd3a4d5 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriver.java @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.replaceSemiColons; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.Reader; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.sql.SQLException; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; +import org.apache.arrow.driver.jdbc.utils.UrlParser; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.util.VisibleForTesting; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.avatica.DriverVersion; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.UnregisteredDriver; + +/** + * JDBC driver for querying data from an Apache Arrow Flight server. + */ +public class ArrowFlightJdbcDriver extends UnregisteredDriver { + private static final String CONNECT_STRING_PREFIX = "jdbc:arrow-flight-sql://"; + private static final String CONNECT_STRING_PREFIX_DEPRECATED = "jdbc:arrow-flight://"; + private static final String CONNECTION_STRING_EXPECTED = "jdbc:arrow-flight-sql://[host][:port][?param1=value&...]"; + private static DriverVersion version; + + static { + // Special code for supporting Java9 and higher. + // Netty requires some extra properties to unlock some native memory management api + // Setting this property if not already set externally + // This has to be done before any netty class is being loaded + final String key = "cfjd.io.netty.tryReflectionSetAccessible"; + final String tryReflectionSetAccessible = System.getProperty(key); + if (tryReflectionSetAccessible == null) { + System.setProperty(key, Boolean.TRUE.toString()); + } + + new ArrowFlightJdbcDriver().register(); + } + + @Override + public ArrowFlightConnection connect(final String url, final Properties info) + throws SQLException { + final Properties properties = new Properties(info); + properties.putAll(info); + + if (url != null) { + final Map propertiesFromUrl = getUrlsArgs(url); + properties.putAll(propertiesFromUrl); + } + + try { + return ArrowFlightConnection.createNewConnection( + this, + factory, + url, + lowerCasePropertyKeys(properties), + new RootAllocator(Long.MAX_VALUE)); + } catch (final FlightRuntimeException e) { + throw new SQLException("Failed to connect.", e); + } + } + + @Override + protected String getFactoryClassName(final JdbcVersion jdbcVersion) { + return ArrowFlightJdbcFactory.class.getName(); + } + + @Override + protected DriverVersion createDriverVersion() { + if (version == null) { + final InputStream flightProperties = this.getClass().getResourceAsStream("/properties/flight.properties"); + if (flightProperties == null) { + throw new RuntimeException("Flight Properties not found. Ensure the JAR was built properly."); + } + try (final Reader reader = new BufferedReader(new InputStreamReader(flightProperties, StandardCharsets.UTF_8))) { + final Properties properties = new Properties(); + properties.load(reader); + + final String parentName = properties.getProperty("org.apache.arrow.flight.name"); + final String parentVersion = properties.getProperty("org.apache.arrow.flight.version"); + final String[] pVersion = parentVersion.split("\\."); + + final int parentMajorVersion = Integer.parseInt(pVersion[0]); + final int parentMinorVersion = Integer.parseInt(pVersion[1]); + + final String childName = properties.getProperty("org.apache.arrow.flight.jdbc-driver.name"); + final String childVersion = properties.getProperty("org.apache.arrow.flight.jdbc-driver.version"); + final String[] cVersion = childVersion.split("\\."); + + final int childMajorVersion = Integer.parseInt(cVersion[0]); + final int childMinorVersion = Integer.parseInt(cVersion[1]); + + version = new DriverVersion( + childName, + childVersion, + parentName, + parentVersion, + true, + childMajorVersion, + childMinorVersion, + parentMajorVersion, + parentMinorVersion); + } catch (final IOException e) { + throw new RuntimeException("Failed to load driver version.", e); + } + } + + return version; + } + + @Override + public Meta createMeta(final AvaticaConnection connection) { + return new ArrowFlightMetaImpl(connection); + } + + @Override + protected String getConnectStringPrefix() { + return CONNECT_STRING_PREFIX; + } + + @Override + public boolean acceptsURL(final String url) { + Preconditions.checkNotNull(url); + return url.startsWith(CONNECT_STRING_PREFIX) || url.startsWith(CONNECT_STRING_PREFIX_DEPRECATED); + } + + /** + * Parses the provided url based on the format this driver accepts, retrieving + * arguments after the {@link #CONNECT_STRING_PREFIX}. + *

+ * This method gets the args if the provided URL follows this pattern: + * {@code jdbc:arrow-flight-sql://:[/?key1=val1&key2=val2&(...)]} + * + *

+ * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
GroupDefinitionValue
? — inaccessible{@link #getConnectStringPrefix} + * the URL prefix accepted by this driver, i.e., + * {@code "jdbc:arrow-flight-sql://"} + *
1IPv4 host name + * first word after previous group and before "{@code :}" + *
2IPv4 port number + * first number after previous group and before "{@code /?}" + *
3custom call parameters + * all parameters provided after "{@code /?}" — must follow the + * pattern: "{@code key=value}" with "{@code &}" separating a + * parameter from another + *
+ * + * @param url The url to parse. + * @return the parsed arguments. + * @throws SQLException If an error occurs while trying to parse the URL. + */ + @VisibleForTesting // ArrowFlightJdbcDriverTest + Map getUrlsArgs(String url) + throws SQLException { + + /* + * + * Perhaps this logic should be inside a utility class, separated from this + * one, so as to better delegate responsibilities and concerns throughout + * the code and increase maintainability. + * + * ===== + * + * Keep in mind that the URL must ALWAYS follow the pattern: + * "jdbc:arrow-flight-sql://:[/?param1=value1¶m2=value2&(...)]." + * + */ + + final Properties resultMap = new Properties(); + url = replaceSemiColons(url); + + if (!url.startsWith("jdbc:")) { + throw new SQLException("Connection string must start with 'jdbc:'. Expected format: " + + CONNECTION_STRING_EXPECTED); + } + + // It's necessary to use a string without "jdbc:" at the beginning to be parsed as a valid URL. + url = url.substring(5); + + final URI uri; + + try { + uri = URI.create(url); + } catch (final IllegalArgumentException e) { + throw new SQLException("Malformed/invalid URL!", e); + } + + if (!Objects.equals(uri.getScheme(), "arrow-flight") && + !Objects.equals(uri.getScheme(), "arrow-flight-sql")) { + throw new SQLException("URL Scheme must be 'arrow-flight'. Expected format: " + + CONNECTION_STRING_EXPECTED); + } + + if (uri.getHost() == null) { + throw new SQLException("URL must have a host. Expected format: " + CONNECTION_STRING_EXPECTED); + } else if (uri.getPort() < 0) { + throw new SQLException("URL must have a port. Expected format: " + CONNECTION_STRING_EXPECTED); + } + resultMap.put(ArrowFlightConnectionProperty.HOST.camelName(), uri.getHost()); // host + resultMap.put(ArrowFlightConnectionProperty.PORT.camelName(), uri.getPort()); // port + + final String extraParams = uri.getRawQuery(); // optional params + if (extraParams != null) { + final Map keyValuePairs = UrlParser.parse(extraParams, "&"); + resultMap.putAll(keyValuePairs); + } + + return resultMap; + } + + static Properties lowerCasePropertyKeys(final Properties properties) { + final Properties resultProperty = new Properties(); + properties.forEach((k, v) -> resultProperty.put(k.toString().toLowerCase(), v)); + return resultProperty; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactory.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactory.java new file mode 100644 index 00000000000..a54fbb9511b --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactory.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.Properties; +import java.util.TimeZone; + +import org.apache.arrow.memory.RootAllocator; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.avatica.AvaticaFactory; +import org.apache.calcite.avatica.AvaticaResultSetMetaData; +import org.apache.calcite.avatica.AvaticaSpecificDatabaseMetaData; +import org.apache.calcite.avatica.AvaticaStatement; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.QueryState; +import org.apache.calcite.avatica.UnregisteredDriver; + +/** + * Factory for the Arrow Flight JDBC Driver. + */ +public class ArrowFlightJdbcFactory implements AvaticaFactory { + private final int major; + private final int minor; + + // This need to be public so Avatica can call this constructor + public ArrowFlightJdbcFactory() { + this(4, 1); + } + + private ArrowFlightJdbcFactory(final int major, final int minor) { + this.major = major; + this.minor = minor; + } + + @Override + public AvaticaConnection newConnection(final UnregisteredDriver driver, + final AvaticaFactory factory, + final String url, + final Properties info) throws SQLException { + return ArrowFlightConnection.createNewConnection( + (ArrowFlightJdbcDriver) driver, + factory, + url, + info, + new RootAllocator(Long.MAX_VALUE)); + } + + @Override + public AvaticaStatement newStatement( + final AvaticaConnection connection, + final Meta.StatementHandle handle, + final int resultType, + final int resultSetConcurrency, + final int resultSetHoldability) { + return new ArrowFlightStatement((ArrowFlightConnection) connection, + handle, resultType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public ArrowFlightPreparedStatement newPreparedStatement( + final AvaticaConnection connection, + final Meta.StatementHandle statementHandle, + final Meta.Signature signature, + final int resultType, + final int resultSetConcurrency, + final int resultSetHoldability) throws SQLException { + return ArrowFlightPreparedStatement.createNewPreparedStatement( + (ArrowFlightConnection) connection, statementHandle, signature, + resultType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public ArrowFlightJdbcVectorSchemaRootResultSet newResultSet(final AvaticaStatement statement, + final QueryState state, + final Meta.Signature signature, + final TimeZone timeZone, + final Meta.Frame frame) + throws SQLException { + final ResultSetMetaData metaData = newResultSetMetaData(statement, signature); + + return new ArrowFlightJdbcFlightStreamResultSet(statement, state, signature, metaData, timeZone, + frame); + } + + @Override + public AvaticaSpecificDatabaseMetaData newDatabaseMetaData(final AvaticaConnection connection) { + return new ArrowDatabaseMetadata(connection); + } + + @Override + public ResultSetMetaData newResultSetMetaData( + final AvaticaStatement avaticaStatement, + final Meta.Signature signature) { + return new AvaticaResultSetMetaData(avaticaStatement, + null, signature); + } + + @Override + public int getJdbcMajorVersion() { + return major; + } + + @Override + public int getJdbcMinorVersion() { + return minor; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java new file mode 100644 index 00000000000..4c01cb6e581 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java @@ -0,0 +1,250 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.apache.arrow.driver.jdbc.utils.FlightStreamQueue.createNewQueue; + +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.Optional; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; + +import org.apache.arrow.driver.jdbc.utils.FlightStreamQueue; +import org.apache.arrow.driver.jdbc.utils.VectorSchemaRootTransformer; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.AvaticaResultSet; +import org.apache.calcite.avatica.AvaticaResultSetMetaData; +import org.apache.calcite.avatica.AvaticaStatement; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.QueryState; + +/** + * {@link ResultSet} implementation for Arrow Flight used to access the results of multiple {@link FlightStream} + * objects. + */ +public final class ArrowFlightJdbcFlightStreamResultSet + extends ArrowFlightJdbcVectorSchemaRootResultSet { + + private final ArrowFlightConnection connection; + private FlightStream currentFlightStream; + private FlightStreamQueue flightStreamQueue; + + private VectorSchemaRootTransformer transformer; + private VectorSchemaRoot currentVectorSchemaRoot; + + private Schema schema; + + ArrowFlightJdbcFlightStreamResultSet(final AvaticaStatement statement, + final QueryState state, + final Meta.Signature signature, + final ResultSetMetaData resultSetMetaData, + final TimeZone timeZone, + final Meta.Frame firstFrame) throws SQLException { + super(statement, state, signature, resultSetMetaData, timeZone, firstFrame); + this.connection = (ArrowFlightConnection) statement.connection; + } + + ArrowFlightJdbcFlightStreamResultSet(final ArrowFlightConnection connection, + final QueryState state, + final Meta.Signature signature, + final ResultSetMetaData resultSetMetaData, + final TimeZone timeZone, + final Meta.Frame firstFrame) throws SQLException { + super(null, state, signature, resultSetMetaData, timeZone, firstFrame); + this.connection = connection; + } + + /** + * Create a {@link ResultSet} which pulls data from given {@link FlightInfo}. + * + * @param connection The connection linked to the returned ResultSet. + * @param flightInfo The FlightInfo from which data will be iterated by the returned ResultSet. + * @param transformer Optional transformer for processing VectorSchemaRoot before access from ResultSet + * @return A ResultSet which pulls data from given FlightInfo. + */ + static ArrowFlightJdbcFlightStreamResultSet fromFlightInfo( + final ArrowFlightConnection connection, + final FlightInfo flightInfo, + final VectorSchemaRootTransformer transformer) throws SQLException { + // Similar to how org.apache.calcite.avatica.util.ArrayFactoryImpl does + + final TimeZone timeZone = TimeZone.getDefault(); + final QueryState state = new QueryState(); + + final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null); + + final AvaticaResultSetMetaData resultSetMetaData = + new AvaticaResultSetMetaData(null, null, signature); + final ArrowFlightJdbcFlightStreamResultSet resultSet = + new ArrowFlightJdbcFlightStreamResultSet(connection, state, signature, resultSetMetaData, + timeZone, null); + + resultSet.transformer = transformer; + + resultSet.execute(flightInfo); + return resultSet; + } + + private void loadNewQueue() { + Optional.ofNullable(flightStreamQueue).ifPresent(AutoCloseables::closeNoChecked); + flightStreamQueue = createNewQueue(connection.getExecutorService()); + } + + private void loadNewFlightStream() throws SQLException { + if (currentFlightStream != null) { + AutoCloseables.closeNoChecked(currentFlightStream); + } + this.currentFlightStream = getNextFlightStream(true); + } + + @Override + protected AvaticaResultSet execute() throws SQLException { + final FlightInfo flightInfo = ((ArrowFlightInfoStatement) statement).executeFlightInfoQuery(); + + if (flightInfo != null) { + schema = flightInfo.getSchema(); + execute(flightInfo); + } + return this; + } + + private void execute(final FlightInfo flightInfo) throws SQLException { + loadNewQueue(); + flightStreamQueue.enqueue(connection.getClientHandler().getStreams(flightInfo)); + loadNewFlightStream(); + + // Ownership of the root will be passed onto the cursor. + if (currentFlightStream != null) { + executeForCurrentFlightStream(); + } + } + + private void executeForCurrentFlightStream() throws SQLException { + final VectorSchemaRoot originalRoot = currentFlightStream.getRoot(); + + if (transformer != null) { + try { + currentVectorSchemaRoot = transformer.transform(originalRoot, currentVectorSchemaRoot); + } catch (final Exception e) { + throw new SQLException("Failed to transform VectorSchemaRoot.", e); + } + } else { + currentVectorSchemaRoot = originalRoot; + } + + if (schema != null) { + execute(currentVectorSchemaRoot, schema); + } else { + execute(currentVectorSchemaRoot); + } + } + + @Override + public boolean next() throws SQLException { + if (currentVectorSchemaRoot == null) { + return false; + } + while (true) { + final boolean hasNext = super.next(); + final int maxRows = statement != null ? statement.getMaxRows() : 0; + if (maxRows != 0 && this.getRow() > maxRows) { + if (statement.isCloseOnCompletion()) { + statement.close(); + } + return false; + } + + if (hasNext) { + return true; + } + + if (currentFlightStream != null) { + currentFlightStream.getRoot().clear(); + if (currentFlightStream.next()) { + executeForCurrentFlightStream(); + continue; + } + + flightStreamQueue.enqueue(currentFlightStream); + } + + currentFlightStream = getNextFlightStream(false); + + if (currentFlightStream != null) { + executeForCurrentFlightStream(); + continue; + } + + if (statement != null && statement.isCloseOnCompletion()) { + statement.close(); + } + + return false; + } + } + + @Override + protected void cancel() { + super.cancel(); + final FlightStream currentFlightStream = this.currentFlightStream; + if (currentFlightStream != null) { + currentFlightStream.cancel("Cancel", null); + } + + if (flightStreamQueue != null) { + try { + flightStreamQueue.close(); + } catch (final Exception e) { + throw new RuntimeException(e); + } + } + } + + @Override + public synchronized void close() { + try { + if (flightStreamQueue != null) { + // flightStreamQueue should close currentFlightStream internally + flightStreamQueue.close(); + } else if (currentFlightStream != null) { + // close is only called for currentFlightStream if there's no queue + currentFlightStream.close(); + } + } catch (final Exception e) { + throw new RuntimeException(e); + } finally { + super.close(); + } + } + + private FlightStream getNextFlightStream(final boolean isExecution) throws SQLException { + if (isExecution) { + final int statementTimeout = statement != null ? statement.getQueryTimeout() : 0; + return statementTimeout != 0 ? + flightStreamQueue.next(statementTimeout, TimeUnit.SECONDS) : flightStreamQueue.next(); + } else { + return flightStreamQueue.next(); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcPooledConnection.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcPooledConnection.java new file mode 100644 index 00000000000..96a2d9dda1d --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcPooledConnection.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Collections; +import java.util.HashSet; +import java.util.Properties; +import java.util.Set; + +import javax.sql.ConnectionEvent; +import javax.sql.ConnectionEventListener; +import javax.sql.PooledConnection; +import javax.sql.StatementEventListener; + +import org.apache.arrow.driver.jdbc.utils.ConnectionWrapper; + +/** + * {@link PooledConnection} implementation for Arrow Flight JDBC Driver. + */ +public class ArrowFlightJdbcPooledConnection implements PooledConnection { + + private final ArrowFlightConnection connection; + private final Set eventListeners; + private final Set statementEventListeners; + + private final class ConnectionHandle extends ConnectionWrapper { + private boolean closed = false; + + public ConnectionHandle() { + super(connection); + } + + @Override + public void close() throws SQLException { + if (!closed) { + closed = true; + onConnectionClosed(); + } + } + + @Override + public boolean isClosed() throws SQLException { + return this.closed || super.isClosed(); + } + } + + ArrowFlightJdbcPooledConnection(ArrowFlightConnection connection) { + this.connection = connection; + this.eventListeners = Collections.synchronizedSet(new HashSet<>()); + this.statementEventListeners = Collections.synchronizedSet(new HashSet<>()); + } + + public Properties getProperties() { + return connection.getClientInfo(); + } + + @Override + public Connection getConnection() throws SQLException { + return new ConnectionHandle(); + } + + @Override + public void close() throws SQLException { + this.connection.close(); + } + + void reset() throws SQLException { + this.connection.reset(); + } + + @Override + public void addConnectionEventListener(ConnectionEventListener listener) { + eventListeners.add(listener); + } + + @Override + public void removeConnectionEventListener(ConnectionEventListener listener) { + this.eventListeners.remove(listener); + } + + @Override + public void addStatementEventListener(StatementEventListener listener) { + statementEventListeners.add(listener); + } + + @Override + public void removeStatementEventListener(StatementEventListener listener) { + this.statementEventListeners.remove(listener); + } + + private void onConnectionClosed() { + ConnectionEvent connectionEvent = new ConnectionEvent(this); + eventListeners.forEach(listener -> listener.connectionClosed(connectionEvent)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcTime.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcTime.java new file mode 100644 index 00000000000..109048bc05c --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcTime.java @@ -0,0 +1,106 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY; + +import java.sql.Time; +import java.time.LocalTime; +import java.time.temporal.ChronoField; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.TimeUnit; + +import org.apache.arrow.util.VisibleForTesting; + +import com.google.common.collect.ImmutableList; + +/** + * Wrapper class for Time objects to include the milliseconds part in ISO 8601 format in this#toString. + */ +public class ArrowFlightJdbcTime extends Time { + private static final List LEADING_ZEROES = ImmutableList.of("", "0", "00"); + + // Desired length of the millisecond portion should be 3 + private static final int DESIRED_MILLIS_LENGTH = 3; + + // Millis of the date time object. + private final int millisReprValue; + + /** + * Constructs this object based on epoch millis. + * + * @param milliseconds milliseconds representing Time. + */ + public ArrowFlightJdbcTime(final long milliseconds) { + super(milliseconds); + millisReprValue = getMillisReprValue(milliseconds); + } + + @VisibleForTesting + ArrowFlightJdbcTime(final LocalTime time) { + // Although the constructor is deprecated, this is the exact same code as Time#valueOf(LocalTime) + super(time.getHour(), time.getMinute(), time.getSecond()); + millisReprValue = time.get(ChronoField.MILLI_OF_SECOND); + } + + private int getMillisReprValue(long milliseconds) { + // Extract the millisecond part from epoch nano day + if (milliseconds >= MILLIS_PER_DAY) { + // Convert to Epoch Day + milliseconds %= MILLIS_PER_DAY; + } else if (milliseconds < 0) { + // LocalTime#ofNanoDay only accepts positive values + milliseconds -= ((milliseconds / MILLIS_PER_DAY) - 1) * MILLIS_PER_DAY; + } + return LocalTime.ofNanoOfDay(TimeUnit.MILLISECONDS.toNanos(milliseconds)) + .get(ChronoField.MILLI_OF_SECOND); + } + + @Override + public String toString() { + final StringBuilder time = new StringBuilder().append(super.toString()); + + if (millisReprValue > 0) { + final String millisString = Integer.toString(millisReprValue); + + // dot to separate the fractional seconds + time.append("."); + + final int millisLength = millisString.length(); + if (millisLength < DESIRED_MILLIS_LENGTH) { + // add necessary leading zeroes + time.append(LEADING_ZEROES.get(DESIRED_MILLIS_LENGTH - millisLength)); + } + time.append(millisString); + } + + return time.toString(); + } + + // Spotbugs requires these methods to be overridden + @Override + public boolean equals(Object obj) { + return super.equals(obj); + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), this.millisReprValue); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java new file mode 100644 index 00000000000..9e377e51dec --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcVectorSchemaRootResultSet.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static java.util.Objects.isNull; + +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.TimeZone; + +import org.apache.arrow.driver.jdbc.utils.ConvertUtils; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.AvaticaResultSet; +import org.apache.calcite.avatica.AvaticaResultSetMetaData; +import org.apache.calcite.avatica.AvaticaStatement; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.Meta.Frame; +import org.apache.calcite.avatica.Meta.Signature; +import org.apache.calcite.avatica.QueryState; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * {@link ResultSet} implementation used to access a {@link VectorSchemaRoot}. + */ +public class ArrowFlightJdbcVectorSchemaRootResultSet extends AvaticaResultSet { + + private static final Logger LOGGER = + LoggerFactory.getLogger(ArrowFlightJdbcVectorSchemaRootResultSet.class); + VectorSchemaRoot vectorSchemaRoot; + + ArrowFlightJdbcVectorSchemaRootResultSet(final AvaticaStatement statement, final QueryState state, + final Signature signature, + final ResultSetMetaData resultSetMetaData, + final TimeZone timeZone, final Frame firstFrame) + throws SQLException { + super(statement, state, signature, resultSetMetaData, timeZone, firstFrame); + } + + /** + * Instantiate a ResultSet backed up by given VectorSchemaRoot. + * + * @param vectorSchemaRoot root from which the ResultSet will access. + * @return a ResultSet which accesses the given VectorSchemaRoot + */ + public static ArrowFlightJdbcVectorSchemaRootResultSet fromVectorSchemaRoot( + final VectorSchemaRoot vectorSchemaRoot) + throws SQLException { + // Similar to how org.apache.calcite.avatica.util.ArrayFactoryImpl does + + final TimeZone timeZone = TimeZone.getDefault(); + final QueryState state = new QueryState(); + + final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null); + + final AvaticaResultSetMetaData resultSetMetaData = + new AvaticaResultSetMetaData(null, null, signature); + final ArrowFlightJdbcVectorSchemaRootResultSet + resultSet = + new ArrowFlightJdbcVectorSchemaRootResultSet(null, state, signature, resultSetMetaData, + timeZone, null); + + resultSet.execute(vectorSchemaRoot); + return resultSet; + } + + @Override + protected AvaticaResultSet execute() throws SQLException { + throw new RuntimeException("Can only execute with execute(VectorSchemaRoot)"); + } + + void execute(final VectorSchemaRoot vectorSchemaRoot) { + final List fields = vectorSchemaRoot.getSchema().getFields(); + final List columns = ConvertUtils.convertArrowFieldsToColumnMetaDataList(fields); + signature.columns.clear(); + signature.columns.addAll(columns); + + this.vectorSchemaRoot = vectorSchemaRoot; + execute2(new ArrowFlightJdbcCursor(vectorSchemaRoot), this.signature.columns); + } + + void execute(final VectorSchemaRoot vectorSchemaRoot, final Schema schema) { + final List columns = ConvertUtils.convertArrowFieldsToColumnMetaDataList(schema.getFields()); + signature.columns.clear(); + signature.columns.addAll(columns); + + this.vectorSchemaRoot = vectorSchemaRoot; + execute2(new ArrowFlightJdbcCursor(vectorSchemaRoot), this.signature.columns); + } + + @Override + protected void cancel() { + signature.columns.clear(); + super.cancel(); + try { + AutoCloseables.close(vectorSchemaRoot); + } catch (final Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public void close() { + final Set exceptions = new HashSet<>(); + try { + if (isClosed()) { + return; + } + } catch (final SQLException e) { + exceptions.add(e); + } + try { + AutoCloseables.close(vectorSchemaRoot); + } catch (final Exception e) { + exceptions.add(e); + } + if (!isNull(statement)) { + try { + super.close(); + } catch (final Exception e) { + exceptions.add(e); + } + } + exceptions.parallelStream().forEach(e -> LOGGER.error(e.getMessage(), e)); + exceptions.stream().findAny().ifPresent(e -> { + throw new RuntimeException(e); + }); + } + +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java new file mode 100644 index 00000000000..cc7addc3a74 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static java.lang.String.format; + +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.SQLTimeoutException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler.PreparedStatement; +import org.apache.arrow.util.Preconditions; +import org.apache.calcite.avatica.AvaticaConnection; +import org.apache.calcite.avatica.AvaticaParameter; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.MetaImpl; +import org.apache.calcite.avatica.NoSuchStatementException; +import org.apache.calcite.avatica.QueryState; +import org.apache.calcite.avatica.remote.TypedValue; + +/** + * Metadata handler for Arrow Flight. + */ +public class ArrowFlightMetaImpl extends MetaImpl { + private final Map statementHandlePreparedStatementMap; + + /** + * Constructs a {@link MetaImpl} object specific for Arrow Flight. + * @param connection A {@link AvaticaConnection}. + */ + public ArrowFlightMetaImpl(final AvaticaConnection connection) { + super(connection); + this.statementHandlePreparedStatementMap = new ConcurrentHashMap<>(); + setDefaultConnectionProperties(); + } + + static Signature newSignature(final String sql) { + return new Signature( + new ArrayList(), + sql, + Collections.emptyList(), + Collections.emptyMap(), + null, // unnecessary, as SQL requests use ArrowFlightJdbcCursor + StatementType.SELECT + ); + } + + @Override + public void closeStatement(final StatementHandle statementHandle) { + PreparedStatement preparedStatement = statementHandlePreparedStatementMap.remove(statementHandle); + // Testing if the prepared statement was created because the statement can be not created until this moment + if (preparedStatement != null) { + preparedStatement.close(); + } + } + + @Override + public void commit(final ConnectionHandle connectionHandle) { + // TODO Fill this stub. + } + + @Override + public ExecuteResult execute(final StatementHandle statementHandle, + final List typedValues, final long maxRowCount) { + // TODO Why is maxRowCount ignored? + Preconditions.checkNotNull(statementHandle.signature, "Signature not found."); + return new ExecuteResult( + Collections.singletonList(MetaResultSet.create( + statementHandle.connectionId, statementHandle.id, + true, statementHandle.signature, null))); + } + + @Override + public ExecuteResult execute(final StatementHandle statementHandle, + final List typedValues, final int maxRowsInFirstFrame) { + return execute(statementHandle, typedValues, (long) maxRowsInFirstFrame); + } + + @Override + public ExecuteBatchResult executeBatch(final StatementHandle statementHandle, + final List> parameterValuesList) + throws IllegalStateException { + throw new IllegalStateException("executeBatch not implemented."); + } + + @Override + public Frame fetch(final StatementHandle statementHandle, final long offset, + final int fetchMaxRowCount) { + /* + * ArrowFlightMetaImpl does not use frames. + * Instead, we have accessors that contain a VectorSchemaRoot with + * the results. + */ + throw AvaticaConnection.HELPER.wrap( + format("%s does not use frames.", this), + AvaticaConnection.HELPER.unsupported()); + } + + @Override + public StatementHandle prepare(final ConnectionHandle connectionHandle, + final String query, final long maxRowCount) { + final StatementHandle handle = super.createStatement(connectionHandle); + handle.signature = newSignature(query); + return handle; + } + + @Override + public ExecuteResult prepareAndExecute(final StatementHandle statementHandle, + final String query, final long maxRowCount, + final PrepareCallback prepareCallback) + throws NoSuchStatementException { + return prepareAndExecute( + statementHandle, query, maxRowCount, -1 /* Not used */, prepareCallback); + } + + @Override + public ExecuteResult prepareAndExecute(final StatementHandle handle, + final String query, final long maxRowCount, + final int maxRowsInFirstFrame, + final PrepareCallback callback) + throws NoSuchStatementException { + try { + final PreparedStatement preparedStatement = + ((ArrowFlightConnection) connection).getClientHandler().prepare(query); + final StatementType statementType = preparedStatement.getType(); + statementHandlePreparedStatementMap.put(handle, preparedStatement); + final Signature signature = newSignature(query); + final long updateCount = + statementType.equals(StatementType.UPDATE) ? preparedStatement.executeUpdate() : -1; + synchronized (callback.getMonitor()) { + callback.clear(); + callback.assign(signature, null, updateCount); + } + callback.execute(); + final MetaResultSet metaResultSet = MetaResultSet.create(handle.connectionId, handle.id, + false, signature, null); + return new ExecuteResult(Collections.singletonList(metaResultSet)); + } catch (SQLTimeoutException e) { + // So far AvaticaStatement(executeInternal) only handles NoSuchStatement and Runtime Exceptions. + throw new RuntimeException(e); + } catch (SQLException e) { + throw new NoSuchStatementException(handle); + } + } + + @Override + public ExecuteBatchResult prepareAndExecuteBatch( + final StatementHandle statementHandle, final List queries) + throws NoSuchStatementException { + // TODO Fill this stub. + return null; + } + + @Override + public void rollback(final ConnectionHandle connectionHandle) { + // TODO Fill this stub. + } + + @Override + public boolean syncResults(final StatementHandle statementHandle, + final QueryState queryState, final long offset) + throws NoSuchStatementException { + // TODO Fill this stub. + return false; + } + + void setDefaultConnectionProperties() { + // TODO Double-check this. + connProps.setDirty(false) + .setAutoCommit(true) + .setReadOnly(true) + .setCatalog(null) + .setSchema(null) + .setTransactionIsolation(Connection.TRANSACTION_NONE); + } + + PreparedStatement getPreparedStatement(StatementHandle statementHandle) { + return statementHandlePreparedStatementMap.get(statementHandle); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java new file mode 100644 index 00000000000..80029f38f09 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.SQLException; + +import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler; +import org.apache.arrow.driver.jdbc.utils.ConvertUtils; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.AvaticaPreparedStatement; +import org.apache.calcite.avatica.Meta.Signature; +import org.apache.calcite.avatica.Meta.StatementHandle; + + +/** + * Arrow Flight JBCS's implementation {@link PreparedStatement}. + */ +public class ArrowFlightPreparedStatement extends AvaticaPreparedStatement + implements ArrowFlightInfoStatement { + + private final ArrowFlightSqlClientHandler.PreparedStatement preparedStatement; + + private ArrowFlightPreparedStatement(final ArrowFlightConnection connection, + final ArrowFlightSqlClientHandler.PreparedStatement preparedStatement, + final StatementHandle handle, + final Signature signature, final int resultSetType, + final int resultSetConcurrency, + final int resultSetHoldability) + throws SQLException { + super(connection, handle, signature, resultSetType, resultSetConcurrency, resultSetHoldability); + this.preparedStatement = Preconditions.checkNotNull(preparedStatement); + } + + /** + * Creates a new {@link ArrowFlightPreparedStatement} from the provided information. + * + * @param connection the {@link Connection} to use. + * @param statementHandle the {@link StatementHandle} to use. + * @param signature the {@link Signature} to use. + * @param resultSetType the ResultSet type. + * @param resultSetConcurrency the ResultSet concurrency. + * @param resultSetHoldability the ResultSet holdability. + * @return a new {@link PreparedStatement}. + * @throws SQLException on error. + */ + static ArrowFlightPreparedStatement createNewPreparedStatement( + final ArrowFlightConnection connection, + final StatementHandle statementHandle, + final Signature signature, + final int resultSetType, + final int resultSetConcurrency, + final int resultSetHoldability) throws SQLException { + + final ArrowFlightSqlClientHandler.PreparedStatement prepare = connection.getClientHandler().prepare(signature.sql); + final Schema resultSetSchema = prepare.getDataSetSchema(); + + signature.columns.addAll(ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields())); + + return new ArrowFlightPreparedStatement( + connection, prepare, statementHandle, + signature, resultSetType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public ArrowFlightConnection getConnection() throws SQLException { + return (ArrowFlightConnection) super.getConnection(); + } + + @Override + public synchronized void close() throws SQLException { + this.preparedStatement.close(); + super.close(); + } + + @Override + public FlightInfo executeFlightInfoQuery() throws SQLException { + return preparedStatement.executeQuery(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java new file mode 100644 index 00000000000..5bc7c2ab9b4 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightStatement.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.SQLException; + +import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler.PreparedStatement; +import org.apache.arrow.driver.jdbc.utils.ConvertUtils; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.AvaticaStatement; +import org.apache.calcite.avatica.Meta; +import org.apache.calcite.avatica.Meta.StatementHandle; + +/** + * A SQL statement for querying data from an Arrow Flight server. + */ +public class ArrowFlightStatement extends AvaticaStatement implements ArrowFlightInfoStatement { + + ArrowFlightStatement(final ArrowFlightConnection connection, + final StatementHandle handle, final int resultSetType, + final int resultSetConcurrency, final int resultSetHoldability) { + super(connection, handle, resultSetType, resultSetConcurrency, resultSetHoldability); + } + + @Override + public ArrowFlightConnection getConnection() throws SQLException { + return (ArrowFlightConnection) super.getConnection(); + } + + @Override + public FlightInfo executeFlightInfoQuery() throws SQLException { + final PreparedStatement preparedStatement = getConnection().getMeta().getPreparedStatement(handle); + final Meta.Signature signature = getSignature(); + if (signature == null) { + return null; + } + + final Schema resultSetSchema = preparedStatement.getDataSetSchema(); + signature.columns.addAll(ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields())); + setSignature(signature); + + return preparedStatement.executeQuery(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessor.java new file mode 100644 index 00000000000..3821ee1dc87 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessor.java @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor; + +import static org.apache.calcite.avatica.util.Cursor.Accessor; + +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.net.URL; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Date; +import java.sql.NClob; +import java.sql.Ref; +import java.sql.SQLException; +import java.sql.SQLXML; +import java.sql.Struct; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; +import java.util.Map; +import java.util.function.IntSupplier; + +/** + * Base Jdbc Accessor. + */ +public abstract class ArrowFlightJdbcAccessor implements Accessor { + private final IntSupplier currentRowSupplier; + + // All the derived accessor classes should alter this as they encounter null Values + protected boolean wasNull; + protected ArrowFlightJdbcAccessorFactory.WasNullConsumer wasNullConsumer; + + protected ArrowFlightJdbcAccessor(final IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer wasNullConsumer) { + this.currentRowSupplier = currentRowSupplier; + this.wasNullConsumer = wasNullConsumer; + } + + protected int getCurrentRow() { + return currentRowSupplier.getAsInt(); + } + + // It needs to be public so this method can be accessed when creating the complex types. + public abstract Class getObjectClass(); + + @Override + public boolean wasNull() { + return wasNull; + } + + @Override + public String getString() throws SQLException { + final Object object = getObject(); + if (object == null) { + return null; + } + + return object.toString(); + } + + @Override + public boolean getBoolean() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public byte getByte() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public short getShort() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public int getInt() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public long getLong() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public float getFloat() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public double getDouble() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public BigDecimal getBigDecimal() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public BigDecimal getBigDecimal(final int i) throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public byte[] getBytes() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public InputStream getAsciiStream() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public InputStream getUnicodeStream() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public InputStream getBinaryStream() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Object getObject() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Reader getCharacterStream() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Object getObject(final Map> map) throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Ref getRef() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Blob getBlob() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Clob getClob() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Array getArray() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Struct getStruct() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Date getDate(final Calendar calendar) throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Time getTime(final Calendar calendar) throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Timestamp getTimestamp(final Calendar calendar) throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public URL getURL() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public NClob getNClob() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public SQLXML getSQLXML() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public String getNString() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public Reader getNCharacterStream() throws SQLException { + throw getOperationNotSupported(this.getClass()); + } + + @Override + public T getObject(final Class type) throws SQLException { + final Object value; + if (type == Byte.class) { + value = getByte(); + } else if (type == Short.class) { + value = getShort(); + } else if (type == Integer.class) { + value = getInt(); + } else if (type == Long.class) { + value = getLong(); + } else if (type == Float.class) { + value = getFloat(); + } else if (type == Double.class) { + value = getDouble(); + } else if (type == Boolean.class) { + value = getBoolean(); + } else if (type == BigDecimal.class) { + value = getBigDecimal(); + } else if (type == String.class) { + value = getString(); + } else if (type == byte[].class) { + value = getBytes(); + } else { + value = getObject(); + } + return !type.isPrimitive() && wasNull ? null : type.cast(value); + } + + private static SQLException getOperationNotSupported(final Class type) { + return new SQLException(String.format("Operation not supported for type: %s.", type.getName())); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactory.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactory.java new file mode 100644 index 00000000000..813b40a8070 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactory.java @@ -0,0 +1,214 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor; + +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.impl.ArrowFlightJdbcNullVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.binary.ArrowFlightJdbcBinaryVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDateVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDurationVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcIntervalVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeStampVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcDenseUnionVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcFixedSizeListVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcLargeListVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcListVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcMapVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcStructVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcUnionVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcBaseIntVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcBitVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcDecimalVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcFloat4VectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcFloat8VectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.text.ArrowFlightJdbcVarCharVectorAccessor; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.Decimal256Vector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.DurationVector; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.IntervalYearVector; +import org.apache.arrow.vector.LargeVarBinaryVector; +import org.apache.arrow.vector.LargeVarCharVector; +import org.apache.arrow.vector.NullVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.UnionVector; + +/** + * Factory to instantiate the accessors. + */ +public class ArrowFlightJdbcAccessorFactory { + + /** + * Create an accessor according to its type. + * + * @param vector an instance of an arrow vector. + * @param getCurrentRow a supplier to check which row is being accessed. + * @return an instance of one of the accessors. + */ + public static ArrowFlightJdbcAccessor createAccessor(ValueVector vector, + IntSupplier getCurrentRow, + WasNullConsumer setCursorWasNull) { + if (vector instanceof UInt1Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt1Vector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof UInt2Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt2Vector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof UInt4Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt4Vector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof UInt8Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt8Vector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof TinyIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((TinyIntVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof SmallIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((SmallIntVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof IntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((IntVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof BigIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((BigIntVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof Float4Vector) { + return new ArrowFlightJdbcFloat4VectorAccessor((Float4Vector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof Float8Vector) { + return new ArrowFlightJdbcFloat8VectorAccessor((Float8Vector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof BitVector) { + return new ArrowFlightJdbcBitVectorAccessor((BitVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof DecimalVector) { + return new ArrowFlightJdbcDecimalVectorAccessor((DecimalVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof Decimal256Vector) { + return new ArrowFlightJdbcDecimalVectorAccessor((Decimal256Vector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof VarBinaryVector) { + return new ArrowFlightJdbcBinaryVectorAccessor((VarBinaryVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof LargeVarBinaryVector) { + return new ArrowFlightJdbcBinaryVectorAccessor((LargeVarBinaryVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof FixedSizeBinaryVector) { + return new ArrowFlightJdbcBinaryVectorAccessor((FixedSizeBinaryVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof TimeStampVector) { + return new ArrowFlightJdbcTimeStampVectorAccessor((TimeStampVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof TimeNanoVector) { + return new ArrowFlightJdbcTimeVectorAccessor((TimeNanoVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof TimeMicroVector) { + return new ArrowFlightJdbcTimeVectorAccessor((TimeMicroVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof TimeMilliVector) { + return new ArrowFlightJdbcTimeVectorAccessor((TimeMilliVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof TimeSecVector) { + return new ArrowFlightJdbcTimeVectorAccessor((TimeSecVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof DateDayVector) { + return new ArrowFlightJdbcDateVectorAccessor(((DateDayVector) vector), getCurrentRow, + setCursorWasNull); + } else if (vector instanceof DateMilliVector) { + return new ArrowFlightJdbcDateVectorAccessor(((DateMilliVector) vector), getCurrentRow, + setCursorWasNull); + } else if (vector instanceof VarCharVector) { + return new ArrowFlightJdbcVarCharVectorAccessor((VarCharVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof LargeVarCharVector) { + return new ArrowFlightJdbcVarCharVectorAccessor((LargeVarCharVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof DurationVector) { + return new ArrowFlightJdbcDurationVectorAccessor((DurationVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof IntervalDayVector) { + return new ArrowFlightJdbcIntervalVectorAccessor(((IntervalDayVector) vector), getCurrentRow, + setCursorWasNull); + } else if (vector instanceof IntervalYearVector) { + return new ArrowFlightJdbcIntervalVectorAccessor(((IntervalYearVector) vector), getCurrentRow, + setCursorWasNull); + } else if (vector instanceof StructVector) { + return new ArrowFlightJdbcStructVectorAccessor((StructVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof MapVector) { + return new ArrowFlightJdbcMapVectorAccessor((MapVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof ListVector) { + return new ArrowFlightJdbcListVectorAccessor((ListVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof LargeListVector) { + return new ArrowFlightJdbcLargeListVectorAccessor((LargeListVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof FixedSizeListVector) { + return new ArrowFlightJdbcFixedSizeListVectorAccessor((FixedSizeListVector) vector, + getCurrentRow, setCursorWasNull); + } else if (vector instanceof UnionVector) { + return new ArrowFlightJdbcUnionVectorAccessor((UnionVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof DenseUnionVector) { + return new ArrowFlightJdbcDenseUnionVectorAccessor((DenseUnionVector) vector, getCurrentRow, + setCursorWasNull); + } else if (vector instanceof NullVector || vector == null) { + return new ArrowFlightJdbcNullVectorAccessor(setCursorWasNull); + } + + throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass().getName()); + } + + /** + * Functional interface used to propagate that the value accessed was null or not. + */ + @FunctionalInterface + public interface WasNullConsumer { + void setWasNull(boolean wasNull); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/ArrowFlightJdbcNullVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/ArrowFlightJdbcNullVectorAccessor.java new file mode 100644 index 00000000000..f40a5797293 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/ArrowFlightJdbcNullVectorAccessor.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.NullVector; + +/** + * Accessor for the Arrow type {@link NullVector}. + */ +public class ArrowFlightJdbcNullVectorAccessor extends ArrowFlightJdbcAccessor { + public ArrowFlightJdbcNullVectorAccessor( + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(null, setCursorWasNull); + } + + @Override + public Class getObjectClass() { + return Object.class; + } + + @Override + public boolean wasNull() { + return true; + } + + @Override + public Object getObject() { + this.wasNullConsumer.setWasNull(true); + return null; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcBinaryVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcBinaryVectorAccessor.java new file mode 100644 index 00000000000..c50d7349721 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcBinaryVectorAccessor.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.binary; + +import java.io.ByteArrayInputStream; +import java.io.CharArrayReader; +import java.io.InputStream; +import java.io.Reader; +import java.nio.charset.StandardCharsets; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.LargeVarBinaryVector; +import org.apache.arrow.vector.VarBinaryVector; + +/** + * Accessor for the Arrow types: {@link FixedSizeBinaryVector}, {@link VarBinaryVector} + * and {@link LargeVarBinaryVector}. + */ +public class ArrowFlightJdbcBinaryVectorAccessor extends ArrowFlightJdbcAccessor { + + private interface ByteArrayGetter { + byte[] get(int index); + } + + private final ByteArrayGetter getter; + + public ArrowFlightJdbcBinaryVectorAccessor(FixedSizeBinaryVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector::get, currentRowSupplier, setCursorWasNull); + } + + public ArrowFlightJdbcBinaryVectorAccessor(VarBinaryVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector::get, currentRowSupplier, setCursorWasNull); + } + + public ArrowFlightJdbcBinaryVectorAccessor(LargeVarBinaryVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector::get, currentRowSupplier, setCursorWasNull); + } + + private ArrowFlightJdbcBinaryVectorAccessor(ByteArrayGetter getter, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.getter = getter; + } + + @Override + public byte[] getBytes() { + byte[] bytes = getter.get(getCurrentRow()); + this.wasNull = bytes == null; + this.wasNullConsumer.setWasNull(this.wasNull); + + return bytes; + } + + @Override + public Object getObject() { + return this.getBytes(); + } + + @Override + public Class getObjectClass() { + return byte[].class; + } + + @Override + public String getString() { + byte[] bytes = this.getBytes(); + if (bytes == null) { + return null; + } + + return new String(bytes, StandardCharsets.UTF_8); + } + + @Override + public InputStream getAsciiStream() { + byte[] bytes = getBytes(); + if (bytes == null) { + return null; + } + + return new ByteArrayInputStream(bytes); + } + + @Override + public InputStream getUnicodeStream() { + byte[] bytes = getBytes(); + if (bytes == null) { + return null; + } + + return new ByteArrayInputStream(bytes); + } + + @Override + public InputStream getBinaryStream() { + byte[] bytes = getBytes(); + if (bytes == null) { + return null; + } + + return new ByteArrayInputStream(bytes); + } + + @Override + public Reader getCharacterStream() { + String string = getString(); + if (string == null) { + return null; + } + + return new CharArrayReader(string.toCharArray()); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorAccessor.java new file mode 100644 index 00000000000..f6c14a47f52 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorAccessor.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDateVectorGetter.Getter; +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDateVectorGetter.Holder; +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDateVectorGetter.createGetter; +import static org.apache.arrow.driver.jdbc.utils.DateTimeUtils.getTimestampValue; +import static org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY; +import static org.apache.calcite.avatica.util.DateTimeUtils.unixDateToString; + +import java.sql.Date; +import java.sql.Timestamp; +import java.util.Calendar; +import java.util.concurrent.TimeUnit; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.DateTimeUtils; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.ValueVector; + +/** + * Accessor for the Arrow types: {@link DateDayVector} and {@link DateMilliVector}. + */ +public class ArrowFlightJdbcDateVectorAccessor extends ArrowFlightJdbcAccessor { + + private final Getter getter; + private final TimeUnit timeUnit; + private final Holder holder; + + /** + * Instantiate an accessor for a {@link DateDayVector}. + * + * @param vector an instance of a DateDayVector. + * @param currentRowSupplier the supplier to track the lines. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcDateVectorAccessor(DateDayVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + /** + * Instantiate an accessor for a {@link DateMilliVector}. + * + * @param vector an instance of a DateMilliVector. + * @param currentRowSupplier the supplier to track the lines. + */ + public ArrowFlightJdbcDateVectorAccessor(DateMilliVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + @Override + public Class getObjectClass() { + return Date.class; + } + + @Override + public Object getObject() { + return this.getDate(null); + } + + @Override + public Date getDate(Calendar calendar) { + fillHolder(); + if (this.wasNull) { + return null; + } + + long value = holder.value; + long milliseconds = this.timeUnit.toMillis(value); + + long millisWithCalendar = DateTimeUtils.applyCalendarOffset(milliseconds, calendar); + + return new Date(getTimestampValue(millisWithCalendar).getTime()); + } + + private void fillHolder() { + getter.get(getCurrentRow(), holder); + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + } + + @Override + public Timestamp getTimestamp(Calendar calendar) { + Date date = getDate(calendar); + if (date == null) { + return null; + } + return new Timestamp(date.getTime()); + } + + @Override + public String getString() { + fillHolder(); + if (wasNull) { + return null; + } + long milliseconds = timeUnit.toMillis(holder.value); + return unixDateToString((int) (milliseconds / MILLIS_PER_DAY)); + } + + protected static TimeUnit getTimeUnitForVector(ValueVector vector) { + if (vector instanceof DateDayVector) { + return TimeUnit.DAYS; + } else if (vector instanceof DateMilliVector) { + return TimeUnit.MILLISECONDS; + } + + throw new IllegalArgumentException("Invalid Arrow vector"); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorGetter.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorGetter.java new file mode 100644 index 00000000000..ea545851a3a --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorGetter.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.holders.NullableDateDayHolder; +import org.apache.arrow.vector.holders.NullableDateMilliHolder; + +/** + * Auxiliary class used to unify data access on TimeStampVectors. + */ +final class ArrowFlightJdbcDateVectorGetter { + + private ArrowFlightJdbcDateVectorGetter() { + // Prevent instantiation. + } + + /** + * Auxiliary class meant to unify Date*Vector#get implementations with different classes of ValueHolders. + */ + static class Holder { + int isSet; // Tells if value is set; 0 = not set, 1 = set + long value; // Holds actual value in its respective timeunit + } + + /** + * Functional interface used to unify Date*Vector#get implementations. + */ + @FunctionalInterface + interface Getter { + void get(int index, Holder holder); + } + + static Getter createGetter(DateDayVector vector) { + NullableDateDayHolder auxHolder = new NullableDateDayHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + static Getter createGetter(DateMilliVector vector) { + NullableDateMilliHolder auxHolder = new NullableDateMilliHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDurationVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDurationVectorAccessor.java new file mode 100644 index 00000000000..22a0e6f8923 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDurationVectorAccessor.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import java.time.Duration; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.DurationVector; + +/** + * Accessor for the Arrow type {@link DurationVector}. + */ +public class ArrowFlightJdbcDurationVectorAccessor extends ArrowFlightJdbcAccessor { + + private final DurationVector vector; + + public ArrowFlightJdbcDurationVectorAccessor(DurationVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + public Class getObjectClass() { + return Duration.class; + } + + @Override + public Object getObject() { + Duration duration = vector.getObject(getCurrentRow()); + this.wasNull = duration == null; + this.wasNullConsumer.setWasNull(this.wasNull); + + return duration; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcIntervalVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcIntervalVectorAccessor.java new file mode 100644 index 00000000000..283dc9160a9 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcIntervalVectorAccessor.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.apache.arrow.driver.jdbc.utils.IntervalStringUtils.formatIntervalDay; +import static org.apache.arrow.driver.jdbc.utils.IntervalStringUtils.formatIntervalYear; +import static org.apache.arrow.vector.util.DateUtility.yearsToMonths; + +import java.sql.SQLException; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.IntervalYearVector; +import org.apache.arrow.vector.holders.NullableIntervalDayHolder; +import org.apache.arrow.vector.holders.NullableIntervalYearHolder; +import org.joda.time.Period; + +/** + * Accessor for the Arrow type {@link IntervalDayVector}. + */ +public class ArrowFlightJdbcIntervalVectorAccessor extends ArrowFlightJdbcAccessor { + + private final BaseFixedWidthVector vector; + private final StringGetter stringGetter; + private final Class objectClass; + + /** + * Instantiate an accessor for a {@link IntervalDayVector}. + * + * @param vector an instance of a IntervalDayVector. + * @param currentRowSupplier the supplier to track the rows. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcIntervalVectorAccessor(IntervalDayVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + stringGetter = (index) -> { + final NullableIntervalDayHolder holder = new NullableIntervalDayHolder(); + vector.get(index, holder); + if (holder.isSet == 0) { + return null; + } else { + final int days = holder.days; + final int millis = holder.milliseconds; + return formatIntervalDay(new Period().plusDays(days).plusMillis(millis)); + } + }; + objectClass = java.time.Duration.class; + } + + /** + * Instantiate an accessor for a {@link IntervalYearVector}. + * + * @param vector an instance of a IntervalYearVector. + * @param currentRowSupplier the supplier to track the rows. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcIntervalVectorAccessor(IntervalYearVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + stringGetter = (index) -> { + final NullableIntervalYearHolder holder = new NullableIntervalYearHolder(); + vector.get(index, holder); + if (holder.isSet == 0) { + return null; + } else { + final int interval = holder.value; + final int years = (interval / yearsToMonths); + final int months = (interval % yearsToMonths); + return formatIntervalYear(new Period().plusYears(years).plusMonths(months)); + } + }; + objectClass = java.time.Period.class; + } + + @Override + public Class getObjectClass() { + return objectClass; + } + + @Override + public String getString() throws SQLException { + String result = stringGetter.get(getCurrentRow()); + wasNull = result == null; + wasNullConsumer.setWasNull(wasNull); + return result; + } + + @Override + public Object getObject() { + Object object = vector.getObject(getCurrentRow()); + wasNull = object == null; + wasNullConsumer.setWasNull(wasNull); + return object; + } + + /** + * Functional interface used to unify Interval*Vector#getAsStringBuilder implementations. + */ + @FunctionalInterface + interface StringGetter { + String get(int index); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorAccessor.java new file mode 100644 index 00000000000..a23883baf1e --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorAccessor.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeStampVectorGetter.Getter; +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeStampVectorGetter.Holder; +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeStampVectorGetter.createGetter; + +import java.sql.Date; +import java.sql.Time; +import java.sql.Timestamp; +import java.time.LocalDateTime; +import java.time.temporal.ChronoUnit; +import java.util.Calendar; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.util.DateUtility; + +/** + * Accessor for the Arrow types extending from {@link TimeStampVector}. + */ +public class ArrowFlightJdbcTimeStampVectorAccessor extends ArrowFlightJdbcAccessor { + + private final TimeZone timeZone; + private final Getter getter; + private final TimeUnit timeUnit; + private final LongToLocalDateTime longToLocalDateTime; + private final Holder holder; + + /** + * Functional interface used to convert a number (in any time resolution) to LocalDateTime. + */ + interface LongToLocalDateTime { + LocalDateTime fromLong(long value); + } + + /** + * Instantiate a ArrowFlightJdbcTimeStampVectorAccessor for given vector. + */ + public ArrowFlightJdbcTimeStampVectorAccessor(TimeStampVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + + this.timeZone = getTimeZoneForVector(vector); + this.timeUnit = getTimeUnitForVector(vector); + this.longToLocalDateTime = getLongToLocalDateTimeForVector(vector, this.timeZone); + } + + @Override + public Class getObjectClass() { + return Timestamp.class; + } + + @Override + public Object getObject() { + return this.getTimestamp(null); + } + + private LocalDateTime getLocalDateTime(Calendar calendar) { + getter.get(getCurrentRow(), holder); + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return null; + } + + long value = holder.value; + + LocalDateTime localDateTime = this.longToLocalDateTime.fromLong(value); + + if (calendar != null) { + TimeZone timeZone = calendar.getTimeZone(); + long millis = this.timeUnit.toMillis(value); + localDateTime = localDateTime + .minus(timeZone.getOffset(millis) - this.timeZone.getOffset(millis), ChronoUnit.MILLIS); + } + return localDateTime; + } + + @Override + public Date getDate(Calendar calendar) { + LocalDateTime localDateTime = getLocalDateTime(calendar); + if (localDateTime == null) { + return null; + } + + return new Date(Timestamp.valueOf(localDateTime).getTime()); + } + + @Override + public Time getTime(Calendar calendar) { + LocalDateTime localDateTime = getLocalDateTime(calendar); + if (localDateTime == null) { + return null; + } + + return new Time(Timestamp.valueOf(localDateTime).getTime()); + } + + @Override + public Timestamp getTimestamp(Calendar calendar) { + LocalDateTime localDateTime = getLocalDateTime(calendar); + if (localDateTime == null) { + return null; + } + + return Timestamp.valueOf(localDateTime); + } + + protected static TimeUnit getTimeUnitForVector(TimeStampVector vector) { + ArrowType.Timestamp arrowType = + (ArrowType.Timestamp) vector.getField().getFieldType().getType(); + + switch (arrowType.getUnit()) { + case NANOSECOND: + return TimeUnit.NANOSECONDS; + case MICROSECOND: + return TimeUnit.MICROSECONDS; + case MILLISECOND: + return TimeUnit.MILLISECONDS; + case SECOND: + return TimeUnit.SECONDS; + default: + throw new UnsupportedOperationException("Invalid Arrow time unit"); + } + } + + protected static LongToLocalDateTime getLongToLocalDateTimeForVector(TimeStampVector vector, + TimeZone timeZone) { + String timeZoneID = timeZone.getID(); + + ArrowType.Timestamp arrowType = + (ArrowType.Timestamp) vector.getField().getFieldType().getType(); + + switch (arrowType.getUnit()) { + case NANOSECOND: + return nanoseconds -> DateUtility.getLocalDateTimeFromEpochNano(nanoseconds, timeZoneID); + case MICROSECOND: + return microseconds -> DateUtility.getLocalDateTimeFromEpochMicro(microseconds, timeZoneID); + case MILLISECOND: + return milliseconds -> DateUtility.getLocalDateTimeFromEpochMilli(milliseconds, timeZoneID); + case SECOND: + return seconds -> DateUtility.getLocalDateTimeFromEpochMilli( + TimeUnit.SECONDS.toMillis(seconds), timeZoneID); + default: + throw new UnsupportedOperationException("Invalid Arrow time unit"); + } + } + + protected static TimeZone getTimeZoneForVector(TimeStampVector vector) { + ArrowType.Timestamp arrowType = + (ArrowType.Timestamp) vector.getField().getFieldType().getType(); + + String timezoneName = arrowType.getTimezone(); + if (timezoneName == null) { + return TimeZone.getTimeZone("UTC"); + } + + return TimeZone.getTimeZone(timezoneName); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorGetter.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorGetter.java new file mode 100644 index 00000000000..03fb35face7 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorGetter.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampNanoTZVector; +import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TimeStampSecTZVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.holders.NullableTimeStampMicroHolder; +import org.apache.arrow.vector.holders.NullableTimeStampMicroTZHolder; +import org.apache.arrow.vector.holders.NullableTimeStampMilliHolder; +import org.apache.arrow.vector.holders.NullableTimeStampMilliTZHolder; +import org.apache.arrow.vector.holders.NullableTimeStampNanoHolder; +import org.apache.arrow.vector.holders.NullableTimeStampNanoTZHolder; +import org.apache.arrow.vector.holders.NullableTimeStampSecHolder; +import org.apache.arrow.vector.holders.NullableTimeStampSecTZHolder; + +/** + * Auxiliary class used to unify data access on TimeStampVectors. + */ +final class ArrowFlightJdbcTimeStampVectorGetter { + + private ArrowFlightJdbcTimeStampVectorGetter() { + // Prevent instantiation. + } + + /** + * Auxiliary class meant to unify TimeStamp*Vector#get implementations with different classes of ValueHolders. + */ + static class Holder { + int isSet; // Tells if value is set; 0 = not set, 1 = set + long value; // Holds actual value in its respective timeunit + } + + /** + * Functional interface used to unify TimeStamp*Vector#get implementations. + */ + @FunctionalInterface + interface Getter { + void get(int index, Holder holder); + } + + static Getter createGetter(TimeStampVector vector) { + if (vector instanceof TimeStampNanoVector) { + return createGetter((TimeStampNanoVector) vector); + } else if (vector instanceof TimeStampNanoTZVector) { + return createGetter((TimeStampNanoTZVector) vector); + } else if (vector instanceof TimeStampMicroVector) { + return createGetter((TimeStampMicroVector) vector); + } else if (vector instanceof TimeStampMicroTZVector) { + return createGetter((TimeStampMicroTZVector) vector); + } else if (vector instanceof TimeStampMilliVector) { + return createGetter((TimeStampMilliVector) vector); + } else if (vector instanceof TimeStampMilliTZVector) { + return createGetter((TimeStampMilliTZVector) vector); + } else if (vector instanceof TimeStampSecVector) { + return createGetter((TimeStampSecVector) vector); + } else if (vector instanceof TimeStampSecTZVector) { + return createGetter((TimeStampSecTZVector) vector); + } + + throw new UnsupportedOperationException("Unsupported Timestamp vector type"); + } + + private static Getter createGetter(TimeStampNanoVector vector) { + NullableTimeStampNanoHolder auxHolder = new NullableTimeStampNanoHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampNanoTZVector vector) { + NullableTimeStampNanoTZHolder auxHolder = new NullableTimeStampNanoTZHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampMicroVector vector) { + NullableTimeStampMicroHolder auxHolder = new NullableTimeStampMicroHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampMicroTZVector vector) { + NullableTimeStampMicroTZHolder auxHolder = new NullableTimeStampMicroTZHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampMilliVector vector) { + NullableTimeStampMilliHolder auxHolder = new NullableTimeStampMilliHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampMilliTZVector vector) { + NullableTimeStampMilliTZHolder auxHolder = new NullableTimeStampMilliTZHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampSecVector vector) { + NullableTimeStampSecHolder auxHolder = new NullableTimeStampSecHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + private static Getter createGetter(TimeStampSecTZVector vector) { + NullableTimeStampSecTZHolder auxHolder = new NullableTimeStampSecTZHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorAccessor.java new file mode 100644 index 00000000000..6c2173d5e56 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorAccessor.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeVectorGetter.Getter; +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeVectorGetter.Holder; +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeVectorGetter.createGetter; + +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; +import java.util.concurrent.TimeUnit; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.ArrowFlightJdbcTime; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.DateTimeUtils; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.ValueVector; + +/** + * Accessor for the Arrow types: {@link TimeNanoVector}, {@link TimeMicroVector}, {@link TimeMilliVector} + * and {@link TimeSecVector}. + */ +public class ArrowFlightJdbcTimeVectorAccessor extends ArrowFlightJdbcAccessor { + + private final Getter getter; + private final TimeUnit timeUnit; + private final Holder holder; + + /** + * Instantiate an accessor for a {@link TimeNanoVector}. + * + * @param vector an instance of a TimeNanoVector. + * @param currentRowSupplier the supplier to track the lines. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcTimeVectorAccessor(TimeNanoVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + /** + * Instantiate an accessor for a {@link TimeMicroVector}. + * + * @param vector an instance of a TimeMicroVector. + * @param currentRowSupplier the supplier to track the lines. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcTimeVectorAccessor(TimeMicroVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + /** + * Instantiate an accessor for a {@link TimeMilliVector}. + * + * @param vector an instance of a TimeMilliVector. + * @param currentRowSupplier the supplier to track the lines. + */ + public ArrowFlightJdbcTimeVectorAccessor(TimeMilliVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + /** + * Instantiate an accessor for a {@link TimeSecVector}. + * + * @param vector an instance of a TimeSecVector. + * @param currentRowSupplier the supplier to track the lines. + */ + public ArrowFlightJdbcTimeVectorAccessor(TimeSecVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new Holder(); + this.getter = createGetter(vector); + this.timeUnit = getTimeUnitForVector(vector); + } + + @Override + public Class getObjectClass() { + return Time.class; + } + + @Override + public Object getObject() { + return this.getTime(null); + } + + @Override + public Time getTime(Calendar calendar) { + fillHolder(); + if (this.wasNull) { + return null; + } + + long value = holder.value; + long milliseconds = this.timeUnit.toMillis(value); + + return new ArrowFlightJdbcTime(DateTimeUtils.applyCalendarOffset(milliseconds, calendar)); + } + + private void fillHolder() { + getter.get(getCurrentRow(), holder); + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + } + + @Override + public Timestamp getTimestamp(Calendar calendar) { + Time time = getTime(calendar); + if (time == null) { + return null; + } + return new Timestamp(time.getTime()); + } + + protected static TimeUnit getTimeUnitForVector(ValueVector vector) { + if (vector instanceof TimeNanoVector) { + return TimeUnit.NANOSECONDS; + } else if (vector instanceof TimeMicroVector) { + return TimeUnit.MICROSECONDS; + } else if (vector instanceof TimeMilliVector) { + return TimeUnit.MILLISECONDS; + } else if (vector instanceof TimeSecVector) { + return TimeUnit.SECONDS; + } + + throw new IllegalArgumentException("Invalid Arrow vector"); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorGetter.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorGetter.java new file mode 100644 index 00000000000..fb254c69401 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorGetter.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.holders.NullableTimeMicroHolder; +import org.apache.arrow.vector.holders.NullableTimeMilliHolder; +import org.apache.arrow.vector.holders.NullableTimeNanoHolder; +import org.apache.arrow.vector.holders.NullableTimeSecHolder; + +/** + * Auxiliary class used to unify data access on Time*Vectors. + */ +final class ArrowFlightJdbcTimeVectorGetter { + + private ArrowFlightJdbcTimeVectorGetter() { + // Prevent instantiation. + } + + /** + * Auxiliary class meant to unify TimeStamp*Vector#get implementations with different classes of ValueHolders. + */ + static class Holder { + int isSet; // Tells if value is set; 0 = not set, 1 = set + long value; // Holds actual value in its respective timeunit + } + + /** + * Functional interface used to unify TimeStamp*Vector#get implementations. + */ + @FunctionalInterface + interface Getter { + void get(int index, Holder holder); + } + + static Getter createGetter(TimeNanoVector vector) { + NullableTimeNanoHolder auxHolder = new NullableTimeNanoHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + static Getter createGetter(TimeMicroVector vector) { + NullableTimeMicroHolder auxHolder = new NullableTimeMicroHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + static Getter createGetter(TimeMilliVector vector) { + NullableTimeMilliHolder auxHolder = new NullableTimeMilliHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } + + static Getter createGetter(TimeSecVector vector) { + NullableTimeSecHolder auxHolder = new NullableTimeSecHolder(); + return (index, holder) -> { + vector.get(index, auxHolder); + holder.isSet = auxHolder.isSet; + holder.value = auxHolder.value; + }; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcListVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcListVectorAccessor.java new file mode 100644 index 00000000000..d3338608f83 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcListVectorAccessor.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.sql.Array; +import java.util.List; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.ArrowFlightJdbcArray; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.ListVector; + +/** + * Base Accessor for the Arrow types {@link ListVector}, {@link LargeListVector} and {@link FixedSizeListVector}. + */ +public abstract class AbstractArrowFlightJdbcListVectorAccessor extends ArrowFlightJdbcAccessor { + + protected AbstractArrowFlightJdbcListVectorAccessor(IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + } + + @Override + public Class getObjectClass() { + return List.class; + } + + protected abstract long getStartOffset(int index); + + protected abstract long getEndOffset(int index); + + protected abstract FieldVector getDataVector(); + + protected abstract boolean isNull(int index); + + @Override + public final Array getArray() { + int index = getCurrentRow(); + FieldVector dataVector = getDataVector(); + + this.wasNull = isNull(index); + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return null; + } + + long startOffset = getStartOffset(index); + long endOffset = getEndOffset(index); + + long valuesCount = endOffset - startOffset; + return new ArrowFlightJdbcArray(dataVector, startOffset, valuesCount); + } +} + diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcUnionVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcUnionVectorAccessor.java new file mode 100644 index 00000000000..0465765f183 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcUnionVectorAccessor.java @@ -0,0 +1,259 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.net.URL; +import java.sql.Array; +import java.sql.Blob; +import java.sql.Clob; +import java.sql.Date; +import java.sql.NClob; +import java.sql.Ref; +import java.sql.SQLException; +import java.sql.SQLXML; +import java.sql.Struct; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; +import java.util.Map; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.accessor.impl.ArrowFlightJdbcNullVectorAccessor; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.complex.UnionVector; + +/** + * Base accessor for {@link UnionVector} and {@link DenseUnionVector}. + */ +public abstract class AbstractArrowFlightJdbcUnionVectorAccessor extends ArrowFlightJdbcAccessor { + + /** + * Array of accessors for each type contained in UnionVector. + * Index corresponds to UnionVector and DenseUnionVector typeIds which are both limited to 128. + */ + private final ArrowFlightJdbcAccessor[] accessors = new ArrowFlightJdbcAccessor[128]; + + private final ArrowFlightJdbcNullVectorAccessor nullAccessor = + new ArrowFlightJdbcNullVectorAccessor((boolean wasNull) -> { + }); + + protected AbstractArrowFlightJdbcUnionVectorAccessor(IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + } + + protected abstract ArrowFlightJdbcAccessor createAccessorForVector(ValueVector vector); + + protected abstract byte getCurrentTypeId(); + + protected abstract ValueVector getVectorByTypeId(byte typeId); + + /** + * Returns an accessor for UnionVector child vector on current row. + * + * @return ArrowFlightJdbcAccessor for child vector on current row. + */ + protected ArrowFlightJdbcAccessor getAccessor() { + // Get the typeId and child vector for the current row being accessed. + byte typeId = this.getCurrentTypeId(); + ValueVector vector = this.getVectorByTypeId(typeId); + + if (typeId < 0) { + // typeId may be negative if the current row has no type defined. + return this.nullAccessor; + } + + // Ensure there is an accessor for given typeId + if (this.accessors[typeId] == null) { + this.accessors[typeId] = this.createAccessorForVector(vector); + } + + return this.accessors[typeId]; + } + + @Override + public Class getObjectClass() { + return getAccessor().getObjectClass(); + } + + @Override + public boolean wasNull() { + return getAccessor().wasNull(); + } + + @Override + public String getString() throws SQLException { + return getAccessor().getString(); + } + + @Override + public boolean getBoolean() throws SQLException { + return getAccessor().getBoolean(); + } + + @Override + public byte getByte() throws SQLException { + return getAccessor().getByte(); + } + + @Override + public short getShort() throws SQLException { + return getAccessor().getShort(); + } + + @Override + public int getInt() throws SQLException { + return getAccessor().getInt(); + } + + @Override + public long getLong() throws SQLException { + return getAccessor().getLong(); + } + + @Override + public float getFloat() throws SQLException { + return getAccessor().getFloat(); + } + + @Override + public double getDouble() throws SQLException { + return getAccessor().getDouble(); + } + + @Override + public BigDecimal getBigDecimal() throws SQLException { + return getAccessor().getBigDecimal(); + } + + @Override + public BigDecimal getBigDecimal(int i) throws SQLException { + return getAccessor().getBigDecimal(i); + } + + @Override + public byte[] getBytes() throws SQLException { + return getAccessor().getBytes(); + } + + @Override + public InputStream getAsciiStream() throws SQLException { + return getAccessor().getAsciiStream(); + } + + @Override + public InputStream getUnicodeStream() throws SQLException { + return getAccessor().getUnicodeStream(); + } + + @Override + public InputStream getBinaryStream() throws SQLException { + return getAccessor().getBinaryStream(); + } + + @Override + public Object getObject() throws SQLException { + return getAccessor().getObject(); + } + + @Override + public Reader getCharacterStream() throws SQLException { + return getAccessor().getCharacterStream(); + } + + @Override + public Object getObject(Map> map) throws SQLException { + return getAccessor().getObject(map); + } + + @Override + public Ref getRef() throws SQLException { + return getAccessor().getRef(); + } + + @Override + public Blob getBlob() throws SQLException { + return getAccessor().getBlob(); + } + + @Override + public Clob getClob() throws SQLException { + return getAccessor().getClob(); + } + + @Override + public Array getArray() throws SQLException { + return getAccessor().getArray(); + } + + @Override + public Struct getStruct() throws SQLException { + return getAccessor().getStruct(); + } + + @Override + public Date getDate(Calendar calendar) throws SQLException { + return getAccessor().getDate(calendar); + } + + @Override + public Time getTime(Calendar calendar) throws SQLException { + return getAccessor().getTime(calendar); + } + + @Override + public Timestamp getTimestamp(Calendar calendar) throws SQLException { + return getAccessor().getTimestamp(calendar); + } + + @Override + public URL getURL() throws SQLException { + return getAccessor().getURL(); + } + + @Override + public NClob getNClob() throws SQLException { + return getAccessor().getNClob(); + } + + @Override + public SQLXML getSQLXML() throws SQLException { + return getAccessor().getSQLXML(); + } + + @Override + public String getNString() throws SQLException { + return getAccessor().getNString(); + } + + @Override + public Reader getNCharacterStream() throws SQLException { + return getAccessor().getNCharacterStream(); + } + + @Override + public T getObject(Class type) throws SQLException { + return getAccessor().getObject(type); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcDenseUnionVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcDenseUnionVectorAccessor.java new file mode 100644 index 00000000000..ba5b83ade63 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcDenseUnionVectorAccessor.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.DenseUnionVector; + +/** + * Accessor for the Arrow type {@link DenseUnionVector}. + */ +public class ArrowFlightJdbcDenseUnionVectorAccessor + extends AbstractArrowFlightJdbcUnionVectorAccessor { + + private final DenseUnionVector vector; + + /** + * Instantiate an accessor for a {@link DenseUnionVector}. + * + * @param vector an instance of a DenseUnionVector. + * @param currentRowSupplier the supplier to track the rows. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcDenseUnionVectorAccessor(DenseUnionVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + protected ArrowFlightJdbcAccessor createAccessorForVector(ValueVector vector) { + return ArrowFlightJdbcAccessorFactory.createAccessor(vector, + () -> this.vector.getOffset(this.getCurrentRow()), (boolean wasNull) -> { + }); + } + + @Override + protected byte getCurrentTypeId() { + int index = getCurrentRow(); + return this.vector.getTypeId(index); + } + + @Override + protected ValueVector getVectorByTypeId(byte typeId) { + return this.vector.getVectorByType(typeId); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcFixedSizeListVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcFixedSizeListVectorAccessor.java new file mode 100644 index 00000000000..7bdd3abfd0c --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcFixedSizeListVectorAccessor.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.util.List; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; + +/** + * Accessor for the Arrow type {@link FixedSizeListVector}. + */ +public class ArrowFlightJdbcFixedSizeListVectorAccessor + extends AbstractArrowFlightJdbcListVectorAccessor { + + private final FixedSizeListVector vector; + + public ArrowFlightJdbcFixedSizeListVectorAccessor(FixedSizeListVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + protected long getStartOffset(int index) { + return (long) vector.getListSize() * index; + } + + @Override + protected long getEndOffset(int index) { + return (long) vector.getListSize() * (index + 1); + } + + @Override + protected FieldVector getDataVector() { + return vector.getDataVector(); + } + + @Override + protected boolean isNull(int index) { + return vector.isNull(index); + } + + @Override + public Object getObject() { + List object = vector.getObject(getCurrentRow()); + this.wasNull = object == null; + this.wasNullConsumer.setWasNull(this.wasNull); + + return object; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcLargeListVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcLargeListVectorAccessor.java new file mode 100644 index 00000000000..f7608bb06e5 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcLargeListVectorAccessor.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.util.List; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.complex.LargeListVector; + +/** + * Accessor for the Arrow type {@link LargeListVector}. + */ +public class ArrowFlightJdbcLargeListVectorAccessor + extends AbstractArrowFlightJdbcListVectorAccessor { + + private final LargeListVector vector; + + public ArrowFlightJdbcLargeListVectorAccessor(LargeListVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + protected long getStartOffset(int index) { + return vector.getOffsetBuffer().getLong((long) index * LargeListVector.OFFSET_WIDTH); + } + + @Override + protected long getEndOffset(int index) { + return vector.getOffsetBuffer().getLong((long) (index + 1) * LargeListVector.OFFSET_WIDTH); + } + + @Override + protected FieldVector getDataVector() { + return vector.getDataVector(); + } + + @Override + protected boolean isNull(int index) { + return vector.isNull(index); + } + + @Override + public Object getObject() { + List object = vector.getObject(getCurrentRow()); + this.wasNull = object == null; + this.wasNullConsumer.setWasNull(this.wasNull); + + return object; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcListVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcListVectorAccessor.java new file mode 100644 index 00000000000..a329a344073 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcListVectorAccessor.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.util.List; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.complex.BaseRepeatedValueVector; +import org.apache.arrow.vector.complex.ListVector; + +/** + * Accessor for the Arrow type {@link ListVector}. + */ +public class ArrowFlightJdbcListVectorAccessor extends AbstractArrowFlightJdbcListVectorAccessor { + + private final ListVector vector; + + public ArrowFlightJdbcListVectorAccessor(ListVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + protected long getStartOffset(int index) { + return vector.getOffsetBuffer().getInt((long) index * BaseRepeatedValueVector.OFFSET_WIDTH); + } + + @Override + protected long getEndOffset(int index) { + return vector.getOffsetBuffer() + .getInt((long) (index + 1) * BaseRepeatedValueVector.OFFSET_WIDTH); + } + + @Override + protected FieldVector getDataVector() { + return vector.getDataVector(); + } + + @Override + protected boolean isNull(int index) { + return vector.isNull(index); + } + + @Override + public Object getObject() { + List object = vector.getObject(getCurrentRow()); + this.wasNull = object == null; + this.wasNullConsumer.setWasNull(this.wasNull); + + return object; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcMapVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcMapVectorAccessor.java new file mode 100644 index 00000000000..bf1225b33de --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcMapVectorAccessor.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.util.Map; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.complex.BaseRepeatedValueVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.impl.UnionMapReader; +import org.apache.arrow.vector.util.JsonStringHashMap; + +/** + * Accessor for the Arrow type {@link MapVector}. + */ +public class ArrowFlightJdbcMapVectorAccessor extends AbstractArrowFlightJdbcListVectorAccessor { + + private final MapVector vector; + + public ArrowFlightJdbcMapVectorAccessor(MapVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + public Class getObjectClass() { + return Map.class; + } + + @Override + public Object getObject() { + int index = getCurrentRow(); + + this.wasNull = vector.isNull(index); + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return null; + } + + Map result = new JsonStringHashMap<>(); + UnionMapReader reader = vector.getReader(); + + reader.setPosition(index); + while (reader.next()) { + Object key = reader.key().readObject(); + Object value = reader.value().readObject(); + + result.put(key, value); + } + + return result; + } + + @Override + protected long getStartOffset(int index) { + return vector.getOffsetBuffer().getInt((long) index * BaseRepeatedValueVector.OFFSET_WIDTH); + } + + @Override + protected long getEndOffset(int index) { + return vector.getOffsetBuffer() + .getInt((long) (index + 1) * BaseRepeatedValueVector.OFFSET_WIDTH); + } + + @Override + protected boolean isNull(int index) { + return vector.isNull(index); + } + + @Override + protected FieldVector getDataVector() { + return vector.getDataVector(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcStructVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcStructVectorAccessor.java new file mode 100644 index 00000000000..8a7ac117113 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcStructVectorAccessor.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.sql.Struct; +import java.util.List; +import java.util.Map; +import java.util.function.IntSupplier; +import java.util.stream.Collectors; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.calcite.avatica.util.StructImpl; + +/** + * Accessor for the Arrow type {@link StructVector}. + */ +public class ArrowFlightJdbcStructVectorAccessor extends ArrowFlightJdbcAccessor { + + private final StructVector vector; + + public ArrowFlightJdbcStructVectorAccessor(StructVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + public Class getObjectClass() { + return Map.class; + } + + @Override + public Object getObject() { + Map object = vector.getObject(getCurrentRow()); + this.wasNull = object == null; + this.wasNullConsumer.setWasNull(this.wasNull); + + return object; + } + + @Override + public Struct getStruct() { + int currentRow = getCurrentRow(); + + this.wasNull = vector.isNull(currentRow); + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return null; + } + + List attributes = vector.getChildrenFromFields() + .stream() + .map(vector -> vector.getObject(currentRow)) + .collect(Collectors.toList()); + + return new StructImpl(attributes); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcUnionVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcUnionVectorAccessor.java new file mode 100644 index 00000000000..5b5a0a472d5 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcUnionVectorAccessor.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.UnionVector; + +/** + * Accessor for the Arrow type {@link UnionVector}. + */ +public class ArrowFlightJdbcUnionVectorAccessor extends AbstractArrowFlightJdbcUnionVectorAccessor { + + private final UnionVector vector; + + /** + * Instantiate an accessor for a {@link UnionVector}. + * + * @param vector an instance of a UnionVector. + * @param currentRowSupplier the supplier to track the rows. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcUnionVectorAccessor(UnionVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + } + + @Override + protected ArrowFlightJdbcAccessor createAccessorForVector(ValueVector vector) { + return ArrowFlightJdbcAccessorFactory.createAccessor(vector, this::getCurrentRow, + (boolean wasNull) -> { + }); + } + + @Override + protected byte getCurrentTypeId() { + int index = getCurrentRow(); + return (byte) this.vector.getTypeValue(index); + } + + @Override + protected ValueVector getVectorByTypeId(byte typeId) { + return this.vector.getVectorByType(typeId); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessor.java new file mode 100644 index 00000000000..aea9b75fa6c --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessor.java @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import static org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcNumericGetter.Getter; +import static org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcNumericGetter.createGetter; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcNumericGetter.NumericHolder; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.types.Types.MinorType; + +/** + * Accessor for the arrow types: TinyIntVector, SmallIntVector, IntVector, BigIntVector, + * UInt1Vector, UInt2Vector, UInt4Vector and UInt8Vector. + */ +public class ArrowFlightJdbcBaseIntVectorAccessor extends ArrowFlightJdbcAccessor { + + private final MinorType type; + private final boolean isUnsigned; + private final int bytesToAllocate; + private final Getter getter; + private final NumericHolder holder; + + public ArrowFlightJdbcBaseIntVectorAccessor(UInt1Vector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector, currentRowSupplier, true, UInt1Vector.TYPE_WIDTH, setCursorWasNull); + } + + public ArrowFlightJdbcBaseIntVectorAccessor(UInt2Vector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector, currentRowSupplier, true, UInt2Vector.TYPE_WIDTH, setCursorWasNull); + } + + public ArrowFlightJdbcBaseIntVectorAccessor(UInt4Vector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector, currentRowSupplier, true, UInt4Vector.TYPE_WIDTH, setCursorWasNull); + } + + public ArrowFlightJdbcBaseIntVectorAccessor(UInt8Vector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector, currentRowSupplier, true, UInt8Vector.TYPE_WIDTH, setCursorWasNull); + } + + public ArrowFlightJdbcBaseIntVectorAccessor(TinyIntVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector, currentRowSupplier, false, TinyIntVector.TYPE_WIDTH, setCursorWasNull); + } + + public ArrowFlightJdbcBaseIntVectorAccessor(SmallIntVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector, currentRowSupplier, false, SmallIntVector.TYPE_WIDTH, setCursorWasNull); + } + + public ArrowFlightJdbcBaseIntVectorAccessor(IntVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector, currentRowSupplier, false, IntVector.TYPE_WIDTH, setCursorWasNull); + } + + public ArrowFlightJdbcBaseIntVectorAccessor(BigIntVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector, currentRowSupplier, false, BigIntVector.TYPE_WIDTH, setCursorWasNull); + } + + private ArrowFlightJdbcBaseIntVectorAccessor(BaseIntVector vector, IntSupplier currentRowSupplier, + boolean isUnsigned, int bytesToAllocate, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.type = vector.getMinorType(); + this.holder = new NumericHolder(); + this.getter = createGetter(vector); + this.isUnsigned = isUnsigned; + this.bytesToAllocate = bytesToAllocate; + } + + @Override + public long getLong() { + getter.get(getCurrentRow(), holder); + + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return 0; + } + + return holder.value; + } + + @Override + public Class getObjectClass() { + return Long.class; + } + + @Override + public String getString() { + final long number = getLong(); + + if (this.wasNull) { + return null; + } else { + return isUnsigned ? Long.toUnsignedString(number) : Long.toString(number); + } + } + + @Override + public byte getByte() { + return (byte) getLong(); + } + + @Override + public short getShort() { + return (short) getLong(); + } + + @Override + public int getInt() { + return (int) getLong(); + } + + @Override + public float getFloat() { + return (float) getLong(); + } + + @Override + public double getDouble() { + return (double) getLong(); + } + + @Override + public BigDecimal getBigDecimal() { + final BigDecimal value = BigDecimal.valueOf(getLong()); + return this.wasNull ? null : value; + } + + @Override + public BigDecimal getBigDecimal(int scale) { + final BigDecimal value = + BigDecimal.valueOf(this.getDouble()).setScale(scale, RoundingMode.HALF_UP); + return this.wasNull ? null : value; + } + + @Override + public Number getObject() { + final Number number; + switch (type) { + case TINYINT: + case UINT1: + number = getByte(); + break; + case SMALLINT: + case UINT2: + number = getShort(); + break; + case INT: + case UINT4: + number = getInt(); + break; + case BIGINT: + case UINT8: + number = getLong(); + break; + default: + throw new IllegalStateException("No valid MinorType was provided."); + } + return wasNull ? null : number; + } + + @Override + public boolean getBoolean() { + final long value = getLong(); + + return value != 0; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessor.java new file mode 100644 index 00000000000..f55fd12f9a5 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessor.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import java.math.BigDecimal; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.holders.NullableBitHolder; + +/** + * Accessor for the arrow {@link BitVector}. + */ +public class ArrowFlightJdbcBitVectorAccessor extends ArrowFlightJdbcAccessor { + + private final BitVector vector; + private final NullableBitHolder holder; + private static final int BYTES_T0_ALLOCATE = 1; + + /** + * Constructor for the BitVectorAccessor. + * + * @param vector an instance of a {@link BitVector}. + * @param currentRowSupplier a supplier to check which row is being accessed. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcBitVectorAccessor(BitVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.vector = vector; + this.holder = new NullableBitHolder(); + } + + @Override + public Class getObjectClass() { + return Boolean.class; + } + + @Override + public String getString() { + final boolean value = getBoolean(); + return wasNull ? null : Boolean.toString(value); + } + + @Override + public boolean getBoolean() { + return this.getLong() != 0; + } + + @Override + public byte getByte() { + return (byte) this.getLong(); + } + + @Override + public short getShort() { + return (short) this.getLong(); + } + + @Override + public int getInt() { + return (int) this.getLong(); + } + + @Override + public long getLong() { + vector.get(getCurrentRow(), holder); + + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return 0; + } + + return holder.value; + } + + @Override + public float getFloat() { + return this.getLong(); + } + + @Override + public double getDouble() { + return this.getLong(); + } + + @Override + public BigDecimal getBigDecimal() { + final long value = this.getLong(); + + return this.wasNull ? null : BigDecimal.valueOf(value); + } + + @Override + public Object getObject() { + final boolean value = this.getBoolean(); + return this.wasNull ? null : value; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcDecimalVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcDecimalVectorAccessor.java new file mode 100644 index 00000000000..0f7d618a609 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcDecimalVectorAccessor.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.Decimal256Vector; +import org.apache.arrow.vector.DecimalVector; + +/** + * Accessor for {@link DecimalVector} and {@link Decimal256Vector}. + */ +public class ArrowFlightJdbcDecimalVectorAccessor extends ArrowFlightJdbcAccessor { + + private final Getter getter; + + /** + * Functional interface used to unify Decimal*Vector#getObject implementations. + */ + @FunctionalInterface + interface Getter { + BigDecimal getObject(int index); + } + + public ArrowFlightJdbcDecimalVectorAccessor(DecimalVector vector, IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.getter = vector::getObject; + } + + public ArrowFlightJdbcDecimalVectorAccessor(Decimal256Vector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.getter = vector::getObject; + } + + @Override + public Class getObjectClass() { + return BigDecimal.class; + } + + @Override + public BigDecimal getBigDecimal() { + final BigDecimal value = getter.getObject(getCurrentRow()); + this.wasNull = value == null; + this.wasNullConsumer.setWasNull(this.wasNull); + return value; + } + + @Override + public String getString() { + final BigDecimal value = this.getBigDecimal(); + return this.wasNull ? null : value.toString(); + } + + @Override + public boolean getBoolean() { + final BigDecimal value = this.getBigDecimal(); + + return !this.wasNull && !value.equals(BigDecimal.ZERO); + } + + @Override + public byte getByte() { + final BigDecimal value = this.getBigDecimal(); + + return this.wasNull ? 0 : value.byteValue(); + } + + @Override + public short getShort() { + final BigDecimal value = this.getBigDecimal(); + + return this.wasNull ? 0 : value.shortValue(); + } + + @Override + public int getInt() { + final BigDecimal value = this.getBigDecimal(); + + return this.wasNull ? 0 : value.intValue(); + } + + @Override + public long getLong() { + final BigDecimal value = this.getBigDecimal(); + + return this.wasNull ? 0 : value.longValue(); + } + + @Override + public float getFloat() { + final BigDecimal value = this.getBigDecimal(); + + return this.wasNull ? 0 : value.floatValue(); + } + + @Override + public double getDouble() { + final BigDecimal value = this.getBigDecimal(); + + return this.wasNull ? 0 : value.doubleValue(); + } + + @Override + public BigDecimal getBigDecimal(int scale) { + final BigDecimal value = this.getBigDecimal(); + + return this.wasNull ? null : value.setScale(scale, RoundingMode.HALF_UP); + } + + @Override + public Object getObject() { + return this.getBigDecimal(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat4VectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat4VectorAccessor.java new file mode 100644 index 00000000000..cbf2d36ff80 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat4VectorAccessor.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.sql.SQLException; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.holders.NullableFloat4Holder; + +/** + * Accessor for the Float4Vector. + */ +public class ArrowFlightJdbcFloat4VectorAccessor extends ArrowFlightJdbcAccessor { + + private final Float4Vector vector; + private final NullableFloat4Holder holder; + + /** + * Instantiate a accessor for the {@link Float4Vector}. + * + * @param vector an instance of a Float4Vector. + * @param currentRowSupplier the supplier to track the lines. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcFloat4VectorAccessor(Float4Vector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new NullableFloat4Holder(); + this.vector = vector; + } + + @Override + public Class getObjectClass() { + return Float.class; + } + + @Override + public String getString() { + final float value = this.getFloat(); + + return this.wasNull ? null : Float.toString(value); + } + + @Override + public boolean getBoolean() { + return this.getFloat() != 0.0; + } + + @Override + public byte getByte() { + return (byte) this.getFloat(); + } + + @Override + public short getShort() { + return (short) this.getFloat(); + } + + @Override + public int getInt() { + return (int) this.getFloat(); + } + + @Override + public long getLong() { + return (long) this.getFloat(); + } + + @Override + public float getFloat() { + vector.get(getCurrentRow(), holder); + + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return 0; + } + + return holder.value; + } + + @Override + public double getDouble() { + return this.getFloat(); + } + + @Override + public BigDecimal getBigDecimal() throws SQLException { + final float value = this.getFloat(); + + if (Float.isInfinite(value) || Float.isNaN(value)) { + throw new SQLException("BigDecimal doesn't support Infinite/NaN."); + } + + return this.wasNull ? null : BigDecimal.valueOf(value); + } + + @Override + public BigDecimal getBigDecimal(int scale) throws SQLException { + final float value = this.getFloat(); + if (Float.isInfinite(value) || Float.isNaN(value)) { + throw new SQLException("BigDecimal doesn't support Infinite/NaN."); + } + return this.wasNull ? null : BigDecimal.valueOf(value).setScale(scale, RoundingMode.HALF_UP); + } + + @Override + public Object getObject() { + final float value = this.getFloat(); + return this.wasNull ? null : value; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat8VectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat8VectorAccessor.java new file mode 100644 index 00000000000..dc5542ffc58 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat8VectorAccessor.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.sql.SQLException; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.holders.NullableFloat8Holder; + +/** + * Accessor for the Float8Vector. + */ +public class ArrowFlightJdbcFloat8VectorAccessor extends ArrowFlightJdbcAccessor { + + private final Float8Vector vector; + private final NullableFloat8Holder holder; + + /** + * Instantiate a accessor for the {@link Float8Vector}. + * + * @param vector an instance of a Float8Vector. + * @param currentRowSupplier the supplier to track the lines. + * @param setCursorWasNull the consumer to set if value was null. + */ + public ArrowFlightJdbcFloat8VectorAccessor(Float8Vector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.holder = new NullableFloat8Holder(); + this.vector = vector; + } + + @Override + public Class getObjectClass() { + return Double.class; + } + + @Override + public double getDouble() { + vector.get(getCurrentRow(), holder); + + this.wasNull = holder.isSet == 0; + this.wasNullConsumer.setWasNull(this.wasNull); + if (this.wasNull) { + return 0; + } + + return holder.value; + } + + @Override + public Object getObject() { + final double value = this.getDouble(); + + return this.wasNull ? null : value; + } + + @Override + public String getString() { + final double value = this.getDouble(); + return this.wasNull ? null : Double.toString(value); + } + + @Override + public boolean getBoolean() { + return this.getDouble() != 0.0; + } + + @Override + public byte getByte() { + return (byte) this.getDouble(); + } + + @Override + public short getShort() { + return (short) this.getDouble(); + } + + @Override + public int getInt() { + return (int) this.getDouble(); + } + + @Override + public long getLong() { + return (long) this.getDouble(); + } + + @Override + public float getFloat() { + return (float) this.getDouble(); + } + + @Override + public BigDecimal getBigDecimal() throws SQLException { + final double value = this.getDouble(); + if (Double.isInfinite(value) || Double.isNaN(value)) { + throw new SQLException("BigDecimal doesn't support Infinite/NaN."); + } + return this.wasNull ? null : BigDecimal.valueOf(value); + } + + @Override + public BigDecimal getBigDecimal(int scale) throws SQLException { + final double value = this.getDouble(); + if (Double.isInfinite(value) || Double.isNaN(value)) { + throw new SQLException("BigDecimal doesn't support Infinite/NaN."); + } + return this.wasNull ? null : BigDecimal.valueOf(value).setScale(scale, RoundingMode.HALF_UP); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcNumericGetter.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcNumericGetter.java new file mode 100644 index 00000000000..cc802a0089d --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcNumericGetter.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.holders.NullableBigIntHolder; +import org.apache.arrow.vector.holders.NullableIntHolder; +import org.apache.arrow.vector.holders.NullableSmallIntHolder; +import org.apache.arrow.vector.holders.NullableTinyIntHolder; +import org.apache.arrow.vector.holders.NullableUInt1Holder; +import org.apache.arrow.vector.holders.NullableUInt2Holder; +import org.apache.arrow.vector.holders.NullableUInt4Holder; +import org.apache.arrow.vector.holders.NullableUInt8Holder; + +/** + * A custom getter for values from the {@link BaseIntVector}. + */ +class ArrowFlightJdbcNumericGetter { + /** + * A holder for values from the {@link BaseIntVector}. + */ + static class NumericHolder { + int isSet; // Tells if value is set; 0 = not set, 1 = set + long value; // Holds actual value + } + + /** + * Functional interface for a getter to baseInt values. + */ + @FunctionalInterface + interface Getter { + void get(int index, NumericHolder holder); + } + + /** + * Main class that will check the type of the vector to create + * a specific getter. + * + * @param vector an instance of the {@link BaseIntVector} + * @return a getter. + */ + static Getter createGetter(BaseIntVector vector) { + if (vector instanceof UInt1Vector) { + return createGetter((UInt1Vector) vector); + } else if (vector instanceof UInt2Vector) { + return createGetter((UInt2Vector) vector); + } else if (vector instanceof UInt4Vector) { + return createGetter((UInt4Vector) vector); + } else if (vector instanceof UInt8Vector) { + return createGetter((UInt8Vector) vector); + } else if (vector instanceof TinyIntVector) { + return createGetter((TinyIntVector) vector); + } else if (vector instanceof SmallIntVector) { + return createGetter((SmallIntVector) vector); + } else if (vector instanceof IntVector) { + return createGetter((IntVector) vector); + } else if (vector instanceof BigIntVector) { + return createGetter((BigIntVector) vector); + } + + throw new UnsupportedOperationException("No valid IntVector was provided."); + } + + /** + * Create a specific getter for {@link UInt1Vector}. + * + * @param vector an instance of the {@link UInt1Vector} + * @return a getter. + */ + private static Getter createGetter(UInt1Vector vector) { + NullableUInt1Holder nullableUInt1Holder = new NullableUInt1Holder(); + + return (index, holder) -> { + vector.get(index, nullableUInt1Holder); + + holder.isSet = nullableUInt1Holder.isSet; + holder.value = nullableUInt1Holder.value; + }; + } + + /** + * Create a specific getter for {@link UInt2Vector}. + * + * @param vector an instance of the {@link UInt2Vector} + * @return a getter. + */ + private static Getter createGetter(UInt2Vector vector) { + NullableUInt2Holder nullableUInt2Holder = new NullableUInt2Holder(); + return (index, holder) -> { + vector.get(index, nullableUInt2Holder); + + holder.isSet = nullableUInt2Holder.isSet; + holder.value = nullableUInt2Holder.value; + }; + } + + /** + * Create a specific getter for {@link UInt4Vector}. + * + * @param vector an instance of the {@link UInt4Vector} + * @return a getter. + */ + private static Getter createGetter(UInt4Vector vector) { + NullableUInt4Holder nullableUInt4Holder = new NullableUInt4Holder(); + return (index, holder) -> { + vector.get(index, nullableUInt4Holder); + + holder.isSet = nullableUInt4Holder.isSet; + holder.value = nullableUInt4Holder.value; + }; + } + + /** + * Create a specific getter for {@link UInt8Vector}. + * + * @param vector an instance of the {@link UInt8Vector} + * @return a getter. + */ + private static Getter createGetter(UInt8Vector vector) { + NullableUInt8Holder nullableUInt8Holder = new NullableUInt8Holder(); + return (index, holder) -> { + vector.get(index, nullableUInt8Holder); + + holder.isSet = nullableUInt8Holder.isSet; + holder.value = nullableUInt8Holder.value; + }; + } + + /** + * Create a specific getter for {@link TinyIntVector}. + * + * @param vector an instance of the {@link TinyIntVector} + * @return a getter. + */ + private static Getter createGetter(TinyIntVector vector) { + NullableTinyIntHolder nullableTinyIntHolder = new NullableTinyIntHolder(); + return (index, holder) -> { + vector.get(index, nullableTinyIntHolder); + + holder.isSet = nullableTinyIntHolder.isSet; + holder.value = nullableTinyIntHolder.value; + }; + } + + /** + * Create a specific getter for {@link SmallIntVector}. + * + * @param vector an instance of the {@link SmallIntVector} + * @return a getter. + */ + private static Getter createGetter(SmallIntVector vector) { + NullableSmallIntHolder nullableSmallIntHolder = new NullableSmallIntHolder(); + return (index, holder) -> { + vector.get(index, nullableSmallIntHolder); + + holder.isSet = nullableSmallIntHolder.isSet; + holder.value = nullableSmallIntHolder.value; + }; + } + + /** + * Create a specific getter for {@link IntVector}. + * + * @param vector an instance of the {@link IntVector} + * @return a getter. + */ + private static Getter createGetter(IntVector vector) { + NullableIntHolder nullableIntHolder = new NullableIntHolder(); + return (index, holder) -> { + vector.get(index, nullableIntHolder); + + holder.isSet = nullableIntHolder.isSet; + holder.value = nullableIntHolder.value; + }; + } + + /** + * Create a specific getter for {@link BigIntVector}. + * + * @param vector an instance of the {@link BigIntVector} + * @return a getter. + */ + private static Getter createGetter(BigIntVector vector) { + NullableBigIntHolder nullableBigIntHolder = new NullableBigIntHolder(); + return (index, holder) -> { + vector.get(index, nullableBigIntHolder); + + holder.isSet = nullableBigIntHolder.isSet; + holder.value = nullableBigIntHolder.value; + }; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/text/ArrowFlightJdbcVarCharVectorAccessor.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/text/ArrowFlightJdbcVarCharVectorAccessor.java new file mode 100644 index 00000000000..aad8d9094c9 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/accessor/impl/text/ArrowFlightJdbcVarCharVectorAccessor.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.text; + +import static java.nio.charset.StandardCharsets.US_ASCII; +import static java.nio.charset.StandardCharsets.UTF_8; + +import java.io.ByteArrayInputStream; +import java.io.CharArrayReader; +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.SQLException; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.Calendar; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.DateTimeUtils; +import org.apache.arrow.vector.LargeVarCharVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.util.Text; + +/** + * Accessor for the Arrow types: {@link VarCharVector} and {@link LargeVarCharVector}. + */ +public class ArrowFlightJdbcVarCharVectorAccessor extends ArrowFlightJdbcAccessor { + + /** + * Functional interface to help integrating VarCharVector and LargeVarCharVector. + */ + @FunctionalInterface + interface Getter { + byte[] get(int index); + } + + private final Getter getter; + + public ArrowFlightJdbcVarCharVectorAccessor(VarCharVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector::get, currentRowSupplier, setCursorWasNull); + } + + public ArrowFlightJdbcVarCharVectorAccessor(LargeVarCharVector vector, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + this(vector::get, currentRowSupplier, setCursorWasNull); + } + + ArrowFlightJdbcVarCharVectorAccessor(Getter getter, + IntSupplier currentRowSupplier, + ArrowFlightJdbcAccessorFactory.WasNullConsumer setCursorWasNull) { + super(currentRowSupplier, setCursorWasNull); + this.getter = getter; + } + + @Override + public Class getObjectClass() { + return String.class; + } + + @Override + public String getObject() { + final byte[] bytes = getBytes(); + return bytes == null ? null : new String(bytes, UTF_8); + } + + @Override + public String getString() { + return getObject(); + } + + @Override + public byte[] getBytes() { + final byte[] bytes = this.getter.get(getCurrentRow()); + this.wasNull = bytes == null; + this.wasNullConsumer.setWasNull(this.wasNull); + return bytes; + } + + @Override + public boolean getBoolean() throws SQLException { + String value = getString(); + if (value == null || value.equalsIgnoreCase("false") || value.equals("0")) { + return false; + } else if (value.equalsIgnoreCase("true") || value.equals("1")) { + return true; + } else { + throw new SQLException("It is not possible to convert this value to boolean: " + value); + } + } + + @Override + public byte getByte() throws SQLException { + try { + return Byte.parseByte(this.getString()); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public short getShort() throws SQLException { + try { + return Short.parseShort(this.getString()); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public int getInt() throws SQLException { + try { + return Integer.parseInt(this.getString()); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public long getLong() throws SQLException { + try { + return Long.parseLong(this.getString()); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public float getFloat() throws SQLException { + try { + return Float.parseFloat(this.getString()); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public double getDouble() throws SQLException { + try { + return Double.parseDouble(this.getString()); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public BigDecimal getBigDecimal() throws SQLException { + try { + return new BigDecimal(this.getString()); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public BigDecimal getBigDecimal(int i) throws SQLException { + try { + return BigDecimal.valueOf(this.getLong(), i); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public InputStream getAsciiStream() { + final String textValue = getString(); + if (textValue == null) { + return null; + } + // Already in UTF-8 + return new ByteArrayInputStream(textValue.getBytes(US_ASCII)); + } + + @Override + public InputStream getUnicodeStream() { + final byte[] value = getBytes(); + if (value == null) { + return null; + } + + // Already in UTF-8 + final Text textValue = new Text(value); + return new ByteArrayInputStream(textValue.getBytes(), 0, textValue.getLength()); + } + + @Override + public Reader getCharacterStream() { + return new CharArrayReader(getString().toCharArray()); + } + + @Override + public Date getDate(Calendar calendar) throws SQLException { + try { + Date date = Date.valueOf(getString()); + if (calendar == null) { + return date; + } + + // Use Calendar to apply time zone's offset + long milliseconds = date.getTime(); + return new Date(DateTimeUtils.applyCalendarOffset(milliseconds, calendar)); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public Time getTime(Calendar calendar) throws SQLException { + try { + Time time = Time.valueOf(getString()); + if (calendar == null) { + return time; + } + + // Use Calendar to apply time zone's offset + long milliseconds = time.getTime(); + return new Time(DateTimeUtils.applyCalendarOffset(milliseconds, calendar)); + } catch (Exception e) { + throw new SQLException(e); + } + } + + @Override + public Timestamp getTimestamp(Calendar calendar) throws SQLException { + try { + Timestamp timestamp = Timestamp.valueOf(getString()); + if (calendar == null) { + return timestamp; + } + + // Use Calendar to apply time zone's offset + long milliseconds = timestamp.getTime(); + return new Timestamp(DateTimeUtils.applyCalendarOffset(milliseconds, calendar)); + } catch (Exception e) { + throw new SQLException(e); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java new file mode 100644 index 00000000000..afac6c16470 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -0,0 +1,582 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.client; + +import java.io.IOException; +import java.security.GeneralSecurityException; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils; +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightClientMiddleware; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.auth2.BearerCredentialWriter; +import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler; +import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware; +import org.apache.arrow.flight.client.ClientCookieMiddleware; +import org.apache.arrow.flight.grpc.CredentialCallOption; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo; +import org.apache.arrow.flight.sql.util.TableRef; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.Meta.StatementType; + +/** + * A {@link FlightSqlClient} handler. + */ +public final class ArrowFlightSqlClientHandler implements AutoCloseable { + + private final FlightSqlClient sqlClient; + private final Set options = new HashSet<>(); + + ArrowFlightSqlClientHandler(final FlightSqlClient sqlClient, + final Collection options) { + this.options.addAll(options); + this.sqlClient = Preconditions.checkNotNull(sqlClient); + } + + /** + * Creates a new {@link ArrowFlightSqlClientHandler} from the provided {@code client} and {@code options}. + * + * @param client the {@link FlightClient} to manage under a {@link FlightSqlClient} wrapper. + * @param options the {@link CallOption}s to persist in between subsequent client calls. + * @return a new {@link ArrowFlightSqlClientHandler}. + */ + public static ArrowFlightSqlClientHandler createNewHandler(final FlightClient client, + final Collection options) { + return new ArrowFlightSqlClientHandler(new FlightSqlClient(client), options); + } + + /** + * Gets the {@link #options} for the subsequent calls from this handler. + * + * @return the {@link CallOption}s. + */ + private CallOption[] getOptions() { + return options.toArray(new CallOption[0]); + } + + /** + * Makes an RPC "getStream" request based on the provided {@link FlightInfo} + * object. Retrieves the result of the query previously prepared with "getInfo." + * + * @param flightInfo The {@link FlightInfo} instance from which to fetch results. + * @return a {@code FlightStream} of results. + */ + public List getStreams(final FlightInfo flightInfo) { + return flightInfo.getEndpoints().stream() + .map(FlightEndpoint::getTicket) + .map(ticket -> sqlClient.getStream(ticket, getOptions())) + .collect(Collectors.toList()); + } + + /** + * Makes an RPC "getInfo" request based on the provided {@code query} + * object. + * + * @param query The query. + * @return a {@code FlightStream} of results. + */ + public FlightInfo getInfo(final String query) { + return sqlClient.execute(query, getOptions()); + } + + @Override + public void close() throws SQLException { + try { + AutoCloseables.close(sqlClient); + } catch (final Exception e) { + throw new SQLException("Failed to clean up client resources.", e); + } + } + + /** + * A prepared statement handler. + */ + public interface PreparedStatement extends AutoCloseable { + /** + * Executes this {@link PreparedStatement}. + * + * @return the {@link FlightInfo} representing the outcome of this query execution. + * @throws SQLException on error. + */ + FlightInfo executeQuery() throws SQLException; + + /** + * Executes a {@link StatementType#UPDATE} query. + * + * @return the number of rows affected. + */ + long executeUpdate(); + + /** + * Gets the {@link StatementType} of this {@link PreparedStatement}. + * + * @return the Statement Type. + */ + StatementType getType(); + + /** + * Gets the {@link Schema} of this {@link PreparedStatement}. + * + * @return {@link Schema}. + */ + Schema getDataSetSchema(); + + @Override + void close(); + } + + /** + * Creates a new {@link PreparedStatement} for the given {@code query}. + * + * @param query the SQL query. + * @return a new prepared statement. + */ + public PreparedStatement prepare(final String query) { + final FlightSqlClient.PreparedStatement preparedStatement = + sqlClient.prepare(query, getOptions()); + return new PreparedStatement() { + @Override + public FlightInfo executeQuery() throws SQLException { + return preparedStatement.execute(getOptions()); + } + + @Override + public long executeUpdate() { + return preparedStatement.executeUpdate(getOptions()); + } + + @Override + public StatementType getType() { + final Schema schema = preparedStatement.getResultSetSchema(); + return schema.getFields().isEmpty() ? StatementType.UPDATE : StatementType.SELECT; + } + + @Override + public Schema getDataSetSchema() { + return preparedStatement.getResultSetSchema(); + } + + @Override + public void close() { + preparedStatement.close(getOptions()); + } + }; + } + + /** + * Makes an RPC "getCatalogs" request. + * + * @return a {@code FlightStream} of results. + */ + public FlightInfo getCatalogs() { + return sqlClient.getCatalogs(getOptions()); + } + + /** + * Makes an RPC "getImportedKeys" request based on the provided info. + * + * @param catalog The catalog name. Must match the catalog name as it is stored in the database. + * Retrieves those without a catalog. Null means that the catalog name should not be used to + * narrow the search. + * @param schema The schema name. Must match the schema name as it is stored in the database. + * "" retrieves those without a schema. Null means that the schema name should not be used to narrow + * the search. + * @param table The table name. Must match the table name as it is stored in the database. + * @return a {@code FlightStream} of results. + */ + public FlightInfo getImportedKeys(final String catalog, final String schema, final String table) { + return sqlClient.getImportedKeys(TableRef.of(catalog, schema, table), getOptions()); + } + + /** + * Makes an RPC "getExportedKeys" request based on the provided info. + * + * @param catalog The catalog name. Must match the catalog name as it is stored in the database. + * Retrieves those without a catalog. Null means that the catalog name should not be used to + * narrow the search. + * @param schema The schema name. Must match the schema name as it is stored in the database. + * "" retrieves those without a schema. Null means that the schema name should not be used to narrow + * the search. + * @param table The table name. Must match the table name as it is stored in the database. + * @return a {@code FlightStream} of results. + */ + public FlightInfo getExportedKeys(final String catalog, final String schema, final String table) { + return sqlClient.getExportedKeys(TableRef.of(catalog, schema, table), getOptions()); + } + + /** + * Makes an RPC "getSchemas" request based on the provided info. + * + * @param catalog The catalog name. Must match the catalog name as it is stored in the database. + * Retrieves those without a catalog. Null means that the catalog name should not be used to + * narrow the search. + * @param schemaPattern The schema name pattern. Must match the schema name as it is stored in the database. + * Null means that schema name should not be used to narrow down the search. + * @return a {@code FlightStream} of results. + */ + public FlightInfo getSchemas(final String catalog, final String schemaPattern) { + return sqlClient.getSchemas(catalog, schemaPattern, getOptions()); + } + + /** + * Makes an RPC "getTableTypes" request. + * + * @return a {@code FlightStream} of results. + */ + public FlightInfo getTableTypes() { + return sqlClient.getTableTypes(getOptions()); + } + + /** + * Makes an RPC "getTables" request based on the provided info. + * + * @param catalog The catalog name. Must match the catalog name as it is stored in the database. + * Retrieves those without a catalog. Null means that the catalog name should not be used to + * narrow the search. + * @param schemaPattern The schema name pattern. Must match the schema name as it is stored in the database. + * "" retrieves those without a schema. Null means that the schema name should not be used to + * narrow the search. + * @param tableNamePattern The table name pattern. Must match the table name as it is stored in the database. + * @param types The list of table types, which must be from the list of table types to include. + * Null returns all types. + * @param includeSchema Whether to include schema. + * @return a {@code FlightStream} of results. + */ + public FlightInfo getTables(final String catalog, final String schemaPattern, + final String tableNamePattern, + final List types, final boolean includeSchema) { + + return sqlClient.getTables(catalog, schemaPattern, tableNamePattern, types, includeSchema, + getOptions()); + } + + /** + * Gets SQL info. + * + * @return the SQL info. + */ + public FlightInfo getSqlInfo(SqlInfo... info) { + return sqlClient.getSqlInfo(info, getOptions()); + } + + /** + * Makes an RPC "getPrimaryKeys" request based on the provided info. + * + * @param catalog The catalog name; must match the catalog name as it is stored in the database. + * "" retrieves those without a catalog. + * Null means that the catalog name should not be used to narrow the search. + * @param schema The schema name; must match the schema name as it is stored in the database. + * "" retrieves those without a schema. Null means that the schema name should not be used to narrow + * the search. + * @param table The table name. Must match the table name as it is stored in the database. + * @return a {@code FlightStream} of results. + */ + public FlightInfo getPrimaryKeys(final String catalog, final String schema, final String table) { + return sqlClient.getPrimaryKeys(TableRef.of(catalog, schema, table), getOptions()); + } + + /** + * Makes an RPC "getCrossReference" request based on the provided info. + * + * @param pkCatalog The catalog name. Must match the catalog name as it is stored in the database. + * Retrieves those without a catalog. Null means that the catalog name should not be used to + * narrow the search. + * @param pkSchema The schema name. Must match the schema name as it is stored in the database. + * "" retrieves those without a schema. Null means that the schema name should not be used to narrow + * the search. + * @param pkTable The table name. Must match the table name as it is stored in the database. + * @param fkCatalog The catalog name. Must match the catalog name as it is stored in the database. + * Retrieves those without a catalog. Null means that the catalog name should not be used to + * narrow the search. + * @param fkSchema The schema name. Must match the schema name as it is stored in the database. + * "" retrieves those without a schema. Null means that the schema name should not be used to narrow + * the search. + * @param fkTable The table name. Must match the table name as it is stored in the database. + * @return a {@code FlightStream} of results. + */ + public FlightInfo getCrossReference(String pkCatalog, String pkSchema, String pkTable, + String fkCatalog, String fkSchema, String fkTable) { + return sqlClient.getCrossReference(TableRef.of(pkCatalog, pkSchema, pkTable), + TableRef.of(fkCatalog, fkSchema, fkTable), + getOptions()); + } + + /** + * Builder for {@link ArrowFlightSqlClientHandler}. + */ + public static final class Builder { + private final Set middlewareFactories = new HashSet<>(); + private final Set options = new HashSet<>(); + private String host; + private int port; + private String username; + private String password; + private String trustStorePath; + private String trustStorePassword; + private String token; + private boolean useEncryption; + private boolean disableCertificateVerification; + private boolean useSystemTrustStore; + private BufferAllocator allocator; + + /** + * Sets the host for this handler. + * + * @param host the host. + * @return this instance. + */ + public Builder withHost(final String host) { + this.host = host; + return this; + } + + /** + * Sets the port for this handler. + * + * @param port the port. + * @return this instance. + */ + public Builder withPort(final int port) { + this.port = port; + return this; + } + + /** + * Sets the username for this handler. + * + * @param username the username. + * @return this instance. + */ + public Builder withUsername(final String username) { + this.username = username; + return this; + } + + /** + * Sets the password for this handler. + * + * @param password the password. + * @return this instance. + */ + public Builder withPassword(final String password) { + this.password = password; + return this; + } + + /** + * Sets the KeyStore path for this handler. + * + * @param trustStorePath the KeyStore path. + * @return this instance. + */ + public Builder withTrustStorePath(final String trustStorePath) { + this.trustStorePath = trustStorePath; + return this; + } + + /** + * Sets the KeyStore password for this handler. + * + * @param trustStorePassword the KeyStore password. + * @return this instance. + */ + public Builder withTrustStorePassword(final String trustStorePassword) { + this.trustStorePassword = trustStorePassword; + return this; + } + + /** + * Sets whether to use TLS encryption in this handler. + * + * @param useEncryption whether to use TLS encryption. + * @return this instance. + */ + public Builder withEncryption(final boolean useEncryption) { + this.useEncryption = useEncryption; + return this; + } + + /** + * Sets whether to disable the certificate verification in this handler. + * + * @param disableCertificateVerification whether to disable certificate verification. + * @return this instance. + */ + public Builder withDisableCertificateVerification(final boolean disableCertificateVerification) { + this.disableCertificateVerification = disableCertificateVerification; + return this; + } + + /** + * Sets whether to use the certificates from the operating system. + * + * @param useSystemTrustStore whether to use the system operating certificates. + * @return this instance. + */ + public Builder withSystemTrustStore(final boolean useSystemTrustStore) { + this.useSystemTrustStore = useSystemTrustStore; + return this; + } + + /** + * Sets the token used in the token authetication. + * @param token the token value. + * @return this builder instance. + */ + public Builder withToken(final String token) { + this.token = token; + return this; + } + + /** + * Sets the {@link BufferAllocator} to use in this handler. + * + * @param allocator the allocator. + * @return this instance. + */ + public Builder withBufferAllocator(final BufferAllocator allocator) { + this.allocator = allocator + .newChildAllocator("ArrowFlightSqlClientHandler", 0, allocator.getLimit()); + return this; + } + + /** + * Adds the provided {@code factories} to the list of {@link #middlewareFactories} of this handler. + * + * @param factories the factories to add. + * @return this instance. + */ + public Builder withMiddlewareFactories(final FlightClientMiddleware.Factory... factories) { + return withMiddlewareFactories(Arrays.asList(factories)); + } + + /** + * Adds the provided {@code factories} to the list of {@link #middlewareFactories} of this handler. + * + * @param factories the factories to add. + * @return this instance. + */ + public Builder withMiddlewareFactories( + final Collection factories) { + this.middlewareFactories.addAll(factories); + return this; + } + + /** + * Adds the provided {@link CallOption}s to this handler. + * + * @param options the options + * @return this instance. + */ + public Builder withCallOptions(final CallOption... options) { + return withCallOptions(Arrays.asList(options)); + } + + /** + * Adds the provided {@link CallOption}s to this handler. + * + * @param options the options + * @return this instance. + */ + public Builder withCallOptions(final Collection options) { + this.options.addAll(options); + return this; + } + + /** + * Builds a new {@link ArrowFlightSqlClientHandler} from the provided fields. + * + * @return a new client handler. + * @throws SQLException on error. + */ + public ArrowFlightSqlClientHandler build() throws SQLException { + FlightClient client = null; + try { + ClientIncomingAuthHeaderMiddleware.Factory authFactory = null; + if (username != null) { + authFactory = + new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler()); + withMiddlewareFactories(authFactory); + } + final FlightClient.Builder clientBuilder = FlightClient.builder().allocator(allocator); + withMiddlewareFactories(new ClientCookieMiddleware.Factory()); + middlewareFactories.forEach(clientBuilder::intercept); + Location location; + if (useEncryption) { + location = Location.forGrpcTls(host, port); + clientBuilder.useTls(); + } else { + location = Location.forGrpcInsecure(host, port); + } + clientBuilder.location(location); + + if (useEncryption) { + if (disableCertificateVerification) { + clientBuilder.verifyServer(false); + } else { + if (useSystemTrustStore) { + clientBuilder.trustedCertificates( + ClientAuthenticationUtils.getCertificateInputStreamFromSystem(trustStorePassword)); + } else if (trustStorePath != null) { + clientBuilder.trustedCertificates( + ClientAuthenticationUtils.getCertificateStream(trustStorePath, trustStorePassword)); + } + } + } + + client = clientBuilder.build(); + if (authFactory != null) { + options.add( + ClientAuthenticationUtils.getAuthenticate(client, username, password, authFactory)); + } else if (token != null) { + options.add( + ClientAuthenticationUtils.getAuthenticate( + client, new CredentialCallOption(new BearerCredentialWriter(token)))); + } + return ArrowFlightSqlClientHandler.createNewHandler(client, options); + + } catch (final IllegalArgumentException | GeneralSecurityException | IOException | FlightRuntimeException e) { + final SQLException originalException = new SQLException(e); + if (client != null) { + try { + client.close(); + } catch (final InterruptedException interruptedException) { + originalException.addSuppressed(interruptedException); + } + } + throw originalException; + } + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/utils/ClientAuthenticationUtils.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/utils/ClientAuthenticationUtils.java new file mode 100644 index 00000000000..854b036ae6b --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/utils/ClientAuthenticationUtils.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.client.utils; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.StringWriter; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Enumeration; +import java.util.List; + +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.auth2.BasicAuthCredentialWriter; +import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware; +import org.apache.arrow.flight.grpc.CredentialCallOption; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.util.VisibleForTesting; +import org.bouncycastle.openssl.jcajce.JcaPEMWriter; + +/** + * Utils for {@link FlightClientHandler} authentication. + */ +public final class ClientAuthenticationUtils { + + private ClientAuthenticationUtils() { + // Prevent instantiation. + } + + /** + * Gets the {@link CredentialCallOption} for the provided authentication info. + * + * @param client the client. + * @param credential the credential as CallOptions. + * @param options the {@link CallOption}s to use. + * @return the credential call option. + */ + public static CredentialCallOption getAuthenticate(final FlightClient client, + final CredentialCallOption credential, + final CallOption... options) { + + final List theseOptions = new ArrayList<>(); + theseOptions.add(credential); + theseOptions.addAll(Arrays.asList(options)); + client.handshake(theseOptions.toArray(new CallOption[0])); + + return (CredentialCallOption) theseOptions.get(0); + } + + /** + * Gets the {@link CredentialCallOption} for the provided authentication info. + * + * @param client the client. + * @param username the username. + * @param password the password. + * @param factory the {@link ClientIncomingAuthHeaderMiddleware.Factory} to use. + * @param options the {@link CallOption}s to use. + * @return the credential call option. + */ + public static CredentialCallOption getAuthenticate(final FlightClient client, + final String username, final String password, + final ClientIncomingAuthHeaderMiddleware.Factory factory, + final CallOption... options) { + + return getAuthenticate(client, + new CredentialCallOption(new BasicAuthCredentialWriter(username, password)), + factory, options); + } + + private static CredentialCallOption getAuthenticate(final FlightClient client, + final CredentialCallOption token, + final ClientIncomingAuthHeaderMiddleware.Factory factory, + final CallOption... options) { + final List theseOptions = new ArrayList<>(); + theseOptions.add(token); + theseOptions.addAll(Arrays.asList(options)); + client.handshake(theseOptions.toArray(new CallOption[0])); + return factory.getCredentialCallOption(); + } + + @VisibleForTesting + static KeyStore getKeyStoreInstance(String instance) + throws KeyStoreException, CertificateException, IOException, NoSuchAlgorithmException { + KeyStore keyStore = KeyStore.getInstance(instance); + keyStore.load(null, null); + + return keyStore; + } + + static String getOperatingSystem() { + return System.getProperty("os.name"); + } + + /** + * Check if the operating system running the software is Windows. + * + * @return whether is the windows system. + */ + public static boolean isWindows() { + return getOperatingSystem().contains("Windows"); + } + + /** + * Check if the operating system running the software is Mac. + * + * @return whether is the mac system. + */ + public static boolean isMac() { + return getOperatingSystem().contains("Mac"); + } + + /** + * It gets the trusted certificate based on the operating system and loads all the certificate into a + * {@link InputStream}. + * + * @return An input stream with all the certificates. + * + * @throws KeyStoreException if a key store could not be loaded. + * @throws CertificateException if a certificate could not be found. + * @throws IOException if it fails reading the file. + */ + public static InputStream getCertificateInputStreamFromSystem(String password) throws KeyStoreException, + CertificateException, IOException, NoSuchAlgorithmException { + + List keyStoreList = new ArrayList<>(); + if (isWindows()) { + keyStoreList.add(getKeyStoreInstance("Windows-ROOT")); + keyStoreList.add(getKeyStoreInstance("Windows-MY")); + } else if (isMac()) { + keyStoreList.add(getKeyStoreInstance("KeychainStore")); + } else { + Path path = Paths.get(System.getProperty("java.home"), "lib", "security", "cacerts"); + try (InputStream fileInputStream = Files.newInputStream(path)) { + KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + keyStore.load(fileInputStream, password.toCharArray()); + keyStoreList.add(keyStore); + } + } + + return getCertificatesInputStream(keyStoreList); + } + + @VisibleForTesting + static void getCertificatesInputStream(KeyStore keyStore, JcaPEMWriter pemWriter) + throws IOException, KeyStoreException { + Enumeration aliases = keyStore.aliases(); + while (aliases.hasMoreElements()) { + String alias = aliases.nextElement(); + if (keyStore.isCertificateEntry(alias)) { + pemWriter.writeObject(keyStore.getCertificate(alias)); + } + } + pemWriter.flush(); + } + + @VisibleForTesting + static InputStream getCertificatesInputStream(Collection keyStores) + throws IOException, KeyStoreException { + try (final StringWriter writer = new StringWriter(); + final JcaPEMWriter pemWriter = new JcaPEMWriter(writer)) { + + for (KeyStore keyStore : keyStores) { + getCertificatesInputStream(keyStore, pemWriter); + } + + return new ByteArrayInputStream( + writer.toString().getBytes(StandardCharsets.UTF_8)); + } + } + + /** + * Generates an {@link InputStream} that contains certificates for a private + * key. + * + * @param keyStorePath The path of the KeyStore. + * @param keyStorePass The password of the KeyStore. + * @return a new {code InputStream} containing the certificates. + * @throws GeneralSecurityException on error. + * @throws IOException on error. + */ + public static InputStream getCertificateStream(final String keyStorePath, + final String keyStorePass) + throws GeneralSecurityException, IOException { + Preconditions.checkNotNull(keyStorePath, "KeyStore path cannot be null!"); + Preconditions.checkNotNull(keyStorePass, "KeyStorePass cannot be null!"); + final KeyStore keyStore = KeyStore.getInstance(KeyStore.getDefaultType()); + + try (final InputStream keyStoreStream = Files + .newInputStream(Paths.get(Preconditions.checkNotNull(keyStorePath)))) { + keyStore.load(keyStoreStream, + Preconditions.checkNotNull(keyStorePass).toCharArray()); + } + + return getSingleCertificateInputStream(keyStore); + } + + private static InputStream getSingleCertificateInputStream(KeyStore keyStore) + throws KeyStoreException, IOException, CertificateException { + final Enumeration aliases = keyStore.aliases(); + + while (aliases.hasMoreElements()) { + final String alias = aliases.nextElement(); + if (keyStore.isCertificateEntry(alias)) { + return toInputStream(keyStore.getCertificate(alias)); + } + } + + throw new CertificateException("Keystore did not have a certificate."); + } + + private static InputStream toInputStream(final Certificate certificate) + throws IOException { + + try (final StringWriter writer = new StringWriter(); + final JcaPEMWriter pemWriter = new JcaPEMWriter(writer)) { + + pemWriter.writeObject(certificate); + pemWriter.flush(); + return new ByteArrayInputStream( + writer.toString().getBytes(StandardCharsets.UTF_8)); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java new file mode 100644 index 00000000000..ac338a85d62 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImpl.java @@ -0,0 +1,292 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.ArrowFlightConnection; +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.FlightCallHeaders; +import org.apache.arrow.flight.HeaderCallOption; +import org.apache.arrow.util.Preconditions; +import org.apache.calcite.avatica.ConnectionConfig; +import org.apache.calcite.avatica.ConnectionConfigImpl; +import org.apache.calcite.avatica.ConnectionProperty; + +/** + * A {@link ConnectionConfig} for the {@link ArrowFlightConnection}. + */ +public final class ArrowFlightConnectionConfigImpl extends ConnectionConfigImpl { + public ArrowFlightConnectionConfigImpl(final Properties properties) { + super(properties); + } + + /** + * Gets the host. + * + * @return the host. + */ + public String getHost() { + return ArrowFlightConnectionProperty.HOST.getString(properties); + } + + /** + * Gets the port. + * + * @return the port. + */ + public int getPort() { + return ArrowFlightConnectionProperty.PORT.getInteger(properties); + } + + /** + * Gets the host. + * + * @return the host. + */ + public String getUser() { + return ArrowFlightConnectionProperty.USER.getString(properties); + } + + /** + * Gets the host. + * + * @return the host. + */ + public String getPassword() { + return ArrowFlightConnectionProperty.PASSWORD.getString(properties); + } + + + public String getToken() { + return ArrowFlightConnectionProperty.TOKEN.getString(properties); + } + + /** + * Gets the KeyStore path. + * + * @return the path. + */ + public String getTrustStorePath() { + return ArrowFlightConnectionProperty.TRUST_STORE.getString(properties); + } + + /** + * Gets the KeyStore password. + * + * @return the password. + */ + public String getTrustStorePassword() { + return ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.getString(properties); + } + + /** + * Check if the JDBC should use the trusted store files from the operating system. + * + * @return whether to use system trusted store certificates. + */ + public boolean useSystemTrustStore() { + return ArrowFlightConnectionProperty.USE_SYSTEM_TRUST_STORE.getBoolean(properties); + } + + /** + * Whether to use TLS encryption. + * + * @return whether to use TLS encryption. + */ + public boolean useEncryption() { + return ArrowFlightConnectionProperty.USE_ENCRYPTION.getBoolean(properties); + } + + public boolean getDisableCertificateVerification() { + return ArrowFlightConnectionProperty.CERTIFICATE_VERIFICATION.getBoolean(properties); + } + + /** + * Gets the thread pool size. + * + * @return the thread pool size. + */ + public int threadPoolSize() { + return ArrowFlightConnectionProperty.THREAD_POOL_SIZE.getInteger(properties); + } + + /** + * Gets the {@link CallOption}s from this {@link ConnectionConfig}. + * + * @return the call options. + */ + public CallOption toCallOption() { + final CallHeaders headers = new FlightCallHeaders(); + Map headerAttributes = getHeaderAttributes(); + headerAttributes.forEach(headers::insert); + return new HeaderCallOption(headers); + } + + /** + * Gets which properties should be added as headers. + * + * @return {@link Map} + */ + public Map getHeaderAttributes() { + Map headers = new HashMap<>(); + ArrowFlightConnectionProperty[] builtInProperties = ArrowFlightConnectionProperty.values(); + properties.forEach( + (key, val) -> { + // For built-in properties before adding new headers + if (Arrays.stream(builtInProperties) + .noneMatch(builtInProperty -> builtInProperty.camelName.equalsIgnoreCase(key.toString()))) { + headers.put(key.toString(), val.toString()); + } + }); + return headers; + } + + /** + * Custom {@link ConnectionProperty} for the {@link ArrowFlightConnectionConfigImpl}. + */ + public enum ArrowFlightConnectionProperty implements ConnectionProperty { + HOST("host", null, Type.STRING, true), + PORT("port", null, Type.NUMBER, true), + USER("user", null, Type.STRING, false), + PASSWORD("password", null, Type.STRING, false), + USE_ENCRYPTION("useEncryption", true, Type.BOOLEAN, false), + CERTIFICATE_VERIFICATION("disableCertificateVerification", false, Type.BOOLEAN, false), + TRUST_STORE("trustStore", null, Type.STRING, false), + TRUST_STORE_PASSWORD("trustStorePassword", null, Type.STRING, false), + USE_SYSTEM_TRUST_STORE("useSystemTrustStore", true, Type.BOOLEAN, false), + THREAD_POOL_SIZE("threadPoolSize", 1, Type.NUMBER, false), + TOKEN("token", null, Type.STRING, false); + + private final String camelName; + private final Object defaultValue; + private final Type type; + private final boolean required; + + ArrowFlightConnectionProperty(final String camelName, final Object defaultValue, + final Type type, final boolean required) { + this.camelName = Preconditions.checkNotNull(camelName); + this.defaultValue = defaultValue; + this.type = Preconditions.checkNotNull(type); + this.required = required; + } + + /** + * Gets the property. + * + * @param properties the properties from which to fetch this property. + * @return the property. + */ + public Object get(final Properties properties) { + Preconditions.checkNotNull(properties, "Properties cannot be null."); + Object value = properties.get(camelName); + if (value == null) { + value = properties.get(camelName.toLowerCase()); + } + if (required) { + if (value == null) { + throw new IllegalStateException(String.format("Required property not provided: <%s>.", this)); + } + return value; + } else { + return value != null ? value : defaultValue; + } + } + + /** + * Gets the property as Boolean. + * + * @param properties the properties from which to fetch this property. + * @return the property. + */ + public Boolean getBoolean(final Properties properties) { + final String valueFromProperties = String.valueOf(get(properties)); + return valueFromProperties.equals("1") || valueFromProperties.equals("true"); + } + + /** + * Gets the property as Integer. + * + * @param properties the properties from which to fetch this property. + * @return the property. + */ + public Integer getInteger(final Properties properties) { + final String valueFromProperties = String.valueOf(get(properties)); + return valueFromProperties.equals("null") ? null : Integer.parseInt(valueFromProperties); + } + + /** + * Gets the property as String. + * + * @param properties the properties from which to fetch this property. + * @return the property. + */ + public String getString(final Properties properties) { + return Objects.toString(get(properties), null); + } + + @Override + public String camelName() { + return camelName; + } + + @Override + public Object defaultValue() { + return defaultValue; + } + + @Override + public Type type() { + return type; + } + + @Override + public PropEnv wrap(final Properties properties) { + throw new UnsupportedOperationException("Operation unsupported."); + } + + @Override + public boolean required() { + return required; + } + + @Override + public Class valueClass() { + return type.defaultValueClass(); + } + + /** + * Replaces the semicolons in the URL to the proper format. + * + * @param url the current connection string + * @return the formatted url + */ + public static String replaceSemiColons(String url) { + if (url != null) { + url = url.replaceFirst(";", "?"); + url = url.replaceAll(";", "&"); + } + return url; + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ConnectionWrapper.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ConnectionWrapper.java new file mode 100644 index 00000000000..5ee43ce012e --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ConnectionWrapper.java @@ -0,0 +1,344 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static com.google.common.base.Preconditions.checkNotNull; + +import java.sql.Array; +import java.sql.Blob; +import java.sql.CallableStatement; +import java.sql.Clob; +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.NClob; +import java.sql.PreparedStatement; +import java.sql.SQLClientInfoException; +import java.sql.SQLException; +import java.sql.SQLWarning; +import java.sql.SQLXML; +import java.sql.Savepoint; +import java.sql.Statement; +import java.sql.Struct; +import java.util.Map; +import java.util.Properties; +import java.util.concurrent.Executor; + +import org.apache.arrow.driver.jdbc.ArrowFlightJdbcPooledConnection; + +/** + * Auxiliary wrapper class for {@link Connection}, used on {@link ArrowFlightJdbcPooledConnection}. + */ +public class ConnectionWrapper implements Connection { + private final Connection realConnection; + + public ConnectionWrapper(final Connection connection) { + realConnection = checkNotNull(connection); + } + + @Override + public T unwrap(final Class type) { + return type.cast(realConnection); + } + + @Override + public boolean isWrapperFor(final Class type) { + return realConnection.getClass().isAssignableFrom(type); + } + + @Override + public Statement createStatement() throws SQLException { + return realConnection.createStatement(); + } + + @Override + public PreparedStatement prepareStatement(final String sqlQuery) throws SQLException { + return realConnection.prepareStatement(sqlQuery); + } + + @Override + public CallableStatement prepareCall(final String sqlQuery) throws SQLException { + return realConnection.prepareCall(sqlQuery); + } + + @Override + public String nativeSQL(final String sqlStatement) throws SQLException { + return realConnection.nativeSQL(sqlStatement); + } + + @Override + public void setAutoCommit(boolean autoCommit) throws SQLException { + realConnection.setAutoCommit(autoCommit); + } + + @Override + public boolean getAutoCommit() throws SQLException { + return realConnection.getAutoCommit(); + } + + @Override + public void commit() throws SQLException { + realConnection.commit(); + } + + @Override + public void rollback() throws SQLException { + realConnection.rollback(); + } + + @Override + public void close() throws SQLException { + realConnection.close(); + } + + @Override + public boolean isClosed() throws SQLException { + return realConnection.isClosed(); + } + + @Override + public DatabaseMetaData getMetaData() throws SQLException { + return realConnection.getMetaData(); + } + + @Override + public void setReadOnly(final boolean readOnly) throws SQLException { + realConnection.setReadOnly(readOnly); + } + + @Override + public boolean isReadOnly() throws SQLException { + return realConnection.isReadOnly(); + } + + @Override + public void setCatalog(final String catalogName) throws SQLException { + realConnection.setCatalog(catalogName); + } + + @Override + public String getCatalog() throws SQLException { + return realConnection.getCatalog(); + } + + @Override + public void setTransactionIsolation(final int transactionIsolationId) throws SQLException { + realConnection.setTransactionIsolation(transactionIsolationId); + } + + @Override + public int getTransactionIsolation() throws SQLException { + return realConnection.getTransactionIsolation(); + } + + @Override + public SQLWarning getWarnings() throws SQLException { + return realConnection.getWarnings(); + } + + @Override + public void clearWarnings() throws SQLException { + realConnection.clearWarnings(); + } + + @Override + public Statement createStatement(final int resultSetTypeId, final int resultSetConcurrencyId) + throws SQLException { + return realConnection.createStatement(resultSetTypeId, resultSetConcurrencyId); + } + + @Override + public PreparedStatement prepareStatement(final String sqlQuery, final int resultSetTypeId, + final int resultSetConcurrencyId) + throws SQLException { + return realConnection.prepareStatement(sqlQuery, resultSetTypeId, resultSetConcurrencyId); + } + + @Override + public CallableStatement prepareCall(final String query, final int resultSetTypeId, + final int resultSetConcurrencyId) + throws SQLException { + return realConnection.prepareCall(query, resultSetTypeId, resultSetConcurrencyId); + } + + @Override + public Map> getTypeMap() throws SQLException { + return realConnection.getTypeMap(); + } + + @Override + public void setTypeMap(final Map> typeNameToClass) throws SQLException { + realConnection.setTypeMap(typeNameToClass); + } + + @Override + public void setHoldability(final int holdabilityId) throws SQLException { + realConnection.setHoldability(holdabilityId); + } + + @Override + public int getHoldability() throws SQLException { + return realConnection.getHoldability(); + } + + @Override + public Savepoint setSavepoint() throws SQLException { + return realConnection.setSavepoint(); + } + + @Override + public Savepoint setSavepoint(final String savepointName) throws SQLException { + return realConnection.setSavepoint(savepointName); + } + + @Override + public void rollback(final Savepoint savepoint) throws SQLException { + realConnection.rollback(savepoint); + } + + @Override + public void releaseSavepoint(final Savepoint savepoint) throws SQLException { + realConnection.releaseSavepoint(savepoint); + } + + @Override + public Statement createStatement(final int resultSetType, + final int resultSetConcurrency, + final int resultSetHoldability) throws SQLException { + return realConnection.createStatement(resultSetType, resultSetConcurrency, + resultSetHoldability); + } + + @Override + public PreparedStatement prepareStatement(final String sqlQuery, + final int resultSetType, + final int resultSetConcurrency, + final int resultSetHoldability) throws SQLException { + return realConnection.prepareStatement(sqlQuery, resultSetType, resultSetConcurrency, + resultSetHoldability); + } + + @Override + public CallableStatement prepareCall(final String sqlQuery, + final int resultSetType, + final int resultSetConcurrency, + final int resultSetHoldability) throws SQLException { + return realConnection.prepareCall(sqlQuery, resultSetType, resultSetConcurrency, + resultSetHoldability); + } + + @Override + public PreparedStatement prepareStatement(final String sqlQuery, final int autoGeneratedKeysId) + throws SQLException { + return realConnection.prepareStatement(sqlQuery, autoGeneratedKeysId); + } + + @Override + public PreparedStatement prepareStatement(final String sqlQuery, final int... columnIndices) + throws SQLException { + return realConnection.prepareStatement(sqlQuery, columnIndices); + } + + @Override + public PreparedStatement prepareStatement(final String sqlQuery, final String... columnNames) + throws SQLException { + return realConnection.prepareStatement(sqlQuery, columnNames); + } + + @Override + public Clob createClob() throws SQLException { + return realConnection.createClob(); + } + + @Override + public Blob createBlob() throws SQLException { + return realConnection.createBlob(); + } + + @Override + public NClob createNClob() throws SQLException { + return realConnection.createNClob(); + } + + @Override + public SQLXML createSQLXML() throws SQLException { + return realConnection.createSQLXML(); + } + + @Override + public boolean isValid(final int timeout) throws SQLException { + return realConnection.isValid(timeout); + } + + @Override + public void setClientInfo(final String propertyName, final String propertyValue) + throws SQLClientInfoException { + realConnection.setClientInfo(propertyName, propertyValue); + } + + @Override + public void setClientInfo(final Properties properties) throws SQLClientInfoException { + realConnection.setClientInfo(properties); + } + + @Override + public String getClientInfo(final String propertyName) throws SQLException { + return realConnection.getClientInfo(propertyName); + } + + @Override + public Properties getClientInfo() throws SQLException { + return realConnection.getClientInfo(); + } + + @Override + public Array createArrayOf(final String typeName, final Object... elements) throws SQLException { + return realConnection.createArrayOf(typeName, elements); + } + + @Override + public Struct createStruct(final String typeName, final Object... attributes) + throws SQLException { + return realConnection.createStruct(typeName, attributes); + } + + @Override + public void setSchema(final String schemaName) throws SQLException { + realConnection.setSchema(schemaName); + } + + @Override + public String getSchema() throws SQLException { + return realConnection.getSchema(); + } + + @Override + public void abort(final Executor executor) throws SQLException { + realConnection.abort(executor); + } + + @Override + public void setNetworkTimeout(final Executor executor, final int timeoutInMillis) + throws SQLException { + realConnection.setNetworkTimeout(executor, timeoutInMillis); + } + + @Override + public int getNetworkTimeout() throws SQLException { + return realConnection.getNetworkTimeout(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java new file mode 100644 index 00000000000..324f991ef09 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.apache.arrow.flight.sql.FlightSqlColumnMetadata; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.proto.Common; +import org.apache.calcite.avatica.proto.Common.ColumnMetaData.Builder; + +/** + * Convert Fields To Column MetaData List functions. + */ +public final class ConvertUtils { + + private ConvertUtils() { + } + + /** + * Convert Fields To Column MetaData List functions. + * + * @param fields list of {@link Field}. + * @return list of {@link ColumnMetaData}. + */ + public static List convertArrowFieldsToColumnMetaDataList(final List fields) { + return Stream.iterate(0, Math::incrementExact).limit(fields.size()) + .map(index -> { + final Field field = fields.get(index); + final ArrowType fieldType = field.getType(); + + final Builder builder = Common.ColumnMetaData.newBuilder() + .setOrdinal(index) + .setColumnName(field.getName()) + .setLabel(field.getName()); + + setOnColumnMetaDataBuilder(builder, field.getMetadata()); + + builder.setType(Common.AvaticaType.newBuilder() + .setId(SqlTypes.getSqlTypeIdFromArrowType(fieldType)) + .setName(SqlTypes.getSqlTypeNameFromArrowType(fieldType)) + .build()); + + return ColumnMetaData.fromProto(builder.build()); + }).collect(Collectors.toList()); + } + + /** + * Set on Column MetaData Builder. + * + * @param builder {@link Builder} + * @param metadataMap {@link Map} + */ + public static void setOnColumnMetaDataBuilder(final Builder builder, + final Map metadataMap) { + final FlightSqlColumnMetadata columnMetadata = new FlightSqlColumnMetadata(metadataMap); + final String catalogName = columnMetadata.getCatalogName(); + if (catalogName != null) { + builder.setCatalogName(catalogName); + } + final String schemaName = columnMetadata.getSchemaName(); + if (schemaName != null) { + builder.setSchemaName(schemaName); + } + final String tableName = columnMetadata.getTableName(); + if (tableName != null) { + builder.setTableName(tableName); + } + + final Integer precision = columnMetadata.getPrecision(); + if (precision != null) { + builder.setPrecision(precision); + } + final Integer scale = columnMetadata.getScale(); + if (scale != null) { + builder.setScale(scale); + } + + final Boolean isAutoIncrement = columnMetadata.isAutoIncrement(); + if (isAutoIncrement != null) { + builder.setAutoIncrement(isAutoIncrement); + } + final Boolean caseSensitive = columnMetadata.isCaseSensitive(); + if (caseSensitive != null) { + builder.setCaseSensitive(caseSensitive); + } + final Boolean readOnly = columnMetadata.isReadOnly(); + if (readOnly != null) { + builder.setReadOnly(readOnly); + } + final Boolean searchable = columnMetadata.isSearchable(); + if (searchable != null) { + builder.setSearchable(searchable); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/DateTimeUtils.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/DateTimeUtils.java new file mode 100644 index 00000000000..dd94a09256d --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/DateTimeUtils.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY; + +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.util.Calendar; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; + +/** + * Datetime utility functions. + */ +public class DateTimeUtils { + private DateTimeUtils() { + // Prevent instantiation. + } + + /** + * Subtracts given Calendar's TimeZone offset from epoch milliseconds. + */ + public static long applyCalendarOffset(long milliseconds, Calendar calendar) { + if (calendar == null) { + calendar = Calendar.getInstance(); + } + + final TimeZone tz = calendar.getTimeZone(); + final TimeZone defaultTz = TimeZone.getDefault(); + + if (tz != defaultTz) { + milliseconds -= tz.getOffset(milliseconds) - defaultTz.getOffset(milliseconds); + } + + return milliseconds; + } + + + /** + * Converts Epoch millis to a {@link Timestamp} object. + * + * @param millisWithCalendar the Timestamp in Epoch millis + * @return a {@link Timestamp} object representing the given Epoch millis + */ + public static Timestamp getTimestampValue(long millisWithCalendar) { + long milliseconds = millisWithCalendar; + if (milliseconds < 0) { + // LocalTime#ofNanoDay only accepts positive values + milliseconds -= ((milliseconds / MILLIS_PER_DAY) - 1) * MILLIS_PER_DAY; + } + + return Timestamp.valueOf( + LocalDateTime.of( + LocalDate.ofEpochDay(millisWithCalendar / MILLIS_PER_DAY), + LocalTime.ofNanoOfDay(TimeUnit.MILLISECONDS.toNanos(milliseconds % MILLIS_PER_DAY))) + ); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueue.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueue.java new file mode 100644 index 00000000000..e1d770800e4 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueue.java @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static java.lang.String.format; +import static java.util.Collections.synchronizedSet; +import static org.apache.arrow.util.Preconditions.checkNotNull; +import static org.apache.arrow.util.Preconditions.checkState; + +import java.sql.SQLException; +import java.sql.SQLTimeoutException; +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletionService; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorCompletionService; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightStream; +import org.apache.calcite.avatica.AvaticaConnection; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Auxiliary class used to handle consuming of multiple {@link FlightStream}. + *

+ * The usage follows this routine: + *

    + *
  1. Create a FlightStreamQueue;
  2. + *
  3. Call enqueue(FlightStream) for all streams to be consumed;
  4. + *
  5. Call next() to get a FlightStream that is ready to consume
  6. + *
  7. Consume the given FlightStream and add it back to the queue - call enqueue(FlightStream)
  8. + *
  9. Repeat from (3) until next() returns null.
  10. + *
+ */ +public class FlightStreamQueue implements AutoCloseable { + private static final Logger LOGGER = LoggerFactory.getLogger(FlightStreamQueue.class); + private final CompletionService completionService; + private final Set> futures = synchronizedSet(new HashSet<>()); + private final Set allStreams = synchronizedSet(new HashSet<>()); + private final AtomicBoolean closed = new AtomicBoolean(); + + /** + * Instantiate a new FlightStreamQueue. + */ + protected FlightStreamQueue(final CompletionService executorService) { + completionService = checkNotNull(executorService); + } + + /** + * Creates a new {@link FlightStreamQueue} from the provided {@link ExecutorService}. + * + * @param service the service from which to create a new queue. + * @return a new queue. + */ + public static FlightStreamQueue createNewQueue(final ExecutorService service) { + return new FlightStreamQueue(new ExecutorCompletionService<>(service)); + } + + /** + * Gets whether this queue is closed. + * + * @return a boolean indicating whether this resource is closed. + */ + public boolean isClosed() { + return closed.get(); + } + + /** + * Auxiliary functional interface for getting ready-to-consume FlightStreams. + */ + @FunctionalInterface + interface FlightStreamSupplier { + Future get() throws SQLException; + } + + private FlightStream next(final FlightStreamSupplier flightStreamSupplier) throws SQLException { + checkOpen(); + while (!futures.isEmpty()) { + final Future future = flightStreamSupplier.get(); + futures.remove(future); + try { + final FlightStream stream = future.get(); + if (stream.getRoot().getRowCount() > 0) { + return stream; + } + } catch (final ExecutionException | InterruptedException | CancellationException e) { + throw AvaticaConnection.HELPER.wrap(e.getMessage(), e); + } + } + return null; + } + + /** + * Blocking request with timeout to get the next ready FlightStream in queue. + * + * @param timeoutValue the amount of time to be waited + * @param timeoutUnit the timeoutValue time unit + * @return a FlightStream that is ready to consume or null if all FlightStreams are ended. + */ + public FlightStream next(final long timeoutValue, final TimeUnit timeoutUnit) + throws SQLException { + return next(() -> { + try { + final Future future = completionService.poll(timeoutValue, timeoutUnit); + if (future != null) { + return future; + } + } catch (final InterruptedException e) { + throw new SQLTimeoutException("Query was interrupted", e); + } + + throw new SQLTimeoutException( + String.format("Query timed out after %d %s", timeoutValue, timeoutUnit)); + }); + } + + /** + * Blocking request to get the next ready FlightStream in queue. + * + * @return a FlightStream that is ready to consume or null if all FlightStreams are ended. + */ + public FlightStream next() throws SQLException { + return next(() -> { + try { + return completionService.take(); + } catch (final InterruptedException e) { + throw AvaticaConnection.HELPER.wrap(e.getMessage(), e); + } + }); + } + + /** + * Checks if this queue is open. + */ + public synchronized void checkOpen() { + checkState(!isClosed(), format("%s closed", this.getClass().getSimpleName())); + } + + /** + * Readily adds given {@link FlightStream}s to the queue. + */ + public void enqueue(final Collection flightStreams) { + flightStreams.forEach(this::enqueue); + } + + /** + * Adds given {@link FlightStream} to the queue. + */ + public synchronized void enqueue(final FlightStream flightStream) { + checkNotNull(flightStream); + checkOpen(); + allStreams.add(flightStream); + futures.add(completionService.submit(() -> { + // `FlightStream#next` will block until new data can be read or stream is over. + flightStream.next(); + return flightStream; + })); + } + + private static boolean isCallStatusCancelled(final Exception e) { + return e.getCause() instanceof FlightRuntimeException && + ((FlightRuntimeException) e.getCause()).status().code() == CallStatus.CANCELLED.code(); + } + + @Override + public synchronized void close() throws SQLException { + final Set exceptions = new HashSet<>(); + if (isClosed()) { + return; + } + try { + for (final FlightStream flightStream : allStreams) { + try { + flightStream.cancel("Cancelling this FlightStream.", null); + } catch (final Exception e) { + final String errorMsg = "Failed to cancel a FlightStream."; + LOGGER.error(errorMsg, e); + exceptions.add(new SQLException(errorMsg, e)); + } + } + futures.forEach(future -> { + try { + // TODO: Consider adding a hardcoded timeout? + future.get(); + } catch (final InterruptedException | ExecutionException e) { + // Ignore if future is already cancelled + if (!isCallStatusCancelled(e)) { + final String errorMsg = "Failed consuming a future during close."; + LOGGER.error(errorMsg, e); + exceptions.add(new SQLException(errorMsg, e)); + } + } + }); + for (final FlightStream flightStream : allStreams) { + try { + flightStream.close(); + } catch (final Exception e) { + final String errorMsg = "Failed to close a FlightStream."; + LOGGER.error(errorMsg, e); + exceptions.add(new SQLException(errorMsg, e)); + } + } + } finally { + allStreams.clear(); + futures.clear(); + closed.set(true); + } + if (!exceptions.isEmpty()) { + final SQLException sqlException = new SQLException("Failed to close streams."); + exceptions.forEach(sqlException::setNextException); + throw sqlException; + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/IntervalStringUtils.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/IntervalStringUtils.java new file mode 100644 index 00000000000..05643274ac3 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/IntervalStringUtils.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import org.apache.arrow.vector.util.DateUtility; +import org.joda.time.Period; + +/** + * Utility class to format periods similar to Oracle's representation + * of "INTERVAL * to *" data type. + */ +public final class IntervalStringUtils { + + /** + * Constructor Method of class. + */ + private IntervalStringUtils( ) {} + + /** + * Formats a period similar to Oracle INTERVAL YEAR TO MONTH data type
. + * For example, the string "+21-02" defines an interval of 21 years and 2 months. + */ + public static String formatIntervalYear(final Period p) { + long months = p.getYears() * (long) DateUtility.yearsToMonths + p.getMonths(); + boolean neg = false; + if (months < 0) { + months = -months; + neg = true; + } + final int years = (int) (months / DateUtility.yearsToMonths); + months = months % DateUtility.yearsToMonths; + + return String.format("%c%03d-%02d", neg ? '-' : '+', years, months); + } + + /** + * Formats a period similar to Oracle INTERVAL DAY TO SECOND data type.
. + * For example, the string "-001 18:25:16.766" defines an interval of + * - 1 day 18 hours 25 minutes 16 seconds and 766 milliseconds. + */ + public static String formatIntervalDay(final Period p) { + long millis = p.getDays() * (long) DateUtility.daysToStandardMillis + millisFromPeriod(p); + + boolean neg = false; + if (millis < 0) { + millis = -millis; + neg = true; + } + + final int days = (int) (millis / DateUtility.daysToStandardMillis); + millis = millis % DateUtility.daysToStandardMillis; + + final int hours = (int) (millis / DateUtility.hoursToMillis); + millis = millis % DateUtility.hoursToMillis; + + final int minutes = (int) (millis / DateUtility.minutesToMillis); + millis = millis % DateUtility.minutesToMillis; + + final int seconds = (int) (millis / DateUtility.secondsToMillis); + millis = millis % DateUtility.secondsToMillis; + + return String.format("%c%03d %02d:%02d:%02d.%03d", neg ? '-' : '+', days, hours, minutes, seconds, millis); + } + + public static int millisFromPeriod(Period period) { + return period.getHours() * DateUtility.hoursToMillis + period.getMinutes() * DateUtility.minutesToMillis + + period.getSeconds() * DateUtility.secondsToMillis + period.getMillis(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/SqlTypes.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/SqlTypes.java new file mode 100644 index 00000000000..85c3964303c --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/SqlTypes.java @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.sql.Types; +import java.util.HashMap; +import java.util.Map; + +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.pojo.ArrowType; + +/** + * SQL Types utility functions. + */ +public class SqlTypes { + private static final Map typeIdToName = new HashMap<>(); + + static { + typeIdToName.put(Types.BIT, "BIT"); + typeIdToName.put(Types.TINYINT, "TINYINT"); + typeIdToName.put(Types.SMALLINT, "SMALLINT"); + typeIdToName.put(Types.INTEGER, "INTEGER"); + typeIdToName.put(Types.BIGINT, "BIGINT"); + typeIdToName.put(Types.FLOAT, "FLOAT"); + typeIdToName.put(Types.REAL, "REAL"); + typeIdToName.put(Types.DOUBLE, "DOUBLE"); + typeIdToName.put(Types.NUMERIC, "NUMERIC"); + typeIdToName.put(Types.DECIMAL, "DECIMAL"); + typeIdToName.put(Types.CHAR, "CHAR"); + typeIdToName.put(Types.VARCHAR, "VARCHAR"); + typeIdToName.put(Types.LONGVARCHAR, "LONGVARCHAR"); + typeIdToName.put(Types.DATE, "DATE"); + typeIdToName.put(Types.TIME, "TIME"); + typeIdToName.put(Types.TIMESTAMP, "TIMESTAMP"); + typeIdToName.put(Types.BINARY, "BINARY"); + typeIdToName.put(Types.VARBINARY, "VARBINARY"); + typeIdToName.put(Types.LONGVARBINARY, "LONGVARBINARY"); + typeIdToName.put(Types.NULL, "NULL"); + typeIdToName.put(Types.OTHER, "OTHER"); + typeIdToName.put(Types.JAVA_OBJECT, "JAVA_OBJECT"); + typeIdToName.put(Types.DISTINCT, "DISTINCT"); + typeIdToName.put(Types.STRUCT, "STRUCT"); + typeIdToName.put(Types.ARRAY, "ARRAY"); + typeIdToName.put(Types.BLOB, "BLOB"); + typeIdToName.put(Types.CLOB, "CLOB"); + typeIdToName.put(Types.REF, "REF"); + typeIdToName.put(Types.DATALINK, "DATALINK"); + typeIdToName.put(Types.BOOLEAN, "BOOLEAN"); + typeIdToName.put(Types.ROWID, "ROWID"); + typeIdToName.put(Types.NCHAR, "NCHAR"); + typeIdToName.put(Types.NVARCHAR, "NVARCHAR"); + typeIdToName.put(Types.LONGNVARCHAR, "LONGNVARCHAR"); + typeIdToName.put(Types.NCLOB, "NCLOB"); + typeIdToName.put(Types.SQLXML, "SQLXML"); + typeIdToName.put(Types.REF_CURSOR, "REF_CURSOR"); + typeIdToName.put(Types.TIME_WITH_TIMEZONE, "TIME_WITH_TIMEZONE"); + typeIdToName.put(Types.TIMESTAMP_WITH_TIMEZONE, "TIMESTAMP_WITH_TIMEZONE"); + } + + /** + * Convert given {@link ArrowType} to its corresponding SQL type name. + * + * @param arrowType type to convert from + * @return corresponding SQL type name. + * @see java.sql.Types + */ + public static String getSqlTypeNameFromArrowType(ArrowType arrowType) { + final int typeId = getSqlTypeIdFromArrowType(arrowType); + return typeIdToName.get(typeId); + } + + + /** + * Convert given {@link ArrowType} to its corresponding SQL type ID. + * + * @param arrowType type to convert from + * @return corresponding SQL type ID. + * @see java.sql.Types + */ + public static int getSqlTypeIdFromArrowType(ArrowType arrowType) { + final ArrowType.ArrowTypeID typeID = arrowType.getTypeID(); + switch (typeID) { + case Int: + final int bitWidth = ((ArrowType.Int) arrowType).getBitWidth(); + switch (bitWidth) { + case 8: + return Types.TINYINT; + case 16: + return Types.SMALLINT; + case 32: + return Types.INTEGER; + case 64: + return Types.BIGINT; + default: + break; + } + break; + case Binary: + return Types.VARBINARY; + case FixedSizeBinary: + return Types.BINARY; + case LargeBinary: + return Types.LONGVARBINARY; + case Utf8: + return Types.VARCHAR; + case LargeUtf8: + return Types.LONGVARCHAR; + case Date: + return Types.DATE; + case Time: + return Types.TIME; + case Timestamp: + return Types.TIMESTAMP; + case Bool: + return Types.BOOLEAN; + case Decimal: + return Types.DECIMAL; + case FloatingPoint: + final FloatingPointPrecision floatingPointPrecision = + ((ArrowType.FloatingPoint) arrowType).getPrecision(); + switch (floatingPointPrecision) { + case DOUBLE: + return Types.DOUBLE; + case SINGLE: + return Types.FLOAT; + default: + break; + } + break; + case List: + case FixedSizeList: + case LargeList: + return Types.ARRAY; + case Struct: + case Duration: + case Interval: + case Map: + case Union: + return Types.JAVA_OBJECT; + case NONE: + case Null: + return Types.NULL; + default: + break; + } + + throw new IllegalArgumentException("Unsupported ArrowType " + arrowType); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/UrlParser.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/UrlParser.java new file mode 100644 index 00000000000..e52251f5391 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/UrlParser.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; +import java.util.HashMap; +import java.util.Map; + +/** + * URL Parser for extracting key values from a connection string. + */ +public final class UrlParser { + private UrlParser() { + } + + /** + * Parse URL key value parameters. + * + *

URL-decodes keys and values. + * + * @param url {@link String} + * @return {@link Map} + */ + public static Map parse(String url, String separator) { + Map resultMap = new HashMap<>(); + if (url != null) { + String[] keyValues = url.split(separator); + + for (String keyValue : keyValues) { + try { + int separatorKey = keyValue.indexOf("="); // Find the first equal sign to split key and value. + if (separatorKey != -1) { // Avoid crashes when not finding an equal sign in the property value. + String key = keyValue.substring(0, separatorKey); + key = URLDecoder.decode(key, "UTF-8"); + String value = ""; + if (!keyValue.endsWith("=")) { // Avoid crashes for empty values. + value = keyValue.substring(separatorKey + 1); + } + value = URLDecoder.decode(value, "UTF-8"); + resultMap.put(key, value); + } + } catch (UnsupportedEncodingException e) { + throw new RuntimeException(e); + } + } + } + return resultMap; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/VectorSchemaRootTransformer.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/VectorSchemaRootTransformer.java new file mode 100644 index 00000000000..3bab918c83a --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/VectorSchemaRootTransformer.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.BaseVariableWidthVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * Converts Arrow's {@link VectorSchemaRoot} format to one JDBC would expect. + */ +@FunctionalInterface +public interface VectorSchemaRootTransformer { + VectorSchemaRoot transform(VectorSchemaRoot originalRoot, VectorSchemaRoot transformedRoot) + throws Exception; + + /** + * Transformer's helper class; builds a new {@link VectorSchemaRoot}. + */ + class Builder { + + private final Schema schema; + private final BufferAllocator bufferAllocator; + private final List newFields = new ArrayList<>(); + private final Collection tasks = new ArrayList<>(); + + public Builder(final Schema schema, final BufferAllocator bufferAllocator) { + this.schema = schema; + this.bufferAllocator = bufferAllocator + .newChildAllocator("VectorSchemaRootTransformer", 0, bufferAllocator.getLimit()); + } + + /** + * Add task to transform a vector to a new vector renaming it. + * This also adds transformedVectorName to the transformed {@link VectorSchemaRoot} schema. + * + * @param originalVectorName Name of the original vector to be transformed. + * @param transformedVectorName Name of the vector that is the result of the transformation. + * @return a VectorSchemaRoot instance with a task to rename a field vector. + */ + public Builder renameFieldVector(final String originalVectorName, + final String transformedVectorName) { + tasks.add((originalRoot, transformedRoot) -> { + final FieldVector originalVector = originalRoot.getVector(originalVectorName); + final FieldVector transformedVector = transformedRoot.getVector(transformedVectorName); + + final ArrowType originalType = originalVector.getField().getType(); + final ArrowType transformedType = transformedVector.getField().getType(); + if (!originalType.equals(transformedType)) { + throw new IllegalArgumentException(String.format( + "Can not transfer vector with field type %s to %s", originalType, transformedType)); + } + + if (originalVector instanceof BaseVariableWidthVector) { + ((BaseVariableWidthVector) originalVector).transferTo( + ((BaseVariableWidthVector) transformedVector)); + } else if (originalVector instanceof BaseFixedWidthVector) { + ((BaseFixedWidthVector) originalVector).transferTo( + ((BaseFixedWidthVector) transformedVector)); + } else { + throw new IllegalStateException(String.format( + "Can not transfer vector of type %s", originalVector.getClass())); + } + }); + + final Field originalField = schema.findField(originalVectorName); + newFields.add(new Field( + transformedVectorName, + new FieldType(originalField.isNullable(), originalField.getType(), + originalField.getDictionary(), originalField.getMetadata()), + originalField.getChildren()) + ); + + return this; + } + + /** + * Adds an empty field to the transformed {@link VectorSchemaRoot} schema. + * + * @param fieldName Name of the field to be added. + * @param fieldType Type of the field to be added. + * @return a VectorSchemaRoot instance with the current tasks. + */ + public Builder addEmptyField(final String fieldName, final Types.MinorType fieldType) { + newFields.add(Field.nullable(fieldName, fieldType.getType())); + + return this; + } + + /** + * Adds an empty field to the transformed {@link VectorSchemaRoot} schema. + * + * @param fieldName Name of the field to be added. + * @param fieldType Type of the field to be added. + * @return a VectorSchemaRoot instance with the current tasks. + */ + public Builder addEmptyField(final String fieldName, final ArrowType fieldType) { + newFields.add(Field.nullable(fieldName, fieldType)); + + return this; + } + + public VectorSchemaRootTransformer build() { + return (originalRoot, transformedRoot) -> { + if (transformedRoot == null) { + transformedRoot = VectorSchemaRoot.create(new Schema(newFields), bufferAllocator); + } + + for (final Task task : tasks) { + task.run(originalRoot, transformedRoot); + } + + transformedRoot.setRowCount(originalRoot.getRowCount()); + + originalRoot.clear(); + return transformedRoot; + }; + } + + /** + * Functional interface used to a task to transform a VectorSchemaRoot into a new VectorSchemaRoot. + */ + @FunctionalInterface + interface Task { + void run(VectorSchemaRoot originalRoot, VectorSchemaRoot transformedRoot); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/main/resources/META-INF/services/java.sql.Driver b/java/flight/flight-sql-jdbc-driver/src/main/resources/META-INF/services/java.sql.Driver new file mode 100644 index 00000000000..83cfb23427f --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/main/resources/META-INF/services/java.sql.Driver @@ -0,0 +1,15 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +org.apache.arrow.driver.jdbc.ArrowFlightJdbcDriver \ No newline at end of file diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowDatabaseMetadataTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowDatabaseMetadataTest.java new file mode 100644 index 00000000000..0d930f4c44e --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowDatabaseMetadataTest.java @@ -0,0 +1,1423 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static com.google.protobuf.ByteString.copyFrom; +import static java.lang.String.format; +import static java.sql.Types.BIGINT; +import static java.sql.Types.BIT; +import static java.sql.Types.INTEGER; +import static java.sql.Types.JAVA_OBJECT; +import static java.util.Collections.singletonList; +import static java.util.stream.Collectors.toList; +import static java.util.stream.IntStream.range; +import static org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer.serializeSchema; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCrossReference; +import static org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportsConvert.SQL_CONVERT_BIGINT_VALUE; +import static org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportsConvert.SQL_CONVERT_BIT_VALUE; +import static org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportsConvert.SQL_CONVERT_INTEGER_VALUE; +import static org.hamcrest.CoreMatchers.is; + +import java.sql.Connection; +import java.sql.DatabaseMetaData; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Consumer; + +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.driver.jdbc.utils.ResultSetTestUtils; +import org.apache.arrow.driver.jdbc.utils.ThrowableAssertionUtils; +import org.apache.arrow.flight.FlightProducer.ServerStreamListener; +import org.apache.arrow.flight.sql.FlightSqlProducer.Schemas; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetExportedKeys; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetImportedKeys; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetPrimaryKeys; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTableTypes; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables; +import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedSubqueries; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Message; + +/** + * Class containing the tests from the {@link ArrowDatabaseMetadata}. + */ +@SuppressWarnings("DoubleBraceInitialization") +public class ArrowDatabaseMetadataTest { + public static final boolean EXPECTED_MAX_ROW_SIZE_INCLUDES_BLOBS = false; + private static final MockFlightSqlProducer FLIGHT_SQL_PRODUCER = new MockFlightSqlProducer(); + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE = FlightServerTestRule + .createStandardTestRule(FLIGHT_SQL_PRODUCER); + private static final int ROW_COUNT = 10; + private static final List> EXPECTED_GET_CATALOGS_RESULTS = + range(0, ROW_COUNT) + .mapToObj(i -> format("catalog #%d", i)) + .map(Object.class::cast) + .map(Collections::singletonList) + .collect(toList()); + private static final List> EXPECTED_GET_TABLE_TYPES_RESULTS = + range(0, ROW_COUNT) + .mapToObj(i -> format("table_type #%d", i)) + .map(Object.class::cast) + .map(Collections::singletonList) + .collect(toList()); + private static final List> EXPECTED_GET_TABLES_RESULTS = + range(0, ROW_COUNT) + .mapToObj(i -> new Object[] { + format("catalog_name #%d", i), + format("db_schema_name #%d", i), + format("table_name #%d", i), + format("table_type #%d", i), + // TODO Add these fields to FlightSQL, as it's currently not possible to fetch them. + null, null, null, null, null, null}) + .map(Arrays::asList) + .collect(toList()); + private static final List> EXPECTED_GET_SCHEMAS_RESULTS = + range(0, ROW_COUNT) + .mapToObj(i -> new Object[] { + format("db_schema_name #%d", i), + format("catalog_name #%d", i)}) + .map(Arrays::asList) + .collect(toList()); + private static final List> EXPECTED_GET_EXPORTED_AND_IMPORTED_KEYS_RESULTS = + range(0, ROW_COUNT) + .mapToObj(i -> new Object[] { + format("pk_catalog_name #%d", i), + format("pk_db_schema_name #%d", i), + format("pk_table_name #%d", i), + format("pk_column_name #%d", i), + format("fk_catalog_name #%d", i), + format("fk_db_schema_name #%d", i), + format("fk_table_name #%d", i), + format("fk_column_name #%d", i), + i, + format("fk_key_name #%d", i), + format("pk_key_name #%d", i), + (byte) i, + (byte) i, + // TODO Add this field to FlightSQL, as it's currently not possible to fetch it. + null}) + .map(Arrays::asList) + .collect(toList()); + private static final List> EXPECTED_CROSS_REFERENCE_RESULTS = + EXPECTED_GET_EXPORTED_AND_IMPORTED_KEYS_RESULTS; + private static final List> EXPECTED_PRIMARY_KEYS_RESULTS = + range(0, ROW_COUNT) + .mapToObj(i -> new Object[] { + format("catalog_name #%d", i), + format("db_schema_name #%d", i), + format("table_name #%d", i), + format("column_name #%d", i), + i, + format("key_name #%d", i)}) + .map(Arrays::asList) + .collect(toList()); + private static final List FIELDS_GET_IMPORTED_EXPORTED_KEYS = ImmutableList.of( + "PKTABLE_CAT", "PKTABLE_SCHEM", "PKTABLE_NAME", + "PKCOLUMN_NAME", "FKTABLE_CAT", "FKTABLE_SCHEM", + "FKTABLE_NAME", "FKCOLUMN_NAME", "KEY_SEQ", + "FK_NAME", "PK_NAME", "UPDATE_RULE", "DELETE_RULE", + "DEFERRABILITY"); + private static final List FIELDS_GET_CROSS_REFERENCE = FIELDS_GET_IMPORTED_EXPORTED_KEYS; + private static final String TARGET_TABLE = "TARGET_TABLE"; + private static final String TARGET_FOREIGN_TABLE = "FOREIGN_TABLE"; + private static final String EXPECTED_DATABASE_PRODUCT_NAME = "Test Server Name"; + private static final String EXPECTED_DATABASE_PRODUCT_VERSION = "v0.0.1-alpha"; + private static final String EXPECTED_IDENTIFIER_QUOTE_STRING = "\""; + private static final boolean EXPECTED_IS_READ_ONLY = true; + private static final String EXPECTED_SQL_KEYWORDS = + "ADD, ADD CONSTRAINT, ALTER, ALTER TABLE, ANY, USER, TABLE"; + private static final String EXPECTED_NUMERIC_FUNCTIONS = + "ABS(), ACOS(), ASIN(), ATAN(), CEIL(), CEILING(), COT()"; + private static final String EXPECTED_STRING_FUNCTIONS = + "ASCII, CHAR, CHARINDEX, CONCAT, CONCAT_WS, FORMAT, LEFT"; + private static final String EXPECTED_SYSTEM_FUNCTIONS = + "CAST, CONVERT, CHOOSE, ISNULL, IS_NUMERIC, IIF, TRY_CAST"; + private static final String EXPECTED_TIME_DATE_FUNCTIONS = + "GETDATE(), DATEPART(), DATEADD(), DATEDIFF()"; + private static final String EXPECTED_SEARCH_STRING_ESCAPE = "\\"; + private static final String EXPECTED_EXTRA_NAME_CHARACTERS = ""; + private static final boolean EXPECTED_SUPPORTS_COLUMN_ALIASING = true; + private static final boolean EXPECTED_NULL_PLUS_NULL_IS_NULL = true; + private static final boolean EXPECTED_SQL_SUPPORTS_CONVERT = true; + private static final boolean EXPECTED_INVALID_SQL_SUPPORTS_CONVERT = false; + private static final boolean EXPECTED_SUPPORTS_TABLE_CORRELATION_NAMES = true; + private static final boolean EXPECTED_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES = false; + private static final boolean EXPECTED_EXPRESSIONS_IN_ORDER_BY = true; + private static final boolean EXPECTED_SUPPORTS_ORDER_BY_UNRELATED = true; + private static final boolean EXPECTED_SUPPORTS_GROUP_BY = true; + private static final boolean EXPECTED_SUPPORTS_GROUP_BY_UNRELATED = true; + private static final boolean EXPECTED_SUPPORTS_LIKE_ESCAPE_CLAUSE = true; + private static final boolean EXPECTED_NON_NULLABLE_COLUMNS = true; + private static final boolean EXPECTED_MINIMUM_SQL_GRAMMAR = true; + private static final boolean EXPECTED_CORE_SQL_GRAMMAR = true; + private static final boolean EXPECTED_EXTEND_SQL_GRAMMAR = false; + private static final boolean EXPECTED_ANSI92_ENTRY_LEVEL_SQL = true; + private static final boolean EXPECTED_ANSI92_INTERMEDIATE_SQL = true; + private static final boolean EXPECTED_ANSI92_FULL_SQL = false; + private static final String EXPECTED_SCHEMA_TERM = "schema"; + private static final String EXPECTED_PROCEDURE_TERM = "procedure"; + private static final String EXPECTED_CATALOG_TERM = "catalog"; + private static final boolean EXPECTED_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY = true; + private static final boolean EXPECTED_SUPPORTS_OUTER_JOINS = true; + private static final boolean EXPECTED_SUPPORTS_FULL_OUTER_JOINS = true; + private static final boolean EXPECTED_SUPPORTS_LIMITED_JOINS = false; + private static final boolean EXPECTED_CATALOG_AT_START = true; + private static final boolean EXPECTED_SCHEMAS_IN_PROCEDURE_CALLS = true; + private static final boolean EXPECTED_SCHEMAS_IN_INDEX_DEFINITIONS = true; + private static final boolean EXPECTED_SCHEMAS_IN_PRIVILEGE_DEFINITIONS = false; + private static final boolean EXPECTED_CATALOGS_IN_INDEX_DEFINITIONS = true; + private static final boolean EXPECTED_CATALOGS_IN_PRIVILEGE_DEFINITIONS = false; + private static final boolean EXPECTED_POSITIONED_DELETE = true; + private static final boolean EXPECTED_POSITIONED_UPDATE = false; + private static final boolean EXPECTED_TYPE_FORWARD_ONLY = true; + private static final boolean EXPECTED_TYPE_SCROLL_INSENSITIVE = true; + private static final boolean EXPECTED_TYPE_SCROLL_SENSITIVE = false; + private static final boolean EXPECTED_SELECT_FOR_UPDATE_SUPPORTED = false; + private static final boolean EXPECTED_STORED_PROCEDURES_SUPPORTED = false; + private static final boolean EXPECTED_SUBQUERIES_IN_COMPARISON = true; + private static final boolean EXPECTED_SUBQUERIES_IN_EXISTS = false; + private static final boolean EXPECTED_SUBQUERIES_IN_INS = false; + private static final boolean EXPECTED_SUBQUERIES_IN_QUANTIFIEDS = false; + private static final SqlSupportedSubqueries[] EXPECTED_SUPPORTED_SUBQUERIES = new SqlSupportedSubqueries[] + {SqlSupportedSubqueries.SQL_SUBQUERIES_IN_COMPARISONS}; + private static final boolean EXPECTED_CORRELATED_SUBQUERIES_SUPPORTED = true; + private static final boolean EXPECTED_SUPPORTS_UNION = true; + private static final boolean EXPECTED_SUPPORTS_UNION_ALL = true; + private static final int EXPECTED_MAX_BINARY_LITERAL_LENGTH = 0; + private static final int EXPECTED_MAX_CHAR_LITERAL_LENGTH = 0; + private static final int EXPECTED_MAX_COLUMN_NAME_LENGTH = 1024; + private static final int EXPECTED_MAX_COLUMNS_IN_GROUP_BY = 0; + private static final int EXPECTED_MAX_COLUMNS_IN_INDEX = 0; + private static final int EXPECTED_MAX_COLUMNS_IN_ORDER_BY = 0; + private static final int EXPECTED_MAX_COLUMNS_IN_SELECT = 0; + private static final int EXPECTED_MAX_CONNECTIONS = 0; + private static final int EXPECTED_MAX_CURSOR_NAME_LENGTH = 1024; + private static final int EXPECTED_MAX_INDEX_LENGTH = 0; + private static final int EXPECTED_SCHEMA_NAME_LENGTH = 1024; + private static final int EXPECTED_MAX_PROCEDURE_NAME_LENGTH = 0; + private static final int EXPECTED_MAX_CATALOG_NAME_LENGTH = 1024; + private static final int EXPECTED_MAX_ROW_SIZE = 0; + private static final int EXPECTED_MAX_STATEMENT_LENGTH = 0; + private static final int EXPECTED_MAX_STATEMENTS = 0; + private static final int EXPECTED_MAX_TABLE_NAME_LENGTH = 1024; + private static final int EXPECTED_MAX_TABLES_IN_SELECT = 0; + private static final int EXPECTED_MAX_USERNAME_LENGTH = 1024; + private static final int EXPECTED_DEFAULT_TRANSACTION_ISOLATION = 0; + private static final boolean EXPECTED_TRANSACTIONS_SUPPORTED = false; + private static final boolean EXPECTED_TRANSACTION_NONE = false; + private static final boolean EXPECTED_TRANSACTION_READ_UNCOMMITTED = false; + private static final boolean EXPECTED_TRANSACTION_READ_COMMITTED = true; + private static final boolean EXPECTED_TRANSACTION_REPEATABLE_READ = false; + private static final boolean EXPECTED_TRANSACTION_SERIALIZABLE = true; + private static final boolean EXPECTED_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT = true; + private static final boolean EXPECTED_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED = false; + private static final boolean EXPECTED_BATCH_UPDATES_SUPPORTED = true; + private static final boolean EXPECTED_SAVEPOINTS_SUPPORTED = false; + private static final boolean EXPECTED_NAMED_PARAMETERS_SUPPORTED = false; + private static final boolean EXPECTED_LOCATORS_UPDATE_COPY = true; + private static final boolean EXPECTED_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED = false; + private static final List> EXPECTED_GET_COLUMNS_RESULTS; + private static Connection connection; + + static { + List expectedGetColumnsDataTypes = Arrays.asList(3, 93, 4); + List expectedGetColumnsTypeName = Arrays.asList("DECIMAL", "TIMESTAMP", "INTEGER"); + List expectedGetColumnsRadix = Arrays.asList(10, null, 10); + List expectedGetColumnsColumnSize = Arrays.asList(5, 29, 10); + List expectedGetColumnsDecimalDigits = Arrays.asList(2, 9, 0); + List expectedGetColumnsIsNullable = Arrays.asList("YES", "YES", "NO"); + EXPECTED_GET_COLUMNS_RESULTS = range(0, ROW_COUNT * 3) + .mapToObj(i -> new Object[] { + format("catalog_name #%d", i / 3), + format("db_schema_name #%d", i / 3), + format("table_name%d", i / 3), + format("column_%d", (i % 3) + 1), + expectedGetColumnsDataTypes.get(i % 3), + expectedGetColumnsTypeName.get(i % 3), + expectedGetColumnsColumnSize.get(i % 3), + null, + expectedGetColumnsDecimalDigits.get(i % 3), + expectedGetColumnsRadix.get(i % 3), + !Objects.equals(expectedGetColumnsIsNullable.get(i % 3), "NO") ? 1 : 0, + null, null, null, null, null, + (i % 3) + 1, + expectedGetColumnsIsNullable.get(i % 3), + null, null, null, null, + "", ""}) + .map(Arrays::asList) + .collect(toList()); + } + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + public final ResultSetTestUtils resultSetTestUtils = new ResultSetTestUtils(collector); + + @BeforeClass + public static void setUpBeforeClass() throws SQLException { + connection = FLIGHT_SERVER_TEST_RULE.getConnection(false); + + final Message commandGetCatalogs = CommandGetCatalogs.getDefaultInstance(); + final Consumer commandGetCatalogsResultProducer = listener -> { + try (final BufferAllocator allocator = new RootAllocator(); + final VectorSchemaRoot root = VectorSchemaRoot.create(Schemas.GET_CATALOGS_SCHEMA, + allocator)) { + final VarCharVector catalogName = (VarCharVector) root.getVector("catalog_name"); + range(0, ROW_COUNT).forEach( + i -> catalogName.setSafe(i, new Text(format("catalog #%d", i)))); + root.setRowCount(ROW_COUNT); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetCatalogs, commandGetCatalogsResultProducer); + + final Message commandGetTableTypes = CommandGetTableTypes.getDefaultInstance(); + final Consumer commandGetTableTypesResultProducer = listener -> { + try (final BufferAllocator allocator = new RootAllocator(); + final VectorSchemaRoot root = VectorSchemaRoot.create(Schemas.GET_TABLE_TYPES_SCHEMA, + allocator)) { + final VarCharVector tableType = (VarCharVector) root.getVector("table_type"); + range(0, ROW_COUNT).forEach( + i -> tableType.setSafe(i, new Text(format("table_type #%d", i)))); + root.setRowCount(ROW_COUNT); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetTableTypes, commandGetTableTypesResultProducer); + + final Message commandGetTables = CommandGetTables.getDefaultInstance(); + final Consumer commandGetTablesResultProducer = listener -> { + try (final BufferAllocator allocator = new RootAllocator(); + final VectorSchemaRoot root = VectorSchemaRoot.create( + Schemas.GET_TABLES_SCHEMA_NO_SCHEMA, allocator)) { + final VarCharVector catalogName = (VarCharVector) root.getVector("catalog_name"); + final VarCharVector schemaName = (VarCharVector) root.getVector("db_schema_name"); + final VarCharVector tableName = (VarCharVector) root.getVector("table_name"); + final VarCharVector tableType = (VarCharVector) root.getVector("table_type"); + range(0, ROW_COUNT) + .peek(i -> catalogName.setSafe(i, new Text(format("catalog_name #%d", i)))) + .peek(i -> schemaName.setSafe(i, new Text(format("db_schema_name #%d", i)))) + .peek(i -> tableName.setSafe(i, new Text(format("table_name #%d", i)))) + .forEach(i -> tableType.setSafe(i, new Text(format("table_type #%d", i)))); + root.setRowCount(ROW_COUNT); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetTables, commandGetTablesResultProducer); + + final Message commandGetTablesWithSchema = CommandGetTables.newBuilder() + .setIncludeSchema(true) + .build(); + final Consumer commandGetTablesWithSchemaResultProducer = listener -> { + try (final BufferAllocator allocator = new RootAllocator(); + final VectorSchemaRoot root = VectorSchemaRoot.create(Schemas.GET_TABLES_SCHEMA, + allocator)) { + final byte[] filledTableSchemaBytes = + copyFrom( + serializeSchema(new Schema(Arrays.asList( + Field.nullable("column_1", ArrowType.Decimal.createDecimal(5, 2, 128)), + Field.nullable("column_2", new ArrowType.Timestamp(TimeUnit.NANOSECOND, "UTC")), + Field.notNullable("column_3", Types.MinorType.INT.getType()))))) + .toByteArray(); + final VarCharVector catalogName = (VarCharVector) root.getVector("catalog_name"); + final VarCharVector schemaName = (VarCharVector) root.getVector("db_schema_name"); + final VarCharVector tableName = (VarCharVector) root.getVector("table_name"); + final VarCharVector tableType = (VarCharVector) root.getVector("table_type"); + final VarBinaryVector tableSchema = (VarBinaryVector) root.getVector("table_schema"); + range(0, ROW_COUNT) + .peek(i -> catalogName.setSafe(i, new Text(format("catalog_name #%d", i)))) + .peek(i -> schemaName.setSafe(i, new Text(format("db_schema_name #%d", i)))) + .peek(i -> tableName.setSafe(i, new Text(format("table_name%d", i)))) + .peek(i -> tableType.setSafe(i, new Text(format("table_type #%d", i)))) + .forEach(i -> tableSchema.setSafe(i, filledTableSchemaBytes)); + root.setRowCount(ROW_COUNT); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetTablesWithSchema, + commandGetTablesWithSchemaResultProducer); + + final Message commandGetDbSchemas = CommandGetDbSchemas.getDefaultInstance(); + final Consumer commandGetSchemasResultProducer = listener -> { + try (final BufferAllocator allocator = new RootAllocator(); + final VectorSchemaRoot root = VectorSchemaRoot.create(Schemas.GET_SCHEMAS_SCHEMA, + allocator)) { + final VarCharVector catalogName = (VarCharVector) root.getVector("catalog_name"); + final VarCharVector schemaName = (VarCharVector) root.getVector("db_schema_name"); + range(0, ROW_COUNT) + .peek(i -> catalogName.setSafe(i, new Text(format("catalog_name #%d", i)))) + .forEach(i -> schemaName.setSafe(i, new Text(format("db_schema_name #%d", i)))); + root.setRowCount(ROW_COUNT); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetDbSchemas, commandGetSchemasResultProducer); + + final Message commandGetExportedKeys = + CommandGetExportedKeys.newBuilder().setTable(TARGET_TABLE).build(); + final Message commandGetImportedKeys = + CommandGetImportedKeys.newBuilder().setTable(TARGET_TABLE).build(); + final Message commandGetCrossReference = CommandGetCrossReference.newBuilder() + .setPkTable(TARGET_TABLE) + .setFkTable(TARGET_FOREIGN_TABLE) + .build(); + final Consumer commandGetExportedAndImportedKeysResultProducer = + listener -> { + try (final BufferAllocator allocator = new RootAllocator(); + final VectorSchemaRoot root = VectorSchemaRoot.create( + Schemas.GET_IMPORTED_KEYS_SCHEMA, + allocator)) { + final VarCharVector pkCatalogName = (VarCharVector) root.getVector("pk_catalog_name"); + final VarCharVector pkSchemaName = (VarCharVector) root.getVector("pk_db_schema_name"); + final VarCharVector pkTableName = (VarCharVector) root.getVector("pk_table_name"); + final VarCharVector pkColumnName = (VarCharVector) root.getVector("pk_column_name"); + final VarCharVector fkCatalogName = (VarCharVector) root.getVector("fk_catalog_name"); + final VarCharVector fkSchemaName = (VarCharVector) root.getVector("fk_db_schema_name"); + final VarCharVector fkTableName = (VarCharVector) root.getVector("fk_table_name"); + final VarCharVector fkColumnName = (VarCharVector) root.getVector("fk_column_name"); + final IntVector keySequence = (IntVector) root.getVector("key_sequence"); + final VarCharVector fkKeyName = (VarCharVector) root.getVector("fk_key_name"); + final VarCharVector pkKeyName = (VarCharVector) root.getVector("pk_key_name"); + final UInt1Vector updateRule = (UInt1Vector) root.getVector("update_rule"); + final UInt1Vector deleteRule = (UInt1Vector) root.getVector("delete_rule"); + range(0, ROW_COUNT) + .peek(i -> pkCatalogName.setSafe(i, new Text(format("pk_catalog_name #%d", i)))) + .peek(i -> pkSchemaName.setSafe(i, new Text(format("pk_db_schema_name #%d", i)))) + .peek(i -> pkTableName.setSafe(i, new Text(format("pk_table_name #%d", i)))) + .peek(i -> pkColumnName.setSafe(i, new Text(format("pk_column_name #%d", i)))) + .peek(i -> fkCatalogName.setSafe(i, new Text(format("fk_catalog_name #%d", i)))) + .peek(i -> fkSchemaName.setSafe(i, new Text(format("fk_db_schema_name #%d", i)))) + .peek(i -> fkTableName.setSafe(i, new Text(format("fk_table_name #%d", i)))) + .peek(i -> fkColumnName.setSafe(i, new Text(format("fk_column_name #%d", i)))) + .peek(i -> keySequence.setSafe(i, i)) + .peek(i -> fkKeyName.setSafe(i, new Text(format("fk_key_name #%d", i)))) + .peek(i -> pkKeyName.setSafe(i, new Text(format("pk_key_name #%d", i)))) + .peek(i -> updateRule.setSafe(i, i)) + .forEach(i -> deleteRule.setSafe(i, i)); + root.setRowCount(ROW_COUNT); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetExportedKeys, + commandGetExportedAndImportedKeysResultProducer); + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetImportedKeys, + commandGetExportedAndImportedKeysResultProducer); + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetCrossReference, + commandGetExportedAndImportedKeysResultProducer); + + final Message commandGetPrimaryKeys = + CommandGetPrimaryKeys.newBuilder().setTable(TARGET_TABLE).build(); + final Consumer commandGetPrimaryKeysResultProducer = listener -> { + try (final BufferAllocator allocator = new RootAllocator(); + final VectorSchemaRoot root = VectorSchemaRoot.create(Schemas.GET_PRIMARY_KEYS_SCHEMA, + allocator)) { + final VarCharVector catalogName = (VarCharVector) root.getVector("catalog_name"); + final VarCharVector schemaName = (VarCharVector) root.getVector("db_schema_name"); + final VarCharVector tableName = (VarCharVector) root.getVector("table_name"); + final VarCharVector columnName = (VarCharVector) root.getVector("column_name"); + final IntVector keySequence = (IntVector) root.getVector("key_sequence"); + final VarCharVector keyName = (VarCharVector) root.getVector("key_name"); + range(0, ROW_COUNT) + .peek(i -> catalogName.setSafe(i, new Text(format("catalog_name #%d", i)))) + .peek(i -> schemaName.setSafe(i, new Text(format("db_schema_name #%d", i)))) + .peek(i -> tableName.setSafe(i, new Text(format("table_name #%d", i)))) + .peek(i -> columnName.setSafe(i, new Text(format("column_name #%d", i)))) + .peek(i -> keySequence.setSafe(i, i)) + .forEach(i -> keyName.setSafe(i, new Text(format("key_name #%d", i)))); + root.setRowCount(ROW_COUNT); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + }; + FLIGHT_SQL_PRODUCER.addCatalogQuery(commandGetPrimaryKeys, commandGetPrimaryKeysResultProducer); + + FLIGHT_SQL_PRODUCER.getSqlInfoBuilder() + .withSqlOuterJoinSupportLevel(FlightSql.SqlOuterJoinsSupportLevel.SQL_FULL_OUTER_JOINS) + .withFlightSqlServerName(EXPECTED_DATABASE_PRODUCT_NAME) + .withFlightSqlServerVersion(EXPECTED_DATABASE_PRODUCT_VERSION) + .withSqlIdentifierQuoteChar(EXPECTED_IDENTIFIER_QUOTE_STRING) + .withFlightSqlServerReadOnly(EXPECTED_IS_READ_ONLY) + .withSqlKeywords(EXPECTED_SQL_KEYWORDS.split("\\s*,\\s*")) + .withSqlNumericFunctions(EXPECTED_NUMERIC_FUNCTIONS.split("\\s*,\\s*")) + .withSqlStringFunctions(EXPECTED_STRING_FUNCTIONS.split("\\s*,\\s*")) + .withSqlSystemFunctions(EXPECTED_SYSTEM_FUNCTIONS.split("\\s*,\\s*")) + .withSqlDatetimeFunctions(EXPECTED_TIME_DATE_FUNCTIONS.split("\\s*,\\s*")) + .withSqlSearchStringEscape(EXPECTED_SEARCH_STRING_ESCAPE) + .withSqlExtraNameCharacters(EXPECTED_EXTRA_NAME_CHARACTERS) + .withSqlSupportsColumnAliasing(EXPECTED_SUPPORTS_COLUMN_ALIASING) + .withSqlNullPlusNullIsNull(EXPECTED_NULL_PLUS_NULL_IS_NULL) + .withSqlSupportsConvert(ImmutableMap.of(SQL_CONVERT_BIT_VALUE, + Arrays.asList(SQL_CONVERT_INTEGER_VALUE, SQL_CONVERT_BIGINT_VALUE))) + .withSqlSupportsTableCorrelationNames(EXPECTED_SUPPORTS_TABLE_CORRELATION_NAMES) + .withSqlSupportsDifferentTableCorrelationNames( + EXPECTED_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES) + .withSqlSupportsExpressionsInOrderBy(EXPECTED_EXPRESSIONS_IN_ORDER_BY) + .withSqlSupportsOrderByUnrelated(EXPECTED_SUPPORTS_ORDER_BY_UNRELATED) + .withSqlSupportedGroupBy(FlightSql.SqlSupportedGroupBy.SQL_GROUP_BY_UNRELATED) + .withSqlSupportsLikeEscapeClause(EXPECTED_SUPPORTS_LIKE_ESCAPE_CLAUSE) + .withSqlSupportsNonNullableColumns(EXPECTED_NON_NULLABLE_COLUMNS) + .withSqlSupportedGrammar(FlightSql.SupportedSqlGrammar.SQL_CORE_GRAMMAR, + FlightSql.SupportedSqlGrammar.SQL_MINIMUM_GRAMMAR) + .withSqlAnsi92SupportedLevel(FlightSql.SupportedAnsi92SqlGrammarLevel.ANSI92_ENTRY_SQL, + FlightSql.SupportedAnsi92SqlGrammarLevel.ANSI92_INTERMEDIATE_SQL) + .withSqlSupportsIntegrityEnhancementFacility( + EXPECTED_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY) + .withSqlSchemaTerm(EXPECTED_SCHEMA_TERM) + .withSqlCatalogTerm(EXPECTED_CATALOG_TERM) + .withSqlProcedureTerm(EXPECTED_PROCEDURE_TERM) + .withSqlCatalogAtStart(EXPECTED_CATALOG_AT_START) + .withSqlSchemasSupportedActions( + FlightSql.SqlSupportedElementActions.SQL_ELEMENT_IN_PROCEDURE_CALLS, + FlightSql.SqlSupportedElementActions.SQL_ELEMENT_IN_INDEX_DEFINITIONS) + .withSqlCatalogsSupportedActions( + FlightSql.SqlSupportedElementActions.SQL_ELEMENT_IN_INDEX_DEFINITIONS) + .withSqlSupportedPositionedCommands( + FlightSql.SqlSupportedPositionedCommands.SQL_POSITIONED_DELETE) + .withSqlSelectForUpdateSupported(EXPECTED_SELECT_FOR_UPDATE_SUPPORTED) + .withSqlStoredProceduresSupported(EXPECTED_STORED_PROCEDURES_SUPPORTED) + .withSqlSubQueriesSupported(EXPECTED_SUPPORTED_SUBQUERIES) + .withSqlCorrelatedSubqueriesSupported(EXPECTED_CORRELATED_SUBQUERIES_SUPPORTED) + .withSqlSupportedUnions(FlightSql.SqlSupportedUnions.SQL_UNION_ALL) + .withSqlMaxBinaryLiteralLength(EXPECTED_MAX_BINARY_LITERAL_LENGTH) + .withSqlMaxCharLiteralLength(EXPECTED_MAX_CHAR_LITERAL_LENGTH) + .withSqlMaxColumnNameLength(EXPECTED_MAX_COLUMN_NAME_LENGTH) + .withSqlMaxColumnsInGroupBy(EXPECTED_MAX_COLUMNS_IN_GROUP_BY) + .withSqlMaxColumnsInIndex(EXPECTED_MAX_COLUMNS_IN_INDEX) + .withSqlMaxColumnsInOrderBy(EXPECTED_MAX_COLUMNS_IN_ORDER_BY) + .withSqlMaxColumnsInSelect(EXPECTED_MAX_COLUMNS_IN_SELECT) + .withSqlMaxConnections(EXPECTED_MAX_CONNECTIONS) + .withSqlMaxCursorNameLength(EXPECTED_MAX_CURSOR_NAME_LENGTH) + .withSqlMaxIndexLength(EXPECTED_MAX_INDEX_LENGTH) + .withSqlDbSchemaNameLength(EXPECTED_SCHEMA_NAME_LENGTH) + .withSqlMaxProcedureNameLength(EXPECTED_MAX_PROCEDURE_NAME_LENGTH) + .withSqlMaxCatalogNameLength(EXPECTED_MAX_CATALOG_NAME_LENGTH) + .withSqlMaxRowSize(EXPECTED_MAX_ROW_SIZE) + .withSqlMaxRowSizeIncludesBlobs(EXPECTED_MAX_ROW_SIZE_INCLUDES_BLOBS) + .withSqlMaxStatementLength(EXPECTED_MAX_STATEMENT_LENGTH) + .withSqlMaxStatements(EXPECTED_MAX_STATEMENTS) + .withSqlMaxTableNameLength(EXPECTED_MAX_TABLE_NAME_LENGTH) + .withSqlMaxTablesInSelect(EXPECTED_MAX_TABLES_IN_SELECT) + .withSqlMaxUsernameLength(EXPECTED_MAX_USERNAME_LENGTH) + .withSqlDefaultTransactionIsolation(EXPECTED_DEFAULT_TRANSACTION_ISOLATION) + .withSqlTransactionsSupported(EXPECTED_TRANSACTIONS_SUPPORTED) + .withSqlSupportedTransactionsIsolationLevels( + FlightSql.SqlTransactionIsolationLevel.SQL_TRANSACTION_SERIALIZABLE, + FlightSql.SqlTransactionIsolationLevel.SQL_TRANSACTION_READ_COMMITTED) + .withSqlDataDefinitionCausesTransactionCommit( + EXPECTED_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT) + .withSqlDataDefinitionsInTransactionsIgnored( + EXPECTED_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED) + .withSqlSupportedResultSetTypes( + FlightSql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_FORWARD_ONLY, + FlightSql.SqlSupportedResultSetType.SQL_RESULT_SET_TYPE_SCROLL_INSENSITIVE) + .withSqlBatchUpdatesSupported(EXPECTED_BATCH_UPDATES_SUPPORTED) + .withSqlSavepointsSupported(EXPECTED_SAVEPOINTS_SUPPORTED) + .withSqlNamedParametersSupported(EXPECTED_NAMED_PARAMETERS_SUPPORTED) + .withSqlLocatorsUpdateCopy(EXPECTED_LOCATORS_UPDATE_COPY) + .withSqlStoredFunctionsUsingCallSyntaxSupported( + EXPECTED_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED); + } + + @AfterClass + public static void tearDown() throws Exception { + AutoCloseables.close(connection, FLIGHT_SQL_PRODUCER); + } + + + @Test + public void testGetCatalogsCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getCatalogs()) { + resultSetTestUtils.testData(resultSet, EXPECTED_GET_CATALOGS_RESULTS); + } + } + + @Test + public void testGetCatalogsCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getCatalogs()) { + resultSetTestUtils.testData(resultSet, singletonList("TABLE_CAT"), + EXPECTED_GET_CATALOGS_RESULTS); + } + } + + @Test + public void testTableTypesCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getTableTypes()) { + resultSetTestUtils.testData(resultSet, EXPECTED_GET_TABLE_TYPES_RESULTS); + } + } + + @Test + public void testTableTypesCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getTableTypes()) { + resultSetTestUtils.testData(resultSet, singletonList("TABLE_TYPE"), + EXPECTED_GET_TABLE_TYPES_RESULTS); + } + } + + @Test + public void testGetTablesCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getTables(null, null, null, null)) { + resultSetTestUtils.testData(resultSet, EXPECTED_GET_TABLES_RESULTS); + } + } + + @Test + public void testGetTablesCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getTables(null, null, null, null)) { + resultSetTestUtils.testData( + resultSet, + ImmutableList.of( + "TABLE_CAT", + "TABLE_SCHEM", + "TABLE_NAME", + "TABLE_TYPE", + "REMARKS", + "TYPE_CAT", + "TYPE_SCHEM", + "TYPE_NAME", + "SELF_REFERENCING_COL_NAME", + "REF_GENERATION"), + EXPECTED_GET_TABLES_RESULTS + ); + } + } + + @Test + public void testGetSchemasCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getSchemas()) { + resultSetTestUtils.testData(resultSet, EXPECTED_GET_SCHEMAS_RESULTS); + } + } + + @Test + public void testGetSchemasCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getSchemas()) { + resultSetTestUtils.testData(resultSet, ImmutableList.of("TABLE_SCHEM", "TABLE_CATALOG"), + EXPECTED_GET_SCHEMAS_RESULTS); + } + } + + @Test + public void testGetExportedKeysCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData() + .getExportedKeys(null, null, TARGET_TABLE)) { + resultSetTestUtils.testData(resultSet, EXPECTED_GET_EXPORTED_AND_IMPORTED_KEYS_RESULTS); + } + } + + @Test + public void testGetExportedKeysCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData() + .getExportedKeys(null, null, TARGET_TABLE)) { + resultSetTestUtils.testData( + resultSet, FIELDS_GET_IMPORTED_EXPORTED_KEYS, + EXPECTED_GET_EXPORTED_AND_IMPORTED_KEYS_RESULTS); + } + } + + @Test + public void testGetImportedKeysCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData() + .getImportedKeys(null, null, TARGET_TABLE)) { + resultSetTestUtils.testData(resultSet, EXPECTED_GET_EXPORTED_AND_IMPORTED_KEYS_RESULTS); + } + } + + @Test + public void testGetImportedKeysCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData() + .getImportedKeys(null, null, TARGET_TABLE)) { + resultSetTestUtils.testData( + resultSet, FIELDS_GET_IMPORTED_EXPORTED_KEYS, + EXPECTED_GET_EXPORTED_AND_IMPORTED_KEYS_RESULTS); + } + } + + @Test + public void testGetCrossReferenceCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getCrossReference(null, null, + TARGET_TABLE, null, null, TARGET_FOREIGN_TABLE)) { + resultSetTestUtils.testData(resultSet, EXPECTED_CROSS_REFERENCE_RESULTS); + } + } + + @Test + public void testGetGetCrossReferenceCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getCrossReference(null, null, + TARGET_TABLE, null, null, TARGET_FOREIGN_TABLE)) { + resultSetTestUtils.testData( + resultSet, FIELDS_GET_CROSS_REFERENCE, EXPECTED_CROSS_REFERENCE_RESULTS); + } + } + + @Test + public void testPrimaryKeysCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData() + .getPrimaryKeys(null, null, TARGET_TABLE)) { + resultSetTestUtils.testData(resultSet, EXPECTED_PRIMARY_KEYS_RESULTS); + } + } + + @Test + public void testPrimaryKeysCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData() + .getPrimaryKeys(null, null, TARGET_TABLE)) { + resultSetTestUtils.testData( + resultSet, + ImmutableList.of( + "TABLE_CAT", + "TABLE_SCHEM", + "TABLE_NAME", + "COLUMN_NAME", + "KEY_SEQ", + "PK_NAME"), + EXPECTED_PRIMARY_KEYS_RESULTS + ); + } + } + + @Test + public void testGetColumnsCanBeAccessedByIndices() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getColumns(null, null, null, null)) { + resultSetTestUtils.testData(resultSet, EXPECTED_GET_COLUMNS_RESULTS); + } + } + + @Test + public void testGetColumnsCanByIndicesFilteringColumnNames() throws SQLException { + try ( + final ResultSet resultSet = connection.getMetaData() + .getColumns(null, null, null, "column_1")) { + resultSetTestUtils.testData(resultSet, EXPECTED_GET_COLUMNS_RESULTS + .stream() + .filter(insideList -> Objects.equals(insideList.get(3), "column_1")) + .collect(toList()) + ); + } + } + + @Test + public void testGetSqlInfo() throws SQLException { + final DatabaseMetaData metaData = connection.getMetaData(); + collector.checkThat(metaData.getDatabaseProductName(), is(EXPECTED_DATABASE_PRODUCT_NAME)); + collector.checkThat(metaData.getDatabaseProductVersion(), + is(EXPECTED_DATABASE_PRODUCT_VERSION)); + collector.checkThat(metaData.getIdentifierQuoteString(), is(EXPECTED_IDENTIFIER_QUOTE_STRING)); + collector.checkThat(metaData.isReadOnly(), is(EXPECTED_IS_READ_ONLY)); + collector.checkThat(metaData.getSQLKeywords(), is(EXPECTED_SQL_KEYWORDS)); + collector.checkThat(metaData.getNumericFunctions(), is(EXPECTED_NUMERIC_FUNCTIONS)); + collector.checkThat(metaData.getStringFunctions(), is(EXPECTED_STRING_FUNCTIONS)); + collector.checkThat(metaData.getSystemFunctions(), is(EXPECTED_SYSTEM_FUNCTIONS)); + collector.checkThat(metaData.getTimeDateFunctions(), is(EXPECTED_TIME_DATE_FUNCTIONS)); + collector.checkThat(metaData.getSearchStringEscape(), is(EXPECTED_SEARCH_STRING_ESCAPE)); + collector.checkThat(metaData.getExtraNameCharacters(), is(EXPECTED_EXTRA_NAME_CHARACTERS)); + collector.checkThat(metaData.supportsConvert(), is(EXPECTED_SQL_SUPPORTS_CONVERT)); + collector.checkThat(metaData.supportsConvert(BIT, INTEGER), is(EXPECTED_SQL_SUPPORTS_CONVERT)); + collector.checkThat(metaData.supportsConvert(BIT, BIGINT), is(EXPECTED_SQL_SUPPORTS_CONVERT)); + collector.checkThat(metaData.supportsConvert(BIGINT, INTEGER), + is(EXPECTED_INVALID_SQL_SUPPORTS_CONVERT)); + collector.checkThat(metaData.supportsConvert(JAVA_OBJECT, INTEGER), + is(EXPECTED_INVALID_SQL_SUPPORTS_CONVERT)); + collector.checkThat(metaData.supportsTableCorrelationNames(), + is(EXPECTED_SUPPORTS_TABLE_CORRELATION_NAMES)); + collector.checkThat(metaData.supportsExpressionsInOrderBy(), + is(EXPECTED_EXPRESSIONS_IN_ORDER_BY)); + collector.checkThat(metaData.supportsOrderByUnrelated(), + is(EXPECTED_SUPPORTS_ORDER_BY_UNRELATED)); + collector.checkThat(metaData.supportsGroupBy(), is(EXPECTED_SUPPORTS_GROUP_BY)); + collector.checkThat(metaData.supportsGroupByUnrelated(), + is(EXPECTED_SUPPORTS_GROUP_BY_UNRELATED)); + collector.checkThat(metaData.supportsLikeEscapeClause(), + is(EXPECTED_SUPPORTS_LIKE_ESCAPE_CLAUSE)); + collector.checkThat(metaData.supportsNonNullableColumns(), is(EXPECTED_NON_NULLABLE_COLUMNS)); + collector.checkThat(metaData.supportsMinimumSQLGrammar(), is(EXPECTED_MINIMUM_SQL_GRAMMAR)); + collector.checkThat(metaData.supportsCoreSQLGrammar(), is(EXPECTED_CORE_SQL_GRAMMAR)); + collector.checkThat(metaData.supportsExtendedSQLGrammar(), is(EXPECTED_EXTEND_SQL_GRAMMAR)); + collector.checkThat(metaData.supportsANSI92EntryLevelSQL(), + is(EXPECTED_ANSI92_ENTRY_LEVEL_SQL)); + collector.checkThat(metaData.supportsANSI92IntermediateSQL(), + is(EXPECTED_ANSI92_INTERMEDIATE_SQL)); + collector.checkThat(metaData.supportsANSI92FullSQL(), is(EXPECTED_ANSI92_FULL_SQL)); + collector.checkThat(metaData.supportsOuterJoins(), is(EXPECTED_SUPPORTS_OUTER_JOINS)); + collector.checkThat(metaData.supportsFullOuterJoins(), is(EXPECTED_SUPPORTS_FULL_OUTER_JOINS)); + collector.checkThat(metaData.supportsLimitedOuterJoins(), is(EXPECTED_SUPPORTS_LIMITED_JOINS)); + collector.checkThat(metaData.getSchemaTerm(), is(EXPECTED_SCHEMA_TERM)); + collector.checkThat(metaData.getProcedureTerm(), is(EXPECTED_PROCEDURE_TERM)); + collector.checkThat(metaData.getCatalogTerm(), is(EXPECTED_CATALOG_TERM)); + collector.checkThat(metaData.isCatalogAtStart(), is(EXPECTED_CATALOG_AT_START)); + collector.checkThat(metaData.supportsSchemasInProcedureCalls(), + is(EXPECTED_SCHEMAS_IN_PROCEDURE_CALLS)); + collector.checkThat(metaData.supportsSchemasInIndexDefinitions(), + is(EXPECTED_SCHEMAS_IN_INDEX_DEFINITIONS)); + collector.checkThat(metaData.supportsCatalogsInIndexDefinitions(), + is(EXPECTED_CATALOGS_IN_INDEX_DEFINITIONS)); + collector.checkThat(metaData.supportsPositionedDelete(), is(EXPECTED_POSITIONED_DELETE)); + collector.checkThat(metaData.supportsPositionedUpdate(), is(EXPECTED_POSITIONED_UPDATE)); + collector.checkThat(metaData.supportsResultSetType(ResultSet.TYPE_FORWARD_ONLY), + is(EXPECTED_TYPE_FORWARD_ONLY)); + collector.checkThat(metaData.supportsSelectForUpdate(), + is(EXPECTED_SELECT_FOR_UPDATE_SUPPORTED)); + collector.checkThat(metaData.supportsStoredProcedures(), + is(EXPECTED_STORED_PROCEDURES_SUPPORTED)); + collector.checkThat(metaData.supportsSubqueriesInComparisons(), + is(EXPECTED_SUBQUERIES_IN_COMPARISON)); + collector.checkThat(metaData.supportsSubqueriesInExists(), is(EXPECTED_SUBQUERIES_IN_EXISTS)); + collector.checkThat(metaData.supportsSubqueriesInIns(), is(EXPECTED_SUBQUERIES_IN_INS)); + collector.checkThat(metaData.supportsSubqueriesInQuantifieds(), + is(EXPECTED_SUBQUERIES_IN_QUANTIFIEDS)); + collector.checkThat(metaData.supportsCorrelatedSubqueries(), + is(EXPECTED_CORRELATED_SUBQUERIES_SUPPORTED)); + collector.checkThat(metaData.supportsUnion(), is(EXPECTED_SUPPORTS_UNION)); + collector.checkThat(metaData.supportsUnionAll(), is(EXPECTED_SUPPORTS_UNION_ALL)); + collector.checkThat(metaData.getMaxBinaryLiteralLength(), + is(EXPECTED_MAX_BINARY_LITERAL_LENGTH)); + collector.checkThat(metaData.getMaxCharLiteralLength(), is(EXPECTED_MAX_CHAR_LITERAL_LENGTH)); + collector.checkThat(metaData.getMaxColumnsInGroupBy(), is(EXPECTED_MAX_COLUMNS_IN_GROUP_BY)); + collector.checkThat(metaData.getMaxColumnsInIndex(), is(EXPECTED_MAX_COLUMNS_IN_INDEX)); + collector.checkThat(metaData.getMaxColumnsInOrderBy(), is(EXPECTED_MAX_COLUMNS_IN_ORDER_BY)); + collector.checkThat(metaData.getMaxColumnsInSelect(), is(EXPECTED_MAX_COLUMNS_IN_SELECT)); + collector.checkThat(metaData.getMaxConnections(), is(EXPECTED_MAX_CONNECTIONS)); + collector.checkThat(metaData.getMaxCursorNameLength(), is(EXPECTED_MAX_CURSOR_NAME_LENGTH)); + collector.checkThat(metaData.getMaxIndexLength(), is(EXPECTED_MAX_INDEX_LENGTH)); + collector.checkThat(metaData.getMaxSchemaNameLength(), is(EXPECTED_SCHEMA_NAME_LENGTH)); + collector.checkThat(metaData.getMaxProcedureNameLength(), + is(EXPECTED_MAX_PROCEDURE_NAME_LENGTH)); + collector.checkThat(metaData.getMaxCatalogNameLength(), is(EXPECTED_MAX_CATALOG_NAME_LENGTH)); + collector.checkThat(metaData.getMaxRowSize(), is(EXPECTED_MAX_ROW_SIZE)); + collector.checkThat(metaData.doesMaxRowSizeIncludeBlobs(), + is(EXPECTED_MAX_ROW_SIZE_INCLUDES_BLOBS)); + collector.checkThat(metaData.getMaxStatementLength(), is(EXPECTED_MAX_STATEMENT_LENGTH)); + collector.checkThat(metaData.getMaxStatements(), is(EXPECTED_MAX_STATEMENTS)); + collector.checkThat(metaData.getMaxTableNameLength(), is(EXPECTED_MAX_TABLE_NAME_LENGTH)); + collector.checkThat(metaData.getMaxTablesInSelect(), is(EXPECTED_MAX_TABLES_IN_SELECT)); + collector.checkThat(metaData.getMaxUserNameLength(), is(EXPECTED_MAX_USERNAME_LENGTH)); + collector.checkThat(metaData.getDefaultTransactionIsolation(), + is(EXPECTED_DEFAULT_TRANSACTION_ISOLATION)); + collector.checkThat(metaData.supportsTransactions(), is(EXPECTED_TRANSACTIONS_SUPPORTED)); + collector.checkThat(metaData.supportsBatchUpdates(), is(EXPECTED_BATCH_UPDATES_SUPPORTED)); + collector.checkThat(metaData.supportsSavepoints(), is(EXPECTED_SAVEPOINTS_SUPPORTED)); + collector.checkThat(metaData.supportsNamedParameters(), + is(EXPECTED_NAMED_PARAMETERS_SUPPORTED)); + collector.checkThat(metaData.locatorsUpdateCopy(), is(EXPECTED_LOCATORS_UPDATE_COPY)); + + collector.checkThat(metaData.supportsResultSetType(ResultSet.TYPE_SCROLL_INSENSITIVE), + is(EXPECTED_TYPE_SCROLL_INSENSITIVE)); + collector.checkThat(metaData.supportsResultSetType(ResultSet.TYPE_SCROLL_SENSITIVE), + is(EXPECTED_TYPE_SCROLL_SENSITIVE)); + collector.checkThat(metaData.supportsSchemasInPrivilegeDefinitions(), + is(EXPECTED_SCHEMAS_IN_PRIVILEGE_DEFINITIONS)); + collector.checkThat(metaData.supportsCatalogsInPrivilegeDefinitions(), + is(EXPECTED_CATALOGS_IN_PRIVILEGE_DEFINITIONS)); + collector.checkThat(metaData.supportsTransactionIsolationLevel(Connection.TRANSACTION_NONE), + is(EXPECTED_TRANSACTION_NONE)); + collector.checkThat( + metaData.supportsTransactionIsolationLevel(Connection.TRANSACTION_READ_COMMITTED), + is(EXPECTED_TRANSACTION_READ_COMMITTED)); + collector.checkThat( + metaData.supportsTransactionIsolationLevel(Connection.TRANSACTION_READ_UNCOMMITTED), + is(EXPECTED_TRANSACTION_READ_UNCOMMITTED)); + collector.checkThat( + metaData.supportsTransactionIsolationLevel(Connection.TRANSACTION_REPEATABLE_READ), + is(EXPECTED_TRANSACTION_REPEATABLE_READ)); + collector.checkThat( + metaData.supportsTransactionIsolationLevel(Connection.TRANSACTION_SERIALIZABLE), + is(EXPECTED_TRANSACTION_SERIALIZABLE)); + collector.checkThat(metaData.dataDefinitionCausesTransactionCommit(), + is(EXPECTED_DATA_DEFINITION_CAUSES_TRANSACTION_COMMIT)); + collector.checkThat(metaData.dataDefinitionIgnoredInTransactions(), + is(EXPECTED_DATA_DEFINITIONS_IN_TRANSACTIONS_IGNORED)); + collector.checkThat(metaData.supportsStoredFunctionsUsingCallSyntax(), + is(EXPECTED_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED)); + collector.checkThat(metaData.supportsIntegrityEnhancementFacility(), + is(EXPECTED_SUPPORTS_INTEGRITY_ENHANCEMENT_FACILITY)); + collector.checkThat(metaData.supportsDifferentTableCorrelationNames(), + is(EXPECTED_SUPPORTS_DIFFERENT_TABLE_CORRELATION_NAMES)); + + ThrowableAssertionUtils.simpleAssertThrowableClass(SQLException.class, + () -> metaData.supportsTransactionIsolationLevel(Connection.TRANSACTION_SERIALIZABLE + 1)); + ThrowableAssertionUtils.simpleAssertThrowableClass(SQLException.class, + () -> metaData.supportsResultSetType(ResultSet.HOLD_CURSORS_OVER_COMMIT)); + } + + @Test + public void testGetColumnsCanBeAccessedByNames() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getColumns(null, null, null, null)) { + resultSetTestUtils.testData(resultSet, + ImmutableList.of( + "TABLE_CAT", + "TABLE_SCHEM", + "TABLE_NAME", + "COLUMN_NAME", + "DATA_TYPE", + "TYPE_NAME", + "COLUMN_SIZE", + "BUFFER_LENGTH", + "DECIMAL_DIGITS", + "NUM_PREC_RADIX", + "NULLABLE", + "REMARKS", + "COLUMN_DEF", + "SQL_DATA_TYPE", + "SQL_DATETIME_SUB", + "CHAR_OCTET_LENGTH", + "ORDINAL_POSITION", + "IS_NULLABLE", + "SCOPE_CATALOG", + "SCOPE_SCHEMA", + "SCOPE_TABLE", + "SOURCE_DATA_TYPE", + "IS_AUTOINCREMENT", + "IS_GENERATEDCOLUMN"), + EXPECTED_GET_COLUMNS_RESULTS + ); + } + } + + @Test + public void testGetProcedures() throws SQLException { + try (final ResultSet resultSet = connection.getMetaData().getProcedures(null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetProceduresSchema = new HashMap() { + { + put(1, "PROCEDURE_CAT"); + put(2, "PROCEDURE_SCHEM"); + put(3, "PROCEDURE_NAME"); + put(4, "FUTURE_USE1"); + put(5, "FUTURE_USE2"); + put(6, "FUTURE_USE3"); + put(7, "REMARKS"); + put(8, "PROCEDURE_TYPE"); + put(9, "SPECIFIC_NAME"); + } + }; + testEmptyResultSet(resultSet, expectedGetProceduresSchema); + } + } + + @Test + public void testGetProcedureColumns() throws SQLException { + try (ResultSet resultSet = connection.getMetaData() + .getProcedureColumns(null, null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetProcedureColumnsSchema = + new HashMap() { + { + put(1, "PROCEDURE_CAT"); + put(2, "PROCEDURE_SCHEM"); + put(3, "PROCEDURE_NAME"); + put(4, "COLUMN_NAME"); + put(5, "COLUMN_TYPE"); + put(6, "DATA_TYPE"); + put(7, "TYPE_NAME"); + put(8, "PRECISION"); + put(9, "LENGTH"); + put(10, "SCALE"); + put(11, "RADIX"); + put(12, "NULLABLE"); + put(13, "REMARKS"); + put(14, "COLUMN_DEF"); + put(15, "SQL_DATA_TYPE"); + put(16, "SQL_DATETIME_SUB"); + put(17, "CHAR_OCTET_LENGTH"); + put(18, "ORDINAL_POSITION"); + put(19, "IS_NULLABLE"); + put(20, "SPECIFIC_NAME"); + } + }; + testEmptyResultSet(resultSet, expectedGetProcedureColumnsSchema); + } + } + + @Test + public void testGetColumnPrivileges() throws SQLException { + try (ResultSet resultSet = connection.getMetaData() + .getColumnPrivileges(null, null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetColumnPrivilegesSchema = + new HashMap() { + { + put(1, "TABLE_CAT"); + put(2, "TABLE_SCHEM"); + put(3, "TABLE_NAME"); + put(4, "COLUMN_NAME"); + put(5, "GRANTOR"); + put(6, "GRANTEE"); + put(7, "PRIVILEGE"); + put(8, "IS_GRANTABLE"); + } + }; + testEmptyResultSet(resultSet, expectedGetColumnPrivilegesSchema); + } + } + + @Test + public void testGetTablePrivileges() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getTablePrivileges(null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetTablePrivilegesSchema = new HashMap() { + { + put(1, "TABLE_CAT"); + put(2, "TABLE_SCHEM"); + put(3, "TABLE_NAME"); + put(4, "GRANTOR"); + put(5, "GRANTEE"); + put(6, "PRIVILEGE"); + put(7, "IS_GRANTABLE"); + } + }; + testEmptyResultSet(resultSet, expectedGetTablePrivilegesSchema); + } + } + + @Test + public void testGetBestRowIdentifier() throws SQLException { + try (ResultSet resultSet = connection.getMetaData() + .getBestRowIdentifier(null, null, null, 0, true)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetBestRowIdentifierSchema = + new HashMap() { + { + put(1, "SCOPE"); + put(2, "COLUMN_NAME"); + put(3, "DATA_TYPE"); + put(4, "TYPE_NAME"); + put(5, "COLUMN_SIZE"); + put(6, "BUFFER_LENGTH"); + put(7, "DECIMAL_DIGITS"); + put(8, "PSEUDO_COLUMN"); + } + }; + testEmptyResultSet(resultSet, expectedGetBestRowIdentifierSchema); + } + } + + @Test + public void testGetVersionColumns() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getVersionColumns(null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetVersionColumnsSchema = new HashMap() { + { + put(1, "SCOPE"); + put(2, "COLUMN_NAME"); + put(3, "DATA_TYPE"); + put(4, "TYPE_NAME"); + put(5, "COLUMN_SIZE"); + put(6, "BUFFER_LENGTH"); + put(7, "DECIMAL_DIGITS"); + put(8, "PSEUDO_COLUMN"); + } + }; + testEmptyResultSet(resultSet, expectedGetVersionColumnsSchema); + } + } + + @Test + public void testGetTypeInfo() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getTypeInfo()) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetTypeInfoSchema = new HashMap() { + { + put(1, "TYPE_NAME"); + put(2, "DATA_TYPE"); + put(3, "PRECISION"); + put(4, "LITERAL_PREFIX"); + put(5, "LITERAL_SUFFIX"); + put(6, "CREATE_PARAMS"); + put(7, "NULLABLE"); + put(8, "CASE_SENSITIVE"); + put(9, "SEARCHABLE"); + put(10, "UNSIGNED_ATTRIBUTE"); + put(11, "FIXED_PREC_SCALE"); + put(12, "AUTO_INCREMENT"); + put(13, "LOCAL_TYPE_NAME"); + put(14, "MINIMUM_SCALE"); + put(15, "MAXIMUM_SCALE"); + put(16, "SQL_DATA_TYPE"); + put(17, "SQL_DATETIME_SUB"); + put(18, "NUM_PREC_RADIX"); + } + }; + testEmptyResultSet(resultSet, expectedGetTypeInfoSchema); + } + } + + @Test + public void testGetIndexInfo() throws SQLException { + try (ResultSet resultSet = connection.getMetaData() + .getIndexInfo(null, null, null, false, true)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetIndexInfoSchema = new HashMap() { + { + put(1, "TABLE_CAT"); + put(2, "TABLE_SCHEM"); + put(3, "TABLE_NAME"); + put(4, "NON_UNIQUE"); + put(5, "INDEX_QUALIFIER"); + put(6, "INDEX_NAME"); + put(7, "TYPE"); + put(8, "ORDINAL_POSITION"); + put(9, "COLUMN_NAME"); + put(10, "ASC_OR_DESC"); + put(11, "CARDINALITY"); + put(12, "PAGES"); + put(13, "FILTER_CONDITION"); + } + }; + testEmptyResultSet(resultSet, expectedGetIndexInfoSchema); + } + } + + @Test + public void testGetUDTs() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getUDTs(null, null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetUDTsSchema = new HashMap() { + { + put(1, "TYPE_CAT"); + put(2, "TYPE_SCHEM"); + put(3, "TYPE_NAME"); + put(4, "CLASS_NAME"); + put(5, "DATA_TYPE"); + put(6, "REMARKS"); + put(7, "BASE_TYPE"); + } + }; + testEmptyResultSet(resultSet, expectedGetUDTsSchema); + } + } + + @Test + public void testGetSuperTypes() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getSuperTypes(null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetSuperTypesSchema = new HashMap() { + { + put(1, "TYPE_CAT"); + put(2, "TYPE_SCHEM"); + put(3, "TYPE_NAME"); + put(4, "SUPERTYPE_CAT"); + put(5, "SUPERTYPE_SCHEM"); + put(6, "SUPERTYPE_NAME"); + } + }; + testEmptyResultSet(resultSet, expectedGetSuperTypesSchema); + } + } + + @Test + public void testGetSuperTables() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getSuperTables(null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetSuperTablesSchema = new HashMap() { + { + put(1, "TABLE_CAT"); + put(2, "TABLE_SCHEM"); + put(3, "TABLE_NAME"); + put(4, "SUPERTABLE_NAME"); + } + }; + testEmptyResultSet(resultSet, expectedGetSuperTablesSchema); + } + } + + @Test + public void testGetAttributes() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getAttributes(null, null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetAttributesSchema = new HashMap() { + { + put(1, "TYPE_CAT"); + put(2, "TYPE_SCHEM"); + put(3, "TYPE_NAME"); + put(4, "ATTR_NAME"); + put(5, "DATA_TYPE"); + put(6, "ATTR_TYPE_NAME"); + put(7, "ATTR_SIZE"); + put(8, "DECIMAL_DIGITS"); + put(9, "NUM_PREC_RADIX"); + put(10, "NULLABLE"); + put(11, "REMARKS"); + put(12, "ATTR_DEF"); + put(13, "SQL_DATA_TYPE"); + put(14, "SQL_DATETIME_SUB"); + put(15, "CHAR_OCTET_LENGTH"); + put(16, "ORDINAL_POSITION"); + put(17, "IS_NULLABLE"); + put(18, "SCOPE_CATALOG"); + put(19, "SCOPE_SCHEMA"); + put(20, "SCOPE_TABLE"); + put(21, "SOURCE_DATA_TYPE"); + } + }; + testEmptyResultSet(resultSet, expectedGetAttributesSchema); + } + } + + @Test + public void testGetClientInfoProperties() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getClientInfoProperties()) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetClientInfoPropertiesSchema = + new HashMap() { + { + put(1, "NAME"); + put(2, "MAX_LEN"); + put(3, "DEFAULT_VALUE"); + put(4, "DESCRIPTION"); + } + }; + testEmptyResultSet(resultSet, expectedGetClientInfoPropertiesSchema); + } + } + + @Test + public void testGetFunctions() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getFunctions(null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetFunctionsSchema = new HashMap() { + { + put(1, "FUNCTION_CAT"); + put(2, "FUNCTION_SCHEM"); + put(3, "FUNCTION_NAME"); + put(4, "REMARKS"); + put(5, "FUNCTION_TYPE"); + put(6, "SPECIFIC_NAME"); + } + }; + testEmptyResultSet(resultSet, expectedGetFunctionsSchema); + } + } + + @Test + public void testGetFunctionColumns() throws SQLException { + try ( + ResultSet resultSet = connection.getMetaData().getFunctionColumns(null, null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetFunctionColumnsSchema = new HashMap() { + { + put(1, "FUNCTION_CAT"); + put(2, "FUNCTION_SCHEM"); + put(3, "FUNCTION_NAME"); + put(4, "COLUMN_NAME"); + put(5, "COLUMN_TYPE"); + put(6, "DATA_TYPE"); + put(7, "TYPE_NAME"); + put(8, "PRECISION"); + put(9, "LENGTH"); + put(10, "SCALE"); + put(11, "RADIX"); + put(12, "NULLABLE"); + put(13, "REMARKS"); + put(14, "CHAR_OCTET_LENGTH"); + put(15, "ORDINAL_POSITION"); + put(16, "IS_NULLABLE"); + put(17, "SPECIFIC_NAME"); + } + }; + testEmptyResultSet(resultSet, expectedGetFunctionColumnsSchema); + } + } + + @Test + public void testGetPseudoColumns() throws SQLException { + try (ResultSet resultSet = connection.getMetaData().getPseudoColumns(null, null, null, null)) { + // Maps ordinal index to column name according to JDBC documentation + final Map expectedGetPseudoColumnsSchema = new HashMap() { + { + put(1, "TABLE_CAT"); + put(2, "TABLE_SCHEM"); + put(3, "TABLE_NAME"); + put(4, "COLUMN_NAME"); + put(5, "DATA_TYPE"); + put(6, "COLUMN_SIZE"); + put(7, "DECIMAL_DIGITS"); + put(8, "NUM_PREC_RADIX"); + put(9, "COLUMN_USAGE"); + put(10, "REMARKS"); + put(11, "CHAR_OCTET_LENGTH"); + put(12, "IS_NULLABLE"); + } + }; + testEmptyResultSet(resultSet, expectedGetPseudoColumnsSchema); + } + } + + private void testEmptyResultSet(final ResultSet resultSet, + final Map expectedResultSetSchema) + throws SQLException { + Assert.assertFalse(resultSet.next()); + final ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + for (final Map.Entry entry : expectedResultSetSchema.entrySet()) { + Assert.assertEquals(entry.getValue(), resultSetMetaData.getColumnLabel(entry.getKey())); + } + } + + @Test + public void testGetColumnSize() { + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_BYTE), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Int(Byte.SIZE, true))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_SHORT), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Int(Short.SIZE, true))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_INT), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Int(Integer.SIZE, true))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_LONG), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Int(Long.SIZE, true))); + + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_VARCHAR_AND_BINARY), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Utf8())); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_VARCHAR_AND_BINARY), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Binary())); + + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_TIMESTAMP_SECONDS), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Timestamp(TimeUnit.SECOND, null))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_TIMESTAMP_MILLISECONDS), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_TIMESTAMP_MICROSECONDS), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Timestamp(TimeUnit.MICROSECOND, null))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_TIMESTAMP_NANOSECONDS), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Timestamp(TimeUnit.NANOSECOND, null))); + + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_TIME), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Time(TimeUnit.SECOND, Integer.SIZE))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_TIME_MILLISECONDS), + ArrowDatabaseMetadata.getColumnSize( + new ArrowType.Time(TimeUnit.MILLISECOND, Integer.SIZE))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_TIME_MICROSECONDS), + ArrowDatabaseMetadata.getColumnSize( + new ArrowType.Time(TimeUnit.MICROSECOND, Integer.SIZE))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_TIME_NANOSECONDS), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Time(TimeUnit.NANOSECOND, Integer.SIZE))); + + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.COLUMN_SIZE_DATE), + ArrowDatabaseMetadata.getColumnSize(new ArrowType.Date(DateUnit.DAY))); + + Assert.assertNull(ArrowDatabaseMetadata.getColumnSize(new ArrowType.FloatingPoint( + FloatingPointPrecision.DOUBLE))); + } + + @Test + public void testGetDecimalDigits() { + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.NO_DECIMAL_DIGITS), + ArrowDatabaseMetadata.getDecimalDigits(new ArrowType.Int(Byte.SIZE, true))); + + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.NO_DECIMAL_DIGITS), + ArrowDatabaseMetadata.getDecimalDigits(new ArrowType.Timestamp(TimeUnit.SECOND, null))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.DECIMAL_DIGITS_TIME_MILLISECONDS), + ArrowDatabaseMetadata.getDecimalDigits( + new ArrowType.Timestamp(TimeUnit.MILLISECOND, null))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.DECIMAL_DIGITS_TIME_MICROSECONDS), + ArrowDatabaseMetadata.getDecimalDigits( + new ArrowType.Timestamp(TimeUnit.MICROSECOND, null))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.DECIMAL_DIGITS_TIME_NANOSECONDS), + ArrowDatabaseMetadata.getDecimalDigits(new ArrowType.Timestamp(TimeUnit.NANOSECOND, null))); + + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.NO_DECIMAL_DIGITS), + ArrowDatabaseMetadata.getDecimalDigits(new ArrowType.Time(TimeUnit.SECOND, Integer.SIZE))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.DECIMAL_DIGITS_TIME_MILLISECONDS), + ArrowDatabaseMetadata.getDecimalDigits( + new ArrowType.Time(TimeUnit.MILLISECOND, Integer.SIZE))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.DECIMAL_DIGITS_TIME_MICROSECONDS), + ArrowDatabaseMetadata.getDecimalDigits( + new ArrowType.Time(TimeUnit.MICROSECOND, Integer.SIZE))); + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.DECIMAL_DIGITS_TIME_NANOSECONDS), + ArrowDatabaseMetadata.getDecimalDigits( + new ArrowType.Time(TimeUnit.NANOSECOND, Integer.SIZE))); + + Assert.assertEquals(Integer.valueOf(ArrowDatabaseMetadata.NO_DECIMAL_DIGITS), + ArrowDatabaseMetadata.getDecimalDigits(new ArrowType.Date(DateUnit.DAY))); + + Assert.assertNull(ArrowDatabaseMetadata.getDecimalDigits(new ArrowType.Utf8())); + } + + @Test + public void testSqlToRegexLike() { + Assert.assertEquals(".*", ArrowDatabaseMetadata.sqlToRegexLike("%")); + Assert.assertEquals(".", ArrowDatabaseMetadata.sqlToRegexLike("_")); + Assert.assertEquals("\\*", ArrowDatabaseMetadata.sqlToRegexLike("*")); + Assert.assertEquals("T\\*E.S.*T", ArrowDatabaseMetadata.sqlToRegexLike("T*E_S%T")); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcArrayTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcArrayTest.java new file mode 100644 index 00000000000..90c926612f1 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcArrayTest.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.sql.Types; +import java.util.Arrays; +import java.util.HashMap; + +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.util.JsonStringArrayList; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class ArrowFlightJdbcArrayTest { + + @Rule + public RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + IntVector dataVector; + + @Before + public void setup() { + dataVector = rootAllocatorTestRule.createIntVector(); + } + + @After + public void tearDown() { + this.dataVector.close(); + } + + @Test + public void testShouldGetBaseTypeNameReturnCorrectTypeName() { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + Assert.assertEquals("INTEGER", arrowFlightJdbcArray.getBaseTypeName()); + } + + @Test + public void testShouldGetBaseTypeReturnCorrectType() { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + Assert.assertEquals(Types.INTEGER, arrowFlightJdbcArray.getBaseType()); + } + + @Test + public void testShouldGetArrayReturnValidArray() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + Object[] array = (Object[]) arrowFlightJdbcArray.getArray(); + + Object[] expected = new Object[dataVector.getValueCount()]; + for (int i = 0; i < expected.length; i++) { + expected[i] = dataVector.getObject(i); + } + Assert.assertArrayEquals(array, expected); + } + + @Test + public void testShouldGetArrayReturnValidArrayWithOffsets() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + Object[] array = (Object[]) arrowFlightJdbcArray.getArray(1, 5); + + Object[] expected = new Object[5]; + for (int i = 0; i < expected.length; i++) { + expected[i] = dataVector.getObject(i + 1); + } + Assert.assertArrayEquals(array, expected); + } + + @Test(expected = ArrayIndexOutOfBoundsException.class) + public void testShouldGetArrayWithOffsetsThrowArrayIndexOutOfBoundsException() + throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + arrowFlightJdbcArray.getArray(0, dataVector.getValueCount() + 1); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + public void testShouldGetArrayWithMapNotBeSupported() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + HashMap> map = new HashMap<>(); + arrowFlightJdbcArray.getArray(map); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + public void testShouldGetArrayWithOffsetsAndMapNotBeSupported() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + HashMap> map = new HashMap<>(); + arrowFlightJdbcArray.getArray(0, 5, map); + } + + @Test + public void testShouldGetResultSetReturnValidResultSet() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + try (ResultSet resultSet = arrowFlightJdbcArray.getResultSet()) { + int count = 0; + while (resultSet.next()) { + Assert.assertEquals((Object) resultSet.getInt(1), dataVector.getObject(count)); + count++; + } + } + } + + @Test + public void testShouldGetResultSetReturnValidResultSetWithOffsets() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + try (ResultSet resultSet = arrowFlightJdbcArray.getResultSet(3, 5)) { + int count = 0; + while (resultSet.next()) { + Assert.assertEquals((Object) resultSet.getInt(1), dataVector.getObject(count + 3)); + count++; + } + Assert.assertEquals(count, 5); + } + } + + @Test + public void testToString() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + + JsonStringArrayList array = new JsonStringArrayList<>(); + array.addAll(Arrays.asList((Object[]) arrowFlightJdbcArray.getArray())); + + Assert.assertEquals(array.toString(), arrowFlightJdbcArray.toString()); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + public void testShouldGetResultSetWithMapNotBeSupported() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + HashMap> map = new HashMap<>(); + arrowFlightJdbcArray.getResultSet(map); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + public void testShouldGetResultSetWithOffsetsAndMapNotBeSupported() throws SQLException { + ArrowFlightJdbcArray arrowFlightJdbcArray = + new ArrowFlightJdbcArray(dataVector, 0, dataVector.getValueCount()); + HashMap> map = new HashMap<>(); + arrowFlightJdbcArray.getResultSet(0, 5, map); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionCookieTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionCookieTest.java new file mode 100644 index 00000000000..c7268e0594e --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionCookieTest.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.Statement; + +import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers; +import org.junit.Assert; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightJdbcConnectionCookieTest { + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE = + FlightServerTestRule.createStandardTestRule(CoreMockedSqlProducers.getLegacyProducer()); + + @Test + public void testCookies() throws SQLException { + try (Connection connection = FLIGHT_SERVER_TEST_RULE.getConnection(false); + Statement statement = connection.createStatement()) { + + // Expect client didn't receive cookies before any operation + Assert.assertNull(FLIGHT_SERVER_TEST_RULE.getMiddlewareCookieFactory().getCookie()); + + // Run another action for check if the cookies was sent by the server. + statement.execute(CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD); + Assert.assertEquals("k=v", FLIGHT_SERVER_TEST_RULE.getMiddlewareCookieFactory().getCookie()); + } + } +} + diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionPoolDataSourceTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionPoolDataSourceTest.java new file mode 100644 index 00000000000..f4a5c87a23c --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcConnectionPoolDataSourceTest.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.Connection; + +import javax.sql.PooledConnection; + +import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; +import org.apache.arrow.driver.jdbc.utils.ConnectionWrapper; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +public class ArrowFlightJdbcConnectionPoolDataSourceTest { + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE; + + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + + static { + UserPasswordAuthentication authentication = + new UserPasswordAuthentication.Builder() + .user("user1", "pass1") + .user("user2", "pass2") + .build(); + + FLIGHT_SERVER_TEST_RULE = new FlightServerTestRule.Builder() + .host("localhost") + .randomPort() + .authentication(authentication) + .producer(PRODUCER) + .build(); + } + + private ArrowFlightJdbcConnectionPoolDataSource dataSource; + + @Before + public void setUp() { + dataSource = FLIGHT_SERVER_TEST_RULE.createConnectionPoolDataSource(false); + } + + @After + public void tearDown() throws Exception { + dataSource.close(); + } + + @Test + public void testShouldInnerConnectionIsClosedReturnCorrectly() throws Exception { + PooledConnection pooledConnection = dataSource.getPooledConnection(); + Connection connection = pooledConnection.getConnection(); + Assert.assertFalse(connection.isClosed()); + connection.close(); + Assert.assertTrue(connection.isClosed()); + } + + @Test + public void testShouldInnerConnectionShouldIgnoreDoubleClose() throws Exception { + PooledConnection pooledConnection = dataSource.getPooledConnection(); + Connection connection = pooledConnection.getConnection(); + Assert.assertFalse(connection.isClosed()); + connection.close(); + Assert.assertTrue(connection.isClosed()); + } + + @Test + public void testShouldInnerConnectionIsClosedReturnTrueIfPooledConnectionCloses() + throws Exception { + PooledConnection pooledConnection = dataSource.getPooledConnection(); + Connection connection = pooledConnection.getConnection(); + Assert.assertFalse(connection.isClosed()); + pooledConnection.close(); + Assert.assertTrue(connection.isClosed()); + } + + @Test + public void testShouldReuseConnectionsOnPool() throws Exception { + PooledConnection pooledConnection = dataSource.getPooledConnection("user1", "pass1"); + ConnectionWrapper connection = ((ConnectionWrapper) pooledConnection.getConnection()); + Assert.assertFalse(connection.isClosed()); + connection.close(); + Assert.assertTrue(connection.isClosed()); + Assert.assertFalse(connection.unwrap(ArrowFlightConnection.class).isClosed()); + + PooledConnection pooledConnection2 = dataSource.getPooledConnection("user1", "pass1"); + ConnectionWrapper connection2 = ((ConnectionWrapper) pooledConnection2.getConnection()); + Assert.assertFalse(connection2.isClosed()); + connection2.close(); + Assert.assertTrue(connection2.isClosed()); + Assert.assertFalse(connection2.unwrap(ArrowFlightConnection.class).isClosed()); + + Assert.assertSame(pooledConnection, pooledConnection2); + Assert.assertNotSame(connection, connection2); + Assert.assertSame(connection.unwrap(ArrowFlightConnection.class), + connection2.unwrap(ArrowFlightConnection.class)); + } + + @Test + public void testShouldNotMixConnectionsForDifferentUsers() throws Exception { + PooledConnection pooledConnection = dataSource.getPooledConnection("user1", "pass1"); + ConnectionWrapper connection = ((ConnectionWrapper) pooledConnection.getConnection()); + Assert.assertFalse(connection.isClosed()); + connection.close(); + Assert.assertTrue(connection.isClosed()); + Assert.assertFalse(connection.unwrap(ArrowFlightConnection.class).isClosed()); + + PooledConnection pooledConnection2 = dataSource.getPooledConnection("user2", "pass2"); + ConnectionWrapper connection2 = ((ConnectionWrapper) pooledConnection2.getConnection()); + Assert.assertFalse(connection2.isClosed()); + connection2.close(); + Assert.assertTrue(connection2.isClosed()); + Assert.assertFalse(connection2.unwrap(ArrowFlightConnection.class).isClosed()); + + Assert.assertNotSame(pooledConnection, pooledConnection2); + Assert.assertNotSame(connection, connection2); + Assert.assertNotSame(connection.unwrap(ArrowFlightConnection.class), + connection2.unwrap(ArrowFlightConnection.class)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcCursorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcCursorTest.java new file mode 100644 index 00000000000..b818f7115b7 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcCursorTest.java @@ -0,0 +1,251 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.junit.Assert.assertTrue; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.DurationVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.IntervalUnit; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.util.Cursor; +import org.junit.After; +import org.junit.Test; + +import com.google.common.collect.ImmutableList; + +/** + * Tests for {@link ArrowFlightJdbcCursor}. + */ +public class ArrowFlightJdbcCursorTest { + + ArrowFlightJdbcCursor cursor; + BufferAllocator allocator; + + @After + public void cleanUp() { + allocator.close(); + cursor.close(); + } + + @Test + public void testBinaryVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Binary", new ArrowType.Binary(), null); + ((VarBinaryVector) root.getVector("Binary")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testDateVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = + getVectorSchemaRoot("Date", new ArrowType.Date(DateUnit.DAY), null); + ((DateDayVector) root.getVector("Date")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testDurationVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Duration", + new ArrowType.Duration(TimeUnit.MILLISECOND), null); + ((DurationVector) root.getVector("Duration")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testDateInternalNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Interval", + new ArrowType.Interval(IntervalUnit.DAY_TIME), null); + ((IntervalDayVector) root.getVector("Interval")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testTimeStampVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("TimeStamp", + new ArrowType.Timestamp(TimeUnit.MILLISECOND, null), null); + ((TimeStampMilliVector) root.getVector("TimeStamp")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testTimeVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Time", + new ArrowType.Time(TimeUnit.MILLISECOND, 32), null); + ((TimeMilliVector) root.getVector("Time")).setNull(0); + testCursorWasNull(root); + + } + + @Test + public void testFixedSizeListVectorNullTrue() throws SQLException { + List fieldList = new ArrayList<>(); + fieldList.add(new Field("Null", new FieldType(true, new ArrowType.Null(), null), + null)); + final VectorSchemaRoot root = getVectorSchemaRoot("FixedSizeList", + new ArrowType.FixedSizeList(10), fieldList); + ((FixedSizeListVector) root.getVector("FixedSizeList")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testLargeListVectorNullTrue() throws SQLException { + List fieldList = new ArrayList<>(); + fieldList.add(new Field("Null", new FieldType(true, new ArrowType.Null(), null), + null)); + final VectorSchemaRoot root = + getVectorSchemaRoot("LargeList", new ArrowType.LargeList(), fieldList); + ((LargeListVector) root.getVector("LargeList")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testListVectorNullTrue() throws SQLException { + List fieldList = new ArrayList<>(); + fieldList.add(new Field("Null", new FieldType(true, new ArrowType.Null(), null), + null)); + final VectorSchemaRoot root = getVectorSchemaRoot("List", new ArrowType.List(), fieldList); + ((ListVector) root.getVector("List")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testMapVectorNullTrue() throws SQLException { + List structChildren = new ArrayList<>(); + structChildren.add(new Field("Key", new FieldType(false, new ArrowType.Utf8(), null), + null)); + structChildren.add(new Field("Value", new FieldType(false, new ArrowType.Utf8(), null), + null)); + List fieldList = new ArrayList<>(); + fieldList.add(new Field("Struct", new FieldType(false, new ArrowType.Struct(), null), + structChildren)); + final VectorSchemaRoot root = getVectorSchemaRoot("Map", new ArrowType.Map(false), fieldList); + ((MapVector) root.getVector("Map")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testStructVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Struct", new ArrowType.Struct(), null); + ((StructVector) root.getVector("Struct")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testBaseIntVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("BaseInt", + new ArrowType.Int(32, false), null); + ((UInt4Vector) root.getVector("BaseInt")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testBitVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Bit", new ArrowType.Bool(), null); + ((BitVector) root.getVector("Bit")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testDecimalVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Decimal", + new ArrowType.Decimal(2, 2, 128), null); + ((DecimalVector) root.getVector("Decimal")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testFloat4VectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Float4", + new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), null); + ((Float4Vector) root.getVector("Float4")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testFloat8VectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Float8", + new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), null); + ((Float8Vector) root.getVector("Float8")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testVarCharVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("VarChar", new ArrowType.Utf8(), null); + ((VarCharVector) root.getVector("VarChar")).setNull(0); + testCursorWasNull(root); + } + + @Test + public void testNullVectorNullTrue() throws SQLException { + final VectorSchemaRoot root = getVectorSchemaRoot("Null", new ArrowType.Null(), null); + testCursorWasNull(root); + } + + private VectorSchemaRoot getVectorSchemaRoot(String name, ArrowType arrowType, + List children) { + final Schema schema = new Schema(ImmutableList.of( + new Field( + name, + new FieldType(true, arrowType, + null), + children))); + allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + root.allocateNew(); + return root; + } + + private void testCursorWasNull(VectorSchemaRoot root) throws SQLException { + root.setRowCount(1); + cursor = new ArrowFlightJdbcCursor(root); + cursor.next(); + List accessorList = cursor.createAccessors(null, null, null); + accessorList.get(0).getObject(); + assertTrue(cursor.wasNull()); + root.close(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriverTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriverTest.java new file mode 100644 index 00000000000..682c20c696a --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriverTest.java @@ -0,0 +1,381 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.sql.Connection; +import java.sql.Driver; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Collection; +import java.util.Map; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +/** + * Tests for {@link ArrowFlightJdbcDriver}. + */ +public class ArrowFlightJdbcDriverTest { + + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE; + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + + static { + UserPasswordAuthentication authentication = + new UserPasswordAuthentication.Builder().user("user1", "pass1").user("user2", "pass2") + .build(); + + FLIGHT_SERVER_TEST_RULE = new FlightServerTestRule.Builder().host("localhost").randomPort() + .authentication(authentication).producer(PRODUCER).build(); + } + + private BufferAllocator allocator; + private ArrowFlightJdbcConnectionPoolDataSource dataSource; + + @Before + public void setUp() throws Exception { + allocator = new RootAllocator(Long.MAX_VALUE); + dataSource = FLIGHT_SERVER_TEST_RULE.createConnectionPoolDataSource(); + } + + @After + public void tearDown() throws Exception { + Collection childAllocators = allocator.getChildAllocators(); + AutoCloseables.close(childAllocators.toArray(new AutoCloseable[0])); + AutoCloseables.close(dataSource, allocator); + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} is registered in the + * {@link DriverManager}. + * + * @throws SQLException If an error occurs. (This is not supposed to happen.) + */ + @Test + public void testDriverIsRegisteredInDriverManager() throws Exception { + assertTrue(DriverManager.getDriver("jdbc:arrow-flight://localhost:32010") instanceof + ArrowFlightJdbcDriver); + assertTrue(DriverManager.getDriver("jdbc:arrow-flight-sql://localhost:32010") instanceof + ArrowFlightJdbcDriver); + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} fails when provided with an + * unsupported URL prefix. + * + * @throws SQLException If the test passes. + */ + @Test(expected = SQLException.class) + public void testShouldDeclineUrlWithUnsupportedPrefix() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + + driver.connect("jdbc:mysql://localhost:32010", dataSource.getProperties("flight", "flight123")) + .close(); + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} can establish a successful + * connection to the Arrow Flight client. + * + * @throws Exception If the connection fails to be established. + */ + @Test + public void testShouldConnectWhenProvidedWithValidUrl() throws Exception { + // Get the Arrow Flight JDBC driver by providing a URL with a valid prefix. + final Driver driver = new ArrowFlightJdbcDriver(); + + try (Connection connection = + driver.connect("jdbc:arrow-flight://" + + dataSource.getConfig().getHost() + ":" + + dataSource.getConfig().getPort() + "?" + + "useEncryption=false", + dataSource.getProperties(dataSource.getConfig().getUser(), dataSource.getConfig().getPassword()))) { + assertTrue(connection.isValid(300)); + } + try (Connection connection = + driver.connect("jdbc:arrow-flight-sql://" + + dataSource.getConfig().getHost() + ":" + + dataSource.getConfig().getPort() + "?" + + "useEncryption=false", + dataSource.getProperties(dataSource.getConfig().getUser(), dataSource.getConfig().getPassword()))) { + assertTrue(connection.isValid(300)); + } + } + + @Test + public void testConnectWithInsensitiveCasePropertyKeys() throws Exception { + // Get the Arrow Flight JDBC driver by providing a URL with insensitive case property keys. + final Driver driver = new ArrowFlightJdbcDriver(); + + try (Connection connection = + driver.connect("jdbc:arrow-flight://" + + dataSource.getConfig().getHost() + ":" + + dataSource.getConfig().getPort() + "?" + + "UseEncryptiOn=false", + dataSource.getProperties(dataSource.getConfig().getUser(), dataSource.getConfig().getPassword()))) { + assertTrue(connection.isValid(300)); + } + try (Connection connection = + driver.connect("jdbc:arrow-flight-sql://" + + dataSource.getConfig().getHost() + ":" + + dataSource.getConfig().getPort() + "?" + + "UseEncryptiOn=false", + dataSource.getProperties(dataSource.getConfig().getUser(), dataSource.getConfig().getPassword()))) { + assertTrue(connection.isValid(300)); + } + } + + @Test + public void testConnectWithInsensitiveCasePropertyKeys2() throws Exception { + // Get the Arrow Flight JDBC driver by providing a property object with insensitive case keys. + final Driver driver = new ArrowFlightJdbcDriver(); + Properties properties = + dataSource.getProperties(dataSource.getConfig().getUser(), dataSource.getConfig().getPassword()); + properties.put("UseEncryptiOn", "false"); + + try (Connection connection = + driver.connect("jdbc:arrow-flight://" + + dataSource.getConfig().getHost() + ":" + + dataSource.getConfig().getPort(), properties)) { + assertTrue(connection.isValid(300)); + } + try (Connection connection = + driver.connect("jdbc:arrow-flight-sql://" + + dataSource.getConfig().getHost() + ":" + + dataSource.getConfig().getPort(), properties)) { + assertTrue(connection.isValid(300)); + } + } + + /** + * Tests whether an exception is thrown upon attempting to connect to a + * malformed URI. + */ + @Test(expected = SQLException.class) + public void testShouldThrowExceptionWhenAttemptingToConnectToMalformedUrl() throws SQLException { + final Driver driver = new ArrowFlightJdbcDriver(); + final String malformedUri = "yes:??/chainsaw.i=T333"; + + driver.connect(malformedUri, dataSource.getProperties("flight", "flight123")); + } + + /** + * Tests whether an exception is thrown upon attempting to connect to a + * malformed URI. + * + * @throws Exception If an error occurs. + */ + @Test(expected = SQLException.class) + public void testShouldThrowExceptionWhenAttemptingToConnectToUrlNoPrefix() throws SQLException { + final Driver driver = new ArrowFlightJdbcDriver(); + final String malformedUri = "localhost:32010"; + + driver.connect(malformedUri, dataSource.getProperties(dataSource.getConfig().getUser(), + dataSource.getConfig().getPassword())); + } + + /** + * Tests whether an exception is thrown upon attempting to connect to a + * malformed URI. + */ + @Test + public void testShouldThrowExceptionWhenAttemptingToConnectToUrlNoPort() { + final Driver driver = new ArrowFlightJdbcDriver(); + SQLException e = assertThrows(SQLException.class, () -> { + Properties properties = dataSource.getProperties(dataSource.getConfig().getUser(), + dataSource.getConfig().getPassword()); + Connection conn = driver.connect("jdbc:arrow-flight://localhost", properties); + conn.close(); + }); + assertTrue(e.getMessage().contains("URL must have a port")); + e = assertThrows(SQLException.class, () -> { + Properties properties = dataSource.getProperties(dataSource.getConfig().getUser(), + dataSource.getConfig().getPassword()); + Connection conn = driver.connect("jdbc:arrow-flight-sql://localhost", properties); + conn.close(); + }); + assertTrue(e.getMessage().contains("URL must have a port")); + } + + /** + * Tests whether an exception is thrown upon attempting to connect to a + * malformed URI. + */ + @Test + public void testShouldThrowExceptionWhenAttemptingToConnectToUrlNoHost() { + final Driver driver = new ArrowFlightJdbcDriver(); + SQLException e = assertThrows(SQLException.class, () -> { + Properties properties = dataSource.getProperties(dataSource.getConfig().getUser(), + dataSource.getConfig().getPassword()); + Connection conn = driver.connect("jdbc:arrow-flight://32010:localhost", properties); + conn.close(); + }); + assertTrue(e.getMessage().contains("URL must have a host")); + + e = assertThrows(SQLException.class, () -> { + Properties properties = dataSource.getProperties(dataSource.getConfig().getUser(), + dataSource.getConfig().getPassword()); + Connection conn = driver.connect("jdbc:arrow-flight-sql://32010:localhost", properties); + conn.close(); + }); + assertTrue(e.getMessage().contains("URL must have a host")); + } + + /** + * Tests whether {@link ArrowFlightJdbcDriver#getUrlsArgs} returns the + * correct URL parameters. + * + * @throws Exception If an error occurs. + */ + @Test + public void testDriverUrlParsingMechanismShouldReturnTheDesiredArgsFromUrl() throws Exception { + final ArrowFlightJdbcDriver driver = new ArrowFlightJdbcDriver(); + + final Map parsedArgs = driver.getUrlsArgs( + "jdbc:arrow-flight-sql://localhost:2222/?key1=value1&key2=value2&a=b"); + + // Check size == the amount of args provided (scheme not included) + assertEquals(5, parsedArgs.size()); + + // Check host == the provided host + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.HOST.camelName()), "localhost"); + + // Check port == the provided port + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.PORT.camelName()), 2222); + + // Check all other non-default arguments + assertEquals(parsedArgs.get("key1"), "value1"); + assertEquals(parsedArgs.get("key2"), "value2"); + assertEquals(parsedArgs.get("a"), "b"); + } + + @Test + public void testDriverUrlParsingMechanismShouldReturnTheDesiredArgsFromUrlWithSemicolon() throws Exception { + final ArrowFlightJdbcDriver driver = new ArrowFlightJdbcDriver(); + final Map parsedArgs = driver.getUrlsArgs( + "jdbc:arrow-flight-sql://localhost:2222/;key1=value1;key2=value2;a=b"); + + // Check size == the amount of args provided (scheme not included) + assertEquals(5, parsedArgs.size()); + + // Check host == the provided host + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.HOST.camelName()), "localhost"); + + // Check port == the provided port + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.PORT.camelName()), 2222); + + // Check all other non-default arguments + assertEquals(parsedArgs.get("key1"), "value1"); + assertEquals(parsedArgs.get("key2"), "value2"); + assertEquals(parsedArgs.get("a"), "b"); + } + + @Test + public void testDriverUrlParsingMechanismShouldReturnTheDesiredArgsFromUrlWithOneSemicolon() throws Exception { + final ArrowFlightJdbcDriver driver = new ArrowFlightJdbcDriver(); + final Map parsedArgs = driver.getUrlsArgs( + "jdbc:arrow-flight-sql://localhost:2222/;key1=value1"); + + // Check size == the amount of args provided (scheme not included) + assertEquals(3, parsedArgs.size()); + + // Check host == the provided host + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.HOST.camelName()), "localhost"); + + // Check port == the provided port + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.PORT.camelName()), 2222); + + // Check all other non-default arguments + assertEquals(parsedArgs.get("key1"), "value1"); + } + + /** + * Tests whether an exception is thrown upon attempting to connect to a + * malformed URI. + * + */ + @Test + public void testDriverUrlParsingMechanismShouldThrowExceptionUponProvidedWithMalformedUrl() { + final ArrowFlightJdbcDriver driver = new ArrowFlightJdbcDriver(); + assertThrows(SQLException.class, () -> driver.getUrlsArgs( + "jdbc:malformed-url-flight://localhost:2222")); + } + + /** + * Tests whether {@code ArrowFlightJdbcDriverTest#getUrlsArgs} returns the + * correct URL parameters when the host is an IP Address. + * + * @throws Exception If an error occurs. + */ + @Test + public void testDriverUrlParsingMechanismShouldWorkWithIPAddress() throws Exception { + final ArrowFlightJdbcDriver driver = new ArrowFlightJdbcDriver(); + final Map parsedArgs = driver.getUrlsArgs("jdbc:arrow-flight-sql://0.0.0.0:2222"); + + // Check size == the amount of args provided (scheme not included) + assertEquals(2, parsedArgs.size()); + + // Check host == the provided host + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.HOST.camelName()), "0.0.0.0"); + + // Check port == the provided port + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.PORT.camelName()), 2222); + } + + /** + * Tests whether {@code ArrowFlightJdbcDriverTest#getUrlsArgs} escape especial characters and returns the + * correct URL parameters when the especial character '&' is embedded in the query parameters values. + * + * @throws Exception If an error occurs. + */ + @Test + public void testDriverUrlParsingMechanismShouldWorkWithEmbeddedEspecialCharacter() + throws Exception { + final ArrowFlightJdbcDriver driver = new ArrowFlightJdbcDriver(); + final Map parsedArgs = driver.getUrlsArgs( + "jdbc:arrow-flight-sql://0.0.0.0:2222?test1=test1value&test2%26continue=test2value&test3=test3value"); + + // Check size == the amount of args provided (scheme not included) + assertEquals(5, parsedArgs.size()); + + // Check host == the provided host + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.HOST.camelName()), "0.0.0.0"); + + // Check port == the provided port + assertEquals(parsedArgs.get(ArrowFlightConnectionProperty.PORT.camelName()), 2222); + + // Check all other non-default arguments + assertEquals(parsedArgs.get("test1"), "test1value"); + assertEquals(parsedArgs.get("test2&continue"), "test2value"); + assertEquals(parsedArgs.get("test3"), "test3value"); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactoryTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactoryTest.java new file mode 100644 index 00000000000..c482169852e --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFactoryTest.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.lang.reflect.Constructor; +import java.sql.Connection; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.calcite.avatica.UnregisteredDriver; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +import com.google.common.collect.ImmutableMap; + +/** + * Tests for {@link ArrowFlightJdbcDriver}. + */ +public class ArrowFlightJdbcFactoryTest { + + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE; + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + + static { + UserPasswordAuthentication authentication = + new UserPasswordAuthentication.Builder().user("user1", "pass1").user("user2", "pass2") + .build(); + + FLIGHT_SERVER_TEST_RULE = new FlightServerTestRule.Builder().host("localhost").randomPort() + .authentication(authentication).producer(PRODUCER).build(); + } + + private BufferAllocator allocator; + private ArrowFlightJdbcConnectionPoolDataSource dataSource; + + @Before + public void setUp() throws Exception { + allocator = new RootAllocator(Long.MAX_VALUE); + dataSource = FLIGHT_SERVER_TEST_RULE.createConnectionPoolDataSource(); + } + + @After + public void tearDown() throws Exception { + AutoCloseables.close(dataSource, allocator); + } + + @Test + public void testShouldBeAbleToEstablishAConnectionSuccessfully() throws Exception { + UnregisteredDriver driver = new ArrowFlightJdbcDriver(); + Constructor constructor = ArrowFlightJdbcFactory.class.getConstructor(); + constructor.setAccessible(true); + ArrowFlightJdbcFactory factory = constructor.newInstance(); + + final Properties properties = new Properties(); + properties.putAll(ImmutableMap.of( + ArrowFlightConnectionProperty.HOST.camelName(), "localhost", + ArrowFlightConnectionProperty.PORT.camelName(), 32010, + ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), false)); + + try (Connection connection = factory.newConnection(driver, constructor.newInstance(), + "jdbc:arrow-flight-sql://localhost:32010", properties)) { + assert connection.isValid(300); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcTimeTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcTimeTest.java new file mode 100644 index 00000000000..104794b3ad1 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcTimeTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.hamcrest.CoreMatchers.endsWith; +import static org.hamcrest.CoreMatchers.is; + +import java.time.LocalTime; +import java.util.concurrent.TimeUnit; + +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightJdbcTimeTest { + + @ClassRule + public static final ErrorCollector collector = new ErrorCollector(); + final int hour = 5; + final int minute = 6; + final int second = 7; + + @Test + public void testPrintingMillisNoLeadingZeroes() { + // testing the regular case where the precision of the millisecond is 3 + LocalTime dateTime = LocalTime.of(hour, minute, second, (int) TimeUnit.MILLISECONDS.toNanos(999)); + ArrowFlightJdbcTime time = new ArrowFlightJdbcTime(dateTime); + collector.checkThat(time.toString(), endsWith(".999")); + collector.checkThat(time.getHours(), is(hour)); + collector.checkThat(time.getMinutes(), is(minute)); + collector.checkThat(time.getSeconds(), is(second)); + } + + @Test + public void testPrintingMillisOneLeadingZeroes() { + // test case where one leading zero needs to be added + LocalTime dateTime = LocalTime.of(hour, minute, second, (int) TimeUnit.MILLISECONDS.toNanos(99)); + ArrowFlightJdbcTime time = new ArrowFlightJdbcTime(dateTime); + collector.checkThat(time.toString(), endsWith(".099")); + collector.checkThat(time.getHours(), is(hour)); + collector.checkThat(time.getMinutes(), is(minute)); + collector.checkThat(time.getSeconds(), is(second)); + } + + @Test + public void testPrintingMillisTwoLeadingZeroes() { + // test case where two leading zeroes needs to be added + LocalTime dateTime = LocalTime.of(hour, minute, second, (int) TimeUnit.MILLISECONDS.toNanos(1)); + ArrowFlightJdbcTime time = new ArrowFlightJdbcTime(dateTime); + collector.checkThat(time.toString(), endsWith(".001")); + collector.checkThat(time.getHours(), is(hour)); + collector.checkThat(time.getMinutes(), is(minute)); + collector.checkThat(time.getSeconds(), is(second)); + } + + @Test + public void testEquality() { + // tests #equals and #hashCode for coverage checks + LocalTime dateTime = LocalTime.of(hour, minute, second, (int) TimeUnit.MILLISECONDS.toNanos(1)); + ArrowFlightJdbcTime time1 = new ArrowFlightJdbcTime(dateTime); + ArrowFlightJdbcTime time2 = new ArrowFlightJdbcTime(dateTime); + collector.checkThat(time1, is(time2)); + collector.checkThat(time1.hashCode(), is(time2.hashCode())); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java new file mode 100644 index 00000000000..51c491be288 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.hamcrest.CoreMatchers.equalTo; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; + +import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightPreparedStatementTest { + + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE = FlightServerTestRule + .createStandardTestRule(CoreMockedSqlProducers.getLegacyProducer()); + + private static Connection connection; + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @BeforeClass + public static void setup() throws SQLException { + connection = FLIGHT_SERVER_TEST_RULE.getConnection(false); + } + + @AfterClass + public static void tearDown() throws SQLException { + connection.close(); + } + + @Test + public void testSimpleQueryNoParameterBinding() throws SQLException { + final String query = CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD; + try (final PreparedStatement preparedStatement = connection.prepareStatement(query); + final ResultSet resultSet = preparedStatement.executeQuery()) { + CoreMockedSqlProducers.assertLegacyRegularSqlResultSet(resultSet, collector); + } + } + + @Test + public void testReturnColumnCount() throws SQLException { + final String query = CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD; + try (final PreparedStatement psmt = connection.prepareStatement(query)) { + collector.checkThat("ID", equalTo(psmt.getMetaData().getColumnName(1))); + collector.checkThat("Name", equalTo(psmt.getMetaData().getColumnName(2))); + collector.checkThat("Age", equalTo(psmt.getMetaData().getColumnName(3))); + collector.checkThat("Salary", equalTo(psmt.getMetaData().getColumnName(4))); + collector.checkThat("Hire Date", equalTo(psmt.getMetaData().getColumnName(5))); + collector.checkThat("Last Sale", equalTo(psmt.getMetaData().getColumnName(6))); + collector.checkThat(6, equalTo(psmt.getMetaData().getColumnCount())); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java new file mode 100644 index 00000000000..155fcc50827 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteTest.java @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.hamcrest.CoreMatchers.allOf; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.nullValue; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.AvaticaUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +/** + * Tests for {@link ArrowFlightStatement#execute}. + */ +public class ArrowFlightStatementExecuteTest { + private static final String SAMPLE_QUERY_CMD = "SELECT * FROM this_test"; + private static final int SAMPLE_QUERY_ROWS = Byte.MAX_VALUE; + private static final String VECTOR_NAME = "Unsigned Byte"; + private static final Schema SAMPLE_QUERY_SCHEMA = + new Schema(Collections.singletonList(Field.nullable(VECTOR_NAME, MinorType.UINT1.getType()))); + private static final String SAMPLE_UPDATE_QUERY = + "UPDATE this_table SET this_field = that_field FROM this_test WHERE this_condition"; + private static final long SAMPLE_UPDATE_COUNT = 100L; + private static final String SAMPLE_LARGE_UPDATE_QUERY = + "UPDATE this_large_table SET this_large_field = that_large_field FROM this_large_test WHERE this_large_condition"; + private static final long SAMPLE_LARGE_UPDATE_COUNT = Long.MAX_VALUE; + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + @ClassRule + public static final FlightServerTestRule SERVER_TEST_RULE = FlightServerTestRule.createStandardTestRule(PRODUCER); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + private Connection connection; + private Statement statement; + + @BeforeClass + public static void setUpBeforeClass() { + PRODUCER.addSelectQuery( + SAMPLE_QUERY_CMD, + SAMPLE_QUERY_SCHEMA, + Collections.singletonList(listener -> { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(SAMPLE_QUERY_SCHEMA, + allocator)) { + final UInt1Vector vector = (UInt1Vector) root.getVector(VECTOR_NAME); + IntStream.range(0, SAMPLE_QUERY_ROWS).forEach(index -> vector.setSafe(index, index)); + vector.setValueCount(SAMPLE_QUERY_ROWS); + root.setRowCount(SAMPLE_QUERY_ROWS); + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + })); + PRODUCER.addUpdateQuery(SAMPLE_UPDATE_QUERY, SAMPLE_UPDATE_COUNT); + PRODUCER.addUpdateQuery(SAMPLE_LARGE_UPDATE_QUERY, SAMPLE_LARGE_UPDATE_COUNT); + } + + @Before + public void setUp() throws SQLException { + connection = SERVER_TEST_RULE.getConnection(false); + statement = connection.createStatement(); + } + + @After + public void tearDown() throws Exception { + AutoCloseables.close(statement, connection); + } + + @AfterClass + public static void tearDownAfterClass() throws Exception { + AutoCloseables.close(PRODUCER); + } + + @Test + public void testExecuteShouldRunSelectQuery() throws SQLException { + collector.checkThat(statement.execute(SAMPLE_QUERY_CMD), + is(true)); // Means this is a SELECT query. + final Set numbers = + IntStream.range(0, SAMPLE_QUERY_ROWS).boxed() + .map(Integer::byteValue) + .collect(Collectors.toCollection(HashSet::new)); + try (final ResultSet resultSet = statement.getResultSet()) { + final int columnCount = resultSet.getMetaData().getColumnCount(); + collector.checkThat(columnCount, is(1)); + int rowCount = 0; + for (; resultSet.next(); rowCount++) { + collector.checkThat(numbers.remove(resultSet.getByte(1)), is(true)); + } + collector.checkThat(rowCount, is(equalTo(SAMPLE_QUERY_ROWS))); + } + collector.checkThat(numbers, is(Collections.emptySet())); + collector.checkThat( + (long) statement.getUpdateCount(), + is(allOf(equalTo(statement.getLargeUpdateCount()), equalTo(-1L)))); + } + + @Test + public void testExecuteShouldRunUpdateQueryForSmallUpdate() throws SQLException { + collector.checkThat(statement.execute(SAMPLE_UPDATE_QUERY), + is(false)); // Means this is an UPDATE query. + collector.checkThat( + (long) statement.getUpdateCount(), + is(allOf(equalTo(statement.getLargeUpdateCount()), equalTo(SAMPLE_UPDATE_COUNT)))); + collector.checkThat(statement.getResultSet(), is(nullValue())); + } + + @Test + public void testExecuteShouldRunUpdateQueryForLargeUpdate() throws SQLException { + collector.checkThat(statement.execute(SAMPLE_LARGE_UPDATE_QUERY), is(false)); // UPDATE query. + final long updateCountSmall = statement.getUpdateCount(); + final long updateCountLarge = statement.getLargeUpdateCount(); + collector.checkThat(updateCountLarge, is(equalTo(SAMPLE_LARGE_UPDATE_COUNT))); + collector.checkThat( + updateCountSmall, + is(allOf(equalTo((long) AvaticaUtils.toSaturatedInt(updateCountLarge)), + not(equalTo(updateCountLarge))))); + collector.checkThat(statement.getResultSet(), is(nullValue())); + } + + @Test + public void testUpdateCountShouldStartOnZero() throws SQLException { + collector.checkThat( + (long) statement.getUpdateCount(), + is(allOf(equalTo(statement.getLargeUpdateCount()), equalTo(0L)))); + collector.checkThat(statement.getResultSet(), is(nullValue())); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java new file mode 100644 index 00000000000..43209d8913e --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightStatementExecuteUpdateTest.java @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static java.lang.String.format; +import static org.hamcrest.CoreMatchers.allOf; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; + +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.SQLFeatureNotSupportedException; +import java.sql.Statement; +import java.util.Collections; + +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.AvaticaUtils; +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +/** + * Tests for {@link ArrowFlightStatement#executeUpdate}. + */ +public class ArrowFlightStatementExecuteUpdateTest { + private static final String UPDATE_SAMPLE_QUERY = + "UPDATE sample_table SET sample_col = sample_val WHERE sample_condition"; + private static final int UPDATE_SAMPLE_QUERY_AFFECTED_COLS = 10; + private static final String LARGE_UPDATE_SAMPLE_QUERY = + "UPDATE large_sample_table SET large_sample_col = large_sample_val WHERE large_sample_condition"; + private static final long LARGE_UPDATE_SAMPLE_QUERY_AFFECTED_COLS = (long) Integer.MAX_VALUE + 1; + private static final String REGULAR_QUERY_SAMPLE = "SELECT * FROM NOT_UPDATE_QUERY"; + private static final Schema REGULAR_QUERY_SCHEMA = + new Schema( + Collections.singletonList(Field.nullable("placeholder", MinorType.VARCHAR.getType()))); + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + @ClassRule + public static final FlightServerTestRule SERVER_TEST_RULE = FlightServerTestRule.createStandardTestRule(PRODUCER); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + public Connection connection; + public Statement statement; + + @BeforeClass + public static void setUpBeforeClass() { + PRODUCER.addUpdateQuery(UPDATE_SAMPLE_QUERY, UPDATE_SAMPLE_QUERY_AFFECTED_COLS); + PRODUCER.addUpdateQuery(LARGE_UPDATE_SAMPLE_QUERY, LARGE_UPDATE_SAMPLE_QUERY_AFFECTED_COLS); + PRODUCER.addSelectQuery( + REGULAR_QUERY_SAMPLE, + REGULAR_QUERY_SCHEMA, + Collections.singletonList(listener -> { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(REGULAR_QUERY_SCHEMA, + allocator)) { + listener.start(root); + listener.putNext(); + } catch (final Throwable throwable) { + listener.error(throwable); + } finally { + listener.completed(); + } + })); + } + + @Before + public void setUp() throws SQLException { + connection = SERVER_TEST_RULE.getConnection(false); + statement = connection.createStatement(); + } + + @After + public void tearDown() throws Exception { + AutoCloseables.close(statement, connection); + } + + @AfterClass + public static void tearDownAfterClass() throws Exception { + AutoCloseables.close(PRODUCER); + } + + @Test + public void testExecuteUpdateShouldReturnNumColsAffectedForNumRowsFittingInt() + throws SQLException { + collector.checkThat(statement.executeUpdate(UPDATE_SAMPLE_QUERY), + is(UPDATE_SAMPLE_QUERY_AFFECTED_COLS)); + } + + @Test + public void testExecuteUpdateShouldReturnSaturatedNumColsAffectedIfDoesNotFitInInt() + throws SQLException { + final long result = statement.executeUpdate(LARGE_UPDATE_SAMPLE_QUERY); + final long expectedRowCountRaw = LARGE_UPDATE_SAMPLE_QUERY_AFFECTED_COLS; + collector.checkThat( + result, + is(allOf( + not(equalTo(expectedRowCountRaw)), + equalTo((long) AvaticaUtils.toSaturatedInt( + expectedRowCountRaw))))); // Because of long-to-integer overflow. + } + + @Test + public void testExecuteLargeUpdateShouldReturnNumColsAffected() throws SQLException { + collector.checkThat( + statement.executeLargeUpdate(LARGE_UPDATE_SAMPLE_QUERY), + is(LARGE_UPDATE_SAMPLE_QUERY_AFFECTED_COLS)); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + // TODO Implement `Statement#executeUpdate(String, int)` + public void testExecuteUpdateUnsupportedWithDriverFlag() throws SQLException { + collector.checkThat( + statement.executeUpdate(UPDATE_SAMPLE_QUERY, Statement.RETURN_GENERATED_KEYS), + is(UPDATE_SAMPLE_QUERY_AFFECTED_COLS)); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + // TODO Implement `Statement#executeUpdate(String, int[])` + public void testExecuteUpdateUnsupportedWithArrayOfInts() throws SQLException { + collector.checkThat( + statement.executeUpdate(UPDATE_SAMPLE_QUERY, new int[0]), + is(UPDATE_SAMPLE_QUERY_AFFECTED_COLS)); + } + + @Test(expected = SQLFeatureNotSupportedException.class) + // TODO Implement `Statement#executeUpdate(String, String[])` + public void testExecuteUpdateUnsupportedWithArraysOfStrings() throws SQLException { + collector.checkThat( + statement.executeUpdate(UPDATE_SAMPLE_QUERY, new String[0]), + is(UPDATE_SAMPLE_QUERY_AFFECTED_COLS)); + } + + @Test + public void testExecuteShouldExecuteUpdateQueryAutomatically() throws SQLException { + collector.checkThat(statement.execute(UPDATE_SAMPLE_QUERY), + is(false)); // Meaning there was an update query. + collector.checkThat(statement.execute(REGULAR_QUERY_SAMPLE), + is(true)); // Meaning there was a select query. + } + + @Test + public void testShouldFailToPrepareStatementForNullQuery() { + int count = 0; + try { + collector.checkThat(statement.execute(null), is(false)); + } catch (final SQLException e) { + count++; + collector.checkThat(e.getCause(), is(instanceOf(NullPointerException.class))); + } + collector.checkThat(count, is(1)); + } + + @Test + public void testShouldFailToPrepareStatementForClosedStatement() throws SQLException { + statement.close(); + collector.checkThat(statement.isClosed(), is(true)); + int count = 0; + try { + statement.execute(UPDATE_SAMPLE_QUERY); + } catch (final SQLException e) { + count++; + collector.checkThat(e.getMessage(), is("Statement closed")); + } + collector.checkThat(count, is(1)); + } + + @Test + public void testShouldFailToPrepareStatementForBadStatement() { + final String badQuery = "BAD INVALID STATEMENT"; + int count = 0; + try { + statement.execute(badQuery); + } catch (final SQLException e) { + count++; + /* + * The error message is up to whatever implementation of `FlightSqlProducer` + * the driver is communicating with. However, for the purpose of this test, + * we simply throw an `IllegalArgumentException` for queries not registered + * in our `MockFlightSqlProducer`. + */ + collector.checkThat( + e.getMessage(), + is(format("Error while executing SQL \"%s\": Query not found", badQuery))); + } + collector.checkThat(count, is(1)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java new file mode 100644 index 00000000000..6fe7ba71298 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTest.java @@ -0,0 +1,552 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.junit.Assert.assertNotNull; + +import java.net.URISyntaxException; +import java.sql.Connection; +import java.sql.Driver; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; +import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler; +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +/** + * Tests for {@link Connection}. + */ +public class ConnectionTest { + + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE; + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + private static final String userTest = "user1"; + private static final String passTest = "pass1"; + + static { + UserPasswordAuthentication authentication = + new UserPasswordAuthentication.Builder() + .user(userTest, passTest) + .build(); + + FLIGHT_SERVER_TEST_RULE = new FlightServerTestRule.Builder().host("localhost").randomPort() + .authentication(authentication).producer(PRODUCER).build(); + } + + private BufferAllocator allocator; + + @Before + public void setUp() throws Exception { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @After + public void tearDown() throws Exception { + allocator.getChildAllocators().forEach(BufferAllocator::close); + AutoCloseables.close(allocator); + } + + /** + * Checks if an unencrypted connection can be established successfully when + * the provided valid credentials. + * + * @throws SQLException on error. + */ + @Test + public void testUnencryptedConnectionShouldOpenSuccessfullyWhenProvidedValidCredentials() + throws Exception { + final Properties properties = new Properties(); + + properties.put(ArrowFlightConnectionProperty.HOST.camelName(), "localhost"); + properties.put(ArrowFlightConnectionProperty.PORT.camelName(), + FLIGHT_SERVER_TEST_RULE.getPort()); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put("useEncryption", false); + + try (Connection connection = DriverManager.getConnection( + "jdbc:arrow-flight-sql://" + FLIGHT_SERVER_TEST_RULE.getHost() + ":" + + FLIGHT_SERVER_TEST_RULE.getPort(), properties)) { + assert connection.isValid(300); + } + } + + /** + * Checks if the exception SQLException is thrown when trying to establish a connection without a host. + * + * @throws SQLException on error. + */ + @Test(expected = SQLException.class) + public void testUnencryptedConnectionWithEmptyHost() + throws Exception { + final Properties properties = new Properties(); + + properties.put("user", userTest); + properties.put("password", passTest); + final String invalidUrl = "jdbc:arrow-flight-sql://"; + + DriverManager.getConnection(invalidUrl, properties); + } + + /** + * Try to instantiate a basic FlightClient. + * + * @throws URISyntaxException on error. + */ + @Test + public void testGetBasicClientAuthenticatedShouldOpenConnection() + throws Exception { + + try (ArrowFlightSqlClientHandler client = + new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withPort(FLIGHT_SERVER_TEST_RULE.getPort()) + .withUsername(userTest) + .withPassword(passTest) + .withBufferAllocator(allocator) + .build()) { + assertNotNull(client); + } + } + + /** + * Checks if the exception IllegalArgumentException is thrown when trying to establish an unencrypted + * connection providing with an invalid port. + * + * @throws SQLException on error. + */ + @Test(expected = SQLException.class) + public void testUnencryptedConnectionProvidingInvalidPort() + throws Exception { + final Properties properties = new Properties(); + + properties.put(ArrowFlightConnectionProperty.HOST.camelName(), "localhost"); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), + false); + final String invalidUrl = "jdbc:arrow-flight-sql://" + FLIGHT_SERVER_TEST_RULE.getHost() + + ":" + 65537; + + DriverManager.getConnection(invalidUrl, properties); + } + + /** + * Try to instantiate a basic FlightClient. + * + * @throws URISyntaxException on error. + */ + @Test + public void testGetBasicClientNoAuthShouldOpenConnection() throws Exception { + + try (ArrowFlightSqlClientHandler client = + new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withBufferAllocator(allocator) + .build()) { + assertNotNull(client); + } + } + + /** + * Checks if an unencrypted connection can be established successfully when + * not providing credentials. + * + * @throws SQLException on error. + */ + @Test + public void testUnencryptedConnectionShouldOpenSuccessfullyWithoutAuthentication() + throws Exception { + final Properties properties = new Properties(); + properties.put(ArrowFlightConnectionProperty.HOST.camelName(), "localhost"); + properties.put(ArrowFlightConnectionProperty.PORT.camelName(), + FLIGHT_SERVER_TEST_RULE.getPort()); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), + false); + try (Connection connection = DriverManager + .getConnection("jdbc:arrow-flight-sql://localhost:32010", properties)) { + assert connection.isValid(300); + } + } + + /** + * Check if an unencrypted connection throws an exception when provided with + * invalid credentials. + * + * @throws SQLException The exception expected to be thrown. + */ + @Test(expected = SQLException.class) + public void testUnencryptedConnectionShouldThrowExceptionWhenProvidedWithInvalidCredentials() + throws Exception { + + final Properties properties = new Properties(); + + properties.put(ArrowFlightConnectionProperty.HOST.camelName(), "localhost"); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + "invalidUser"); + properties.put(ArrowFlightConnectionProperty.PORT.camelName(), + FLIGHT_SERVER_TEST_RULE.getPort()); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), + false); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + "invalidPassword"); + + try (Connection ignored = DriverManager.getConnection("jdbc:arrow-flight-sql://localhost:32010", + properties)) { + Assert.fail(); + } + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using just a connection url. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyFalseCorrectCastUrlWithDriverManager() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s?user=%s&password=%s&useEncryption=false", + FLIGHT_SERVER_TEST_RULE.getPort(), + userTest, + passTest)); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using a connection url and properties with String K-V pairs. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyFalseCorrectCastUrlAndPropertiesUsingSetPropertyWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + + properties.setProperty(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.setProperty(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.setProperty(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), "false"); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using a connection url and properties with Object K-V pairs. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyFalseCorrectCastUrlAndPropertiesUsingPutWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), false); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using just a connection url and using 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyFalseIntegerCorrectCastUrlWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s?user=%s&password=%s&useEncryption=0", + FLIGHT_SERVER_TEST_RULE.getPort(), + userTest, + passTest)); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using a connection url and properties with String K-V pairs and using + * 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyFalseIntegerCorrectCastUrlAndPropertiesUsingSetPropertyWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + Properties properties = new Properties(); + + properties.setProperty(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.setProperty(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.setProperty(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), "0"); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using a connection url and properties with Object K-V pairs and using + * 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyFalseIntegerCorrectCastUrlAndPropertiesUsingPutWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), 0); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using just a connection url. + * + * @throws Exception on error. + */ + @Test + public void testThreadPoolSizeConnectionPropertyCorrectCastUrlWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s?user=%s&password=%s&threadPoolSize=1&useEncryption=%s", + FLIGHT_SERVER_TEST_RULE.getPort(), + userTest, + passTest, + false)); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using a connection url and properties with String K-V pairs and using + * 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testThreadPoolSizeConnectionPropertyCorrectCastUrlAndPropertiesUsingSetPropertyWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + Properties properties = new Properties(); + + properties.setProperty(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.setProperty(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.setProperty(ArrowFlightConnectionProperty.THREAD_POOL_SIZE.camelName(), "1"); + properties.put("useEncryption", false); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using a connection url and properties with Object K-V pairs and using + * 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testThreadPoolSizeConnectionPropertyCorrectCastUrlAndPropertiesUsingPutWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put(ArrowFlightConnectionProperty.THREAD_POOL_SIZE.camelName(), 1); + properties.put("useEncryption", false); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using just a connection url. + * + * @throws Exception on error. + */ + @Test + public void testPasswordConnectionPropertyIntegerCorrectCastUrlWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s?user=%s&password=%s&useEncryption=%s", + FLIGHT_SERVER_TEST_RULE.getPort(), + userTest, + passTest, + false)); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using a connection url and properties with String K-V pairs and using + * 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testPasswordConnectionPropertyIntegerCorrectCastUrlAndPropertiesUsingSetPropertyWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + Properties properties = new Properties(); + + properties.setProperty(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.setProperty(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put("useEncryption", false); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an non-encrypted connection can be established successfully when connecting through + * the DriverManager using a connection url and properties with Object K-V pairs and using + * 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testPasswordConnectionPropertyIntegerCorrectCastUrlAndPropertiesUsingPutWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put("useEncryption", false); + + Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTlsTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTlsTest.java new file mode 100644 index 00000000000..a5f9938f04b --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ConnectionTlsTest.java @@ -0,0 +1,454 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.junit.Assert.assertNotNull; + +import java.net.URLEncoder; +import java.nio.file.Paths; +import java.sql.Connection; +import java.sql.Driver; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; +import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler; +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; +import org.apache.arrow.driver.jdbc.utils.FlightSqlTestCertificates; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.util.Preconditions; +import org.apache.calcite.avatica.org.apache.http.auth.UsernamePasswordCredentials; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Test; + +/** + * Tests encrypted connections. + */ +public class ConnectionTlsTest { + + @ClassRule + public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE; + private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer(); + private static final String userTest = "user1"; + private static final String passTest = "pass1"; + + static { + final FlightSqlTestCertificates.CertKeyPair + certKey = FlightSqlTestCertificates.exampleTlsCerts().get(0); + + UserPasswordAuthentication authentication = new UserPasswordAuthentication.Builder() + .user(userTest, passTest) + .build(); + + FLIGHT_SERVER_TEST_RULE = new FlightServerTestRule.Builder() + .host("localhost") + .randomPort() + .authentication(authentication) + .useEncryption(certKey.cert, certKey.key) + .producer(PRODUCER) + .build(); + } + + private String trustStorePath; + private String noCertificateKeyStorePath; + private final String trustStorePass = "flight"; + private BufferAllocator allocator; + + @Before + public void setUp() throws Exception { + trustStorePath = Paths.get( + Preconditions.checkNotNull(getClass().getResource("/keys/keyStore.jks")).toURI()).toString(); + noCertificateKeyStorePath = Paths.get( + Preconditions.checkNotNull(getClass().getResource("/keys/noCertificate.jks")).toURI()).toString(); + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @After + public void tearDown() throws Exception { + allocator.getChildAllocators().forEach(BufferAllocator::close); + AutoCloseables.close(allocator); + } + + /** + * Try to instantiate an encrypted FlightClient. + * + * @throws Exception on error. + */ + @Test + public void testGetEncryptedClientAuthenticatedWithDisableCertVerification() throws Exception { + final UsernamePasswordCredentials credentials = new UsernamePasswordCredentials( + userTest, passTest); + + try (ArrowFlightSqlClientHandler client = + new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withPort(FLIGHT_SERVER_TEST_RULE.getPort()) + .withUsername(credentials.getUserName()) + .withPassword(credentials.getPassword()) + .withDisableCertificateVerification(true) + .withBufferAllocator(allocator) + .withEncryption(true) + .build()) { + assertNotNull(client); + } + } + + /** + * Try to instantiate an encrypted FlightClient. + * + * @throws Exception on error. + */ + @Test + public void testGetEncryptedClientAuthenticated() throws Exception { + final UsernamePasswordCredentials credentials = new UsernamePasswordCredentials( + userTest, passTest); + + try (ArrowFlightSqlClientHandler client = + new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withPort(FLIGHT_SERVER_TEST_RULE.getPort()) + .withUsername(credentials.getUserName()) + .withPassword(credentials.getPassword()) + .withTrustStorePath(trustStorePath) + .withTrustStorePassword(trustStorePass) + .withBufferAllocator(allocator) + .withEncryption(true) + .build()) { + assertNotNull(client); + } + } + + /** + * Try to instantiate an encrypted FlightClient providing a keystore without certificate. It's expected to + * receive the SQLException. + * + * @throws Exception on error. + */ + @Test(expected = SQLException.class) + public void testGetEncryptedClientWithNoCertificateOnKeyStore() throws Exception { + final String noCertificateKeyStorePassword = "flight1"; + + try (ArrowFlightSqlClientHandler ignored = + new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withTrustStorePath(noCertificateKeyStorePath) + .withTrustStorePassword(noCertificateKeyStorePassword) + .withBufferAllocator(allocator) + .withEncryption(true) + .build()) { + Assert.fail(); + } + } + + /** + * Try to instantiate an encrypted FlightClient without credentials. + * + * @throws Exception on error. + */ + @Test + public void testGetNonAuthenticatedEncryptedClientNoAuth() throws Exception { + try (ArrowFlightSqlClientHandler client = + new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withTrustStorePath(trustStorePath) + .withTrustStorePassword(trustStorePass) + .withBufferAllocator(allocator) + .withEncryption(true) + .build()) { + assertNotNull(client); + } + } + + /** + * Try to instantiate an encrypted FlightClient with an invalid password to the keystore file. + * It's expected to receive the SQLException. + * + * @throws Exception on error. + */ + @Test(expected = SQLException.class) + public void testGetEncryptedClientWithKeyStoreBadPasswordAndNoAuth() throws Exception { + String keyStoreBadPassword = "badPassword"; + + try (ArrowFlightSqlClientHandler ignored = + new ArrowFlightSqlClientHandler.Builder() + .withHost(FLIGHT_SERVER_TEST_RULE.getHost()) + .withTrustStorePath(trustStorePath) + .withTrustStorePassword(keyStoreBadPassword) + .withBufferAllocator(allocator) + .withEncryption(true) + .build()) { + Assert.fail(); + } + } + + /** + * Check if an encrypted connection can be established successfully when the + * provided valid credentials and a valid Keystore. + * + * @throws Exception on error. + */ + @Test + public void testGetEncryptedConnectionWithValidCredentialsAndKeyStore() throws Exception { + final Properties properties = new Properties(); + + properties.put(ArrowFlightConnectionProperty.HOST.camelName(), "localhost"); + properties.put(ArrowFlightConnectionProperty.PORT.camelName(), + FLIGHT_SERVER_TEST_RULE.getPort()); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE.camelName(), trustStorePath); + properties.put(ArrowFlightConnectionProperty.USE_SYSTEM_TRUST_STORE.camelName(), false); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), trustStorePass); + + final ArrowFlightJdbcDataSource dataSource = + ArrowFlightJdbcDataSource.createNewDataSource(properties); + try (final Connection connection = dataSource.getConnection()) { + assert connection.isValid(300); + } + } + + /** + * Check if the SQLException is thrown when trying to establish an encrypted connection + * providing valid credentials but invalid password to the Keystore. + * + * @throws SQLException on error. + */ + @Test(expected = SQLException.class) + public void testGetAuthenticatedEncryptedConnectionWithKeyStoreBadPassword() throws Exception { + final Properties properties = new Properties(); + + properties.put(ArrowFlightConnectionProperty.HOST.camelName(), + FLIGHT_SERVER_TEST_RULE.getHost()); + properties.put(ArrowFlightConnectionProperty.PORT.camelName(), + FLIGHT_SERVER_TEST_RULE.getPort()); + properties.put(ArrowFlightConnectionProperty.USER.camelName(), + userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), + passTest); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), true); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE.camelName(), trustStorePath); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), "badpassword"); + + final ArrowFlightJdbcDataSource dataSource = + ArrowFlightJdbcDataSource.createNewDataSource(properties); + try (final Connection ignored = dataSource.getConnection()) { + Assert.fail(); + } + } + + /** + * Check if an encrypted connection can be established successfully when not providing authentication. + * + * @throws Exception on error. + */ + @Test + public void testGetNonAuthenticatedEncryptedConnection() throws Exception { + final Properties properties = new Properties(); + + properties.put(ArrowFlightConnectionProperty.HOST.camelName(), FLIGHT_SERVER_TEST_RULE.getHost()); + properties.put(ArrowFlightConnectionProperty.PORT.camelName(), FLIGHT_SERVER_TEST_RULE.getPort()); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), true); + properties.put(ArrowFlightConnectionProperty.USE_SYSTEM_TRUST_STORE.camelName(), false); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE.camelName(), trustStorePath); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), trustStorePass); + + final ArrowFlightJdbcDataSource dataSource = ArrowFlightJdbcDataSource.createNewDataSource(properties); + try (final Connection connection = dataSource.getConnection()) { + assert connection.isValid(300); + } + } + + /** + * Check if an encrypted connection can be established successfully when connecting through the DriverManager using + * just a connection url. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyTrueCorrectCastUrlWithDriverManager() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + final Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s?user=%s&password=%s" + + "&useEncryption=true&useSystemTrustStore=false&%s=%s&%s=%s", + FLIGHT_SERVER_TEST_RULE.getPort(), + userTest, + passTest, + ArrowFlightConnectionProperty.TRUST_STORE.camelName(), + URLEncoder.encode(trustStorePath, "UTF-8"), + ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), + URLEncoder.encode(trustStorePass, "UTF-8"))); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an encrypted connection can be established successfully when connecting through the DriverManager using + * a connection url and properties with String K-V pairs. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyTrueCorrectCastUrlAndPropertiesUsingSetPropertyWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + + properties.setProperty(ArrowFlightConnectionProperty.USER.camelName(), userTest); + properties.setProperty(ArrowFlightConnectionProperty.PASSWORD.camelName(), passTest); + properties.setProperty(ArrowFlightConnectionProperty.TRUST_STORE.camelName(), trustStorePath); + properties.setProperty(ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), trustStorePass); + properties.setProperty(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), "true"); + properties.setProperty(ArrowFlightConnectionProperty.USE_SYSTEM_TRUST_STORE.camelName(), "false"); + + final Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an encrypted connection can be established successfully when connecting through the DriverManager using + * a connection url and properties with Object K-V pairs. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyTrueCorrectCastUrlAndPropertiesUsingPutWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + + properties.put(ArrowFlightConnectionProperty.USER.camelName(), userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), passTest); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), true); + properties.put(ArrowFlightConnectionProperty.USE_SYSTEM_TRUST_STORE.camelName(), false); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE.camelName(), trustStorePath); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), trustStorePass); + + final Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an encrypted connection can be established successfully when connecting through the DriverManager using + * just a connection url and using 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyTrueIntegerCorrectCastUrlWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + final Connection connection = DriverManager.getConnection( + String.format( + "jdbc:arrow-flight-sql://localhost:%s?user=%s&password=%s" + + "&useEncryption=1&useSystemTrustStore=0&%s=%s&%s=%s", + FLIGHT_SERVER_TEST_RULE.getPort(), + userTest, + passTest, + ArrowFlightConnectionProperty.TRUST_STORE.camelName(), + URLEncoder.encode(trustStorePath, "UTF-8"), + ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), + URLEncoder.encode(trustStorePass, "UTF-8"))); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an encrypted connection can be established successfully when connecting through the DriverManager using + * a connection url and properties with String K-V pairs and using 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyTrueIntegerCorrectCastUrlAndPropertiesUsingSetPropertyWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + + properties.setProperty(ArrowFlightConnectionProperty.USER.camelName(), userTest); + properties.setProperty(ArrowFlightConnectionProperty.PASSWORD.camelName(), passTest); + properties.setProperty(ArrowFlightConnectionProperty.TRUST_STORE.camelName(), trustStorePath); + properties.setProperty(ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), trustStorePass); + properties.setProperty(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), "1"); + properties.setProperty(ArrowFlightConnectionProperty.USE_SYSTEM_TRUST_STORE.camelName(), "0"); + + final Connection connection = DriverManager.getConnection( + String.format("jdbc:arrow-flight-sql://localhost:%s", FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } + + /** + * Check if an encrypted connection can be established successfully when connecting through the DriverManager using + * a connection url and properties with Object K-V pairs and using 0 and 1 as ssl values. + * + * @throws Exception on error. + */ + @Test + public void testTLSConnectionPropertyTrueIntegerCorrectCastUrlAndPropertiesUsingPutWithDriverManager() + throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + DriverManager.registerDriver(driver); + + Properties properties = new Properties(); + + properties.put(ArrowFlightConnectionProperty.USER.camelName(), userTest); + properties.put(ArrowFlightConnectionProperty.PASSWORD.camelName(), passTest); + properties.put(ArrowFlightConnectionProperty.USE_ENCRYPTION.camelName(), 1); + properties.put(ArrowFlightConnectionProperty.USE_SYSTEM_TRUST_STORE.camelName(), 0); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE.camelName(), trustStorePath); + properties.put(ArrowFlightConnectionProperty.TRUST_STORE_PASSWORD.camelName(), trustStorePass); + + final Connection connection = DriverManager.getConnection( + String.format("jdbc:arrow-flight-sql://localhost:%s", + FLIGHT_SERVER_TEST_RULE.getPort()), + properties); + Assert.assertTrue(connection.isValid(0)); + connection.close(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/FlightServerTestRule.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/FlightServerTestRule.java new file mode 100644 index 00000000000..b251b7df164 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/FlightServerTestRule.java @@ -0,0 +1,365 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.apache.arrow.driver.jdbc.utils.FlightSqlTestCertificates.CertKeyPair; + +import java.io.File; +import java.io.IOException; +import java.lang.reflect.Method; +import java.sql.Connection; +import java.sql.SQLException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.authentication.Authentication; +import org.apache.arrow.driver.jdbc.authentication.TokenAuthentication; +import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl; +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallInfo; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightServerMiddleware; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.RequestContext; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.util.Preconditions; +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import me.alexpanov.net.FreePortFinder; + +/** + * Utility class for unit tests that need to instantiate a {@link FlightServer} + * and interact with it. + */ +public class FlightServerTestRule implements TestRule, AutoCloseable { + private static final Logger LOGGER = LoggerFactory.getLogger(FlightServerTestRule.class); + + private final Properties properties; + private final ArrowFlightConnectionConfigImpl config; + private final BufferAllocator allocator; + private final FlightSqlProducer producer; + private final Authentication authentication; + private final CertKeyPair certKeyPair; + + private final MiddlewareCookie.Factory middlewareCookieFactory = new MiddlewareCookie.Factory(); + + private FlightServerTestRule(final Properties properties, + final ArrowFlightConnectionConfigImpl config, + final BufferAllocator allocator, + final FlightSqlProducer producer, + final Authentication authentication, + final CertKeyPair certKeyPair) { + this.properties = Preconditions.checkNotNull(properties); + this.config = Preconditions.checkNotNull(config); + this.allocator = Preconditions.checkNotNull(allocator); + this.producer = Preconditions.checkNotNull(producer); + this.authentication = authentication; + this.certKeyPair = certKeyPair; + } + + /** + * Create a {@link FlightServerTestRule} with standard values such as: user, password, localhost. + * + * @param producer the producer used to create the FlightServerTestRule. + * @return the FlightServerTestRule. + */ + public static FlightServerTestRule createStandardTestRule(final FlightSqlProducer producer) { + UserPasswordAuthentication authentication = + new UserPasswordAuthentication.Builder() + .user("flight-test-user", "flight-test-password") + .build(); + + return new Builder() + .host("localhost") + .randomPort() + .authentication(authentication) + .producer(producer) + .build(); + } + + ArrowFlightJdbcDataSource createDataSource() { + return ArrowFlightJdbcDataSource.createNewDataSource(properties); + } + + ArrowFlightJdbcDataSource createDataSource(String token) { + properties.put("token", token); + return ArrowFlightJdbcDataSource.createNewDataSource(properties); + } + + public ArrowFlightJdbcConnectionPoolDataSource createConnectionPoolDataSource() { + return ArrowFlightJdbcConnectionPoolDataSource.createNewDataSource(properties); + } + + public ArrowFlightJdbcConnectionPoolDataSource createConnectionPoolDataSource(boolean useEncryption) { + setUseEncryption(useEncryption); + return ArrowFlightJdbcConnectionPoolDataSource.createNewDataSource(properties); + } + + public Connection getConnection(boolean useEncryption, String token) throws SQLException { + properties.put("token", token); + + return getConnection(useEncryption); + } + + public Connection getConnection(boolean useEncryption) throws SQLException { + setUseEncryption(useEncryption); + return this.createDataSource().getConnection(); + } + + private void setUseEncryption(boolean useEncryption) { + properties.put("useEncryption", useEncryption); + } + + public MiddlewareCookie.Factory getMiddlewareCookieFactory() { + return middlewareCookieFactory; + } + + @FunctionalInterface + public interface CheckedFunction { + R apply(T t) throws IOException; + } + + private FlightServer initiateServer(Location location) throws IOException { + FlightServer.Builder builder = FlightServer.builder(allocator, location, producer) + .headerAuthenticator(authentication.authenticate()) + .middleware(FlightServerMiddleware.Key.of("KEY"), middlewareCookieFactory); + if (certKeyPair != null) { + builder.useTls(certKeyPair.cert, certKeyPair.key); + } + return builder.build(); + } + + @Override + public Statement apply(Statement base, Description description) { + return new Statement() { + @Override + public void evaluate() throws Throwable { + try (FlightServer flightServer = + getStartServer(location -> + initiateServer(location), 3)) { + LOGGER.info("Started " + FlightServer.class.getName() + " as " + flightServer); + base.evaluate(); + } finally { + close(); + } + } + }; + } + + private FlightServer getStartServer(CheckedFunction newServerFromLocation, + int retries) + throws IOException { + + final Deque exceptions = new ArrayDeque<>(); + + for (; retries > 0; retries--) { + final Location location = Location.forGrpcInsecure(config.getHost(), config.getPort()); + final FlightServer server = newServerFromLocation.apply(location); + try { + Method start = server.getClass().getMethod("start"); + start.setAccessible(true); + start.invoke(server); + return server; + } catch (ReflectiveOperationException e) { + exceptions.add(e); + } + } + + exceptions.forEach( + e -> LOGGER.error("Failed to start a new " + FlightServer.class.getName() + ".", e)); + throw new IOException(exceptions.pop().getCause()); + } + + /** + * Sets a port to be used. + * + * @return the port value. + */ + public int getPort() { + return config.getPort(); + } + + /** + * Sets a host to be used. + * + * @return the host value. + */ + public String getHost() { + return config.getHost(); + } + + @Override + public void close() throws Exception { + allocator.getChildAllocators().forEach(BufferAllocator::close); + AutoCloseables.close(allocator); + } + + /** + * Builder for {@link FlightServerTestRule}. + */ + public static final class Builder { + private final Properties properties = new Properties(); + private FlightSqlProducer producer; + private Authentication authentication; + private CertKeyPair certKeyPair; + + /** + * Sets the host for the server rule. + * + * @param host the host value. + * @return the Builder. + */ + public Builder host(final String host) { + properties.put(ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.HOST.camelName(), + host); + return this; + } + + /** + * Sets a random port to be used by the server rule. + * + * @return the Builder. + */ + public Builder randomPort() { + properties.put(ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PORT.camelName(), + FreePortFinder.findFreeLocalPort()); + return this; + } + + /** + * Sets a specific port to be used by the server rule. + * + * @param port the port value. + * @return the Builder. + */ + public Builder port(final int port) { + properties.put(ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PORT.camelName(), + port); + return this; + } + + /** + * Sets the producer that will be used in the server rule. + * + * @param producer the flight sql producer. + * @return the Builder. + */ + public Builder producer(final FlightSqlProducer producer) { + this.producer = producer; + return this; + } + + /** + * Sets the type of the authentication that will be used in the server rules. + * There are two types of authentication: {@link UserPasswordAuthentication} and + * {@link TokenAuthentication}. + * + * @param authentication the type of authentication. + * @return the Builder. + */ + public Builder authentication(final Authentication authentication) { + this.authentication = authentication; + return this; + } + + /** + * Enable TLS on the server. + * + * @param certChain The certificate chain to use. + * @param key The private key to use. + * @return the Builder. + */ + public Builder useEncryption(final File certChain, final File key) { + certKeyPair = new CertKeyPair(certChain, key); + return this; + } + + /** + * Builds the {@link FlightServerTestRule} using the provided values. + * + * @return a {@link FlightServerTestRule}. + */ + public FlightServerTestRule build() { + authentication.populateProperties(properties); + return new FlightServerTestRule(properties, new ArrowFlightConnectionConfigImpl(properties), + new RootAllocator(Long.MAX_VALUE), producer, authentication, certKeyPair); + } + } + + /** + * A middleware to handle with the cookies in the server. It is used to test if cookies are + * being sent properly. + */ + static class MiddlewareCookie implements FlightServerMiddleware { + + private final Factory factory; + + public MiddlewareCookie(Factory factory) { + this.factory = factory; + } + + @Override + public void onBeforeSendingHeaders(CallHeaders callHeaders) { + if (!factory.receivedCookieHeader) { + callHeaders.insert("Set-Cookie", "k=v"); + } + } + + @Override + public void onCallCompleted(CallStatus callStatus) { + + } + + @Override + public void onCallErrored(Throwable throwable) { + + } + + /** + * A factory for the MiddlewareCookie. + */ + static class Factory implements FlightServerMiddleware.Factory { + + private boolean receivedCookieHeader = false; + private String cookie; + + @Override + public MiddlewareCookie onCallStarted(CallInfo callInfo, CallHeaders callHeaders, + RequestContext requestContext) { + cookie = callHeaders.get("Cookie"); + receivedCookieHeader = null != cookie; + return new MiddlewareCookie(this); + } + + public String getCookie() { + return cookie; + } + } + } + +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ResultSetMetadataTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ResultSetMetadataTest.java new file mode 100644 index 00000000000..64ec7f7d9e1 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ResultSetMetadataTest.java @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.notNullValue; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; +import java.sql.Types; + +import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers; +import org.hamcrest.CoreMatchers; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ResultSetMetadataTest { + private static ResultSetMetaData metadata; + + private static Connection connection; + + @Rule + public ErrorCollector collector = new ErrorCollector(); + + @ClassRule + public static final FlightServerTestRule SERVER_TEST_RULE = FlightServerTestRule + .createStandardTestRule(CoreMockedSqlProducers.getLegacyProducer()); + + @BeforeClass + public static void setup() throws SQLException { + connection = SERVER_TEST_RULE.getConnection(false); + + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_METADATA_SQL_CMD)) { + metadata = resultSet.getMetaData(); + } + } + + @AfterClass + public static void teardown() throws SQLException { + connection.close(); + } + + /** + * Test if {@link ResultSetMetaData} object is not null. + */ + @Test + public void testShouldGetResultSetMetadata() { + collector.checkThat(metadata, CoreMatchers.is(notNullValue())); + } + + /** + * Test if {@link ResultSetMetaData#getColumnCount()} returns the correct values. + * + * @throws SQLException in case of error. + */ + @Test + public void testShouldGetColumnCount() throws SQLException { + final int columnCount = metadata.getColumnCount(); + + assert columnCount == 3; + } + + /** + * Test if {@link ResultSetMetaData#getColumnTypeName(int)} returns the correct type name for each + * column. + * + * @throws SQLException in case of error. + */ + @Test + public void testShouldGetColumnTypesName() throws SQLException { + final String firstColumn = metadata.getColumnTypeName(1); + final String secondColumn = metadata.getColumnTypeName(2); + final String thirdColumn = metadata.getColumnTypeName(3); + + collector.checkThat(firstColumn, equalTo("BIGINT")); + collector.checkThat(secondColumn, equalTo("VARCHAR")); + collector.checkThat(thirdColumn, equalTo("FLOAT")); + } + + /** + * Test if {@link ResultSetMetaData#getColumnTypeName(int)} passing an column index that does not exist. + * + * @throws SQLException in case of error. + */ + @Test(expected = IndexOutOfBoundsException.class) + public void testShouldGetColumnTypesNameFromOutOfBoundIndex() throws SQLException { + metadata.getColumnTypeName(4); + + Assert.fail(); + } + + /** + * Test if {@link ResultSetMetaData#getColumnName(int)} returns the correct name for each column. + * + * @throws SQLException in case of error. + */ + @Test + public void testShouldGetColumnNames() throws SQLException { + final String firstColumn = metadata.getColumnName(1); + final String secondColumn = metadata.getColumnName(2); + final String thirdColumn = metadata.getColumnName(3); + + collector.checkThat(firstColumn, equalTo("integer0")); + collector.checkThat(secondColumn, equalTo("string1")); + collector.checkThat(thirdColumn, equalTo("float2")); + } + + + /** + * Test {@link ResultSetMetaData#getColumnTypeName(int)} passing an column index that does not exist. + * + * @throws SQLException in case of error. + */ + @Test(expected = IndexOutOfBoundsException.class) + public void testShouldGetColumnNameFromOutOfBoundIndex() throws SQLException { + metadata.getColumnName(4); + + Assert.fail(); + } + + /** + * Test if {@link ResultSetMetaData#getColumnType(int)}returns the correct values. + * + * @throws SQLException in case of error. + */ + @Test + public void testShouldGetColumnType() throws SQLException { + final int firstColumn = metadata.getColumnType(1); + final int secondColumn = metadata.getColumnType(2); + final int thirdColumn = metadata.getColumnType(3); + + collector.checkThat(firstColumn, equalTo(Types.BIGINT)); + collector.checkThat(secondColumn, equalTo(Types.VARCHAR)); + collector.checkThat(thirdColumn, equalTo(Types.FLOAT)); + } + + @Test + public void testShouldGetPrecision() throws SQLException { + collector.checkThat(metadata.getPrecision(1), equalTo(10)); + collector.checkThat(metadata.getPrecision(2), equalTo(65535)); + collector.checkThat(metadata.getPrecision(3), equalTo(15)); + } + + @Test + public void testShouldGetScale() throws SQLException { + collector.checkThat(metadata.getScale(1), equalTo(0)); + collector.checkThat(metadata.getScale(2), equalTo(0)); + collector.checkThat(metadata.getScale(3), equalTo(20)); + } + + @Test + public void testShouldGetCatalogName() throws SQLException { + collector.checkThat(metadata.getCatalogName(1), equalTo("CATALOG_NAME_1")); + collector.checkThat(metadata.getCatalogName(2), equalTo("CATALOG_NAME_2")); + collector.checkThat(metadata.getCatalogName(3), equalTo("CATALOG_NAME_3")); + } + + @Test + public void testShouldGetSchemaName() throws SQLException { + collector.checkThat(metadata.getSchemaName(1), equalTo("SCHEMA_NAME_1")); + collector.checkThat(metadata.getSchemaName(2), equalTo("SCHEMA_NAME_2")); + collector.checkThat(metadata.getSchemaName(3), equalTo("SCHEMA_NAME_3")); + } + + @Test + public void testShouldGetTableName() throws SQLException { + collector.checkThat(metadata.getTableName(1), equalTo("TABLE_NAME_1")); + collector.checkThat(metadata.getTableName(2), equalTo("TABLE_NAME_2")); + collector.checkThat(metadata.getTableName(3), equalTo("TABLE_NAME_3")); + } + + @Test + public void testShouldIsAutoIncrement() throws SQLException { + collector.checkThat(metadata.isAutoIncrement(1), equalTo(true)); + collector.checkThat(metadata.isAutoIncrement(2), equalTo(false)); + collector.checkThat(metadata.isAutoIncrement(3), equalTo(false)); + } + + @Test + public void testShouldIsCaseSensitive() throws SQLException { + collector.checkThat(metadata.isCaseSensitive(1), equalTo(false)); + collector.checkThat(metadata.isCaseSensitive(2), equalTo(true)); + collector.checkThat(metadata.isCaseSensitive(3), equalTo(false)); + } + + @Test + public void testShouldIsReadonly() throws SQLException { + collector.checkThat(metadata.isReadOnly(1), equalTo(true)); + collector.checkThat(metadata.isReadOnly(2), equalTo(false)); + collector.checkThat(metadata.isReadOnly(3), equalTo(false)); + } + + @Test + public void testShouldIsSearchable() throws SQLException { + collector.checkThat(metadata.isSearchable(1), equalTo(true)); + collector.checkThat(metadata.isSearchable(2), equalTo(true)); + collector.checkThat(metadata.isSearchable(3), equalTo(true)); + } + + /** + * Test if {@link ResultSetMetaData#getColumnTypeName(int)} passing an column index that does not exist. + * + * @throws SQLException in case of error. + */ + @Test(expected = IndexOutOfBoundsException.class) + public void testShouldGetColumnTypesFromOutOfBoundIndex() throws SQLException { + metadata.getColumnType(4); + + Assert.fail(); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java new file mode 100644 index 00000000000..33473b6fe2b --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ResultSetTest.java @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import static java.lang.String.format; +import static java.util.Collections.synchronizedSet; +import static org.hamcrest.CoreMatchers.allOf; +import static org.hamcrest.CoreMatchers.anyOf; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.SQLTimeoutException; +import java.sql.Statement; +import java.util.HashSet; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CountDownLatch; + +import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +import com.google.common.collect.ImmutableSet; + +public class ResultSetTest { + private static final Random RANDOM = new Random(10); + @ClassRule + public static final FlightServerTestRule SERVER_TEST_RULE = FlightServerTestRule + .createStandardTestRule(CoreMockedSqlProducers.getLegacyProducer()); + private static Connection connection; + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @BeforeClass + public static void setup() throws SQLException { + connection = SERVER_TEST_RULE.getConnection(false); + } + + @AfterClass + public static void tearDown() throws SQLException { + connection.close(); + } + + private static void resultSetNextUntilDone(ResultSet resultSet) throws SQLException { + while (resultSet.next()) { + // TODO: implement resultSet.last() + // Pass to the next until resultSet is done + } + } + + private static void setMaxRowsLimit(int maxRowsLimit, Statement statement) throws SQLException { + statement.setLargeMaxRows(maxRowsLimit); + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} can run a query successfully. + * + * @throws Exception If the connection fails to be established. + */ + @Test + public void testShouldRunSelectQuery() throws Exception { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { + CoreMockedSqlProducers.assertLegacyRegularSqlResultSet(resultSet, collector); + } + } + + @Test + public void testShouldExecuteQueryNotBlockIfClosedBeforeEnd() throws Exception { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { + + for (int i = 0; i < 7500; i++) { + assertTrue(resultSet.next()); + } + } + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} query only returns only the + * amount of value set by {@link org.apache.calcite.avatica.AvaticaStatement#setMaxRows(int)}. + * + * @throws Exception If the connection fails to be established. + */ + @Test + public void testShouldRunSelectQuerySettingMaxRowLimit() throws Exception { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { + + final int maxRowsLimit = 3; + statement.setMaxRows(maxRowsLimit); + + collector.checkThat(statement.getMaxRows(), is(maxRowsLimit)); + + int count = 0; + int columns = 6; + for (; resultSet.next(); count++) { + for (int column = 1; column <= columns; column++) { + resultSet.getObject(column); + } + collector.checkThat("Test Name #" + count, is(resultSet.getString(2))); + } + + collector.checkThat(maxRowsLimit, is(count)); + } + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} fails upon attempting + * to run an invalid query. + * + * @throws Exception If the connection fails to be established. + */ + @Test(expected = SQLException.class) + public void testShouldThrowExceptionUponAttemptingToExecuteAnInvalidSelectQuery() + throws Exception { + Statement statement = connection.createStatement(); + statement.executeQuery("SELECT * FROM SHOULD-FAIL"); + fail(); + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} query only returns only the + * amount of value set by {@link org.apache.calcite.avatica.AvaticaStatement#setLargeMaxRows(long)} (int)}. + * + * @throws Exception If the connection fails to be established. + */ + @Test + public void testShouldRunSelectQuerySettingLargeMaxRowLimit() throws Exception { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { + + final long maxRowsLimit = 3; + statement.setLargeMaxRows(maxRowsLimit); + + collector.checkThat(statement.getLargeMaxRows(), is(maxRowsLimit)); + + int count = 0; + int columns = resultSet.getMetaData().getColumnCount(); + for (; resultSet.next(); count++) { + for (int column = 1; column <= columns; column++) { + resultSet.getObject(column); + } + assertEquals("Test Name #" + count, resultSet.getString(2)); + } + + assertEquals(maxRowsLimit, count); + } + } + + @Test + public void testColumnCountShouldRemainConsistentForResultSetThroughoutEntireDuration() + throws SQLException { + final Set counts = new HashSet<>(); + try (final Statement statement = connection.createStatement(); + final ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { + while (resultSet.next()) { + counts.add(resultSet.getMetaData().getColumnCount()); + } + } + collector.checkThat(counts, is(ImmutableSet.of(6))); + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} close the statement after complete ResultSet + * when call {@link org.apache.calcite.avatica.AvaticaStatement#closeOnCompletion()}. + * + * @throws Exception If the connection fails to be established. + */ + @Test + public void testShouldCloseStatementWhenIsCloseOnCompletion() throws Exception { + Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD); + + statement.closeOnCompletion(); + + resultSetNextUntilDone(resultSet); + + collector.checkThat(statement.isClosed(), is(true)); + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} close the statement after complete ResultSet with max rows limit + * when call {@link org.apache.calcite.avatica.AvaticaStatement#closeOnCompletion()}. + * + * @throws Exception If the connection fails to be established. + */ + @Test + public void testShouldCloseStatementWhenIsCloseOnCompletionWithMaxRowsLimit() throws Exception { + Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery(CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD); + + final long maxRowsLimit = 3; + statement.setLargeMaxRows(maxRowsLimit); + statement.closeOnCompletion(); + + resultSetNextUntilDone(resultSet); + + collector.checkThat(statement.isClosed(), is(true)); + } + + /** + * Tests whether the {@link ArrowFlightJdbcDriver} not close the statement after complete ResultSet with max rows + * limit when call {@link org.apache.calcite.avatica.AvaticaStatement#closeOnCompletion()}. + * + * @throws Exception If the connection fails to be established. + */ + @Test + public void testShouldNotCloseStatementWhenIsNotCloseOnCompletionWithMaxRowsLimit() + throws Exception { + try (Statement statement = connection.createStatement(); + ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { + + final long maxRowsLimit = 3; + statement.setLargeMaxRows(maxRowsLimit); + + collector.checkThat(statement.isClosed(), is(false)); + resultSetNextUntilDone(resultSet); + collector.checkThat(resultSet.isClosed(), is(false)); + collector.checkThat(resultSet, is(instanceOf(ArrowFlightJdbcFlightStreamResultSet.class))); + } + } + + @Test + public void testShouldCancelQueryUponCancelAfterQueryingResultSet() throws SQLException { + try (final Statement statement = connection.createStatement(); + final ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { + final int column = RANDOM.nextInt(resultSet.getMetaData().getColumnCount()) + 1; + collector.checkThat(resultSet.isClosed(), is(false)); + collector.checkThat(resultSet.next(), is(true)); + collector.checkSucceeds(() -> resultSet.getObject(column)); + statement.cancel(); + // Should reset `ResultSet`; keep both `ResultSet` and `Connection` open. + collector.checkThat(statement.isClosed(), is(false)); + collector.checkThat(resultSet.isClosed(), is(false)); + collector.checkThat(resultSet.getMetaData().getColumnCount(), is(0)); + } + } + + @Test + public void testShouldInterruptFlightStreamsIfQueryIsCancelledMidQuerying() + throws SQLException, InterruptedException { + try (final Statement statement = connection.createStatement()) { + final CountDownLatch latch = new CountDownLatch(1); + final Set exceptions = synchronizedSet(new HashSet<>(1)); + final Thread thread = new Thread(() -> { + try (final ResultSet resultSet = statement.executeQuery( + CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD)) { + final int cachedColumnCount = resultSet.getMetaData().getColumnCount(); + Thread.sleep(300); + while (resultSet.next()) { + resultSet.getObject(RANDOM.nextInt(cachedColumnCount) + 1); + } + } catch (final SQLException | InterruptedException e) { + exceptions.add(e); + } finally { + latch.countDown(); + } + }); + thread.setName("Test Case: interrupt query execution before first retrieval"); + thread.start(); + statement.cancel(); + thread.join(); + collector.checkThat( + exceptions.stream() + .map(Exception::getMessage) + .map(StringBuilder::new) + .reduce(StringBuilder::append) + .orElseThrow(IllegalArgumentException::new) + .toString(), + is("Statement canceled")); + } + } + + @Test + public void testShouldInterruptFlightStreamsIfQueryIsCancelledMidProcessingForTimeConsumingQueries() + throws SQLException, InterruptedException { + final String query = CoreMockedSqlProducers.LEGACY_CANCELLATION_SQL_CMD; + try (final Statement statement = connection.createStatement()) { + final Set exceptions = synchronizedSet(new HashSet<>(1)); + final Thread thread = new Thread(() -> { + try (final ResultSet ignored = statement.executeQuery(query)) { + fail(); + } catch (final SQLException e) { + exceptions.add(e); + } + }); + thread.setName("Test Case: interrupt query execution mid-process"); + thread.setPriority(Thread.MAX_PRIORITY); + thread.start(); + Thread.sleep(5000); // Let the other thread attempt to retrieve results. + statement.cancel(); + thread.join(); + collector.checkThat( + exceptions.stream() + .map(Exception::getMessage) + .map(StringBuilder::new) + .reduce(StringBuilder::append) + .orElseThrow(IllegalStateException::new) + .toString(), + anyOf(is(format("Error while executing SQL \"%s\": Query canceled", query)), + allOf(containsString(format("Error while executing SQL \"%s\"", query)), + containsString("CANCELLED")))); + } + } + + @Test + public void testShouldInterruptFlightStreamsIfQueryTimeoutIsOver() throws SQLException { + final String query = CoreMockedSqlProducers.LEGACY_CANCELLATION_SQL_CMD; + final int timeoutValue = 2; + final String timeoutUnit = "SECONDS"; + try (final Statement statement = connection.createStatement()) { + statement.setQueryTimeout(timeoutValue); + final Set exceptions = new HashSet<>(1); + try { + statement.executeQuery(query); + } catch (final Exception e) { + exceptions.add(e); + } + final Throwable comparisonCause = exceptions.stream() + .findFirst() + .orElseThrow(RuntimeException::new) + .getCause() + .getCause(); + collector.checkThat(comparisonCause, + is(instanceOf(SQLTimeoutException.class))); + collector.checkThat(comparisonCause.getMessage(), + is(format("Query timed out after %d %s", timeoutValue, timeoutUnit))); + } + } + + @Test + public void testFlightStreamsQueryShouldNotTimeout() throws SQLException { + final String query = CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD; + final int timeoutValue = 5; + try (Statement statement = connection.createStatement()) { + statement.setQueryTimeout(timeoutValue); + ResultSet resultSet = statement.executeQuery(query); + CoreMockedSqlProducers.assertLegacyRegularSqlResultSet(resultSet, collector); + resultSet.close(); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/TokenAuthenticationTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/TokenAuthenticationTest.java new file mode 100644 index 00000000000..56c8c178f21 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/TokenAuthenticationTest.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc; + +import java.sql.Connection; +import java.sql.SQLException; + +import org.apache.arrow.driver.jdbc.authentication.TokenAuthentication; +import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; +import org.apache.arrow.util.AutoCloseables; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.ClassRule; +import org.junit.Test; + +public class TokenAuthenticationTest { + private static final MockFlightSqlProducer FLIGHT_SQL_PRODUCER = new MockFlightSqlProducer(); + + @ClassRule + public static FlightServerTestRule FLIGHT_SERVER_TEST_RULE; + + static { + FLIGHT_SERVER_TEST_RULE = new FlightServerTestRule.Builder() + .host("localhost") + .randomPort() + .authentication(new TokenAuthentication.Builder() + .token("1234") + .build()) + .producer(FLIGHT_SQL_PRODUCER) + .build(); + } + + @AfterClass + public static void tearDownAfterClass() { + AutoCloseables.closeNoChecked(FLIGHT_SQL_PRODUCER); + } + + @Test(expected = SQLException.class) + public void connectUsingTokenAuthenticationShouldFail() throws SQLException { + try (Connection ignored = FLIGHT_SERVER_TEST_RULE.getConnection(false, "invalid")) { + Assert.fail(); + } + } + + @Test + public void connectUsingTokenAuthenticationShouldSuccess() throws SQLException { + try (Connection connection = FLIGHT_SERVER_TEST_RULE.getConnection(false, "1234")) { + Assert.assertFalse(connection.isClosed()); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactoryTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactoryTest.java new file mode 100644 index 00000000000..4b3744372c0 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorFactoryTest.java @@ -0,0 +1,496 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor; + +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.impl.binary.ArrowFlightJdbcBinaryVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDateVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDurationVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcIntervalVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeStampVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcDenseUnionVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcFixedSizeListVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcLargeListVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcListVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcMapVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcStructVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.complex.ArrowFlightJdbcUnionVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcBaseIntVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcBitVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcDecimalVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcFloat4VectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.numeric.ArrowFlightJdbcFloat8VectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.text.ArrowFlightJdbcVarCharVectorAccessor; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.DurationVector; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.IntervalYearVector; +import org.apache.arrow.vector.LargeVarCharVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.UnionVector; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.junit.Assert; +import org.junit.ClassRule; +import org.junit.Test; + +public class ArrowFlightJdbcAccessorFactoryTest { + public static final IntSupplier GET_CURRENT_ROW = () -> 0; + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Test + public void createAccessorForUInt1Vector() { + try (ValueVector valueVector = rootAllocatorTestRule.createUInt1Vector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBaseIntVectorAccessor); + } + } + + @Test + public void createAccessorForUInt2Vector() { + try (ValueVector valueVector = rootAllocatorTestRule.createUInt2Vector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBaseIntVectorAccessor); + } + } + + @Test + public void createAccessorForUInt4Vector() { + try (ValueVector valueVector = rootAllocatorTestRule.createUInt4Vector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBaseIntVectorAccessor); + } + } + + @Test + public void createAccessorForUInt8Vector() { + try (ValueVector valueVector = rootAllocatorTestRule.createUInt8Vector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBaseIntVectorAccessor); + } + } + + @Test + public void createAccessorForTinyIntVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createTinyIntVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBaseIntVectorAccessor); + } + } + + @Test + public void createAccessorForSmallIntVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createSmallIntVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBaseIntVectorAccessor); + } + } + + @Test + public void createAccessorForIntVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createIntVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBaseIntVectorAccessor); + } + } + + @Test + public void createAccessorForBigIntVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createBigIntVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBaseIntVectorAccessor); + } + } + + @Test + public void createAccessorForFloat4Vector() { + try (ValueVector valueVector = rootAllocatorTestRule.createFloat4Vector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcFloat4VectorAccessor); + } + } + + @Test + public void createAccessorForFloat8Vector() { + try (ValueVector valueVector = rootAllocatorTestRule.createFloat8Vector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcFloat8VectorAccessor); + } + } + + @Test + public void createAccessorForBitVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createBitVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBitVectorAccessor); + } + } + + @Test + public void createAccessorForDecimalVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createDecimalVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcDecimalVectorAccessor); + } + } + + @Test + public void createAccessorForDecimal256Vector() { + try (ValueVector valueVector = rootAllocatorTestRule.createDecimal256Vector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcDecimalVectorAccessor); + } + } + + @Test + public void createAccessorForVarBinaryVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createVarBinaryVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBinaryVectorAccessor); + } + } + + @Test + public void createAccessorForLargeVarBinaryVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createLargeVarBinaryVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBinaryVectorAccessor); + } + } + + @Test + public void createAccessorForFixedSizeBinaryVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createFixedSizeBinaryVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcBinaryVectorAccessor); + } + } + + @Test + public void createAccessorForTimeStampVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createTimeStampMilliVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcTimeStampVectorAccessor); + } + } + + @Test + public void createAccessorForTimeNanoVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createTimeNanoVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcTimeVectorAccessor); + } + } + + @Test + public void createAccessorForTimeMicroVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createTimeMicroVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcTimeVectorAccessor); + } + } + + @Test + public void createAccessorForTimeMilliVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createTimeMilliVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcTimeVectorAccessor); + } + } + + @Test + public void createAccessorForTimeSecVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createTimeSecVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcTimeVectorAccessor); + } + } + + @Test + public void createAccessorForDateDayVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createDateDayVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcDateVectorAccessor); + } + } + + @Test + public void createAccessorForDateMilliVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createDateMilliVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcDateVectorAccessor); + } + } + + @Test + public void createAccessorForVarCharVector() { + try ( + ValueVector valueVector = new VarCharVector("", rootAllocatorTestRule.getRootAllocator())) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcVarCharVectorAccessor); + } + } + + @Test + public void createAccessorForLargeVarCharVector() { + try (ValueVector valueVector = new LargeVarCharVector("", + rootAllocatorTestRule.getRootAllocator())) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcVarCharVectorAccessor); + } + } + + @Test + public void createAccessorForDurationVector() { + try (ValueVector valueVector = + new DurationVector("", + new FieldType(true, new ArrowType.Duration(TimeUnit.MILLISECOND), null), + rootAllocatorTestRule.getRootAllocator())) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcDurationVectorAccessor); + } + } + + @Test + public void createAccessorForIntervalDayVector() { + try (ValueVector valueVector = new IntervalDayVector("", + rootAllocatorTestRule.getRootAllocator())) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcIntervalVectorAccessor); + } + } + + @Test + public void createAccessorForIntervalYearVector() { + try (ValueVector valueVector = new IntervalYearVector("", + rootAllocatorTestRule.getRootAllocator())) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcIntervalVectorAccessor); + } + } + + @Test + public void createAccessorForUnionVector() { + try (ValueVector valueVector = new UnionVector("", rootAllocatorTestRule.getRootAllocator(), + null, null)) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcUnionVectorAccessor); + } + } + + @Test + public void createAccessorForDenseUnionVector() { + try ( + ValueVector valueVector = new DenseUnionVector("", rootAllocatorTestRule.getRootAllocator(), + null, null)) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcDenseUnionVectorAccessor); + } + } + + @Test + public void createAccessorForStructVector() { + try (ValueVector valueVector = StructVector.empty("", + rootAllocatorTestRule.getRootAllocator())) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcStructVectorAccessor); + } + } + + @Test + public void createAccessorForListVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createListVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcListVectorAccessor); + } + } + + @Test + public void createAccessorForLargeListVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createLargeListVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcLargeListVectorAccessor); + } + } + + @Test + public void createAccessorForFixedSizeListVector() { + try (ValueVector valueVector = rootAllocatorTestRule.createFixedSizeListVector()) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcFixedSizeListVectorAccessor); + } + } + + @Test + public void createAccessorForMapVector() { + try (ValueVector valueVector = MapVector.empty("", rootAllocatorTestRule.getRootAllocator(), + true)) { + ArrowFlightJdbcAccessor accessor = + ArrowFlightJdbcAccessorFactory.createAccessor(valueVector, GET_CURRENT_ROW, + (boolean wasNull) -> { + }); + + Assert.assertTrue(accessor instanceof ArrowFlightJdbcMapVectorAccessor); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorTest.java new file mode 100644 index 00000000000..099b0122179 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/ArrowFlightJdbcAccessorTest.java @@ -0,0 +1,358 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor; + +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.sql.SQLException; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class ArrowFlightJdbcAccessorTest { + + static class MockedArrowFlightJdbcAccessor extends ArrowFlightJdbcAccessor { + + protected MockedArrowFlightJdbcAccessor() { + super(() -> 0, (boolean wasNull) -> { + }); + } + + @Override + public Class getObjectClass() { + return Long.class; + } + } + + @Mock + MockedArrowFlightJdbcAccessor accessor; + + @Test + public void testShouldGetObjectWithByteClassReturnGetByte() throws SQLException { + byte expected = Byte.MAX_VALUE; + when(accessor.getByte()).thenReturn(expected); + + when(accessor.getObject(Byte.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(Byte.class), (Object) expected); + verify(accessor).getByte(); + } + + @Test + public void testShouldGetObjectWithShortClassReturnGetShort() throws SQLException { + short expected = Short.MAX_VALUE; + when(accessor.getShort()).thenReturn(expected); + + when(accessor.getObject(Short.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(Short.class), (Object) expected); + verify(accessor).getShort(); + } + + @Test + public void testShouldGetObjectWithIntegerClassReturnGetInt() throws SQLException { + int expected = Integer.MAX_VALUE; + when(accessor.getInt()).thenReturn(expected); + + when(accessor.getObject(Integer.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(Integer.class), (Object) expected); + verify(accessor).getInt(); + } + + @Test + public void testShouldGetObjectWithLongClassReturnGetLong() throws SQLException { + long expected = Long.MAX_VALUE; + when(accessor.getLong()).thenReturn(expected); + + when(accessor.getObject(Long.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(Long.class), (Object) expected); + verify(accessor).getLong(); + } + + @Test + public void testShouldGetObjectWithFloatClassReturnGetFloat() throws SQLException { + float expected = Float.MAX_VALUE; + when(accessor.getFloat()).thenReturn(expected); + + when(accessor.getObject(Float.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(Float.class), (Object) expected); + verify(accessor).getFloat(); + } + + @Test + public void testShouldGetObjectWithDoubleClassReturnGetDouble() throws SQLException { + double expected = Double.MAX_VALUE; + when(accessor.getDouble()).thenReturn(expected); + + when(accessor.getObject(Double.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(Double.class), (Object) expected); + verify(accessor).getDouble(); + } + + @Test + public void testShouldGetObjectWithBooleanClassReturnGetBoolean() throws SQLException { + when(accessor.getBoolean()).thenReturn(true); + + when(accessor.getObject(Boolean.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(Boolean.class), true); + verify(accessor).getBoolean(); + } + + @Test + public void testShouldGetObjectWithBigDecimalClassReturnGetBigDecimal() throws SQLException { + BigDecimal expected = BigDecimal.TEN; + when(accessor.getBigDecimal()).thenReturn(expected); + + when(accessor.getObject(BigDecimal.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(BigDecimal.class), expected); + verify(accessor).getBigDecimal(); + } + + @Test + public void testShouldGetObjectWithStringClassReturnGetString() throws SQLException { + String expected = "STRING_VALUE"; + when(accessor.getString()).thenReturn(expected); + + when(accessor.getObject(String.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(String.class), expected); + verify(accessor).getString(); + } + + @Test + public void testShouldGetObjectWithByteArrayClassReturnGetBytes() throws SQLException { + byte[] expected = "STRING_VALUE".getBytes(StandardCharsets.UTF_8); + when(accessor.getBytes()).thenReturn(expected); + + when(accessor.getObject(byte[].class)).thenCallRealMethod(); + + Assert.assertArrayEquals(accessor.getObject(byte[].class), expected); + verify(accessor).getBytes(); + } + + @Test + public void testShouldGetObjectWithObjectClassReturnGetObject() throws SQLException { + Object expected = new Object(); + when(accessor.getObject()).thenReturn(expected); + + when(accessor.getObject(Object.class)).thenCallRealMethod(); + + Assert.assertEquals(accessor.getObject(Object.class), expected); + verify(accessor).getObject(); + } + + @Test + public void testShouldGetObjectWithAccessorsObjectClassReturnGetObject() throws SQLException { + Class objectClass = Long.class; + + when(accessor.getObject(objectClass)).thenCallRealMethod(); + + accessor.getObject(objectClass); + verify(accessor).getObject(objectClass); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetBoolean() throws SQLException { + when(accessor.getBoolean()).thenCallRealMethod(); + accessor.getBoolean(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetByte() throws SQLException { + when(accessor.getByte()).thenCallRealMethod(); + accessor.getByte(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetShort() throws SQLException { + when(accessor.getShort()).thenCallRealMethod(); + accessor.getShort(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetInt() throws SQLException { + when(accessor.getInt()).thenCallRealMethod(); + accessor.getInt(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetLong() throws SQLException { + when(accessor.getLong()).thenCallRealMethod(); + accessor.getLong(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetFloat() throws SQLException { + when(accessor.getFloat()).thenCallRealMethod(); + accessor.getFloat(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetDouble() throws SQLException { + when(accessor.getDouble()).thenCallRealMethod(); + accessor.getDouble(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetBigDecimal() throws SQLException { + when(accessor.getBigDecimal()).thenCallRealMethod(); + accessor.getBigDecimal(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetBytes() throws SQLException { + when(accessor.getBytes()).thenCallRealMethod(); + accessor.getBytes(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetAsciiStream() throws SQLException { + when(accessor.getAsciiStream()).thenCallRealMethod(); + accessor.getAsciiStream(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetUnicodeStream() throws SQLException { + when(accessor.getUnicodeStream()).thenCallRealMethod(); + accessor.getUnicodeStream(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetBinaryStream() throws SQLException { + when(accessor.getBinaryStream()).thenCallRealMethod(); + accessor.getBinaryStream(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetObject() throws SQLException { + when(accessor.getObject()).thenCallRealMethod(); + accessor.getObject(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetObjectMap() throws SQLException { + Map> map = new HashMap<>(); + when(accessor.getObject(map)).thenCallRealMethod(); + accessor.getObject(map); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetCharacterStream() throws SQLException { + when(accessor.getCharacterStream()).thenCallRealMethod(); + accessor.getCharacterStream(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetRef() throws SQLException { + when(accessor.getRef()).thenCallRealMethod(); + accessor.getRef(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetBlob() throws SQLException { + when(accessor.getBlob()).thenCallRealMethod(); + accessor.getBlob(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetClob() throws SQLException { + when(accessor.getClob()).thenCallRealMethod(); + accessor.getClob(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetArray() throws SQLException { + when(accessor.getArray()).thenCallRealMethod(); + accessor.getArray(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetStruct() throws SQLException { + when(accessor.getStruct()).thenCallRealMethod(); + accessor.getStruct(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetURL() throws SQLException { + when(accessor.getURL()).thenCallRealMethod(); + accessor.getURL(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetNClob() throws SQLException { + when(accessor.getNClob()).thenCallRealMethod(); + accessor.getNClob(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetSQLXML() throws SQLException { + when(accessor.getSQLXML()).thenCallRealMethod(); + accessor.getSQLXML(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetNString() throws SQLException { + when(accessor.getNString()).thenCallRealMethod(); + accessor.getNString(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetNCharacterStream() throws SQLException { + when(accessor.getNCharacterStream()).thenCallRealMethod(); + accessor.getNCharacterStream(); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetDate() throws SQLException { + when(accessor.getDate(null)).thenCallRealMethod(); + accessor.getDate(null); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetTime() throws SQLException { + when(accessor.getTime(null)).thenCallRealMethod(); + accessor.getTime(null); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetTimestamp() throws SQLException { + when(accessor.getTimestamp(null)).thenCallRealMethod(); + accessor.getTimestamp(null); + } + + @Test(expected = SQLException.class) + public void testShouldFailToGetBigDecimalWithValue() throws SQLException { + when(accessor.getBigDecimal(0)).thenCallRealMethod(); + accessor.getBigDecimal(0); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/ArrowFlightJdbcNullVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/ArrowFlightJdbcNullVectorAccessorTest.java new file mode 100644 index 00000000000..57e7ecfe025 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/ArrowFlightJdbcNullVectorAccessorTest.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl; + +import org.junit.Assert; +import org.junit.Test; + +public class ArrowFlightJdbcNullVectorAccessorTest { + + ArrowFlightJdbcNullVectorAccessor accessor = + new ArrowFlightJdbcNullVectorAccessor((boolean wasNull) -> { + }); + + @Test + public void testShouldWasNullReturnTrue() { + Assert.assertTrue(accessor.wasNull()); + } + + @Test + public void testShouldGetObjectReturnNull() { + Assert.assertNull(accessor.getObject()); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcBinaryVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcBinaryVectorAccessorTest.java new file mode 100644 index 00000000000..f4d256c4cf8 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/binary/ArrowFlightJdbcBinaryVectorAccessorTest.java @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.binary; + +import static java.nio.charset.StandardCharsets.US_ASCII; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.hamcrest.CoreMatchers.is; + +import java.io.InputStream; +import java.io.Reader; +import java.util.Arrays; +import java.util.Collection; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.LargeVarBinaryVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.commons.io.IOUtils; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class ArrowFlightJdbcBinaryVectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private ValueVector vector; + private final Supplier vectorSupplier; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> { + ArrowFlightJdbcAccessorFactory.WasNullConsumer noOpWasNullConsumer = (boolean wasNull) -> { + }; + if (vector instanceof VarBinaryVector) { + return new ArrowFlightJdbcBinaryVectorAccessor(((VarBinaryVector) vector), getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof LargeVarBinaryVector) { + return new ArrowFlightJdbcBinaryVectorAccessor(((LargeVarBinaryVector) vector), + getCurrentRow, noOpWasNullConsumer); + } else if (vector instanceof FixedSizeBinaryVector) { + return new ArrowFlightJdbcBinaryVectorAccessor(((FixedSizeBinaryVector) vector), + getCurrentRow, noOpWasNullConsumer); + } + return null; + }; + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {(Supplier) () -> rootAllocatorTestRule.createVarBinaryVector(), + "VarBinaryVector"}, + {(Supplier) () -> rootAllocatorTestRule.createLargeVarBinaryVector(), + "LargeVarBinaryVector"}, + {(Supplier) () -> rootAllocatorTestRule.createFixedSizeBinaryVector(), + "FixedSizeBinaryVector"}, + }); + } + + public ArrowFlightJdbcBinaryVectorAccessorTest(Supplier vectorSupplier, + String vectorType) { + this.vectorSupplier = vectorSupplier; + } + + @Before + public void setup() { + this.vector = vectorSupplier.get(); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void testShouldGetStringReturnExpectedString() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBinaryVectorAccessor::getString, + (accessor) -> is(new String(accessor.getBytes(), UTF_8))); + } + + @Test + public void testShouldGetStringReturnNull() throws Exception { + vector.reset(); + vector.setValueCount(5); + + accessorIterator + .assertAccessorGetter(vector, ArrowFlightJdbcBinaryVectorAccessor::getString, + CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetBytesReturnExpectedByteArray() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBinaryVectorAccessor::getBytes, + (accessor, currentRow) -> { + if (vector instanceof VarBinaryVector) { + return is(((VarBinaryVector) vector).get(currentRow)); + } else if (vector instanceof LargeVarBinaryVector) { + return is(((LargeVarBinaryVector) vector).get(currentRow)); + } else if (vector instanceof FixedSizeBinaryVector) { + return is(((FixedSizeBinaryVector) vector).get(currentRow)); + } + return null; + }); + } + + @Test + public void testShouldGetBytesReturnNull() throws Exception { + vector.reset(); + vector.setValueCount(5); + + ArrowFlightJdbcBinaryVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getBytes(), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetObjectReturnAsGetBytes() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBinaryVectorAccessor::getObject, + (accessor) -> is(accessor.getBytes())); + } + + @Test + public void testShouldGetObjectReturnNull() { + vector.reset(); + vector.setValueCount(5); + + ArrowFlightJdbcBinaryVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getObject(), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetUnicodeStreamReturnCorrectInputStream() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + InputStream inputStream = accessor.getUnicodeStream(); + String actualString = IOUtils.toString(inputStream, UTF_8); + collector.checkThat(accessor.wasNull(), is(false)); + collector.checkThat(actualString, is(accessor.getString())); + }); + } + + @Test + public void testShouldGetUnicodeStreamReturnNull() throws Exception { + vector.reset(); + vector.setValueCount(5); + + ArrowFlightJdbcBinaryVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getUnicodeStream(), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetAsciiStreamReturnCorrectInputStream() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + InputStream inputStream = accessor.getAsciiStream(); + String actualString = IOUtils.toString(inputStream, US_ASCII); + collector.checkThat(accessor.wasNull(), is(false)); + collector.checkThat(actualString, is(accessor.getString())); + }); + } + + @Test + public void testShouldGetAsciiStreamReturnNull() throws Exception { + vector.reset(); + vector.setValueCount(5); + + ArrowFlightJdbcBinaryVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getAsciiStream(), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetBinaryStreamReturnCurrentInputStream() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + InputStream inputStream = accessor.getBinaryStream(); + String actualString = IOUtils.toString(inputStream, UTF_8); + collector.checkThat(accessor.wasNull(), is(false)); + collector.checkThat(actualString, is(accessor.getString())); + }); + } + + @Test + public void testShouldGetBinaryStreamReturnNull() throws Exception { + vector.reset(); + vector.setValueCount(5); + + ArrowFlightJdbcBinaryVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getBinaryStream(), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetCharacterStreamReturnCorrectReader() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + Reader characterStream = accessor.getCharacterStream(); + String actualString = IOUtils.toString(characterStream); + collector.checkThat(accessor.wasNull(), is(false)); + collector.checkThat(actualString, is(accessor.getString())); + }); + } + + @Test + public void testShouldGetCharacterStreamReturnNull() throws Exception { + vector.reset(); + vector.setValueCount(5); + + ArrowFlightJdbcBinaryVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getCharacterStream(), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorAccessorTest.java new file mode 100644 index 00000000000..36af5134626 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDateVectorAccessorTest.java @@ -0,0 +1,254 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDateVectorAccessor.getTimeUnitForVector; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; + +import java.sql.Date; +import java.sql.Timestamp; +import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.Calendar; +import java.util.Collection; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.impl.text.ArrowFlightJdbcVarCharVectorAccessor; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.util.Text; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class ArrowFlightJdbcDateVectorAccessorTest { + + public static final String AMERICA_VANCOUVER = "America/Vancouver"; + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private BaseFixedWidthVector vector; + private final Supplier vectorSupplier; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> { + if (vector instanceof DateDayVector) { + return new ArrowFlightJdbcDateVectorAccessor((DateDayVector) vector, getCurrentRow, + (boolean wasNull) -> { + }); + } else if (vector instanceof DateMilliVector) { + return new ArrowFlightJdbcDateVectorAccessor((DateMilliVector) vector, getCurrentRow, + (boolean wasNull) -> { + }); + } + return null; + }; + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {(Supplier) () -> rootAllocatorTestRule.createDateDayVector(), + "DateDayVector"}, + {(Supplier) () -> rootAllocatorTestRule.createDateMilliVector(), + "DateMilliVector"}, + }); + } + + public ArrowFlightJdbcDateVectorAccessorTest(Supplier vectorSupplier, + String vectorType) { + this.vectorSupplier = vectorSupplier; + } + + @Before + public void setup() { + this.vector = vectorSupplier.get(); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void testShouldGetTimestampReturnValidTimestampWithoutCalendar() throws Exception { + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getTimestamp(null), + (accessor, currentRow) -> is(getTimestampForVector(currentRow))); + } + + @Test + public void testShouldGetObjectWithDateClassReturnValidDateWithoutCalendar() throws Exception { + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getObject(Date.class), + (accessor, currentRow) -> is(new Date(getTimestampForVector(currentRow).getTime()))); + } + + @Test + public void testShouldGetTimestampReturnValidTimestampWithCalendar() throws Exception { + TimeZone timeZone = TimeZone.getTimeZone(AMERICA_VANCOUVER); + Calendar calendar = Calendar.getInstance(timeZone); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final Timestamp resultWithoutCalendar = accessor.getTimestamp(null); + final Timestamp result = accessor.getTimestamp(calendar); + + long offset = timeZone.getOffset(resultWithoutCalendar.getTime()); + + collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + + @Test + public void testShouldGetTimestampReturnNull() { + vector.setNull(0); + ArrowFlightJdbcDateVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getTimestamp(null), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetDateReturnValidDateWithoutCalendar() throws Exception { + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getDate(null), + (accessor, currentRow) -> is(new Date(getTimestampForVector(currentRow).getTime()))); + } + + @Test + public void testShouldGetDateReturnValidDateWithCalendar() throws Exception { + TimeZone timeZone = TimeZone.getTimeZone(AMERICA_VANCOUVER); + Calendar calendar = Calendar.getInstance(timeZone); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final Date resultWithoutCalendar = accessor.getDate(null); + final Date result = accessor.getDate(calendar); + + long offset = timeZone.getOffset(resultWithoutCalendar.getTime()); + + collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + + @Test + public void testShouldGetDateReturnNull() { + vector.setNull(0); + ArrowFlightJdbcDateVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getDate(null), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + private Timestamp getTimestampForVector(int currentRow) { + Object object = vector.getObject(currentRow); + + Timestamp expectedTimestamp = null; + if (object instanceof LocalDateTime) { + expectedTimestamp = Timestamp.valueOf((LocalDateTime) object); + } else if (object instanceof Number) { + long value = ((Number) object).longValue(); + TimeUnit timeUnit = getTimeUnitForVector(vector); + long millis = timeUnit.toMillis(value); + expectedTimestamp = new Timestamp(millis); + } + return expectedTimestamp; + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator + .assertAccessorGetter(vector, ArrowFlightJdbcDateVectorAccessor::getObjectClass, + equalTo(Date.class)); + } + + @Test + public void testShouldGetStringBeConsistentWithVarCharAccessorWithoutCalendar() throws Exception { + assertGetStringIsConsistentWithVarCharAccessor(null); + } + + @Test + public void testShouldGetStringBeConsistentWithVarCharAccessorWithCalendar() throws Exception { + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(AMERICA_VANCOUVER)); + assertGetStringIsConsistentWithVarCharAccessor(calendar); + } + + @Test + public void testValidateGetStringTimeZoneConsistency() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final TimeZone defaultTz = TimeZone.getDefault(); + try { + final String string = accessor.getString(); // Should always be UTC as no calendar is provided + + // Validate with UTC + Date date = accessor.getDate(null); + TimeZone.setDefault(TimeZone.getTimeZone("UTC")); + collector.checkThat(date.toString(), is(string)); + + // Validate with different TZ + TimeZone.setDefault(TimeZone.getTimeZone(AMERICA_VANCOUVER)); + collector.checkThat(date.toString(), not(string)); + + collector.checkThat(accessor.wasNull(), is(false)); + } finally { + // Set default Tz back + TimeZone.setDefault(defaultTz); + } + }); + } + + private void assertGetStringIsConsistentWithVarCharAccessor(Calendar calendar) throws Exception { + try (VarCharVector varCharVector = new VarCharVector("", + rootAllocatorTestRule.getRootAllocator())) { + varCharVector.allocateNew(1); + ArrowFlightJdbcVarCharVectorAccessor varCharVectorAccessor = + new ArrowFlightJdbcVarCharVectorAccessor(varCharVector, () -> 0, (boolean wasNull) -> { + }); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final String string = accessor.getString(); + varCharVector.set(0, new Text(string)); + varCharVector.setValueCount(1); + + Date dateFromVarChar = varCharVectorAccessor.getDate(calendar); + Date date = accessor.getDate(calendar); + + collector.checkThat(date, is(dateFromVarChar)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDurationVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDurationVectorAccessorTest.java new file mode 100644 index 00000000000..64ddb573f1b --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcDurationVectorAccessorTest.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.time.Duration; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.DurationVector; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightJdbcDurationVectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private DurationVector vector; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = + (vector, getCurrentRow) -> new ArrowFlightJdbcDurationVectorAccessor((DurationVector) vector, + getCurrentRow, (boolean wasNull) -> { + }); + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Before + public void setup() { + FieldType fieldType = new FieldType(true, new ArrowType.Duration(TimeUnit.MILLISECOND), null); + this.vector = new DurationVector("", fieldType, rootAllocatorTestRule.getRootAllocator()); + + int valueCount = 10; + this.vector.setValueCount(valueCount); + for (int i = 0; i < valueCount; i++) { + this.vector.set(i, java.util.concurrent.TimeUnit.DAYS.toMillis(i + 1)); + } + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void getObject() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDurationVectorAccessor::getObject, + (accessor, currentRow) -> is(Duration.ofDays(currentRow + 1))); + } + + @Test + public void getObjectForNull() throws Exception { + int valueCount = vector.getValueCount(); + for (int i = 0; i < valueCount; i++) { + vector.setNull(i); + } + + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDurationVectorAccessor::getObject, + (accessor, currentRow) -> equalTo(null)); + } + + @Test + public void getString() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcAccessor::getString, + (accessor, currentRow) -> is(Duration.ofDays(currentRow + 1).toString())); + } + + @Test + public void getStringForNull() throws Exception { + int valueCount = vector.getValueCount(); + for (int i = 0; i < valueCount; i++) { + vector.setNull(i); + } + + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcAccessor::getString, + (accessor, currentRow) -> equalTo(null)); + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcAccessor::getObjectClass, + (accessor, currentRow) -> equalTo(Duration.class)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcIntervalVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcIntervalVectorAccessorTest.java new file mode 100644 index 00000000000..ea228692202 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcIntervalVectorAccessorTest.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.apache.arrow.driver.jdbc.utils.IntervalStringUtils.formatIntervalDay; +import static org.apache.arrow.driver.jdbc.utils.IntervalStringUtils.formatIntervalYear; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.joda.time.Period.parse; + +import java.time.Duration; +import java.time.Period; +import java.util.Arrays; +import java.util.Collection; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.IntervalDayVector; +import org.apache.arrow.vector.IntervalYearVector; +import org.apache.arrow.vector.ValueVector; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class ArrowFlightJdbcIntervalVectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private final Supplier vectorSupplier; + private ValueVector vector; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> { + ArrowFlightJdbcAccessorFactory.WasNullConsumer noOpWasNullConsumer = (boolean wasNull) -> { + }; + if (vector instanceof IntervalDayVector) { + return new ArrowFlightJdbcIntervalVectorAccessor((IntervalDayVector) vector, + getCurrentRow, noOpWasNullConsumer); + } else if (vector instanceof IntervalYearVector) { + return new ArrowFlightJdbcIntervalVectorAccessor((IntervalYearVector) vector, + getCurrentRow, noOpWasNullConsumer); + } + return null; + }; + + final AccessorTestUtils.AccessorIterator accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {(Supplier) () -> { + IntervalDayVector vector = + new IntervalDayVector("", rootAllocatorTestRule.getRootAllocator()); + + int valueCount = 10; + vector.setValueCount(valueCount); + for (int i = 0; i < valueCount; i++) { + vector.set(i, i + 1, (i + 1) * 1000); + } + return vector; + }, "IntervalDayVector"}, + {(Supplier) () -> { + IntervalYearVector vector = + new IntervalYearVector("", rootAllocatorTestRule.getRootAllocator()); + + int valueCount = 10; + vector.setValueCount(valueCount); + for (int i = 0; i < valueCount; i++) { + vector.set(i, i + 1); + } + return vector; + }, "IntervalYearVector"}, + }); + } + + public ArrowFlightJdbcIntervalVectorAccessorTest(Supplier vectorSupplier, + String vectorType) { + this.vectorSupplier = vectorSupplier; + } + + @Before + public void setup() { + this.vector = vectorSupplier.get(); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void testShouldGetObjectReturnValidObject() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcIntervalVectorAccessor::getObject, + (accessor, currentRow) -> is(getExpectedObject(vector, currentRow))); + } + + @Test + public void testShouldGetObjectPassingObjectClassAsParameterReturnValidObject() throws Exception { + Class objectClass = getExpectedObjectClassForVector(vector); + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getObject(objectClass), + (accessor, currentRow) -> is(getExpectedObject(vector, currentRow))); + } + + @Test + public void testShouldGetObjectReturnNull() throws Exception { + setAllNullOnVector(vector); + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcIntervalVectorAccessor::getObject, + (accessor, currentRow) -> equalTo(null)); + } + + private String getStringOnVector(ValueVector vector, int index) { + String object = getExpectedObject(vector, index).toString(); + if (object == null) { + return null; + } else if (vector instanceof IntervalDayVector) { + return formatIntervalDay(parse(object)); + } else if (vector instanceof IntervalYearVector) { + return formatIntervalYear(parse(object)); + } + return null; + } + + @Test + public void testShouldGetIntervalYear( ) { + Assert.assertEquals("-002-00", formatIntervalYear(parse("P-2Y"))); + Assert.assertEquals("-001-01", formatIntervalYear(parse("P-1Y-1M"))); + Assert.assertEquals("-001-02", formatIntervalYear(parse("P-1Y-2M"))); + Assert.assertEquals("-002-03", formatIntervalYear(parse("P-2Y-3M"))); + Assert.assertEquals("-002-04", formatIntervalYear(parse("P-2Y-4M"))); + Assert.assertEquals("-011-01", formatIntervalYear(parse("P-11Y-1M"))); + Assert.assertEquals("+002-00", formatIntervalYear(parse("P+2Y"))); + Assert.assertEquals("+001-01", formatIntervalYear(parse("P+1Y1M"))); + Assert.assertEquals("+001-02", formatIntervalYear(parse("P+1Y2M"))); + Assert.assertEquals("+002-03", formatIntervalYear(parse("P+2Y3M"))); + Assert.assertEquals("+002-04", formatIntervalYear(parse("P+2Y4M"))); + Assert.assertEquals("+011-01", formatIntervalYear(parse("P+11Y1M"))); + } + + @Test + public void testShouldGetIntervalDay( ) { + Assert.assertEquals("-001 00:00:00.000", formatIntervalDay(parse("PT-24H"))); + Assert.assertEquals("+001 00:00:00.000", formatIntervalDay(parse("PT+24H"))); + Assert.assertEquals("-000 01:00:00.000", formatIntervalDay(parse("PT-1H"))); + Assert.assertEquals("-000 01:00:00.001", formatIntervalDay(parse("PT-1H-0M-00.001S"))); + Assert.assertEquals("-000 01:01:01.000", formatIntervalDay(parse("PT-1H-1M-1S"))); + Assert.assertEquals("-000 02:02:02.002", formatIntervalDay(parse("PT-2H-2M-02.002S"))); + Assert.assertEquals("-000 23:59:59.999", formatIntervalDay(parse("PT-23H-59M-59.999S"))); + Assert.assertEquals("-000 11:59:00.100", formatIntervalDay(parse("PT-11H-59M-00.100S"))); + Assert.assertEquals("-000 05:02:03.000", formatIntervalDay(parse("PT-5H-2M-3S"))); + Assert.assertEquals("-000 22:22:22.222", formatIntervalDay(parse("PT-22H-22M-22.222S"))); + Assert.assertEquals("+000 01:00:00.000", formatIntervalDay(parse("PT+1H"))); + Assert.assertEquals("+000 01:00:00.001", formatIntervalDay(parse("PT+1H0M00.001S"))); + Assert.assertEquals("+000 01:01:01.000", formatIntervalDay(parse("PT+1H1M1S"))); + Assert.assertEquals("+000 02:02:02.002", formatIntervalDay(parse("PT+2H2M02.002S"))); + Assert.assertEquals("+000 23:59:59.999", formatIntervalDay(parse("PT+23H59M59.999S"))); + Assert.assertEquals("+000 11:59:00.100", formatIntervalDay(parse("PT+11H59M00.100S"))); + Assert.assertEquals("+000 05:02:03.000", formatIntervalDay(parse("PT+5H2M3S"))); + Assert.assertEquals("+000 22:22:22.222", formatIntervalDay(parse("PT+22H22M22.222S"))); + } + + @Test + public void testIntervalDayWithJodaPeriodObject() { + Assert.assertEquals("+1567 00:00:00.000", + formatIntervalDay(new org.joda.time.Period().plusDays(1567))); + Assert.assertEquals("-1567 00:00:00.000", + formatIntervalDay(new org.joda.time.Period().minusDays(1567))); + } + + @Test + public void testShouldGetStringReturnCorrectString() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcIntervalVectorAccessor::getString, + (accessor, currentRow) -> is(getStringOnVector(vector, currentRow))); + } + + @Test + public void testShouldGetStringReturnNull() throws Exception { + setAllNullOnVector(vector); + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcIntervalVectorAccessor::getString, + (accessor, currentRow) -> equalTo(null)); + } + + @Test + public void testShouldGetObjectClassReturnCorrectClass() throws Exception { + Class expectedObjectClass = getExpectedObjectClassForVector(vector); + accessorIterator.assertAccessorGetter(vector, + ArrowFlightJdbcIntervalVectorAccessor::getObjectClass, + (accessor, currentRow) -> equalTo(expectedObjectClass)); + } + + private Class getExpectedObjectClassForVector(ValueVector vector) { + if (vector instanceof IntervalDayVector) { + return Duration.class; + } else if (vector instanceof IntervalYearVector) { + return Period.class; + } + return null; + } + + private void setAllNullOnVector(ValueVector vector) { + int valueCount = vector.getValueCount(); + if (vector instanceof IntervalDayVector) { + for (int i = 0; i < valueCount; i++) { + ((IntervalDayVector) vector).setNull(i); + } + } else if (vector instanceof IntervalYearVector) { + for (int i = 0; i < valueCount; i++) { + ((IntervalYearVector) vector).setNull(i); + } + } + } + + private Object getExpectedObject(ValueVector vector, int currentRow) { + if (vector instanceof IntervalDayVector) { + return Duration.ofDays(currentRow + 1).plusMillis((currentRow + 1) * 1000L); + } else if (vector instanceof IntervalYearVector) { + return Period.ofMonths(currentRow + 1); + } + return null; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorAccessorTest.java new file mode 100644 index 00000000000..38d842724b9 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeStampVectorAccessorTest.java @@ -0,0 +1,322 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeStampVectorAccessor.getTimeUnitForVector; +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeStampVectorAccessor.getTimeZoneForVector; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.sql.Date; +import java.sql.Time; +import java.sql.Timestamp; +import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.Calendar; +import java.util.Collection; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.impl.text.ArrowFlightJdbcVarCharVectorAccessor; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.util.Text; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Assume; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class ArrowFlightJdbcTimeStampVectorAccessorTest { + + public static final String AMERICA_VANCOUVER = "America/Vancouver"; + public static final String ASIA_BANGKOK = "Asia/Bangkok"; + public static final String AMERICA_SAO_PAULO = "America/Sao_Paulo"; + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + private final String timeZone; + + private TimeStampVector vector; + private final Supplier vectorSupplier; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = + (vector, getCurrentRow) -> new ArrowFlightJdbcTimeStampVectorAccessor( + (TimeStampVector) vector, getCurrentRow, (boolean wasNull) -> { + }); + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Parameterized.Parameters(name = "{1} - TimeZone: {2}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {(Supplier) () -> rootAllocatorTestRule.createTimeStampNanoVector(), + "TimeStampNanoVector", + null}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampNanoTZVector("UTC"), + "TimeStampNanoTZVector", + "UTC"}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampNanoTZVector( + AMERICA_VANCOUVER), + "TimeStampNanoTZVector", + AMERICA_VANCOUVER}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampNanoTZVector( + ASIA_BANGKOK), + "TimeStampNanoTZVector", + ASIA_BANGKOK}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampMicroVector(), + "TimeStampMicroVector", + null}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampMicroTZVector( + "UTC"), + "TimeStampMicroTZVector", + "UTC"}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampMicroTZVector( + AMERICA_VANCOUVER), + "TimeStampMicroTZVector", + AMERICA_VANCOUVER}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampMicroTZVector( + ASIA_BANGKOK), + "TimeStampMicroTZVector", + ASIA_BANGKOK}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampMilliVector(), + "TimeStampMilliVector", + null}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampMilliTZVector( + "UTC"), + "TimeStampMilliTZVector", + "UTC"}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampMilliTZVector( + AMERICA_VANCOUVER), + "TimeStampMilliTZVector", + AMERICA_VANCOUVER}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampMilliTZVector( + ASIA_BANGKOK), + "TimeStampMilliTZVector", + ASIA_BANGKOK}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampSecVector(), + "TimeStampSecVector", + null}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampSecTZVector("UTC"), + "TimeStampSecTZVector", + "UTC"}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampSecTZVector( + AMERICA_VANCOUVER), + "TimeStampSecTZVector", + AMERICA_VANCOUVER}, + {(Supplier) () -> rootAllocatorTestRule.createTimeStampSecTZVector( + ASIA_BANGKOK), + "TimeStampSecTZVector", + ASIA_BANGKOK} + }); + } + + public ArrowFlightJdbcTimeStampVectorAccessorTest(Supplier vectorSupplier, + String vectorType, + String timeZone) { + this.vectorSupplier = vectorSupplier; + this.timeZone = timeZone; + } + + @Before + public void setup() { + this.vector = vectorSupplier.get(); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void testShouldGetTimestampReturnValidTimestampWithoutCalendar() throws Exception { + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getTimestamp(null), + (accessor, currentRow) -> is(getTimestampForVector(currentRow))); + } + + @Test + public void testShouldGetTimestampReturnValidTimestampWithCalendar() throws Exception { + TimeZone timeZone = TimeZone.getTimeZone(AMERICA_SAO_PAULO); + Calendar calendar = Calendar.getInstance(timeZone); + + TimeZone timeZoneForVector = getTimeZoneForVector(vector); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final Timestamp resultWithoutCalendar = accessor.getTimestamp(null); + final Timestamp result = accessor.getTimestamp(calendar); + + long offset = timeZone.getOffset(resultWithoutCalendar.getTime()) - + timeZoneForVector.getOffset(resultWithoutCalendar.getTime()); + + collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + + @Test + public void testShouldGetTimestampReturnNull() { + vector.setNull(0); + ArrowFlightJdbcTimeStampVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getTimestamp(null), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetDateReturnValidDateWithoutCalendar() throws Exception { + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getDate(null), + (accessor, currentRow) -> is(new Date(getTimestampForVector(currentRow).getTime()))); + } + + @Test + public void testShouldGetDateReturnValidDateWithCalendar() throws Exception { + TimeZone timeZone = TimeZone.getTimeZone(AMERICA_SAO_PAULO); + Calendar calendar = Calendar.getInstance(timeZone); + + TimeZone timeZoneForVector = getTimeZoneForVector(vector); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final Date resultWithoutCalendar = accessor.getDate(null); + final Date result = accessor.getDate(calendar); + + long offset = timeZone.getOffset(resultWithoutCalendar.getTime()) - + timeZoneForVector.getOffset(resultWithoutCalendar.getTime()); + + collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + + @Test + public void testShouldGetDateReturnNull() { + vector.setNull(0); + ArrowFlightJdbcTimeStampVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getDate(null), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetTimeReturnValidTimeWithoutCalendar() throws Exception { + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getTime(null), + (accessor, currentRow) -> is(new Time(getTimestampForVector(currentRow).getTime()))); + } + + @Test + public void testShouldGetTimeReturnValidTimeWithCalendar() throws Exception { + TimeZone timeZone = TimeZone.getTimeZone(AMERICA_SAO_PAULO); + Calendar calendar = Calendar.getInstance(timeZone); + + TimeZone timeZoneForVector = getTimeZoneForVector(vector); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final Time resultWithoutCalendar = accessor.getTime(null); + final Time result = accessor.getTime(calendar); + + long offset = timeZone.getOffset(resultWithoutCalendar.getTime()) - + timeZoneForVector.getOffset(resultWithoutCalendar.getTime()); + + collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + + @Test + public void testShouldGetTimeReturnNull() { + vector.setNull(0); + ArrowFlightJdbcTimeStampVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getTime(null), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + private Timestamp getTimestampForVector(int currentRow) { + Object object = vector.getObject(currentRow); + + Timestamp expectedTimestamp = null; + if (object instanceof LocalDateTime) { + expectedTimestamp = Timestamp.valueOf((LocalDateTime) object); + } else if (object instanceof Long) { + TimeUnit timeUnit = getTimeUnitForVector(vector); + long millis = timeUnit.toMillis((Long) object); + long offset = TimeZone.getTimeZone(timeZone).getOffset(millis); + expectedTimestamp = new Timestamp(millis + offset); + } + return expectedTimestamp; + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, + ArrowFlightJdbcTimeStampVectorAccessor::getObjectClass, + equalTo(Timestamp.class)); + } + + @Test + public void testShouldGetStringBeConsistentWithVarCharAccessorWithoutCalendar() throws Exception { + assertGetStringIsConsistentWithVarCharAccessor(null); + } + + @Test + public void testShouldGetStringBeConsistentWithVarCharAccessorWithCalendar() throws Exception { + // Ignore for TimeStamp vectors with TZ, as VarChar accessor won't consider their TZ + Assume.assumeTrue( + vector instanceof TimeStampNanoVector || vector instanceof TimeStampMicroVector || + vector instanceof TimeStampMilliVector || vector instanceof TimeStampSecVector); + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(AMERICA_VANCOUVER)); + assertGetStringIsConsistentWithVarCharAccessor(calendar); + } + + private void assertGetStringIsConsistentWithVarCharAccessor(Calendar calendar) throws Exception { + try (VarCharVector varCharVector = new VarCharVector("", + rootAllocatorTestRule.getRootAllocator())) { + varCharVector.allocateNew(1); + ArrowFlightJdbcVarCharVectorAccessor varCharVectorAccessor = + new ArrowFlightJdbcVarCharVectorAccessor(varCharVector, () -> 0, (boolean wasNull) -> { + }); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final String string = accessor.getString(); + varCharVector.set(0, new Text(string)); + varCharVector.setValueCount(1); + + Timestamp timestampFromVarChar = varCharVectorAccessor.getTimestamp(calendar); + Timestamp timestamp = accessor.getTimestamp(calendar); + + collector.checkThat(timestamp, is(timestampFromVarChar)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorAccessorTest.java new file mode 100644 index 00000000000..d2f7eb336af --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/calendar/ArrowFlightJdbcTimeVectorAccessorTest.java @@ -0,0 +1,263 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.calendar; + +import static org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeVectorAccessor.getTimeUnitForVector; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; + +import java.sql.Time; +import java.sql.Timestamp; +import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.Calendar; +import java.util.Collection; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.accessor.impl.text.ArrowFlightJdbcVarCharVectorAccessor; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.BaseFixedWidthVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.util.Text; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class ArrowFlightJdbcTimeVectorAccessorTest { + + public static final String AMERICA_VANCOUVER = "America/Vancouver"; + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private BaseFixedWidthVector vector; + private final Supplier vectorSupplier; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> { + ArrowFlightJdbcAccessorFactory.WasNullConsumer noOpWasNullConsumer = (boolean wasNull) -> { + }; + if (vector instanceof TimeNanoVector) { + return new ArrowFlightJdbcTimeVectorAccessor((TimeNanoVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof TimeMicroVector) { + return new ArrowFlightJdbcTimeVectorAccessor((TimeMicroVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof TimeMilliVector) { + return new ArrowFlightJdbcTimeVectorAccessor((TimeMilliVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof TimeSecVector) { + return new ArrowFlightJdbcTimeVectorAccessor((TimeSecVector) vector, getCurrentRow, + noOpWasNullConsumer); + } + return null; + }; + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {(Supplier) () -> rootAllocatorTestRule.createTimeNanoVector(), + "TimeNanoVector"}, + {(Supplier) () -> rootAllocatorTestRule.createTimeMicroVector(), + "TimeMicroVector"}, + {(Supplier) () -> rootAllocatorTestRule.createTimeMilliVector(), + "TimeMilliVector"}, + {(Supplier) () -> rootAllocatorTestRule.createTimeSecVector(), + "TimeSecVector"} + }); + } + + public ArrowFlightJdbcTimeVectorAccessorTest(Supplier vectorSupplier, + String vectorType) { + this.vectorSupplier = vectorSupplier; + } + + @Before + public void setup() { + this.vector = vectorSupplier.get(); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void testShouldGetTimestampReturnValidTimestampWithoutCalendar() throws Exception { + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getTimestamp(null), + (accessor, currentRow) -> is(getTimestampForVector(currentRow))); + } + + @Test + public void testShouldGetTimestampReturnValidTimestampWithCalendar() throws Exception { + TimeZone timeZone = TimeZone.getTimeZone(AMERICA_VANCOUVER); + Calendar calendar = Calendar.getInstance(timeZone); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final Timestamp resultWithoutCalendar = accessor.getTimestamp(null); + final Timestamp result = accessor.getTimestamp(calendar); + + long offset = timeZone.getOffset(resultWithoutCalendar.getTime()); + + collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + + @Test + public void testShouldGetTimestampReturnNull() { + vector.setNull(0); + ArrowFlightJdbcTimeVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getTimestamp(null), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + @Test + public void testShouldGetTimeReturnValidTimeWithoutCalendar() throws Exception { + accessorIterator.assertAccessorGetter(vector, accessor -> accessor.getTime(null), + (accessor, currentRow) -> { + Timestamp expectedTimestamp = getTimestampForVector(currentRow); + return is(new Time(expectedTimestamp.getTime())); + }); + } + + @Test + public void testShouldGetTimeReturnValidTimeWithCalendar() throws Exception { + TimeZone timeZone = TimeZone.getTimeZone(AMERICA_VANCOUVER); + Calendar calendar = Calendar.getInstance(timeZone); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final Time resultWithoutCalendar = accessor.getTime(null); + final Time result = accessor.getTime(calendar); + + long offset = timeZone.getOffset(resultWithoutCalendar.getTime()); + + collector.checkThat(resultWithoutCalendar.getTime() - result.getTime(), is(offset)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + + @Test + public void testShouldGetTimeReturnNull() { + vector.setNull(0); + ArrowFlightJdbcTimeVectorAccessor accessor = accessorSupplier.supply(vector, () -> 0); + collector.checkThat(accessor.getTime(null), CoreMatchers.equalTo(null)); + collector.checkThat(accessor.wasNull(), is(true)); + } + + private Timestamp getTimestampForVector(int currentRow) { + Object object = vector.getObject(currentRow); + + Timestamp expectedTimestamp = null; + if (object instanceof LocalDateTime) { + expectedTimestamp = Timestamp.valueOf((LocalDateTime) object); + } else if (object instanceof Number) { + long value = ((Number) object).longValue(); + TimeUnit timeUnit = getTimeUnitForVector(vector); + long millis = timeUnit.toMillis(value); + expectedTimestamp = new Timestamp(millis); + } + return expectedTimestamp; + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcTimeVectorAccessor::getObjectClass, + equalTo(Time.class)); + } + + @Test + public void testShouldGetStringBeConsistentWithVarCharAccessorWithoutCalendar() throws Exception { + assertGetStringIsConsistentWithVarCharAccessor(null); + } + + @Test + public void testShouldGetStringBeConsistentWithVarCharAccessorWithCalendar() throws Exception { + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone(AMERICA_VANCOUVER)); + assertGetStringIsConsistentWithVarCharAccessor(calendar); + } + + @Test + public void testValidateGetStringTimeZoneConsistency() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final TimeZone defaultTz = TimeZone.getDefault(); + try { + final String string = accessor.getString(); // Should always be UTC as no calendar is provided + + // Validate with UTC + Time time = accessor.getTime(null); + TimeZone.setDefault(TimeZone.getTimeZone("UTC")); + collector.checkThat(time.toString(), is(string)); + + // Validate with different TZ + TimeZone.setDefault(TimeZone.getTimeZone(AMERICA_VANCOUVER)); + collector.checkThat(time.toString(), not(string)); + + collector.checkThat(accessor.wasNull(), is(false)); + } finally { + // Set default Tz back + TimeZone.setDefault(defaultTz); + } + }); + } + + private void assertGetStringIsConsistentWithVarCharAccessor(Calendar calendar) throws Exception { + try (VarCharVector varCharVector = new VarCharVector("", + rootAllocatorTestRule.getRootAllocator())) { + varCharVector.allocateNew(1); + ArrowFlightJdbcVarCharVectorAccessor varCharVectorAccessor = + new ArrowFlightJdbcVarCharVectorAccessor(varCharVector, () -> 0, (boolean wasNull) -> { + }); + + accessorIterator.iterate(vector, (accessor, currentRow) -> { + final String string = accessor.getString(); + varCharVector.set(0, new Text(string)); + varCharVector.setValueCount(1); + + Time timeFromVarChar = varCharVectorAccessor.getTime(calendar); + Time time = accessor.getTime(calendar); + + collector.checkThat(time, is(timeFromVarChar)); + collector.checkThat(accessor.wasNull(), is(false)); + }); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcListAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcListAccessorTest.java new file mode 100644 index 00000000000..b2eb8f1dbee --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcListAccessorTest.java @@ -0,0 +1,185 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import static org.hamcrest.CoreMatchers.equalTo; + +import java.sql.Array; +import java.sql.ResultSet; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.ListVector; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class AbstractArrowFlightJdbcListAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private final Supplier vectorSupplier; + private ValueVector vector; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> { + ArrowFlightJdbcAccessorFactory.WasNullConsumer noOpWasNullConsumer = (boolean wasNull) -> { + }; + if (vector instanceof ListVector) { + return new ArrowFlightJdbcListVectorAccessor((ListVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof LargeListVector) { + return new ArrowFlightJdbcLargeListVectorAccessor((LargeListVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof FixedSizeListVector) { + return new ArrowFlightJdbcFixedSizeListVectorAccessor((FixedSizeListVector) vector, + getCurrentRow, noOpWasNullConsumer); + } + return null; + }; + + final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {(Supplier) () -> rootAllocatorTestRule.createListVector(), "ListVector"}, + {(Supplier) () -> rootAllocatorTestRule.createLargeListVector(), + "LargeListVector"}, + {(Supplier) () -> rootAllocatorTestRule.createFixedSizeListVector(), + "FixedSizeListVector"}, + }); + } + + public AbstractArrowFlightJdbcListAccessorTest(Supplier vectorSupplier, + String vectorType) { + this.vectorSupplier = vectorSupplier; + } + + @Before + public void setup() { + this.vector = this.vectorSupplier.get(); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void testShouldGetObjectClassReturnCorrectClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, + AbstractArrowFlightJdbcListVectorAccessor::getObjectClass, + (accessor, currentRow) -> equalTo(List.class)); + } + + @Test + public void testShouldGetObjectReturnValidList() throws Exception { + accessorIterator.assertAccessorGetter(vector, + AbstractArrowFlightJdbcListVectorAccessor::getObject, + (accessor, currentRow) -> equalTo( + Arrays.asList(0, (currentRow), (currentRow) * 2, (currentRow) * 3, (currentRow) * 4))); + } + + @Test + public void testShouldGetObjectReturnNull() throws Exception { + vector.clear(); + vector.allocateNewSafe(); + vector.setValueCount(5); + + accessorIterator.assertAccessorGetter(vector, + AbstractArrowFlightJdbcListVectorAccessor::getObject, + (accessor, currentRow) -> CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetArrayReturnValidArray() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + Array array = accessor.getArray(); + assert array != null; + + Object[] arrayObject = (Object[]) array.getArray(); + + collector.checkThat(arrayObject, equalTo( + new Object[] {0, currentRow, (currentRow) * 2, (currentRow) * 3, (currentRow) * 4})); + }); + } + + @Test + public void testShouldGetArrayReturnNull() throws Exception { + vector.clear(); + vector.allocateNewSafe(); + vector.setValueCount(5); + + accessorIterator.assertAccessorGetter(vector, + AbstractArrowFlightJdbcListVectorAccessor::getArray, + CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetArrayReturnValidArrayPassingOffsets() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + Array array = accessor.getArray(); + assert array != null; + + Object[] arrayObject = (Object[]) array.getArray(1, 3); + + collector.checkThat(arrayObject, equalTo( + new Object[] {currentRow, (currentRow) * 2, (currentRow) * 3})); + }); + } + + @Test + public void testShouldGetArrayGetResultSetReturnValidResultSet() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + Array array = accessor.getArray(); + assert array != null; + + try (ResultSet rs = array.getResultSet()) { + int count = 0; + while (rs.next()) { + final int value = rs.getInt(1); + collector.checkThat(value, equalTo(currentRow * count)); + count++; + } + collector.checkThat(count, equalTo(5)); + } + }); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcUnionVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcUnionVectorAccessorTest.java new file mode 100644 index 00000000000..2b53b27dc9e --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/AbstractArrowFlightJdbcUnionVectorAccessorTest.java @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.SQLException; +import java.util.Calendar; +import java.util.Map; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.ArrowFlightJdbcNullVectorAccessor; +import org.apache.arrow.vector.NullVector; +import org.apache.arrow.vector.ValueVector; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.Spy; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class AbstractArrowFlightJdbcUnionVectorAccessorTest { + + @Mock + ArrowFlightJdbcAccessor innerAccessor; + @Spy + AbstractArrowFlightJdbcUnionVectorAccessorMock accessor; + + @Before + public void setup() { + when(accessor.getAccessor()).thenReturn(innerAccessor); + } + + @Test + public void testGetNCharacterStreamUsesSpecificAccessor() throws SQLException { + accessor.getNCharacterStream(); + verify(innerAccessor).getNCharacterStream(); + } + + @Test + public void testGetNStringUsesSpecificAccessor() throws SQLException { + accessor.getNString(); + verify(innerAccessor).getNString(); + } + + @Test + public void testGetSQLXMLUsesSpecificAccessor() throws SQLException { + accessor.getSQLXML(); + verify(innerAccessor).getSQLXML(); + } + + @Test + public void testGetNClobUsesSpecificAccessor() throws SQLException { + accessor.getNClob(); + verify(innerAccessor).getNClob(); + } + + @Test + public void testGetURLUsesSpecificAccessor() throws SQLException { + accessor.getURL(); + verify(innerAccessor).getURL(); + } + + @Test + public void testGetStructUsesSpecificAccessor() throws SQLException { + accessor.getStruct(); + verify(innerAccessor).getStruct(); + } + + @Test + public void testGetArrayUsesSpecificAccessor() throws SQLException { + accessor.getArray(); + verify(innerAccessor).getArray(); + } + + @Test + public void testGetClobUsesSpecificAccessor() throws SQLException { + accessor.getClob(); + verify(innerAccessor).getClob(); + } + + @Test + public void testGetBlobUsesSpecificAccessor() throws SQLException { + accessor.getBlob(); + verify(innerAccessor).getBlob(); + } + + @Test + public void testGetRefUsesSpecificAccessor() throws SQLException { + accessor.getRef(); + verify(innerAccessor).getRef(); + } + + @Test + public void testGetCharacterStreamUsesSpecificAccessor() throws SQLException { + accessor.getCharacterStream(); + verify(innerAccessor).getCharacterStream(); + } + + @Test + public void testGetBinaryStreamUsesSpecificAccessor() throws SQLException { + accessor.getBinaryStream(); + verify(innerAccessor).getBinaryStream(); + } + + @Test + public void testGetUnicodeStreamUsesSpecificAccessor() throws SQLException { + accessor.getUnicodeStream(); + verify(innerAccessor).getUnicodeStream(); + } + + @Test + public void testGetAsciiStreamUsesSpecificAccessor() throws SQLException { + accessor.getAsciiStream(); + verify(innerAccessor).getAsciiStream(); + } + + @Test + public void testGetBytesUsesSpecificAccessor() throws SQLException { + accessor.getBytes(); + verify(innerAccessor).getBytes(); + } + + @Test + public void testGetBigDecimalUsesSpecificAccessor() throws SQLException { + accessor.getBigDecimal(); + verify(innerAccessor).getBigDecimal(); + } + + @Test + public void testGetDoubleUsesSpecificAccessor() throws SQLException { + accessor.getDouble(); + verify(innerAccessor).getDouble(); + } + + @Test + public void testGetFloatUsesSpecificAccessor() throws SQLException { + accessor.getFloat(); + verify(innerAccessor).getFloat(); + } + + @Test + public void testGetLongUsesSpecificAccessor() throws SQLException { + accessor.getLong(); + verify(innerAccessor).getLong(); + } + + @Test + public void testGetIntUsesSpecificAccessor() throws SQLException { + accessor.getInt(); + verify(innerAccessor).getInt(); + } + + @Test + public void testGetShortUsesSpecificAccessor() throws SQLException { + accessor.getShort(); + verify(innerAccessor).getShort(); + } + + @Test + public void testGetByteUsesSpecificAccessor() throws SQLException { + accessor.getByte(); + verify(innerAccessor).getByte(); + } + + @Test + public void testGetBooleanUsesSpecificAccessor() throws SQLException { + accessor.getBoolean(); + verify(innerAccessor).getBoolean(); + } + + @Test + public void testGetStringUsesSpecificAccessor() throws SQLException { + accessor.getString(); + verify(innerAccessor).getString(); + } + + @Test + public void testGetObjectClassUsesSpecificAccessor() { + accessor.getObjectClass(); + verify(innerAccessor).getObjectClass(); + } + + @Test + public void testGetObjectWithClassUsesSpecificAccessor() throws SQLException { + accessor.getObject(Object.class); + verify(innerAccessor).getObject(Object.class); + } + + @Test + public void testGetTimestampUsesSpecificAccessor() throws SQLException { + Calendar calendar = Calendar.getInstance(); + accessor.getTimestamp(calendar); + verify(innerAccessor).getTimestamp(calendar); + } + + @Test + public void testGetTimeUsesSpecificAccessor() throws SQLException { + Calendar calendar = Calendar.getInstance(); + accessor.getTime(calendar); + verify(innerAccessor).getTime(calendar); + } + + @Test + public void testGetDateUsesSpecificAccessor() throws SQLException { + Calendar calendar = Calendar.getInstance(); + accessor.getDate(calendar); + verify(innerAccessor).getDate(calendar); + } + + @Test + public void testGetObjectUsesSpecificAccessor() throws SQLException { + Map> map = mock(Map.class); + accessor.getObject(map); + verify(innerAccessor).getObject(map); + } + + @Test + public void testGetBigDecimalWithScaleUsesSpecificAccessor() throws SQLException { + accessor.getBigDecimal(2); + verify(innerAccessor).getBigDecimal(2); + } + + private static class AbstractArrowFlightJdbcUnionVectorAccessorMock + extends AbstractArrowFlightJdbcUnionVectorAccessor { + protected AbstractArrowFlightJdbcUnionVectorAccessorMock() { + super(() -> 0, (boolean wasNull) -> { + }); + } + + @Override + protected ArrowFlightJdbcAccessor createAccessorForVector(ValueVector vector) { + return new ArrowFlightJdbcNullVectorAccessor((boolean wasNull) -> { + }); + } + + @Override + protected byte getCurrentTypeId() { + return 0; + } + + @Override + protected ValueVector getVectorByTypeId(byte typeId) { + return new NullVector(); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcDenseUnionVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcDenseUnionVectorAccessorTest.java new file mode 100644 index 00000000000..41d5eb97e85 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcDenseUnionVectorAccessorTest.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.List; + +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.holders.NullableBigIntHolder; +import org.apache.arrow.vector.holders.NullableFloat8Holder; +import org.apache.arrow.vector.holders.NullableTimeStampMilliHolder; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.Field; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightJdbcDenseUnionVectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private DenseUnionVector vector; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = + (vector, getCurrentRow) -> new ArrowFlightJdbcDenseUnionVectorAccessor( + (DenseUnionVector) vector, getCurrentRow, (boolean wasNull) -> { + //No Operation + }); + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Before + public void setup() throws Exception { + this.vector = DenseUnionVector.empty("", rootAllocatorTestRule.getRootAllocator()); + this.vector.allocateNew(); + + // write some data + byte bigIntTypeId = + this.vector.registerNewTypeId(Field.nullable("", Types.MinorType.BIGINT.getType())); + byte float8TypeId = + this.vector.registerNewTypeId(Field.nullable("", Types.MinorType.FLOAT8.getType())); + byte timestampMilliTypeId = + this.vector.registerNewTypeId(Field.nullable("", Types.MinorType.TIMESTAMPMILLI.getType())); + + NullableBigIntHolder nullableBigIntHolder = new NullableBigIntHolder(); + nullableBigIntHolder.isSet = 1; + nullableBigIntHolder.value = Long.MAX_VALUE; + this.vector.setTypeId(0, bigIntTypeId); + this.vector.setSafe(0, nullableBigIntHolder); + + NullableFloat8Holder nullableFloat4Holder = new NullableFloat8Holder(); + nullableFloat4Holder.isSet = 1; + nullableFloat4Holder.value = Math.PI; + this.vector.setTypeId(1, float8TypeId); + this.vector.setSafe(1, nullableFloat4Holder); + + NullableTimeStampMilliHolder nullableTimeStampMilliHolder = new NullableTimeStampMilliHolder(); + nullableTimeStampMilliHolder.isSet = 1; + nullableTimeStampMilliHolder.value = 1625702400000L; + this.vector.setTypeId(2, timestampMilliTypeId); + this.vector.setSafe(2, nullableTimeStampMilliHolder); + + nullableBigIntHolder.isSet = 0; + this.vector.setTypeId(3, bigIntTypeId); + this.vector.setSafe(3, nullableBigIntHolder); + + this.vector.setValueCount(5); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void getObject() throws Exception { + List result = accessorIterator.toList(vector); + List expected = Arrays.asList( + Long.MAX_VALUE, + Math.PI, + new Timestamp(1625702400000L), + null, + null); + + collector.checkThat(result, is(expected)); + } + + @Test + public void getObjectForNull() throws Exception { + vector.reset(); + vector.setValueCount(5); + accessorIterator.assertAccessorGetter(vector, + AbstractArrowFlightJdbcUnionVectorAccessor::getObject, equalTo(null)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcMapVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcMapVectorAccessorTest.java new file mode 100644 index 00000000000..7a81da4240b --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcMapVectorAccessorTest.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import java.sql.Array; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Map; + +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.impl.UnionMapWriter; +import org.apache.arrow.vector.util.JsonStringHashMap; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightJdbcMapVectorAccessorTest { + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private MapVector vector; + + @Before + public void setup() { + vector = MapVector.empty("", rootAllocatorTestRule.getRootAllocator(), false); + UnionMapWriter writer = vector.getWriter(); + writer.allocate(); + writer.setPosition(0); // optional + writer.startMap(); + writer.startEntry(); + writer.key().integer().writeInt(1); + writer.value().integer().writeInt(11); + writer.endEntry(); + writer.startEntry(); + writer.key().integer().writeInt(2); + writer.value().integer().writeInt(22); + writer.endEntry(); + writer.startEntry(); + writer.key().integer().writeInt(3); + writer.value().integer().writeInt(33); + writer.endEntry(); + writer.endMap(); + + writer.setPosition(1); + writer.startMap(); + writer.startEntry(); + writer.key().integer().writeInt(2); + writer.endEntry(); + writer.endMap(); + + writer.setPosition(2); + writer.startMap(); + writer.startEntry(); + writer.key().integer().writeInt(0); + writer.value().integer().writeInt(2000); + writer.endEntry(); + writer.startEntry(); + writer.key().integer().writeInt(1); + writer.value().integer().writeInt(2001); + writer.endEntry(); + writer.startEntry(); + writer.key().integer().writeInt(2); + writer.value().integer().writeInt(2002); + writer.endEntry(); + writer.startEntry(); + writer.key().integer().writeInt(3); + writer.value().integer().writeInt(2003); + writer.endEntry(); + writer.endMap(); + + writer.setValueCount(3); + } + + @After + public void tearDown() { + vector.close(); + } + + @Test + public void testShouldGetObjectReturnValidMap() { + AccessorTestUtils.Cursor cursor = new AccessorTestUtils.Cursor(vector.getValueCount()); + ArrowFlightJdbcMapVectorAccessor accessor = + new ArrowFlightJdbcMapVectorAccessor(vector, cursor::getCurrentRow, (boolean wasNull) -> { + }); + + Map expected = new JsonStringHashMap<>(); + expected.put(1, 11); + expected.put(2, 22); + expected.put(3, 33); + Assert.assertEquals(expected, accessor.getObject()); + Assert.assertFalse(accessor.wasNull()); + + cursor.next(); + expected = new JsonStringHashMap<>(); + expected.put(2, null); + Assert.assertEquals(expected, accessor.getObject()); + Assert.assertFalse(accessor.wasNull()); + + cursor.next(); + expected = new JsonStringHashMap<>(); + expected.put(0, 2000); + expected.put(1, 2001); + expected.put(2, 2002); + expected.put(3, 2003); + Assert.assertEquals(expected, accessor.getObject()); + Assert.assertFalse(accessor.wasNull()); + } + + @Test + public void testShouldGetObjectReturnNull() { + vector.setNull(0); + ArrowFlightJdbcMapVectorAccessor accessor = + new ArrowFlightJdbcMapVectorAccessor(vector, () -> 0, (boolean wasNull) -> { + }); + + Assert.assertNull(accessor.getObject()); + Assert.assertTrue(accessor.wasNull()); + } + + @Test + public void testShouldGetArrayReturnValidArray() throws SQLException { + AccessorTestUtils.Cursor cursor = new AccessorTestUtils.Cursor(vector.getValueCount()); + ArrowFlightJdbcMapVectorAccessor accessor = + new ArrowFlightJdbcMapVectorAccessor(vector, cursor::getCurrentRow, (boolean wasNull) -> { + }); + + Array array = accessor.getArray(); + Assert.assertNotNull(array); + Assert.assertFalse(accessor.wasNull()); + + try (ResultSet resultSet = array.getResultSet()) { + Assert.assertTrue(resultSet.next()); + Map entry = resultSet.getObject(1, Map.class); + Assert.assertEquals(1, entry.get("key")); + Assert.assertEquals(11, entry.get("value")); + Assert.assertTrue(resultSet.next()); + entry = resultSet.getObject(1, Map.class); + Assert.assertEquals(2, entry.get("key")); + Assert.assertEquals(22, entry.get("value")); + Assert.assertTrue(resultSet.next()); + entry = resultSet.getObject(1, Map.class); + Assert.assertEquals(3, entry.get("key")); + Assert.assertEquals(33, entry.get("value")); + Assert.assertFalse(resultSet.next()); + } + + cursor.next(); + array = accessor.getArray(); + Assert.assertNotNull(array); + Assert.assertFalse(accessor.wasNull()); + try (ResultSet resultSet = array.getResultSet()) { + Assert.assertTrue(resultSet.next()); + Map entry = resultSet.getObject(1, Map.class); + Assert.assertEquals(2, entry.get("key")); + Assert.assertNull(entry.get("value")); + Assert.assertFalse(resultSet.next()); + } + + cursor.next(); + array = accessor.getArray(); + Assert.assertNotNull(array); + Assert.assertFalse(accessor.wasNull()); + try (ResultSet resultSet = array.getResultSet()) { + Assert.assertTrue(resultSet.next()); + Map entry = resultSet.getObject(1, Map.class); + Assert.assertEquals(0, entry.get("key")); + Assert.assertEquals(2000, entry.get("value")); + Assert.assertTrue(resultSet.next()); + entry = resultSet.getObject(1, Map.class); + Assert.assertEquals(1, entry.get("key")); + Assert.assertEquals(2001, entry.get("value")); + Assert.assertTrue(resultSet.next()); + entry = resultSet.getObject(1, Map.class); + Assert.assertEquals(2, entry.get("key")); + Assert.assertEquals(2002, entry.get("value")); + Assert.assertTrue(resultSet.next()); + entry = resultSet.getObject(1, Map.class); + Assert.assertEquals(3, entry.get("key")); + Assert.assertEquals(2003, entry.get("value")); + Assert.assertFalse(resultSet.next()); + } + } + + @Test + public void testShouldGetArrayReturnNull() { + vector.setNull(0); + ((StructVector) vector.getDataVector()).setNull(0); + + ArrowFlightJdbcMapVectorAccessor accessor = + new ArrowFlightJdbcMapVectorAccessor(vector, () -> 0, (boolean wasNull) -> { + }); + + Assert.assertNull(accessor.getArray()); + Assert.assertTrue(accessor.wasNull()); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcStructVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcStructVectorAccessorTest.java new file mode 100644 index 00000000000..b3c85fc0ab1 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcStructVectorAccessorTest.java @@ -0,0 +1,209 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.nullValue; + +import java.sql.SQLException; +import java.sql.Struct; +import java.util.HashMap; +import java.util.Map; + +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.UnionVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.holders.NullableBitHolder; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.util.JsonStringArrayList; +import org.apache.arrow.vector.util.JsonStringHashMap; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightJdbcStructVectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private StructVector vector; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = + (vector, getCurrentRow) -> new ArrowFlightJdbcStructVectorAccessor((StructVector) vector, + getCurrentRow, (boolean wasNull) -> { + }); + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Before + public void setUp() throws Exception { + Map metadata = new HashMap<>(); + metadata.put("k1", "v1"); + FieldType type = new FieldType(true, ArrowType.Struct.INSTANCE, null, metadata); + vector = new StructVector("", rootAllocatorTestRule.getRootAllocator(), type, null); + vector.allocateNew(); + + IntVector intVector = + vector.addOrGet("int", FieldType.nullable(Types.MinorType.INT.getType()), IntVector.class); + Float8Vector float8Vector = + vector.addOrGet("float8", FieldType.nullable(Types.MinorType.FLOAT8.getType()), + Float8Vector.class); + + intVector.setSafe(0, 100); + float8Vector.setSafe(0, 100.05); + vector.setIndexDefined(0); + intVector.setSafe(1, 200); + float8Vector.setSafe(1, 200.1); + vector.setIndexDefined(1); + + vector.setValueCount(2); + } + + @After + public void tearDown() throws Exception { + vector.close(); + } + + @Test + public void testShouldGetObjectClassReturnMapClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, + ArrowFlightJdbcStructVectorAccessor::getObjectClass, + (accessor, currentRow) -> equalTo(Map.class)); + } + + @Test + public void testShouldGetObjectReturnValidMap() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcStructVectorAccessor::getObject, + (accessor, currentRow) -> { + Map expected = new HashMap<>(); + expected.put("int", 100 * (currentRow + 1)); + expected.put("float8", 100.05 * (currentRow + 1)); + + return equalTo(expected); + }); + } + + @Test + public void testShouldGetObjectReturnNull() throws Exception { + vector.setNull(0); + vector.setNull(1); + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcStructVectorAccessor::getObject, + (accessor, currentRow) -> nullValue()); + } + + @Test + public void testShouldGetStructReturnValidStruct() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + Struct struct = accessor.getStruct(); + assert struct != null; + + Object[] expected = new Object[] { + 100 * (currentRow + 1), + 100.05 * (currentRow + 1) + }; + + collector.checkThat(struct.getAttributes(), equalTo(expected)); + }); + } + + @Test + public void testShouldGetStructReturnNull() throws Exception { + vector.setNull(0); + vector.setNull(1); + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcStructVectorAccessor::getStruct, + (accessor, currentRow) -> nullValue()); + } + + @Test + public void testShouldGetObjectWorkWithNestedComplexData() throws SQLException { + try (StructVector rootVector = StructVector.empty("", + rootAllocatorTestRule.getRootAllocator())) { + StructVector structVector = rootVector.addOrGetStruct("struct"); + + FieldType intFieldType = FieldType.nullable(Types.MinorType.INT.getType()); + IntVector intVector = structVector.addOrGet("int", intFieldType, IntVector.class); + FieldType float8FieldType = FieldType.nullable(Types.MinorType.FLOAT8.getType()); + Float8Vector float8Vector = + structVector.addOrGet("float8", float8FieldType, Float8Vector.class); + + ListVector listVector = rootVector.addOrGetList("list"); + UnionListWriter listWriter = listVector.getWriter(); + listWriter.allocate(); + + UnionVector unionVector = rootVector.addOrGetUnion("union"); + + intVector.setSafe(0, 100); + intVector.setValueCount(1); + float8Vector.setSafe(0, 100.05); + float8Vector.setValueCount(1); + structVector.setIndexDefined(0); + + listWriter.setPosition(0); + listWriter.startList(); + listWriter.bigInt().writeBigInt(Long.MAX_VALUE); + listWriter.bigInt().writeBigInt(Long.MIN_VALUE); + listWriter.endList(); + listVector.setValueCount(1); + + unionVector.setType(0, Types.MinorType.BIT); + NullableBitHolder holder = new NullableBitHolder(); + holder.isSet = 1; + holder.value = 1; + unionVector.setSafe(0, holder); + unionVector.setValueCount(1); + + rootVector.setIndexDefined(0); + rootVector.setValueCount(1); + + Map expected = new JsonStringHashMap<>(); + Map nestedStruct = new JsonStringHashMap<>(); + nestedStruct.put("int", 100); + nestedStruct.put("float8", 100.05); + expected.put("struct", nestedStruct); + JsonStringArrayList nestedList = new JsonStringArrayList<>(); + nestedList.add(Long.MAX_VALUE); + nestedList.add(Long.MIN_VALUE); + expected.put("list", nestedList); + expected.put("union", true); + + ArrowFlightJdbcStructVectorAccessor accessor = + new ArrowFlightJdbcStructVectorAccessor(rootVector, () -> 0, (boolean wasNull) -> { + }); + + Assert.assertEquals(accessor.getObject(), expected); + Assert.assertEquals(accessor.getString(), expected.toString()); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcUnionVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcUnionVectorAccessorTest.java new file mode 100644 index 00000000000..9ec9388ff87 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/complex/ArrowFlightJdbcUnionVectorAccessorTest.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.complex; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.List; + +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.complex.UnionVector; +import org.apache.arrow.vector.holders.NullableBigIntHolder; +import org.apache.arrow.vector.holders.NullableFloat8Holder; +import org.apache.arrow.vector.holders.NullableTimeStampMilliHolder; +import org.apache.arrow.vector.types.Types; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightJdbcUnionVectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private UnionVector vector; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = + (vector, getCurrentRow) -> new ArrowFlightJdbcUnionVectorAccessor((UnionVector) vector, + getCurrentRow, (boolean wasNull) -> { + }); + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Before + public void setup() { + this.vector = UnionVector.empty("", rootAllocatorTestRule.getRootAllocator()); + this.vector.allocateNew(); + + NullableBigIntHolder nullableBigIntHolder = new NullableBigIntHolder(); + nullableBigIntHolder.isSet = 1; + nullableBigIntHolder.value = Long.MAX_VALUE; + this.vector.setType(0, Types.MinorType.BIGINT); + this.vector.setSafe(0, nullableBigIntHolder); + + NullableFloat8Holder nullableFloat4Holder = new NullableFloat8Holder(); + nullableFloat4Holder.isSet = 1; + nullableFloat4Holder.value = Math.PI; + this.vector.setType(1, Types.MinorType.FLOAT8); + this.vector.setSafe(1, nullableFloat4Holder); + + NullableTimeStampMilliHolder nullableTimeStampMilliHolder = new NullableTimeStampMilliHolder(); + nullableTimeStampMilliHolder.isSet = 1; + nullableTimeStampMilliHolder.value = 1625702400000L; + this.vector.setType(2, Types.MinorType.TIMESTAMPMILLI); + this.vector.setSafe(2, nullableTimeStampMilliHolder); + + nullableBigIntHolder.isSet = 0; + this.vector.setType(3, Types.MinorType.BIGINT); + this.vector.setSafe(3, nullableBigIntHolder); + + this.vector.setValueCount(5); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void getObject() throws Exception { + List result = accessorIterator.toList(vector); + List expected = Arrays.asList( + Long.MAX_VALUE, + Math.PI, + new Timestamp(1625702400000L), + null, + null); + + collector.checkThat(result, is(expected)); + } + + @Test + public void getObjectForNull() throws Exception { + vector.reset(); + vector.setValueCount(5); + + accessorIterator.assertAccessorGetter(vector, + AbstractArrowFlightJdbcUnionVectorAccessor::getObject, + equalTo(null)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessorTest.java new file mode 100644 index 00000000000..5e54b545a85 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessorTest.java @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import static org.hamcrest.CoreMatchers.equalTo; + +import java.util.Arrays; +import java.util.Collection; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.BaseIntVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + + +@RunWith(Parameterized.class) +public class ArrowFlightJdbcBaseIntVectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private BaseIntVector vector; + private final Supplier vectorSupplier; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> { + ArrowFlightJdbcAccessorFactory.WasNullConsumer noOpWasNullConsumer = (boolean wasNull) -> { + }; + if (vector instanceof UInt1Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt1Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof UInt2Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt2Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } else { + if (vector instanceof UInt4Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt4Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof UInt8Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt8Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof TinyIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((TinyIntVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof SmallIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((SmallIntVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof IntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((IntVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof BigIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((BigIntVector) vector, getCurrentRow, + noOpWasNullConsumer); + } + } + throw new UnsupportedOperationException(); + }; + + private final AccessorTestUtils.AccessorIterator + accessorIterator = new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {(Supplier) () -> rootAllocatorTestRule.createIntVector(), "IntVector"}, + {(Supplier) () -> rootAllocatorTestRule.createSmallIntVector(), + "SmallIntVector"}, + {(Supplier) () -> rootAllocatorTestRule.createTinyIntVector(), + "TinyIntVector"}, + {(Supplier) () -> rootAllocatorTestRule.createBigIntVector(), + "BigIntVector"}, + {(Supplier) () -> rootAllocatorTestRule.createUInt1Vector(), "UInt1Vector"}, + {(Supplier) () -> rootAllocatorTestRule.createUInt2Vector(), "UInt2Vector"}, + {(Supplier) () -> rootAllocatorTestRule.createUInt4Vector(), "UInt4Vector"}, + {(Supplier) () -> rootAllocatorTestRule.createUInt8Vector(), "UInt8Vector"} + }); + } + + public ArrowFlightJdbcBaseIntVectorAccessorTest(Supplier vectorSupplier, + String vectorType) { + this.vectorSupplier = vectorSupplier; + } + + @Before + public void setup() { + this.vector = vectorSupplier.get(); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void testShouldConvertToByteMethodFromBaseIntVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBaseIntVectorAccessor::getByte, + (accessor, currentRow) -> equalTo((byte) accessor.getLong())); + } + + @Test + public void testShouldConvertToShortMethodFromBaseIntVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBaseIntVectorAccessor::getShort, + (accessor, currentRow) -> equalTo((short) accessor.getLong())); + } + + @Test + public void testShouldConvertToIntegerMethodFromBaseIntVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBaseIntVectorAccessor::getInt, + (accessor, currentRow) -> equalTo((int) accessor.getLong())); + } + + @Test + public void testShouldConvertToFloatMethodFromBaseIntVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBaseIntVectorAccessor::getFloat, + (accessor, currentRow) -> equalTo((float) accessor.getLong())); + } + + @Test + public void testShouldConvertToDoubleMethodFromBaseIntVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBaseIntVectorAccessor::getDouble, + (accessor, currentRow) -> equalTo((double) accessor.getLong())); + } + + @Test + public void testShouldConvertToBooleanMethodFromBaseIntVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBaseIntVectorAccessor::getBoolean, + (accessor, currentRow) -> equalTo(accessor.getLong() != 0L)); + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, + ArrowFlightJdbcBaseIntVectorAccessor::getObjectClass, + equalTo(Long.class)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessorUnitTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessorUnitTest.java new file mode 100644 index 00000000000..2e64b6fb402 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBaseIntVectorAccessorUnitTest.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import static org.hamcrest.CoreMatchers.equalTo; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.hamcrest.CoreMatchers; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class ArrowFlightJdbcBaseIntVectorAccessorUnitTest { + + @ClassRule + public static RootAllocatorTestRule rule = new RootAllocatorTestRule(); + private static UInt4Vector int4Vector; + private static UInt8Vector int8Vector; + private static IntVector intVectorWithNull; + private static TinyIntVector tinyIntVector; + private static SmallIntVector smallIntVector; + private static IntVector intVector; + private static BigIntVector bigIntVector; + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> { + ArrowFlightJdbcAccessorFactory.WasNullConsumer noOpWasNullConsumer = (boolean wasNull) -> { + }; + if (vector instanceof UInt1Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt1Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof UInt2Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt2Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof UInt4Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt4Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof UInt8Vector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((UInt8Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof TinyIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((TinyIntVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof SmallIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((SmallIntVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof IntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((IntVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof BigIntVector) { + return new ArrowFlightJdbcBaseIntVectorAccessor((BigIntVector) vector, getCurrentRow, + noOpWasNullConsumer); + } + return null; + }; + + private final AccessorTestUtils.AccessorIterator + accessorIterator = new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @BeforeClass + public static void setup() { + int4Vector = new UInt4Vector("ID", rule.getRootAllocator()); + int4Vector.setSafe(0, 0x80000001); + int4Vector.setValueCount(1); + + int8Vector = new UInt8Vector("ID", rule.getRootAllocator()); + int8Vector.setSafe(0, 0xFFFFFFFFFFFFFFFFL); + int8Vector.setValueCount(1); + + intVectorWithNull = new IntVector("ID", rule.getRootAllocator()); + intVectorWithNull.setNull(0); + intVectorWithNull.setValueCount(1); + + tinyIntVector = new TinyIntVector("ID", rule.getRootAllocator()); + tinyIntVector.setSafe(0, 0xAA); + tinyIntVector.setValueCount(1); + + smallIntVector = new SmallIntVector("ID", rule.getRootAllocator()); + smallIntVector.setSafe(0, 0xAABB); + smallIntVector.setValueCount(1); + + intVector = new IntVector("ID", rule.getRootAllocator()); + intVector.setSafe(0, 0xAABBCCDD); + intVector.setValueCount(1); + + bigIntVector = new BigIntVector("ID", rule.getRootAllocator()); + bigIntVector.setSafe(0, 0xAABBCCDDEEFFAABBL); + bigIntVector.setValueCount(1); + } + + @AfterClass + public static void tearDown() throws Exception { + AutoCloseables.close(bigIntVector, intVector, smallIntVector, tinyIntVector, int4Vector, + int8Vector, intVectorWithNull, rule); + } + + @Test + public void testShouldGetStringFromUnsignedValue() throws Exception { + accessorIterator.assertAccessorGetter(int8Vector, + ArrowFlightJdbcBaseIntVectorAccessor::getString, equalTo("18446744073709551615")); + } + + @Test + public void testShouldGetBytesFromIntVectorThrowsSqlException() throws Exception { + accessorIterator.assertAccessorGetterThrowingException(intVector, ArrowFlightJdbcBaseIntVectorAccessor::getBytes); + } + + @Test + public void testShouldGetStringFromIntVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(intVectorWithNull, + ArrowFlightJdbcBaseIntVectorAccessor::getString, CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetObjectFromInt() throws Exception { + accessorIterator.assertAccessorGetter(intVector, + ArrowFlightJdbcBaseIntVectorAccessor::getObject, equalTo(0xAABBCCDD)); + } + + @Test + public void testShouldGetObjectFromTinyInt() throws Exception { + accessorIterator.assertAccessorGetter(tinyIntVector, + ArrowFlightJdbcBaseIntVectorAccessor::getObject, equalTo((byte) 0xAA)); + } + + @Test + public void testShouldGetObjectFromSmallInt() throws Exception { + accessorIterator.assertAccessorGetter(smallIntVector, + ArrowFlightJdbcBaseIntVectorAccessor::getObject, equalTo((short) 0xAABB)); + } + + @Test + public void testShouldGetObjectFromBigInt() throws Exception { + accessorIterator.assertAccessorGetter(bigIntVector, + ArrowFlightJdbcBaseIntVectorAccessor::getObject, equalTo(0xAABBCCDDEEFFAABBL)); + } + + @Test + public void testShouldGetObjectFromUnsignedInt() throws Exception { + accessorIterator.assertAccessorGetter(int4Vector, + ArrowFlightJdbcBaseIntVectorAccessor::getObject, equalTo(0x80000001)); + } + + @Test + public void testShouldGetObjectFromIntVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(intVectorWithNull, + ArrowFlightJdbcBaseIntVectorAccessor::getObject, CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetBigDecimalFromIntVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(intVectorWithNull, + ArrowFlightJdbcBaseIntVectorAccessor::getBigDecimal, CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetBigDecimalWithScaleFromIntVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(intVectorWithNull, accessor -> accessor.getBigDecimal(2), + CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetBytesFromSmallVectorThrowsSqlException() throws Exception { + accessorIterator.assertAccessorGetterThrowingException(smallIntVector, + ArrowFlightJdbcBaseIntVectorAccessor::getBytes); + } + + @Test + public void testShouldGetBytesFromTinyIntVectorThrowsSqlException() throws Exception { + accessorIterator.assertAccessorGetterThrowingException(tinyIntVector, + ArrowFlightJdbcBaseIntVectorAccessor::getBytes); + } + + @Test + public void testShouldGetBytesFromBigIntVectorThrowsSqlException() throws Exception { + accessorIterator.assertAccessorGetterThrowingException(bigIntVector, + ArrowFlightJdbcBaseIntVectorAccessor::getBytes); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessorTest.java new file mode 100644 index 00000000000..809d6e8d353 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcBitVectorAccessorTest.java @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.math.BigDecimal; + +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils.AccessorIterator; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils.CheckedFunction; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.BitVector; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class ArrowFlightJdbcBitVectorAccessorTest { + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> new ArrowFlightJdbcBitVectorAccessor((BitVector) vector, + getCurrentRow, (boolean wasNull) -> { + }); + private final AccessorIterator + accessorIterator = + new AccessorIterator<>(collector, accessorSupplier); + private BitVector vector; + private BitVector vectorWithNull; + private boolean[] arrayToAssert; + + @Before + public void setup() { + this.arrayToAssert = new boolean[] {false, true}; + this.vector = rootAllocatorTestRule.createBitVector(); + this.vectorWithNull = rootAllocatorTestRule.createBitVectorForNullTests(); + } + + @After + public void tearDown() { + this.vector.close(); + this.vectorWithNull.close(); + } + + private void iterate(final CheckedFunction function, + final T result, + final T resultIfFalse, final BitVector vector) throws Exception { + accessorIterator.assertAccessorGetter(vector, function, + ((accessor, currentRow) -> is(arrayToAssert[currentRow] ? result : resultIfFalse)) + ); + } + + @Test + public void testShouldGetBooleanMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getBoolean, true, false, vector); + } + + @Test + public void testShouldGetByteMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getByte, (byte) 1, (byte) 0, vector); + } + + @Test + public void testShouldGetShortMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getShort, (short) 1, (short) 0, vector); + } + + @Test + public void testShouldGetIntMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getInt, 1, 0, vector); + + } + + @Test + public void testShouldGetLongMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getLong, (long) 1, (long) 0, vector); + + } + + @Test + public void testShouldGetFloatMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getFloat, (float) 1, (float) 0, vector); + + } + + @Test + public void testShouldGetDoubleMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getDouble, (double) 1, (double) 0, vector); + + } + + @Test + public void testShouldGetBigDecimalMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getBigDecimal, BigDecimal.ONE, BigDecimal.ZERO, + vector); + } + + @Test + public void testShouldGetBigDecimalMethodFromBitVectorFromNull() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getBigDecimal, null, null, vectorWithNull); + + } + + @Test + public void testShouldGetObjectMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getObject, true, false, vector); + + } + + @Test + public void testShouldGetObjectMethodFromBitVectorFromNull() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getObject, null, null, vectorWithNull); + + } + + @Test + public void testShouldGetStringMethodFromBitVector() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getString, "true", "false", vector); + + } + + @Test + public void testShouldGetStringMethodFromBitVectorFromNull() throws Exception { + iterate(ArrowFlightJdbcBitVectorAccessor::getString, null, null, vectorWithNull); + + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcBitVectorAccessor::getObjectClass, + equalTo(Boolean.class)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcDecimalVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcDecimalVectorAccessorTest.java new file mode 100644 index 00000000000..b7bd7c40fef --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcDecimalVectorAccessorTest.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.Collection; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessorFactory; +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.Decimal256Vector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.ValueVector; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +@RunWith(Parameterized.class) +public class ArrowFlightJdbcDecimalVectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + private final Supplier vectorSupplier; + private ValueVector vector; + private ValueVector vectorWithNull; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = (vector, getCurrentRow) -> { + ArrowFlightJdbcAccessorFactory.WasNullConsumer noOpWasNullConsumer = (boolean wasNull) -> { + }; + if (vector instanceof DecimalVector) { + return new ArrowFlightJdbcDecimalVectorAccessor((DecimalVector) vector, getCurrentRow, + noOpWasNullConsumer); + } else if (vector instanceof Decimal256Vector) { + return new ArrowFlightJdbcDecimalVectorAccessor((Decimal256Vector) vector, getCurrentRow, + noOpWasNullConsumer); + } + return null; + }; + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Parameterized.Parameters(name = "{1}") + public static Collection data() { + return Arrays.asList(new Object[][] { + {(Supplier) () -> rootAllocatorTestRule.createDecimalVector(), + "DecimalVector"}, + {(Supplier) () -> rootAllocatorTestRule.createDecimal256Vector(), + "Decimal256Vector"}, + }); + } + + public ArrowFlightJdbcDecimalVectorAccessorTest(Supplier vectorSupplier, + String vectorType) { + this.vectorSupplier = vectorSupplier; + } + + @Before + public void setup() { + this.vector = vectorSupplier.get(); + + this.vectorWithNull = vectorSupplier.get(); + this.vectorWithNull.clear(); + this.vectorWithNull.setValueCount(5); + } + + @After + public void tearDown() { + this.vector.close(); + this.vectorWithNull.close(); + } + + @Test + public void testShouldGetBigDecimalFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, + ArrowFlightJdbcDecimalVectorAccessor::getBigDecimal, + (accessor, currentRow) -> CoreMatchers.notNullValue()); + } + + @Test + public void testShouldGetDoubleMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getDouble, + (accessor, currentRow) -> equalTo(accessor.getBigDecimal().doubleValue())); + } + + @Test + public void testShouldGetFloatMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getFloat, + (accessor, currentRow) -> equalTo(accessor.getBigDecimal().floatValue())); + } + + @Test + public void testShouldGetLongMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getLong, + (accessor, currentRow) -> equalTo(accessor.getBigDecimal().longValue())); + } + + @Test + public void testShouldGetIntMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getInt, + (accessor, currentRow) -> equalTo(accessor.getBigDecimal().intValue())); + } + + @Test + public void testShouldGetShortMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getShort, + (accessor, currentRow) -> equalTo(accessor.getBigDecimal().shortValue())); + } + + @Test + public void testShouldGetByteMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getByte, + (accessor, currentRow) -> equalTo(accessor.getBigDecimal().byteValue())); + } + + @Test + public void testShouldGetStringMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getString, + (accessor, currentRow) -> equalTo(accessor.getBigDecimal().toString())); + } + + @Test + public void testShouldGetBooleanMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getBoolean, + (accessor, currentRow) -> equalTo(!accessor.getBigDecimal().equals(BigDecimal.ZERO))); + } + + @Test + public void testShouldGetObjectMethodFromDecimalVector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcDecimalVectorAccessor::getObject, + (accessor, currentRow) -> equalTo(accessor.getBigDecimal())); + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, + ArrowFlightJdbcDecimalVectorAccessor::getObjectClass, + (accessor, currentRow) -> equalTo(BigDecimal.class)); + } + + @Test + public void testShouldGetBigDecimalMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getBigDecimal, + (accessor, currentRow) -> CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetObjectMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getObject, + (accessor, currentRow) -> CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetStringMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getString, + (accessor, currentRow) -> CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetByteMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getByte, + (accessor, currentRow) -> is((byte) 0)); + } + + @Test + public void testShouldGetShortMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getShort, + (accessor, currentRow) -> is((short) 0)); + } + + @Test + public void testShouldGetIntMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getInt, + (accessor, currentRow) -> is(0)); + } + + @Test + public void testShouldGetLongMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getLong, + (accessor, currentRow) -> is((long) 0)); + } + + @Test + public void testShouldGetFloatMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getFloat, + (accessor, currentRow) -> is(0.0f)); + } + + @Test + public void testShouldGetDoubleMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getDouble, + (accessor, currentRow) -> is(0.0D)); + } + + @Test + public void testShouldGetBooleanMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcDecimalVectorAccessor::getBoolean, + (accessor, currentRow) -> is(false)); + } + + @Test + public void testShouldGetBigDecimalWithScaleMethodFromDecimalVectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, accessor -> accessor.getBigDecimal(2), + (accessor, currentRow) -> CoreMatchers.nullValue()); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat4VectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat4VectorAccessorTest.java new file mode 100644 index 00000000000..74a65715ec0 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat4VectorAccessorTest.java @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.sql.SQLException; + +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.Float4Vector; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.rules.ExpectedException; + +public class ArrowFlightJdbcFloat4VectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @Rule + public ExpectedException exceptionCollector = ExpectedException.none(); + + private Float4Vector vector; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = + (vector, getCurrentRow) -> new ArrowFlightJdbcFloat4VectorAccessor((Float4Vector) vector, + getCurrentRow, (boolean wasNull) -> { + }); + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Before + public void setup() { + this.vector = rootAllocatorTestRule.createFloat4Vector(); + } + + @After + public void tearDown() { + this.vector.close(); + } + + @Test + public void testShouldGetFloatMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getFloat, + (accessor, currentRow) -> is(vector.get(currentRow))); + } + + @Test + public void testShouldGetObjectMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getObject, + (accessor) -> is(accessor.getFloat())); + } + + @Test + public void testShouldGetStringMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getString, + accessor -> is(Float.toString(accessor.getFloat()))); + } + + @Test + public void testShouldGetStringMethodFromFloat4VectorWithNull() throws Exception { + try (final Float4Vector float4Vector = new Float4Vector("ID", + rootAllocatorTestRule.getRootAllocator())) { + float4Vector.setNull(0); + float4Vector.setValueCount(1); + + accessorIterator.assertAccessorGetter(float4Vector, + ArrowFlightJdbcFloat4VectorAccessor::getString, + CoreMatchers.nullValue()); + } + } + + @Test + public void testShouldGetFloatMethodFromFloat4VectorWithNull() throws Exception { + try (final Float4Vector float4Vector = new Float4Vector("ID", + rootAllocatorTestRule.getRootAllocator())) { + float4Vector.setNull(0); + float4Vector.setValueCount(1); + + accessorIterator.assertAccessorGetter(float4Vector, + ArrowFlightJdbcFloat4VectorAccessor::getFloat, is(0.0f)); + } + } + + @Test + public void testShouldGetBigDecimalMethodFromFloat4VectorWithNull() throws Exception { + try (final Float4Vector float4Vector = new Float4Vector("ID", + rootAllocatorTestRule.getRootAllocator())) { + float4Vector.setNull(0); + float4Vector.setValueCount(1); + + accessorIterator.assertAccessorGetter(float4Vector, + ArrowFlightJdbcFloat4VectorAccessor::getBigDecimal, + CoreMatchers.nullValue()); + } + } + + @Test + public void testShouldGetObjectMethodFromFloat4VectorWithNull() throws Exception { + try (final Float4Vector float4Vector = new Float4Vector("ID", + rootAllocatorTestRule.getRootAllocator())) { + float4Vector.setNull(0); + float4Vector.setValueCount(1); + + accessorIterator.assertAccessorGetter(float4Vector, + ArrowFlightJdbcFloat4VectorAccessor::getObject, + CoreMatchers.nullValue()); + } + } + + @Test + public void testShouldGetBooleanMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getBoolean, + accessor -> is(accessor.getFloat() != 0.0f)); + } + + @Test + public void testShouldGetByteMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getByte, + accessor -> is((byte) accessor.getFloat())); + } + + @Test + public void testShouldGetShortMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getShort, + accessor -> is((short) accessor.getFloat())); + } + + @Test + public void testShouldGetIntMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getInt, + accessor -> is((int) accessor.getFloat())); + } + + @Test + public void testShouldGetLongMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getLong, + accessor -> is((long) accessor.getFloat())); + } + + @Test + public void testShouldGetDoubleMethodFromFloat4Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat4VectorAccessor::getDouble, + accessor -> is((double) accessor.getFloat())); + } + + @Test + public void testShouldGetBigDecimalMethodFromFloat4Vector() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + float value = accessor.getFloat(); + if (Float.isInfinite(value) || Float.isNaN(value)) { + exceptionCollector.expect(SQLException.class); + } + collector.checkThat(accessor.getBigDecimal(), is(BigDecimal.valueOf(value))); + }); + } + + @Test + public void testShouldGetBigDecimalWithScaleMethodFromFloat4Vector() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + float value = accessor.getFloat(); + if (Float.isInfinite(value) || Float.isNaN(value)) { + exceptionCollector.expect(SQLException.class); + } + collector.checkThat(accessor.getBigDecimal(9), + is(BigDecimal.valueOf(value).setScale(9, RoundingMode.HALF_UP))); + }); + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator.assertAccessorGetter(vector, + ArrowFlightJdbcFloat4VectorAccessor::getObjectClass, + accessor -> equalTo(Float.class)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat8VectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat8VectorAccessorTest.java new file mode 100644 index 00000000000..26758287a96 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/numeric/ArrowFlightJdbcFloat8VectorAccessorTest.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.numeric; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.sql.SQLException; + +import org.apache.arrow.driver.jdbc.utils.AccessorTestUtils; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.vector.Float8Vector; +import org.hamcrest.CoreMatchers; +import org.junit.After; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.rules.ExpectedException; + +public class ArrowFlightJdbcFloat8VectorAccessorTest { + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @Rule + public ExpectedException exceptionCollector = ExpectedException.none(); + + + private Float8Vector vector; + private Float8Vector vectorWithNull; + + private final AccessorTestUtils.AccessorSupplier + accessorSupplier = + (vector, getCurrentRow) -> new ArrowFlightJdbcFloat8VectorAccessor((Float8Vector) vector, + getCurrentRow, (boolean wasNull) -> { + }); + + private final AccessorTestUtils.AccessorIterator + accessorIterator = + new AccessorTestUtils.AccessorIterator<>(collector, accessorSupplier); + + @Before + public void setup() { + this.vector = rootAllocatorTestRule.createFloat8Vector(); + this.vectorWithNull = rootAllocatorTestRule.createFloat8VectorForNullTests(); + } + + @After + public void tearDown() { + this.vector.close(); + this.vectorWithNull.close(); + } + + @Test + public void testShouldGetDoubleMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getDouble, + (accessor, currentRow) -> is(vector.getValueAsDouble(currentRow))); + } + + @Test + public void testShouldGetObjectMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getObject, + (accessor) -> is(accessor.getDouble())); + } + + @Test + public void testShouldGetStringMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getString, + (accessor) -> is(Double.toString(accessor.getDouble()))); + } + + @Test + public void testShouldGetBooleanMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getBoolean, + (accessor) -> is(accessor.getDouble() != 0.0)); + } + + @Test + public void testShouldGetByteMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getByte, + (accessor) -> is((byte) accessor.getDouble())); + } + + @Test + public void testShouldGetShortMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getShort, + (accessor) -> is((short) accessor.getDouble())); + } + + @Test + public void testShouldGetIntMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getInt, + (accessor) -> is((int) accessor.getDouble())); + } + + @Test + public void testShouldGetLongMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getLong, + (accessor) -> is((long) accessor.getDouble())); + } + + @Test + public void testShouldGetFloatMethodFromFloat8Vector() throws Exception { + accessorIterator.assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getFloat, + (accessor) -> is((float) accessor.getDouble())); + } + + @Test + public void testShouldGetBigDecimalMethodFromFloat8Vector() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + double value = accessor.getDouble(); + if (Double.isInfinite(value) || Double.isNaN(value)) { + exceptionCollector.expect(SQLException.class); + } + collector.checkThat(accessor.getBigDecimal(), is(BigDecimal.valueOf(value))); + }); + } + + @Test + public void testShouldGetObjectClass() throws Exception { + accessorIterator + .assertAccessorGetter(vector, ArrowFlightJdbcFloat8VectorAccessor::getObjectClass, + equalTo(Double.class)); + } + + @Test + public void testShouldGetStringMethodFromFloat8VectorWithNull() throws Exception { + accessorIterator + .assertAccessorGetter(vectorWithNull, ArrowFlightJdbcFloat8VectorAccessor::getString, + CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetFloatMethodFromFloat8VectorWithNull() throws Exception { + accessorIterator + .assertAccessorGetter(vectorWithNull, ArrowFlightJdbcFloat8VectorAccessor::getFloat, + is(0.0f)); + } + + @Test + public void testShouldGetBigDecimalMethodFromFloat8VectorWithNull() throws Exception { + accessorIterator.assertAccessorGetter(vectorWithNull, + ArrowFlightJdbcFloat8VectorAccessor::getBigDecimal, + CoreMatchers.nullValue()); + } + + @Test + public void testShouldGetBigDecimalWithScaleMethodFromFloat4Vector() throws Exception { + accessorIterator.iterate(vector, (accessor, currentRow) -> { + double value = accessor.getDouble(); + if (Double.isInfinite(value) || Double.isNaN(value)) { + exceptionCollector.expect(SQLException.class); + } + collector.checkThat(accessor.getBigDecimal(9), + is(BigDecimal.valueOf(value).setScale(9, RoundingMode.HALF_UP))); + }); + } + + @Test + public void testShouldGetObjectMethodFromFloat8VectorWithNull() throws Exception { + accessorIterator + .assertAccessorGetter(vectorWithNull, ArrowFlightJdbcFloat8VectorAccessor::getObject, + CoreMatchers.nullValue()); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/text/ArrowFlightJdbcVarCharVectorAccessorTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/text/ArrowFlightJdbcVarCharVectorAccessorTest.java new file mode 100644 index 00000000000..799c517dd56 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/accessor/impl/text/ArrowFlightJdbcVarCharVectorAccessorTest.java @@ -0,0 +1,733 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.accessor.impl.text; + +import static java.nio.charset.StandardCharsets.US_ASCII; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.apache.commons.io.IOUtils.toByteArray; +import static org.apache.commons.io.IOUtils.toCharArray; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.mockito.Mockito.when; + +import java.io.InputStream; +import java.io.Reader; +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.SQLException; +import java.sql.Time; +import java.sql.Timestamp; +import java.text.SimpleDateFormat; +import java.util.Calendar; +import java.util.TimeZone; +import java.util.function.IntSupplier; + +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcDateVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeStampVectorAccessor; +import org.apache.arrow.driver.jdbc.accessor.impl.calendar.ArrowFlightJdbcTimeVectorAccessor; +import org.apache.arrow.driver.jdbc.utils.RootAllocatorTestRule; +import org.apache.arrow.driver.jdbc.utils.ThrowableAssertionUtils; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.util.Text; +import org.junit.Assert; +import org.junit.Before; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.rules.ExpectedException; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + + +@RunWith(MockitoJUnitRunner.class) +public class ArrowFlightJdbcVarCharVectorAccessorTest { + + private ArrowFlightJdbcVarCharVectorAccessor accessor; + private final SimpleDateFormat dateTimeFormat = + new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss.SSSXXX"); + private final SimpleDateFormat timeFormat = new SimpleDateFormat("HH:mm:ss.SSSXXX"); + + @ClassRule + public static RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + + @Mock + private ArrowFlightJdbcVarCharVectorAccessor.Getter getter; + + @Rule + public ErrorCollector collector = new ErrorCollector(); + + @Rule + public ExpectedException thrown = ExpectedException.none(); + + @Before + public void setUp() { + IntSupplier currentRowSupplier = () -> 0; + accessor = + new ArrowFlightJdbcVarCharVectorAccessor(getter, currentRowSupplier, (boolean wasNull) -> { + }); + } + + @Test + public void testShouldGetStringFromNullReturnNull() { + when(getter.get(0)).thenReturn(null); + final String result = accessor.getString(); + + collector.checkThat(result, equalTo(null)); + } + + @Test + public void testShouldGetStringReturnValidString() { + Text value = new Text("Value for Test."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + final String result = accessor.getString(); + + collector.checkThat(result, instanceOf(String.class)); + collector.checkThat(result, equalTo(value.toString())); + } + + @Test + public void testShouldGetObjectReturnValidString() { + Text value = new Text("Value for Test."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + final String result = accessor.getObject(); + + collector.checkThat(result, instanceOf(String.class)); + collector.checkThat(result, equalTo(value.toString())); + } + + @Test + public void testShouldGetByteThrowsExceptionForNonNumericValue() throws Exception { + Text value = new Text("Invalid value for byte."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getByte(); + } + + @Test + public void testShouldGetByteThrowsExceptionForOutOfRangePositiveValue() throws Exception { + Text value = new Text("128"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getByte(); + } + + @Test + public void testShouldGetByteThrowsExceptionForOutOfRangeNegativeValue() throws Exception { + Text value = new Text("-129"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getByte(); + } + + @Test + public void testShouldGetByteReturnValidPositiveByte() throws Exception { + Text value = new Text("127"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + byte result = accessor.getByte(); + + collector.checkThat(result, instanceOf(Byte.class)); + collector.checkThat(result, equalTo((byte) 127)); + } + + @Test + public void testShouldGetByteReturnValidNegativeByte() throws Exception { + Text value = new Text("-128"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + byte result = accessor.getByte(); + + collector.checkThat(result, instanceOf(Byte.class)); + collector.checkThat(result, equalTo((byte) -128)); + } + + @Test + public void testShouldGetShortThrowsExceptionForNonNumericValue() throws Exception { + Text value = new Text("Invalid value for short."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getShort(); + } + + @Test + public void testShouldGetShortThrowsExceptionForOutOfRangePositiveValue() throws Exception { + Text value = new Text("32768"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getShort(); + } + + @Test + public void testShouldGetShortThrowsExceptionForOutOfRangeNegativeValue() throws Exception { + Text value = new Text("-32769"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getShort(); + } + + @Test + public void testShouldGetShortReturnValidPositiveShort() throws Exception { + Text value = new Text("32767"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + short result = accessor.getShort(); + + collector.checkThat(result, instanceOf(Short.class)); + collector.checkThat(result, equalTo((short) 32767)); + } + + @Test + public void testShouldGetShortReturnValidNegativeShort() throws Exception { + Text value = new Text("-32768"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + short result = accessor.getShort(); + + collector.checkThat(result, instanceOf(Short.class)); + collector.checkThat(result, equalTo((short) -32768)); + } + + @Test + public void testShouldGetIntThrowsExceptionForNonNumericValue() throws Exception { + Text value = new Text("Invalid value for int."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getInt(); + } + + @Test + public void testShouldGetIntThrowsExceptionForOutOfRangePositiveValue() throws Exception { + Text value = new Text("2147483648"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getInt(); + } + + @Test + public void testShouldGetIntThrowsExceptionForOutOfRangeNegativeValue() throws Exception { + Text value = new Text("-2147483649"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getInt(); + } + + @Test + public void testShouldGetIntReturnValidPositiveInteger() throws Exception { + Text value = new Text("2147483647"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + int result = accessor.getInt(); + + collector.checkThat(result, instanceOf(Integer.class)); + collector.checkThat(result, equalTo(2147483647)); + } + + @Test + public void testShouldGetIntReturnValidNegativeInteger() throws Exception { + Text value = new Text("-2147483648"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + int result = accessor.getInt(); + + collector.checkThat(result, instanceOf(Integer.class)); + collector.checkThat(result, equalTo(-2147483648)); + } + + @Test + public void testShouldGetLongThrowsExceptionForNonNumericValue() throws Exception { + Text value = new Text("Invalid value for long."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getLong(); + } + + @Test + public void testShouldGetLongThrowsExceptionForOutOfRangePositiveValue() throws Exception { + Text value = new Text("9223372036854775808"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getLong(); + } + + @Test + public void testShouldGetLongThrowsExceptionForOutOfRangeNegativeValue() throws Exception { + Text value = new Text("-9223372036854775809"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getLong(); + } + + @Test + public void testShouldGetLongReturnValidPositiveLong() throws Exception { + Text value = new Text("9223372036854775807"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + long result = accessor.getLong(); + + collector.checkThat(result, instanceOf(Long.class)); + collector.checkThat(result, equalTo(9223372036854775807L)); + } + + @Test + public void testShouldGetLongReturnValidNegativeLong() throws Exception { + Text value = new Text("-9223372036854775808"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + long result = accessor.getLong(); + + collector.checkThat(result, instanceOf(Long.class)); + collector.checkThat(result, equalTo(-9223372036854775808L)); + } + + @Test + public void testShouldBigDecimalWithParametersThrowsExceptionForNonNumericValue() throws Exception { + Text value = new Text("Invalid value for BigDecimal."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getBigDecimal(1); + } + + @Test + public void testShouldGetBigDecimalThrowsExceptionForNonNumericValue() throws Exception { + Text value = new Text("Invalid value for BigDecimal."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getBigDecimal(); + } + + @Test + public void testShouldGetBigDecimalReturnValidPositiveBigDecimal() throws Exception { + Text value = new Text("9223372036854775807000.999"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + BigDecimal result = accessor.getBigDecimal(); + + collector.checkThat(result, instanceOf(BigDecimal.class)); + collector.checkThat(result, equalTo(new BigDecimal("9223372036854775807000.999"))); + } + + @Test + public void testShouldGetBigDecimalReturnValidNegativeBigDecimal() throws Exception { + Text value = new Text("-9223372036854775807000.999"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + BigDecimal result = accessor.getBigDecimal(); + + collector.checkThat(result, instanceOf(BigDecimal.class)); + collector.checkThat(result, equalTo(new BigDecimal("-9223372036854775807000.999"))); + } + + @Test + public void testShouldGetDoubleThrowsExceptionForNonNumericValue() throws Exception { + Text value = new Text("Invalid value for double."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getDouble(); + } + + @Test + public void testShouldGetDoubleReturnValidPositiveDouble() throws Exception { + Text value = new Text("1.7976931348623157E308D"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + double result = accessor.getDouble(); + + collector.checkThat(result, instanceOf(Double.class)); + collector.checkThat(result, equalTo(1.7976931348623157E308D)); + } + + @Test + public void testShouldGetDoubleReturnValidNegativeDouble() throws Exception { + Text value = new Text("-1.7976931348623157E308D"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + double result = accessor.getDouble(); + + collector.checkThat(result, instanceOf(Double.class)); + collector.checkThat(result, equalTo(-1.7976931348623157E308D)); + } + + @Test + public void testShouldGetDoubleWorkWithPositiveInfinity() throws Exception { + Text value = new Text("Infinity"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + double result = accessor.getDouble(); + + collector.checkThat(result, instanceOf(Double.class)); + collector.checkThat(result, equalTo(Double.POSITIVE_INFINITY)); + } + + @Test + public void testShouldGetDoubleWorkWithNegativeInfinity() throws Exception { + Text value = new Text("-Infinity"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + double result = accessor.getDouble(); + + collector.checkThat(result, instanceOf(Double.class)); + collector.checkThat(result, equalTo(Double.NEGATIVE_INFINITY)); + } + + @Test + public void testShouldGetDoubleWorkWithNaN() throws Exception { + Text value = new Text("NaN"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + double result = accessor.getDouble(); + + collector.checkThat(result, instanceOf(Double.class)); + collector.checkThat(result, equalTo(Double.NaN)); + } + + @Test + public void testShouldGetFloatThrowsExceptionForNonNumericValue() throws Exception { + Text value = new Text("Invalid value for float."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getFloat(); + } + + @Test + public void testShouldGetFloatReturnValidPositiveFloat() throws Exception { + Text value = new Text("3.4028235E38F"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + float result = accessor.getFloat(); + + collector.checkThat(result, instanceOf(Float.class)); + collector.checkThat(result, equalTo(3.4028235E38F)); + } + + @Test + public void testShouldGetFloatReturnValidNegativeFloat() throws Exception { + Text value = new Text("-3.4028235E38F"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + float result = accessor.getFloat(); + + collector.checkThat(result, instanceOf(Float.class)); + collector.checkThat(result, equalTo(-3.4028235E38F)); + } + + @Test + public void testShouldGetFloatWorkWithPositiveInfinity() throws Exception { + Text value = new Text("Infinity"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + float result = accessor.getFloat(); + + collector.checkThat(result, instanceOf(Float.class)); + collector.checkThat(result, equalTo(Float.POSITIVE_INFINITY)); + } + + @Test + public void testShouldGetFloatWorkWithNegativeInfinity() throws Exception { + Text value = new Text("-Infinity"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + float result = accessor.getFloat(); + + collector.checkThat(result, instanceOf(Float.class)); + collector.checkThat(result, equalTo(Float.NEGATIVE_INFINITY)); + } + + @Test + public void testShouldGetFloatWorkWithNaN() throws Exception { + Text value = new Text("NaN"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + float result = accessor.getFloat(); + + collector.checkThat(result, instanceOf(Float.class)); + collector.checkThat(result, equalTo(Float.NaN)); + } + + @Test + public void testShouldGetDateThrowsExceptionForNonDateValue() throws Exception { + Text value = new Text("Invalid value for date."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getDate(null); + } + + @Test + public void testShouldGetDateReturnValidDateWithoutCalendar() throws Exception { + Text value = new Text("2021-07-02"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Date result = accessor.getDate(null); + + collector.checkThat(result, instanceOf(Date.class)); + + Calendar calendar = Calendar.getInstance(); + calendar.setTime(result); + + collector.checkThat(dateTimeFormat.format(calendar.getTime()), + equalTo("2021-07-02T00:00:00.000Z")); + } + + @Test + public void testShouldGetDateReturnValidDateWithCalendar() throws Exception { + Text value = new Text("2021-07-02"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("America/Sao_Paulo")); + Date result = accessor.getDate(calendar); + + calendar = Calendar.getInstance(TimeZone.getTimeZone("Etc/UTC")); + calendar.setTime(result); + + collector.checkThat(dateTimeFormat.format(calendar.getTime()), + equalTo("2021-07-02T03:00:00.000Z")); + } + + @Test + public void testShouldGetTimeThrowsExceptionForNonTimeValue() throws Exception { + Text value = new Text("Invalid value for time."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getTime(null); + } + + @Test + public void testShouldGetTimeReturnValidDateWithoutCalendar() throws Exception { + Text value = new Text("02:30:00"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Time result = accessor.getTime(null); + + Calendar calendar = Calendar.getInstance(); + calendar.setTime(result); + + collector.checkThat(timeFormat.format(calendar.getTime()), equalTo("02:30:00.000Z")); + } + + @Test + public void testShouldGetTimeReturnValidDateWithCalendar() throws Exception { + Text value = new Text("02:30:00"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("America/Sao_Paulo")); + Time result = accessor.getTime(calendar); + + calendar = Calendar.getInstance(TimeZone.getTimeZone("Etc/UTC")); + calendar.setTime(result); + + collector.checkThat(timeFormat.format(calendar.getTime()), equalTo("05:30:00.000Z")); + } + + @Test + public void testShouldGetTimestampThrowsExceptionForNonTimeValue() throws Exception { + Text value = new Text("Invalid value for timestamp."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + thrown.expect(SQLException.class); + accessor.getTimestamp(null); + } + + @Test + public void testShouldGetTimestampReturnValidDateWithoutCalendar() throws Exception { + Text value = new Text("2021-07-02 02:30:00.000"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Timestamp result = accessor.getTimestamp(null); + + Calendar calendar = Calendar.getInstance(); + calendar.setTime(result); + + collector.checkThat(dateTimeFormat.format(calendar.getTime()), + equalTo("2021-07-02T02:30:00.000Z")); + } + + @Test + public void testShouldGetTimestampReturnValidDateWithCalendar() throws Exception { + Text value = new Text("2021-07-02 02:30:00.000"); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("America/Sao_Paulo")); + Timestamp result = accessor.getTimestamp(calendar); + + calendar = Calendar.getInstance(TimeZone.getTimeZone("Etc/UTC")); + calendar.setTime(result); + + collector.checkThat(dateTimeFormat.format(calendar.getTime()), + equalTo("2021-07-02T05:30:00.000Z")); + } + + private void assertGetBoolean(Text value, boolean expectedResult) throws SQLException { + when(getter.get(0)).thenReturn(value == null ? null : value.copyBytes()); + boolean result = accessor.getBoolean(); + collector.checkThat(result, equalTo(expectedResult)); + } + + private void assertGetBooleanForSQLException(Text value) { + when(getter.get(0)).thenReturn(value == null ? null : value.copyBytes()); + ThrowableAssertionUtils.simpleAssertThrowableClass(SQLException.class, () -> accessor.getBoolean()); + } + + @Test + public void testShouldGetBooleanThrowsSQLExceptionForInvalidValue() { + assertGetBooleanForSQLException(new Text("anything")); + } + + @Test + public void testShouldGetBooleanThrowsSQLExceptionForEmpty() { + assertGetBooleanForSQLException(new Text("")); + } + + @Test + public void testShouldGetBooleanReturnFalseFor0() throws Exception { + assertGetBoolean(new Text("0"), false); + } + + @Test + public void testShouldGetBooleanReturnFalseForFalseString() throws Exception { + assertGetBoolean(new Text("false"), false); + } + + @Test + public void testShouldGetBooleanReturnFalseForNull() throws Exception { + assertGetBoolean(null, false); + } + + @Test + public void testShouldGetBytesReturnValidByteArray() { + Text value = new Text("Value for Test."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + final byte[] result = accessor.getBytes(); + + collector.checkThat(result, instanceOf(byte[].class)); + collector.checkThat(result, equalTo(value.toString().getBytes(UTF_8))); + } + + @Test + public void testShouldGetUnicodeStreamReturnValidInputStream() throws Exception { + Text value = new Text("Value for Test."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + try (final InputStream result = accessor.getUnicodeStream()) { + byte[] resultBytes = toByteArray(result); + + collector.checkThat(new String(resultBytes, UTF_8), + equalTo(value.toString())); + } + } + + @Test + public void testShouldGetAsciiStreamReturnValidInputStream() throws Exception { + Text valueText = new Text("Value for Test."); + byte[] valueAscii = valueText.toString().getBytes(US_ASCII); + when(getter.get(0)).thenReturn(valueText.copyBytes()); + + try (final InputStream result = accessor.getAsciiStream()) { + byte[] resultBytes = toByteArray(result); + + Assert.assertArrayEquals(valueAscii, resultBytes); + } + } + + @Test + public void testShouldGetCharacterStreamReturnValidReader() throws Exception { + Text value = new Text("Value for Test."); + when(getter.get(0)).thenReturn(value.copyBytes()); + + try (Reader result = accessor.getCharacterStream()) { + char[] resultChars = toCharArray(result); + + collector.checkThat(new String(resultChars), equalTo(value.toString())); + } + } + + @Test + public void testShouldGetTimeStampBeConsistentWithTimeStampAccessor() throws Exception { + try (TimeStampVector timeStampVector = rootAllocatorTestRule.createTimeStampMilliVector()) { + ArrowFlightJdbcTimeStampVectorAccessor timeStampVectorAccessor = + new ArrowFlightJdbcTimeStampVectorAccessor(timeStampVector, () -> 0, + (boolean wasNull) -> { + }); + + Text value = new Text(timeStampVectorAccessor.getString()); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Timestamp timestamp = accessor.getTimestamp(null); + collector.checkThat(timestamp, equalTo(timeStampVectorAccessor.getTimestamp(null))); + } + } + + @Test + public void testShouldGetTimeBeConsistentWithTimeAccessor() throws Exception { + try (TimeMilliVector timeVector = rootAllocatorTestRule.createTimeMilliVector()) { + ArrowFlightJdbcTimeVectorAccessor timeVectorAccessor = + new ArrowFlightJdbcTimeVectorAccessor(timeVector, () -> 0, (boolean wasNull) -> { + }); + + Text value = new Text(timeVectorAccessor.getString()); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Time time = accessor.getTime(null); + collector.checkThat(time, equalTo(timeVectorAccessor.getTime(null))); + } + } + + @Test + public void testShouldGetDateBeConsistentWithDateAccessor() throws Exception { + try (DateMilliVector dateVector = rootAllocatorTestRule.createDateMilliVector()) { + ArrowFlightJdbcDateVectorAccessor dateVectorAccessor = + new ArrowFlightJdbcDateVectorAccessor(dateVector, () -> 0, (boolean wasNull) -> { + }); + + Text value = new Text(dateVectorAccessor.getString()); + when(getter.get(0)).thenReturn(value.copyBytes()); + + Date date = accessor.getDate(null); + collector.checkThat(date, equalTo(dateVectorAccessor.getDate(null))); + } + } + + @Test + public void testShouldGetObjectClassReturnString() { + final Class clazz = accessor.getObjectClass(); + collector.checkThat(clazz, equalTo(String.class)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/Authentication.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/Authentication.java new file mode 100644 index 00000000000..5fe2b0dc057 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/Authentication.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.authentication; + +import java.util.Properties; + +import org.apache.arrow.flight.auth2.CallHeaderAuthenticator; + +public interface Authentication { + /** + * Create a {@link CallHeaderAuthenticator} which is used to authenticate the connection. + * + * @return a CallHeaderAuthenticator. + */ + CallHeaderAuthenticator authenticate(); + + /** + * Uses the validCredentials variable and populate the Properties object. + * @param properties the Properties object that will be populated. + */ + void populateProperties(Properties properties); +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/TokenAuthentication.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/TokenAuthentication.java new file mode 100644 index 00000000000..605705d1ca9 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/TokenAuthentication.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.authentication; + +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl; +import org.apache.arrow.flight.CallHeaders; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.auth2.CallHeaderAuthenticator; + +public class TokenAuthentication implements Authentication { + private final List validCredentials; + + public TokenAuthentication(List validCredentials) { + this.validCredentials = validCredentials; + } + + @Override + public CallHeaderAuthenticator authenticate() { + return new CallHeaderAuthenticator() { + @Override + public AuthResult authenticate(CallHeaders incomingHeaders) { + String authorization = incomingHeaders.get("authorization"); + if (!validCredentials.contains(authorization)) { + throw CallStatus.UNAUTHENTICATED.withDescription("Invalid credentials.").toRuntimeException(); + } + return new AuthResult() { + @Override + public String getPeerIdentity() { + return authorization; + } + }; + } + }; + } + + @Override + public void populateProperties(Properties properties) { + this.validCredentials.forEach(value -> properties.put( + ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.TOKEN.camelName(), value)); + } + + public static final class Builder { + private final List tokenList = new ArrayList<>(); + + public TokenAuthentication.Builder token(String token) { + tokenList.add("Bearer " + token); + return this; + } + + public TokenAuthentication build() { + return new TokenAuthentication(tokenList); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/UserPasswordAuthentication.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/UserPasswordAuthentication.java new file mode 100644 index 00000000000..5dc97c858f3 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/authentication/UserPasswordAuthentication.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.authentication; + +import java.util.HashMap; +import java.util.Map; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.auth2.BasicCallHeaderAuthenticator; +import org.apache.arrow.flight.auth2.CallHeaderAuthenticator; +import org.apache.arrow.flight.auth2.GeneratedBearerTokenAuthenticator; + +public class UserPasswordAuthentication implements Authentication { + + private final Map validCredentials; + + public UserPasswordAuthentication(Map validCredentials) { + this.validCredentials = validCredentials; + } + + private String getCredentials(String key) { + return validCredentials.getOrDefault(key, null); + } + + @Override + public CallHeaderAuthenticator authenticate() { + return new GeneratedBearerTokenAuthenticator( + new BasicCallHeaderAuthenticator((username, password) -> { + if (validCredentials.containsKey(username) && getCredentials(username).equals(password)) { + return () -> username; + } + throw CallStatus.UNAUTHENTICATED.withDescription("Invalid credentials.").toRuntimeException(); + })); + } + + @Override + public void populateProperties(Properties properties) { + validCredentials.forEach((key, value) -> { + properties.put(ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.USER.camelName(), key); + properties.put(ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PASSWORD.camelName(), value); + }); + } + + public static class Builder { + Map credentials = new HashMap<>(); + + public Builder user(String username, String password) { + credentials.put(username, password); + return this; + } + + public UserPasswordAuthentication build() { + return new UserPasswordAuthentication(credentials); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/client/utils/ClientAuthenticationUtilsTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/client/utils/ClientAuthenticationUtilsTest.java new file mode 100644 index 00000000000..d61436fd6e2 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/client/utils/ClientAuthenticationUtilsTest.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.client.utils; + +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.io.InputStream; +import java.lang.reflect.Method; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Enumeration; + +import org.bouncycastle.openssl.jcajce.JcaPEMWriter; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public class ClientAuthenticationUtilsTest { + @Mock + KeyStore keyStoreMock; + + @Test + public void testGetCertificatesInputStream() throws IOException, KeyStoreException { + JcaPEMWriter pemWriterMock = mock(JcaPEMWriter.class); + Certificate certificateMock = mock(Certificate.class); + Enumeration alias = Collections.enumeration(Arrays.asList("test1", "test2")); + + Mockito.when(keyStoreMock.aliases()).thenReturn(alias); + Mockito.when(keyStoreMock.isCertificateEntry("test1")).thenReturn(true); + Mockito.when(keyStoreMock.getCertificate("test1")).thenReturn(certificateMock); + + ClientAuthenticationUtils.getCertificatesInputStream(keyStoreMock, pemWriterMock); + Mockito.verify(pemWriterMock).writeObject(certificateMock); + Mockito.verify(pemWriterMock).flush(); + } + + @Test + public void testGetKeyStoreInstance() throws IOException, + KeyStoreException, CertificateException, NoSuchAlgorithmException { + try (MockedStatic keyStoreMockedStatic = Mockito.mockStatic(KeyStore.class)) { + keyStoreMockedStatic + .when(() -> ClientAuthenticationUtils.getKeyStoreInstance(Mockito.any())) + .thenReturn(keyStoreMock); + + KeyStore receiveKeyStore = ClientAuthenticationUtils.getKeyStoreInstance("test1"); + Mockito + .verify(keyStoreMock) + .load(null, null); + + Assert.assertEquals(receiveKeyStore, keyStoreMock); + } + } + + @Test + public void testGetCertificateInputStreamFromMacSystem() throws IOException, + KeyStoreException, CertificateException, NoSuchAlgorithmException { + InputStream mock = mock(InputStream.class); + + try (MockedStatic keyStoreMockedStatic = createKeyStoreStaticMock(); + MockedStatic + clientAuthenticationUtilsMockedStatic = createClientAuthenticationUtilsStaticMock()) { + + setOperatingSystemMock(clientAuthenticationUtilsMockedStatic, false, true); + keyStoreMockedStatic.when(() -> ClientAuthenticationUtils + .getKeyStoreInstance("KeychainStore")) + .thenReturn(keyStoreMock); + keyStoreMockedStatic.when(() -> ClientAuthenticationUtils + .getCertificatesInputStream(Mockito.any())) + .thenReturn(mock); + + InputStream inputStream = ClientAuthenticationUtils.getCertificateInputStreamFromSystem("test"); + Assert.assertEquals(inputStream, mock); + } + } + + @Test + public void testGetCertificateInputStreamFromWindowsSystem() throws IOException, + KeyStoreException, CertificateException, NoSuchAlgorithmException { + InputStream mock = mock(InputStream.class); + + try (MockedStatic keyStoreMockedStatic = createKeyStoreStaticMock(); + MockedStatic + clientAuthenticationUtilsMockedStatic = createClientAuthenticationUtilsStaticMock()) { + + setOperatingSystemMock(clientAuthenticationUtilsMockedStatic, true, false); + keyStoreMockedStatic + .when(() -> ClientAuthenticationUtils.getKeyStoreInstance("Windows-ROOT")) + .thenReturn(keyStoreMock); + keyStoreMockedStatic + .when(() -> ClientAuthenticationUtils.getKeyStoreInstance("Windows-MY")) + .thenReturn(keyStoreMock); + keyStoreMockedStatic + .when(() -> ClientAuthenticationUtils.getCertificatesInputStream(Mockito.any())) + .thenReturn(mock); + + InputStream inputStream = ClientAuthenticationUtils.getCertificateInputStreamFromSystem("test"); + Assert.assertEquals(inputStream, mock); + } + } + + private MockedStatic createKeyStoreStaticMock() { + return Mockito.mockStatic(KeyStore.class); + } + + private MockedStatic createClientAuthenticationUtilsStaticMock() { + return Mockito.mockStatic(ClientAuthenticationUtils.class , invocationOnMock -> { + Method method = invocationOnMock.getMethod(); + if (method.getName().equals("getCertificateInputStreamFromSystem")) { + return invocationOnMock.callRealMethod(); + } + return invocationOnMock.getMock(); + }); + } + + private void setOperatingSystemMock(MockedStatic clientAuthenticationUtilsMockedStatic, + boolean isWindows, boolean isMac) { + clientAuthenticationUtilsMockedStatic.when(ClientAuthenticationUtils::isMac).thenReturn(isMac); + clientAuthenticationUtilsMockedStatic.when(ClientAuthenticationUtils::isWindows).thenReturn(isWindows); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/AccessorTestUtils.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/AccessorTestUtils.java new file mode 100644 index 00000000000..bc1e8a04203 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/AccessorTestUtils.java @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static org.hamcrest.CoreMatchers.is; + +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.IntSupplier; +import java.util.function.Supplier; + +import org.apache.arrow.driver.jdbc.accessor.ArrowFlightJdbcAccessor; +import org.apache.arrow.vector.ValueVector; +import org.hamcrest.Matcher; +import org.junit.rules.ErrorCollector; + +public class AccessorTestUtils { + @FunctionalInterface + public interface CheckedFunction { + R apply(T t) throws SQLException; + } + + public interface AccessorSupplier { + T supply(ValueVector vector, IntSupplier getCurrentRow); + } + + public interface AccessorConsumer { + void accept(T accessor, int currentRow) throws Exception; + } + + public interface MatcherGetter { + Matcher get(T accessor, int currentRow); + } + + public static class Cursor { + int currentRow = 0; + int limit; + + public Cursor(int limit) { + this.limit = limit; + } + + public void next() { + currentRow++; + } + + boolean hasNext() { + return currentRow < limit; + } + + public int getCurrentRow() { + return currentRow; + } + } + + public static class AccessorIterator { + private final ErrorCollector collector; + private final AccessorSupplier accessorSupplier; + + public AccessorIterator(ErrorCollector collector, AccessorSupplier accessorSupplier) { + this.collector = collector; + this.accessorSupplier = accessorSupplier; + } + + public void iterate(ValueVector vector, AccessorConsumer accessorConsumer) throws Exception { + int valueCount = vector.getValueCount(); + if (valueCount == 0) { + throw new IllegalArgumentException("Vector is empty"); + } + + Cursor cursor = new Cursor(valueCount); + T accessor = accessorSupplier.supply(vector, cursor::getCurrentRow); + + while (cursor.hasNext()) { + accessorConsumer.accept(accessor, cursor.getCurrentRow()); + cursor.next(); + } + } + + public void iterate(ValueVector vector, Consumer accessorConsumer) throws Exception { + iterate(vector, (accessor, currentRow) -> accessorConsumer.accept(accessor)); + } + + public List toList(ValueVector vector) throws Exception { + List result = new ArrayList<>(); + iterate(vector, (accessor, currentRow) -> result.add(accessor.getObject())); + + return result; + } + + public void assertAccessorGetter(ValueVector vector, CheckedFunction getter, + MatcherGetter matcherGetter) throws Exception { + iterate(vector, (accessor, currentRow) -> { + R object = getter.apply(accessor); + boolean wasNull = accessor.wasNull(); + + collector.checkThat(object, matcherGetter.get(accessor, currentRow)); + collector.checkThat(wasNull, is(accessor.getObject() == null)); + }); + } + + public void assertAccessorGetterThrowingException(ValueVector vector, CheckedFunction getter) + throws Exception { + iterate(vector, (accessor, currentRow) -> + ThrowableAssertionUtils.simpleAssertThrowableClass(SQLException.class, () -> getter.apply(accessor))); + } + + public void assertAccessorGetter(ValueVector vector, CheckedFunction getter, + Function> matcherGetter) throws Exception { + assertAccessorGetter(vector, getter, (accessor, currentRow) -> matcherGetter.apply(accessor)); + } + + public void assertAccessorGetter(ValueVector vector, CheckedFunction getter, + Supplier> matcherGetter) throws Exception { + assertAccessorGetter(vector, getter, (accessor, currentRow) -> matcherGetter.get()); + } + + public void assertAccessorGetter(ValueVector vector, CheckedFunction getter, + Matcher matcher) throws Exception { + assertAccessorGetter(vector, getter, (accessor, currentRow) -> matcher); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java new file mode 100644 index 00000000000..4fb07428af4 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionConfigImplTest.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static java.lang.Runtime.getRuntime; +import static java.util.Arrays.asList; +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.HOST; +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PASSWORD; +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.PORT; +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.THREAD_POOL_SIZE; +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.USER; +import static org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty.USE_ENCRYPTION; +import static org.hamcrest.CoreMatchers.is; + +import java.util.List; +import java.util.Properties; +import java.util.Random; +import java.util.function.Function; + +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +@RunWith(Parameterized.class) +public final class ArrowFlightConnectionConfigImplTest { + + private static final Random RANDOM = new Random(12L); + + private final Properties properties = new Properties(); + private ArrowFlightConnectionConfigImpl arrowFlightConnectionConfig; + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @Parameter + public ArrowFlightConnectionProperty property; + + @Parameter(value = 1) + public Object value; + + @Parameter(value = 2) + public Function arrowFlightConnectionConfigFunction; + + @Before + public void setUp() { + arrowFlightConnectionConfig = new ArrowFlightConnectionConfigImpl(properties); + properties.put(property.camelName(), value); + } + + @Test + public void testGetProperty() { + collector.checkThat(arrowFlightConnectionConfigFunction.apply(arrowFlightConnectionConfig), + is(value)); + } + + @Parameters(name = "<{0}> as <{1}>") + public static List provideParameters() { + return asList(new Object[][] { + {HOST, "host", + (Function) ArrowFlightConnectionConfigImpl::getHost}, + {PORT, + RANDOM.nextInt(Short.toUnsignedInt(Short.MAX_VALUE)), + (Function) ArrowFlightConnectionConfigImpl::getPort}, + {USER, "user", + (Function) ArrowFlightConnectionConfigImpl::getUser}, + {PASSWORD, "password", + (Function) ArrowFlightConnectionConfigImpl::getPassword}, + {USE_ENCRYPTION, RANDOM.nextBoolean(), + (Function) ArrowFlightConnectionConfigImpl::useEncryption}, + {THREAD_POOL_SIZE, + RANDOM.nextInt(getRuntime().availableProcessors()), + (Function) ArrowFlightConnectionConfigImpl::threadPoolSize}, + }); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionPropertyTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionPropertyTest.java new file mode 100644 index 00000000000..25a48612cbd --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ArrowFlightConnectionPropertyTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static org.apache.arrow.util.AutoCloseables.close; +import static org.mockito.MockitoAnnotations.openMocks; + +import java.util.ArrayList; +import java.util.List; +import java.util.Properties; + +import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; +import org.junit.After; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; +import org.mockito.Mock; + +@RunWith(Parameterized.class) +public final class ArrowFlightConnectionPropertyTest { + + @Mock + public Properties properties; + + private AutoCloseable mockitoResource; + + @Parameter + public ArrowFlightConnectionProperty arrowFlightConnectionProperty; + + @Before + public void setUp() { + mockitoResource = openMocks(this); + } + + @After + public void tearDown() throws Exception { + close(mockitoResource); + } + + @Test + public void testWrapIsUnsupported() { + ThrowableAssertionUtils.simpleAssertThrowableClass(UnsupportedOperationException.class, + () -> arrowFlightConnectionProperty.wrap(properties)); + } + + @Test + public void testRequiredPropertyThrows() { + Assume.assumeTrue(arrowFlightConnectionProperty.required()); + ThrowableAssertionUtils.simpleAssertThrowableClass(IllegalStateException.class, + () -> arrowFlightConnectionProperty.get(new Properties())); + } + + @Test + public void testOptionalPropertyReturnsDefault() { + Assume.assumeTrue(!arrowFlightConnectionProperty.required()); + Assert.assertEquals(arrowFlightConnectionProperty.defaultValue(), + arrowFlightConnectionProperty.get(new Properties())); + } + + @Parameters + public static List provideParameters() { + final ArrowFlightConnectionProperty[] arrowFlightConnectionProperties = + ArrowFlightConnectionProperty.values(); + final List parameters = new ArrayList<>(arrowFlightConnectionProperties.length); + for (final ArrowFlightConnectionProperty arrowFlightConnectionProperty : arrowFlightConnectionProperties) { + parameters.add(new Object[] {arrowFlightConnectionProperty}); + } + return parameters; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ConnectionWrapperTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ConnectionWrapperTest.java new file mode 100644 index 00000000000..6044f3a363c --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ConnectionWrapperTest.java @@ -0,0 +1,443 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static java.lang.String.format; +import static java.util.stream.IntStream.range; +import static org.hamcrest.CoreMatchers.allOf; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.CoreMatchers.sameInstance; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLClientInfoException; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Arrays; +import java.util.Random; + +import org.apache.arrow.driver.jdbc.ArrowFlightConnection; +import org.apache.arrow.util.AutoCloseables; +import org.apache.calcite.avatica.AvaticaConnection; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +@RunWith(MockitoJUnitRunner.class) +public final class ConnectionWrapperTest { + + private static final String SCHEMA_NAME = "SCHEMA"; + private static final String PLACEHOLDER_QUERY = "SELECT * FROM DOES_NOT_MATTER"; + private static final int[] COLUMN_INDICES = range(0, 10).toArray(); + private static final String[] COLUMN_NAMES = + Arrays.stream(COLUMN_INDICES).mapToObj(i -> format("col%d", i)).toArray(String[]::new); + private static final String TYPE_NAME = "TYPE_NAME"; + private static final String SAVEPOINT_NAME = "SAVEPOINT"; + private static final String CLIENT_INFO = "CLIENT_INFO"; + private static final int RESULT_SET_TYPE = ResultSet.TYPE_FORWARD_ONLY; + private static final int RESULT_SET_CONCURRENCY = ResultSet.CONCUR_READ_ONLY; + private static final int RESULT_SET_HOLDABILITY = ResultSet.HOLD_CURSORS_OVER_COMMIT; + private static final int GENERATED_KEYS = Statement.NO_GENERATED_KEYS; + private static final Random RANDOM = new Random(Long.MAX_VALUE); + private static final int TIMEOUT = RANDOM.nextInt(Integer.MAX_VALUE); + + @Mock + public AvaticaConnection underlyingConnection; + private ConnectionWrapper connectionWrapper; + @Rule + public final ErrorCollector collector = new ErrorCollector(); + + @Before + public void setUp() { + connectionWrapper = new ConnectionWrapper(underlyingConnection); + } + + @After + public void tearDown() throws Exception { + AutoCloseables.close(connectionWrapper, underlyingConnection); + } + + @Test + public void testUnwrappingUnderlyingConnectionShouldReturnUnderlyingConnection() { + collector.checkThat( + collector.checkSucceeds(() -> connectionWrapper.unwrap(Object.class)), + is(sameInstance(underlyingConnection))); + collector.checkThat( + collector.checkSucceeds(() -> connectionWrapper.unwrap(Connection.class)), + is(sameInstance(underlyingConnection))); + collector.checkThat( + collector.checkSucceeds(() -> connectionWrapper.unwrap(AvaticaConnection.class)), + is(sameInstance(underlyingConnection))); + ThrowableAssertionUtils.simpleAssertThrowableClass(ClassCastException.class, + () -> connectionWrapper.unwrap(ArrowFlightConnection.class)); + ThrowableAssertionUtils.simpleAssertThrowableClass(ClassCastException.class, + () -> connectionWrapper.unwrap(ConnectionWrapper.class)); + } + + @Test + public void testCreateStatementShouldCreateStatementFromUnderlyingConnection() + throws SQLException { + collector.checkThat( + connectionWrapper.createStatement(), + is(sameInstance(verify(underlyingConnection, times(1)).createStatement()))); + collector.checkThat( + connectionWrapper.createStatement(RESULT_SET_TYPE, RESULT_SET_CONCURRENCY, + RESULT_SET_HOLDABILITY), + is(verify(underlyingConnection, times(1)) + .createStatement(RESULT_SET_TYPE, RESULT_SET_CONCURRENCY, RESULT_SET_HOLDABILITY))); + collector.checkThat( + connectionWrapper.createStatement(RESULT_SET_TYPE, RESULT_SET_CONCURRENCY), + is(verify(underlyingConnection, times(1)) + .createStatement(RESULT_SET_TYPE, RESULT_SET_CONCURRENCY))); + } + + @Test + public void testPrepareStatementShouldPrepareStatementFromUnderlyingConnection() + throws SQLException { + collector.checkThat( + connectionWrapper.prepareStatement(PLACEHOLDER_QUERY), + is(sameInstance( + verify(underlyingConnection, times(1)).prepareStatement(PLACEHOLDER_QUERY)))); + collector.checkThat( + connectionWrapper.prepareStatement(PLACEHOLDER_QUERY, COLUMN_INDICES), + is(allOf(sameInstance(verify(underlyingConnection, times(1)) + .prepareStatement(PLACEHOLDER_QUERY, COLUMN_INDICES)), + nullValue()))); + collector.checkThat( + connectionWrapper.prepareStatement(PLACEHOLDER_QUERY, COLUMN_NAMES), + is(allOf(sameInstance(verify(underlyingConnection, times(1)) + .prepareStatement(PLACEHOLDER_QUERY, COLUMN_NAMES)), + nullValue()))); + collector.checkThat( + connectionWrapper.prepareStatement(PLACEHOLDER_QUERY, RESULT_SET_TYPE, + RESULT_SET_CONCURRENCY), + is(allOf(sameInstance(verify(underlyingConnection, times(1)) + .prepareStatement(PLACEHOLDER_QUERY, RESULT_SET_TYPE, RESULT_SET_CONCURRENCY)), + nullValue()))); + collector.checkThat( + connectionWrapper.prepareStatement(PLACEHOLDER_QUERY, GENERATED_KEYS), + is(allOf(sameInstance(verify(underlyingConnection, times(1)) + .prepareStatement(PLACEHOLDER_QUERY, GENERATED_KEYS)), + nullValue()))); + } + + @Test + public void testPrepareCallShouldPrepareCallFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.prepareCall(PLACEHOLDER_QUERY), + is(sameInstance( + verify(underlyingConnection, times(1)).prepareCall(PLACEHOLDER_QUERY)))); + collector.checkThat( + connectionWrapper.prepareCall(PLACEHOLDER_QUERY, RESULT_SET_TYPE, RESULT_SET_CONCURRENCY), + is(verify(underlyingConnection, times(1)) + .prepareCall(PLACEHOLDER_QUERY, RESULT_SET_TYPE, RESULT_SET_CONCURRENCY))); + } + + @Test + public void testNativeSqlShouldGetNativeSqlFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.nativeSQL(PLACEHOLDER_QUERY), + is(sameInstance( + verify(underlyingConnection, times(1)).nativeSQL(PLACEHOLDER_QUERY)))); + } + + @Test + public void testSetAutoCommitShouldSetAutoCommitInUnderlyingConnection() throws SQLException { + connectionWrapper.setAutoCommit(true); + verify(underlyingConnection, times(1)).setAutoCommit(true); + connectionWrapper.setAutoCommit(false); + verify(underlyingConnection, times(1)).setAutoCommit(false); + } + + @Test + public void testGetAutoCommitShouldGetAutoCommitFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.getAutoCommit(), + is(verify(underlyingConnection, times(1)).getAutoCommit())); + } + + @Test + public void testCommitShouldCommitToUnderlyingConnection() throws SQLException { + connectionWrapper.commit(); + verify(underlyingConnection, times(1)).commit(); + } + + @Test + public void testRollbackShouldRollbackFromUnderlyingConnection() throws SQLException { + connectionWrapper.rollback(); + verify(underlyingConnection, times(1)).rollback(); + } + + @Test + public void testCloseShouldCloseUnderlyingConnection() throws SQLException { + connectionWrapper.close(); + verify(underlyingConnection, times(1)).close(); + } + + @Test + public void testIsClosedShouldGetStatusFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.isClosed(), is(verify(underlyingConnection, times(1)).isClosed())); + } + + @Test + public void testGetMetadataShouldGetMetadataFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.getMetaData(), is(verify(underlyingConnection, times(1)).getMetaData())); + } + + @Test + public void testSetReadOnlyShouldSetUnderlyingConnectionAsReadOnly() throws SQLException { + connectionWrapper.setReadOnly(false); + verify(underlyingConnection, times(1)).setReadOnly(false); + connectionWrapper.setReadOnly(true); + verify(underlyingConnection, times(1)).setReadOnly(true); + } + + @Test + public void testSetIsReadOnlyShouldGetStatusFromUnderlyingConnection() throws SQLException { + collector.checkThat(connectionWrapper.isReadOnly(), + is(verify(underlyingConnection).isReadOnly())); + } + + @Test + public void testSetCatalogShouldSetCatalogInUnderlyingConnection() throws SQLException { + final String catalog = "CATALOG"; + connectionWrapper.setCatalog(catalog); + verify(underlyingConnection, times(1)).setCatalog(catalog); + } + + @Test + public void testGetCatalogShouldGetCatalogFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.getCatalog(), + is(allOf(sameInstance(verify(underlyingConnection, times(1)).getCatalog()), nullValue()))); + } + + @Test + public void setTransactionIsolationShouldSetUnderlyingTransactionIsolation() throws SQLException { + final int transactionIsolation = Connection.TRANSACTION_NONE; + connectionWrapper.setTransactionIsolation(Connection.TRANSACTION_NONE); + verify(underlyingConnection, times(1)).setTransactionIsolation(transactionIsolation); + } + + @Test + public void getTransactionIsolationShouldGetUnderlyingConnectionIsolation() throws SQLException { + collector.checkThat( + connectionWrapper.getTransactionIsolation(), + is(equalTo(verify(underlyingConnection, times(1)).getTransactionIsolation()))); + } + + @Test + public void getWarningShouldGetWarningsFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.getWarnings(), + is(allOf( + sameInstance(verify(underlyingConnection, times(1)).getWarnings()), + nullValue()))); + } + + @Test + public void testClearWarningShouldClearWarningsFromUnderlyingConnection() throws SQLException { + connectionWrapper.clearWarnings(); + verify(underlyingConnection, times(1)).clearWarnings(); + } + + @Test + public void getTypeMapShouldGetTypeMapFromUnderlyingConnection() throws SQLException { + when(underlyingConnection.getTypeMap()).thenReturn(null); + collector.checkThat( + connectionWrapper.getTypeMap(), + is(verify(underlyingConnection, times(1)).getTypeMap())); + } + + @Test + public void testSetTypeMapShouldSetTypeMapFromUnderlyingConnection() throws SQLException { + connectionWrapper.setTypeMap(null); + verify(underlyingConnection, times(1)).setTypeMap(null); + } + + @Test + public void testSetHoldabilityShouldSetUnderlyingConnection() throws SQLException { + connectionWrapper.setHoldability(RESULT_SET_HOLDABILITY); + verify(underlyingConnection, times(1)).setHoldability(RESULT_SET_HOLDABILITY); + } + + @Test + public void testGetHoldabilityShouldGetHoldabilityFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.getHoldability(), + is(equalTo(verify(underlyingConnection, times(1)).getHoldability()))); + } + + @Test + public void testSetSavepointShouldSetSavepointInUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.setSavepoint(), + is(allOf( + sameInstance(verify(underlyingConnection, times(1)).setSavepoint()), + nullValue()))); + collector.checkThat( + connectionWrapper.setSavepoint(SAVEPOINT_NAME), + is(sameInstance( + verify(underlyingConnection, times(1)).setSavepoint(SAVEPOINT_NAME)))); + } + + @Test + public void testRollbackShouldRollbackInUnderlyingConnection() throws SQLException { + connectionWrapper.rollback(null); + verify(underlyingConnection, times(1)).rollback(null); + } + + @Test + public void testReleaseSavepointShouldReleaseSavepointFromUnderlyingConnection() + throws SQLException { + connectionWrapper.releaseSavepoint(null); + verify(underlyingConnection, times(1)).releaseSavepoint(null); + } + + @Test + public void testCreateClobShouldCreateClobFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.createClob(), + is(allOf(sameInstance( + verify(underlyingConnection, times(1)).createClob()), nullValue()))); + } + + @Test + public void testCreateBlobShouldCreateBlobFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.createBlob(), + is(allOf(sameInstance( + verify(underlyingConnection, times(1)).createBlob()), nullValue()))); + } + + @Test + public void testCreateNClobShouldCreateNClobFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.createNClob(), + is(allOf(sameInstance( + verify(underlyingConnection, times(1)).createNClob()), nullValue()))); + } + + @Test + public void testCreateSQLXMLShouldCreateSQLXMLFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.createSQLXML(), + is(allOf(sameInstance( + verify(underlyingConnection, times(1)).createSQLXML()), nullValue()))); + } + + @Test + public void testIsValidShouldReturnWhetherUnderlyingConnectionIsValid() throws SQLException { + collector.checkThat( + connectionWrapper.isValid(TIMEOUT), + is(verify(underlyingConnection, times(1)).isValid(TIMEOUT))); + } + + @Test + public void testSetClientInfoShouldSetClientInfoInUnderlyingConnection() + throws SQLClientInfoException { + connectionWrapper.setClientInfo(null); + verify(underlyingConnection, times(1)).setClientInfo(null); + } + + @Test + public void testGetClientInfoShouldGetClientInfoFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.getClientInfo(CLIENT_INFO), + is(allOf( + sameInstance( + verify(underlyingConnection, times(1)).getClientInfo(CLIENT_INFO)), + nullValue()))); + collector.checkThat( + connectionWrapper.getClientInfo(), + is(allOf( + sameInstance( + verify(underlyingConnection, times(1)).getClientInfo()), + nullValue()))); + } + + @Test + public void testCreateArrayOfShouldCreateArrayFromUnderlyingConnection() throws SQLException { + final Object[] elements = range(0, 100).boxed().toArray(); + collector.checkThat( + connectionWrapper.createArrayOf(TYPE_NAME, elements), + is(allOf( + sameInstance( + verify(underlyingConnection, times(1)).createArrayOf(TYPE_NAME, elements)), + nullValue()))); + } + + @Test + public void testCreateStructShouldCreateStructFromUnderlyingConnection() throws SQLException { + final Object[] attributes = range(0, 120).boxed().toArray(); + collector.checkThat( + connectionWrapper.createStruct(TYPE_NAME, attributes), + is(allOf( + sameInstance( + verify(underlyingConnection, times(1)).createStruct(TYPE_NAME, attributes)), + nullValue()))); + } + + @Test + public void testSetSchemaShouldSetSchemaInUnderlyingConnection() throws SQLException { + connectionWrapper.setSchema(SCHEMA_NAME); + verify(underlyingConnection, times(1)).setSchema(SCHEMA_NAME); + } + + @Test + public void testGetSchemaShouldGetSchemaFromUnderlyingConnection() throws SQLException { + collector.checkThat( + connectionWrapper.getSchema(), + is(allOf( + sameInstance(verify(underlyingConnection, times(1)).getSchema()), + nullValue()))); + } + + @Test + public void testAbortShouldAbortUnderlyingConnection() throws SQLException { + connectionWrapper.abort(null); + verify(underlyingConnection, times(1)).abort(null); + } + + @Test + public void testSetNetworkTimeoutShouldSetNetworkTimeoutInUnderlyingConnection() + throws SQLException { + connectionWrapper.setNetworkTimeout(null, TIMEOUT); + verify(underlyingConnection, times(1)).setNetworkTimeout(null, TIMEOUT); + } + + @Test + public void testGetNetworkTimeoutShouldGetNetworkTimeoutFromUnderlyingConnection() + throws SQLException { + collector.checkThat( + connectionWrapper.getNetworkTimeout(), + is(equalTo(verify(underlyingConnection, times(1)).getNetworkTimeout()))); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ConvertUtilsTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ConvertUtilsTest.java new file mode 100644 index 00000000000..5cea3749283 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ConvertUtilsTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static org.hamcrest.CoreMatchers.equalTo; + +import java.util.List; + +import org.apache.arrow.flight.sql.FlightSqlColumnMetadata; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.calcite.avatica.ColumnMetaData; +import org.apache.calcite.avatica.proto.Common; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +import com.google.common.collect.ImmutableList; + +public class ConvertUtilsTest { + + @Rule + public ErrorCollector collector = new ErrorCollector(); + + @Test + public void testShouldSetOnColumnMetaDataBuilder() { + + final Common.ColumnMetaData.Builder builder = Common.ColumnMetaData.newBuilder(); + final FlightSqlColumnMetadata expectedColumnMetaData = new FlightSqlColumnMetadata.Builder() + .catalogName("catalog1") + .schemaName("schema1") + .tableName("table1") + .isAutoIncrement(true) + .isCaseSensitive(true) + .isReadOnly(true) + .isSearchable(true) + .precision(20) + .scale(10) + .build(); + ConvertUtils.setOnColumnMetaDataBuilder(builder, expectedColumnMetaData.getMetadataMap()); + assertBuilder(builder, expectedColumnMetaData); + } + + @Test + public void testShouldConvertArrowFieldsToColumnMetaDataList() { + + final List listField = ImmutableList.of( + new Field("col1", + new FieldType(true, ArrowType.Utf8.INSTANCE, null, + new FlightSqlColumnMetadata.Builder() + .catalogName("catalog1") + .schemaName("schema1") + .tableName("table1") + .build().getMetadataMap() + ), null)); + + final List expectedColumnMetaData = ImmutableList.of( + ColumnMetaData.fromProto( + Common.ColumnMetaData.newBuilder() + .setCatalogName("catalog1") + .setSchemaName("schema1") + .setTableName("table1") + .build())); + + final List actualColumnMetaData = ConvertUtils.convertArrowFieldsToColumnMetaDataList(listField); + assertColumnMetaData(expectedColumnMetaData, actualColumnMetaData); + } + + private void assertColumnMetaData(final List expected, final List actual) { + collector.checkThat(expected.size(), equalTo(actual.size())); + int size = expected.size(); + for (int i = 0; i < size; i++) { + final ColumnMetaData expectedColumnMetaData = expected.get(i); + final ColumnMetaData actualColumnMetaData = actual.get(i); + collector.checkThat(expectedColumnMetaData.catalogName, equalTo(actualColumnMetaData.catalogName)); + collector.checkThat(expectedColumnMetaData.schemaName, equalTo(actualColumnMetaData.schemaName)); + collector.checkThat(expectedColumnMetaData.tableName, equalTo(actualColumnMetaData.tableName)); + collector.checkThat(expectedColumnMetaData.readOnly, equalTo(actualColumnMetaData.readOnly)); + collector.checkThat(expectedColumnMetaData.autoIncrement, equalTo(actualColumnMetaData.autoIncrement)); + collector.checkThat(expectedColumnMetaData.precision, equalTo(actualColumnMetaData.precision)); + collector.checkThat(expectedColumnMetaData.scale, equalTo(actualColumnMetaData.scale)); + collector.checkThat(expectedColumnMetaData.caseSensitive, equalTo(actualColumnMetaData.caseSensitive)); + collector.checkThat(expectedColumnMetaData.searchable, equalTo(actualColumnMetaData.searchable)); + } + } + + private void assertBuilder(final Common.ColumnMetaData.Builder builder, + final FlightSqlColumnMetadata flightSqlColumnMetaData) { + + final Integer precision = flightSqlColumnMetaData.getPrecision(); + final Integer scale = flightSqlColumnMetaData.getScale(); + + collector.checkThat(flightSqlColumnMetaData.getCatalogName(), equalTo(builder.getCatalogName())); + collector.checkThat(flightSqlColumnMetaData.getSchemaName(), equalTo(builder.getSchemaName())); + collector.checkThat(flightSqlColumnMetaData.getTableName(), equalTo(builder.getTableName())); + collector.checkThat(flightSqlColumnMetaData.isAutoIncrement(), equalTo(builder.getAutoIncrement())); + collector.checkThat(flightSqlColumnMetaData.isCaseSensitive(), equalTo(builder.getCaseSensitive())); + collector.checkThat(flightSqlColumnMetaData.isSearchable(), equalTo(builder.getSearchable())); + collector.checkThat(flightSqlColumnMetaData.isReadOnly(), equalTo(builder.getReadOnly())); + collector.checkThat(precision == null ? 0 : precision, equalTo(builder.getPrecision())); + collector.checkThat(scale == null ? 0 : scale, equalTo(builder.getScale())); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/CoreMockedSqlProducers.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/CoreMockedSqlProducers.java new file mode 100644 index 00000000000..cf359849a71 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/CoreMockedSqlProducers.java @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static java.lang.String.format; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; + +import java.sql.Date; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.Consumer; +import java.util.stream.IntStream; + +import org.apache.arrow.flight.FlightProducer.ServerStreamListener; +import org.apache.arrow.flight.sql.FlightSqlColumnMetadata; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.junit.rules.ErrorCollector; + +import com.google.common.collect.ImmutableList; + +/** + * Standard {@link MockFlightSqlProducer} instances for tests. + */ +// TODO Remove this once all tests are refactor to use only the queries they need. +public final class CoreMockedSqlProducers { + + public static final String LEGACY_REGULAR_SQL_CMD = "SELECT * FROM TEST"; + public static final String LEGACY_METADATA_SQL_CMD = "SELECT * FROM METADATA"; + public static final String LEGACY_CANCELLATION_SQL_CMD = "SELECT * FROM TAKES_FOREVER"; + + private CoreMockedSqlProducers() { + // Prevent instantiation. + } + + /** + * Gets the {@link MockFlightSqlProducer} for legacy tests and backward compatibility. + * + * @return a new producer. + */ + public static MockFlightSqlProducer getLegacyProducer() { + + final MockFlightSqlProducer producer = new MockFlightSqlProducer(); + addLegacyRegularSqlCmdSupport(producer); + addLegacyMetadataSqlCmdSupport(producer); + addLegacyCancellationSqlCmdSupport(producer); + return producer; + } + + private static void addLegacyRegularSqlCmdSupport(final MockFlightSqlProducer producer) { + final Schema querySchema = new Schema(ImmutableList.of( + new Field( + "ID", + new FieldType(true, new ArrowType.Int(64, true), + null), + null), + new Field( + "Name", + new FieldType(true, new ArrowType.Utf8(), null), + null), + new Field( + "Age", + new FieldType(true, new ArrowType.Int(32, false), + null), + null), + new Field( + "Salary", + new FieldType(true, new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE), + null), + null), + new Field( + "Hire Date", + new FieldType(true, new ArrowType.Date(DateUnit.DAY), null), + null), + new Field( + "Last Sale", + new FieldType(true, new ArrowType.Timestamp(TimeUnit.MILLISECOND, null), + null), + null) + )); + final List> resultProducers = new ArrayList<>(); + IntStream.range(0, 10).forEach(page -> { + resultProducers.add(listener -> { + final int rowsPerPage = 5000; + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(querySchema, allocator)) { + root.allocateNew(); + listener.start(root); + int batchSize = 500; + int indexOnBatch = 0; + int resultsOffset = page * rowsPerPage; + for (int i = 0; i < rowsPerPage; i++) { + ((BigIntVector) root.getVector("ID")) + .setSafe(indexOnBatch, (long) Integer.MAX_VALUE + 1 + i + resultsOffset); + ((VarCharVector) root.getVector("Name")) + .setSafe(indexOnBatch, new Text("Test Name #" + (resultsOffset + i))); + ((UInt4Vector) root.getVector("Age")) + .setSafe(indexOnBatch, (int) Short.MAX_VALUE + 1 + i + resultsOffset); + ((Float8Vector) root.getVector("Salary")) + .setSafe(indexOnBatch, + Math.scalb((double) (i + resultsOffset) / 2, i + resultsOffset)); + ((DateDayVector) root.getVector("Hire Date")) + .setSafe(indexOnBatch, i + resultsOffset); + ((TimeStampMilliVector) root.getVector("Last Sale")) + .setSafe(indexOnBatch, Long.MAX_VALUE - i - resultsOffset); + indexOnBatch++; + if (indexOnBatch == batchSize) { + root.setRowCount(indexOnBatch); + if (listener.isCancelled()) { + return; + } + listener.putNext(); + root.allocateNew(); + indexOnBatch = 0; + } + } + if (listener.isCancelled()) { + return; + } + root.setRowCount(indexOnBatch); + listener.putNext(); + } finally { + listener.completed(); + } + }); + }); + producer.addSelectQuery(LEGACY_REGULAR_SQL_CMD, querySchema, resultProducers); + } + + private static void addLegacyMetadataSqlCmdSupport(final MockFlightSqlProducer producer) { + final Schema metadataSchema = new Schema(ImmutableList.of( + new Field( + "integer0", + new FieldType(true, new ArrowType.Int(64, true), + null, new FlightSqlColumnMetadata.Builder() + .catalogName("CATALOG_NAME_1") + .schemaName("SCHEMA_NAME_1") + .tableName("TABLE_NAME_1") + .typeName("TYPE_NAME_1") + .precision(10) + .scale(0) + .isAutoIncrement(true) + .isCaseSensitive(false) + .isReadOnly(true) + .isSearchable(true) + .build().getMetadataMap()), + null), + new Field( + "string1", + new FieldType(true, new ArrowType.Utf8(), + null, new FlightSqlColumnMetadata.Builder() + .catalogName("CATALOG_NAME_2") + .schemaName("SCHEMA_NAME_2") + .tableName("TABLE_NAME_2") + .typeName("TYPE_NAME_2") + .precision(65535) + .scale(0) + .isAutoIncrement(false) + .isCaseSensitive(true) + .isReadOnly(false) + .isSearchable(true) + .build().getMetadataMap()), + null), + new Field( + "float2", + new FieldType(true, new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE), + null, new FlightSqlColumnMetadata.Builder() + .catalogName("CATALOG_NAME_3") + .schemaName("SCHEMA_NAME_3") + .tableName("TABLE_NAME_3") + .typeName("TYPE_NAME_3") + .precision(15) + .scale(20) + .isAutoIncrement(false) + .isCaseSensitive(false) + .isReadOnly(false) + .isSearchable(true) + .build().getMetadataMap()), + null))); + final Consumer formula = listener -> { + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final VectorSchemaRoot root = VectorSchemaRoot.create(metadataSchema, allocator)) { + root.allocateNew(); + ((BigIntVector) root.getVector("integer0")).setSafe(0, 1); + ((VarCharVector) root.getVector("string1")).setSafe(0, new Text("teste")); + ((Float4Vector) root.getVector("float2")).setSafe(0, (float) 4.1); + root.setRowCount(1); + listener.start(root); + listener.putNext(); + } finally { + listener.completed(); + } + }; + producer.addSelectQuery(LEGACY_METADATA_SQL_CMD, metadataSchema, + Collections.singletonList(formula)); + } + + private static void addLegacyCancellationSqlCmdSupport(final MockFlightSqlProducer producer) { + producer.addSelectQuery( + LEGACY_CANCELLATION_SQL_CMD, + new Schema(Collections.singletonList(new Field( + "integer0", + new FieldType(true, new ArrowType.Int(64, true), null), + null))), + Collections.singletonList(listener -> { + // Should keep hanging until canceled. + })); + } + + /** + * Asserts that the values in the provided {@link ResultSet} are expected for the + * legacy {@link MockFlightSqlProducer}. + * + * @param resultSet the result set. + * @param collector the {@link ErrorCollector} to use. + * @throws SQLException on error. + */ + public static void assertLegacyRegularSqlResultSet(final ResultSet resultSet, + final ErrorCollector collector) + throws SQLException { + final int expectedRowCount = 50_000; + + final long[] expectedIds = new long[expectedRowCount]; + final List expectedNames = new ArrayList<>(expectedRowCount); + final int[] expectedAges = new int[expectedRowCount]; + final double[] expectedSalaries = new double[expectedRowCount]; + final List expectedHireDates = new ArrayList<>(expectedRowCount); + final List expectedLastSales = new ArrayList<>(expectedRowCount); + + final long[] actualIds = new long[expectedRowCount]; + final List actualNames = new ArrayList<>(expectedRowCount); + final int[] actualAges = new int[expectedRowCount]; + final double[] actualSalaries = new double[expectedRowCount]; + final List actualHireDates = new ArrayList<>(expectedRowCount); + final List actualLastSales = new ArrayList<>(expectedRowCount); + + int actualRowCount = 0; + + for (; resultSet.next(); actualRowCount++) { + expectedIds[actualRowCount] = (long) Integer.MAX_VALUE + 1 + actualRowCount; + expectedNames.add(format("Test Name #%d", actualRowCount)); + expectedAges[actualRowCount] = (int) Short.MAX_VALUE + 1 + actualRowCount; + expectedSalaries[actualRowCount] = Math.scalb((double) actualRowCount / 2, actualRowCount); + expectedHireDates.add(new Date(86_400_000L * actualRowCount)); + expectedLastSales.add(new Timestamp(Long.MAX_VALUE - actualRowCount)); + + actualIds[actualRowCount] = (long) resultSet.getObject(1); + actualNames.add((String) resultSet.getObject(2)); + actualAges[actualRowCount] = (int) resultSet.getObject(3); + actualSalaries[actualRowCount] = (double) resultSet.getObject(4); + actualHireDates.add((Date) resultSet.getObject(5)); + actualLastSales.add((Timestamp) resultSet.getObject(6)); + } + collector.checkThat(actualRowCount, is(equalTo(expectedRowCount))); + collector.checkThat(actualIds, is(expectedIds)); + collector.checkThat(actualNames, is(expectedNames)); + collector.checkThat(actualAges, is(expectedAges)); + collector.checkThat(actualSalaries, is(expectedSalaries)); + collector.checkThat(actualHireDates, is(expectedHireDates)); + collector.checkThat(actualLastSales, is(expectedLastSales)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/DateTimeUtilsTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/DateTimeUtilsTest.java new file mode 100644 index 00000000000..adb892fcdc7 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/DateTimeUtilsTest.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static org.hamcrest.CoreMatchers.is; + +import java.sql.Timestamp; +import java.time.Instant; +import java.util.Calendar; +import java.util.TimeZone; + +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; + +public class DateTimeUtilsTest { + + @ClassRule + public static final ErrorCollector collector = new ErrorCollector(); + private final TimeZone defaultTimezone = TimeZone.getTimeZone("UTC"); + private final TimeZone alternateTimezone = TimeZone.getTimeZone("America/Vancouver"); + private final long positiveEpochMilli = 959817600000L; // 2000-06-01 00:00:00 UTC + private final long negativeEpochMilli = -618105600000L; // 1950-06-01 00:00:00 UTC + + @Test + public void testShouldGetOffsetWithSameTimeZone() { + final TimeZone currentTimezone = TimeZone.getDefault(); + + final long epochMillis = positiveEpochMilli; + final long offset = defaultTimezone.getOffset(epochMillis); + + TimeZone.setDefault(defaultTimezone); + + try { // Trying to guarantee timezone returns to its original value + final long expected = epochMillis + offset; + final long actual = DateTimeUtils.applyCalendarOffset(epochMillis, Calendar.getInstance(defaultTimezone)); + + collector.checkThat(actual, is(expected)); + } finally { + // Reset Timezone + TimeZone.setDefault(currentTimezone); + } + } + + @Test + public void testShouldGetOffsetWithDifferentTimeZone() { + final TimeZone currentTimezone = TimeZone.getDefault(); + + final long epochMillis = negativeEpochMilli; + final long offset = alternateTimezone.getOffset(epochMillis); + + TimeZone.setDefault(alternateTimezone); + + try { // Trying to guarantee timezone returns to its original value + final long expectedEpochMillis = epochMillis + offset; + final long actualEpochMillis = DateTimeUtils.applyCalendarOffset(epochMillis, Calendar.getInstance( + defaultTimezone)); + + collector.checkThat(actualEpochMillis, is(expectedEpochMillis)); + } finally { + // Reset Timezone + TimeZone.setDefault(currentTimezone); + } + } + + @Test + public void testShouldGetTimestampPositive() { + long epochMilli = positiveEpochMilli; + final Instant instant = Instant.ofEpochMilli(epochMilli); + + final Timestamp expected = Timestamp.from(instant); + final Timestamp actual = DateTimeUtils.getTimestampValue(epochMilli); + + collector.checkThat(expected, is(actual)); + } + + @Test + public void testShouldGetTimestampNegative() { + final long epochMilli = negativeEpochMilli; + final Instant instant = Instant.ofEpochMilli(epochMilli); + + final Timestamp expected = Timestamp.from(instant); + final Timestamp actual = DateTimeUtils.getTimestampValue(epochMilli); + + collector.checkThat(expected, is(actual)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightSqlTestCertificates.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightSqlTestCertificates.java new file mode 100644 index 00000000000..a2b1864c026 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightSqlTestCertificates.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.io.File; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +/** + * Utility class for unit tests that need to reference the certificate params. + */ +public class FlightSqlTestCertificates { + + public static final String TEST_DATA_ENV_VAR = "ARROW_TEST_DATA"; + public static final String TEST_DATA_PROPERTY = "arrow.test.dataRoot"; + + static Path getTestDataRoot() { + String path = System.getenv(TEST_DATA_ENV_VAR); + if (path == null) { + path = System.getProperty(TEST_DATA_PROPERTY); + } + return Paths.get(Objects.requireNonNull(path, + String.format("Could not find test data path. Set the environment variable %s or the JVM property %s.", + TEST_DATA_ENV_VAR, TEST_DATA_PROPERTY))); + } + + /** + * Get the Path from the Files to be used in the encrypted test of Flight. + * + * @return the Path from the Files with certificates and keys. + */ + static Path getFlightTestDataRoot() { + return getTestDataRoot().resolve("flight"); + } + + /** + * Create CertKeyPair object with the certificates and keys. + * + * @return A list with CertKeyPair. + */ + public static List exampleTlsCerts() { + final Path root = getFlightTestDataRoot(); + return Arrays.asList(new CertKeyPair(root.resolve("cert0.pem") + .toFile(), root.resolve("cert0.pkcs1").toFile()), + new CertKeyPair(root.resolve("cert1.pem") + .toFile(), root.resolve("cert1.pkcs1").toFile())); + } + + public static class CertKeyPair { + + public final File cert; + public final File key; + + public CertKeyPair(File cert, File key) { + this.cert = cert; + this.key = key; + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueueTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueueTest.java new file mode 100644 index 00000000000..b474da55a7f --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueueTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.mockito.Mockito.mock; + +import java.util.concurrent.CompletionService; + +import org.apache.arrow.flight.FlightStream; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ErrorCollector; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; + +/** + * Tests for {@link FlightStreamQueue}. + */ +@RunWith(MockitoJUnitRunner.class) +public class FlightStreamQueueTest { + + @Rule + public final ErrorCollector collector = new ErrorCollector(); + @Mock + private CompletionService mockedService; + private FlightStreamQueue queue; + + @Before + public void setUp() { + queue = new FlightStreamQueue(mockedService); + } + + @Test + public void testNextShouldRetrieveNullIfEmpty() throws Exception { + collector.checkThat(queue.next(), is(nullValue())); + } + + @Test + public void testNextShouldThrowExceptionUponClose() throws Exception { + queue.close(); + ThrowableAssertionUtils.simpleAssertThrowableClass(IllegalStateException.class, () -> queue.next()); + } + + @Test + public void testEnqueueShouldThrowExceptionUponClose() throws Exception { + queue.close(); + ThrowableAssertionUtils.simpleAssertThrowableClass(IllegalStateException.class, + () -> queue.enqueue(mock(FlightStream.class))); + } + + @Test + public void testCheckOpen() throws Exception { + collector.checkSucceeds(() -> { + queue.checkOpen(); + return true; + }); + queue.close(); + ThrowableAssertionUtils.simpleAssertThrowableClass(IllegalStateException.class, () -> queue.checkOpen()); + } + + @Test + public void testShouldCloseQueue() throws Exception { + collector.checkThat(queue.isClosed(), is(false)); + queue.close(); + collector.checkThat(queue.isClosed(), is(true)); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java new file mode 100644 index 00000000000..cc8fae9722f --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/MockFlightSqlProducer.java @@ -0,0 +1,539 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static com.google.protobuf.Any.pack; +import static com.google.protobuf.ByteString.copyFrom; +import static java.lang.String.format; +import static java.util.UUID.randomUUID; +import static java.util.stream.Collectors.toList; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.charset.StandardCharsets; +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.UUID; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import java.util.stream.IntStream; + +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.SqlInfoBuilder; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest; +import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest; +import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCrossReference; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetExportedKeys; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetImportedKeys; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetPrimaryKeys; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetSqlInfo; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTableTypes; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementUpdate; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery; +import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; +import org.apache.arrow.flight.sql.impl.FlightSql.DoPutUpdateResult; +import org.apache.arrow.flight.sql.impl.FlightSql.TicketStatementQuery; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.ipc.WriteChannel; +import org.apache.arrow.vector.ipc.message.MessageSerializer; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.calcite.avatica.Meta.StatementType; + +import com.google.protobuf.Any; +import com.google.protobuf.ByteString; +import com.google.protobuf.Message; + +/** + * An ad-hoc {@link FlightSqlProducer} for tests. + */ +public final class MockFlightSqlProducer implements FlightSqlProducer { + + private final Map>> queryResults = new HashMap<>(); + private final Map> selectResultProviders = new HashMap<>(); + private final Map preparedStatements = new HashMap<>(); + private final Map> catalogQueriesResults = + new HashMap<>(); + private final Map>> + updateResultProviders = + new HashMap<>(); + private SqlInfoBuilder sqlInfoBuilder = new SqlInfoBuilder(); + + private static FlightInfo getFightInfoExportedAndImportedKeys(final Message message, + final FlightDescriptor descriptor) { + return getFlightInfo(message, Schemas.GET_IMPORTED_KEYS_SCHEMA, descriptor); + } + + private static FlightInfo getFlightInfo(final Message message, final Schema schema, + final FlightDescriptor descriptor) { + return new FlightInfo( + schema, + descriptor, + Collections.singletonList(new FlightEndpoint(new Ticket(Any.pack(message).toByteArray()))), + -1, -1); + } + + public static ByteBuffer serializeSchema(final Schema schema) { + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + try { + MessageSerializer.serialize(new WriteChannel(Channels.newChannel(outputStream)), schema); + + return ByteBuffer.wrap(outputStream.toByteArray()); + } catch (final IOException e) { + throw new RuntimeException("Failed to serialize schema", e); + } + } + + /** + * Registers a new {@link StatementType#SELECT} SQL query. + * + * @param sqlCommand the SQL command under which to register the new query. + * @param schema the schema to use for the query result. + * @param resultProviders the result provider for this query. + */ + public void addSelectQuery(final String sqlCommand, final Schema schema, + final List> resultProviders) { + final int providers = resultProviders.size(); + final List uuids = + IntStream.range(0, providers) + .mapToObj(index -> new UUID(sqlCommand.hashCode(), Integer.hashCode(index))) + .collect(toList()); + queryResults.put(sqlCommand, new SimpleImmutableEntry<>(schema, uuids)); + IntStream.range(0, providers) + .forEach( + index -> this.selectResultProviders.put(uuids.get(index), resultProviders.get(index))); + } + + /** + * Registers a new {@link StatementType#UPDATE} SQL query. + * + * @param sqlCommand the SQL command. + * @param updatedRows the number of rows affected. + */ + public void addUpdateQuery(final String sqlCommand, final long updatedRows) { + addUpdateQuery(sqlCommand, ((flightStream, putResultStreamListener) -> { + final DoPutUpdateResult result = + DoPutUpdateResult.newBuilder().setRecordCount(updatedRows).build(); + try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + final ArrowBuf buffer = allocator.buffer(result.getSerializedSize())) { + buffer.writeBytes(result.toByteArray()); + putResultStreamListener.onNext(PutResult.metadata(buffer)); + } catch (final Throwable throwable) { + putResultStreamListener.onError(throwable); + } finally { + putResultStreamListener.onCompleted(); + } + })); + } + + /** + * Adds a catalog query to the results. + * + * @param message the {@link Message} corresponding to the catalog query request type to register. + * @param resultsProvider the results provider. + */ + public void addCatalogQuery(final Message message, + final Consumer resultsProvider) { + catalogQueriesResults.put(message, resultsProvider); + } + + /** + * Registers a new {@link StatementType#UPDATE} SQL query. + * + * @param sqlCommand the SQL command. + * @param resultsProvider consumer for producing update results. + */ + void addUpdateQuery(final String sqlCommand, + final BiConsumer> resultsProvider) { + Preconditions.checkState( + updateResultProviders.putIfAbsent(sqlCommand, resultsProvider) == null, + format("Attempted to overwrite pre-existing query: <%s>.", sqlCommand)); + } + + @Override + public void createPreparedStatement(final ActionCreatePreparedStatementRequest request, + final CallContext callContext, + final StreamListener listener) { + try { + final ByteString preparedStatementHandle = + copyFrom(randomUUID().toString().getBytes(StandardCharsets.UTF_8)); + final String query = request.getQuery(); + + final ActionCreatePreparedStatementResult.Builder resultBuilder = + ActionCreatePreparedStatementResult.newBuilder() + .setPreparedStatementHandle(preparedStatementHandle); + + final Entry> entry = queryResults.get(query); + if (entry != null) { + preparedStatements.put(preparedStatementHandle, query); + + final Schema datasetSchema = entry.getKey(); + final ByteString datasetSchemaBytes = + ByteString.copyFrom(serializeSchema(datasetSchema)); + + resultBuilder.setDatasetSchema(datasetSchemaBytes); + } else if (updateResultProviders.containsKey(query)) { + preparedStatements.put(preparedStatementHandle, query); + + } else { + listener.onError( + CallStatus.INVALID_ARGUMENT.withDescription("Query not found").toRuntimeException()); + return; + } + + listener.onNext(new Result(pack(resultBuilder.build()).toByteArray())); + } catch (final Throwable t) { + listener.onError(t); + } finally { + listener.onCompleted(); + } + } + + @Override + public void closePreparedStatement( + final ActionClosePreparedStatementRequest actionClosePreparedStatementRequest, + final CallContext callContext, final StreamListener streamListener) { + // TODO Implement this method. + streamListener.onCompleted(); + } + + @Override + public FlightInfo getFlightInfoStatement(final CommandStatementQuery commandStatementQuery, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + final String query = commandStatementQuery.getQuery(); + final Entry> queryInfo = + Preconditions.checkNotNull(queryResults.get(query), + format("Query not registered: <%s>.", query)); + final List endpoints = + queryInfo.getValue().stream() + .map(TicketConversionUtils::getTicketBytesFromUuid) + .map(TicketConversionUtils::getTicketStatementQueryFromHandle) + .map(TicketConversionUtils::getEndpointFromMessage) + .collect(toList()); + return new FlightInfo(queryInfo.getKey(), flightDescriptor, endpoints, -1, -1); + } + + @Override + public FlightInfo getFlightInfoPreparedStatement( + final CommandPreparedStatementQuery commandPreparedStatementQuery, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + final ByteString preparedStatementHandle = + commandPreparedStatementQuery.getPreparedStatementHandle(); + + final String query = Preconditions.checkNotNull( + preparedStatements.get(preparedStatementHandle), + format("No query registered under handle: <%s>.", preparedStatementHandle)); + final Entry> queryInfo = + Preconditions.checkNotNull(queryResults.get(query), + format("Query not registered: <%s>.", query)); + final List endpoints = + queryInfo.getValue().stream() + .map(TicketConversionUtils::getTicketBytesFromUuid) + .map(TicketConversionUtils::getCommandPreparedStatementQueryFromHandle) + .map(TicketConversionUtils::getEndpointFromMessage) + .collect(toList()); + return new FlightInfo(queryInfo.getKey(), flightDescriptor, endpoints, -1, -1); + } + + @Override + public SchemaResult getSchemaStatement(final CommandStatementQuery commandStatementQuery, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + final String query = commandStatementQuery.getQuery(); + final Entry> queryInfo = + Preconditions.checkNotNull(queryResults.get(query), + format("Query not registered: <%s>.", query)); + + return new SchemaResult(queryInfo.getKey()); + } + + @Override + public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + final UUID uuid = UUID.fromString(ticketStatementQuery.getStatementHandle().toStringUtf8()); + Preconditions.checkNotNull( + selectResultProviders.get(uuid), + "No consumer was registered for the specified UUID: <%s>.", uuid) + .accept(serverStreamListener); + } + + @Override + public void getStreamPreparedStatement( + final CommandPreparedStatementQuery commandPreparedStatementQuery, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + final UUID uuid = + UUID.fromString(commandPreparedStatementQuery.getPreparedStatementHandle().toStringUtf8()); + Preconditions.checkNotNull( + selectResultProviders.get(uuid), + "No consumer was registered for the specified UUID: <%s>.", uuid) + .accept(serverStreamListener); + } + + @Override + public Runnable acceptPutStatement(final CommandStatementUpdate commandStatementUpdate, + final CallContext callContext, + final FlightStream flightStream, + final StreamListener streamListener) { + return () -> { + final String query = commandStatementUpdate.getQuery(); + final BiConsumer> resultProvider = + Preconditions.checkNotNull( + updateResultProviders.get(query), + format("No consumer found for query: <%s>.", query)); + resultProvider.accept(flightStream, streamListener); + }; + } + + @Override + public Runnable acceptPutPreparedStatementUpdate( + final CommandPreparedStatementUpdate commandPreparedStatementUpdate, + final CallContext callContext, final FlightStream flightStream, + final StreamListener streamListener) { + final ByteString handle = commandPreparedStatementUpdate.getPreparedStatementHandle(); + final String query = Preconditions.checkNotNull( + preparedStatements.get(handle), + format("No query registered under handle: <%s>.", handle)); + return acceptPutStatement( + CommandStatementUpdate.newBuilder().setQuery(query).build(), callContext, flightStream, + streamListener); + } + + @Override + public Runnable acceptPutPreparedStatementQuery( + final CommandPreparedStatementQuery commandPreparedStatementQuery, + final CallContext callContext, final FlightStream flightStream, + final StreamListener streamListener) { + // TODO Implement this method. + throw CallStatus.UNIMPLEMENTED.toRuntimeException(); + } + + @Override + public FlightInfo getFlightInfoSqlInfo(final CommandGetSqlInfo commandGetSqlInfo, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFlightInfo(commandGetSqlInfo, Schemas.GET_SQL_INFO_SCHEMA, flightDescriptor); + } + + @Override + public void getStreamSqlInfo(final CommandGetSqlInfo commandGetSqlInfo, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + sqlInfoBuilder.send(commandGetSqlInfo.getInfoList(), serverStreamListener); + } + + @Override + public FlightInfo getFlightInfoTypeInfo(FlightSql.CommandGetXdbcTypeInfo request, CallContext context, + FlightDescriptor descriptor) { + // TODO Implement this + return null; + } + + @Override + public void getStreamTypeInfo(FlightSql.CommandGetXdbcTypeInfo request, CallContext context, + ServerStreamListener listener) { + // TODO Implement this + } + + @Override + public FlightInfo getFlightInfoCatalogs(final CommandGetCatalogs commandGetCatalogs, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFlightInfo(commandGetCatalogs, Schemas.GET_CATALOGS_SCHEMA, flightDescriptor); + } + + @Override + public void getStreamCatalogs(final CallContext callContext, + final ServerStreamListener serverStreamListener) { + final CommandGetCatalogs command = CommandGetCatalogs.getDefaultInstance(); + getStreamCatalogFunctions(command, serverStreamListener); + } + + @Override + public FlightInfo getFlightInfoSchemas(final CommandGetDbSchemas commandGetSchemas, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFlightInfo(commandGetSchemas, Schemas.GET_SCHEMAS_SCHEMA, flightDescriptor); + } + + @Override + public void getStreamSchemas(final CommandGetDbSchemas commandGetSchemas, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + getStreamCatalogFunctions(commandGetSchemas, serverStreamListener); + } + + @Override + public FlightInfo getFlightInfoTables(final CommandGetTables commandGetTables, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFlightInfo(commandGetTables, Schemas.GET_TABLES_SCHEMA_NO_SCHEMA, flightDescriptor); + } + + @Override + public void getStreamTables(final CommandGetTables commandGetTables, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + getStreamCatalogFunctions(commandGetTables, serverStreamListener); + } + + @Override + public FlightInfo getFlightInfoTableTypes(final CommandGetTableTypes commandGetTableTypes, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFlightInfo(commandGetTableTypes, Schemas.GET_TABLE_TYPES_SCHEMA, flightDescriptor); + } + + @Override + public void getStreamTableTypes(final CallContext callContext, + final ServerStreamListener serverStreamListener) { + final CommandGetTableTypes command = CommandGetTableTypes.getDefaultInstance(); + getStreamCatalogFunctions(command, serverStreamListener); + } + + @Override + public FlightInfo getFlightInfoPrimaryKeys(final CommandGetPrimaryKeys commandGetPrimaryKeys, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFlightInfo(commandGetPrimaryKeys, Schemas.GET_PRIMARY_KEYS_SCHEMA, flightDescriptor); + } + + @Override + public void getStreamPrimaryKeys(final CommandGetPrimaryKeys commandGetPrimaryKeys, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + getStreamCatalogFunctions(commandGetPrimaryKeys, serverStreamListener); + } + + @Override + public FlightInfo getFlightInfoExportedKeys(final CommandGetExportedKeys commandGetExportedKeys, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFightInfoExportedAndImportedKeys(commandGetExportedKeys, flightDescriptor); + } + + @Override + public FlightInfo getFlightInfoImportedKeys(final CommandGetImportedKeys commandGetImportedKeys, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFightInfoExportedAndImportedKeys(commandGetImportedKeys, flightDescriptor); + } + + @Override + public FlightInfo getFlightInfoCrossReference( + final CommandGetCrossReference commandGetCrossReference, + final CallContext callContext, + final FlightDescriptor flightDescriptor) { + return getFightInfoExportedAndImportedKeys(commandGetCrossReference, flightDescriptor); + } + + @Override + public void getStreamExportedKeys(final CommandGetExportedKeys commandGetExportedKeys, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + getStreamCatalogFunctions(commandGetExportedKeys, serverStreamListener); + } + + @Override + public void getStreamImportedKeys(final CommandGetImportedKeys commandGetImportedKeys, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + getStreamCatalogFunctions(commandGetImportedKeys, serverStreamListener); + } + + @Override + public void getStreamCrossReference(final CommandGetCrossReference commandGetCrossReference, + final CallContext callContext, + final ServerStreamListener serverStreamListener) { + getStreamCatalogFunctions(commandGetCrossReference, serverStreamListener); + } + + @Override + public void close() { + // TODO No-op. + } + + @Override + public void listFlights(final CallContext callContext, final Criteria criteria, + final StreamListener streamListener) { + // TODO Implement this method. + throw CallStatus.UNIMPLEMENTED.toRuntimeException(); + } + + private void getStreamCatalogFunctions(final Message ticket, + final ServerStreamListener serverStreamListener) { + Preconditions.checkNotNull( + catalogQueriesResults.get(ticket), + format("Query not registered for ticket: <%s>", ticket)) + .accept(serverStreamListener); + } + + public SqlInfoBuilder getSqlInfoBuilder() { + return sqlInfoBuilder; + } + + private static final class TicketConversionUtils { + private TicketConversionUtils() { + // Prevent instantiation. + } + + private static ByteString getTicketBytesFromUuid(final UUID uuid) { + return ByteString.copyFromUtf8(uuid.toString()); + } + + private static TicketStatementQuery getTicketStatementQueryFromHandle(final ByteString handle) { + return TicketStatementQuery.newBuilder().setStatementHandle(handle).build(); + } + + private static CommandPreparedStatementQuery getCommandPreparedStatementQueryFromHandle( + final ByteString handle) { + return CommandPreparedStatementQuery.newBuilder().setPreparedStatementHandle(handle).build(); + } + + private static FlightEndpoint getEndpointFromMessage(final Message message) { + return new FlightEndpoint(new Ticket(Any.pack(message).toByteArray())); + } + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ResultSetTestUtils.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ResultSetTestUtils.java new file mode 100644 index 00000000000..d5ce7fb8fb3 --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ResultSetTestUtils.java @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static java.util.stream.IntStream.range; +import static org.hamcrest.CoreMatchers.is; + +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; + +import org.apache.arrow.util.Preconditions; +import org.junit.rules.ErrorCollector; + +/** + * Utility class for testing that require asserting that the values in a {@link ResultSet} are expected. + */ +public final class ResultSetTestUtils { + private final ErrorCollector collector; + + public ResultSetTestUtils(final ErrorCollector collector) { + this.collector = + Preconditions.checkNotNull(collector, "Error collector cannot be null."); + } + + /** + * Checks that the values (rows and columns) in the provided {@link ResultSet} are expected. + * + * @param resultSet the {@code ResultSet} to assert. + * @param expectedResults the rows and columns representing the only values the {@code resultSet} + * is expected to have. + * @param collector the {@link ErrorCollector} to use for asserting that the {@code resultSet} + * has the expected values. + * @param the type to be found in the expected results for the {@code resultSet}. + * @throws SQLException if querying the {@code ResultSet} fails at some point unexpectedly. + */ + public static void testData(final ResultSet resultSet, final List> expectedResults, + final ErrorCollector collector) + throws SQLException { + testData( + resultSet, + range(1, resultSet.getMetaData().getColumnCount() + 1).toArray(), + expectedResults, + collector); + } + + /** + * Checks that the values (rows and columns) in the provided {@link ResultSet} are expected. + * + * @param resultSet the {@code ResultSet} to assert. + * @param columnNames the column names to fetch in the {@code ResultSet} for comparison. + * @param expectedResults the rows and columns representing the only values the {@code resultSet} + * is expected to have. + * @param collector the {@link ErrorCollector} to use for asserting that the {@code resultSet} + * has the expected values. + * @param the type to be found in the expected results for the {@code resultSet}. + * @throws SQLException if querying the {@code ResultSet} fails at some point unexpectedly. + */ + @SuppressWarnings("unchecked") + public static void testData(final ResultSet resultSet, final List columnNames, + final List> expectedResults, + final ErrorCollector collector) + throws SQLException { + testData( + resultSet, + data -> { + final List columns = new ArrayList<>(); + for (final String columnName : columnNames) { + try { + columns.add((T) resultSet.getObject(columnName)); + } catch (final SQLException e) { + collector.addError(e); + } + } + return columns; + }, + expectedResults, + collector); + } + + /** + * Checks that the values (rows and columns) in the provided {@link ResultSet} are expected. + * + * @param resultSet the {@code ResultSet} to assert. + * @param columnIndices the column indices to fetch in the {@code ResultSet} for comparison. + * @param expectedResults the rows and columns representing the only values the {@code resultSet} + * is expected to have. + * @param collector the {@link ErrorCollector} to use for asserting that the {@code resultSet} + * has the expected values. + * @param the type to be found in the expected results for the {@code resultSet}. + * @throws SQLException if querying the {@code ResultSet} fails at some point unexpectedly. + */ + @SuppressWarnings("unchecked") + public static void testData(final ResultSet resultSet, final int[] columnIndices, + final List> expectedResults, + final ErrorCollector collector) + throws SQLException { + testData( + resultSet, + data -> { + final List columns = new ArrayList<>(); + for (final int columnIndex : columnIndices) { + try { + columns.add((T) resultSet.getObject(columnIndex)); + } catch (final SQLException e) { + collector.addError(e); + } + } + return columns; + }, + expectedResults, + collector); + } + + /** + * Checks that the values (rows and columns) in the provided {@link ResultSet} are expected. + * + * @param resultSet the {@code ResultSet} to assert. + * @param dataConsumer the column indices to fetch in the {@code ResultSet} for comparison. + * @param expectedResults the rows and columns representing the only values the {@code resultSet} + * is expected to have. + * @param collector the {@link ErrorCollector} to use for asserting that the {@code resultSet} + * has the expected values. + * @param the type to be found in the expected results for the {@code resultSet}. + * @throws SQLException if querying the {@code ResultSet} fails at some point unexpectedly. + */ + public static void testData(final ResultSet resultSet, + final Function> dataConsumer, + final List> expectedResults, + final ErrorCollector collector) + throws SQLException { + final List> actualResults = new ArrayList<>(); + while (resultSet.next()) { + actualResults.add(dataConsumer.apply(resultSet)); + } + collector.checkThat(actualResults, is(expectedResults)); + } + + /** + * Checks that the values (rows and columns) in the provided {@link ResultSet} are expected. + * + * @param resultSet the {@code ResultSet} to assert. + * @param expectedResults the rows and columns representing the only values the {@code resultSet} is expected to have. + * @param the type to be found in the expected results for the {@code resultSet}. + * @throws SQLException if querying the {@code ResultSet} fails at some point unexpectedly. + */ + public void testData(final ResultSet resultSet, final List> expectedResults) + throws SQLException { + testData(resultSet, expectedResults, collector); + } + + /** + * Checks that the values (rows and columns) in the provided {@link ResultSet} are expected. + * + * @param resultSet the {@code ResultSet} to assert. + * @param columnNames the column names to fetch in the {@code ResultSet} for comparison. + * @param expectedResults the rows and columns representing the only values the {@code resultSet} is expected to have. + * @param the type to be found in the expected results for the {@code resultSet}. + * @throws SQLException if querying the {@code ResultSet} fails at some point unexpectedly. + */ + @SuppressWarnings("unchecked") + public void testData(final ResultSet resultSet, final List columnNames, + final List> expectedResults) throws SQLException { + testData(resultSet, columnNames, expectedResults, collector); + } + + /** + * Checks that the values (rows and columns) in the provided {@link ResultSet} are expected. + * + * @param resultSet the {@code ResultSet} to assert. + * @param columnIndices the column indices to fetch in the {@code ResultSet} for comparison. + * @param expectedResults the rows and columns representing the only values the {@code resultSet} is expected to have. + * @param the type to be found in the expected results for the {@code resultSet}. + * @throws SQLException if querying the {@code ResultSet} fails at some point unexpectedly. + */ + @SuppressWarnings("unchecked") + public void testData(final ResultSet resultSet, final int[] columnIndices, + final List> expectedResults) throws SQLException { + testData(resultSet, columnIndices, expectedResults, collector); + } + + /** + * Checks that the values (rows and columns) in the provided {@link ResultSet} are expected. + * + * @param resultSet the {@code ResultSet} to assert. + * @param dataConsumer the column indices to fetch in the {@code ResultSet} for comparison. + * @param expectedResults the rows and columns representing the only values the {@code resultSet} is expected to have. + * @param the type to be found in the expected results for the {@code resultSet}. + * @throws SQLException if querying the {@code ResultSet} fails at some point unexpectedly. + */ + public void testData(final ResultSet resultSet, + final Function> dataConsumer, + final List> expectedResults) throws SQLException { + testData(resultSet, dataConsumer, expectedResults, collector); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/RootAllocatorTestRule.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/RootAllocatorTestRule.java new file mode 100644 index 00000000000..a200fc8d39c --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/RootAllocatorTestRule.java @@ -0,0 +1,820 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.math.BigDecimal; +import java.util.Random; +import java.util.concurrent.TimeUnit; +import java.util.stream.IntStream; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.Decimal256Vector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FixedSizeBinaryVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.LargeVarBinaryVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeNanoVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampNanoTZVector; +import org.apache.arrow.vector.TimeStampNanoVector; +import org.apache.arrow.vector.TimeStampSecTZVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.UInt1Vector; +import org.apache.arrow.vector.UInt2Vector; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.UInt8Vector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.complex.FixedSizeListVector; +import org.apache.arrow.vector.complex.LargeListVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionFixedSizeListWriter; +import org.apache.arrow.vector.complex.impl.UnionLargeListWriter; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; + +public class RootAllocatorTestRule implements TestRule, AutoCloseable { + + public static final byte MAX_VALUE = Byte.MAX_VALUE; + private final BufferAllocator rootAllocator = new RootAllocator(); + + private final Random random = new Random(10); + + @Override + public Statement apply(Statement base, Description description) { + return new Statement() { + @Override + public void evaluate() throws Throwable { + try { + base.evaluate(); + } finally { + close(); + } + } + }; + } + + public BufferAllocator getRootAllocator() { + return rootAllocator; + } + + @Override + public void close() throws Exception { + this.rootAllocator.getChildAllocators().forEach(BufferAllocator::close); + AutoCloseables.close(this.rootAllocator); + } + + /** + * Create a Float8Vector to be used in the accessor tests. + * + * @return Float8Vector + */ + public Float8Vector createFloat8Vector() { + double[] doubleVectorValues = new double[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE, + Long.MIN_VALUE, + Long.MAX_VALUE, + Float.MAX_VALUE, + -Float.MAX_VALUE, + Float.NEGATIVE_INFINITY, + Float.POSITIVE_INFINITY, + Float.MIN_VALUE, + -Float.MIN_VALUE, + Double.MAX_VALUE, + -Double.MAX_VALUE, + Double.NEGATIVE_INFINITY, + Double.POSITIVE_INFINITY, + Double.MIN_VALUE, + -Double.MIN_VALUE, + }; + + Float8Vector result = new Float8Vector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < doubleVectorValues.length) { + result.setSafe(i, doubleVectorValues[i]); + } else { + result.setSafe(i, random.nextDouble()); + } + } + + return result; + } + + public Float8Vector createFloat8VectorForNullTests() { + final Float8Vector float8Vector = new Float8Vector("ID", this.getRootAllocator()); + float8Vector.allocateNew(1); + float8Vector.setNull(0); + float8Vector.setValueCount(1); + + return float8Vector; + } + + /** + * Create a Float4Vector to be used in the accessor tests. + * + * @return Float4Vector + */ + public Float4Vector createFloat4Vector() { + + float[] floatVectorValues = new float[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE, + Long.MIN_VALUE, + Long.MAX_VALUE, + Float.MAX_VALUE, + -Float.MAX_VALUE, + Float.NEGATIVE_INFINITY, + Float.POSITIVE_INFINITY, + Float.MIN_VALUE, + -Float.MIN_VALUE, + }; + + Float4Vector result = new Float4Vector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < floatVectorValues.length) { + result.setSafe(i, floatVectorValues[i]); + } else { + result.setSafe(i, random.nextFloat()); + } + } + + return result; + } + + /** + * Create a IntVector to be used in the accessor tests. + * + * @return IntVector + */ + public IntVector createIntVector() { + + int[] intVectorValues = new int[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE, + }; + + IntVector result = new IntVector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < intVectorValues.length) { + result.setSafe(i, intVectorValues[i]); + } else { + result.setSafe(i, random.nextInt()); + } + } + + return result; + } + + /** + * Create a SmallIntVector to be used in the accessor tests. + * + * @return SmallIntVector + */ + public SmallIntVector createSmallIntVector() { + + short[] smallIntVectorValues = new short[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE, + }; + + SmallIntVector result = new SmallIntVector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < smallIntVectorValues.length) { + result.setSafe(i, smallIntVectorValues[i]); + } else { + result.setSafe(i, random.nextInt(Short.MAX_VALUE)); + } + } + + return result; + } + + /** + * Create a TinyIntVector to be used in the accessor tests. + * + * @return TinyIntVector + */ + public TinyIntVector createTinyIntVector() { + + byte[] byteVectorValues = new byte[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + }; + + TinyIntVector result = new TinyIntVector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < byteVectorValues.length) { + result.setSafe(i, byteVectorValues[i]); + } else { + result.setSafe(i, random.nextInt(Byte.MAX_VALUE)); + } + } + + return result; + } + + /** + * Create a BigIntVector to be used in the accessor tests. + * + * @return BigIntVector + */ + public BigIntVector createBigIntVector() { + + long[] longVectorValues = new long[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE, + Long.MIN_VALUE, + Long.MAX_VALUE, + }; + + BigIntVector result = new BigIntVector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < longVectorValues.length) { + result.setSafe(i, longVectorValues[i]); + } else { + result.setSafe(i, random.nextLong()); + } + } + + return result; + } + + /** + * Create a UInt1Vector to be used in the accessor tests. + * + * @return UInt1Vector + */ + public UInt1Vector createUInt1Vector() { + + short[] uInt1VectorValues = new short[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + }; + + UInt1Vector result = new UInt1Vector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < uInt1VectorValues.length) { + result.setSafe(i, uInt1VectorValues[i]); + } else { + result.setSafe(i, random.nextInt(0x100)); + } + } + + return result; + } + + /** + * Create a UInt2Vector to be used in the accessor tests. + * + * @return UInt2Vector + */ + public UInt2Vector createUInt2Vector() { + + int[] uInt2VectorValues = new int[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE, + }; + + UInt2Vector result = new UInt2Vector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < uInt2VectorValues.length) { + result.setSafe(i, uInt2VectorValues[i]); + } else { + result.setSafe(i, random.nextInt(0x10000)); + } + } + + return result; + } + + /** + * Create a UInt4Vector to be used in the accessor tests. + * + * @return UInt4Vector + */ + public UInt4Vector createUInt4Vector() { + + + int[] uInt4VectorValues = new int[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE + }; + + UInt4Vector result = new UInt4Vector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < uInt4VectorValues.length) { + result.setSafe(i, uInt4VectorValues[i]); + } else { + result.setSafe(i, random.nextInt(Integer.MAX_VALUE)); + } + } + + return result; + } + + /** + * Create a UInt8Vector to be used in the accessor tests. + * + * @return UInt8Vector + */ + public UInt8Vector createUInt8Vector() { + + long[] uInt8VectorValues = new long[] { + 0, + 1, + -1, + Byte.MIN_VALUE, + Byte.MAX_VALUE, + Short.MIN_VALUE, + Short.MAX_VALUE, + Integer.MIN_VALUE, + Integer.MAX_VALUE, + Long.MIN_VALUE, + Long.MAX_VALUE + }; + + UInt8Vector result = new UInt8Vector("", this.getRootAllocator()); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < uInt8VectorValues.length) { + result.setSafe(i, uInt8VectorValues[i]); + } else { + result.setSafe(i, random.nextLong()); + } + } + + return result; + } + + /** + * Create a VarBinaryVector to be used in the accessor tests. + * + * @return VarBinaryVector + */ + public VarBinaryVector createVarBinaryVector() { + return createVarBinaryVector(""); + } + + /** + * Create a VarBinaryVector to be used in the accessor tests. + * + * @return VarBinaryVector + */ + public VarBinaryVector createVarBinaryVector(final String fieldName) { + VarBinaryVector valueVector = new VarBinaryVector(fieldName, this.getRootAllocator()); + valueVector.allocateNew(3); + valueVector.setSafe(0, (fieldName + "__BINARY_DATA_0001").getBytes()); + valueVector.setSafe(1, (fieldName + "__BINARY_DATA_0002").getBytes()); + valueVector.setSafe(2, (fieldName + "__BINARY_DATA_0003").getBytes()); + valueVector.setValueCount(3); + + return valueVector; + } + + /** + * Create a LargeVarBinaryVector to be used in the accessor tests. + * + * @return LargeVarBinaryVector + */ + public LargeVarBinaryVector createLargeVarBinaryVector() { + LargeVarBinaryVector valueVector = new LargeVarBinaryVector("", this.getRootAllocator()); + valueVector.allocateNew(3); + valueVector.setSafe(0, "BINARY_DATA_0001".getBytes()); + valueVector.setSafe(1, "BINARY_DATA_0002".getBytes()); + valueVector.setSafe(2, "BINARY_DATA_0003".getBytes()); + valueVector.setValueCount(3); + + return valueVector; + } + + /** + * Create a FixedSizeBinaryVector to be used in the accessor tests. + * + * @return FixedSizeBinaryVector + */ + public FixedSizeBinaryVector createFixedSizeBinaryVector() { + FixedSizeBinaryVector valueVector = new FixedSizeBinaryVector("", this.getRootAllocator(), 16); + valueVector.allocateNew(3); + valueVector.setSafe(0, "BINARY_DATA_0001".getBytes()); + valueVector.setSafe(1, "BINARY_DATA_0002".getBytes()); + valueVector.setSafe(2, "BINARY_DATA_0003".getBytes()); + valueVector.setValueCount(3); + + return valueVector; + } + + /** + * Create a UInt8Vector to be used in the accessor tests. + * + * @return UInt8Vector + */ + public DecimalVector createDecimalVector() { + + BigDecimal[] bigDecimalValues = new BigDecimal[] { + new BigDecimal(0), + new BigDecimal(1), + new BigDecimal(-1), + new BigDecimal(Byte.MIN_VALUE), + new BigDecimal(Byte.MAX_VALUE), + new BigDecimal(-Short.MAX_VALUE), + new BigDecimal(Short.MIN_VALUE), + new BigDecimal(Integer.MIN_VALUE), + new BigDecimal(Integer.MAX_VALUE), + new BigDecimal(Long.MIN_VALUE), + new BigDecimal(-Long.MAX_VALUE), + new BigDecimal("170141183460469231731687303715884105727") + }; + + DecimalVector result = new DecimalVector("ID", this.getRootAllocator(), 39, 0); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < bigDecimalValues.length) { + result.setSafe(i, bigDecimalValues[i]); + } else { + result.setSafe(i, random.nextLong()); + } + } + + return result; + } + + /** + * Create a UInt8Vector to be used in the accessor tests. + * + * @return UInt8Vector + */ + public Decimal256Vector createDecimal256Vector() { + + BigDecimal[] bigDecimalValues = new BigDecimal[] { + new BigDecimal(0), + new BigDecimal(1), + new BigDecimal(-1), + new BigDecimal(Byte.MIN_VALUE), + new BigDecimal(Byte.MAX_VALUE), + new BigDecimal(-Short.MAX_VALUE), + new BigDecimal(Short.MIN_VALUE), + new BigDecimal(Integer.MIN_VALUE), + new BigDecimal(Integer.MAX_VALUE), + new BigDecimal(Long.MIN_VALUE), + new BigDecimal(-Long.MAX_VALUE), + new BigDecimal("170141183460469231731687303715884105727"), + new BigDecimal("17014118346046923173168234157303715884105727"), + new BigDecimal("1701411834604692317316823415265417303715884105727"), + new BigDecimal("-17014118346046923173168234152654115451237303715884105727"), + new BigDecimal("-17014118346046923173168234152654115451231545157303715884105727"), + new BigDecimal("1701411834604692315815656534152654115451231545157303715884105727"), + new BigDecimal("30560141183460469231581565634152654115451231545157303715884105727"), + new BigDecimal( + "57896044618658097711785492504343953926634992332820282019728792003956564819967"), + new BigDecimal( + "-56896044618658097711785492504343953926634992332820282019728792003956564819967") + }; + + Decimal256Vector result = new Decimal256Vector("ID", this.getRootAllocator(), 77, 0); + result.setValueCount(MAX_VALUE); + for (int i = 0; i < MAX_VALUE; i++) { + if (i < bigDecimalValues.length) { + result.setSafe(i, bigDecimalValues[i]); + } else { + result.setSafe(i, random.nextLong()); + } + } + + return result; + } + + public TimeStampNanoVector createTimeStampNanoVector() { + TimeStampNanoVector valueVector = new TimeStampNanoVector("", this.getRootAllocator()); + valueVector.allocateNew(2); + valueVector.setSafe(0, TimeUnit.MILLISECONDS.toNanos(1625702400000L)); + valueVector.setSafe(1, TimeUnit.MILLISECONDS.toNanos(1625788800000L)); + valueVector.setValueCount(2); + + return valueVector; + } + + public TimeStampNanoTZVector createTimeStampNanoTZVector(String timeZone) { + TimeStampNanoTZVector valueVector = + new TimeStampNanoTZVector("", this.getRootAllocator(), timeZone); + valueVector.allocateNew(2); + valueVector.setSafe(0, TimeUnit.MILLISECONDS.toNanos(1625702400000L)); + valueVector.setSafe(1, TimeUnit.MILLISECONDS.toNanos(1625788800000L)); + valueVector.setValueCount(2); + + return valueVector; + } + + public TimeStampMicroVector createTimeStampMicroVector() { + TimeStampMicroVector valueVector = new TimeStampMicroVector("", this.getRootAllocator()); + valueVector.allocateNew(2); + valueVector.setSafe(0, TimeUnit.MILLISECONDS.toMicros(1625702400000L)); + valueVector.setSafe(1, TimeUnit.MILLISECONDS.toMicros(1625788800000L)); + valueVector.setValueCount(2); + + return valueVector; + } + + public TimeStampMicroTZVector createTimeStampMicroTZVector(String timeZone) { + TimeStampMicroTZVector valueVector = + new TimeStampMicroTZVector("", this.getRootAllocator(), timeZone); + valueVector.allocateNew(2); + valueVector.setSafe(0, TimeUnit.MILLISECONDS.toMicros(1625702400000L)); + valueVector.setSafe(1, TimeUnit.MILLISECONDS.toMicros(1625788800000L)); + valueVector.setValueCount(2); + + return valueVector; + } + + public TimeStampMilliVector createTimeStampMilliVector() { + TimeStampMilliVector valueVector = new TimeStampMilliVector("", this.getRootAllocator()); + valueVector.allocateNew(2); + valueVector.setSafe(0, 1625702400000L); + valueVector.setSafe(1, 1625788800000L); + valueVector.setValueCount(2); + + return valueVector; + } + + public TimeStampMilliTZVector createTimeStampMilliTZVector(String timeZone) { + TimeStampMilliTZVector valueVector = + new TimeStampMilliTZVector("", this.getRootAllocator(), timeZone); + valueVector.allocateNew(2); + valueVector.setSafe(0, 1625702400000L); + valueVector.setSafe(1, 1625788800000L); + valueVector.setValueCount(2); + + return valueVector; + } + + public TimeStampSecVector createTimeStampSecVector() { + TimeStampSecVector valueVector = new TimeStampSecVector("", this.getRootAllocator()); + valueVector.allocateNew(2); + valueVector.setSafe(0, TimeUnit.MILLISECONDS.toSeconds(1625702400000L)); + valueVector.setSafe(1, TimeUnit.MILLISECONDS.toSeconds(1625788800000L)); + valueVector.setValueCount(2); + + return valueVector; + } + + public TimeStampSecTZVector createTimeStampSecTZVector(String timeZone) { + TimeStampSecTZVector valueVector = + new TimeStampSecTZVector("", this.getRootAllocator(), timeZone); + valueVector.allocateNew(2); + valueVector.setSafe(0, TimeUnit.MILLISECONDS.toSeconds(1625702400000L)); + valueVector.setSafe(1, TimeUnit.MILLISECONDS.toSeconds(1625788800000L)); + valueVector.setValueCount(2); + + return valueVector; + } + + public BitVector createBitVector() { + BitVector valueVector = new BitVector("Value", this.getRootAllocator()); + valueVector.allocateNew(2); + valueVector.setSafe(0, 0); + valueVector.setSafe(1, 1); + valueVector.setValueCount(2); + + return valueVector; + } + + public BitVector createBitVectorForNullTests() { + final BitVector bitVector = new BitVector("ID", this.getRootAllocator()); + bitVector.allocateNew(2); + bitVector.setNull(0); + bitVector.setValueCount(1); + + return bitVector; + } + + public TimeNanoVector createTimeNanoVector() { + TimeNanoVector valueVector = new TimeNanoVector("", this.getRootAllocator()); + valueVector.allocateNew(5); + valueVector.setSafe(0, 0); + valueVector.setSafe(1, 1_000_000_000L); // 1 second + valueVector.setSafe(2, 60 * 1_000_000_000L); // 1 minute + valueVector.setSafe(3, 60 * 60 * 1_000_000_000L); // 1 hour + valueVector.setSafe(4, (24 * 60 * 60 - 1) * 1_000_000_000L); // 23:59:59 + valueVector.setValueCount(5); + + return valueVector; + } + + public TimeMicroVector createTimeMicroVector() { + TimeMicroVector valueVector = new TimeMicroVector("", this.getRootAllocator()); + valueVector.allocateNew(5); + valueVector.setSafe(0, 0); + valueVector.setSafe(1, 1_000_000L); // 1 second + valueVector.setSafe(2, 60 * 1_000_000L); // 1 minute + valueVector.setSafe(3, 60 * 60 * 1_000_000L); // 1 hour + valueVector.setSafe(4, (24 * 60 * 60 - 1) * 1_000_000L); // 23:59:59 + valueVector.setValueCount(5); + + return valueVector; + } + + public TimeMilliVector createTimeMilliVector() { + TimeMilliVector valueVector = new TimeMilliVector("", this.getRootAllocator()); + valueVector.allocateNew(5); + valueVector.setSafe(0, 0); + valueVector.setSafe(1, 1_000); // 1 second + valueVector.setSafe(2, 60 * 1_000); // 1 minute + valueVector.setSafe(3, 60 * 60 * 1_000); // 1 hour + valueVector.setSafe(4, (24 * 60 * 60 - 1) * 1_000); // 23:59:59 + valueVector.setValueCount(5); + + return valueVector; + } + + public TimeSecVector createTimeSecVector() { + TimeSecVector valueVector = new TimeSecVector("", this.getRootAllocator()); + valueVector.allocateNew(5); + valueVector.setSafe(0, 0); + valueVector.setSafe(1, 1); // 1 second + valueVector.setSafe(2, 60); // 1 minute + valueVector.setSafe(3, 60 * 60); // 1 hour + valueVector.setSafe(4, (24 * 60 * 60 - 1)); // 23:59:59 + valueVector.setValueCount(5); + + return valueVector; + } + + public DateDayVector createDateDayVector() { + DateDayVector valueVector = new DateDayVector("", this.getRootAllocator()); + valueVector.allocateNew(2); + valueVector.setSafe(0, (int) TimeUnit.MILLISECONDS.toDays(1625702400000L)); + valueVector.setSafe(1, (int) TimeUnit.MILLISECONDS.toDays(1625788800000L)); + valueVector.setValueCount(2); + + return valueVector; + } + + public DateMilliVector createDateMilliVector() { + DateMilliVector valueVector = new DateMilliVector("", this.getRootAllocator()); + valueVector.allocateNew(2); + valueVector.setSafe(0, 1625702400000L); + valueVector.setSafe(1, 1625788800000L); + valueVector.setValueCount(2); + + return valueVector; + } + + public ListVector createListVector() { + return createListVector(""); + } + + public ListVector createListVector(String fieldName) { + ListVector valueVector = ListVector.empty(fieldName, this.getRootAllocator()); + valueVector.setInitialCapacity(MAX_VALUE); + + UnionListWriter writer = valueVector.getWriter(); + + IntStream range = IntStream.range(0, MAX_VALUE); + + range.forEach(row -> { + writer.startList(); + writer.setPosition(row); + IntStream.range(0, 5).map(j -> j * row).forEach(writer::writeInt); + writer.setValueCount(5); + writer.endList(); + }); + + valueVector.setValueCount(MAX_VALUE); + + return valueVector; + } + + public LargeListVector createLargeListVector() { + LargeListVector valueVector = LargeListVector.empty("", this.getRootAllocator()); + valueVector.setInitialCapacity(MAX_VALUE); + + UnionLargeListWriter writer = valueVector.getWriter(); + + IntStream range = IntStream.range(0, MAX_VALUE); + + range.forEach(row -> { + writer.startList(); + writer.setPosition(row); + IntStream.range(0, 5).map(j -> j * row).forEach(writer::writeInt); + writer.setValueCount(5); + writer.endList(); + }); + + valueVector.setValueCount(MAX_VALUE); + + return valueVector; + } + + public FixedSizeListVector createFixedSizeListVector() { + FixedSizeListVector valueVector = FixedSizeListVector.empty("", 5, this.getRootAllocator()); + valueVector.setInitialCapacity(MAX_VALUE); + + UnionFixedSizeListWriter writer = valueVector.getWriter(); + + IntStream range = IntStream.range(0, MAX_VALUE); + + range.forEach(row -> { + writer.startList(); + writer.setPosition(row); + IntStream.range(0, 5).map(j -> j * row).forEach(writer::writeInt); + writer.setValueCount(5); + writer.endList(); + }); + + valueVector.setValueCount(MAX_VALUE); + + return valueVector; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/SqlTypesTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/SqlTypesTest.java new file mode 100644 index 00000000000..5c7c873e55c --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/SqlTypesTest.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static org.apache.arrow.driver.jdbc.utils.SqlTypes.getSqlTypeIdFromArrowType; +import static org.apache.arrow.driver.jdbc.utils.SqlTypes.getSqlTypeNameFromArrowType; +import static org.junit.Assert.assertEquals; + +import java.sql.Types; + +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.IntervalUnit; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.UnionMode; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.junit.Test; + +public class SqlTypesTest { + + @Test + public void testGetSqlTypeIdFromArrowType() { + assertEquals(Types.TINYINT, getSqlTypeIdFromArrowType(new ArrowType.Int(8, true))); + assertEquals(Types.SMALLINT, getSqlTypeIdFromArrowType(new ArrowType.Int(16, true))); + assertEquals(Types.INTEGER, getSqlTypeIdFromArrowType(new ArrowType.Int(32, true))); + assertEquals(Types.BIGINT, getSqlTypeIdFromArrowType(new ArrowType.Int(64, true))); + + assertEquals(Types.BINARY, getSqlTypeIdFromArrowType(new ArrowType.FixedSizeBinary(1024))); + assertEquals(Types.VARBINARY, getSqlTypeIdFromArrowType(new ArrowType.Binary())); + assertEquals(Types.LONGVARBINARY, getSqlTypeIdFromArrowType(new ArrowType.LargeBinary())); + + assertEquals(Types.VARCHAR, getSqlTypeIdFromArrowType(new ArrowType.Utf8())); + assertEquals(Types.LONGVARCHAR, getSqlTypeIdFromArrowType(new ArrowType.LargeUtf8())); + + assertEquals(Types.DATE, getSqlTypeIdFromArrowType(new ArrowType.Date(DateUnit.MILLISECOND))); + assertEquals(Types.TIME, + getSqlTypeIdFromArrowType(new ArrowType.Time(TimeUnit.MILLISECOND, 32))); + assertEquals(Types.TIMESTAMP, + getSqlTypeIdFromArrowType(new ArrowType.Timestamp(TimeUnit.MILLISECOND, ""))); + + assertEquals(Types.BOOLEAN, getSqlTypeIdFromArrowType(new ArrowType.Bool())); + + assertEquals(Types.DECIMAL, getSqlTypeIdFromArrowType(new ArrowType.Decimal(0, 0, 64))); + assertEquals(Types.DOUBLE, + getSqlTypeIdFromArrowType(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))); + assertEquals(Types.FLOAT, + getSqlTypeIdFromArrowType(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))); + + assertEquals(Types.ARRAY, getSqlTypeIdFromArrowType(new ArrowType.List())); + assertEquals(Types.ARRAY, getSqlTypeIdFromArrowType(new ArrowType.LargeList())); + assertEquals(Types.ARRAY, getSqlTypeIdFromArrowType(new ArrowType.FixedSizeList(10))); + + assertEquals(Types.JAVA_OBJECT, getSqlTypeIdFromArrowType(new ArrowType.Struct())); + assertEquals(Types.JAVA_OBJECT, + getSqlTypeIdFromArrowType(new ArrowType.Duration(TimeUnit.MILLISECOND))); + assertEquals(Types.JAVA_OBJECT, + getSqlTypeIdFromArrowType(new ArrowType.Interval(IntervalUnit.DAY_TIME))); + assertEquals(Types.JAVA_OBJECT, + getSqlTypeIdFromArrowType(new ArrowType.Union(UnionMode.Dense, null))); + assertEquals(Types.JAVA_OBJECT, getSqlTypeIdFromArrowType(new ArrowType.Map(true))); + + assertEquals(Types.NULL, getSqlTypeIdFromArrowType(new ArrowType.Null())); + } + + @Test + public void testGetSqlTypeNameFromArrowType() { + assertEquals("TINYINT", getSqlTypeNameFromArrowType(new ArrowType.Int(8, true))); + assertEquals("SMALLINT", getSqlTypeNameFromArrowType(new ArrowType.Int(16, true))); + assertEquals("INTEGER", getSqlTypeNameFromArrowType(new ArrowType.Int(32, true))); + assertEquals("BIGINT", getSqlTypeNameFromArrowType(new ArrowType.Int(64, true))); + + assertEquals("BINARY", getSqlTypeNameFromArrowType(new ArrowType.FixedSizeBinary(1024))); + assertEquals("VARBINARY", getSqlTypeNameFromArrowType(new ArrowType.Binary())); + assertEquals("LONGVARBINARY", getSqlTypeNameFromArrowType(new ArrowType.LargeBinary())); + + assertEquals("VARCHAR", getSqlTypeNameFromArrowType(new ArrowType.Utf8())); + assertEquals("LONGVARCHAR", getSqlTypeNameFromArrowType(new ArrowType.LargeUtf8())); + + assertEquals("DATE", getSqlTypeNameFromArrowType(new ArrowType.Date(DateUnit.MILLISECOND))); + assertEquals("TIME", getSqlTypeNameFromArrowType(new ArrowType.Time(TimeUnit.MILLISECOND, 32))); + assertEquals("TIMESTAMP", + getSqlTypeNameFromArrowType(new ArrowType.Timestamp(TimeUnit.MILLISECOND, ""))); + + assertEquals("BOOLEAN", getSqlTypeNameFromArrowType(new ArrowType.Bool())); + + assertEquals("DECIMAL", getSqlTypeNameFromArrowType(new ArrowType.Decimal(0, 0, 64))); + assertEquals("DOUBLE", + getSqlTypeNameFromArrowType(new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE))); + assertEquals("FLOAT", + getSqlTypeNameFromArrowType(new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE))); + + assertEquals("ARRAY", getSqlTypeNameFromArrowType(new ArrowType.List())); + assertEquals("ARRAY", getSqlTypeNameFromArrowType(new ArrowType.LargeList())); + assertEquals("ARRAY", getSqlTypeNameFromArrowType(new ArrowType.FixedSizeList(10))); + + assertEquals("JAVA_OBJECT", getSqlTypeNameFromArrowType(new ArrowType.Struct())); + + assertEquals("JAVA_OBJECT", + getSqlTypeNameFromArrowType(new ArrowType.Duration(TimeUnit.MILLISECOND))); + assertEquals("JAVA_OBJECT", + getSqlTypeNameFromArrowType(new ArrowType.Interval(IntervalUnit.DAY_TIME))); + assertEquals("JAVA_OBJECT", + getSqlTypeNameFromArrowType(new ArrowType.Union(UnionMode.Dense, null))); + assertEquals("JAVA_OBJECT", getSqlTypeNameFromArrowType(new ArrowType.Map(true))); + + assertEquals("NULL", getSqlTypeNameFromArrowType(new ArrowType.Null())); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ThrowableAssertionUtils.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ThrowableAssertionUtils.java new file mode 100644 index 00000000000..f1bd44539ac --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/ThrowableAssertionUtils.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +/** + * Utility class to avoid upgrading JUnit to version >= 4.13 and keep using code to assert a {@link Throwable}. + * This should be removed as soon as we can use the proper assertThrows/checkThrows. + */ +public class ThrowableAssertionUtils { + private ThrowableAssertionUtils() { + } + + public static void simpleAssertThrowableClass( + final Class expectedThrowable, final ThrowingRunnable runnable) { + try { + runnable.run(); + } catch (Throwable actualThrown) { + if (expectedThrowable.isInstance(actualThrown)) { + return; + } else { + final String mismatchMessage = String.format("unexpected exception type thrown;\nexpected: %s\nactual: %s", + formatClass(expectedThrowable), + formatClass(actualThrown.getClass())); + + throw new AssertionError(mismatchMessage, actualThrown); + } + } + final String notThrownMessage = String.format("expected %s to be thrown, but nothing was thrown", + formatClass(expectedThrowable)); + throw new AssertionError(notThrownMessage); + } + + private static String formatClass(final Class value) { + // Fallback for anonymous inner classes + final String className = value.getCanonicalName(); + return className == null ? value.getName() : className; + } + + public interface ThrowingRunnable { + void run() throws Throwable; + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/UrlParserTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/UrlParserTest.java new file mode 100644 index 00000000000..4e764ab322c --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/UrlParserTest.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +import java.util.Map; + +import org.junit.jupiter.api.Test; + +class UrlParserTest { + @Test + void parse() { + final Map parsed = UrlParser.parse("foo=bar&123=456", "&"); + assertEquals(parsed.get("foo"), "bar"); + assertEquals(parsed.get("123"), "456"); + } + + @Test + void parseEscaped() { + final Map parsed = UrlParser.parse("foo=bar%26&%26123=456", "&"); + assertEquals(parsed.get("foo"), "bar&"); + assertEquals(parsed.get("&123"), "456"); + } + + @Test + void parseEmpty() { + final Map parsed = UrlParser.parse("a=&b&foo=bar&123=456", "&"); + assertEquals(parsed.get("a"), ""); + assertNull(parsed.get("b")); + assertEquals(parsed.get("foo"), "bar"); + assertEquals(parsed.get("123"), "456"); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/VectorSchemaRootTransformerTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/VectorSchemaRootTransformerTest.java new file mode 100644 index 00000000000..1804b42cecb --- /dev/null +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/utils/VectorSchemaRootTransformerTest.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.driver.jdbc.utils; + +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; + +import com.google.common.collect.ImmutableList; + +public class VectorSchemaRootTransformerTest { + + @Rule + public RootAllocatorTestRule rootAllocatorTestRule = new RootAllocatorTestRule(); + private final BufferAllocator rootAllocator = rootAllocatorTestRule.getRootAllocator(); + + @Test + public void testTransformerBuilderWorksCorrectly() throws Exception { + final VarBinaryVector field1 = rootAllocatorTestRule.createVarBinaryVector("FIELD_1"); + final VarBinaryVector field2 = rootAllocatorTestRule.createVarBinaryVector("FIELD_2"); + final VarBinaryVector field3 = rootAllocatorTestRule.createVarBinaryVector("FIELD_3"); + + try (final VectorSchemaRoot originalRoot = VectorSchemaRoot.of(field1, field2, field3); + final VectorSchemaRoot clonedRoot = cloneVectorSchemaRoot(originalRoot)) { + + final VectorSchemaRootTransformer.Builder builder = + new VectorSchemaRootTransformer.Builder(originalRoot.getSchema(), + rootAllocator); + + builder.renameFieldVector("FIELD_3", "FIELD_3_RENAMED"); + builder.addEmptyField("EMPTY_FIELD", new ArrowType.Bool()); + builder.renameFieldVector("FIELD_2", "FIELD_2_RENAMED"); + builder.renameFieldVector("FIELD_1", "FIELD_1_RENAMED"); + + final VectorSchemaRootTransformer transformer = builder.build(); + + final Schema transformedSchema = new Schema(ImmutableList.of( + Field.nullable("FIELD_3_RENAMED", new ArrowType.Binary()), + Field.nullable("EMPTY_FIELD", new ArrowType.Bool()), + Field.nullable("FIELD_2_RENAMED", new ArrowType.Binary()), + Field.nullable("FIELD_1_RENAMED", new ArrowType.Binary()) + )); + try (final VectorSchemaRoot transformedRoot = createVectorSchemaRoot(transformedSchema)) { + Assert.assertSame(transformedRoot, transformer.transform(clonedRoot, transformedRoot)); + Assert.assertEquals(transformedSchema, transformedRoot.getSchema()); + + final int rowCount = originalRoot.getRowCount(); + Assert.assertEquals(rowCount, transformedRoot.getRowCount()); + + final VarBinaryVector originalField1 = + (VarBinaryVector) originalRoot.getVector("FIELD_1"); + final VarBinaryVector originalField2 = + (VarBinaryVector) originalRoot.getVector("FIELD_2"); + final VarBinaryVector originalField3 = + (VarBinaryVector) originalRoot.getVector("FIELD_3"); + + final VarBinaryVector transformedField1 = + (VarBinaryVector) transformedRoot.getVector("FIELD_1_RENAMED"); + final VarBinaryVector transformedField2 = + (VarBinaryVector) transformedRoot.getVector("FIELD_2_RENAMED"); + final VarBinaryVector transformedField3 = + (VarBinaryVector) transformedRoot.getVector("FIELD_3_RENAMED"); + final FieldVector emptyField = transformedRoot.getVector("EMPTY_FIELD"); + + for (int i = 0; i < rowCount; i++) { + Assert.assertArrayEquals(originalField1.getObject(i), transformedField1.getObject(i)); + Assert.assertArrayEquals(originalField2.getObject(i), transformedField2.getObject(i)); + Assert.assertArrayEquals(originalField3.getObject(i), transformedField3.getObject(i)); + Assert.assertNull(emptyField.getObject(i)); + } + } + } + } + + private VectorSchemaRoot cloneVectorSchemaRoot(final VectorSchemaRoot originalRoot) { + final VectorUnloader vectorUnloader = new VectorUnloader(originalRoot); + try (final ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch()) { + final VectorSchemaRoot clonedRoot = createVectorSchemaRoot(originalRoot.getSchema()); + final VectorLoader vectorLoader = new VectorLoader(clonedRoot); + vectorLoader.load(recordBatch); + return clonedRoot; + } + } + + private VectorSchemaRoot createVectorSchemaRoot(final Schema schema) { + final List fieldVectors = schema.getFields().stream() + .map(field -> field.createVector(rootAllocator)) + .collect(Collectors.toList()); + return new VectorSchemaRoot(fieldVectors); + } +} diff --git a/java/flight/flight-sql-jdbc-driver/src/test/resources/keys/keyStore.jks b/java/flight/flight-sql-jdbc-driver/src/test/resources/keys/keyStore.jks new file mode 100644 index 0000000000000000000000000000000000000000..32a9bedea500a03beaea22af78c5a484f52e504c GIT binary patch literal 1537 zcmezO_TO6u1_mY|W(3oU$$7RVsl_Ea`L?;0`9Fha^k#(mIh{q7C>lbVj3mRYhq?-XkZNGQf@&LlM=F{8Ce;ao0#|+44Rm@n3|Xv z85UO@{Oq-UTbZ=LGOHV}|9=TETKvIy(=qF>n;Q}*P845jCix_y!HTOq#_esB%r}Lu zYL2}0;KKIITV@{Bx4qiz{KT~Kl$T7gX#0GDkCC5$HvjDnwYqs`{30E%PQ|=@xP*Ua zxBlN_+((y*2cO#~dR&gnD*ssd{N*;URtWXJj8>B!e8R?FoDLaUFr9yTbEp7k`L;-mHP2&U4u zZU3?^#?M|C&)$B<^V+{@88>(NvsTa9e|0KznBd-rtqjY1XWZxhaLG7tY9Y@7{nS`< z?MpM2A1ki;C70kjdrgSq(z9u_HEfUgVP1RD*mrGGA4-=2Q;>f|RUZl#o0J^azXBe=dk=$h^Pt)9snYbMkvm|Oo{ z;r6WC!pc7)>y%2;iYq7TIQMulN{B_vRqRwtX=Ri<E2CL82~q?K7R4AL5~Dp=^e#F63z+k}+act92kv#=U4Gcx{1PQ}2i2Ta9` z3@OrCM^2ouKDtz-pl!``zru&5PZ~4so$^$(6~F&euj=M+%O8ogUKUTMT*q_1mMJ?Pvrk@Ifj79XQ{!JU&@3aS`YTgzZ$QVRdT+{$F1!_BY+vT%Km9$9vgxuH657uSV*vS}plxiK4oNVP!=CS7oAKzHQ*$ z+j~}B(o3=KI5R)*)3P=hUqda!EQ^V}`dsS%ubbj8ep!7)X+_-m1w5af5_-9svwxnw zc4Mw`qX}o`whhVsezm?FD|0fex_=p_sLp8L(QPp4DqpwE-o`V$N2i$xbe(Y&;#HoI zw0KMPO~JPd@{8>zua9o1`;m6@ugQ)hb{7A$-|f^^GK%Q?nXqa@MAd$e7v*s_0_jPf z^QXRQ_k44)DRo%`@AQ3@d>WtTvhu|WPj%SL^}C$+fa%h+AKRPkxp~2@ z*^h;5)vh)jz3oDc-LP|k t-j`fDtxs}3)G0ALo1)jv?3F(EhsDkp25yJ67l&VZF`q3lbG^7=IRMtsXPE#1 literal 0 HcmV?d00001 diff --git a/java/flight/flight-sql-jdbc-driver/src/test/resources/keys/noCertificate.jks b/java/flight/flight-sql-jdbc-driver/src/test/resources/keys/noCertificate.jks new file mode 100644 index 0000000000000000000000000000000000000000..071a1ebf97b3e637be5e53a1b3a70ce1b0df04b8 GIT binary patch literal 2545 zcmY+EXE+-Q7sr!G5F@QoGZeQKqai4ti2T_qV}r2rB$Wc zDygkTZq*iTulIf4_ul7zIL|rH?|;tu_(S2KJ0Ktx3J0yAhKj`Ajy+-o(g2Ha&@34|Ks!L++fxxrVQ%j z6GV<@qNk%FdTo+=?1w*y3eax=#(_t^O6VhrN*11Q2$|VC=F2Qu(&sqIWZGL}I~<9_!SxrAN#$$@JXWC04W6ug!6v6$DHivb2bFGtATGrdmi95%u z>jStwLT7iy=3CxWe~CXi)hVvNClxv*lTT&@4%r5E83DB=-QRJT`SHMT@(<;*(k8S? z82tj~(-&vKkcn@SSCu*0))SHeW|$W$i+zLrLWb&L@y4QLe7^C0qDa%J{L>p8>Uxvg z=xTA`(PH7dl8IlsrKBMzb6mj2HECPSk+rdF?zP)%u1IkU-Ix_)z2E_ts{tC9ZjaS$R)ki!`=$)hfI$|Dfsu_@ zRS?OR!XbCX+as?KL*f}b(M^qcpOs@XT<7CO*@K|`Q>%qxn{k#dFJz9e)6A}Pem`wN z+zpIEW&s10)icap>q~nB(|1wVa%`(3W}W!aY-!3PrYHAPNJaS-WVY0;ONq{3nCjT| zI%7N67iJp*5vSBKd}{)S7lcZZ5lyDZVHUrhhCj@OZ{{#f7;LFl4hgqo7rD3>DLpc8 zBy~!)Xi`K%Y$KqY;Vj+^M|Za|T3_^1m}J!+R_aP99QK#XckJmD`1f}bBP47XL+MT~ zigd4)guD5vWl&%C+k0$1_en~HRrp8A$3-LZ`>~>fW^W!!TfnlVTR3ZpYA#(angy{s zVnD;BU(K3NAP_;z<;JedPv&1PZ!lT03m<*13w$4r7IHk*?}P2%N!WSYEuQY&LpIbJ z<9cI9J&OaFE&Tg}S1$zW^V&M);!?>Lh)x9T#Ie|9pOQ78r( zSYK83SZkdSetJ&|*=WM>OEqUl)VYYlZEoGRhh@wC4p*XjJMJ^3SZxyZ=u+H1>b8ldkXfsFKB1||LcPM3upMkd=Zqt0!^DKO>;hOK>sZ1b>vNA z#hXbRo@%3F0U+H!^o$8ka)(X;SAF8|C$U!;5UDZFsbDA&=-FkU^CE? zPt%>3_8Jzv2GZpn=Rfp!zoq2S!TrQ`#whcXt4<|l=%V&=DrP-x_4dK$*F^7zdwobcBY6|4}2}FEnEwmFezfl>L~&yp@NO2;nkFzuQ|3~(Y8<;^6^%R4}+xcxJm?- zGB0D02alfoQnMgR9c{iii#?j zmSIVfX#7BGxd|bqDdYI)GGxlTwwq#mwm+$4DOiMJ!5=M}RY%tH)SJP9VwDE71!a9J zH2&6Ad$I0=+sdObWMKQZu&+@6Kd|*PT#zxM#1cL+VuYs(-M(K|Tj^S?QfUa z+i!5!%gsG7q-havbuQ@=!iNqyS5Z#)HDE70t}%J=ZcGZ?@qF*(d5tnwMAfGjYKk3z z2x^M%uJ;beue}O&-Aj+wG_O<8T{OC}orSQ9Ax&8D5!O$lCT4)18VR0yG6A~kEEXPl zTEOoM>?UL}#;~^b3$F`{uhbTXSIQ19xvNWOqHNq77HQ$tU!|gzjx7~yV5P6W%n%a0 zka5{#3HMcgR9J6qn4dh|Y7+&NfV1antR+2ZyL-r{P+gzCK<`G|+^h$E*vD-7v zOqoo8s0>O+*N><=lr4(MXm-VNkOa6|_kUC68jOxNYYuOs6ktCMYmpL;Av{DtJ^eEx zh)+AN!oj!|$B+g#Q}NlW8X3RGv+{WN>bs+4>$6=mt%<$zw35~f6+n3K3afgOf2b`} zIIuWpU{6&x^j8|A9n*mbXDvi1?IX2|rex*eT$Q}iZ@sN|k + + + + + + + %d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n + + + + + + + + + diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java index 06b3c9dbe20..1d7305fcf2f 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java @@ -260,6 +260,7 @@ public void testGetTablesResultFilteredWithSchema() throws Exception { new Field("ID", new FieldType(false, MinorType.INT.getType(), null, new FlightSqlColumnMetadata.Builder() .catalogName("") + .typeName("INTEGER") .schemaName("APP") .tableName("FOREIGNTABLE") .precision(10) @@ -269,6 +270,7 @@ public void testGetTablesResultFilteredWithSchema() throws Exception { new Field("FOREIGNNAME", new FieldType(true, MinorType.VARCHAR.getType(), null, new FlightSqlColumnMetadata.Builder() .catalogName("") + .typeName("VARCHAR") .schemaName("APP") .tableName("FOREIGNTABLE") .precision(100) @@ -278,6 +280,7 @@ public void testGetTablesResultFilteredWithSchema() throws Exception { new Field("VALUE", new FieldType(true, MinorType.INT.getType(), null, new FlightSqlColumnMetadata.Builder() .catalogName("") + .typeName("INTEGER") .schemaName("APP") .tableName("FOREIGNTABLE") .precision(10) @@ -293,6 +296,7 @@ public void testGetTablesResultFilteredWithSchema() throws Exception { new Field("ID", new FieldType(false, MinorType.INT.getType(), null, new FlightSqlColumnMetadata.Builder() .catalogName("") + .typeName("INTEGER") .schemaName("APP") .tableName("INTTABLE") .precision(10) @@ -302,6 +306,7 @@ public void testGetTablesResultFilteredWithSchema() throws Exception { new Field("KEYNAME", new FieldType(true, MinorType.VARCHAR.getType(), null, new FlightSqlColumnMetadata.Builder() .catalogName("") + .typeName("VARCHAR") .schemaName("APP") .tableName("INTTABLE") .precision(100) @@ -311,6 +316,7 @@ public void testGetTablesResultFilteredWithSchema() throws Exception { new Field("VALUE", new FieldType(true, MinorType.INT.getType(), null, new FlightSqlColumnMetadata.Builder() .catalogName("") + .typeName("INTEGER") .schemaName("APP") .tableName("INTTABLE") .precision(10) @@ -320,6 +326,7 @@ public void testGetTablesResultFilteredWithSchema() throws Exception { new Field("FOREIGNID", new FieldType(true, MinorType.INT.getType(), null, new FlightSqlColumnMetadata.Builder() .catalogName("") + .typeName("INTEGER") .schemaName("APP") .tableName("INTTABLE") .precision(10) diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java index baf162cb919..d66b8df9283 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java @@ -576,6 +576,7 @@ private static VectorSchemaRoot getTablesRoot(final DatabaseMetaData databaseMet final String catalogName = columnsData.getString("TABLE_CAT"); final String schemaName = columnsData.getString("TABLE_SCHEM"); final String tableName = columnsData.getString("TABLE_NAME"); + final String typeName = columnsData.getString("TYPE_NAME"); final String fieldName = columnsData.getString("COLUMN_NAME"); final int dataType = columnsData.getInt("DATA_TYPE"); final boolean isNullable = columnsData.getInt("NULLABLE") != DatabaseMetaData.columnNoNulls; @@ -590,6 +591,7 @@ private static VectorSchemaRoot getTablesRoot(final DatabaseMetaData databaseMet .catalogName(catalogName) .schemaName(schemaName) .tableName(tableName) + .typeName(typeName) .precision(precision) .scale(scale) .isAutoIncrement(isAutoIncrement) diff --git a/java/flight/pom.xml b/java/flight/pom.xml index dad0f05d7af..d8b02bee7ab 100644 --- a/java/flight/pom.xml +++ b/java/flight/pom.xml @@ -28,6 +28,7 @@ flight-core flight-grpc flight-sql + flight-sql-jdbc-driver flight-integration-tests diff --git a/java/pom.xml b/java/pom.xml index 8abe7ad9dca..486765df2d3 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -150,6 +150,7 @@ **/client/build/** **/*.tbl **/*.iml + **/flight.properties From 72b539f54233f6610b01ec7381755a84c652d151 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Thu, 15 Sep 2022 01:15:58 +0530 Subject: [PATCH 063/133] ARROW-17521: [Python] Add python bindings for NamedTableProvider for Substrait consumer (#14024) This PR includes a basic version to use NamedTable feature in Substrait. The idea is to provide the flexibility to write Python tests with in-memory PyArrow tables. Authored-by: Vibhatha Abeykoon Signed-off-by: Weston Pace --- cpp/src/arrow/compute/exec/exec_plan.cc | 4 + cpp/src/arrow/compute/exec/exec_plan.h | 5 + .../arrow/engine/substrait/function_test.cc | 4 +- .../engine/substrait/relation_internal.cc | 6 + cpp/src/arrow/engine/substrait/serde_test.cc | 13 +- cpp/src/arrow/engine/substrait/util.cc | 23 ++-- cpp/src/arrow/engine/substrait/util.h | 8 +- python/pyarrow/_exec_plan.pyx | 2 +- python/pyarrow/_substrait.pyx | 97 +++++++++++++- python/pyarrow/includes/libarrow.pxd | 1 + .../pyarrow/includes/libarrow_substrait.pxd | 28 +++- python/pyarrow/tests/test_substrait.py | 126 ++++++++++++++++++ 12 files changed, 289 insertions(+), 28 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index b6a3916de1f..00415495aa8 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -643,6 +643,10 @@ Declaration Declaration::Sequence(std::vector decls) { return out; } +bool Declaration::IsValid(ExecFactoryRegistry* registry) const { + return !this->factory_name.empty() && this->options != nullptr; +} + namespace internal { void RegisterSourceNode(ExecFactoryRegistry*); diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 263f3634a5a..a9481e21a6e 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -451,6 +451,8 @@ inline Result MakeExecNode( struct ARROW_EXPORT Declaration { using Input = util::Variant; + Declaration() {} + Declaration(std::string factory_name, std::vector inputs, std::shared_ptr options, std::string label) : factory_name{std::move(factory_name)}, @@ -514,6 +516,9 @@ struct ARROW_EXPORT Declaration { Result AddToPlan(ExecPlan* plan, ExecFactoryRegistry* registry = default_exec_factory_registry()) const; + // Validate a declaration + bool IsValid(ExecFactoryRegistry* registry = default_exec_factory_registry()) const; + std::string factory_name; std::vector inputs; std::shared_ptr options; diff --git a/cpp/src/arrow/engine/substrait/function_test.cc b/cpp/src/arrow/engine/substrait/function_test.cc index 0bcb475d310..3465f00e132 100644 --- a/cpp/src/arrow/engine/substrait/function_test.cc +++ b/cpp/src/arrow/engine/substrait/function_test.cc @@ -132,8 +132,8 @@ void CheckValidTestCases(const std::vector& valid_cases) { ASSERT_FINISHES_OK(plan->finished()); // Could also modify the Substrait plan with an emit to drop the leading columns - ASSERT_OK_AND_ASSIGN(output_table, - output_table->SelectColumns({output_table->num_columns() - 1})); + int result_column = output_table->num_columns() - 1; // last column holds result + ASSERT_OK_AND_ASSIGN(output_table, output_table->SelectColumns({result_column})); ASSERT_OK_AND_ASSIGN( std::shared_ptr expected_output, diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 4213895b616..3911373b7b7 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -135,8 +135,14 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& const substrait::ReadRel::NamedTable& named_table = read.named_table(); std::vector table_names(named_table.names().begin(), named_table.names().end()); + if (table_names.empty()) { + return Status::Invalid("names for NamedTable not provided"); + } ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl, named_table_provider(table_names)); + if (!source_decl.IsValid()) { + return Status::Invalid("Invalid NamedTable Source"); + } return ProcessEmit(std::move(read), DeclarationInfo{std::move(source_decl), base_schema}, std::move(base_schema)); diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 251c2bfe352..b50e1c6084c 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -1924,7 +1924,6 @@ TEST(Substrait, BasicPlanRoundTripping) { ASSERT_OK_AND_ASSIGN(auto tempdir, arrow::internal::TemporaryDir::Make("substrait-tempdir-")); - std::cout << "file_path_str " << tempdir->path().ToString() << std::endl; ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); std::string file_path_str = file_path.ToString(); @@ -2189,7 +2188,7 @@ TEST(Substrait, ProjectRel) { } }, "namedTable": { - "names": [] + "names": ["A"] } } } @@ -2313,7 +2312,7 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { } }, "namedTable": { - "names": [] + "names": ["A"] } } } @@ -2396,7 +2395,7 @@ TEST(Substrait, ReadRelWithEmit) { } }, "namedTable": { - "names" : [] + "names" : ["A"] } } } @@ -2501,7 +2500,7 @@ TEST(Substrait, FilterRelWithEmit) { } }, "namedTable": { - "names" : [] + "names" : ["A"] } } } @@ -2885,7 +2884,7 @@ TEST(Substrait, AggregateRel) { } }, "namedTable" : { - "names": [] + "names": ["A"] } } }, @@ -3004,7 +3003,7 @@ TEST(Substrait, AggregateRelEmit) { } }, "namedTable" : { - "names" : [] + "names" : ["A"] } } }, diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index 936bde5c652..f51666ef858 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -63,8 +63,12 @@ class SubstraitSinkConsumer : public compute::SinkNodeConsumer { class SubstraitExecutor { public: explicit SubstraitExecutor(std::shared_ptr plan, - compute::ExecContext exec_context) - : plan_(std::move(plan)), plan_started_(false), exec_context_(exec_context) {} + compute::ExecContext exec_context, + const ConversionOptions& conversion_options = {}) + : plan_(std::move(plan)), + plan_started_(false), + exec_context_(exec_context), + conversion_options_(conversion_options) {} ~SubstraitExecutor() { ARROW_UNUSED(this->Close()); } @@ -95,8 +99,8 @@ class SubstraitExecutor { return sink_consumer_; }; ARROW_ASSIGN_OR_RAISE( - declarations_, - engine::DeserializePlans(substrait_buffer, consumer_factory, registry)); + declarations_, engine::DeserializePlans(substrait_buffer, consumer_factory, + registry, nullptr, conversion_options_)); return Status::OK(); } @@ -107,19 +111,20 @@ class SubstraitExecutor { bool plan_started_; compute::ExecContext exec_context_; std::shared_ptr sink_consumer_; + const ConversionOptions& conversion_options_; }; } // namespace Result> ExecuteSerializedPlan( - const Buffer& substrait_buffer, const ExtensionIdRegistry* extid_registry, - compute::FunctionRegistry* func_registry) { - // TODO(ARROW-15732) + const Buffer& substrait_buffer, const ExtensionIdRegistry* registry, + compute::FunctionRegistry* func_registry, + const ConversionOptions& conversion_options) { compute::ExecContext exec_context(arrow::default_memory_pool(), ::arrow::internal::GetCpuThreadPool(), func_registry); ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(&exec_context)); - SubstraitExecutor executor(std::move(plan), exec_context); - RETURN_NOT_OK(executor.Init(substrait_buffer, extid_registry)); + SubstraitExecutor executor(std::move(plan), exec_context, conversion_options); + RETURN_NOT_OK(executor.Init(substrait_buffer, registry)); ARROW_ASSIGN_OR_RAISE(auto sink_reader, executor.Execute()); // check closing here, not in destructor, to expose error to caller RETURN_NOT_OK(executor.Close()); diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 3ac9320e1da..90cb4e3dd2a 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -20,6 +20,7 @@ #include #include "arrow/compute/registry.h" #include "arrow/engine/substrait/api.h" +#include "arrow/engine/substrait/options.h" #include "arrow/util/iterator.h" #include "arrow/util/optional.h" @@ -27,10 +28,13 @@ namespace arrow { namespace engine { -/// \brief Retrieve a RecordBatchReader from a Substrait plan. +using PythonTableProvider = + std::function>(const std::vector&)>; + ARROW_ENGINE_EXPORT Result> ExecuteSerializedPlan( const Buffer& substrait_buffer, const ExtensionIdRegistry* registry = NULLPTR, - compute::FunctionRegistry* func_registry = NULLPTR); + compute::FunctionRegistry* func_registry = NULLPTR, + const ConversionOptions& conversion_options = {}); /// \brief Get a Serialized Plan from a Substrait JSON plan. /// This is a helper method for Python tests. diff --git a/python/pyarrow/_exec_plan.pyx b/python/pyarrow/_exec_plan.pyx index 89e474f4390..9506caf7d28 100644 --- a/python/pyarrow/_exec_plan.pyx +++ b/python/pyarrow/_exec_plan.pyx @@ -92,7 +92,7 @@ cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads node_factory = "table_source" c_in_table = pyarrow_unwrap_table(ipt) c_tablesourceopts = make_shared[CTableSourceNodeOptions]( - c_in_table, 1 << 20) + c_in_table) c_input_node_opts = static_pointer_cast[CExecNodeOptions, CTableSourceNodeOptions]( c_tablesourceopts) elif isinstance(ipt, Dataset): diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx index 05794a95a20..47a519cf16b 100644 --- a/python/pyarrow/_substrait.pyx +++ b/python/pyarrow/_substrait.pyx @@ -17,15 +17,38 @@ # cython: language_level = 3 from cython.operator cimport dereference as deref +from libcpp.vector cimport vector as std_vector from pyarrow import Buffer -from pyarrow.lib import frombytes +from pyarrow.lib import frombytes, tobytes from pyarrow.lib cimport * from pyarrow.includes.libarrow cimport * from pyarrow.includes.libarrow_substrait cimport * -def run_query(plan): +cdef CDeclaration _create_named_table_provider(dict named_args, const std_vector[c_string]& names): + cdef: + c_string c_name + shared_ptr[CTable] c_in_table + shared_ptr[CTableSourceNodeOptions] c_tablesourceopts + shared_ptr[CExecNodeOptions] c_input_node_opts + vector[CDeclaration.Input] no_c_inputs + + py_names = [] + for i in range(names.size()): + c_name = names[i] + py_names.append(frombytes(c_name)) + + py_table = named_args["provider"](py_names) + c_in_table = pyarrow_unwrap_table(py_table) + c_tablesourceopts = make_shared[CTableSourceNodeOptions](c_in_table) + c_input_node_opts = static_pointer_cast[CExecNodeOptions, CTableSourceNodeOptions]( + c_tablesourceopts) + return CDeclaration(tobytes("table_source"), + no_c_inputs, c_input_node_opts) + + +def run_query(plan, table_provider=None): """ Execute a Substrait plan and read the results as a RecordBatchReader. @@ -33,6 +56,63 @@ def run_query(plan): ---------- plan : Buffer The serialized Substrait plan to execute. + table_provider : object (optional) + A function to resolve any NamedTable relation to a table. + The function will receive a single argument which will be a list + of strings representing the table name and should return a pyarrow.Table. + + Returns + ------- + RecordBatchReader + A reader containing the result of the executed query + + Examples + -------- + >>> import pyarrow as pa + >>> from pyarrow.lib import tobytes + >>> import pyarrow.substrait as substrait + >>> test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]}) + >>> test_table_2 = pa.Table.from_pydict({"x": [4, 5, 6]}) + >>> def table_provider(names): + ... if not names: + ... raise Exception("No names provided") + ... elif names[0] == "t1": + ... return test_table_1 + ... elif names[1] == "t2": + ... return test_table_2 + ... else: + ... raise Exception("Unrecognized table name") + ... + >>> substrait_query = ''' + ... { + ... "relations": [ + ... {"rel": { + ... "read": { + ... "base_schema": { + ... "struct": { + ... "types": [ + ... {"i64": {}} + ... ] + ... }, + ... "names": [ + ... "x" + ... ] + ... }, + ... "namedTable": { + ... "names": ["t1"] + ... } + ... } + ... }} + ... ] + ... } + ... ''' + >>> buf = pa._substrait._parse_json_plan(tobytes(substrait_query)) + >>> reader = pa.substrait.run_query(buf, table_provider) + >>> reader.read_all() + pyarrow.Table + x: int64 + ---- + x: [[1,2,3]] """ cdef: @@ -41,10 +121,21 @@ def run_query(plan): RecordBatchReader reader c_string c_str_plan shared_ptr[CBuffer] c_buf_plan + function[CNamedTableProvider] c_named_table_provider + CConversionOptions c_conversion_options c_buf_plan = pyarrow_unwrap_buffer(plan) + + if table_provider is not None: + named_table_args = { + "provider": table_provider + } + c_conversion_options.named_table_provider = BindFunction[CNamedTableProvider]( + &_create_named_table_provider, named_table_args) + with nogil: - c_res_reader = ExecuteSerializedPlan(deref(c_buf_plan)) + c_res_reader = ExecuteSerializedPlan( + deref(c_buf_plan), default_extension_id_registry(), GetFunctionRegistry(), c_conversion_options) c_reader = GetResultValue(c_res_reader) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index be273975f94..489d73bf27e 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2574,6 +2574,7 @@ cdef extern from "arrow/compute/exec/exec_plan.h" namespace "arrow::compute" nog c_string label vector[Input] inputs + CDeclaration() CDeclaration(c_string factory_name, CExecNodeOptions options) CDeclaration(c_string factory_name, vector[Input] inputs, shared_ptr[CExecNodeOptions] options) diff --git a/python/pyarrow/includes/libarrow_substrait.pxd b/python/pyarrow/includes/libarrow_substrait.pxd index 0b3ace75d92..04990380d97 100644 --- a/python/pyarrow/includes/libarrow_substrait.pxd +++ b/python/pyarrow/includes/libarrow_substrait.pxd @@ -22,10 +22,22 @@ from libcpp.vector cimport vector as std_vector from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * - -cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine" nogil: - CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(const CBuffer& substrait_buffer) - CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string& substrait_json) +ctypedef CResult[CDeclaration] CNamedTableProvider(const std_vector[c_string]&) + +cdef extern from "arrow/engine/substrait/options.h" namespace "arrow::engine" nogil: + cdef enum ConversionStrictness \ + "arrow::engine::ConversionStrictness": + EXACT_ROUNDTRIP \ + "arrow::engine::ConversionStrictness::EXACT_ROUNDTRIP" + PRESERVE_STRUCTURE \ + "arrow::engine::ConversionStrictness::PRESERVE_STRUCTURE" + BEST_EFFORT \ + "arrow::engine::ConversionStrictness::BEST_EFFORT" + + cdef cppclass CConversionOptions \ + "arrow::engine::ConversionOptions": + ConversionStrictness conversion_strictness + function[CNamedTableProvider] named_table_provider cdef extern from "arrow/engine/substrait/extension_set.h" \ namespace "arrow::engine" nogil: @@ -34,3 +46,11 @@ cdef extern from "arrow/engine/substrait/extension_set.h" \ std_vector[c_string] GetSupportedSubstraitFunctions() ExtensionIdRegistry* default_extension_id_registry() + + +cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine" nogil: + CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan( + const CBuffer& substrait_buffer, const ExtensionIdRegistry* registry, + CFunctionRegistry* func_registry, const CConversionOptions& conversion_options) + + CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string& substrait_json) diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index c8fa6afcb9f..c8fd8048aa4 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -165,3 +165,129 @@ def test_get_supported_functions(): 'functions_arithmetic.yaml', 'add') assert has_function(supported_functions, 'functions_arithmetic.yaml', 'sum') + + +def test_named_table(): + test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]}) + test_table_2 = pa.Table.from_pydict({"x": [4, 5, 6]}) + + def table_provider(names): + if not names: + raise Exception("No names provided") + elif names[0] == "t1": + return test_table_1 + elif names[1] == "t2": + return test_table_2 + else: + raise Exception("Unrecognized table name") + + substrait_query = """ + { + "relations": [ + {"rel": { + "read": { + "base_schema": { + "struct": { + "types": [ + {"i64": {}} + ] + }, + "names": [ + "x" + ] + }, + "namedTable": { + "names": ["t1"] + } + } + }} + ] + } + """ + + buf = pa._substrait._parse_json_plan(tobytes(substrait_query)) + reader = pa.substrait.run_query(buf, table_provider) + res_tb = reader.read_all() + assert res_tb == test_table_1 + + +def test_named_table_invalid_table_name(): + test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]}) + + def table_provider(names): + if not names: + raise Exception("No names provided") + elif names[0] == "t1": + return test_table_1 + else: + raise Exception("Unrecognized table name") + + substrait_query = """ + { + "relations": [ + {"rel": { + "read": { + "base_schema": { + "struct": { + "types": [ + {"i64": {}} + ] + }, + "names": [ + "x" + ] + }, + "namedTable": { + "names": ["t3"] + } + } + }} + ] + } + """ + + buf = pa._substrait._parse_json_plan(tobytes(substrait_query)) + exec_message = "Invalid NamedTable Source" + with pytest.raises(ArrowInvalid, match=exec_message): + substrait.run_query(buf, table_provider) + + +def test_named_table_empty_names(): + test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]}) + + def table_provider(names): + if not names: + raise Exception("No names provided") + elif names[0] == "t1": + return test_table_1 + else: + raise Exception("Unrecognized table name") + + substrait_query = """ + { + "relations": [ + {"rel": { + "read": { + "base_schema": { + "struct": { + "types": [ + {"i64": {}} + ] + }, + "names": [ + "x" + ] + }, + "namedTable": { + "names": [] + } + } + }} + ] + } + """ + query = tobytes(substrait_query) + buf = pa._substrait._parse_json_plan(tobytes(query)) + exec_message = "names for NamedTable not provided" + with pytest.raises(ArrowInvalid, match=exec_message): + substrait.run_query(buf, table_provider) From 2749fef975493811a3139ce384cb8c2dec26f80d Mon Sep 17 00:00:00 2001 From: Jin Shang Date: Thu, 15 Sep 2022 07:52:35 +0800 Subject: [PATCH 064/133] MINOR: [CI][Conan] Fix a typo (#14124) Authored-by: Jin Shang Signed-off-by: Sutou Kouhei --- ci/conan/all/conanfile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/conan/all/conanfile.py b/ci/conan/all/conanfile.py index a87478d6e40..aa832ab681f 100644 --- a/ci/conan/all/conanfile.py +++ b/ci/conan/all/conanfile.py @@ -186,7 +186,7 @@ def validate(self): if self.options.with_openssl == False and self._with_openssl(True): raise ConanInvalidConfiguration("with_openssl options is required (or choose auto)") if self.options.with_llvm == False and self._with_llvm(True): - raise ConanInvalidConfiguration("with_openssl options is required (or choose auto)") + raise ConanInvalidConfiguration("with_llvm options is required (or choose auto)") if self.options.with_cuda: raise ConanInvalidConfiguration("CCI has no cuda recipe (yet)") if self.options.with_orc: From 5e49174d69deb9d1cbbdf82bc8041b90098f560b Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Thu, 15 Sep 2022 12:37:21 +0200 Subject: [PATCH 065/133] ARROW-17694: [C++] Remove std::optional backport (#14105) Just use the C++17 standard library version. Authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- .github/workflows/cpp.yml | 9 +- LICENSE.txt | 28 - c_glib/arrow-glib/compute.cpp | 8 +- ci/scripts/python_wheel_macos_build.sh | 2 +- ci/vcpkg/universal2-osx-static-debug.cmake | 2 +- ci/vcpkg/universal2-osx-static-release.cmake | 2 +- .../arrow/compute_register_example.cc | 2 +- .../execution_plan_documentation_examples.cc | 32 +- cpp/examples/arrow/join_example.cc | 2 +- cpp/examples/minimal_build/run_static.sh | 2 +- cpp/gdb_arrow.py | 24 - cpp/src/arrow/array/array_binary.h | 4 +- cpp/src/arrow/array/array_primitive.h | 8 +- cpp/src/arrow/array/array_test.cc | 4 +- .../arrow/compute/exec/asof_join_benchmark.cc | 2 +- cpp/src/arrow/compute/exec/asof_join_node.cc | 26 +- .../arrow/compute/exec/asof_join_node_test.cc | 4 +- cpp/src/arrow/compute/exec/benchmark_util.cc | 4 +- cpp/src/arrow/compute/exec/exec_plan.cc | 18 +- cpp/src/arrow/compute/exec/exec_plan.h | 6 +- cpp/src/arrow/compute/exec/expression.cc | 52 +- cpp/src/arrow/compute/exec/expression_test.cc | 6 +- .../arrow/compute/exec/hash_join_node_test.cc | 18 +- cpp/src/arrow/compute/exec/options.h | 16 +- cpp/src/arrow/compute/exec/plan_test.cc | 72 +- cpp/src/arrow/compute/exec/sink_node.cc | 14 +- cpp/src/arrow/compute/exec/source_node.cc | 14 +- cpp/src/arrow/compute/exec/subtree_internal.h | 6 +- cpp/src/arrow/compute/exec/subtree_test.cc | 16 +- cpp/src/arrow/compute/exec/test_util.cc | 6 +- cpp/src/arrow/compute/exec/test_util.h | 10 +- cpp/src/arrow/compute/exec/tpch_benchmark.cc | 4 +- cpp/src/arrow/compute/exec/tpch_node.cc | 26 +- cpp/src/arrow/compute/exec/tpch_node.h | 4 +- cpp/src/arrow/compute/exec/tpch_node_test.cc | 6 +- cpp/src/arrow/compute/exec/union_node_test.cc | 2 +- cpp/src/arrow/compute/exec/util.h | 4 +- .../arrow/compute/kernels/codegen_internal.h | 2 +- .../arrow/compute/kernels/hash_aggregate.cc | 36 +- .../compute/kernels/hash_aggregate_test.cc | 4 +- .../compute/kernels/scalar_cast_string.cc | 2 +- .../arrow/compute/kernels/scalar_compare.cc | 8 +- .../arrow/compute/kernels/scalar_if_else.cc | 4 +- .../compute/kernels/vector_replace_test.cc | 4 +- cpp/src/arrow/compute/kernels/vector_sort.cc | 2 +- cpp/src/arrow/config.cc | 4 +- cpp/src/arrow/config.h | 6 +- cpp/src/arrow/csv/reader.cc | 8 +- cpp/src/arrow/csv/writer_test.cc | 6 +- cpp/src/arrow/dataset/dataset.cc | 12 +- cpp/src/arrow/dataset/dataset.h | 6 +- cpp/src/arrow/dataset/dataset_test.cc | 7 +- cpp/src/arrow/dataset/dataset_writer_test.cc | 16 +- cpp/src/arrow/dataset/file_base.cc | 10 +- cpp/src/arrow/dataset/file_base.h | 8 +- cpp/src/arrow/dataset/file_csv.cc | 6 +- cpp/src/arrow/dataset/file_csv.h | 2 +- cpp/src/arrow/dataset/file_ipc.cc | 6 +- cpp/src/arrow/dataset/file_ipc.h | 2 +- cpp/src/arrow/dataset/file_orc.cc | 6 +- cpp/src/arrow/dataset/file_orc.h | 2 +- cpp/src/arrow/dataset/file_parquet.cc | 24 +- cpp/src/arrow/dataset/file_parquet.h | 12 +- cpp/src/arrow/dataset/file_parquet_test.cc | 14 +- cpp/src/arrow/dataset/file_test.cc | 8 +- cpp/src/arrow/dataset/partition.cc | 12 +- cpp/src/arrow/dataset/partition.h | 8 +- cpp/src/arrow/dataset/scanner.cc | 26 +- cpp/src/arrow/dataset/scanner_benchmark.cc | 2 +- cpp/src/arrow/dataset/scanner_test.cc | 34 +- cpp/src/arrow/dataset/test_util.h | 10 +- .../engine/simple_extension_type_internal.h | 6 +- .../engine/substrait/expression_internal.cc | 6 +- .../arrow/engine/substrait/extension_set.cc | 39 +- .../arrow/engine/substrait/extension_set.h | 20 +- .../arrow/engine/substrait/extension_types.cc | 8 +- .../arrow/engine/substrait/extension_types.h | 6 +- cpp/src/arrow/engine/substrait/serde_test.cc | 4 +- cpp/src/arrow/engine/substrait/util.cc | 6 +- cpp/src/arrow/engine/substrait/util.h | 3 +- cpp/src/arrow/filesystem/gcsfs.h | 4 +- cpp/src/arrow/filesystem/path_util.cc | 6 +- cpp/src/arrow/filesystem/path_util.h | 6 +- cpp/src/arrow/filesystem/s3fs.cc | 8 +- cpp/src/arrow/flight/cookie_internal.cc | 4 +- cpp/src/arrow/flight/cookie_internal.h | 4 +- cpp/src/arrow/flight/flight_internals_test.cc | 26 +- cpp/src/arrow/flight/sql/client_test.cc | 14 +- cpp/src/arrow/flight/sql/server.cc | 2 +- cpp/src/arrow/flight/sql/server.h | 14 +- cpp/src/arrow/flight/sql/server_test.cc | 10 +- cpp/src/arrow/flight/sql/test_app_cli.cc | 14 +- cpp/src/arrow/flight/sql/types.h | 6 +- cpp/src/arrow/flight/transport.cc | 6 +- cpp/src/arrow/flight/transport.h | 8 +- .../flight/transport/grpc/util_internal.cc | 20 +- .../flight/transport/ucx/ucx_internal.cc | 2 +- cpp/src/arrow/memory_pool.cc | 6 +- cpp/src/arrow/public_api_test.cc | 4 +- cpp/src/arrow/stl_iterator.h | 6 +- cpp/src/arrow/stl_iterator_test.cc | 4 +- cpp/src/arrow/stl_test.cc | 15 +- cpp/src/arrow/testing/gtest_util.cc | 6 +- cpp/src/arrow/testing/gtest_util.h | 21 +- cpp/src/arrow/testing/matchers.h | 14 +- cpp/src/arrow/testing/util.cc | 8 +- cpp/src/arrow/testing/util.h | 4 +- cpp/src/arrow/util/async_generator.h | 12 +- cpp/src/arrow/util/async_generator_test.cc | 4 +- cpp/src/arrow/util/async_util.cc | 8 +- cpp/src/arrow/util/async_util.h | 2 +- cpp/src/arrow/util/async_util_test.cc | 4 +- cpp/src/arrow/util/cancel_test.cc | 12 +- cpp/src/arrow/util/cpu_info.cc | 6 +- cpp/src/arrow/util/future.h | 10 +- cpp/src/arrow/util/io_util_test.cc | 2 +- cpp/src/arrow/util/iterator.h | 18 +- cpp/src/arrow/util/optional.h | 35 - cpp/src/arrow/util/reflection_test.cc | 32 +- cpp/src/arrow/util/string.cc | 6 +- cpp/src/arrow/util/string.h | 6 +- cpp/src/arrow/util/task_group.cc | 2 +- cpp/src/arrow/vendored/optional.hpp | 1553 ----------------- cpp/src/gandiva/cache.h | 4 +- cpp/src/gandiva/lru_cache.h | 7 +- cpp/src/gandiva/lru_cache_test.cc | 2 +- cpp/src/parquet/level_conversion.cc | 4 +- cpp/src/parquet/statistics.cc | 12 +- cpp/src/parquet/stream_reader.h | 8 +- cpp/src/parquet/stream_reader_test.cc | 2 +- cpp/src/parquet/stream_writer.h | 8 +- dev/release/verify-release-candidate.sh | 2 +- .../autobrew/apache-arrow.rb | 1 + dev/tasks/tasks.yml | 6 +- docs/source/cpp/datatypes.rst | 6 +- docs/source/cpp/gdb.rst | 3 +- docs/source/cpp/streaming_execution.rst | 6 +- python/CMakeLists.txt | 2 +- python/pyarrow/includes/common.pxd | 9 + python/pyarrow/includes/libarrow.pxd | 7 - python/pyarrow/includes/libarrow_fs.pxd | 2 +- python/pyarrow/src/gdb.cc | 5 - python/pyarrow/src/python_test.cc | 4 +- python/pyarrow/tests/test_gdb.py | 7 - r/configure.win | 7 +- r/src/Makevars.ucrt | 3 + r/src/compute-exec.cpp | 6 +- r/src/config.cpp | 5 +- 148 files changed, 653 insertions(+), 2311 deletions(-) delete mode 100644 cpp/src/arrow/util/optional.h delete mode 100644 cpp/src/arrow/vendored/optional.hpp diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index d4c5c4c8979..8f4702dacb6 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -257,7 +257,7 @@ jobs: - name: Install ccache shell: bash run: | - ci/scripts/install_ccache.sh 4.6.2 /usr + ci/scripts/install_ccache.sh 4.6.3 /usr - name: Setup ccache shell: bash run: | @@ -271,8 +271,11 @@ jobs: uses: actions/cache@v2 with: path: ${{ steps.ccache-info.outputs.cache-dir }} - key: cpp-ccache-windows-${{ hashFiles('cpp/**') }} - restore-keys: cpp-ccache-windows- + key: cpp-ccache-windows-${{ env.CACHE_VERSION }}-${{ hashFiles('cpp/**') }} + restore-keys: cpp-ccache-windows-${{ env.CACHE_VERSION }}- + env: + # We can invalidate the current cache by updating this. + CACHE_VERSION: "2022-09-13" - name: Build shell: cmd run: | diff --git a/LICENSE.txt b/LICENSE.txt index a82c22aecea..6532b8790c3 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -2059,34 +2059,6 @@ René Nyffenegger rene.nyffenegger@adp-gmbh.ch -------------------------------------------------------------------------------- -The file cpp/src/arrow/vendored/optional.hpp has the following license - -Boost Software License - Version 1.0 - August 17th, 2003 - -Permission is hereby granted, free of charge, to any person or organization -obtaining a copy of the software and accompanying documentation covered by -this license (the "Software") to use, reproduce, display, distribute, -execute, and transmit the Software, and to prepare derivative works of the -Software, and to permit third-parties to whom the Software is furnished to -do so, all subject to the following: - -The copyright notices in the Software and this entire statement, including -the above license grant, this restriction and the following disclaimer, -must be included in all copies of the Software, in whole or in part, and -all derivative works of the Software, unless such copies or derivative -works are solely in the form of machine-executable object code generated by -a source language processor. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT -SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE -FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, -ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. - --------------------------------------------------------------------------------- - This project includes code from Folly. * cpp/src/arrow/vendored/ProducerConsumerQueue.h diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index f3a29be5e43..3404af794de 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -938,7 +938,7 @@ garrow_source_node_options_new_record_batch_reader( arrow_reader->schema(), [arrow_reader]() { using ExecBatch = arrow::compute::ExecBatch; - using ExecBatchOptional = arrow::util::optional; + using ExecBatchOptional = std::optional; auto arrow_record_batch_result = arrow_reader->Next(); if (!arrow_record_batch_result.ok()) { return arrow::AsyncGeneratorEnd(); @@ -979,7 +979,7 @@ garrow_source_node_options_new_record_batch(GArrowRecordBatch *record_batch) state->record_batch->schema(), [state]() { using ExecBatch = arrow::compute::ExecBatch; - using ExecBatchOptional = arrow::util::optional; + using ExecBatchOptional = std::optional; if (!state->generated) { state->generated = true; return arrow::Future::MakeFinished( @@ -1296,7 +1296,7 @@ garrow_aggregate_node_options_new(GList *aggregations, typedef struct GArrowSinkNodeOptionsPrivate_ { - arrow::AsyncGenerator> generator; + arrow::AsyncGenerator> generator; GArrowRecordBatchReader *reader; } GArrowSinkNodeOptionsPrivate; @@ -1333,7 +1333,7 @@ garrow_sink_node_options_init(GArrowSinkNodeOptions *object) { auto priv = GARROW_SINK_NODE_OPTIONS_GET_PRIVATE(object); new(&(priv->generator)) - arrow::AsyncGenerator>(); + arrow::AsyncGenerator>(); } static void diff --git a/ci/scripts/python_wheel_macos_build.sh b/ci/scripts/python_wheel_macos_build.sh index 92494fb4b76..cdd2bd3a400 100755 --- a/ci/scripts/python_wheel_macos_build.sh +++ b/ci/scripts/python_wheel_macos_build.sh @@ -34,7 +34,7 @@ rm -rf ${source_dir}/python/pyarrow/*.so.* echo "=== (${PYTHON_VERSION}) Set SDK, C++ and Wheel flags ===" export _PYTHON_HOST_PLATFORM="macosx-${MACOSX_DEPLOYMENT_TARGET}-${arch}" -export MACOSX_DEPLOYMENT_TARGET=${MACOSX_DEPLOYMENT_TARGET:-10.13} +export MACOSX_DEPLOYMENT_TARGET=${MACOSX_DEPLOYMENT_TARGET:-10.14} export SDKROOT=${SDKROOT:-$(xcrun --sdk macosx --show-sdk-path)} if [ $arch = "arm64" ]; then diff --git a/ci/vcpkg/universal2-osx-static-debug.cmake b/ci/vcpkg/universal2-osx-static-debug.cmake index 29e4b0e63c5..580b4604d52 100644 --- a/ci/vcpkg/universal2-osx-static-debug.cmake +++ b/ci/vcpkg/universal2-osx-static-debug.cmake @@ -21,6 +21,6 @@ set(VCPKG_LIBRARY_LINKAGE static) set(VCPKG_CMAKE_SYSTEM_NAME Darwin) set(VCPKG_OSX_ARCHITECTURES "x86_64;arm64") -set(VCPKG_OSX_DEPLOYMENT_TARGET "10.13") +set(VCPKG_OSX_DEPLOYMENT_TARGET "10.14") set(VCPKG_BUILD_TYPE debug) diff --git a/ci/vcpkg/universal2-osx-static-release.cmake b/ci/vcpkg/universal2-osx-static-release.cmake index 8111169fab2..7247d0af351 100644 --- a/ci/vcpkg/universal2-osx-static-release.cmake +++ b/ci/vcpkg/universal2-osx-static-release.cmake @@ -21,6 +21,6 @@ set(VCPKG_LIBRARY_LINKAGE static) set(VCPKG_CMAKE_SYSTEM_NAME Darwin) set(VCPKG_OSX_ARCHITECTURES "x86_64;arm64") -set(VCPKG_OSX_DEPLOYMENT_TARGET "10.13") +set(VCPKG_OSX_DEPLOYMENT_TARGET "10.14") set(VCPKG_BUILD_TYPE release) diff --git a/cpp/examples/arrow/compute_register_example.cc b/cpp/examples/arrow/compute_register_example.cc index 2a76e8595b6..d8debd9c3e1 100644 --- a/cpp/examples/arrow/compute_register_example.cc +++ b/cpp/examples/arrow/compute_register_example.cc @@ -149,7 +149,7 @@ arrow::Status RunComputeRegister(int argc, char** argv) { ARROW_RETURN_NOT_OK(maybe_plan.status()); ARROW_ASSIGN_OR_RAISE(auto plan, maybe_plan); - arrow::AsyncGenerator> source_gen, sink_gen; + arrow::AsyncGenerator> source_gen, sink_gen; ARROW_RETURN_NOT_OK( cp::Declaration::Sequence( { diff --git a/cpp/examples/arrow/execution_plan_documentation_examples.cc b/cpp/examples/arrow/execution_plan_documentation_examples.cc index b7c690bb278..9a2d682bbae 100644 --- a/cpp/examples/arrow/execution_plan_documentation_examples.cc +++ b/cpp/examples/arrow/execution_plan_documentation_examples.cc @@ -157,11 +157,11 @@ struct BatchesWithSchema { std::shared_ptr schema; // This method uses internal arrow utilities to // convert a vector of record batches to an AsyncGenerator of optional batches - arrow::AsyncGenerator> gen() const { + arrow::AsyncGenerator> gen() const { auto opt_batches = ::arrow::internal::MapVector( - [](cp::ExecBatch batch) { return arrow::util::make_optional(std::move(batch)); }, + [](cp::ExecBatch batch) { return std::make_optional(std::move(batch)); }, batches); - arrow::AsyncGenerator> gen; + arrow::AsyncGenerator> gen; gen = arrow::MakeVectorGenerator(std::move(opt_batches)); return gen; } @@ -259,7 +259,7 @@ arrow::Result MakeGroupableBatches(int multiplicity = 1) { arrow::Status ExecutePlanAndCollectAsTable( cp::ExecContext& exec_context, std::shared_ptr plan, std::shared_ptr schema, - arrow::AsyncGenerator> sink_gen) { + arrow::AsyncGenerator> sink_gen) { // translate sink_gen (async) to sink_reader (sync) std::shared_ptr sink_reader = cp::MakeGeneratorReader(schema, std::move(sink_gen), exec_context.memory_pool()); @@ -312,7 +312,7 @@ arrow::Status ScanSinkExample(cp::ExecContext& exec_context) { ARROW_ASSIGN_OR_RAISE(scan, cp::MakeExecNode("scan", plan.get(), {}, scan_node_options)); - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; ARROW_RETURN_NOT_OK( cp::MakeExecNode("sink", plan.get(), {scan}, cp::SinkNodeOptions{&sink_gen})); @@ -337,7 +337,7 @@ arrow::Status SourceSinkExample(cp::ExecContext& exec_context) { ARROW_ASSIGN_OR_RAISE(auto basic_data, MakeBasicBatches()); - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; auto source_node_options = cp::SourceNodeOptions{basic_data.schema, basic_data.gen()}; @@ -367,7 +367,7 @@ arrow::Status TableSourceSinkExample(cp::ExecContext& exec_context) { ARROW_ASSIGN_OR_RAISE(auto table, GetTable()); - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; int max_batch_size = 2; auto table_source_options = cp::TableSourceNodeOptions{table, max_batch_size}; @@ -427,7 +427,7 @@ arrow::Status ScanFilterSinkExample(cp::ExecContext& exec_context) { cp::FilterNodeOptions{filter_opt})); // finally, pipe the filter node into a sink node - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; ARROW_RETURN_NOT_OK( cp::MakeExecNode("sink", plan.get(), {filter}, cp::SinkNodeOptions{&sink_gen})); @@ -470,7 +470,7 @@ arrow::Status ScanProjectSinkExample(cp::ExecContext& exec_context) { std::cout << "Schema after projection : \n" << project->output_schema()->ToString() << std::endl; - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; ARROW_RETURN_NOT_OK( cp::MakeExecNode("sink", plan.get(), {project}, cp::SinkNodeOptions{&sink_gen})); auto schema = arrow::schema({arrow::field("a * 2", arrow::int32())}); @@ -496,7 +496,7 @@ arrow::Status SourceScalarAggregateSinkExample(cp::ExecContext& exec_context) { ARROW_ASSIGN_OR_RAISE(auto basic_data, MakeBasicBatches()); - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; auto source_node_options = cp::SourceNodeOptions{basic_data.schema, basic_data.gen()}; @@ -532,7 +532,7 @@ arrow::Status SourceGroupAggregateSinkExample(cp::ExecContext& exec_context) { ARROW_ASSIGN_OR_RAISE(auto basic_data, MakeBasicBatches()); - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; auto source_node_options = cp::SourceNodeOptions{basic_data.schema, basic_data.gen()}; @@ -640,7 +640,7 @@ arrow::Status SourceOrderBySinkExample(cp::ExecContext& exec_context) { std::cout << "basic data created" << std::endl; - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; auto source_node_options = cp::SourceNodeOptions{basic_data.schema, basic_data.gen()}; ARROW_ASSIGN_OR_RAISE(cp::ExecNode * source, @@ -670,7 +670,7 @@ arrow::Status SourceHashJoinSinkExample(cp::ExecContext& exec_context) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan, cp::ExecPlan::Make(&exec_context)); - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; cp::ExecNode* left_source; cp::ExecNode* right_source; @@ -714,7 +714,7 @@ arrow::Status SourceKSelectExample(cp::ExecContext& exec_context) { ARROW_ASSIGN_OR_RAISE(auto input, MakeGroupableBatches()); ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan, cp::ExecPlan::Make(&exec_context)); - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; ARROW_ASSIGN_OR_RAISE( cp::ExecNode * source, @@ -761,7 +761,7 @@ arrow::Status ScanFilterWriteExample(cp::ExecContext& exec_context, ARROW_ASSIGN_OR_RAISE(scan, cp::MakeExecNode("scan", plan.get(), {}, scan_node_options)); - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; std::string root_path = ""; std::string uri = "file://" + file_path; @@ -820,7 +820,7 @@ arrow::Status SourceUnionSinkExample(cp::ExecContext& exec_context) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan, cp::ExecPlan::Make(&exec_context)); - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; cp::Declaration union_node{"union", cp::ExecNodeOptions{}}; cp::Declaration lhs{"source", diff --git a/cpp/examples/arrow/join_example.cc b/cpp/examples/arrow/join_example.cc index e531bfbfbf9..7bea588e3ad 100644 --- a/cpp/examples/arrow/join_example.cc +++ b/cpp/examples/arrow/join_example.cc @@ -89,7 +89,7 @@ arrow::Status DoHashJoin() { ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan, cp::ExecPlan::Make(&exec_context)); - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; cp::ExecNode* left_source; cp::ExecNode* right_source; diff --git a/cpp/examples/minimal_build/run_static.sh b/cpp/examples/minimal_build/run_static.sh index cf2a9912f50..619811d09ac 100755 --- a/cpp/examples/minimal_build/run_static.sh +++ b/cpp/examples/minimal_build/run_static.sh @@ -102,7 +102,7 @@ echo rm -rf $EXAMPLE_BUILD_DIR mkdir -p $EXAMPLE_BUILD_DIR -${CXX:-c++} \ +${CXX:-c++} -std=c++17 \ -o $EXAMPLE_BUILD_DIR/arrow-example \ $EXAMPLE_DIR/example.cc \ $(PKG_CONFIG_PATH=$ARROW_BUILD_DIR/lib/pkgconfig \ diff --git a/cpp/gdb_arrow.py b/cpp/gdb_arrow.py index af3dad9c087..5421b4ffb15 100644 --- a/cpp/gdb_arrow.py +++ b/cpp/gdb_arrow.py @@ -2175,28 +2175,6 @@ def to_string(self): return f"arrow::util::string_view of size {size}, {data}" -class OptionalPrinter: - """ - Pretty-printer for arrow::util::optional. - """ - - def __init__(self, name, val): - self.val = val - - def to_string(self): - data_type = self.val.type.template_argument(0) - # XXX We rely on internal details of our vendored optional - # implementation, as inlined methods may not be callable from gdb. - if not self.val['has_value_']: - inner = "nullopt" - else: - data_ptr = self.val['contained']['data'].address - assert data_ptr - inner = data_ptr.reinterpret_cast( - data_type.pointer()).dereference() - return f"arrow::util::optional<{data_type}>({inner})" - - class VariantPrinter: """ Pretty-printer for arrow::util::Variant. @@ -2436,10 +2414,8 @@ def to_string(self): "arrow::SimpleTable": TablePrinter, "arrow::Status": StatusPrinter, "arrow::Table": TablePrinter, - "arrow::util::optional": OptionalPrinter, "arrow::util::string_view": StringViewPrinter, "arrow::util::Variant": VariantPrinter, - "nonstd::optional_lite::optional": OptionalPrinter, "nonstd::sv_lite::basic_string_view": StringViewPrinter, } diff --git a/cpp/src/arrow/array/array_binary.h b/cpp/src/arrow/array/array_binary.h index 04ee804987f..cc04d792002 100644 --- a/cpp/src/arrow/array/array_binary.h +++ b/cpp/src/arrow/array/array_binary.h @@ -75,7 +75,7 @@ class BaseBinaryArray : public FlatArray { raw_value_offsets_[i + 1] - pos); } - util::optional operator[](int64_t i) const { + std::optional operator[](int64_t i) const { return *IteratorType(*this, i); } @@ -240,7 +240,7 @@ class ARROW_EXPORT FixedSizeBinaryArray : public PrimitiveArray { return util::string_view(reinterpret_cast(GetValue(i)), byte_width()); } - util::optional operator[](int64_t i) const { + std::optional operator[](int64_t i) const { return *IteratorType(*this, i); } diff --git a/cpp/src/arrow/array/array_primitive.h b/cpp/src/arrow/array/array_primitive.h index 740a4806a4d..e6df92e3b78 100644 --- a/cpp/src/arrow/array/array_primitive.h +++ b/cpp/src/arrow/array/array_primitive.h @@ -54,7 +54,7 @@ class ARROW_EXPORT BooleanArray : public PrimitiveArray { bool GetView(int64_t i) const { return Value(i); } - util::optional operator[](int64_t i) const { return *IteratorType(*this, i); } + std::optional operator[](int64_t i) const { return *IteratorType(*this, i); } /// \brief Return the number of false (0) values among the valid /// values. Result is not cached. @@ -111,7 +111,7 @@ class NumericArray : public PrimitiveArray { // For API compatibility with BinaryArray etc. value_type GetView(int64_t i) const { return Value(i); } - util::optional operator[](int64_t i) const { + std::optional operator[](int64_t i) const { return *IteratorType(*this, i); } @@ -152,7 +152,7 @@ class ARROW_EXPORT DayTimeIntervalArray : public PrimitiveArray { IteratorType end() const { return IteratorType(*this, length()); } - util::optional operator[](int64_t i) const { + std::optional operator[](int64_t i) const { return *IteratorType(*this, i); } @@ -188,7 +188,7 @@ class ARROW_EXPORT MonthDayNanoIntervalArray : public PrimitiveArray { IteratorType end() const { return IteratorType(*this, length()); } - util::optional operator[](int64_t i) const { + std::optional operator[](int64_t i) const { return *IteratorType(*this, i); } diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index d438557a330..9256d4ad0b7 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -2290,7 +2290,7 @@ TEST_F(TestFWBinaryArray, ArrayIndexOperator) { auto fsba = checked_pointer_cast(arr); ASSERT_EQ("abc", (*fsba)[0].value()); - ASSERT_EQ(util::nullopt, (*fsba)[1]); + ASSERT_EQ(std::nullopt, (*fsba)[1]); ASSERT_EQ("def", (*fsba)[2].value()); } @@ -3538,7 +3538,7 @@ TYPED_TEST(TestPrimitiveArray, IndexOperator) { ASSERT_EQ(this->values_[i], res.value()); } else { ASSERT_FALSE(res.has_value()); - ASSERT_EQ(res, util::nullopt); + ASSERT_EQ(res, std::nullopt); } } } diff --git a/cpp/src/arrow/compute/exec/asof_join_benchmark.cc b/cpp/src/arrow/compute/exec/asof_join_benchmark.cc index 7d8abc0ba4c..a0362eb1ba8 100644 --- a/cpp/src/arrow/compute/exec/asof_join_benchmark.cc +++ b/cpp/src/arrow/compute/exec/asof_join_benchmark.cc @@ -88,7 +88,7 @@ static void TableJoinOverhead(benchmark::State& state, } ASSERT_OK_AND_ASSIGN(arrow::compute::ExecNode * join_node, MakeExecNode(factory_name, plan.get(), input_nodes, options)); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK(MakeExecNode("sink", plan.get(), {join_node}, SinkNodeOptions{&sink_gen})); state.ResumeTiming(); ASSERT_FINISHES_OK(StartAndCollect(plan.get(), sink_gen)); diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 869456a5775..35e7b1c6cc6 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -36,7 +37,6 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/future.h" #include "arrow/util/make_unique.h" -#include "arrow/util/optional.h" #include "arrow/util/string_view.h" namespace arrow { @@ -99,11 +99,11 @@ class ConcurrentQueue { queue_ = std::queue(); } - util::optional TryPop() { + std::optional TryPop() { // Try to pop the oldest value from the queue (or return nullopt if none) std::unique_lock lock(mutex_); if (queue_.empty()) { - return util::nullopt; + return std::nullopt; } else { auto item = queue_.front(); queue_.pop(); @@ -156,10 +156,10 @@ struct MemoStore { e.time = time; } - util::optional GetEntryForKey(ByType key) const { + std::optional GetEntryForKey(ByType key) const { auto e = entries_.find(key); - if (entries_.end() == e) return util::nullopt; - return util::optional(&e->second); + if (entries_.end() == e) return std::nullopt; + return std::optional(&e->second); } void RemoveEntriesWithLesserTime(OnType ts) { @@ -263,7 +263,7 @@ class InputState { return dst_offset; } - const util::optional& MapSrcToDst(col_index_t src) const { + const std::optional& MapSrcToDst(col_index_t src) const { return src_to_dst_[src]; } @@ -436,16 +436,16 @@ class InputState { return Status::OK(); } - util::optional GetMemoEntryForKey(ByType key) { + std::optional GetMemoEntryForKey(ByType key) { return memo_.GetEntryForKey(key); } - util::optional GetMemoTimeForKey(ByType key) { + std::optional GetMemoTimeForKey(ByType key) { auto r = GetMemoEntryForKey(key); if (r.has_value()) { return (*r)->time; } else { - return util::nullopt; + return std::nullopt; } } @@ -492,7 +492,7 @@ class InputState { // Stores latest known values for the various keys MemoStore memo_; // Mapping of source columns to destination columns - std::vector> src_to_dst_; + std::vector> src_to_dst_; }; template @@ -555,7 +555,7 @@ class CompositeReferenceTable { // Get the state for that key from all on the RHS -- assumes it's up to date // (the RHS state comes from the memoized row references) for (size_t i = 1; i < in.size(); ++i) { - util::optional opt_entry = in[i]->GetMemoEntryForKey(key); + std::optional opt_entry = in[i]->GetMemoEntryForKey(key); if (opt_entry.has_value()) { DCHECK(*opt_entry); if ((*opt_entry)->time + tolerance >= lhs_latest_time) { @@ -588,7 +588,7 @@ class CompositeReferenceTable { int n_src_cols = state.at(i_table)->get_schema()->num_fields(); { for (col_index_t i_src_col = 0; i_src_col < n_src_cols; ++i_src_col) { - util::optional i_dst_col_opt = + std::optional i_dst_col_opt = state[i_table]->MapSrcToDst(i_src_col); if (!i_dst_col_opt) continue; col_index_t i_dst_col = *i_dst_col_opt; diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 48d1ae6410b..2e4bb06176a 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -222,7 +222,7 @@ void CheckRunOutput(const BatchesWithSchema& l_batches, join.inputs.emplace_back(Declaration{ "source", SourceNodeOptions{r1_batches.schema, r1_batches.gen(false, false)}}); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) .AddToPlan(plan.get())); @@ -267,7 +267,7 @@ void DoInvalidPlanTest(const BatchesWithSchema& l_batches, "source", SourceNodeOptions{r_batches.schema, r_batches.gen(false, false)}}); if (fail_on_plan_creation) { - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) .AddToPlan(plan.get())); EXPECT_FINISHES_AND_RAISES_WITH_MESSAGE_THAT(Invalid, diff --git a/cpp/src/arrow/compute/exec/benchmark_util.cc b/cpp/src/arrow/compute/exec/benchmark_util.cc index 5bac508854f..d4e14540a54 100644 --- a/cpp/src/arrow/compute/exec/benchmark_util.cc +++ b/cpp/src/arrow/compute/exec/benchmark_util.cc @@ -42,7 +42,7 @@ Status BenchmarkIsolatedNodeOverhead(benchmark::State& state, arrow::compute::ExecNodeOptions& options) { for (auto _ : state) { state.PauseTiming(); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan, arrow::compute::ExecPlan::Make(&ctx)); @@ -119,7 +119,7 @@ Status BenchmarkNodeOverhead( state.PauseTiming(); ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan, arrow::compute::ExecPlan::Make(&ctx)); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; arrow::compute::Declaration source = arrow::compute::Declaration( {"source", arrow::compute::SourceNodeOptions{data.schema, diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 00415495aa8..322b5f5e456 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -17,6 +17,7 @@ #include "arrow/compute/exec/exec_plan.h" +#include #include #include #include @@ -33,7 +34,6 @@ #include "arrow/util/async_generator.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" -#include "arrow/util/optional.h" #include "arrow/util/tracing_internal.h" namespace arrow { @@ -339,12 +339,12 @@ const ExecPlanImpl* ToDerived(const ExecPlan* ptr) { return checked_cast(ptr); } -util::optional GetNodeIndex(const std::vector& nodes, - const ExecNode* node) { +std::optional GetNodeIndex(const std::vector& nodes, + const ExecNode* node) { for (int i = 0; i < static_cast(nodes.size()); ++i) { if (nodes[i] == node) return i; } - return util::nullopt; + return std::nullopt; } } // namespace @@ -569,8 +569,8 @@ void MapNode::Finish(Status finish_st /*= Status::OK()*/) { } std::shared_ptr MakeGeneratorReader( - std::shared_ptr schema, - std::function>()> gen, MemoryPool* pool) { + std::shared_ptr schema, std::function>()> gen, + MemoryPool* pool) { struct Impl : RecordBatchReader { std::shared_ptr schema() const override { return schema_; } @@ -596,7 +596,7 @@ std::shared_ptr MakeGeneratorReader( MemoryPool* pool_; std::shared_ptr schema_; - Iterator> iterator_; + Iterator> iterator_; }; auto out = std::make_shared(); @@ -703,12 +703,12 @@ ExecFactoryRegistry* default_exec_factory_registry() { return &instance; } -Result>()>> MakeReaderGenerator( +Result>()>> MakeReaderGenerator( std::shared_ptr reader, ::arrow::internal::Executor* io_executor, int max_q, int q_restart) { auto batch_it = MakeMapIterator( [](std::shared_ptr batch) { - return util::make_optional(ExecBatch(*batch)); + return std::make_optional(ExecBatch(*batch)); }, MakeIteratorFromReader(reader)); diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index a9481e21a6e..93d06551241 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -30,7 +31,6 @@ #include "arrow/util/cancel.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/macros.h" -#include "arrow/util/optional.h" #include "arrow/util/tracing.h" #include "arrow/util/visibility.h" @@ -530,7 +530,7 @@ struct ARROW_EXPORT Declaration { /// The RecordBatchReader does not impose any ordering on emitted batches. ARROW_EXPORT std::shared_ptr MakeGeneratorReader( - std::shared_ptr, std::function>()>, + std::shared_ptr, std::function>()>, MemoryPool*); constexpr int kDefaultBackgroundMaxQ = 32; @@ -540,7 +540,7 @@ constexpr int kDefaultBackgroundQRestart = 16; /// /// Useful as a source node for an Exec plan ARROW_EXPORT -Result>()>> MakeReaderGenerator( +Result>()>> MakeReaderGenerator( std::shared_ptr reader, arrow::internal::Executor* io_executor, int max_q = kDefaultBackgroundMaxQ, int q_restart = kDefaultBackgroundQRestart); diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index 06f36c7f5ad..16942a0f80f 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -17,6 +17,7 @@ #include "arrow/compute/exec/expression.h" +#include #include #include @@ -31,7 +32,6 @@ #include "arrow/util/hash_util.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" -#include "arrow/util/optional.h" #include "arrow/util/string.h" #include "arrow/util/value_parsing.h" #include "arrow/util/vector.h" @@ -309,13 +309,12 @@ bool Expression::IsNullLiteral() const { } namespace { -util::optional GetNullHandling( - const Expression::Call& call) { +std::optional GetNullHandling(const Expression::Call& call) { DCHECK_NE(call.function, nullptr); if (call.function->kind() == compute::Function::SCALAR) { return static_cast(call.kernel)->null_handling; } - return util::nullopt; + return std::nullopt; } } // namespace @@ -616,8 +615,8 @@ ArgumentsAndFlippedArguments(const Expression::Call& call) { template ::value_type> -util::optional FoldLeft(It begin, It end, const BinOp& bin_op) { - if (begin == end) return util::nullopt; +std::optional FoldLeft(It begin, It end, const BinOp& bin_op) { + if (begin == end) return std::nullopt; Out folded = std::move(*begin++); while (begin != end) { @@ -738,18 +737,18 @@ std::vector GuaranteeConjunctionMembers( /// Recognizes expressions of the form: /// equal(a, 2) /// is_null(a) -util::optional> ExtractOneFieldValue( +std::optional> ExtractOneFieldValue( const Expression& guarantee) { auto call = guarantee.call(); - if (!call) return util::nullopt; + if (!call) return std::nullopt; // search for an equality conditions between a field and a literal if (call->function_name == "equal") { auto ref = call->arguments[0].field_ref(); - if (!ref) return util::nullopt; + if (!ref) return std::nullopt; auto lit = call->arguments[1].literal(); - if (!lit) return util::nullopt; + if (!lit) return std::nullopt; return std::make_pair(*ref, *lit); } @@ -757,12 +756,12 @@ util::optional> ExtractOneFieldValue( // ... or a known null field if (call->function_name == "is_null") { auto ref = call->arguments[0].field_ref(); - if (!ref) return util::nullopt; + if (!ref) return std::nullopt; return std::make_pair(*ref, Datum(std::make_shared())); } - return util::nullopt; + return std::nullopt; } // Conjunction members which are represented in known_values are erased from @@ -953,24 +952,24 @@ struct Inequality { // possibly disjuncted with an "is_null" Expression. // cmp(a, 2) // cmp(a, 2) or is_null(a) - static util::optional ExtractOne(const Expression& guarantee) { + static std::optional ExtractOne(const Expression& guarantee) { auto call = guarantee.call(); - if (!call) return util::nullopt; + if (!call) return std::nullopt; if (call->function_name == "or_kleene") { // expect the LHS to be a usable field inequality auto out = ExtractOneFromComparison(call->arguments[0]); - if (!out) return util::nullopt; + if (!out) return std::nullopt; // expect the RHS to be an is_null expression auto call_rhs = call->arguments[1].call(); - if (!call_rhs) return util::nullopt; - if (call_rhs->function_name != "is_null") return util::nullopt; + if (!call_rhs) return std::nullopt; + if (call_rhs->function_name != "is_null") return std::nullopt; // ... and that it references the same target auto target = call_rhs->arguments[0].field_ref(); - if (!target) return util::nullopt; - if (*target != out->target) return util::nullopt; + if (!target) return std::nullopt; + if (*target != out->target) return std::nullopt; out->nullable = true; return out; @@ -980,26 +979,25 @@ struct Inequality { return ExtractOneFromComparison(guarantee); } - static util::optional ExtractOneFromComparison( - const Expression& guarantee) { + static std::optional ExtractOneFromComparison(const Expression& guarantee) { auto call = guarantee.call(); - if (!call) return util::nullopt; + if (!call) return std::nullopt; if (auto cmp = Comparison::Get(call->function_name)) { // not_equal comparisons are not very usable as guarantees - if (*cmp == Comparison::NOT_EQUAL) return util::nullopt; + if (*cmp == Comparison::NOT_EQUAL) return std::nullopt; auto target = call->arguments[0].field_ref(); - if (!target) return util::nullopt; + if (!target) return std::nullopt; auto bound = call->arguments[1].literal(); - if (!bound) return util::nullopt; - if (!bound->is_scalar()) return util::nullopt; + if (!bound) return std::nullopt; + if (!bound->is_scalar()) return std::nullopt; return Inequality{*cmp, /*target=*/*target, *bound, /*nullable=*/false}; } - return util::nullopt; + return std::nullopt; } /// The given expression simplifies to `value` if the inequality diff --git a/cpp/src/arrow/compute/exec/expression_test.cc b/cpp/src/arrow/compute/exec/expression_test.cc index b4466d827eb..4cb4c272485 100644 --- a/cpp/src/arrow/compute/exec/expression_test.cc +++ b/cpp/src/arrow/compute/exec/expression_test.cc @@ -86,7 +86,7 @@ void ExpectResultsEqual(Actual&& actual, Expected&& expected) { } } -const auto no_change = util::nullopt; +const auto no_change = std::nullopt; TEST(ExpressionUtils, Comparison) { auto Expect = [](Result expected, Datum l, Datum r) { @@ -122,7 +122,7 @@ TEST(ExpressionUtils, Comparison) { } TEST(ExpressionUtils, StripOrderPreservingCasts) { - auto Expect = [](Expression expr, util::optional expected_stripped) { + auto Expect = [](Expression expr, std::optional expected_stripped) { ASSERT_OK_AND_ASSIGN(expr, expr.Bind(*kBoringSchema)); if (!expected_stripped) { expected_stripped = expr; @@ -499,7 +499,7 @@ TEST(Expression, BindLiteral) { } } -void ExpectBindsTo(Expression expr, util::optional expected, +void ExpectBindsTo(Expression expr, std::optional expected, Expression* bound_out = nullptr, const Schema& schema = *kBoringSchema) { if (!expected) { diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc index 3bb778b82ae..b45af654450 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -85,7 +85,7 @@ void CheckRunOutput(JoinType type, const BatchesWithSchema& l_batches, join.inputs.emplace_back(Declaration{ "source", SourceNodeOptions{r_batches.schema, r_batches.gen(parallel, /*slow=*/false)}}); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK(Declaration::Sequence({join, {"sink", SinkNodeOptions{&sink_gen}}}) .AddToPlan(plan.get())); @@ -915,7 +915,7 @@ Result> HashJoinWithExecPlan( ExecNode * join, MakeExecNode("hashjoin", plan.get(), {l_source, r_source}, join_options)); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ARROW_ASSIGN_OR_RAISE( std::ignore, MakeExecNode("sink", plan.get(), {join}, SinkNodeOptions{&sink_gen})); @@ -964,7 +964,7 @@ TEST(HashJoin, Suffix) { ExecContext exec_ctx; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx)); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ExecNode* left_source; ExecNode* right_source; @@ -1335,7 +1335,7 @@ void TestHashJoinDictionaryHelper( {(swap_sides ? r_source : l_source), (swap_sides ? l_source : r_source)}, join_options)); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK_AND_ASSIGN( std::ignore, MakeExecNode("sink", plan.get(), {join}, SinkNodeOptions{&sink_gen})); ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); @@ -1756,7 +1756,7 @@ TEST(HashJoin, DictNegative) { ASSERT_OK_AND_ASSIGN( ExecNode * join, MakeExecNode("hashjoin", plan.get(), {l_source, r_source}, join_options)); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK_AND_ASSIGN(std::ignore, MakeExecNode("sink", plan.get(), {join}, SinkNodeOptions{&sink_gen})); @@ -1806,7 +1806,7 @@ void TestSimpleJoinHelper(BatchesWithSchema input_left, BatchesWithSchema input_ BatchesWithSchema expected) { ExecContext exec_ctx; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx)); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ExecNode* left_source; ExecNode* right_source; @@ -2001,7 +2001,7 @@ TEST(HashJoin, ResidualFilter) { default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ExecNode* left_source; ExecNode* right_source; @@ -2079,7 +2079,7 @@ TEST(HashJoin, TrivialResidualFilter) { parallel ? arrow::internal::GetCpuThreadPool() : nullptr); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ExecNode* left_source; ExecNode* right_source; @@ -2242,7 +2242,7 @@ void TestSingleChainOfHashJoins(Random64Bit& rng) { ASSERT_OK_AND_ASSIGN(joins[i], MakeExecNode("hashjoin", plan.get(), inputs, opts[i])); } - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK( MakeExecNode("sink", plan.get(), {joins.back()}, SinkNodeOptions{&sink_gen})); ASSERT_FINISHES_OK_AND_ASSIGN(auto result, StartAndCollect(plan.get(), sink_gen)); diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index e0172bff7f7..c5edc0610c5 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -29,13 +30,12 @@ #include "arrow/result.h" #include "arrow/util/async_generator.h" #include "arrow/util/async_util.h" -#include "arrow/util/optional.h" #include "arrow/util/visibility.h" namespace arrow { namespace compute { -using AsyncExecBatchGenerator = AsyncGenerator>; +using AsyncExecBatchGenerator = AsyncGenerator>; /// \addtogroup execnode-options /// @{ @@ -51,14 +51,14 @@ class ARROW_EXPORT ExecNodeOptions { class ARROW_EXPORT SourceNodeOptions : public ExecNodeOptions { public: SourceNodeOptions(std::shared_ptr output_schema, - std::function>()> generator) + std::function>()> generator) : output_schema(std::move(output_schema)), generator(std::move(generator)) {} static Result> FromTable(const Table& table, arrow::internal::Executor*); std::shared_ptr output_schema; - std::function>()> generator; + std::function>()> generator; }; /// \brief An extended Source node which accepts a table @@ -166,7 +166,7 @@ struct ARROW_EXPORT BackpressureOptions { /// Emitted batches will not be ordered. class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions { public: - explicit SinkNodeOptions(std::function>()>* generator, + explicit SinkNodeOptions(std::function>()>* generator, BackpressureOptions backpressure = {}, BackpressureMonitor** backpressure_monitor = NULLPTR) : generator(generator), @@ -178,7 +178,7 @@ class ARROW_EXPORT SinkNodeOptions : public ExecNodeOptions { /// This will be set when the node is added to the plan and should be used to consume /// data from the plan. If this function is not called frequently enough then the sink /// node will start to accumulate data and may apply backpressure. - std::function>()>* generator; + std::function>()>* generator; /// \brief Options to control when to apply backpressure /// /// This is optional, the default is to never apply backpressure. If the plan is not @@ -250,7 +250,7 @@ class ARROW_EXPORT OrderBySinkNodeOptions : public SinkNodeOptions { public: explicit OrderBySinkNodeOptions( SortOptions sort_options, - std::function>()>* generator) + std::function>()>* generator) : SinkNodeOptions(generator), sort_options(std::move(sort_options)) {} SortOptions sort_options; @@ -427,7 +427,7 @@ class ARROW_EXPORT SelectKSinkNodeOptions : public SinkNodeOptions { public: explicit SelectKSinkNodeOptions( SelectKOptions select_k_options, - std::function>()>* generator) + std::function>()>* generator) : SinkNodeOptions(generator), select_k_options(std::move(select_k_options)) {} /// SelectK options diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index e06c41c7489..1dd071975ee 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -220,7 +220,7 @@ TEST(ExecPlanExecution, SourceSink) { SCOPED_TRACE(parallel ? "parallel" : "single threaded"); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; auto basic_data = MakeBasicBatches(); @@ -239,7 +239,7 @@ TEST(ExecPlanExecution, SourceSink) { } TEST(ExecPlanExecution, UseSinkAfterExecution) { - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); auto basic_data = MakeBasicBatches(); @@ -260,7 +260,7 @@ TEST(ExecPlanExecution, UseSinkAfterExecution) { TEST(ExecPlanExecution, TableSourceSink) { for (int batch_size : {1, 4}) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; auto exp_batches = MakeBasicBatches(); ASSERT_OK_AND_ASSIGN(auto table, @@ -281,7 +281,7 @@ TEST(ExecPlanExecution, TableSourceSink) { TEST(ExecPlanExecution, TableSourceSinkError) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; auto exp_batches = MakeBasicBatches(); ASSERT_OK_AND_ASSIGN(auto table, @@ -297,7 +297,7 @@ TEST(ExecPlanExecution, TableSourceSinkError) { } TEST(ExecPlanExecution, SinkNodeBackpressure) { - util::optional batch = + std::optional batch = ExecBatchFromJSON({int32(), boolean()}, "[[4, false], [5, null], [6, false], [7, false], [null, true]]"); constexpr uint32_t kPauseIfAbove = 4; @@ -307,8 +307,8 @@ TEST(ExecPlanExecution, SinkNodeBackpressure) { uint32_t resume_if_below_bytes = kResumeIfBelow * static_cast(batch->TotalBufferSize()); EXPECT_OK_AND_ASSIGN(std::shared_ptr plan, ExecPlan::Make()); - PushGenerator> batch_producer; - AsyncGenerator> sink_gen; + PushGenerator> batch_producer; + AsyncGenerator> sink_gen; BackpressureMonitor* backpressure_monitor; BackpressureOptions backpressure_options(resume_if_below_bytes, pause_if_above_bytes); std::shared_ptr schema_ = schema({field("data", uint32())}); @@ -349,14 +349,14 @@ TEST(ExecPlanExecution, SinkNodeBackpressure) { ASSERT_FALSE(backpressure_monitor->is_paused()); // Cleanup - batch_producer.producer().Push(IterationEnd>()); + batch_producer.producer().Push(IterationEnd>()); plan->StopProducing(); ASSERT_FINISHES_OK(plan->finished()); } TEST(ExecPlan, ToString) { auto basic_data = MakeBasicBatches(); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); ASSERT_OK(Declaration::Sequence( @@ -462,7 +462,7 @@ TEST(ExecPlanExecution, SourceOrderBy) { SCOPED_TRACE(parallel ? "parallel" : "single threaded"); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; auto basic_data = MakeBasicBatches(); @@ -483,16 +483,16 @@ TEST(ExecPlanExecution, SourceOrderBy) { TEST(ExecPlanExecution, SourceSinkError) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; auto basic_data = MakeBasicBatches(); auto it = basic_data.batches.begin(); - AsyncGenerator> error_source_gen = - [&]() -> Result> { + AsyncGenerator> error_source_gen = + [&]() -> Result> { if (it == basic_data.batches.end()) { return Status::Invalid("Artificial error"); } - return util::make_optional(*it++); + return std::make_optional(*it++); }; ASSERT_OK(Declaration::Sequence( @@ -693,7 +693,7 @@ TEST(ExecPlanExecution, StressSourceSink) { int num_batches = (slow && !parallel) ? 30 : 300; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; auto random_data = MakeRandomBatches( schema({field("a", int32()), field("b", boolean())}), num_batches); @@ -723,7 +723,7 @@ TEST(ExecPlanExecution, StressSourceOrderBy) { int num_batches = (slow && !parallel) ? 30 : 300; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; auto random_data = MakeRandomBatches(input_schema, num_batches); @@ -760,7 +760,7 @@ TEST(ExecPlanExecution, StressSourceGroupedSumStop) { int num_batches = (slow && !parallel) ? 30 : 300; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; auto random_data = MakeRandomBatches(input_schema, num_batches); @@ -795,7 +795,7 @@ TEST(ExecPlanExecution, StressSourceSinkStopped) { int num_batches = (slow && !parallel) ? 30 : 300; ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; auto random_data = MakeRandomBatches( schema({field("a", int32()), field("b", boolean())}), num_batches); @@ -823,7 +823,7 @@ TEST(ExecPlanExecution, SourceFilterSink) { auto basic_data = MakeBasicBatches(); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK(Declaration::Sequence( { @@ -845,7 +845,7 @@ TEST(ExecPlanExecution, SourceProjectSink) { auto basic_data = MakeBasicBatches(); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK(Declaration::Sequence( { @@ -911,7 +911,7 @@ TEST(ExecPlanExecution, SourceGroupedSum) { auto input = MakeGroupableBatches(/*multiplicity=*/parallel ? 100 : 1); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK( Declaration::Sequence( @@ -945,7 +945,7 @@ TEST(ExecPlanExecution, SourceMinMaxScalar) { R"({"min": -8, "max": 12})")}); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; // NOTE: Test `ScalarAggregateNode` by omitting `keys` attribute ASSERT_OK(Declaration::Sequence( @@ -976,7 +976,7 @@ TEST(ExecPlanExecution, NestedSourceFilter) { ])"); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK(Declaration::Sequence( { @@ -1005,7 +1005,7 @@ TEST(ExecPlanExecution, NestedSourceProjectGroupedSum) { ])"); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK( Declaration::Sequence( @@ -1037,7 +1037,7 @@ TEST(ExecPlanExecution, SourceFilterProjectGroupedSumFilter) { auto input = MakeGroupableBatches(/*multiplicity=*/batch_multiplicity); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK( Declaration::Sequence( @@ -1076,7 +1076,7 @@ TEST(ExecPlanExecution, SourceFilterProjectGroupedSumOrderBy) { auto input = MakeGroupableBatches(/*multiplicity=*/batch_multiplicity); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; SortOptions options({SortKey("str", SortOrder::Descending)}); ASSERT_OK( @@ -1116,7 +1116,7 @@ TEST(ExecPlanExecution, SourceFilterProjectGroupedSumTopK) { auto input = MakeGroupableBatches(/*multiplicity=*/batch_multiplicity); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; SelectKOptions options = SelectKOptions::TopKDefault(/*k=*/1, {"str"}); ASSERT_OK(Declaration::Sequence( @@ -1145,7 +1145,7 @@ TEST(ExecPlanExecution, SourceFilterProjectGroupedSumTopK) { TEST(ExecPlanExecution, SourceScalarAggSink) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; auto basic_data = MakeBasicBatches(); @@ -1175,7 +1175,7 @@ TEST(ExecPlanExecution, AggregationPreservesOptions) { // and need to keep a copy/strong reference to function options { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; auto basic_data = MakeBasicBatches(); @@ -1202,7 +1202,7 @@ TEST(ExecPlanExecution, AggregationPreservesOptions) { } { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; auto data = MakeGroupableBatches(/*multiplicity=*/100); @@ -1234,7 +1234,7 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) { // ARROW-9056: scalar aggregation can be done over scalars, taking // into account batch.length > 1 (e.g. a partition column) ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; BatchesWithSchema scalar_data; scalar_data.batches = { @@ -1280,7 +1280,7 @@ TEST(ExecPlanExecution, ScalarSourceScalarAggSink) { TEST(ExecPlanExecution, ScalarSourceGroupedSum) { // ARROW-14630: ensure grouped aggregation with a scalar key/array input doesn't error ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; BatchesWithSchema scalar_data; scalar_data.batches = { @@ -1321,7 +1321,7 @@ TEST(ExecPlanExecution, SelfInnerHashJoinSink) { default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ExecNode* left_source; ExecNode* right_source; @@ -1378,7 +1378,7 @@ TEST(ExecPlanExecution, SelfOuterHashJoinSink) { default_memory_pool(), parallel ? arrow::internal::GetCpuThreadPool() : nullptr); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(exec_ctx.get())); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ExecNode* left_source; ExecNode* right_source; @@ -1428,7 +1428,7 @@ TEST(ExecPlanExecution, SelfOuterHashJoinSink) { TEST(ExecPlan, RecordBatchReaderSourceSink) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; // set up a RecordBatchReader: auto input = MakeBasicBatches(); @@ -1464,7 +1464,7 @@ TEST(ExecPlan, SourceEnforcesBatchLimit) { schema({field("a", int32()), field("b", boolean())}), /*num_batches=*/3, /*batch_size=*/static_cast(std::floor(ExecPlan::kMaxBatchSize * 3.5))); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK(Declaration::Sequence( { diff --git a/cpp/src/arrow/compute/exec/sink_node.cc b/cpp/src/arrow/compute/exec/sink_node.cc index 8af4e8e996c..96a34bff437 100644 --- a/cpp/src/arrow/compute/exec/sink_node.cc +++ b/cpp/src/arrow/compute/exec/sink_node.cc @@ -17,6 +17,7 @@ // under the License. #include +#include #include "arrow/compute/api_vector.h" #include "arrow/compute/exec.h" @@ -34,7 +35,6 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/future.h" #include "arrow/util/logging.h" -#include "arrow/util/optional.h" #include "arrow/util/thread_pool.h" #include "arrow/util/tracing_internal.h" #include "arrow/util/unreachable.h" @@ -89,7 +89,7 @@ class BackpressureReservoir : public BackpressureMonitor { class SinkNode : public ExecNode { public: SinkNode(ExecPlan* plan, std::vector inputs, - AsyncGenerator>* generator, + AsyncGenerator>* generator, BackpressureOptions backpressure, BackpressureMonitor** backpressure_monitor_out) : ExecNode(plan, std::move(inputs), {"collected"}, {}, @@ -102,12 +102,12 @@ class SinkNode : public ExecNode { *backpressure_monitor_out = &backpressure_queue_; } auto node_destroyed_capture = node_destroyed_; - *generator = [this, node_destroyed_capture]() -> Future> { + *generator = [this, node_destroyed_capture]() -> Future> { if (*node_destroyed_capture) { return Status::Invalid( "Attempt to consume data after the plan has been destroyed"); } - return push_gen_().Then([this](const util::optional& batch) { + return push_gen_().Then([this](const std::optional& batch) { if (batch) { RecordBackpressureBytesFreed(*batch); } @@ -247,8 +247,8 @@ class SinkNode : public ExecNode { // Needs to be a shared_ptr as the push generator can technically outlive the node BackpressureReservoir backpressure_queue_; - PushGenerator> push_gen_; - PushGenerator>::Producer producer_; + PushGenerator> push_gen_; + PushGenerator>::Producer producer_; std::shared_ptr node_destroyed_; }; @@ -404,7 +404,7 @@ static Result MakeTableConsumingSinkNode( struct OrderBySinkNode final : public SinkNode { OrderBySinkNode(ExecPlan* plan, std::vector inputs, std::unique_ptr impl, - AsyncGenerator>* generator) + AsyncGenerator>* generator) : SinkNode(plan, std::move(inputs), generator, /*backpressure=*/{}, /*backpressure_monitor_out=*/nullptr), impl_(std::move(impl)) {} diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index a640cf737ef..1d51a5c1d28 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -16,6 +16,7 @@ // under the License. #include +#include #include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" @@ -31,7 +32,6 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/future.h" #include "arrow/util/logging.h" -#include "arrow/util/optional.h" #include "arrow/util/thread_pool.h" #include "arrow/util/tracing_internal.h" #include "arrow/util/unreachable.h" @@ -47,7 +47,7 @@ namespace { struct SourceNode : ExecNode { SourceNode(ExecPlan* plan, std::shared_ptr output_schema, - AsyncGenerator> generator) + AsyncGenerator> generator) : ExecNode(plan, {}, {}, std::move(output_schema), /*num_outputs=*/1), generator_(std::move(generator)) {} @@ -112,7 +112,7 @@ struct SourceNode : ExecNode { lock.unlock(); return generator_().Then( - [=](const util::optional& maybe_morsel) + [=](const std::optional& maybe_morsel) -> Future> { std::unique_lock lock(mutex_); if (IsIterationEnd(maybe_morsel) || stop_requested_) { @@ -221,7 +221,7 @@ struct SourceNode : ExecNode { bool stop_requested_{false}; bool started_ = false; int batch_count_{0}; - AsyncGenerator> generator_; + AsyncGenerator> generator_; }; struct TableSourceNode : public SourceNode { @@ -257,13 +257,13 @@ struct TableSourceNode : public SourceNode { return Status::OK(); } - static arrow::AsyncGenerator> TableGenerator( + static arrow::AsyncGenerator> TableGenerator( const Table& table, const int64_t batch_size) { auto batches = ConvertTableToExecBatches(table, batch_size); auto opt_batches = - MapVector([](ExecBatch batch) { return util::make_optional(std::move(batch)); }, + MapVector([](ExecBatch batch) { return std::make_optional(std::move(batch)); }, std::move(batches)); - AsyncGenerator> gen; + AsyncGenerator> gen; gen = MakeVectorGenerator(std::move(opt_batches)); return gen; } diff --git a/cpp/src/arrow/compute/exec/subtree_internal.h b/cpp/src/arrow/compute/exec/subtree_internal.h index 72d419df225..9e55af6068f 100644 --- a/cpp/src/arrow/compute/exec/subtree_internal.h +++ b/cpp/src/arrow/compute/exec/subtree_internal.h @@ -18,13 +18,13 @@ #pragma once #include +#include #include #include #include #include #include "arrow/compute/exec/expression.h" -#include "arrow/util/optional.h" namespace arrow { namespace compute { @@ -64,7 +64,7 @@ struct SubtreeImpl { struct Encoded { // An external index identifying the corresponding object (e.g. a Fragment) of the // guarantee. - util::optional index; + std::optional index; // An encoded expression representing a guarantee. expression_codes guarantee; }; @@ -112,7 +112,7 @@ struct SubtreeImpl { void GenerateSubtrees(expression_codes guarantee, std::vector* encoded) { while (!guarantee.empty()) { if (subtree_exprs_.insert(guarantee).second) { - Encoded encoded_subtree{/*index=*/util::nullopt, guarantee}; + Encoded encoded_subtree{/*index=*/std::nullopt, guarantee}; encoded->push_back(std::move(encoded_subtree)); } guarantee.resize(guarantee.size() - 1); diff --git a/cpp/src/arrow/compute/exec/subtree_test.cc b/cpp/src/arrow/compute/exec/subtree_test.cc index 97213104454..9e6e86dbd4f 100644 --- a/cpp/src/arrow/compute/exec/subtree_test.cc +++ b/cpp/src/arrow/compute/exec/subtree_test.cc @@ -327,9 +327,9 @@ TEST(Subtree, GetSubtreeExpression) { const auto code_a = tree.GetOrInsert(expr_a); const auto code_b = tree.GetOrInsert(expr_b); ASSERT_EQ(expr_a, - tree.GetSubtreeExpression(SubtreeImpl::Encoded{util::nullopt, {code_a}})); + tree.GetSubtreeExpression(SubtreeImpl::Encoded{std::nullopt, {code_a}})); ASSERT_EQ(expr_b, tree.GetSubtreeExpression( - SubtreeImpl::Encoded{util::nullopt, {code_a, code_b}})); + SubtreeImpl::Encoded{std::nullopt, {code_a, code_b}})); } class FakeFragment { @@ -363,14 +363,14 @@ TEST(Subtree, EncodeFragments) { EXPECT_THAT( encoded, testing::UnorderedElementsAreArray({ - SubtreeImpl::Encoded{util::make_optional(0), + SubtreeImpl::Encoded{std::make_optional(0), SubtreeImpl::expression_codes({0, 1})}, - SubtreeImpl::Encoded{util::make_optional(1), + SubtreeImpl::Encoded{std::make_optional(1), SubtreeImpl::expression_codes({2, 3})}, - SubtreeImpl::Encoded{util::nullopt, SubtreeImpl::expression_codes({0})}, - SubtreeImpl::Encoded{util::nullopt, SubtreeImpl::expression_codes({2})}, - SubtreeImpl::Encoded{util::nullopt, SubtreeImpl::expression_codes({0, 1})}, - SubtreeImpl::Encoded{util::nullopt, SubtreeImpl::expression_codes({2, 3})}, + SubtreeImpl::Encoded{std::nullopt, SubtreeImpl::expression_codes({0})}, + SubtreeImpl::Encoded{std::nullopt, SubtreeImpl::expression_codes({2})}, + SubtreeImpl::Encoded{std::nullopt, SubtreeImpl::expression_codes({0, 1})}, + SubtreeImpl::Encoded{std::nullopt, SubtreeImpl::expression_codes({2, 3})}, })); } } // namespace compute diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index cc26143179a..8c8c3f6b3b2 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -46,7 +47,6 @@ #include "arrow/util/async_generator.h" #include "arrow/util/iterator.h" #include "arrow/util/logging.h" -#include "arrow/util/optional.h" #include "arrow/util/unreachable.h" #include "arrow/util/vector.h" @@ -180,7 +180,7 @@ Future<> StartAndFinish(ExecPlan* plan) { } Future> StartAndCollect( - ExecPlan* plan, AsyncGenerator> gen) { + ExecPlan* plan, AsyncGenerator> gen) { RETURN_NOT_OK(plan->Validate()); RETURN_NOT_OK(plan->StartProducing()); @@ -190,7 +190,7 @@ Future> StartAndCollect( .Then([collected_fut]() -> Result> { ARROW_ASSIGN_OR_RAISE(auto collected, collected_fut.result()); return ::arrow::internal::MapVector( - [](util::optional batch) { return std::move(*batch); }, + [](std::optional batch) { return std::move(*batch); }, std::move(collected)); }); } diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index ac9a4ae4ced..5b6e8226b7e 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -60,11 +60,11 @@ struct BatchesWithSchema { std::vector batches; std::shared_ptr schema; - AsyncGenerator> gen(bool parallel, bool slow) const { + AsyncGenerator> gen(bool parallel, bool slow) const { auto opt_batches = ::arrow::internal::MapVector( - [](ExecBatch batch) { return util::make_optional(std::move(batch)); }, batches); + [](ExecBatch batch) { return std::make_optional(std::move(batch)); }, batches); - AsyncGenerator> gen; + AsyncGenerator> gen; if (parallel) { // emulate batches completing initial decode-after-scan on a cpu thread @@ -81,7 +81,7 @@ struct BatchesWithSchema { if (slow) { gen = - MakeMappedGenerator(std::move(gen), [](const util::optional& batch) { + MakeMappedGenerator(std::move(gen), [](const std::optional& batch) { SleepABit(); return batch; }); @@ -96,7 +96,7 @@ Future<> StartAndFinish(ExecPlan* plan); ARROW_TESTING_EXPORT Future> StartAndCollect( - ExecPlan* plan, AsyncGenerator> gen); + ExecPlan* plan, AsyncGenerator> gen); ARROW_TESTING_EXPORT BatchesWithSchema MakeBasicBatches(); diff --git a/cpp/src/arrow/compute/exec/tpch_benchmark.cc b/cpp/src/arrow/compute/exec/tpch_benchmark.cc index 54ac7cbdbf5..5aad5370b73 100644 --- a/cpp/src/arrow/compute/exec/tpch_benchmark.cc +++ b/cpp/src/arrow/compute/exec/tpch_benchmark.cc @@ -28,7 +28,7 @@ namespace arrow { namespace compute { namespace internal { -std::shared_ptr Plan_Q1(AsyncGenerator>* sink_gen, +std::shared_ptr Plan_Q1(AsyncGenerator>* sink_gen, int scale_factor) { ExecContext* ctx = default_exec_context(); *ctx = ExecContext(default_memory_pool(), arrow::internal::GetCpuThreadPool()); @@ -109,7 +109,7 @@ std::shared_ptr Plan_Q1(AsyncGenerator>* sin static void BM_Tpch_Q1(benchmark::State& st) { for (auto _ : st) { st.PauseTiming(); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; std::shared_ptr plan = Plan_Q1(&sink_gen, static_cast(st.range(0))); st.ResumeTiming(); auto fut = StartAndCollect(plan.get(), sink_gen); diff --git a/cpp/src/arrow/compute/exec/tpch_node.cc b/cpp/src/arrow/compute/exec/tpch_node.cc index 978a8fb1ff7..40d44dccccf 100644 --- a/cpp/src/arrow/compute/exec/tpch_node.cc +++ b/cpp/src/arrow/compute/exec/tpch_node.cc @@ -664,7 +664,7 @@ class PartAndPartSupplierGenerator { return SetOutputColumns(cols, kPartsuppTypes, kPartsuppNameMap, partsupp_cols_); } - Result> NextPartBatch(size_t thread_index) { + Result> NextPartBatch(size_t thread_index) { ThreadLocalData& tld = thread_local_data_[thread_index]; { std::lock_guard lock(part_output_queue_mutex_); @@ -673,7 +673,7 @@ class PartAndPartSupplierGenerator { part_output_queue_.pop(); return std::move(batch); } else if (part_rows_generated_ == part_rows_to_generate_) { - return util::nullopt; + return std::nullopt; } else { tld.partkey_start = part_rows_generated_; tld.part_to_generate = @@ -719,7 +719,7 @@ class PartAndPartSupplierGenerator { return ExecBatch::Make(std::move(part_result)); } - Result> NextPartSuppBatch(size_t thread_index) { + Result> NextPartSuppBatch(size_t thread_index) { ThreadLocalData& tld = thread_local_data_[thread_index]; { std::lock_guard lock(partsupp_output_queue_mutex_); @@ -732,7 +732,7 @@ class PartAndPartSupplierGenerator { { std::lock_guard lock(part_output_queue_mutex_); if (part_rows_generated_ == part_rows_to_generate_) { - return util::nullopt; + return std::nullopt; } else { tld.partkey_start = part_rows_generated_; tld.part_to_generate = @@ -1324,7 +1324,7 @@ class OrdersAndLineItemGenerator { return SetOutputColumns(cols, kLineitemTypes, kLineitemNameMap, lineitem_cols_); } - Result> NextOrdersBatch(size_t thread_index) { + Result> NextOrdersBatch(size_t thread_index) { ThreadLocalData& tld = thread_local_data_[thread_index]; { std::lock_guard lock(orders_output_queue_mutex_); @@ -1333,7 +1333,7 @@ class OrdersAndLineItemGenerator { orders_output_queue_.pop(); return std::move(batch); } else if (orders_rows_generated_ == orders_rows_to_generate_) { - return util::nullopt; + return std::nullopt; } else { tld.orderkey_start = orders_rows_generated_; tld.orders_to_generate = @@ -1379,7 +1379,7 @@ class OrdersAndLineItemGenerator { return ExecBatch::Make(std::move(orders_result)); } - Result> NextLineItemBatch(size_t thread_index) { + Result> NextLineItemBatch(size_t thread_index) { ThreadLocalData& tld = thread_local_data_[thread_index]; ExecBatch queued; bool from_queue = false; @@ -1401,7 +1401,7 @@ class OrdersAndLineItemGenerator { std::lock_guard lock(orders_output_queue_mutex_); if (orders_rows_generated_ == orders_rows_to_generate_) { if (from_queue) return std::move(queued); - return util::nullopt; + return std::nullopt; } tld.orderkey_start = orders_rows_generated_; @@ -2709,7 +2709,7 @@ class PartGenerator : public TpchTableGenerator { private: Status ProduceCallback(size_t thread_index) { if (done_.load()) return Status::OK(); - ARROW_ASSIGN_OR_RAISE(util::optional maybe_batch, + ARROW_ASSIGN_OR_RAISE(std::optional maybe_batch, gen_->NextPartBatch(thread_index)); if (!maybe_batch.has_value()) { int64_t batches_generated = gen_->part_batches_generated(); @@ -2771,7 +2771,7 @@ class PartSuppGenerator : public TpchTableGenerator { private: Status ProduceCallback(size_t thread_index) { if (done_.load()) return Status::OK(); - ARROW_ASSIGN_OR_RAISE(util::optional maybe_batch, + ARROW_ASSIGN_OR_RAISE(std::optional maybe_batch, gen_->NextPartSuppBatch(thread_index)); if (!maybe_batch.has_value()) { int64_t batches_generated = gen_->partsupp_batches_generated(); @@ -3090,7 +3090,7 @@ class OrdersGenerator : public TpchTableGenerator { private: Status ProduceCallback(size_t thread_index) { if (done_.load()) return Status::OK(); - ARROW_ASSIGN_OR_RAISE(util::optional maybe_batch, + ARROW_ASSIGN_OR_RAISE(std::optional maybe_batch, gen_->NextOrdersBatch(thread_index)); if (!maybe_batch.has_value()) { int64_t batches_generated = gen_->orders_batches_generated(); @@ -3152,7 +3152,7 @@ class LineitemGenerator : public TpchTableGenerator { private: Status ProduceCallback(size_t thread_index) { if (done_.load()) return Status::OK(); - ARROW_ASSIGN_OR_RAISE(util::optional maybe_batch, + ARROW_ASSIGN_OR_RAISE(std::optional maybe_batch, gen_->NextLineItemBatch(thread_index)); if (!maybe_batch.has_value()) { int64_t batches_generated = gen_->lineitem_batches_generated(); @@ -3541,7 +3541,7 @@ Result TpchGenImpl::Region(std::vector columns) { Result> TpchGen::Make(ExecPlan* plan, double scale_factor, int64_t batch_size, - util::optional seed) { + std::optional seed) { if (!seed.has_value()) seed = GetRandomSeed(); return std::unique_ptr(new TpchGenImpl(plan, scale_factor, batch_size, *seed)); } diff --git a/cpp/src/arrow/compute/exec/tpch_node.h b/cpp/src/arrow/compute/exec/tpch_node.h index fb9376982b1..061b66ca436 100644 --- a/cpp/src/arrow/compute/exec/tpch_node.h +++ b/cpp/src/arrow/compute/exec/tpch_node.h @@ -18,13 +18,13 @@ #pragma once #include +#include #include #include #include "arrow/compute/type_fwd.h" #include "arrow/result.h" #include "arrow/status.h" -#include "arrow/util/optional.h" namespace arrow { namespace compute { @@ -44,7 +44,7 @@ class ARROW_EXPORT TpchGen { */ static Result> Make( ExecPlan* plan, double scale_factor = 1.0, int64_t batch_size = 4096, - util::optional seed = util::nullopt); + std::optional seed = std::nullopt); // The below methods will create and add an ExecNode to the plan that generates // data for the desired table. If columns is empty, all columns will be generated. diff --git a/cpp/src/arrow/compute/exec/tpch_node_test.cc b/cpp/src/arrow/compute/exec/tpch_node_test.cc index fc26ce90c2e..133dbfdf43c 100644 --- a/cpp/src/arrow/compute/exec/tpch_node_test.cc +++ b/cpp/src/arrow/compute/exec/tpch_node_test.cc @@ -50,7 +50,7 @@ using TableNodeFn = Result (TpchGen::*)(std::vector); constexpr double kDefaultScaleFactor = 0.1; Status AddTableAndSinkToPlan(ExecPlan& plan, TpchGen& gen, - AsyncGenerator>& sink_gen, + AsyncGenerator>& sink_gen, TableNodeFn table) { ARROW_ASSIGN_OR_RAISE(ExecNode * table_node, ((gen.*table)({}))); Declaration sink("sink", {Declaration::Input(table_node)}, SinkNodeOptions{&sink_gen}); @@ -64,7 +64,7 @@ Result> GenerateTable(TableNodeFn table, ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan, ExecPlan::Make(&ctx)); ARROW_ASSIGN_OR_RAISE(std::unique_ptr gen, TpchGen::Make(plan.get(), scale_factor)); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ARROW_RETURN_NOT_OK(AddTableAndSinkToPlan(*plan, *gen, sink_gen, table)); auto fut = StartAndCollect(plan.get(), sink_gen); return fut.MoveResult(); @@ -618,7 +618,7 @@ TEST(TpchNode, AllTables) { &VerifyOrders, &VerifyLineitem, &VerifyNation, &VerifyRegion, }; - std::array>, kNumTables> gens; + std::array>, kNumTables> gens; ExecContext ctx(default_memory_pool(), arrow::internal::GetCpuThreadPool()); ASSERT_OK_AND_ASSIGN(std::shared_ptr plan, ExecPlan::Make(&ctx)); ASSERT_OK_AND_ASSIGN(std::unique_ptr gen, diff --git a/cpp/src/arrow/compute/exec/union_node_test.cc b/cpp/src/arrow/compute/exec/union_node_test.cc index 41aaac26d2b..d14bfe16e5f 100644 --- a/cpp/src/arrow/compute/exec/union_node_test.cc +++ b/cpp/src/arrow/compute/exec/union_node_test.cc @@ -90,7 +90,7 @@ struct TestUnionNode : public ::testing::Test { "source", SourceNodeOptions{batch.schema, batch.gen(parallel, /*slow=*/false)}}); } - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; // Test UnionNode::Make with zero inputs if (batches.size() == 0) { diff --git a/cpp/src/arrow/compute/exec/util.h b/cpp/src/arrow/compute/exec/util.h index 7e716808fa0..e1797771fe0 100644 --- a/cpp/src/arrow/compute/exec/util.h +++ b/cpp/src/arrow/compute/exec/util.h @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -33,7 +34,6 @@ #include "arrow/util/cpu_info.h" #include "arrow/util/logging.h" #include "arrow/util/mutex.h" -#include "arrow/util/optional.h" #include "arrow/util/thread_pool.h" #if defined(__clang__) || defined(__GNUC__) @@ -246,7 +246,7 @@ class ARROW_EXPORT AtomicCounter { int count() const { return count_.load(); } - util::optional total() const { + std::optional total() const { int total = total_.load(); if (total == -1) return {}; return total; diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index a6ede14176c..a20b4ce1476 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -46,7 +47,6 @@ #include "arrow/util/logging.h" #include "arrow/util/macros.h" #include "arrow/util/make_unique.h" -#include "arrow/util/optional.h" #include "arrow/util/string_view.h" #include "arrow/visit_data_inline.h" diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 4537c32eb38..068fcab95e4 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -1435,7 +1435,7 @@ struct GroupedMinMaxImpl enable_if_base_binary MakeOffsetsValues( - ArrayData* array, const std::vector>& values) { + ArrayData* array, const std::vector>& values) { using offset_type = typename T::offset_type; ARROW_ASSIGN_OR_RAISE( auto raw_offsets, @@ -1447,7 +1447,7 @@ struct GroupedMinMaxImpl& value = values[i]; + const std::optional& value = values[i]; DCHECK(value.has_value()); if (value->size() > static_cast(std::numeric_limits::max()) || @@ -1463,7 +1463,7 @@ struct GroupedMinMaxImpl& value = values[i]; + const std::optional& value = values[i]; DCHECK(value.has_value()); std::memcpy(data->mutable_data() + offset, value->data(), value->size()); offset += value->size(); @@ -1476,7 +1476,7 @@ struct GroupedMinMaxImpl enable_if_same MakeOffsetsValues( - ArrayData* array, const std::vector>& values) { + ArrayData* array, const std::vector>& values) { const uint8_t* null_bitmap = array->buffers[0]->data(); const int32_t slot_width = checked_cast(*array->type).byte_width(); @@ -1485,7 +1485,7 @@ struct GroupedMinMaxImpl& value = values[i]; + const std::optional& value = values[i]; DCHECK(value.has_value()); std::memcpy(data->mutable_data() + offset, value->data(), slot_width); } else { @@ -1504,7 +1504,7 @@ struct GroupedMinMaxImpl> mins_, maxes_; + std::vector> mins_, maxes_; TypedBufferBuilder has_values_, has_nulls_; std::shared_ptr type_; ScalarAggregateOptions options_; @@ -2128,7 +2128,7 @@ struct GroupedOneImpl::value || template enable_if_base_binary MakeOffsetsValues( - ArrayData* array, const std::vector>& values) { + ArrayData* array, const std::vector>& values) { using offset_type = typename T::offset_type; ARROW_ASSIGN_OR_RAISE( auto raw_offsets, @@ -2140,7 +2140,7 @@ struct GroupedOneImpl::value || offset_type total_length = 0; for (size_t i = 0; i < values.size(); i++) { if (bit_util::GetBit(null_bitmap, i)) { - const util::optional& value = values[i]; + const std::optional& value = values[i]; DCHECK(value.has_value()); if (value->size() > static_cast(std::numeric_limits::max()) || @@ -2156,7 +2156,7 @@ struct GroupedOneImpl::value || int64_t offset = 0; for (size_t i = 0; i < values.size(); i++) { if (bit_util::GetBit(null_bitmap, i)) { - const util::optional& value = values[i]; + const std::optional& value = values[i]; DCHECK(value.has_value()); std::memcpy(data->mutable_data() + offset, value->data(), value->size()); offset += value->size(); @@ -2169,7 +2169,7 @@ struct GroupedOneImpl::value || template enable_if_same MakeOffsetsValues( - ArrayData* array, const std::vector>& values) { + ArrayData* array, const std::vector>& values) { const uint8_t* null_bitmap = array->buffers[0]->data(); const int32_t slot_width = checked_cast(*array->type).byte_width(); @@ -2178,7 +2178,7 @@ struct GroupedOneImpl::value || int64_t offset = 0; for (size_t i = 0; i < values.size(); i++) { if (bit_util::GetBit(null_bitmap, i)) { - const util::optional& value = values[i]; + const std::optional& value = values[i]; DCHECK(value.has_value()); std::memcpy(data->mutable_data() + offset, value->data(), slot_width); } else { @@ -2195,7 +2195,7 @@ struct GroupedOneImpl::value || ExecContext* ctx_; Allocator allocator_; int64_t num_groups_; - std::vector> ones_; + std::vector> ones_; TypedBufferBuilder has_one_; std::shared_ptr out_type_; }; @@ -2467,7 +2467,7 @@ struct GroupedListImpl::value || template enable_if_base_binary MakeOffsetsValues( - ArrayData* array, const std::vector>& values) { + ArrayData* array, const std::vector>& values) { using offset_type = typename T::offset_type; ARROW_ASSIGN_OR_RAISE( auto raw_offsets, @@ -2479,7 +2479,7 @@ struct GroupedListImpl::value || offset_type total_length = 0; for (size_t i = 0; i < values.size(); i++) { if (bit_util::GetBit(null_bitmap, i)) { - const util::optional& value = values[i]; + const std::optional& value = values[i]; DCHECK(value.has_value()); if (value->size() > static_cast(std::numeric_limits::max()) || @@ -2495,7 +2495,7 @@ struct GroupedListImpl::value || int64_t offset = 0; for (size_t i = 0; i < values.size(); i++) { if (bit_util::GetBit(null_bitmap, i)) { - const util::optional& value = values[i]; + const std::optional& value = values[i]; DCHECK(value.has_value()); std::memcpy(data->mutable_data() + offset, value->data(), value->size()); offset += value->size(); @@ -2508,7 +2508,7 @@ struct GroupedListImpl::value || template enable_if_same MakeOffsetsValues( - ArrayData* array, const std::vector>& values) { + ArrayData* array, const std::vector>& values) { const uint8_t* null_bitmap = array->buffers[0]->data(); const int32_t slot_width = checked_cast(*array->type).byte_width(); @@ -2517,7 +2517,7 @@ struct GroupedListImpl::value || int64_t offset = 0; for (size_t i = 0; i < values.size(); i++) { if (bit_util::GetBit(null_bitmap, i)) { - const util::optional& value = values[i]; + const std::optional& value = values[i]; DCHECK(value.has_value()); std::memcpy(data->mutable_data() + offset, value->data(), slot_width); } else { @@ -2534,7 +2534,7 @@ struct GroupedListImpl::value || ExecContext* ctx_; Allocator allocator_; int64_t num_groups_, num_args_ = 0; - std::vector> values_; + std::vector> values_; TypedBufferBuilder groups_; TypedBufferBuilder values_bitmap_; std::shared_ptr out_type_; diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc index f599f9abb60..f4dc74b7c89 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate_test.cc @@ -131,7 +131,7 @@ Result GroupByUsingExecPlan(const BatchesWithSchema& input, } ARROW_ASSIGN_OR_RAISE(auto plan, ExecPlan::Make(ctx)); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; RETURN_NOT_OK( Declaration::Sequence( { @@ -152,7 +152,7 @@ Result GroupByUsingExecPlan(const BatchesWithSchema& input, .Then([collected_fut]() -> Result> { ARROW_ASSIGN_OR_RAISE(auto collected, collected_fut.result()); return ::arrow::internal::MapVector( - [](util::optional batch) { return std::move(*batch); }, + [](std::optional batch) { return std::move(*batch); }, std::move(collected)); }); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc index 6b21a532392..7a77b63e37a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc @@ -16,6 +16,7 @@ // under the License. #include +#include #include "arrow/array/array_base.h" #include "arrow/array/builder_binary.h" @@ -26,7 +27,6 @@ #include "arrow/result.h" #include "arrow/util/formatting.h" #include "arrow/util/int_util.h" -#include "arrow/util/optional.h" #include "arrow/util/utf8_internal.h" #include "arrow/visit_data_inline.h" diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index cfe10855314..290a0e5df66 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -18,12 +18,12 @@ #include #include #include +#include #include "arrow/compute/api_scalar.h" #include "arrow/compute/kernels/common.h" #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_ops.h" -#include "arrow/util/optional.h" namespace arrow { @@ -640,7 +640,7 @@ struct BinaryScalarMinMax { RETURN_NOT_OK(builder.ReserveData(estimated_final_size)); for (int64_t row = 0; row < batch.length; row++) { - util::optional result; + std::optional result; auto visit_value = [&](string_view value) { result = !result ? value : Op::Call(*result, value); }; @@ -651,7 +651,7 @@ struct BinaryScalarMinMax { if (scalar.is_valid) { visit_value(UnboxScalar::Unbox(scalar)); } else if (!options.skip_nulls) { - result = util::nullopt; + result = std::nullopt; break; } } else { @@ -664,7 +664,7 @@ struct BinaryScalarMinMax { visit_value( string_view(reinterpret_cast(data + offsets[row]), length)); } else if (!options.skip_nulls) { - result = util::nullopt; + result = std::nullopt; break; } } diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 672a8b27977..8c941934a1e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -67,7 +67,7 @@ Status CheckIdenticalTypes(const ExecValue* begin, int count) { constexpr uint64_t kAllNull = 0; constexpr uint64_t kAllValid = ~kAllNull; -util::optional GetConstantValidityWord(const ExecValue& data) { +std::optional GetConstantValidityWord(const ExecValue& data) { if (data.is_scalar()) { return data.scalar->is_valid ? kAllValid : kAllNull; } @@ -91,7 +91,7 @@ struct IfElseNullPromoter { enum { COND_CONST = 1, LEFT_CONST = 2, RIGHT_CONST = 4 }; int64_t constant_validity_flag; - util::optional cond_const, left_const, right_const; + std::optional cond_const, left_const, right_const; Bitmap cond_data, cond_valid, left_valid, right_valid; IfElseNullPromoter(KernelContext* ctx, const ExecValue& cond_d, const ExecValue& left_d, diff --git a/cpp/src/arrow/compute/kernels/vector_replace_test.cc b/cpp/src/arrow/compute/kernels/vector_replace_test.cc index 589952ba700..b83b6973313 100644 --- a/cpp/src/arrow/compute/kernels/vector_replace_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_replace_test.cc @@ -419,7 +419,7 @@ TYPED_TEST(TestReplaceNumeric, ReplaceWithMaskRandom) { rand.ArrayOf(boolean(), length, /*null_probability=*/0.01)); const int64_t num_replacements = std::count_if( mask->begin(), mask->end(), - [](util::optional value) { return value.has_value() && *value; }); + [](std::optional value) { return value.has_value() && *value; }); auto replacements = checked_pointer_cast( rand.ArrayOf(*field("a", ty, options), num_replacements)); auto expected = this->NaiveImpl(*array, *mask, *replacements); @@ -1045,7 +1045,7 @@ TYPED_TEST(TestReplaceBinary, ReplaceWithMaskRandom) { rand.ArrayOf(boolean(), length, /*null_probability=*/0.01)); const int64_t num_replacements = std::count_if( mask->begin(), mask->end(), - [](util::optional value) { return value.has_value() && *value; }); + [](std::optional value) { return value.has_value() && *value; }); auto replacements = checked_pointer_cast( rand.ArrayOf(*field("a", ty, options), num_replacements)); auto expected = this->NaiveImpl(*array, *mask, *replacements); diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 28307ecca37..94c80b9f80d 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -36,7 +37,6 @@ #include "arrow/table.h" #include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" -#include "arrow/util/optional.h" #include "arrow/visit_type_inline.h" #include "arrow/visitor.h" diff --git a/cpp/src/arrow/config.cc b/cpp/src/arrow/config.cc index a93a8feae1d..9e32e543732 100644 --- a/cpp/src/arrow/config.cc +++ b/cpp/src/arrow/config.cc @@ -63,7 +63,7 @@ std::string MakeSimdLevelString(QueryFlagFunction&& query_flag) { } } -util::optional timezone_db_path; +std::optional timezone_db_path; }; // namespace @@ -80,7 +80,7 @@ RuntimeInfo GetRuntimeInfo() { #if !USE_OS_TZDB info.timezone_db_path = timezone_db_path; #else - info.timezone_db_path = util::optional(); + info.timezone_db_path = std::optional(); #endif return info; } diff --git a/cpp/src/arrow/config.h b/cpp/src/arrow/config.h index 87e31cc456a..617d6c268b5 100644 --- a/cpp/src/arrow/config.h +++ b/cpp/src/arrow/config.h @@ -17,11 +17,11 @@ #pragma once +#include #include #include "arrow/status.h" #include "arrow/util/config.h" // IWYU pragma: export -#include "arrow/util/optional.h" #include "arrow/util/visibility.h" namespace arrow { @@ -70,7 +70,7 @@ struct RuntimeInfo { bool using_os_timezone_db; /// The path to the timezone database; by default None. - util::optional timezone_db_path; + std::optional timezone_db_path; }; /// \brief Get runtime build info. @@ -89,7 +89,7 @@ RuntimeInfo GetRuntimeInfo(); struct GlobalOptions { /// Path to text timezone database. This is only configurable on Windows, /// which does not have a compatible OS timezone database. - util::optional timezone_db_path; + std::optional timezone_db_path; }; ARROW_EXPORT diff --git a/cpp/src/arrow/csv/reader.cc b/cpp/src/arrow/csv/reader.cc index ba754399b75..d770fa734f5 100644 --- a/cpp/src/arrow/csv/reader.cc +++ b/cpp/src/arrow/csv/reader.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -46,7 +47,6 @@ #include "arrow/util/iterator.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" -#include "arrow/util/optional.h" #include "arrow/util/task_group.h" #include "arrow/util/thread_pool.h" #include "arrow/util/utf8_internal.h" @@ -166,7 +166,7 @@ namespace { // This is a callable that can be used to transform an iterator. The source iterator // will contain buffers of data and the output iterator will contain delimited CSV -// blocks. util::optional is used so that there is an end token (required by the +// blocks. std::optional is used so that there is an end token (required by the // iterator APIs (e.g. Visit)) even though an empty optional is never used in this code. class BlockReader { public: @@ -1212,8 +1212,8 @@ class CSVRowCounter : public ReaderMixin, // count_cb must return a value instead of Status/Future<> to work with // MakeMappedGenerator, and it must use a type with a valid end value to work with // IterationEnd. - std::function>(const CSVBlock&)> count_cb = - [self](const CSVBlock& maybe_block) -> Result> { + std::function>(const CSVBlock&)> count_cb = + [self](const CSVBlock& maybe_block) -> Result> { ARROW_ASSIGN_OR_RAISE( auto parser, self->Parse(maybe_block.partial, maybe_block.completion, maybe_block.buffer, diff --git a/cpp/src/arrow/csv/writer_test.cc b/cpp/src/arrow/csv/writer_test.cc index c7f9433688c..d8f13bdbbe6 100644 --- a/cpp/src/arrow/csv/writer_test.cc +++ b/cpp/src/arrow/csv/writer_test.cc @@ -18,6 +18,7 @@ #include "gtest/gtest.h" #include +#include #include #include @@ -31,14 +32,13 @@ #include "arrow/testing/matchers.h" #include "arrow/type.h" #include "arrow/type_fwd.h" -#include "arrow/util/optional.h" namespace arrow { namespace csv { struct WriterTestParams { WriterTestParams(std::shared_ptr schema, std::string batch_data, - WriteOptions options, util::optional expected_output, + WriteOptions options, std::optional expected_output, Status expected_status = Status::OK()) : schema(std::move(schema)), batch_data(std::move(batch_data)), @@ -48,7 +48,7 @@ struct WriterTestParams { std::shared_ptr schema; std::string batch_data; WriteOptions options; - util::optional expected_output; + std::optional expected_output; Status expected_status; }; diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index 1e4c9b7f719..6faaa953bb3 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -55,9 +55,9 @@ Result> Fragment::ReadPhysicalSchema() { return physical_schema_; } -Future> Fragment::CountRows(compute::Expression, - const std::shared_ptr&) { - return Future>::MakeFinished(util::nullopt); +Future> Fragment::CountRows(compute::Expression, + const std::shared_ptr&) { + return Future>::MakeFinished(std::nullopt); } Result> InMemoryFragment::ReadPhysicalSchemaImpl() { @@ -129,16 +129,16 @@ Result InMemoryFragment::ScanBatchesAsync( options->batch_size); } -Future> InMemoryFragment::CountRows( +Future> InMemoryFragment::CountRows( compute::Expression predicate, const std::shared_ptr& options) { if (ExpressionHasFieldRefs(predicate)) { - return Future>::MakeFinished(util::nullopt); + return Future>::MakeFinished(std::nullopt); } int64_t total = 0; for (const auto& batch : record_batches_) { total += batch->num_rows(); } - return Future>::MakeFinished(total); + return Future>::MakeFinished(total); } Dataset::Dataset(std::shared_ptr schema, compute::Expression partition_expression) diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index 9f4fee52154..62181b60ba4 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -21,6 +21,7 @@ #include #include +#include #include #include #include @@ -30,7 +31,6 @@ #include "arrow/dataset/visibility.h" #include "arrow/util/macros.h" #include "arrow/util/mutex.h" -#include "arrow/util/optional.h" namespace arrow { namespace dataset { @@ -64,7 +64,7 @@ class ARROW_DS_EXPORT Fragment : public std::enable_shared_from_this { /// /// If this is not possible, resolve with an empty optional. The fragment can perform /// I/O (e.g. to read metadata) before it deciding whether it can satisfy the request. - virtual Future> CountRows( + virtual Future> CountRows( compute::Expression predicate, const std::shared_ptr& options); virtual std::string type_name() const = 0; @@ -120,7 +120,7 @@ class ARROW_DS_EXPORT InMemoryFragment : public Fragment { Result ScanBatchesAsync( const std::shared_ptr& options) override; - Future> CountRows( + Future> CountRows( compute::Expression predicate, const std::shared_ptr& options) override; diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc index 35b6e8129e2..cb155d7b962 100644 --- a/cpp/src/arrow/dataset/dataset_test.cc +++ b/cpp/src/arrow/dataset/dataset_test.cc @@ -17,6 +17,8 @@ #include "arrow/dataset/dataset.h" +#include + #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/discovery.h" #include "arrow/dataset/partition.h" @@ -24,7 +26,6 @@ #include "arrow/filesystem/mockfs.h" #include "arrow/stl.h" #include "arrow/testing/generator.h" -#include "arrow/util/optional.h" namespace arrow { namespace dataset { @@ -485,7 +486,7 @@ inline std::shared_ptr SchemaFromNames(const std::vector na class TestSchemaUnification : public TestUnionDataset { public: - using i32 = util::optional; + using i32 = std::optional; using PathAndContent = std::vector>; void SetUp() override { @@ -595,7 +596,7 @@ class TestSchemaUnification : public TestUnionDataset { std::shared_ptr dataset_; }; -using util::nullopt; +using std::nullopt; TEST_F(TestSchemaUnification, SelectStar) { // This is a `SELECT * FROM dataset` where it ensures: diff --git a/cpp/src/arrow/dataset/dataset_writer_test.cc b/cpp/src/arrow/dataset/dataset_writer_test.cc index f4e7344cdb2..edc9bc8bbc1 100644 --- a/cpp/src/arrow/dataset/dataset_writer_test.cc +++ b/cpp/src/arrow/dataset/dataset_writer_test.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include "arrow/array/builder_primitive.h" @@ -30,7 +31,6 @@ #include "arrow/table.h" #include "arrow/testing/future_util.h" #include "arrow/testing/gtest_util.h" -#include "arrow/util/optional.h" #include "gtest/gtest.h" namespace arrow { @@ -114,13 +114,13 @@ class DatasetWriterTestFixture : public testing::Test { return batch; } - util::optional FindFile(const std::string& filename) { + std::optional FindFile(const std::string& filename) { for (const auto& mock_file : filesystem_->AllFiles()) { if (mock_file.full_path == filename) { return mock_file; } } - return util::nullopt; + return std::nullopt; } void AssertVisited(const std::vector& actual_paths, @@ -150,7 +150,7 @@ class DatasetWriterTestFixture : public testing::Test { return batch; } - void AssertFileCreated(const util::optional& maybe_file, + void AssertFileCreated(const std::optional& maybe_file, const std::string& expected_filename) { ASSERT_TRUE(maybe_file.has_value()) << "The file " << expected_filename << " was not created"; @@ -167,7 +167,7 @@ class DatasetWriterTestFixture : public testing::Test { void AssertCreatedData(const std::vector& expected_files) { counter_ = 0; for (const auto& expected_file : expected_files) { - util::optional written_file = FindFile(expected_file.filename); + std::optional written_file = FindFile(expected_file.filename); AssertFileCreated(written_file, expected_file.filename); int num_batches = 0; AssertBatchesEqual(*MakeBatch(expected_file.start, expected_file.num_rows), @@ -178,21 +178,21 @@ class DatasetWriterTestFixture : public testing::Test { void AssertFilesCreated(const std::vector& expected_files) { for (const std::string& expected_file : expected_files) { - util::optional written_file = FindFile(expected_file); + std::optional written_file = FindFile(expected_file); AssertFileCreated(written_file, expected_file); } } void AssertNotFiles(const std::vector& expected_non_files) { for (const auto& expected_non_file : expected_non_files) { - util::optional file = FindFile(expected_non_file); + std::optional file = FindFile(expected_non_file); ASSERT_FALSE(file.has_value()); } } void AssertEmptyFiles(const std::vector& expected_empty_files) { for (const auto& expected_empty_file : expected_empty_files) { - util::optional file = FindFile(expected_empty_file); + std::optional file = FindFile(expected_empty_file); ASSERT_TRUE(file.has_value()); ASSERT_EQ("", file->data); } diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index 81bf10abe30..64daf08fd03 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -65,7 +65,7 @@ Result> FileSource::Open() const { } Result> FileSource::OpenCompressed( - util::optional compression) const { + std::optional compression) const { ARROW_ASSIGN_OR_RAISE(auto file, Open()); auto actual_compression = Compression::type::UNCOMPRESSED; if (!compression.has_value()) { @@ -100,10 +100,10 @@ bool FileSource::Equals(const FileSource& other) const { compression_ == other.compression_; } -Future> FileFormat::CountRows( +Future> FileFormat::CountRows( const std::shared_ptr&, compute::Expression, const std::shared_ptr&) { - return Future>::MakeFinished(util::nullopt); + return Future>::MakeFinished(std::nullopt); } Result> FileFormat::MakeFragment( @@ -135,12 +135,12 @@ Result FileFragment::ScanBatchesAsync( return format_->ScanBatchesAsync(options, self); } -Future> FileFragment::CountRows( +Future> FileFragment::CountRows( compute::Expression predicate, const std::shared_ptr& options) { ARROW_ASSIGN_OR_RAISE(predicate, compute::SimplifyWithGuarantee(std::move(predicate), partition_expression_)); if (!predicate.IsSatisfiable()) { - return Future>::MakeFinished(0); + return Future>::MakeFinished(0); } auto self = checked_pointer_cast(shared_from_this()); return format()->CountRows(self, std::move(predicate), options); diff --git a/cpp/src/arrow/dataset/file_base.h b/cpp/src/arrow/dataset/file_base.h index 7b0f5ffcf2e..586c58b3f52 100644 --- a/cpp/src/arrow/dataset/file_base.h +++ b/cpp/src/arrow/dataset/file_base.h @@ -112,7 +112,7 @@ class ARROW_DS_EXPORT FileSource : public util::EqualityComparable { /// \param[in] compression If nullopt, guess the compression scheme from the /// filename, else decompress with the given codec Result> OpenCompressed( - util::optional compression = util::nullopt) const; + std::optional compression = std::nullopt) const; /// \brief equality comparison with another FileSource bool Equals(const FileSource& other) const; @@ -154,7 +154,7 @@ class ARROW_DS_EXPORT FileFormat : public std::enable_shared_from_this& options, const std::shared_ptr& file) const = 0; - virtual Future> CountRows( + virtual Future> CountRows( const std::shared_ptr& file, compute::Expression predicate, const std::shared_ptr& options); @@ -187,7 +187,7 @@ class ARROW_DS_EXPORT FileFragment : public Fragment, public: Result ScanBatchesAsync( const std::shared_ptr& options) override; - Future> CountRows( + Future> CountRows( compute::Expression predicate, const std::shared_ptr& options) override; @@ -344,7 +344,7 @@ class ARROW_DS_EXPORT FileWriter { std::shared_ptr options_; std::shared_ptr destination_; fs::FileLocator destination_locator_; - util::optional bytes_written_; + std::optional bytes_written_; }; /// \brief Options for writing a dataset. diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index 4cb331b0afc..bfc710105ed 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -290,11 +290,11 @@ Result CsvFileFormat::ScanBatchesAsync( return generator; } -Future> CsvFileFormat::CountRows( +Future> CsvFileFormat::CountRows( const std::shared_ptr& file, compute::Expression predicate, const std::shared_ptr& options) { if (ExpressionHasFieldRefs(predicate)) { - return Future>::MakeFinished(util::nullopt); + return Future>::MakeFinished(std::nullopt); } auto self = checked_pointer_cast(shared_from_this()); ARROW_ASSIGN_OR_RAISE( @@ -309,7 +309,7 @@ Future> CsvFileFormat::CountRows( return csv::CountRowsAsync(options->io_context, std::move(input), ::arrow::internal::GetCpuThreadPool(), read_options, self->parse_options) - .Then([](int64_t count) { return util::make_optional(count); }); + .Then([](int64_t count) { return std::make_optional(count); }); } // diff --git a/cpp/src/arrow/dataset/file_csv.h b/cpp/src/arrow/dataset/file_csv.h index e58ed87b427..a3d214ef494 100644 --- a/cpp/src/arrow/dataset/file_csv.h +++ b/cpp/src/arrow/dataset/file_csv.h @@ -57,7 +57,7 @@ class ARROW_DS_EXPORT CsvFileFormat : public FileFormat { const std::shared_ptr& scan_options, const std::shared_ptr& file) const override; - Future> CountRows( + Future> CountRows( const std::shared_ptr& file, compute::Expression predicate, const std::shared_ptr& options) override; diff --git a/cpp/src/arrow/dataset/file_ipc.cc b/cpp/src/arrow/dataset/file_ipc.cc index 7c45a5d7056..2650db499ce 100644 --- a/cpp/src/arrow/dataset/file_ipc.cc +++ b/cpp/src/arrow/dataset/file_ipc.cc @@ -175,15 +175,15 @@ Result IpcFileFormat::ScanBatchesAsync( return MakeFromFuture(open_reader.Then(reopen_reader).Then(open_generator)); } -Future> IpcFileFormat::CountRows( +Future> IpcFileFormat::CountRows( const std::shared_ptr& file, compute::Expression predicate, const std::shared_ptr& options) { if (ExpressionHasFieldRefs(predicate)) { - return Future>::MakeFinished(util::nullopt); + return Future>::MakeFinished(std::nullopt); } auto self = checked_pointer_cast(shared_from_this()); return DeferNotOk(options->io_context.executor()->Submit( - [self, file]() -> Result> { + [self, file]() -> Result> { ARROW_ASSIGN_OR_RAISE(auto reader, OpenReader(file->source())); return reader->CountRows(); })); diff --git a/cpp/src/arrow/dataset/file_ipc.h b/cpp/src/arrow/dataset/file_ipc.h index 6dc40da1fac..8b97046271b 100644 --- a/cpp/src/arrow/dataset/file_ipc.h +++ b/cpp/src/arrow/dataset/file_ipc.h @@ -56,7 +56,7 @@ class ARROW_DS_EXPORT IpcFileFormat : public FileFormat { const std::shared_ptr& options, const std::shared_ptr& file) const override; - Future> CountRows( + Future> CountRows( const std::shared_ptr& file, compute::Expression predicate, const std::shared_ptr& options) override; diff --git a/cpp/src/arrow/dataset/file_orc.cc b/cpp/src/arrow/dataset/file_orc.cc index 49102f3deae..cf04e5e7484 100644 --- a/cpp/src/arrow/dataset/file_orc.cc +++ b/cpp/src/arrow/dataset/file_orc.cc @@ -196,15 +196,15 @@ Result OrcFileFormat::ScanBatchesAsync( return iter_to_gen; } -Future> OrcFileFormat::CountRows( +Future> OrcFileFormat::CountRows( const std::shared_ptr& file, compute::Expression predicate, const std::shared_ptr& options) { if (ExpressionHasFieldRefs(predicate)) { - return Future>::MakeFinished(util::nullopt); + return Future>::MakeFinished(std::nullopt); } auto self = checked_pointer_cast(shared_from_this()); return DeferNotOk(options->io_context.executor()->Submit( - [self, file]() -> Result> { + [self, file]() -> Result> { ARROW_ASSIGN_OR_RAISE(auto reader, OpenORCReader(file->source())); return reader->NumberOfRows(); })); diff --git a/cpp/src/arrow/dataset/file_orc.h b/cpp/src/arrow/dataset/file_orc.h index 5bbe4df24ad..cbfb83670cb 100644 --- a/cpp/src/arrow/dataset/file_orc.h +++ b/cpp/src/arrow/dataset/file_orc.h @@ -55,7 +55,7 @@ class ARROW_DS_EXPORT OrcFileFormat : public FileFormat { const std::shared_ptr& options, const std::shared_ptr& file) const override; - Future> CountRows( + Future> CountRows( const std::shared_ptr& file, compute::Expression predicate, const std::shared_ptr& options) override; diff --git a/cpp/src/arrow/dataset/file_parquet.cc b/cpp/src/arrow/dataset/file_parquet.cc index 0d95e18171b..f07254e1115 100644 --- a/cpp/src/arrow/dataset/file_parquet.cc +++ b/cpp/src/arrow/dataset/file_parquet.cc @@ -98,7 +98,7 @@ Result> GetSchemaManifest( return manifest; } -util::optional ColumnChunkStatisticsAsExpression( +std::optional ColumnChunkStatisticsAsExpression( const SchemaField& schema_field, const parquet::RowGroupMetaData& metadata) { // For the remaining of this function, failure to extract/parse statistics // are ignored by returning nullptr. The goal is two fold. First @@ -107,13 +107,13 @@ util::optional ColumnChunkStatisticsAsExpression( // For now, only leaf (primitive) types are supported. if (!schema_field.is_leaf()) { - return util::nullopt; + return std::nullopt; } auto column_metadata = metadata.ColumnChunk(schema_field.column_index); auto statistics = column_metadata->statistics(); if (statistics == nullptr) { - return util::nullopt; + return std::nullopt; } const auto& field = schema_field.field; @@ -126,7 +126,7 @@ util::optional ColumnChunkStatisticsAsExpression( std::shared_ptr min, max; if (!StatisticsAsScalars(*statistics, &min, &max).ok()) { - return util::nullopt; + return std::nullopt; } auto maybe_min = min->CastTo(field->type()); @@ -155,7 +155,7 @@ util::optional ColumnChunkStatisticsAsExpression( return in_range; } - return util::nullopt; + return std::nullopt; } void AddColumnIndices(const SchemaField& schema_field, @@ -482,17 +482,17 @@ Result ParquetFileFormat::ScanBatchesAsync( return generator; } -Future> ParquetFileFormat::CountRows( +Future> ParquetFileFormat::CountRows( const std::shared_ptr& file, compute::Expression predicate, const std::shared_ptr& options) { auto parquet_file = checked_pointer_cast(file); if (parquet_file->metadata()) { ARROW_ASSIGN_OR_RAISE(auto maybe_count, parquet_file->TryCountRows(std::move(predicate))); - return Future>::MakeFinished(maybe_count); + return Future>::MakeFinished(maybe_count); } else { return DeferNotOk(options->io_context.executor()->Submit( - [parquet_file, predicate]() -> Result> { + [parquet_file, predicate]() -> Result> { RETURN_NOT_OK(parquet_file->EnsureCompleteMetadata()); return parquet_file->TryCountRows(predicate); })); @@ -512,7 +512,7 @@ Result> ParquetFileFormat::MakeFragment( std::shared_ptr physical_schema) { return std::shared_ptr(new ParquetFileFragment( std::move(source), shared_from_this(), std::move(partition_expression), - std::move(physical_schema), util::nullopt)); + std::move(physical_schema), std::nullopt)); } // @@ -573,7 +573,7 @@ ParquetFileFragment::ParquetFileFragment(FileSource source, std::shared_ptr format, compute::Expression partition_expression, std::shared_ptr physical_schema, - util::optional> row_groups) + std::optional> row_groups) : FileFragment(std::move(source), std::move(format), std::move(partition_expression), std::move(physical_schema)), parquet_format_(checked_cast(*format_)), @@ -738,7 +738,7 @@ Result> ParquetFileFragment::TestRowGroups( return row_groups; } -Result> ParquetFileFragment::TryCountRows( +Result> ParquetFileFragment::TryCountRows( compute::Expression predicate) { DCHECK_NE(metadata_, nullptr); if (ExpressionHasFieldRefs(predicate)) { @@ -757,7 +757,7 @@ Result> ParquetFileFragment::TryCountRows( // If the row group is entirely excluded, exclude it from the row count if (!expressions[i].IsSatisfiable()) continue; // Unless the row group is entirely included, bail out of fast path - if (expressions[i] != compute::literal(true)) return util::nullopt; + if (expressions[i] != compute::literal(true)) return std::nullopt; BEGIN_PARQUET_CATCH_EXCEPTIONS rows += metadata()->RowGroup((*row_groups_)[i])->num_rows(); END_PARQUET_CATCH_EXCEPTIONS diff --git a/cpp/src/arrow/dataset/file_parquet.h b/cpp/src/arrow/dataset/file_parquet.h index 6167c2894f6..05c02940d35 100644 --- a/cpp/src/arrow/dataset/file_parquet.h +++ b/cpp/src/arrow/dataset/file_parquet.h @@ -20,6 +20,7 @@ #pragma once #include +#include #include #include #include @@ -30,7 +31,6 @@ #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" #include "arrow/io/caching.h" -#include "arrow/util/optional.h" namespace parquet { class ParquetFileReader; @@ -99,7 +99,7 @@ class ARROW_DS_EXPORT ParquetFileFormat : public FileFormat { const std::shared_ptr& options, const std::shared_ptr& file) const override; - Future> CountRows( + Future> CountRows( const std::shared_ptr& file, compute::Expression predicate, const std::shared_ptr& options) override; @@ -167,7 +167,7 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { ParquetFileFragment(FileSource source, std::shared_ptr format, compute::Expression partition_expression, std::shared_ptr physical_schema, - util::optional> row_groups); + std::optional> row_groups); Status SetMetadata(std::shared_ptr metadata, std::shared_ptr manifest); @@ -185,13 +185,13 @@ class ARROW_DS_EXPORT ParquetFileFragment : public FileFragment { /// Try to count rows matching the predicate using metadata. Expects /// metadata to be present, and expects the predicate to have been /// simplified against the partition expression already. - Result> TryCountRows(compute::Expression predicate); + Result> TryCountRows(compute::Expression predicate); ParquetFileFormat& parquet_format_; /// Indices of row groups selected by this fragment, - /// or util::nullopt if all row groups are selected. - util::optional> row_groups_; + /// or std::nullopt if all row groups are selected. + std::optional> row_groups_; std::vector statistics_expressions_; std::vector statistics_expressions_complete_; diff --git a/cpp/src/arrow/dataset/file_parquet_test.cc b/cpp/src/arrow/dataset/file_parquet_test.cc index de048855cf2..a2a15762a58 100644 --- a/cpp/src/arrow/dataset/file_parquet_test.cc +++ b/cpp/src/arrow/dataset/file_parquet_test.cc @@ -231,7 +231,7 @@ TEST_F(TestParquetFileFormat, CountRowsPredicatePushdown) { auto fragment = MakeFragment(*source); - ASSERT_FINISHES_OK_AND_EQ(util::make_optional(kTotalNumRows), + ASSERT_FINISHES_OK_AND_EQ(std::make_optional(kTotalNumRows), fragment->CountRows(literal(true), options)); for (int i = 1; i <= kNumRowGroups; i++) { @@ -240,18 +240,18 @@ TEST_F(TestParquetFileFormat, CountRowsPredicatePushdown) { auto predicate = less_equal(field_ref("i64"), literal(i)); ASSERT_OK_AND_ASSIGN(predicate, predicate.Bind(*reader->schema())); auto expected = i * (i + 1) / 2; - ASSERT_FINISHES_OK_AND_EQ(util::make_optional(expected), + ASSERT_FINISHES_OK_AND_EQ(std::make_optional(expected), fragment->CountRows(predicate, options)); predicate = and_(less_equal(field_ref("i64"), literal(i)), greater_equal(field_ref("i64"), literal(i))); ASSERT_OK_AND_ASSIGN(predicate, predicate.Bind(*reader->schema())); - ASSERT_FINISHES_OK_AND_EQ(util::make_optional(i), + ASSERT_FINISHES_OK_AND_EQ(std::make_optional(i), fragment->CountRows(predicate, options)); predicate = equal(field_ref("i64"), literal(i)); ASSERT_OK_AND_ASSIGN(predicate, predicate.Bind(*reader->schema())); - ASSERT_FINISHES_OK_AND_EQ(util::make_optional(i), + ASSERT_FINISHES_OK_AND_EQ(std::make_optional(i), fragment->CountRows(predicate, options)); } @@ -278,15 +278,15 @@ TEST_F(TestParquetFileFormat, CountRowsPredicatePushdown) { ASSERT_OK_AND_ASSIGN( auto predicate, greater_equal(field_ref("i64"), literal(1)).Bind(*dataset_schema)); - ASSERT_FINISHES_OK_AND_EQ(util::make_optional(4), + ASSERT_FINISHES_OK_AND_EQ(std::make_optional(4), fragment->CountRows(predicate, options)); ASSERT_OK_AND_ASSIGN(predicate, is_null(field_ref("i64")).Bind(*dataset_schema)); - ASSERT_FINISHES_OK_AND_EQ(util::make_optional(3), + ASSERT_FINISHES_OK_AND_EQ(std::make_optional(3), fragment->CountRows(predicate, options)); ASSERT_OK_AND_ASSIGN(predicate, is_valid(field_ref("i64")).Bind(*dataset_schema)); - ASSERT_FINISHES_OK_AND_EQ(util::make_optional(4), + ASSERT_FINISHES_OK_AND_EQ(std::make_optional(4), fragment->CountRows(predicate, options)); } } diff --git a/cpp/src/arrow/dataset/file_test.cc b/cpp/src/arrow/dataset/file_test.cc index 4dfc6bc584d..6d866c196b3 100644 --- a/cpp/src/arrow/dataset/file_test.cc +++ b/cpp/src/arrow/dataset/file_test.cc @@ -351,7 +351,7 @@ TEST_F(TestFileSystemDataset, WriteProjected) { class FileSystemWriteTest : public testing::TestWithParam> { using PlanFactory = std::function( const FileSystemDatasetWriteOptions&, - std::function>()>*)>; + std::function>()>*)>; protected: bool IsParallel() { return std::get<0>(GetParam()); } @@ -379,7 +379,7 @@ class FileSystemWriteTest : public testing::TestWithParam "[[5, null], [6, false], [7, false]]")}; source_data.schema = schema({field("i32", int32()), field("bool", boolean())}); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK_AND_ASSIGN(auto plan, cp::ExecPlan::Make()); auto source_decl = cp::Declaration::Sequence( @@ -422,7 +422,7 @@ class FileSystemWriteTest : public testing::TestWithParam TEST_P(FileSystemWriteTest, Write) { auto plan_factory = [](const FileSystemDatasetWriteOptions& write_options, - std::function>()>* sink_gen) { + std::function>()>* sink_gen) { return std::vector{{"write", WriteNodeOptions{write_options}}}; }; TestDatasetWriteRoundTrip(plan_factory, /*has_output=*/false); @@ -431,7 +431,7 @@ TEST_P(FileSystemWriteTest, Write) { TEST_P(FileSystemWriteTest, TeeWrite) { auto plan_factory = [](const FileSystemDatasetWriteOptions& write_options, - std::function>()>* sink_gen) { + std::function>()>* sink_gen) { return std::vector{ {"tee", WriteNodeOptions{write_options}}, {"sink", cp::SinkNodeOptions{sink_gen}}, diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index 26abc10e6b8..a9744d0aabf 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -330,12 +330,12 @@ Result KeyValuePartitioning::Format( return FormatValues(values); } -inline util::optional NextValid(const ScalarVector& values, int first_null) { +inline std::optional NextValid(const ScalarVector& values, int first_null) { auto it = std::find_if(values.begin() + first_null + 1, values.end(), [](const std::shared_ptr& v) { return v != nullptr; }); if (it == values.end()) { - return util::nullopt; + return std::nullopt; } return static_cast(it - values.begin()); @@ -473,7 +473,7 @@ class KeyValuePartitioningFactory : public PartitioningFactory { return it_inserted.first->second; } - Status InsertRepr(const std::string& name, util::optional repr) { + Status InsertRepr(const std::string& name, std::optional repr) { auto field_index = GetOrInsertField(name); if (repr.has_value()) { return InsertRepr(field_index, *repr); @@ -715,12 +715,12 @@ bool FilenamePartitioning::Equals(const Partitioning& other) const { return KeyValuePartitioning::Equals(other); } -Result> HivePartitioning::ParseKey( +Result> HivePartitioning::ParseKey( const std::string& segment, const HivePartitioningOptions& options) { auto name_end = string_view(segment).find_first_of('='); // Not round-trippable if (name_end == string_view::npos) { - return util::nullopt; + return std::nullopt; } // Static method, so we have no better place for it @@ -750,7 +750,7 @@ Result> HivePartitioning::ParseKey( } if (value == options.null_fallback) { - return Key{std::move(name), util::nullopt}; + return Key{std::move(name), std::nullopt}; } return Key{std::move(name), std::move(value)}; } diff --git a/cpp/src/arrow/dataset/partition.h b/cpp/src/arrow/dataset/partition.h index 2d8c8bb2746..faee0c676e2 100644 --- a/cpp/src/arrow/dataset/partition.h +++ b/cpp/src/arrow/dataset/partition.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -31,7 +32,6 @@ #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" #include "arrow/util/compare.h" -#include "arrow/util/optional.h" namespace arrow { @@ -174,7 +174,7 @@ class ARROW_DS_EXPORT KeyValuePartitioning : public Partitioning { /// of a scalar value struct Key { std::string name; - util::optional value; + std::optional value; }; Result Partition( @@ -289,8 +289,8 @@ class ARROW_DS_EXPORT HivePartitioning : public KeyValuePartitioning { std::string null_fallback() const { return hive_options_.null_fallback; } const HivePartitioningOptions& options() const { return hive_options_; } - static Result> ParseKey(const std::string& segment, - const HivePartitioningOptions& options); + static Result> ParseKey(const std::string& segment, + const HivePartitioningOptions& options); bool Equals(const Partitioning& other) const override; diff --git a/cpp/src/arrow/dataset/scanner.cc b/cpp/src/arrow/dataset/scanner.cc index badd18bf318..eb09a986c97 100644 --- a/cpp/src/arrow/dataset/scanner.cc +++ b/cpp/src/arrow/dataset/scanner.cc @@ -334,7 +334,7 @@ Result AsyncScanner::ScanBatchesUnorderedAsync( } Result ToEnumeratedRecordBatch( - const util::optional& batch, const ScanOptions& options, + const std::optional& batch, const ScanOptions& options, const FragmentVector& fragments) { int num_fields = options.projected_schema->num_fields(); @@ -363,7 +363,7 @@ Result AsyncScanner::ScanBatchesUnorderedAsync( ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(exec_context.get())); plan->SetUseLegacyBatching(use_legacy_batching); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; auto exprs = scan_options_->projection.call()->arguments; auto names = checked_cast( @@ -402,7 +402,7 @@ Result AsyncScanner::ScanBatchesUnorderedAsync( return MakeMappedGenerator( std::move(sink_gen), [sink_gen, options, stop_producing, - shared_fragments](const util::optional& batch) + shared_fragments](const std::optional& batch) -> Future { return ToEnumeratedRecordBatch(batch, *options, *shared_fragments); }); @@ -655,7 +655,7 @@ Result AsyncScanner::CountRows() { fragment_gen = MakeMappedGenerator( std::move(fragment_gen), [&](const std::shared_ptr& fragment) { return fragment->CountRows(options->filter, options) - .Then([&, fragment](util::optional fast_count) mutable + .Then([&, fragment](std::optional fast_count) mutable -> std::shared_ptr { if (fast_count) { // fast path: got row count directly; skip scanning this fragment @@ -669,7 +669,7 @@ Result AsyncScanner::CountRows() { }); }); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; RETURN_NOT_OK( compute::Declaration::Sequence( @@ -912,7 +912,7 @@ Result MakeScanNode(compute::ExecPlan* plan, auto gen = MakeMappedGenerator( std::move(batch_gen), [scan_options](const EnumeratedRecordBatch& partial) - -> Result> { + -> Result> { // TODO(ARROW-13263) fragments may be able to attach more guarantees to batches // than this, for example parquet's row group stats. Failing to do this leaves // perf on the table because row group stats could be used to skip kernel execs in @@ -924,7 +924,7 @@ Result MakeScanNode(compute::ExecPlan* plan, auto guarantee = partial.fragment.value->partition_expression(); ARROW_ASSIGN_OR_RAISE( - util::optional batch, + std::optional batch, compute::MakeExecBatch(*scan_options->dataset_schema, partial.record_batch.value, guarantee)); @@ -978,7 +978,7 @@ Result MakeOrderedSinkNode(compute::ExecPlan* plan, } auto input = inputs[0]; - AsyncGenerator> unordered; + AsyncGenerator> unordered; ARROW_ASSIGN_OR_RAISE(auto node, compute::MakeExecNode("sink", plan, std::move(inputs), compute::SinkNodeOptions{&unordered})); @@ -1009,8 +1009,8 @@ Result MakeOrderedSinkNode(compute::ExecPlan* plan, return fragment_index(batch) < 0; }; - auto left_after_right = [=](const util::optional& left, - const util::optional& right) { + auto left_after_right = [=](const std::optional& left, + const std::optional& right) { // Before any comes first if (is_before_any(*left)) { return false; @@ -1026,8 +1026,8 @@ Result MakeOrderedSinkNode(compute::ExecPlan* plan, return fragment_index(*left) > fragment_index(*right); }; - auto is_next = [=](const util::optional& prev, - const util::optional& next) { + auto is_next = [=](const std::optional& prev, + const std::optional& next) { // Only true if next is the first batch if (is_before_any(*prev)) { return fragment_index(*next) == 0 && batch_index(*next) == 0; @@ -1044,7 +1044,7 @@ Result MakeOrderedSinkNode(compute::ExecPlan* plan, const auto& sink_options = checked_cast(options); *sink_options.generator = MakeSequencingGenerator(std::move(unordered), left_after_right, is_next, - util::make_optional(std::move(before_any))); + std::make_optional(std::move(before_any))); return node; } diff --git a/cpp/src/arrow/dataset/scanner_benchmark.cc b/cpp/src/arrow/dataset/scanner_benchmark.cc index 6d314d9d9a6..b0254089a95 100644 --- a/cpp/src/arrow/dataset/scanner_benchmark.cc +++ b/cpp/src/arrow/dataset/scanner_benchmark.cc @@ -149,7 +149,7 @@ void MinimalEndToEndScan(size_t num_batches, size_t batch_size, bool async_mode) compute::ProjectNodeOptions{{a_times_2}, {}, async_mode})); // finally, pipe the project node into a sink node - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK_AND_ASSIGN(compute::ExecNode * sink, compute::MakeExecNode("sink", plan.get(), {project}, compute::SinkNodeOptions{&sink_gen})); diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 804e82b57db..0768014b862 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -481,16 +481,16 @@ class CountRowsOnlyFragment : public InMemoryFragment { public: using InMemoryFragment::InMemoryFragment; - Future> CountRows( - compute::Expression predicate, const std::shared_ptr&) override { + Future> CountRows(compute::Expression predicate, + const std::shared_ptr&) override { if (compute::FieldsInExpression(predicate).size() > 0) { - return Future>::MakeFinished(util::nullopt); + return Future>::MakeFinished(std::nullopt); } int64_t sum = 0; for (const auto& batch : record_batches_) { sum += batch->num_rows(); } - return Future>::MakeFinished(sum); + return Future>::MakeFinished(sum); } Result ScanBatchesAsync( const std::shared_ptr&) override { @@ -502,9 +502,9 @@ class ScanOnlyFragment : public InMemoryFragment { public: using InMemoryFragment::InMemoryFragment; - Future> CountRows( - compute::Expression predicate, const std::shared_ptr&) override { - return Future>::MakeFinished(util::nullopt); + Future> CountRows(compute::Expression predicate, + const std::shared_ptr&) override { + return Future>::MakeFinished(std::nullopt); } Result ScanBatchesAsync( const std::shared_ptr&) override { @@ -532,14 +532,14 @@ class CountFailFragment : public InMemoryFragment { public: explicit CountFailFragment(RecordBatchVector record_batches) : InMemoryFragment(std::move(record_batches)), - count(Future>::Make()) {} + count(Future>::Make()) {} - Future> CountRows( - compute::Expression, const std::shared_ptr&) override { + Future> CountRows(compute::Expression, + const std::shared_ptr&) override { return count; } - Future> count; + Future> count; }; TEST_P(TestScanner, CountRowsFailure) { SetSchema({field("i32", int32()), field("f64", float64())}); @@ -557,7 +557,7 @@ TEST_P(TestScanner, CountRowsFailure) { ASSERT_RAISES(Invalid, scanner->CountRows()); // Fragment 2 doesn't complete until after the count stops - should not break anything // under ASan, etc. - fragment2->count.MarkFinished(util::nullopt); + fragment2->count.MarkFinished(std::nullopt); } TEST_P(TestScanner, CountRowsWithMetadata) { @@ -1358,7 +1358,7 @@ struct TestPlan { .Then([collected_fut]() -> Result> { ARROW_ASSIGN_OR_RAISE(auto collected, collected_fut.result()); return ::arrow::internal::MapVector( - [](util::optional batch) { return std::move(*batch); }, + [](std::optional batch) { return std::move(*batch); }, std::move(collected)); }); } @@ -1366,7 +1366,7 @@ struct TestPlan { compute::ExecPlan* get() { return plan.get(); } std::shared_ptr plan; - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; }; struct DatasetAndBatches { @@ -1763,7 +1763,7 @@ TEST(ScanNode, MinimalEndToEnd) { compute::ProjectNodeOptions{{a_times_2}})); // finally, pipe the project node into a sink node - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK_AND_ASSIGN(compute::ExecNode * sink, compute::MakeExecNode("ordered_sink", plan.get(), {project}, compute::SinkNodeOptions{&sink_gen})); @@ -1863,7 +1863,7 @@ TEST(ScanNode, MinimalScalarAggEndToEnd) { "sum", nullptr, "a * 2", "sum(a * 2)"}}})); // finally, pipe the aggregate node into a sink node - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK_AND_ASSIGN(compute::ExecNode * sink, compute::MakeExecNode("sink", plan.get(), {aggregate}, compute::SinkNodeOptions{&sink_gen})); @@ -1953,7 +1953,7 @@ TEST(ScanNode, MinimalGroupedAggEndToEnd) { /*keys=*/{"b"}})); // finally, pipe the aggregate node into a sink node - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; ASSERT_OK_AND_ASSIGN(compute::ExecNode * sink, compute::MakeExecNode("sink", plan.get(), {aggregate}, compute::SinkNodeOptions{&sink_gen})); diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 09409745159..05a98693896 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -534,26 +534,26 @@ class FileFormatFixtureMixin : public ::testing::Test { auto source = this->GetFileSource(reader.get()); auto fragment = this->MakeFragment(*source); - ASSERT_FINISHES_OK_AND_EQ(util::make_optional(expected_rows()), + ASSERT_FINISHES_OK_AND_EQ(std::make_optional(expected_rows()), fragment->CountRows(literal(true), options)); fragment = this->MakeFragment(*source, equal(field_ref("part"), literal(2))); - ASSERT_FINISHES_OK_AND_EQ(util::make_optional(expected_rows()), + ASSERT_FINISHES_OK_AND_EQ(std::make_optional(expected_rows()), fragment->CountRows(literal(true), options)); auto predicate = equal(field_ref("part"), literal(1)); ASSERT_OK_AND_ASSIGN(predicate, predicate.Bind(*full_schema)); - ASSERT_FINISHES_OK_AND_EQ(util::make_optional(0), + ASSERT_FINISHES_OK_AND_EQ(std::make_optional(0), fragment->CountRows(predicate, options)); predicate = equal(field_ref("part"), literal(2)); ASSERT_OK_AND_ASSIGN(predicate, predicate.Bind(*full_schema)); - ASSERT_FINISHES_OK_AND_EQ(util::make_optional(expected_rows()), + ASSERT_FINISHES_OK_AND_EQ(std::make_optional(expected_rows()), fragment->CountRows(predicate, options)); predicate = equal(call("add", {field_ref("f64"), literal(3)}), literal(2)); ASSERT_OK_AND_ASSIGN(predicate, predicate.Bind(*full_schema)); - ASSERT_FINISHES_OK_AND_EQ(util::nullopt, fragment->CountRows(predicate, options)); + ASSERT_FINISHES_OK_AND_EQ(std::nullopt, fragment->CountRows(predicate, options)); } void TestFragmentEquals() { auto options = std::make_shared(); diff --git a/cpp/src/arrow/engine/simple_extension_type_internal.h b/cpp/src/arrow/engine/simple_extension_type_internal.h index b177425a9a9..66d86088a76 100644 --- a/cpp/src/arrow/engine/simple_extension_type_internal.h +++ b/cpp/src/arrow/engine/simple_extension_type_internal.h @@ -18,13 +18,13 @@ #pragma once #include +#include #include #include #include #include "arrow/extension_type.h" #include "arrow/util/logging.h" -#include "arrow/util/optional.h" #include "arrow/util/reflection_internal.h" #include "arrow/util/string.h" @@ -106,7 +106,7 @@ class SimpleExtensionType : public ExtensionType { kProperties->ForEach(*this); } - void Fail() { params_ = util::nullopt; } + void Fail() { params_ = std::nullopt; } void Init(util::string_view class_name, util::string_view repr, size_t num_properties) { @@ -144,7 +144,7 @@ class SimpleExtensionType : public ExtensionType { prop.set(&*params_, std::move(value)); } - util::optional params_; + std::optional params_; std::vector members_; }; Result> Deserialize( diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 9b364741a35..1f9d234bff7 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -51,7 +51,7 @@ Status DecodeArg(const substrait::FunctionArgument& arg, uint32_t idx, call->SetEnumArg(idx, enum_val.specified()); break; case substrait::FunctionArgument::Enum::EnumKindCase::kUnspecified: - call->SetEnumArg(idx, util::nullopt); + call->SetEnumArg(idx, std::nullopt); break; default: return Status::Invalid("Unrecognized enum kind case: ", @@ -138,7 +138,7 @@ Result FromProto(const substrait::Expression& expr, case substrait::Expression::kSelection: { if (!expr.selection().has_direct_reference()) break; - util::optional out; + std::optional out; if (expr.selection().has_expression()) { ARROW_ASSIGN_OR_RAISE( out, FromProto(expr.selection().expression(), ext_set, conversion_options)); @@ -906,7 +906,7 @@ Result> EncodeSubstraitCa substrait::FunctionArgument* arg = scalar_fn->add_arguments(); if (call.HasEnumArg(i)) { auto enum_val = internal::make_unique(); - ARROW_ASSIGN_OR_RAISE(util::optional enum_arg, + ARROW_ASSIGN_OR_RAISE(std::optional enum_arg, call.GetEnumArg(i)); if (enum_arg) { enum_val->set_specified(enum_arg->to_string()); diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 0e1f5ebc664..926fe846fff 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -78,24 +78,24 @@ Id IdStorage::Emplace(Id id) { return {owned_uri, owned_name}; } -util::optional IdStorage::Find(Id id) const { - util::optional maybe_owned_uri = FindUri(id.uri); +std::optional IdStorage::Find(Id id) const { + std::optional maybe_owned_uri = FindUri(id.uri); if (!maybe_owned_uri) { - return util::nullopt; + return std::nullopt; } auto name_itr = names_.find(id.name); if (name_itr == names_.end()) { - return util::nullopt; + return std::nullopt; } else { return Id{*maybe_owned_uri, *name_itr}; } } -util::optional IdStorage::FindUri(util::string_view uri) const { +std::optional IdStorage::FindUri(util::string_view uri) const { auto uri_itr = uris_.find(uri); if (uri_itr == uris_.end()) { - return util::nullopt; + return std::nullopt; } return *uri_itr; } @@ -111,8 +111,7 @@ util::string_view IdStorage::EmplaceUri(util::string_view uri) { return *uri_itr; } -Result> SubstraitCall::GetEnumArg( - uint32_t index) const { +Result> SubstraitCall::GetEnumArg(uint32_t index) const { if (index >= size_) { return Status::Invalid("Expected Substrait call to have an enum argument at index ", index, " but it did not have enough arguments"); @@ -129,7 +128,7 @@ bool SubstraitCall::HasEnumArg(uint32_t index) const { return enum_args_.find(index) != enum_args_.end(); } -void SubstraitCall::SetEnumArg(uint32_t index, util::optional enum_arg) { +void SubstraitCall::SetEnumArg(uint32_t index, std::optional enum_arg) { size_ = std::max(size_, index + 1); enum_args_[index] = std::move(enum_arg); } @@ -203,7 +202,7 @@ Result ExtensionSet::Make( set.registry_ = registry; for (auto& uri : uris) { - util::optional maybe_uri_internal = registry->FindUri(uri.second); + std::optional maybe_uri_internal = registry->FindUri(uri.second); if (maybe_uri_internal) { set.uris_[uri.first] = *maybe_uri_internal; } else { @@ -233,7 +232,7 @@ Result ExtensionSet::Make( for (const auto& function_id : function_ids) { if (function_id.second.empty()) continue; RETURN_NOT_OK(set.CheckHasUri(function_id.second.uri)); - util::optional maybe_id_internal = registry->FindId(function_id.second); + std::optional maybe_id_internal = registry->FindId(function_id.second); if (maybe_id_internal) { set.functions_[function_id.first] = *maybe_id_internal; } else { @@ -309,9 +308,9 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { virtual ~ExtensionIdRegistryImpl() {} - util::optional FindUri(util::string_view uri) const override { + std::optional FindUri(util::string_view uri) const override { if (parent_) { - util::optional parent_uri = parent_->FindUri(uri); + std::optional parent_uri = parent_->FindUri(uri); if (parent_uri) { return parent_uri; } @@ -319,9 +318,9 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return ids_.FindUri(uri); } - util::optional FindId(Id id) const override { + std::optional FindId(Id id) const override { if (parent_) { - util::optional parent_id = parent_->FindId(id); + std::optional parent_id = parent_->FindId(id); if (parent_id) { return parent_id; } @@ -329,7 +328,7 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return ids_.Find(id); } - util::optional GetType(const DataType& type) const override { + std::optional GetType(const DataType& type) const override { if (auto index = GetIndex(type_to_index_, &type)) { return TypeRecord{type_ids_[*index], types_[*index]}; } @@ -339,7 +338,7 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return {}; } - util::optional GetType(Id id) const override { + std::optional GetType(Id id) const override { if (auto index = GetIndex(id_to_index_, id)) { return TypeRecord{type_ids_[*index], types_[*index]}; } @@ -605,7 +604,7 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { }; template -using EnumParser = std::function(util::optional)>; +using EnumParser = std::function(std::optional)>; template EnumParser GetEnumParser(const std::vector& options) { @@ -613,7 +612,7 @@ EnumParser GetEnumParser(const std::vector& options) { for (std::size_t i = 0; i < options.size(); i++) { parse_map[options[i]] = static_cast(i + 1); } - return [parse_map](util::optional enum_val) -> Result { + return [parse_map](std::optional enum_val) -> Result { if (!enum_val) { // Assumes 0 is always kUnspecified in Enum return static_cast(0); @@ -640,7 +639,7 @@ static EnumParser kOverflowParser = template Result ParseEnumArg(const SubstraitCall& call, uint32_t arg_index, const EnumParser& parser) { - ARROW_ASSIGN_OR_RAISE(util::optional enum_arg, + ARROW_ASSIGN_OR_RAISE(std::optional enum_arg, call.GetEnumArg(arg_index)); return parser(enum_arg); } diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index 410a19ecf61..e2b20f989ac 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -20,6 +20,7 @@ #pragma once #include +#include #include #include #include @@ -31,7 +32,6 @@ #include "arrow/type_fwd.h" #include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" -#include "arrow/util/optional.h" #include "arrow/util/string_view.h" namespace arrow { @@ -86,11 +86,11 @@ class IdStorage { /// \brief Get an equivalent id pointing into this storage /// /// If no id is found then nullopt will be returned - util::optional Find(Id id) const; + std::optional Find(Id id) const; /// \brief Get an equivalent view pointing into this storage for a URI /// /// If no URI is found then nullopt will be returned - util::optional FindUri(util::string_view uri) const; + std::optional FindUri(util::string_view uri) const; private: std::unordered_set uris_; @@ -119,8 +119,8 @@ class SubstraitCall { bool is_hash() const { return is_hash_; } bool HasEnumArg(uint32_t index) const; - Result> GetEnumArg(uint32_t index) const; - void SetEnumArg(uint32_t index, util::optional enum_arg); + Result> GetEnumArg(uint32_t index) const; + void SetEnumArg(uint32_t index, std::optional enum_arg); Result GetValueArg(uint32_t index) const; bool HasValueArg(uint32_t index) const; void SetValueArg(uint32_t index, compute::Expression value_arg); @@ -133,7 +133,7 @@ class SubstraitCall { // Only needed when converting from Substrait -> Arrow aggregates. The // Arrow function name depends on whether or not there are any groups bool is_hash_; - std::unordered_map> enum_args_; + std::unordered_map> enum_args_; std::unordered_map value_args_; uint32_t size_ = 0; }; @@ -174,13 +174,13 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { /// \brief Return a uri view owned by this registry /// /// If the URI has never been emplaced it will return nullopt - virtual util::optional FindUri(util::string_view uri) const = 0; + virtual std::optional FindUri(util::string_view uri) const = 0; /// \brief Return a id view owned by this registry /// /// If the id has never been emplaced it will return nullopt - virtual util::optional FindId(Id id) const = 0; - virtual util::optional GetType(const DataType&) const = 0; - virtual util::optional GetType(Id) const = 0; + virtual std::optional FindId(Id id) const = 0; + virtual std::optional GetType(const DataType&) const = 0; + virtual std::optional GetType(Id) const = 0; virtual Status CanRegisterType(Id, const std::shared_ptr& type) const = 0; virtual Status RegisterType(Id, std::shared_ptr) = 0; /// \brief Register a converter that converts an Arrow call to a Substrait call diff --git a/cpp/src/arrow/engine/substrait/extension_types.cc b/cpp/src/arrow/engine/substrait/extension_types.cc index b8fd191b3fd..2b7211766ee 100644 --- a/cpp/src/arrow/engine/substrait/extension_types.cc +++ b/cpp/src/arrow/engine/substrait/extension_types.cc @@ -115,18 +115,18 @@ bool UnwrapUuid(const DataType& t) { return false; } -util::optional UnwrapFixedChar(const DataType& t) { +std::optional UnwrapFixedChar(const DataType& t) { if (auto params = FixedCharType::GetIf(t)) { return params->length; } - return util::nullopt; + return std::nullopt; } -util::optional UnwrapVarChar(const DataType& t) { +std::optional UnwrapVarChar(const DataType& t) { if (auto params = VarCharType::GetIf(t)) { return params->length; } - return util::nullopt; + return std::nullopt; } bool UnwrapIntervalYear(const DataType& t) { diff --git a/cpp/src/arrow/engine/substrait/extension_types.h b/cpp/src/arrow/engine/substrait/extension_types.h index d6db454ec30..c623d081b18 100644 --- a/cpp/src/arrow/engine/substrait/extension_types.h +++ b/cpp/src/arrow/engine/substrait/extension_types.h @@ -19,13 +19,13 @@ #pragma once +#include #include #include "arrow/buffer.h" #include "arrow/compute/function.h" #include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" -#include "arrow/util/optional.h" #include "arrow/util/string_view.h" namespace arrow { @@ -64,11 +64,11 @@ bool UnwrapUuid(const DataType&); /// Return FixedChar length if t is FixedChar, otherwise nullopt ARROW_ENGINE_EXPORT -util::optional UnwrapFixedChar(const DataType&); +std::optional UnwrapFixedChar(const DataType&); /// Return Varchar (max) length if t is VarChar, otherwise nullopt ARROW_ENGINE_EXPORT -util::optional UnwrapVarChar(const DataType& t); +std::optional UnwrapVarChar(const DataType& t); /// Return true if t is IntervalYear, otherwise false ARROW_ENGINE_EXPORT diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index b50e1c6084c..cb1eadcdbdf 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -80,7 +80,7 @@ Result> GetTableFromPlan( const std::shared_ptr& output_schema) { ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(&exec_context)); - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; auto sink_node_options = compute::SinkNodeOptions{&sink_gen}; auto sink_declaration = compute::Declaration({"sink", sink_node_options, "e"}); auto declarations = compute::Declaration::Sequence({other_declrs, sink_declaration}); @@ -1949,7 +1949,7 @@ TEST(Substrait, BasicPlanRoundTripping) { auto comp_right_value = compute::field_ref(filter_col_right); auto filter = compute::equal(comp_left_value, comp_right_value); - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; auto declarations = compute::Declaration::Sequence( {compute::Declaration( diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index f51666ef858..6587ea077dd 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -29,7 +29,7 @@ namespace { class SubstraitSinkConsumer : public compute::SinkNodeConsumer { public: explicit SubstraitSinkConsumer( - arrow::PushGenerator>::Producer producer) + arrow::PushGenerator>::Producer producer) : producer_(std::move(producer)) {} Status Consume(compute::ExecBatch batch) override { @@ -53,7 +53,7 @@ class SubstraitSinkConsumer : public compute::SinkNodeConsumer { std::shared_ptr schema() { return schema_; } private: - arrow::PushGenerator>::Producer producer_; + arrow::PushGenerator>::Producer producer_; std::shared_ptr schema_; }; @@ -105,7 +105,7 @@ class SubstraitExecutor { } private: - arrow::PushGenerator> generator_; + arrow::PushGenerator> generator_; std::vector declarations_; std::shared_ptr plan_; bool plan_started_; diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 90cb4e3dd2a..a616968d961 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -18,11 +18,12 @@ #pragma once #include +#include + #include "arrow/compute/registry.h" #include "arrow/engine/substrait/api.h" #include "arrow/engine/substrait/options.h" #include "arrow/util/iterator.h" -#include "arrow/util/optional.h" namespace arrow { diff --git a/cpp/src/arrow/filesystem/gcsfs.h b/cpp/src/arrow/filesystem/gcsfs.h index 77b8a0b201a..c3d03b5cb21 100644 --- a/cpp/src/arrow/filesystem/gcsfs.h +++ b/cpp/src/arrow/filesystem/gcsfs.h @@ -18,11 +18,11 @@ #pragma once #include +#include #include #include #include "arrow/filesystem/filesystem.h" -#include "arrow/util/optional.h" #include "arrow/util/uri.h" namespace arrow { @@ -70,7 +70,7 @@ struct ARROW_EXPORT GcsOptions { /// errors. /// /// The default policy is to retry for up to 15 minutes. - arrow::util::optional retry_limit_seconds; + std::optional retry_limit_seconds; /// \brief Default metadata for OpenOutputStream. /// diff --git a/cpp/src/arrow/filesystem/path_util.cc b/cpp/src/arrow/filesystem/path_util.cc index 1afc3b2a89b..2216a4bb258 100644 --- a/cpp/src/arrow/filesystem/path_util.cc +++ b/cpp/src/arrow/filesystem/path_util.cc @@ -189,10 +189,10 @@ bool IsAncestorOf(util::string_view ancestor, util::string_view descendant) { return descendant.starts_with(std::string{kSep}); } -util::optional RemoveAncestor(util::string_view ancestor, - util::string_view descendant) { +std::optional RemoveAncestor(util::string_view ancestor, + util::string_view descendant) { if (!IsAncestorOf(ancestor, descendant)) { - return util::nullopt; + return std::nullopt; } auto relative_to_ancestor = descendant.substr(ancestor.size()); diff --git a/cpp/src/arrow/filesystem/path_util.h b/cpp/src/arrow/filesystem/path_util.h index d4083d3b5c9..ea8e56df5d4 100644 --- a/cpp/src/arrow/filesystem/path_util.h +++ b/cpp/src/arrow/filesystem/path_util.h @@ -17,12 +17,12 @@ #pragma once +#include #include #include #include #include "arrow/type_fwd.h" -#include "arrow/util/optional.h" #include "arrow/util/string_view.h" namespace arrow { @@ -79,8 +79,8 @@ ARROW_EXPORT bool IsAncestorOf(util::string_view ancestor, util::string_view descendant); ARROW_EXPORT -util::optional RemoveAncestor(util::string_view ancestor, - util::string_view descendant); +std::optional RemoveAncestor(util::string_view ancestor, + util::string_view descendant); /// Return a vector of ancestors between a base path and a descendant. /// For example, diff --git a/cpp/src/arrow/filesystem/s3fs.cc b/cpp/src/arrow/filesystem/s3fs.cc index 878f54812ba..db79810f5d7 100644 --- a/cpp/src/arrow/filesystem/s3fs.cc +++ b/cpp/src/arrow/filesystem/s3fs.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -90,7 +91,6 @@ #include "arrow/util/io_util.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" -#include "arrow/util/optional.h" #include "arrow/util/string.h" #include "arrow/util/task_group.h" #include "arrow/util/thread_pool.h" @@ -657,7 +657,7 @@ class S3Client : public Aws::S3::S3Client { // We work around the issue by registering a DataReceivedEventHandler // which parses the XML response for embedded errors. - util::optional> aws_error; + std::optional> aws_error; auto handler = [&](const Aws::Http::HttpRequest* http_req, Aws::Http::HttpResponse* http_resp, @@ -762,7 +762,7 @@ class ClientBuilder { Aws::Client::ClientConfiguration* mutable_config() { return &client_config_; } Result> BuildClient( - util::optional io_context = util::nullopt) { + std::optional io_context = std::nullopt) { credentials_provider_ = options_.credentials_provider; if (!options_.region.empty()) { client_config_.region = ToAwsString(options_.region); @@ -1708,7 +1708,7 @@ class S3FileSystem::Impl : public std::enable_shared_from_this client_; - util::optional backend_; + std::optional backend_; const int32_t kListObjectsMaxKeys = 1000; // At most 1000 keys per multiple-delete request diff --git a/cpp/src/arrow/flight/cookie_internal.cc b/cpp/src/arrow/flight/cookie_internal.cc index 1a15da92676..380ea56976d 100644 --- a/cpp/src/arrow/flight/cookie_internal.cc +++ b/cpp/src/arrow/flight/cookie_internal.cc @@ -47,7 +47,7 @@ namespace arrow { namespace flight { namespace internal { -using CookiePair = arrow::util::optional>; +using CookiePair = std::optional>; using CookieHeaderPair = const std::pair&; @@ -139,7 +139,7 @@ CookiePair Cookie::ParseCookieAttribute(const std::string& cookie_header_value, if (std::string::npos == equals_pos) { // No cookie attribute. *start_pos = std::string::npos; - return arrow::util::nullopt; + return std::nullopt; } std::string::size_type semi_col_pos = cookie_header_value.find(';', equals_pos); diff --git a/cpp/src/arrow/flight/cookie_internal.h b/cpp/src/arrow/flight/cookie_internal.h index 6b3af516bb6..b87c8052266 100644 --- a/cpp/src/arrow/flight/cookie_internal.h +++ b/cpp/src/arrow/flight/cookie_internal.h @@ -21,13 +21,13 @@ #include #include +#include #include #include #include #include "arrow/flight/client_middleware.h" #include "arrow/result.h" -#include "arrow/util/optional.h" #include "arrow/util/string_view.h" namespace arrow { @@ -65,7 +65,7 @@ class ARROW_FLIGHT_EXPORT Cookie { /// function returns. /// /// \return Optional cookie key value pair. - static arrow::util::optional> ParseCookieAttribute( + static std::optional> ParseCookieAttribute( const std::string& cookie_header_value, std::string::size_type* start_pos); /// \brief Function to fix cookie format date string so it is accepted by Windows diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc index 84040a1a476..1275db6a8d4 100644 --- a/cpp/src/arrow/flight/flight_internals_test.cc +++ b/cpp/src/arrow/flight/flight_internals_test.cc @@ -359,12 +359,12 @@ class TestCookieParsing : public ::testing::Test { void VerifyCookieAttributeParsing( const std::string cookie_str, std::string::size_type start_pos, - const util::optional> cookie_attribute, + const std::optional> cookie_attribute, const std::string::size_type start_pos_after) { - util::optional> attr = + std::optional> attr = internal::Cookie::ParseCookieAttribute(cookie_str, &start_pos); - if (cookie_attribute == util::nullopt) { + if (cookie_attribute == std::nullopt) { EXPECT_EQ(cookie_attribute, attr); } else { EXPECT_EQ(cookie_attribute.value(), attr.value()); @@ -454,7 +454,7 @@ TEST_F(TestCookieParsing, DateConversion) { } TEST_F(TestCookieParsing, ParseCookieAttribute) { - VerifyCookieAttributeParsing("", 0, util::nullopt, std::string::npos); + VerifyCookieAttributeParsing("", 0, std::nullopt, std::string::npos); std::string cookie_string = "attr0=0; attr1=1; attr2=2; attr3=3"; auto attr_length = std::string("attr0=0;").length(); @@ -470,8 +470,8 @@ TEST_F(TestCookieParsing, ParseCookieAttribute) { VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length + 1)), std::make_pair("attr3", "3"), std::string::npos); VerifyCookieAttributeParsing(cookie_string, (start_pos += (attr_length - 1)), - util::nullopt, std::string::npos); - VerifyCookieAttributeParsing(cookie_string, std::string::npos, util::nullopt, + std::nullopt, std::string::npos); + VerifyCookieAttributeParsing(cookie_string, std::string::npos, std::nullopt, std::string::npos); } @@ -491,28 +491,28 @@ TEST(TransportErrorHandling, ReconstructStatus) { EXPECT_RAISES_WITH_MESSAGE_THAT( Invalid, ::testing::HasSubstr(". Also, server sent unknown or invalid Arrow status code -1"), - internal::ReconstructStatus("-1", current, util::nullopt, util::nullopt, - util::nullopt, /*detail=*/nullptr)); + internal::ReconstructStatus("-1", current, std::nullopt, std::nullopt, std::nullopt, + /*detail=*/nullptr)); EXPECT_RAISES_WITH_MESSAGE_THAT( Invalid, ::testing::HasSubstr( ". Also, server sent unknown or invalid Arrow status code foobar"), - internal::ReconstructStatus("foobar", current, util::nullopt, util::nullopt, - util::nullopt, /*detail=*/nullptr)); + internal::ReconstructStatus("foobar", current, std::nullopt, std::nullopt, + std::nullopt, /*detail=*/nullptr)); // Override code EXPECT_RAISES_WITH_MESSAGE_THAT( AlreadyExists, ::testing::HasSubstr("Base error message"), internal::ReconstructStatus( std::to_string(static_cast(StatusCode::AlreadyExists)), current, - util::nullopt, util::nullopt, util::nullopt, /*detail=*/nullptr)); + std::nullopt, std::nullopt, std::nullopt, /*detail=*/nullptr)); // Override message EXPECT_RAISES_WITH_MESSAGE_THAT( AlreadyExists, ::testing::HasSubstr("Custom error message"), internal::ReconstructStatus( std::to_string(static_cast(StatusCode::AlreadyExists)), current, - "Custom error message", util::nullopt, util::nullopt, /*detail=*/nullptr)); + "Custom error message", std::nullopt, std::nullopt, /*detail=*/nullptr)); // With detail EXPECT_RAISES_WITH_MESSAGE_THAT( @@ -521,7 +521,7 @@ TEST(TransportErrorHandling, ReconstructStatus) { ::testing::HasSubstr(". Detail: Detail message")), internal::ReconstructStatus( std::to_string(static_cast(StatusCode::AlreadyExists)), current, - "Custom error message", "Detail message", util::nullopt, /*detail=*/nullptr)); + "Custom error message", "Detail message", std::nullopt, /*detail=*/nullptr)); // With detail and bin auto reconstructed = internal::ReconstructStatus( diff --git a/cpp/src/arrow/flight/sql/client_test.cc b/cpp/src/arrow/flight/sql/client_test.cc index b9eeda76b00..acd078a8477 100644 --- a/cpp/src/arrow/flight/sql/client_test.cc +++ b/cpp/src/arrow/flight/sql/client_test.cc @@ -213,7 +213,7 @@ TEST_F(TestFlightSqlClient, TestGetExported) { ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); - TableRef table_ref = {util::make_optional(catalog), util::make_optional(schema), table}; + TableRef table_ref = {std::make_optional(catalog), std::make_optional(schema), table}; ASSERT_OK(sql_client_.GetExportedKeys(call_options_, table_ref)); } @@ -231,7 +231,7 @@ TEST_F(TestFlightSqlClient, TestGetImported) { ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); - TableRef table_ref = {util::make_optional(catalog), util::make_optional(schema), table}; + TableRef table_ref = {std::make_optional(catalog), std::make_optional(schema), table}; ASSERT_OK(sql_client_.GetImportedKeys(call_options_, table_ref)); } @@ -249,7 +249,7 @@ TEST_F(TestFlightSqlClient, TestGetPrimary) { ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); - TableRef table_ref = {util::make_optional(catalog), util::make_optional(schema), table}; + TableRef table_ref = {std::make_optional(catalog), std::make_optional(schema), table}; ASSERT_OK(sql_client_.GetPrimaryKeys(call_options_, table_ref)); } @@ -273,10 +273,10 @@ TEST_F(TestFlightSqlClient, TestGetCrossReference) { ON_CALL(sql_client_, GetFlightInfo).WillByDefault(ReturnEmptyFlightInfo); EXPECT_CALL(sql_client_, GetFlightInfo(Ref(call_options_), descriptor)); - TableRef pk_table_ref = {util::make_optional(pk_catalog), - util::make_optional(pk_schema), pk_table}; - TableRef fk_table_ref = {util::make_optional(fk_catalog), - util::make_optional(fk_schema), fk_table}; + TableRef pk_table_ref = {std::make_optional(pk_catalog), std::make_optional(pk_schema), + pk_table}; + TableRef fk_table_ref = {std::make_optional(fk_catalog), std::make_optional(fk_schema), + fk_table}; ASSERT_OK(sql_client_.GetCrossReference(call_options_, pk_table_ref, fk_table_ref)); } diff --git a/cpp/src/arrow/flight/sql/server.cc b/cpp/src/arrow/flight/sql/server.cc index 78fbff0c33a..a8f3ed8a80c 100644 --- a/cpp/src/arrow/flight/sql/server.cc +++ b/cpp/src/arrow/flight/sql/server.cc @@ -33,7 +33,7 @@ #include "arrow/util/checked_cast.h" #define PROPERTY_TO_OPTIONAL(COMMAND, PROPERTY) \ - COMMAND.has_##PROPERTY() ? util::make_optional(COMMAND.PROPERTY()) : util::nullopt + COMMAND.has_##PROPERTY() ? std::make_optional(COMMAND.PROPERTY()) : std::nullopt namespace arrow { namespace flight { diff --git a/cpp/src/arrow/flight/sql/server.h b/cpp/src/arrow/flight/sql/server.h index 49e239a0cdd..91dad98843f 100644 --- a/cpp/src/arrow/flight/sql/server.h +++ b/cpp/src/arrow/flight/sql/server.h @@ -21,6 +21,7 @@ #pragma once #include +#include #include #include @@ -29,7 +30,6 @@ #include "arrow/flight/sql/types.h" #include "arrow/flight/sql/visibility.h" #include "arrow/flight/types.h" -#include "arrow/util/optional.h" namespace arrow { namespace flight { @@ -79,19 +79,19 @@ struct ARROW_FLIGHT_SQL_EXPORT GetSqlInfo { /// \brief A request to list database schemas. struct ARROW_FLIGHT_SQL_EXPORT GetDbSchemas { /// \brief An optional database catalog to filter on. - util::optional catalog; + std::optional catalog; /// \brief An optional database schema to filter on. - util::optional db_schema_filter_pattern; + std::optional db_schema_filter_pattern; }; /// \brief A request to list database tables. struct ARROW_FLIGHT_SQL_EXPORT GetTables { /// \brief An optional database catalog to filter on. - util::optional catalog; + std::optional catalog; /// \brief An optional database schema to filter on. - util::optional db_schema_filter_pattern; + std::optional db_schema_filter_pattern; /// \brief An optional table name to filter on. - util::optional table_name_filter_pattern; + std::optional table_name_filter_pattern; /// \brief A list of table types to filter on. std::vector table_types; /// \brief Whether to include the Arrow schema in the response. @@ -101,7 +101,7 @@ struct ARROW_FLIGHT_SQL_EXPORT GetTables { /// \brief A request to get SQL data type information. struct ARROW_FLIGHT_SQL_EXPORT GetXdbcTypeInfo { /// \brief A specific SQL type ID to fetch information about. - util::optional data_type; + std::optional data_type; }; /// \brief A request to list primary keys of a table. diff --git a/cpp/src/arrow/flight/sql/server_test.cc b/cpp/src/arrow/flight/sql/server_test.cc index 69081acdb59..7ba3ca4a243 100644 --- a/cpp/src/arrow/flight/sql/server_test.cc +++ b/cpp/src/arrow/flight/sql/server_test.cc @@ -640,7 +640,7 @@ TEST_F(TestFlightSqlServer, TestCommandPreparedStatementUpdate) { TEST_F(TestFlightSqlServer, TestCommandGetPrimaryKeys) { FlightCallOptions options = {}; - TableRef table_ref = {util::nullopt, util::nullopt, "int%"}; + TableRef table_ref = {std::nullopt, std::nullopt, "int%"}; ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetPrimaryKeys(options, table_ref)); ASSERT_OK_AND_ASSIGN(auto stream, @@ -664,7 +664,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetPrimaryKeys) { TEST_F(TestFlightSqlServer, TestCommandGetImportedKeys) { FlightCallOptions options = {}; - TableRef table_ref = {util::nullopt, util::nullopt, "intTable"}; + TableRef table_ref = {std::nullopt, std::nullopt, "intTable"}; ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetImportedKeys(options, table_ref)); ASSERT_OK_AND_ASSIGN(auto stream, @@ -696,7 +696,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetImportedKeys) { TEST_F(TestFlightSqlServer, TestCommandGetExportedKeys) { FlightCallOptions options = {}; - TableRef table_ref = {util::nullopt, util::nullopt, "foreignTable"}; + TableRef table_ref = {std::nullopt, std::nullopt, "foreignTable"}; ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetExportedKeys(options, table_ref)); ASSERT_OK_AND_ASSIGN(auto stream, @@ -728,8 +728,8 @@ TEST_F(TestFlightSqlServer, TestCommandGetExportedKeys) { TEST_F(TestFlightSqlServer, TestCommandGetCrossReference) { FlightCallOptions options = {}; - TableRef pk_table_ref = {util::nullopt, util::nullopt, "foreignTable"}; - TableRef fk_table_ref = {util::nullopt, util::nullopt, "intTable"}; + TableRef pk_table_ref = {std::nullopt, std::nullopt, "foreignTable"}; + TableRef fk_table_ref = {std::nullopt, std::nullopt, "intTable"}; ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetCrossReference( options, pk_table_ref, fk_table_ref)); diff --git a/cpp/src/arrow/flight/sql/test_app_cli.cc b/cpp/src/arrow/flight/sql/test_app_cli.cc index 7989210dd09..2df05d2875b 100644 --- a/cpp/src/arrow/flight/sql/test_app_cli.cc +++ b/cpp/src/arrow/flight/sql/test_app_cli.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include "arrow/array/builder_binary.h" #include "arrow/array/builder_primitive.h" @@ -29,7 +30,6 @@ #include "arrow/pretty_print.h" #include "arrow/status.h" #include "arrow/table.h" -#include "arrow/util/optional.h" using arrow::Result; using arrow::Schema; @@ -159,16 +159,16 @@ Status RunMain() { info, sql_client.GetTables(call_options, &FLAGS_catalog, &FLAGS_schema, &FLAGS_table, false, nullptr)); } else if (FLAGS_command == "GetExportedKeys") { - TableRef table_ref = {arrow::util::make_optional(FLAGS_catalog), - arrow::util::make_optional(FLAGS_schema), FLAGS_table}; + TableRef table_ref = {std::make_optional(FLAGS_catalog), + std::make_optional(FLAGS_schema), FLAGS_table}; ARROW_ASSIGN_OR_RAISE(info, sql_client.GetExportedKeys(call_options, table_ref)); } else if (FLAGS_command == "GetImportedKeys") { - TableRef table_ref = {arrow::util::make_optional(FLAGS_catalog), - arrow::util::make_optional(FLAGS_schema), FLAGS_table}; + TableRef table_ref = {std::make_optional(FLAGS_catalog), + std::make_optional(FLAGS_schema), FLAGS_table}; ARROW_ASSIGN_OR_RAISE(info, sql_client.GetImportedKeys(call_options, table_ref)); } else if (FLAGS_command == "GetPrimaryKeys") { - TableRef table_ref = {arrow::util::make_optional(FLAGS_catalog), - arrow::util::make_optional(FLAGS_schema), FLAGS_table}; + TableRef table_ref = {std::make_optional(FLAGS_catalog), + std::make_optional(FLAGS_schema), FLAGS_table}; ARROW_ASSIGN_OR_RAISE(info, sql_client.GetPrimaryKeys(call_options, table_ref)); } else if (FLAGS_command == "GetSqlInfo") { ARROW_ASSIGN_OR_RAISE(info, sql_client.GetSqlInfo(call_options, {})); diff --git a/cpp/src/arrow/flight/sql/types.h b/cpp/src/arrow/flight/sql/types.h index a6c2648e7c4..20c7952d8d7 100644 --- a/cpp/src/arrow/flight/sql/types.h +++ b/cpp/src/arrow/flight/sql/types.h @@ -18,13 +18,13 @@ #pragma once #include +#include #include #include #include #include "arrow/flight/sql/visibility.h" #include "arrow/type_fwd.h" -#include "arrow/util/optional.h" #include "arrow/util/variant.h" namespace arrow { @@ -838,9 +838,9 @@ struct ARROW_FLIGHT_SQL_EXPORT SqlInfoOptions { /// \brief A SQL %table reference, optionally containing table's catalog and db_schema. struct ARROW_FLIGHT_SQL_EXPORT TableRef { /// \brief The table's catalog. - util::optional catalog; + std::optional catalog; /// \brief The table's database schema. - util::optional db_schema; + std::optional db_schema; /// \brief The table name. std::string table; }; diff --git a/cpp/src/arrow/flight/transport.cc b/cpp/src/arrow/flight/transport.cc index 0da81a567eb..7f0b1cf929d 100644 --- a/cpp/src/arrow/flight/transport.cc +++ b/cpp/src/arrow/flight/transport.cc @@ -299,9 +299,9 @@ Status TransportStatus::ToStatus() const { } Status ReconstructStatus(const std::string& code_str, const Status& current_status, - util::optional message, - util::optional detail_message, - util::optional detail_bin, + std::optional message, + std::optional detail_message, + std::optional detail_bin, std::shared_ptr detail) { // Bounce through std::string to get a proper null-terminated C string StatusCode status_code = current_status.code(); diff --git a/cpp/src/arrow/flight/transport.h b/cpp/src/arrow/flight/transport.h index 66ded71fbe9..6406734e6e7 100644 --- a/cpp/src/arrow/flight/transport.h +++ b/cpp/src/arrow/flight/transport.h @@ -58,6 +58,7 @@ #include #include +#include #include #include #include @@ -65,7 +66,6 @@ #include "arrow/flight/type_fwd.h" #include "arrow/flight/visibility.h" #include "arrow/type_fwd.h" -#include "arrow/util/optional.h" namespace arrow { namespace ipc { @@ -265,9 +265,9 @@ struct ARROW_FLIGHT_EXPORT TransportStatus { /// back to an Arrow status. ARROW_FLIGHT_EXPORT Status ReconstructStatus(const std::string& code_str, const Status& current_status, - util::optional message, - util::optional detail_message, - util::optional detail_bin, + std::optional message, + std::optional detail_message, + std::optional detail_bin, std::shared_ptr detail); } // namespace internal diff --git a/cpp/src/arrow/flight/transport/grpc/util_internal.cc b/cpp/src/arrow/flight/transport/grpc/util_internal.cc index 0455dc119a9..be858fb6d44 100644 --- a/cpp/src/arrow/flight/transport/grpc/util_internal.cc +++ b/cpp/src/arrow/flight/transport/grpc/util_internal.cc @@ -55,25 +55,25 @@ static bool FromGrpcContext(const ::grpc::ClientContext& ctx, if (code_val == trailers.end()) return false; const auto message_val = trailers.find(kGrpcStatusMessageHeader); - const util::optional message = + const std::optional message = message_val == trailers.end() - ? util::nullopt - : util::optional( + ? std::nullopt + : std::optional( std::string(message_val->second.data(), message_val->second.size())); const auto detail_val = trailers.find(kGrpcStatusDetailHeader); - const util::optional detail_message = + const std::optional detail_message = detail_val == trailers.end() - ? util::nullopt - : util::optional( + ? std::nullopt + : std::optional( std::string(detail_val->second.data(), detail_val->second.size())); const auto grpc_detail_val = trailers.find(kBinaryErrorDetailsKey); - const util::optional detail_bin = + const std::optional detail_bin = grpc_detail_val == trailers.end() - ? util::nullopt - : util::optional(std::string(grpc_detail_val->second.data(), - grpc_detail_val->second.size())); + ? std::nullopt + : std::optional(std::string(grpc_detail_val->second.data(), + grpc_detail_val->second.size())); std::string code_str(code_val->second.data(), code_val->second.size()); *status = internal::ReconstructStatus(code_str, current_status, std::move(message), diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc index abcf7911255..373333663f8 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc @@ -274,7 +274,7 @@ Status HeadersFrame::GetStatus(Status* out) { *out = transport_status.ToStatus(); util::string_view detail_str, bin_str; - util::optional message, detail_message, detail_bin; + std::optional message, detail_message, detail_bin; if (!Get(kHeaderStatusCode).Value(&code_str).ok()) { // No Arrow status sent, go with the transport status return Status::OK(); diff --git a/cpp/src/arrow/memory_pool.cc b/cpp/src/arrow/memory_pool.cc index dff6fd83076..99cb0682462 100644 --- a/cpp/src/arrow/memory_pool.cc +++ b/cpp/src/arrow/memory_pool.cc @@ -25,6 +25,7 @@ #include #include #include +#include #if defined(sun) || defined(__sun) #include @@ -40,7 +41,6 @@ #include "arrow/util/int_util_overflow.h" #include "arrow/util/io_util.h" #include "arrow/util/logging.h" // IWYU pragma: keep -#include "arrow/util/optional.h" #include "arrow/util/string.h" #include "arrow/util/thread_pool.h" #include "arrow/util/ubsan.h" @@ -103,8 +103,8 @@ const std::vector& SupportedBackends() { // Return the MemoryPoolBackend selected by the user through the // ARROW_DEFAULT_MEMORY_POOL environment variable, if any. -util::optional UserSelectedBackend() { - static auto user_selected_backend = []() -> util::optional { +std::optional UserSelectedBackend() { + static auto user_selected_backend = []() -> std::optional { auto unsupported_backend = [](const std::string& name) { std::vector supported; for (const auto backend : SupportedBackends()) { diff --git a/cpp/src/arrow/public_api_test.cc b/cpp/src/arrow/public_api_test.cc index a2aa624d092..a611dd7920c 100644 --- a/cpp/src/arrow/public_api_test.cc +++ b/cpp/src/arrow/public_api_test.cc @@ -109,7 +109,7 @@ TEST(Misc, SetTimezoneConfig) { #else auto fs = std::make_shared(); - util::optional tzdata_result = GetTestTimezoneDatabaseRoot(); + std::optional tzdata_result = GetTestTimezoneDatabaseRoot(); std::string tzdata_dir; if (tzdata_result.has_value()) { tzdata_dir = tzdata_result.value(); @@ -129,7 +129,7 @@ TEST(Misc, SetTimezoneConfig) { ASSERT_OK_AND_ASSIGN(auto tempdir, arrow::internal::TemporaryDir::Make("tzdata")); // Validate that setting tzdb to that dir fails - arrow::GlobalOptions options = {util::make_optional(tempdir->path().ToString())}; + arrow::GlobalOptions options = {std::make_optional(tempdir->path().ToString())}; ASSERT_NOT_OK(arrow::Initialize(options)); // Copy tzdb data from ~/Downloads diff --git a/cpp/src/arrow/stl_iterator.h b/cpp/src/arrow/stl_iterator.h index e1eeb33fbae..5f2acfb071b 100644 --- a/cpp/src/arrow/stl_iterator.h +++ b/cpp/src/arrow/stl_iterator.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "arrow/chunked_array.h" @@ -27,7 +28,6 @@ #include "arrow/type_fwd.h" #include "arrow/type_traits.h" #include "arrow/util/macros.h" -#include "arrow/util/optional.h" namespace arrow { namespace stl { @@ -49,7 +49,7 @@ template > class ArrayIterator { public: - using value_type = arrow::util::optional; + using value_type = std::optional; using difference_type = int64_t; using pointer = value_type*; using reference = value_type&; @@ -138,7 +138,7 @@ template > class ChunkedArrayIterator { public: - using value_type = arrow::util::optional; + using value_type = std::optional; using difference_type = int64_t; using pointer = value_type*; using reference = value_type&; diff --git a/cpp/src/arrow/stl_iterator_test.cc b/cpp/src/arrow/stl_iterator_test.cc index d4a011e4507..652a66cb516 100644 --- a/cpp/src/arrow/stl_iterator_test.cc +++ b/cpp/src/arrow/stl_iterator_test.cc @@ -30,8 +30,8 @@ namespace arrow { using internal::checked_cast; using internal::checked_pointer_cast; -using util::nullopt; -using util::optional; +using std::nullopt; +using std::optional; namespace stl { diff --git a/cpp/src/arrow/stl_test.cc b/cpp/src/arrow/stl_test.cc index 52dda54ce18..ec12db2d74c 100644 --- a/cpp/src/arrow/stl_test.cc +++ b/cpp/src/arrow/stl_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -32,7 +33,6 @@ #include "arrow/testing/gtest_util.h" #include "arrow/type.h" #include "arrow/type_fwd.h" -#include "arrow/util/optional.h" using primitive_types_tuple = std::tuple; @@ -101,10 +101,10 @@ struct TestInt32Type { namespace arrow { using optional_types_tuple = - std::tuple, util::optional, util::optional, - util::optional, util::optional, util::optional, - util::optional, util::optional, util::optional, - util::optional>; + std::tuple, std::optional, std::optional, + std::optional, std::optional, std::optional, + std::optional, std::optional, std::optional, + std::optional>; template <> struct CTypeTraits { @@ -291,9 +291,8 @@ TEST(TestTableFromTupleVector, NullableTypesWithBoostOptional) { std::vector rows{ types_tuple(-1, -2, -3, -4, 1, 2, 3, 4, true, std::string("Tests")), types_tuple(-10, -20, -30, -40, 10, 20, 30, 40, false, std::string("Other")), - types_tuple(util::nullopt, util::nullopt, util::nullopt, util::nullopt, - util::nullopt, util::nullopt, util::nullopt, util::nullopt, - util::nullopt, util::nullopt), + types_tuple(std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt), }; std::shared_ptr
table; ASSERT_OK(TableFromTupleRange(default_memory_pool(), rows, names, &table)); diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index a4d86708800..2ba944e41f1 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -466,10 +466,10 @@ std::shared_ptr
TableFromJSON(const std::shared_ptr& schema, return *Table::FromRecordBatches(schema, std::move(batches)); } -Result> PrintArrayDiff(const ChunkedArray& expected, - const ChunkedArray& actual) { +Result> PrintArrayDiff(const ChunkedArray& expected, + const ChunkedArray& actual) { if (actual.Equals(expected)) { - return util::nullopt; + return std::nullopt; } std::stringstream ss; diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 8ce5049452a..b6bfcb8e2d3 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -38,7 +39,6 @@ #include "arrow/type_fwd.h" #include "arrow/type_traits.h" #include "arrow/util/macros.h" -#include "arrow/util/optional.h" #include "arrow/util/string_builder.h" #include "arrow/util/string_view.h" #include "arrow/util/type_fwd.h" @@ -270,7 +270,7 @@ ARROW_TESTING_EXPORT void AssertSchemaNotEqual(const std::shared_ptr& lh const std::shared_ptr& rhs, bool check_metadata = false); -ARROW_TESTING_EXPORT Result> PrintArrayDiff( +ARROW_TESTING_EXPORT Result> PrintArrayDiff( const ChunkedArray& expected, const ChunkedArray& actual); ARROW_TESTING_EXPORT void AssertTablesEqual(const Table& expected, const Table& actual, @@ -541,21 +541,4 @@ void PrintTo(const basic_string_view& view, std::ostream* os) { } } // namespace sv_lite - -namespace optional_lite { - -template -void PrintTo(const optional& opt, std::ostream* os) { - if (opt.has_value()) { - *os << "{"; - ::testing::internal::UniversalPrint(*opt, os); - *os << "}"; - } else { - *os << "nullopt"; - } -} - -inline void PrintTo(const decltype(nullopt)&, std::ostream* os) { *os << "nullopt"; } - -} // namespace optional_lite } // namespace nonstd diff --git a/cpp/src/arrow/testing/matchers.h b/cpp/src/arrow/testing/matchers.h index 25607c1ff5b..4d5bb695757 100644 --- a/cpp/src/arrow/testing/matchers.h +++ b/cpp/src/arrow/testing/matchers.h @@ -220,14 +220,14 @@ class ResultMatcher { class ErrorMatcher { public: explicit ErrorMatcher(StatusCode code, - util::optional> message_matcher) + std::optional> message_matcher) : code_(code), message_matcher_(std::move(message_matcher)) {} template operator testing::Matcher() const { // NOLINT runtime/explicit struct Impl : testing::MatcherInterface { explicit Impl(StatusCode code, - util::optional> message_matcher) + std::optional> message_matcher) : code_(code), message_matcher_(std::move(message_matcher)) {} void DescribeTo(::std::ostream* os) const override { @@ -270,7 +270,7 @@ class ErrorMatcher { } const StatusCode code_; - const util::optional> message_matcher_; + const std::optional> message_matcher_; }; return testing::Matcher(new Impl(code_, message_matcher_)); @@ -278,7 +278,7 @@ class ErrorMatcher { private: const StatusCode code_; - const util::optional> message_matcher_; + const std::optional> message_matcher_; }; class OkMatcher { @@ -324,7 +324,7 @@ inline OkMatcher Ok() { return {}; } // Returns a matcher that matches the StatusCode of a Status or Result. // Do not use Raises(StatusCode::OK) to match a non error code. -inline ErrorMatcher Raises(StatusCode code) { return ErrorMatcher(code, util::nullopt); } +inline ErrorMatcher Raises(StatusCode code) { return ErrorMatcher(code, std::nullopt); } // Returns a matcher that matches the StatusCode and message of a Status or Result. template @@ -421,7 +421,7 @@ template ::ArrayType, typename BuilderType = typename TypeTraits::BuilderType, typename ValueType = typename ::arrow::stl::detail::DefaultValueAccessor::ValueType> -DataEqMatcher DataEqArray(T type, const std::vector>& values) { +DataEqMatcher DataEqArray(T type, const std::vector>& values) { // FIXME(bkietz) broken until DataType is move constructible BuilderType builder(std::make_shared(std::move(type)), default_memory_pool()); DCHECK_OK(builder.Reserve(static_cast(values.size()))); @@ -453,7 +453,7 @@ inline DataEqMatcher DataEqScalar(const std::shared_ptr& type, /// Constructs a scalar against which arguments are matched template ::ScalarType, typename ValueType = typename ScalarType::ValueType> -DataEqMatcher DataEqScalar(T type, util::optional value) { +DataEqMatcher DataEqScalar(T type, std::optional value) { ScalarType expected(std::make_shared(std::move(type))); if (value) { diff --git a/cpp/src/arrow/testing/util.cc b/cpp/src/arrow/testing/util.cc index bc8e1e26995..b5985448076 100644 --- a/cpp/src/arrow/testing/util.cc +++ b/cpp/src/arrow/testing/util.cc @@ -111,12 +111,12 @@ Status GetTestResourceRoot(std::string* out) { return Status::OK(); } -util::optional GetTestTimezoneDatabaseRoot() { +std::optional GetTestTimezoneDatabaseRoot() { const char* c_root = std::getenv("ARROW_TIMEZONE_DATABASE"); if (!c_root) { - return util::optional(); + return std::optional(); } - return util::make_optional(std::string(c_root)); + return std::make_optional(std::string(c_root)); } Status InitTestTimezoneDatabase() { @@ -125,7 +125,7 @@ Status InitTestTimezoneDatabase() { if (!maybe_tzdata.has_value()) return Status::OK(); auto tzdata_path = std::string(maybe_tzdata.value()); - arrow::GlobalOptions options = {util::make_optional(tzdata_path)}; + arrow::GlobalOptions options = {std::make_optional(tzdata_path)}; ARROW_RETURN_NOT_OK(arrow::Initialize(options)); return Status::OK(); } diff --git a/cpp/src/arrow/testing/util.h b/cpp/src/arrow/testing/util.h index 457713f969b..4f4b03438fd 100644 --- a/cpp/src/arrow/testing/util.h +++ b/cpp/src/arrow/testing/util.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -34,7 +35,6 @@ #include "arrow/testing/visibility.h" #include "arrow/type_fwd.h" #include "arrow/util/macros.h" -#include "arrow/util/optional.h" namespace arrow { @@ -112,7 +112,7 @@ UnionTypeFactories() { ARROW_TESTING_EXPORT Status GetTestResourceRoot(std::string*); // Return the value of the ARROW_TIMEZONE_DATABASE environment variable -ARROW_TESTING_EXPORT util::optional GetTestTimezoneDatabaseRoot(); +ARROW_TESTING_EXPORT std::optional GetTestTimezoneDatabaseRoot(); // Set the Timezone database based on the ARROW_TIMEZONE_DATABASE env variable // This is only relevant on Windows, since other OSs have compatible databases built-in diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index 9819b5ce923..d4a9c2829a7 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "arrow/util/async_util.h" @@ -30,7 +31,6 @@ #include "arrow/util/io_util.h" #include "arrow/util/iterator.h" #include "arrow/util/mutex.h" -#include "arrow/util/optional.h" #include "arrow/util/queue.h" #include "arrow/util/thread_pool.h" @@ -498,7 +498,7 @@ class TransformingGenerator { } // See comment on TransformingIterator::Pump - Result> Pump() { + Result> Pump() { if (!finished_ && last_value_.has_value()) { ARROW_ASSIGN_OR_RAISE(TransformFlow next, transformer_(*last_value_)); if (next.ReadyForNext()) { @@ -517,12 +517,12 @@ class TransformingGenerator { if (finished_) { return IterationTraits::End(); } - return util::nullopt; + return std::nullopt; } AsyncGenerator generator_; Transformer transformer_; - util::optional last_value_; + std::optional last_value_; bool finished_; }; @@ -839,7 +839,7 @@ class PushGenerator { util::Mutex mutex; std::deque> result_q; - util::optional> consumer_fut; + std::optional> consumer_fut; bool finished = false; }; @@ -1726,7 +1726,7 @@ class BackgroundGenerator { bool should_shutdown; // If the queue is empty, the consumer will create a waiting future and wait for it std::queue> queue; - util::optional> waiting_future; + std::optional> waiting_future; // Every background task is given a future to complete when it is entirely finished // processing and ready for the next task to start or for State to be destroyed Future<> task_finished; diff --git a/cpp/src/arrow/util/async_generator_test.cc b/cpp/src/arrow/util/async_generator_test.cc index e75ca577c77..2b4c869bedc 100644 --- a/cpp/src/arrow/util/async_generator_test.cc +++ b/cpp/src/arrow/util/async_generator_test.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -31,7 +32,6 @@ #include "arrow/type_fwd.h" #include "arrow/util/async_generator.h" #include "arrow/util/async_util.h" -#include "arrow/util/optional.h" #include "arrow/util/test_common.h" #include "arrow/util/vector.h" @@ -1846,7 +1846,7 @@ TEST(PushGenerator, CloseEarly) { } TEST(PushGenerator, DanglingProducer) { - util::optional> gen; + std::optional> gen; gen.emplace(); auto producer = gen->producer(); diff --git a/cpp/src/arrow/util/async_util.cc b/cpp/src/arrow/util/async_util.cc index 7e8c5513aab..3ac4519a6f0 100644 --- a/cpp/src/arrow/util/async_util.cc +++ b/cpp/src/arrow/util/async_util.cc @@ -34,14 +34,14 @@ class ThrottleImpl : public AsyncTaskScheduler::Throttle { public: explicit ThrottleImpl(int max_concurrent_cost) : available_cost_(max_concurrent_cost) {} - util::optional> TryAcquire(int amt) override { + std::optional> TryAcquire(int amt) override { std::lock_guard lk(mutex_); if (backoff_.is_valid()) { return backoff_; } if (amt <= available_cost_) { available_cost_ -= amt; - return nullopt; + return std::nullopt; } backoff_ = Future<>::Make(); return backoff_; @@ -151,7 +151,7 @@ class AsyncTaskSchedulerImpl : public AsyncTaskScheduler { queue_->Push(std::move(task)); return true; } - util::optional> maybe_backoff = throttle_->TryAcquire(task->cost()); + std::optional> maybe_backoff = throttle_->TryAcquire(task->cost()); if (maybe_backoff) { queue_->Push(std::move(task)); lk.unlock(); @@ -238,7 +238,7 @@ class AsyncTaskSchedulerImpl : public AsyncTaskScheduler { void ContinueTasksUnlocked(std::unique_lock&& lk) { while (!queue_->Empty()) { int next_cost = queue_->Peek().cost(); - util::optional> maybe_backoff = throttle_->TryAcquire(next_cost); + std::optional> maybe_backoff = throttle_->TryAcquire(next_cost); if (maybe_backoff) { lk.unlock(); if (!maybe_backoff->TryAddCallback([this] { diff --git a/cpp/src/arrow/util/async_util.h b/cpp/src/arrow/util/async_util.h index 653654668fd..707f70d471f 100644 --- a/cpp/src/arrow/util/async_util.h +++ b/cpp/src/arrow/util/async_util.h @@ -140,7 +140,7 @@ class ARROW_EXPORT AsyncTaskScheduler { /// acquired and the caller can proceed. If a future is returned then the caller /// should wait for the future to complete first. When the returned future completes /// the permits have NOT been acquired and the caller must call Acquire again - virtual util::optional> TryAcquire(int amt) = 0; + virtual std::optional> TryAcquire(int amt) = 0; /// Release amt permits /// /// This will possibly complete waiting futures and should probably not be diff --git a/cpp/src/arrow/util/async_util_test.cc b/cpp/src/arrow/util/async_util_test.cc index 25a3ca77cea..dfb688f70d1 100644 --- a/cpp/src/arrow/util/async_util_test.cc +++ b/cpp/src/arrow/util/async_util_test.cc @@ -200,9 +200,9 @@ TEST(AsyncTaskScheduler, SubSchedulerNoTasks) { class CustomThrottle : public AsyncTaskScheduler::Throttle { public: - virtual util::optional> TryAcquire(int amt) { + virtual std::optional> TryAcquire(int amt) { if (gate_.is_finished()) { - return nullopt; + return std::nullopt; } else { return gate_; } diff --git a/cpp/src/arrow/util/cancel_test.cc b/cpp/src/arrow/util/cancel_test.cc index b9bf94ba43a..bca78034c04 100644 --- a/cpp/src/arrow/util/cancel_test.cc +++ b/cpp/src/arrow/util/cancel_test.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -35,7 +36,6 @@ #include "arrow/util/future.h" #include "arrow/util/io_util.h" #include "arrow/util/logging.h" -#include "arrow/util/optional.h" namespace arrow { @@ -100,7 +100,7 @@ TEST_F(CancelTest, Unstoppable) { TEST_F(CancelTest, SourceVanishes) { { - util::optional source{StopSource()}; + std::optional source{StopSource()}; StopToken token = source->token(); ASSERT_FALSE(token.IsStopRequested()); ASSERT_OK(token.Poll()); @@ -110,7 +110,7 @@ TEST_F(CancelTest, SourceVanishes) { ASSERT_OK(token.Poll()); } { - util::optional source{StopSource()}; + std::optional source{StopSource()}; StopToken token = source->token(); source->RequestStop(); @@ -125,7 +125,7 @@ static void noop_signal_handler(int signum) { } #ifndef _WIN32 -static util::optional signal_stop_source; +static std::optional signal_stop_source; static void signal_handler(int signum) { signal_stop_source->RequestStopFromSignal(signum); @@ -207,8 +207,8 @@ class SignalCancelTest : public CancelTest { #else const int expected_signal_ = SIGALRM; #endif - util::optional guard_; - util::optional stop_token_; + std::optional guard_; + std::optional stop_token_; }; TEST_F(SignalCancelTest, Register) { diff --git a/cpp/src/arrow/util/cpu_info.cc b/cpp/src/arrow/util/cpu_info.cc index fbe55aec0c1..9bc33f04570 100644 --- a/cpp/src/arrow/util/cpu_info.cc +++ b/cpp/src/arrow/util/cpu_info.cc @@ -41,13 +41,13 @@ #include #include #include +#include #include #include #include "arrow/result.h" #include "arrow/util/io_util.h" #include "arrow/util/logging.h" -#include "arrow/util/optional.h" #include "arrow/util/string.h" #undef CPUINFO_ARCH_X86 @@ -226,7 +226,7 @@ void OsRetrieveCpuInfo(int64_t* hardware_flags, CpuInfo::Vendor* vendor, #elif defined(__APPLE__) //------------------------------ MACOS ------------------------------// -util::optional IntegerSysCtlByName(const char* name) { +std::optional IntegerSysCtlByName(const char* name) { size_t len = sizeof(int64_t); int64_t data = 0; if (sysctlbyname(name, &data, &len, nullptr, 0) == 0) { @@ -238,7 +238,7 @@ util::optional IntegerSysCtlByName(const char* name) { auto st = IOErrorFromErrno(errno, "sysctlbyname failed for '", name, "'"); ARROW_LOG(WARNING) << st.ToString(); } - return util::nullopt; + return std::nullopt; } void OsRetrieveCacheSize(std::array* cache_sizes) { diff --git a/cpp/src/arrow/util/future.h b/cpp/src/arrow/util/future.h index 3be4e334b14..7fc0d5063f6 100644 --- a/cpp/src/arrow/util/future.h +++ b/cpp/src/arrow/util/future.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -32,7 +33,6 @@ #include "arrow/util/config.h" #include "arrow/util/functional.h" #include "arrow/util/macros.h" -#include "arrow/util/optional.h" #include "arrow/util/tracing.h" #include "arrow/util/type_fwd.h" #include "arrow/util/visibility.h" @@ -781,18 +781,18 @@ Future<> AllFinished(const std::vector>& futures); struct Continue { template - operator util::optional() && { // NOLINT explicit + operator std::optional() && { // NOLINT explicit return {}; } }; template -util::optional Break(T break_value = {}) { - return util::optional{std::move(break_value)}; +std::optional Break(T break_value = {}) { + return std::optional{std::move(break_value)}; } template -using ControlFlow = util::optional; +using ControlFlow = std::optional; /// \brief Loop through an asynchronous sequence /// diff --git a/cpp/src/arrow/util/io_util_test.cc b/cpp/src/arrow/util/io_util_test.cc index 57c75fff3c7..f4fcc26d072 100644 --- a/cpp/src/arrow/util/io_util_test.cc +++ b/cpp/src/arrow/util/io_util_test.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -41,7 +42,6 @@ #include "arrow/util/cpu_info.h" #include "arrow/util/io_util.h" #include "arrow/util/logging.h" -#include "arrow/util/optional.h" #include "arrow/util/windows_compatibility.h" #include "arrow/util/windows_fixup.h" diff --git a/cpp/src/arrow/util/iterator.h b/cpp/src/arrow/util/iterator.h index 2f42803d26f..0eae7f6a857 100644 --- a/cpp/src/arrow/util/iterator.h +++ b/cpp/src/arrow/util/iterator.h @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -30,7 +31,6 @@ #include "arrow/util/compare.h" #include "arrow/util/functional.h" #include "arrow/util/macros.h" -#include "arrow/util/optional.h" #include "arrow/util/visibility.h" namespace arrow { @@ -66,16 +66,16 @@ bool IsIterationEnd(const T& val) { } template -struct IterationTraits> { +struct IterationTraits> { /// \brief by default when iterating through a sequence of optional, /// nullopt indicates the end of iteration. /// Specialize IterationTraits if different end semantics are required. - static util::optional End() { return util::nullopt; } + static std::optional End() { return std::nullopt; } /// \brief by default when iterating through a sequence of optional, /// nullopt (!has_value()) indicates the end of iteration. /// Specialize IterationTraits if different end semantics are required. - static bool IsEnd(const util::optional& val) { return !val.has_value(); } + static bool IsEnd(const std::optional& val) { return !val.has_value(); } // TODO(bkietz) The range-for loop over Iterator> yields // Result> which is unnecessary (since only the unyielded end optional @@ -227,7 +227,7 @@ struct TransformFlow { bool finished_ = false; bool ready_for_next_ = false; - util::optional yield_value_; + std::optional yield_value_; }; struct TransformFinish { @@ -263,7 +263,7 @@ class TransformIterator { Result Next() { while (!finished_) { - ARROW_ASSIGN_OR_RAISE(util::optional next, Pump()); + ARROW_ASSIGN_OR_RAISE(std::optional next, Pump()); if (next.has_value()) { return std::move(*next); } @@ -278,7 +278,7 @@ class TransformIterator { // * If an invalid status is encountered that will be returned // * If finished it will return IterationTraits::End() // * If a value is returned by the transformer that will be returned - Result> Pump() { + Result> Pump() { if (!finished_ && last_value_.has_value()) { auto next_res = transformer_(*last_value_); if (!next_res.ok()) { @@ -302,12 +302,12 @@ class TransformIterator { if (finished_) { return IterationTraits::End(); } - return util::nullopt; + return std::nullopt; } Iterator it_; Transformer transformer_; - util::optional last_value_; + std::optional last_value_; bool finished_ = false; }; diff --git a/cpp/src/arrow/util/optional.h b/cpp/src/arrow/util/optional.h deleted file mode 100644 index e1c32e76134..00000000000 --- a/cpp/src/arrow/util/optional.h +++ /dev/null @@ -1,35 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#define optional_CONFIG_SELECT_OPTIONAL optional_OPTIONAL_NONSTD - -#include "arrow/vendored/optional.hpp" // IWYU pragma: export - -namespace arrow { -namespace util { - -template -using optional = nonstd::optional; - -using nonstd::bad_optional_access; -using nonstd::make_optional; -using nonstd::nullopt; - -} // namespace util -} // namespace arrow diff --git a/cpp/src/arrow/util/reflection_test.cc b/cpp/src/arrow/util/reflection_test.cc index fb3d3b8fb02..8ca9077ddc6 100644 --- a/cpp/src/arrow/util/reflection_test.cc +++ b/cpp/src/arrow/util/reflection_test.cc @@ -79,7 +79,7 @@ struct FromStringImpl { props.ForEach(*this); } - void Fail() { obj_ = util::nullopt; } + void Fail() { obj_ = std::nullopt; } void Init(util::string_view class_name, util::string_view repr, size_t num_properties) { if (!repr.starts_with(class_name)) return Fail(); @@ -116,7 +116,7 @@ struct FromStringImpl { prop.set(&*obj_, std::move(value)); } - util::optional obj_ = Class{}; + std::optional obj_ = Class{}; std::vector members_; }; @@ -146,7 +146,7 @@ std::string ToString(const Person& obj) { void PrintTo(const Person& obj, std::ostream* os) { *os << ToString(obj); } -util::optional PersonFromString(util::string_view repr) { +std::optional PersonFromString(util::string_view repr) { return FromStringImpl("Person", repr, kPersonProperties).obj_; } @@ -174,23 +174,23 @@ TEST(Reflection, FromStringToDataMembers) { EXPECT_EQ(PersonFromString(ToString(genos)), genos); - EXPECT_EQ(PersonFromString(""), util::nullopt); - EXPECT_EQ(PersonFromString("Per"), util::nullopt); - EXPECT_EQ(PersonFromString("Person{"), util::nullopt); - EXPECT_EQ(PersonFromString("Person{age:19,name:Genos"), util::nullopt); + EXPECT_EQ(PersonFromString(""), std::nullopt); + EXPECT_EQ(PersonFromString("Per"), std::nullopt); + EXPECT_EQ(PersonFromString("Person{"), std::nullopt); + EXPECT_EQ(PersonFromString("Person{age:19,name:Genos"), std::nullopt); - EXPECT_EQ(PersonFromString("Person{name:Genos"), util::nullopt); - EXPECT_EQ(PersonFromString("Person{age:19,name:Genos,extra:Cyborg}"), util::nullopt); - EXPECT_EQ(PersonFromString("Person{name:Genos,age:19"), util::nullopt); + EXPECT_EQ(PersonFromString("Person{name:Genos"), std::nullopt); + EXPECT_EQ(PersonFromString("Person{age:19,name:Genos,extra:Cyborg}"), std::nullopt); + EXPECT_EQ(PersonFromString("Person{name:Genos,age:19"), std::nullopt); - EXPECT_EQ(PersonFromString("Fake{age:19,name:Genos}"), util::nullopt); + EXPECT_EQ(PersonFromString("Fake{age:19,name:Genos}"), std::nullopt); - EXPECT_EQ(PersonFromString("Person{age,name:Genos}"), util::nullopt); - EXPECT_EQ(PersonFromString("Person{age:nineteen,name:Genos}"), util::nullopt); - EXPECT_EQ(PersonFromString("Person{age:19 ,name:Genos}"), util::nullopt); - EXPECT_EQ(PersonFromString("Person{age:19,moniker:Genos}"), util::nullopt); + EXPECT_EQ(PersonFromString("Person{age,name:Genos}"), std::nullopt); + EXPECT_EQ(PersonFromString("Person{age:nineteen,name:Genos}"), std::nullopt); + EXPECT_EQ(PersonFromString("Person{age:19 ,name:Genos}"), std::nullopt); + EXPECT_EQ(PersonFromString("Person{age:19,moniker:Genos}"), std::nullopt); - EXPECT_EQ(PersonFromString("Person{age: 19, name: Genos}"), util::nullopt); + EXPECT_EQ(PersonFromString("Person{age: 19, name: Genos}"), std::nullopt); } enum class PersonType : int8_t { diff --git a/cpp/src/arrow/util/string.cc b/cpp/src/arrow/util/string.cc index 00ab8e64c47..09df881a9b0 100644 --- a/cpp/src/arrow/util/string.cc +++ b/cpp/src/arrow/util/string.cc @@ -182,11 +182,11 @@ std::string AsciiToUpper(util::string_view value) { return result; } -util::optional Replace(util::string_view s, util::string_view token, - util::string_view replacement) { +std::optional Replace(util::string_view s, util::string_view token, + util::string_view replacement) { size_t token_start = s.find(token); if (token_start == std::string::npos) { - return util::nullopt; + return std::nullopt; } return s.substr(0, token_start).to_string() + replacement.to_string() + s.substr(token_start + token.size()).to_string(); diff --git a/cpp/src/arrow/util/string.h b/cpp/src/arrow/util/string.h index b2baa0ebeda..fd9a3d1e063 100644 --- a/cpp/src/arrow/util/string.h +++ b/cpp/src/arrow/util/string.h @@ -17,11 +17,11 @@ #pragma once +#include #include #include #include "arrow/result.h" -#include "arrow/util/optional.h" #include "arrow/util/string_view.h" #include "arrow/util/visibility.h" @@ -74,8 +74,8 @@ std::string AsciiToUpper(util::string_view value); /// \brief Search for the first instance of a token and replace it or return nullopt if /// the token is not found. ARROW_EXPORT -util::optional Replace(util::string_view s, util::string_view token, - util::string_view replacement); +std::optional Replace(util::string_view s, util::string_view token, + util::string_view replacement); /// \brief Get boolean value from string /// diff --git a/cpp/src/arrow/util/task_group.cc b/cpp/src/arrow/util/task_group.cc index 0679b6ef1f6..96b7fbb8b3b 100644 --- a/cpp/src/arrow/util/task_group.cc +++ b/cpp/src/arrow/util/task_group.cc @@ -207,7 +207,7 @@ class ThreadedTaskGroup : public TaskGroup { std::mutex mutex_; std::condition_variable cv_; Status status_; - util::optional> completion_future_; + std::optional> completion_future_; }; } // namespace diff --git a/cpp/src/arrow/vendored/optional.hpp b/cpp/src/arrow/vendored/optional.hpp deleted file mode 100644 index e266bb20be2..00000000000 --- a/cpp/src/arrow/vendored/optional.hpp +++ /dev/null @@ -1,1553 +0,0 @@ -// Vendored from git tag v3.2.0 - -// Copyright (c) 2014-2018 Martin Moene -// -// https://github.com/martinmoene/optional-lite -// -// Distributed under the Boost Software License, Version 1.0. -// (See accompanying file LICENSE.txt or copy at http://www.boost.org/LICENSE_1_0.txt) - -#pragma once - -#ifndef NONSTD_OPTIONAL_LITE_HPP -#define NONSTD_OPTIONAL_LITE_HPP - -#define optional_lite_MAJOR 3 -#define optional_lite_MINOR 2 -#define optional_lite_PATCH 0 - -#define optional_lite_VERSION optional_STRINGIFY(optional_lite_MAJOR) "." optional_STRINGIFY(optional_lite_MINOR) "." optional_STRINGIFY(optional_lite_PATCH) - -#define optional_STRINGIFY( x ) optional_STRINGIFY_( x ) -#define optional_STRINGIFY_( x ) #x - -// optional-lite configuration: - -#define optional_OPTIONAL_DEFAULT 0 -#define optional_OPTIONAL_NONSTD 1 -#define optional_OPTIONAL_STD 2 - -#if !defined( optional_CONFIG_SELECT_OPTIONAL ) -# define optional_CONFIG_SELECT_OPTIONAL ( optional_HAVE_STD_OPTIONAL ? optional_OPTIONAL_STD : optional_OPTIONAL_NONSTD ) -#endif - -// Control presence of exception handling (try and auto discover): - -#ifndef optional_CONFIG_NO_EXCEPTIONS -# if defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND) -# define optional_CONFIG_NO_EXCEPTIONS 0 -# else -# define optional_CONFIG_NO_EXCEPTIONS 1 -# endif -#endif - -// C++ language version detection (C++20 is speculative): -// Note: VC14.0/1900 (VS2015) lacks too much from C++14. - -#ifndef optional_CPLUSPLUS -# if defined(_MSVC_LANG ) && !defined(__clang__) -# define optional_CPLUSPLUS (_MSC_VER == 1900 ? 201103L : _MSVC_LANG ) -# else -# define optional_CPLUSPLUS __cplusplus -# endif -#endif - -#define optional_CPP98_OR_GREATER ( optional_CPLUSPLUS >= 199711L ) -#define optional_CPP11_OR_GREATER ( optional_CPLUSPLUS >= 201103L ) -#define optional_CPP11_OR_GREATER_ ( optional_CPLUSPLUS >= 201103L ) -#define optional_CPP14_OR_GREATER ( optional_CPLUSPLUS >= 201402L ) -#define optional_CPP17_OR_GREATER ( optional_CPLUSPLUS >= 201703L ) -#define optional_CPP20_OR_GREATER ( optional_CPLUSPLUS >= 202000L ) - -// C++ language version (represent 98 as 3): - -#define optional_CPLUSPLUS_V ( optional_CPLUSPLUS / 100 - (optional_CPLUSPLUS > 200000 ? 2000 : 1994) ) - -// Use C++17 std::optional if available and requested: - -#if optional_CPP17_OR_GREATER && defined(__has_include ) -# if __has_include( ) -# define optional_HAVE_STD_OPTIONAL 1 -# else -# define optional_HAVE_STD_OPTIONAL 0 -# endif -#else -# define optional_HAVE_STD_OPTIONAL 0 -#endif - -#define optional_USES_STD_OPTIONAL ( (optional_CONFIG_SELECT_OPTIONAL == optional_OPTIONAL_STD) || ((optional_CONFIG_SELECT_OPTIONAL == optional_OPTIONAL_DEFAULT) && optional_HAVE_STD_OPTIONAL) ) - -// -// in_place: code duplicated in any-lite, expected-lite, optional-lite, value-ptr-lite, variant-lite: -// - -#ifndef nonstd_lite_HAVE_IN_PLACE_TYPES -#define nonstd_lite_HAVE_IN_PLACE_TYPES 1 - -// C++17 std::in_place in : - -#if optional_CPP17_OR_GREATER - -#include - -namespace nonstd { - -using std::in_place; -using std::in_place_type; -using std::in_place_index; -using std::in_place_t; -using std::in_place_type_t; -using std::in_place_index_t; - -#define nonstd_lite_in_place_t( T) std::in_place_t -#define nonstd_lite_in_place_type_t( T) std::in_place_type_t -#define nonstd_lite_in_place_index_t(K) std::in_place_index_t - -#define nonstd_lite_in_place( T) std::in_place_t{} -#define nonstd_lite_in_place_type( T) std::in_place_type_t{} -#define nonstd_lite_in_place_index(K) std::in_place_index_t{} - -} // namespace nonstd - -#else // optional_CPP17_OR_GREATER - -#include - -namespace nonstd { -namespace detail { - -template< class T > -struct in_place_type_tag {}; - -template< std::size_t K > -struct in_place_index_tag {}; - -} // namespace detail - -struct in_place_t {}; - -template< class T > -inline in_place_t in_place( detail::in_place_type_tag /*unused*/ = detail::in_place_type_tag() ) -{ - return in_place_t(); -} - -template< std::size_t K > -inline in_place_t in_place( detail::in_place_index_tag /*unused*/ = detail::in_place_index_tag() ) -{ - return in_place_t(); -} - -template< class T > -inline in_place_t in_place_type( detail::in_place_type_tag /*unused*/ = detail::in_place_type_tag() ) -{ - return in_place_t(); -} - -template< std::size_t K > -inline in_place_t in_place_index( detail::in_place_index_tag /*unused*/ = detail::in_place_index_tag() ) -{ - return in_place_t(); -} - -// mimic templated typedef: - -#define nonstd_lite_in_place_t( T) nonstd::in_place_t(&)( nonstd::detail::in_place_type_tag ) -#define nonstd_lite_in_place_type_t( T) nonstd::in_place_t(&)( nonstd::detail::in_place_type_tag ) -#define nonstd_lite_in_place_index_t(K) nonstd::in_place_t(&)( nonstd::detail::in_place_index_tag ) - -#define nonstd_lite_in_place( T) nonstd::in_place_type -#define nonstd_lite_in_place_type( T) nonstd::in_place_type -#define nonstd_lite_in_place_index(K) nonstd::in_place_index - -} // namespace nonstd - -#endif // optional_CPP17_OR_GREATER -#endif // nonstd_lite_HAVE_IN_PLACE_TYPES - -// -// Using std::optional: -// - -#if optional_USES_STD_OPTIONAL - -#include - -namespace nonstd { - - using std::optional; - using std::bad_optional_access; - using std::hash; - - using std::nullopt; - using std::nullopt_t; - - using std::operator==; - using std::operator!=; - using std::operator<; - using std::operator<=; - using std::operator>; - using std::operator>=; - using std::make_optional; - using std::swap; -} - -#else // optional_USES_STD_OPTIONAL - -#include -#include - -// optional-lite alignment configuration: - -#ifndef optional_CONFIG_MAX_ALIGN_HACK -# define optional_CONFIG_MAX_ALIGN_HACK 0 -#endif - -#ifndef optional_CONFIG_ALIGN_AS -// no default, used in #if defined() -#endif - -#ifndef optional_CONFIG_ALIGN_AS_FALLBACK -# define optional_CONFIG_ALIGN_AS_FALLBACK double -#endif - -// Compiler warning suppression: - -#if defined(__clang__) -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wundef" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wundef" -#elif defined(_MSC_VER ) -# pragma warning( push ) -#endif - -// half-open range [lo..hi): -#define optional_BETWEEN( v, lo, hi ) ( (lo) <= (v) && (v) < (hi) ) - -// Compiler versions: -// -// MSVC++ 6.0 _MSC_VER == 1200 (Visual Studio 6.0) -// MSVC++ 7.0 _MSC_VER == 1300 (Visual Studio .NET 2002) -// MSVC++ 7.1 _MSC_VER == 1310 (Visual Studio .NET 2003) -// MSVC++ 8.0 _MSC_VER == 1400 (Visual Studio 2005) -// MSVC++ 9.0 _MSC_VER == 1500 (Visual Studio 2008) -// MSVC++ 10.0 _MSC_VER == 1600 (Visual Studio 2010) -// MSVC++ 11.0 _MSC_VER == 1700 (Visual Studio 2012) -// MSVC++ 12.0 _MSC_VER == 1800 (Visual Studio 2013) -// MSVC++ 14.0 _MSC_VER == 1900 (Visual Studio 2015) -// MSVC++ 14.1 _MSC_VER >= 1910 (Visual Studio 2017) - -#if defined(_MSC_VER ) && !defined(__clang__) -# define optional_COMPILER_MSVC_VER (_MSC_VER ) -# define optional_COMPILER_MSVC_VERSION (_MSC_VER / 10 - 10 * ( 5 + (_MSC_VER < 1900 ) ) ) -#else -# define optional_COMPILER_MSVC_VER 0 -# define optional_COMPILER_MSVC_VERSION 0 -#endif - -#define optional_COMPILER_VERSION( major, minor, patch ) ( 10 * (10 * (major) + (minor) ) + (patch) ) - -#if defined(__GNUC__) && !defined(__clang__) -# define optional_COMPILER_GNUC_VERSION optional_COMPILER_VERSION(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) -#else -# define optional_COMPILER_GNUC_VERSION 0 -#endif - -#if defined(__clang__) -# define optional_COMPILER_CLANG_VERSION optional_COMPILER_VERSION(__clang_major__, __clang_minor__, __clang_patchlevel__) -#else -# define optional_COMPILER_CLANG_VERSION 0 -#endif - -#if optional_BETWEEN(optional_COMPILER_MSVC_VERSION, 70, 140 ) -# pragma warning( disable: 4345 ) // initialization behavior changed -#endif - -#if optional_BETWEEN(optional_COMPILER_MSVC_VERSION, 70, 150 ) -# pragma warning( disable: 4814 ) // in C++14 'constexpr' will not imply 'const' -#endif - -// Presence of language and library features: - -#define optional_HAVE(FEATURE) ( optional_HAVE_##FEATURE ) - -#ifdef _HAS_CPP0X -# define optional_HAS_CPP0X _HAS_CPP0X -#else -# define optional_HAS_CPP0X 0 -#endif - -// Unless defined otherwise below, consider VC14 as C++11 for optional-lite: - -#if optional_COMPILER_MSVC_VER >= 1900 -# undef optional_CPP11_OR_GREATER -# define optional_CPP11_OR_GREATER 1 -#endif - -#define optional_CPP11_90 (optional_CPP11_OR_GREATER_ || optional_COMPILER_MSVC_VER >= 1500) -#define optional_CPP11_100 (optional_CPP11_OR_GREATER_ || optional_COMPILER_MSVC_VER >= 1600) -#define optional_CPP11_110 (optional_CPP11_OR_GREATER_ || optional_COMPILER_MSVC_VER >= 1700) -#define optional_CPP11_120 (optional_CPP11_OR_GREATER_ || optional_COMPILER_MSVC_VER >= 1800) -#define optional_CPP11_140 (optional_CPP11_OR_GREATER_ || optional_COMPILER_MSVC_VER >= 1900) -#define optional_CPP11_141 (optional_CPP11_OR_GREATER_ || optional_COMPILER_MSVC_VER >= 1910) - -#define optional_CPP14_000 (optional_CPP14_OR_GREATER) -#define optional_CPP17_000 (optional_CPP17_OR_GREATER) - -// Presence of C++11 language features: - -#define optional_HAVE_CONSTEXPR_11 optional_CPP11_140 -#define optional_HAVE_IS_DEFAULT optional_CPP11_140 -#define optional_HAVE_NOEXCEPT optional_CPP11_140 -#define optional_HAVE_NULLPTR optional_CPP11_100 -#define optional_HAVE_REF_QUALIFIER optional_CPP11_140 - -// Presence of C++14 language features: - -#define optional_HAVE_CONSTEXPR_14 optional_CPP14_000 - -// Presence of C++17 language features: - -#define optional_HAVE_NODISCARD optional_CPP17_000 - -// Presence of C++ library features: - -#define optional_HAVE_CONDITIONAL optional_CPP11_120 -#define optional_HAVE_REMOVE_CV optional_CPP11_120 -#define optional_HAVE_TYPE_TRAITS optional_CPP11_90 - -#define optional_HAVE_TR1_TYPE_TRAITS (!! optional_COMPILER_GNUC_VERSION ) -#define optional_HAVE_TR1_ADD_POINTER (!! optional_COMPILER_GNUC_VERSION ) - -// C++ feature usage: - -#if optional_HAVE( CONSTEXPR_11 ) -# define optional_constexpr constexpr -#else -# define optional_constexpr /*constexpr*/ -#endif - -#if optional_HAVE( IS_DEFAULT ) -# define optional_is_default = default; -#else -# define optional_is_default {} -#endif - -#if optional_HAVE( CONSTEXPR_14 ) -# define optional_constexpr14 constexpr -#else -# define optional_constexpr14 /*constexpr*/ -#endif - -#if optional_HAVE( NODISCARD ) -# define optional_nodiscard [[nodiscard]] -#else -# define optional_nodiscard /*[[nodiscard]]*/ -#endif - -#if optional_HAVE( NOEXCEPT ) -# define optional_noexcept noexcept -#else -# define optional_noexcept /*noexcept*/ -#endif - -#if optional_HAVE( NULLPTR ) -# define optional_nullptr nullptr -#else -# define optional_nullptr NULL -#endif - -#if optional_HAVE( REF_QUALIFIER ) -// NOLINTNEXTLINE( bugprone-macro-parentheses ) -# define optional_ref_qual & -# define optional_refref_qual && -#else -# define optional_ref_qual /*&*/ -# define optional_refref_qual /*&&*/ -#endif - -// additional includes: - -#if optional_CONFIG_NO_EXCEPTIONS -// already included: -#else -# include -#endif - -#if optional_CPP11_OR_GREATER -# include -#endif - -#if optional_HAVE( INITIALIZER_LIST ) -# include -#endif - -#if optional_HAVE( TYPE_TRAITS ) -# include -#elif optional_HAVE( TR1_TYPE_TRAITS ) -# include -#endif - -// Method enabling - -#if optional_CPP11_OR_GREATER - -#define optional_REQUIRES_0(...) \ - template< bool B = (__VA_ARGS__), typename std::enable_if::type = 0 > - -#define optional_REQUIRES_T(...) \ - , typename = typename std::enable_if< (__VA_ARGS__), nonstd::optional_lite::detail::enabler >::type - -#define optional_REQUIRES_R(R, ...) \ - typename std::enable_if< (__VA_ARGS__), R>::type - -#define optional_REQUIRES_A(...) \ - , typename std::enable_if< (__VA_ARGS__), void*>::type = nullptr - -#endif - -// -// optional: -// - -namespace nonstd { namespace optional_lite { - -namespace std11 { - -#if optional_CPP11_OR_GREATER - using std::move; -#else - template< typename T > T & move( T & t ) { return t; } -#endif - -#if optional_HAVE( CONDITIONAL ) - using std::conditional; -#else - template< bool B, typename T, typename F > struct conditional { typedef T type; }; - template< typename T, typename F > struct conditional { typedef F type; }; -#endif // optional_HAVE_CONDITIONAL - -} // namespace std11 - -#if optional_CPP11_OR_GREATER - -/// type traits C++17: - -namespace std17 { - -#if optional_CPP17_OR_GREATER - -using std::is_swappable; -using std::is_nothrow_swappable; - -#elif optional_CPP11_OR_GREATER - -namespace detail { - -using std::swap; - -struct is_swappable -{ - template< typename T, typename = decltype( swap( std::declval(), std::declval() ) ) > - static std::true_type test( int /*unused*/ ); - - template< typename > - static std::false_type test(...); -}; - -struct is_nothrow_swappable -{ - // wrap noexcept(expr) in separate function as work-around for VC140 (VS2015): - - template< typename T > - static constexpr bool satisfies() - { - return noexcept( swap( std::declval(), std::declval() ) ); - } - - template< typename T > - static auto test( int /*unused*/ ) -> std::integral_constant()>{} - - template< typename > - static auto test(...) -> std::false_type; -}; - -} // namespace detail - -// is [nothow] swappable: - -template< typename T > -struct is_swappable : decltype( detail::is_swappable::test(0) ){}; - -template< typename T > -struct is_nothrow_swappable : decltype( detail::is_nothrow_swappable::test(0) ){}; - -#endif // optional_CPP17_OR_GREATER - -} // namespace std17 - -/// type traits C++20: - -namespace std20 { - -template< typename T > -struct remove_cvref -{ - typedef typename std::remove_cv< typename std::remove_reference::type >::type type; -}; - -} // namespace std20 - -#endif // optional_CPP11_OR_GREATER - -/// class optional - -template< typename T > -class optional; - -namespace detail { - -// for optional_REQUIRES_T - -#if optional_CPP11_OR_GREATER -enum class enabler{}; -#endif - -// C++11 emulation: - -struct nulltype{}; - -template< typename Head, typename Tail > -struct typelist -{ - typedef Head head; - typedef Tail tail; -}; - -#if optional_CONFIG_MAX_ALIGN_HACK - -// Max align, use most restricted type for alignment: - -#define optional_UNIQUE( name ) optional_UNIQUE2( name, __LINE__ ) -#define optional_UNIQUE2( name, line ) optional_UNIQUE3( name, line ) -#define optional_UNIQUE3( name, line ) name ## line - -#define optional_ALIGN_TYPE( type ) \ - type optional_UNIQUE( _t ); struct_t< type > optional_UNIQUE( _st ) - -template< typename T > -struct struct_t { T _; }; - -union max_align_t -{ - optional_ALIGN_TYPE( char ); - optional_ALIGN_TYPE( short int ); - optional_ALIGN_TYPE( int ); - optional_ALIGN_TYPE( long int ); - optional_ALIGN_TYPE( float ); - optional_ALIGN_TYPE( double ); - optional_ALIGN_TYPE( long double ); - optional_ALIGN_TYPE( char * ); - optional_ALIGN_TYPE( short int * ); - optional_ALIGN_TYPE( int * ); - optional_ALIGN_TYPE( long int * ); - optional_ALIGN_TYPE( float * ); - optional_ALIGN_TYPE( double * ); - optional_ALIGN_TYPE( long double * ); - optional_ALIGN_TYPE( void * ); - -#ifdef HAVE_LONG_LONG - optional_ALIGN_TYPE( long long ); -#endif - - struct Unknown; - - Unknown ( * optional_UNIQUE(_) )( Unknown ); - Unknown * Unknown::* optional_UNIQUE(_); - Unknown ( Unknown::* optional_UNIQUE(_) )( Unknown ); - - struct_t< Unknown ( * )( Unknown) > optional_UNIQUE(_); - struct_t< Unknown * Unknown::* > optional_UNIQUE(_); - struct_t< Unknown ( Unknown::* )(Unknown) > optional_UNIQUE(_); -}; - -#undef optional_UNIQUE -#undef optional_UNIQUE2 -#undef optional_UNIQUE3 - -#undef optional_ALIGN_TYPE - -#elif defined( optional_CONFIG_ALIGN_AS ) // optional_CONFIG_MAX_ALIGN_HACK - -// Use user-specified type for alignment: - -#define optional_ALIGN_AS( unused ) \ - optional_CONFIG_ALIGN_AS - -#else // optional_CONFIG_MAX_ALIGN_HACK - -// Determine POD type to use for alignment: - -#define optional_ALIGN_AS( to_align ) \ - typename type_of_size< alignment_types, alignment_of< to_align >::value >::type - -template< typename T > -struct alignment_of; - -template< typename T > -struct alignment_of_hack -{ - char c; - T t; - alignment_of_hack(); -}; - -template< size_t A, size_t S > -struct alignment_logic -{ - enum { value = A < S ? A : S }; -}; - -template< typename T > -struct alignment_of -{ - enum { value = alignment_logic< - sizeof( alignment_of_hack ) - sizeof(T), sizeof(T) >::value }; -}; - -template< typename List, size_t N > -struct type_of_size -{ - typedef typename std11::conditional< - N == sizeof( typename List::head ), - typename List::head, - typename type_of_size::type >::type type; -}; - -template< size_t N > -struct type_of_size< nulltype, N > -{ - typedef optional_CONFIG_ALIGN_AS_FALLBACK type; -}; - -template< typename T> -struct struct_t { T _; }; - -#define optional_ALIGN_TYPE( type ) \ - typelist< type , typelist< struct_t< type > - -struct Unknown; - -typedef - optional_ALIGN_TYPE( char ), - optional_ALIGN_TYPE( short ), - optional_ALIGN_TYPE( int ), - optional_ALIGN_TYPE( long), optional_ALIGN_TYPE(float), optional_ALIGN_TYPE(double), - optional_ALIGN_TYPE(long double), - - optional_ALIGN_TYPE(char*), optional_ALIGN_TYPE(short*), optional_ALIGN_TYPE(int*), - optional_ALIGN_TYPE(long*), optional_ALIGN_TYPE(float*), optional_ALIGN_TYPE(double*), - optional_ALIGN_TYPE(long double*), - - optional_ALIGN_TYPE(Unknown (*)(Unknown)), optional_ALIGN_TYPE(Unknown* Unknown::*), - optional_ALIGN_TYPE(Unknown (Unknown::*)(Unknown)), - - nulltype >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> alignment_types; - -#undef optional_ALIGN_TYPE - -#endif // optional_CONFIG_MAX_ALIGN_HACK - -/// C++03 constructed union to hold value. - -template -union storage_t { - // private: - // template< typename > friend class optional; - - typedef T value_type; - - storage_t() optional_is_default - - explicit storage_t(value_type const& v) { - construct_value(v); - } - - void construct_value(value_type const& v) { ::new (value_ptr()) value_type(v); } - -#if optional_CPP11_OR_GREATER - - explicit storage_t(value_type&& v) { construct_value(std::move(v)); } - - void construct_value(value_type&& v) { ::new (value_ptr()) value_type(std::move(v)); } - - template - void emplace(Args&&... args) { - ::new (value_ptr()) value_type(std::forward(args)...); - } - - template - void emplace(std::initializer_list il, Args&&... args) { - ::new (value_ptr()) value_type(il, std::forward(args)...); - } - -#endif - - void destruct_value() { value_ptr()->~T(); } - - optional_nodiscard value_type const* value_ptr() const { return as(); } - - value_type* value_ptr() { return as(); } - - optional_nodiscard value_type const& value() const optional_ref_qual { - return *value_ptr(); - } - - value_type& value() optional_ref_qual { return *value_ptr(); } - -#if optional_CPP11_OR_GREATER - - optional_nodiscard value_type const&& value() const optional_refref_qual { - return std::move(value()); - } - - value_type&& value() optional_refref_qual { return std::move(value()); } - -#endif - -#if optional_CPP11_OR_GREATER - - using aligned_storage_t = - typename std::aligned_storage::type; - aligned_storage_t data; - -#elif optional_CONFIG_MAX_ALIGN_HACK - - typedef struct { - unsigned char data[sizeof(value_type)]; - } aligned_storage_t; - - max_align_t hack; - aligned_storage_t data; - -#else - typedef optional_ALIGN_AS(value_type) align_as_type; - - typedef struct { - align_as_type data[1 + (sizeof(value_type) - 1) / sizeof(align_as_type)]; - } aligned_storage_t; - aligned_storage_t data; - -#undef optional_ALIGN_AS - -#endif // optional_CONFIG_MAX_ALIGN_HACK - - optional_nodiscard void* ptr() optional_noexcept { return &data; } - - optional_nodiscard void const* ptr() const optional_noexcept { return &data; } - - template - optional_nodiscard U* as() { - return reinterpret_cast(ptr()); - } - - template - optional_nodiscard U const* as() const { - return reinterpret_cast(ptr()); - } -}; - -} // namespace detail - -/// disengaged state tag - -struct nullopt_t { - struct init {}; - explicit optional_constexpr nullopt_t(init /*unused*/) optional_noexcept {} -}; - -#if optional_HAVE(CONSTEXPR_11) -constexpr nullopt_t nullopt{nullopt_t::init{}}; -#else -// extra parenthesis to prevent the most vexing parse: -const nullopt_t nullopt((nullopt_t::init())); -#endif - -/// optional access error - -#if !optional_CONFIG_NO_EXCEPTIONS - -class bad_optional_access : public std::logic_error { - public: - explicit bad_optional_access() : logic_error("bad optional access") {} -}; - -#endif // optional_CONFIG_NO_EXCEPTIONS - -/// optional - -template -class optional { - private: - template - friend class optional; - - typedef void (optional::*safe_bool)() const; - - public: - typedef T value_type; - - // x.x.3.1, constructors - - // 1a - default construct - optional_constexpr optional() optional_noexcept : has_value_(false), contained() {} - - // 1b - construct explicitly empty - // NOLINTNEXTLINE( google-explicit-constructor, hicpp-explicit-conversions ) - optional_constexpr optional(nullopt_t /*unused*/) optional_noexcept : has_value_(false), - contained() {} - - // 2 - copy-construct - optional_constexpr14 optional( - optional const& other -#if optional_CPP11_OR_GREATER - optional_REQUIRES_A(true || std::is_copy_constructible::value) -#endif - ) - : has_value_(other.has_value()) { - if (other.has_value()) { - contained.construct_value(other.contained.value()); - } - } - -#if optional_CPP11_OR_GREATER - - // 3 (C++11) - move-construct from optional - optional_constexpr14 optional( - optional&& other optional_REQUIRES_A(true || std::is_move_constructible::value) - // NOLINTNEXTLINE( performance-noexcept-move-constructor ) - ) noexcept(std::is_nothrow_move_constructible::value) - : has_value_(other.has_value()) { - if (other.has_value()) { - contained.construct_value(std::move(other.contained.value())); - } - } - - // 4a (C++11) - explicit converting copy-construct from optional - template - explicit optional(optional const& other optional_REQUIRES_A( - std::is_constructible::value && - !std::is_constructible&>::value && - !std::is_constructible&&>::value && - !std::is_constructible const&>::value && - !std::is_constructible const&&>::value && - !std::is_convertible&, T>::value && - !std::is_convertible&&, T>::value && - !std::is_convertible const&, T>::value && - !std::is_convertible const&&, T>::value && - !std::is_convertible::value /*=> explicit */ - )) - : has_value_(other.has_value()) { - if (other.has_value()) { - contained.construct_value(T{other.contained.value()}); - } - } -#endif // optional_CPP11_OR_GREATER - - // 4b (C++98 and later) - non-explicit converting copy-construct from optional - template - // NOLINTNEXTLINE( google-explicit-constructor, hicpp-explicit-conversions ) - optional( - optional const& other -#if optional_CPP11_OR_GREATER - optional_REQUIRES_A(std::is_constructible::value && - !std::is_constructible&>::value && - !std::is_constructible&&>::value && - !std::is_constructible const&>::value && - !std::is_constructible const&&>::value && - !std::is_convertible&, T>::value && - !std::is_convertible&&, T>::value && - !std::is_convertible const&, T>::value && - !std::is_convertible const&&, T>::value && - std::is_convertible::value /*=> non-explicit */ - ) -#endif // optional_CPP11_OR_GREATER - ) - : has_value_(other.has_value()) { - if (other.has_value()) { - contained.construct_value(other.contained.value()); - } - } - -#if optional_CPP11_OR_GREATER - - // 5a (C++11) - explicit converting move-construct from optional - template - explicit optional(optional&& other optional_REQUIRES_A( - std::is_constructible::value && - !std::is_constructible&>::value && - !std::is_constructible&&>::value && - !std::is_constructible const&>::value && - !std::is_constructible const&&>::value && - !std::is_convertible&, T>::value && - !std::is_convertible&&, T>::value && - !std::is_convertible const&, T>::value && - !std::is_convertible const&&, T>::value && - !std::is_convertible::value /*=> explicit */ - )) - : has_value_(other.has_value()) { - if (other.has_value()) { - contained.construct_value(T{std::move(other.contained.value())}); - } - } - - // 5a (C++11) - non-explicit converting move-construct from optional - template - // NOLINTNEXTLINE( google-explicit-constructor, hicpp-explicit-conversions ) - optional(optional&& other optional_REQUIRES_A( - std::is_constructible::value && - !std::is_constructible&>::value && - !std::is_constructible&&>::value && - !std::is_constructible const&>::value && - !std::is_constructible const&&>::value && - !std::is_convertible&, T>::value && - !std::is_convertible&&, T>::value && - !std::is_convertible const&, T>::value && - !std::is_convertible const&&, T>::value && - std::is_convertible::value /*=> non-explicit */ - )) - : has_value_(other.has_value()) { - if (other.has_value()) { - contained.construct_value(std::move(other.contained.value())); - } - } - - // 6 (C++11) - in-place construct - template < - typename... Args optional_REQUIRES_T(std::is_constructible::value)> - optional_constexpr explicit optional(nonstd_lite_in_place_t(T), Args&&... args) - : has_value_(true), contained(T(std::forward(args)...)) {} - - // 7 (C++11) - in-place construct, initializer-list - template &, Args&&...>::value)> - optional_constexpr explicit optional(nonstd_lite_in_place_t(T), - std::initializer_list il, Args&&... args) - : has_value_(true), contained(T(il, std::forward(args)...)) {} - - // 8a (C++11) - explicit move construct from value - template - optional_constexpr explicit optional(U&& value optional_REQUIRES_A( - std::is_constructible::value && - !std::is_same::type, - nonstd_lite_in_place_t(U)>::value && - !std::is_same::type, optional >::value && - !std::is_convertible::value /*=> explicit */ - )) - : has_value_(true), contained(T{std::forward(value)}) {} - - // 8b (C++11) - non-explicit move construct from value - template - // NOLINTNEXTLINE( google-explicit-constructor, hicpp-explicit-conversions ) - optional_constexpr optional(U&& value optional_REQUIRES_A( - std::is_constructible::value && - !std::is_same::type, - nonstd_lite_in_place_t(U)>::value && - !std::is_same::type, optional >::value && - std::is_convertible::value /*=> non-explicit */ - )) - : has_value_(true), contained(std::forward(value)) {} - -#else // optional_CPP11_OR_GREATER - - // 8 (C++98) - optional(value_type const& value) : has_value_(true), contained(value) {} - -#endif // optional_CPP11_OR_GREATER - - // x.x.3.2, destructor - - ~optional() { - if (has_value()) { - contained.destruct_value(); - } - } - - // x.x.3.3, assignment - - // 1 (C++98and later) - assign explicitly empty - optional& operator=(nullopt_t /*unused*/) optional_noexcept { - reset(); - return *this; - } - - // 2 (C++98and later) - copy-assign from optional -#if optional_CPP11_OR_GREATER - // NOLINTNEXTLINE( cppcoreguidelines-c-copy-assignment-signature, - // misc-unconventional-assign-operator ) - optional_REQUIRES_R(optional&, true - // std::is_copy_constructible::value - // && std::is_copy_assignable::value - ) - operator=(optional const& other) noexcept( - std::is_nothrow_move_assignable::value&& - std::is_nothrow_move_constructible::value) -#else - optional& operator=(optional const& other) -#endif - { - if ((has_value() == true) && (other.has_value() == false)) { - reset(); - } else if ((has_value() == false) && (other.has_value() == true)) { - initialize(*other); - } else if ((has_value() == true) && (other.has_value() == true)) { - contained.value() = *other; - } - return *this; - } - -#if optional_CPP11_OR_GREATER - - // 3 (C++11) - move-assign from optional - // NOLINTNEXTLINE( cppcoreguidelines-c-copy-assignment-signature, - // misc-unconventional-assign-operator ) - optional_REQUIRES_R(optional&, true - // std::is_move_constructible::value - // && std::is_move_assignable::value - ) - operator=(optional&& other) noexcept { - if ((has_value() == true) && (other.has_value() == false)) { - reset(); - } else if ((has_value() == false) && (other.has_value() == true)) { - initialize(std::move(*other)); - } else if ((has_value() == true) && (other.has_value() == true)) { - contained.value() = std::move(*other); - } - return *this; - } - - // 4 (C++11) - move-assign from value - template - // NOLINTNEXTLINE( cppcoreguidelines-c-copy-assignment-signature, - // misc-unconventional-assign-operator ) - optional_REQUIRES_R( - optional&, - std::is_constructible::value&& std::is_assignable::value && - !std::is_same::type, - nonstd_lite_in_place_t(U)>::value && - !std::is_same::type, optional >::value && - !(std::is_scalar::value && - std::is_same::type>::value)) - operator=(U&& value) { - if (has_value()) { - contained.value() = std::forward(value); - } else { - initialize(T(std::forward(value))); - } - return *this; - } - -#else // optional_CPP11_OR_GREATER - - // 4 (C++98) - copy-assign from value - template - optional& operator=(U const& value) { - if (has_value()) - contained.value() = value; - else - initialize(T(value)); - return *this; - } - -#endif // optional_CPP11_OR_GREATER - - // 5 (C++98 and later) - converting copy-assign from optional - template -#if optional_CPP11_OR_GREATER - // NOLINTNEXTLINE( cppcoreguidelines-c-copy-assignment-signature, - // misc-unconventional-assign-operator ) - optional_REQUIRES_R(optional&, - std::is_constructible::value&& - std::is_assignable::value && - !std::is_constructible&>::value && - !std::is_constructible&&>::value && - !std::is_constructible const&>::value && - !std::is_constructible const&&>::value && - !std::is_convertible&, T>::value && - !std::is_convertible&&, T>::value && - !std::is_convertible const&, T>::value && - !std::is_convertible const&&, T>::value && - !std::is_assignable&>::value && - !std::is_assignable&&>::value && - !std::is_assignable const&>::value && - !std::is_assignable const&&>::value) -#else - optional& -#endif // optional_CPP11_OR_GREATER - operator=(optional const& other) { - return *this = optional(other); - } - -#if optional_CPP11_OR_GREATER - - // 6 (C++11) - converting move-assign from optional - template - // NOLINTNEXTLINE( cppcoreguidelines-c-copy-assignment-signature, - // misc-unconventional-assign-operator ) - optional_REQUIRES_R( - optional&, std::is_constructible::value&& std::is_assignable::value && - !std::is_constructible&>::value && - !std::is_constructible&&>::value && - !std::is_constructible const&>::value && - !std::is_constructible const&&>::value && - !std::is_convertible&, T>::value && - !std::is_convertible&&, T>::value && - !std::is_convertible const&, T>::value && - !std::is_convertible const&&, T>::value && - !std::is_assignable&>::value && - !std::is_assignable&&>::value && - !std::is_assignable const&>::value && - !std::is_assignable const&&>::value) - operator=(optional&& other) { - return *this = optional(std::move(other)); - } - - // 7 (C++11) - emplace - template < - typename... Args optional_REQUIRES_T(std::is_constructible::value)> - T& emplace(Args&&... args) { - *this = nullopt; - contained.emplace(std::forward(args)...); - has_value_ = true; - return contained.value(); - } - - // 8 (C++11) - emplace, initializer-list - template &, Args&&...>::value)> - T& emplace(std::initializer_list il, Args&&... args) { - *this = nullopt; - contained.emplace(il, std::forward(args)...); - has_value_ = true; - return contained.value(); - } - -#endif // optional_CPP11_OR_GREATER - - // x.x.3.4, swap - - void swap(optional& other) -#if optional_CPP11_OR_GREATER - noexcept(std::is_nothrow_move_constructible::value&& - std17::is_nothrow_swappable::value) -#endif - { - using std::swap; - if ((has_value() == true) && (other.has_value() == true)) { - swap(**this, *other); - } else if ((has_value() == false) && (other.has_value() == true)) { - initialize(std11::move(*other)); - other.reset(); - } else if ((has_value() == true) && (other.has_value() == false)) { - other.initialize(std11::move(**this)); - reset(); - } - } - - // x.x.3.5, observers - - optional_constexpr value_type const* operator->() const { - return assert(has_value()), contained.value_ptr(); - } - - optional_constexpr14 value_type* operator->() { - return assert(has_value()), contained.value_ptr(); - } - - optional_constexpr value_type const& operator*() const optional_ref_qual { - return assert(has_value()), contained.value(); - } - - optional_constexpr14 value_type& operator*() optional_ref_qual { - return assert(has_value()), contained.value(); - } - -#if optional_HAVE(REF_QUALIFIER) && \ - (!optional_COMPILER_GNUC_VERSION || optional_COMPILER_GNUC_VERSION >= 490) - - optional_constexpr value_type const&& operator*() const optional_refref_qual { - return std::move(**this); - } - - optional_constexpr14 value_type&& operator*() optional_refref_qual { - return std::move(**this); - } - -#endif - -#if optional_CPP11_OR_GREATER - optional_constexpr explicit operator bool() const optional_noexcept { - return has_value(); - } -#else - optional_constexpr operator safe_bool() const optional_noexcept { - return has_value() ? &optional::this_type_does_not_support_comparisons : 0; - } -#endif - - // NOLINTNEXTLINE( modernize-use-nodiscard ) - /*optional_nodiscard*/ optional_constexpr bool has_value() const optional_noexcept { - return has_value_; - } - - // NOLINTNEXTLINE( modernize-use-nodiscard ) - /*optional_nodiscard*/ optional_constexpr14 value_type const& value() const - optional_ref_qual { -#if optional_CONFIG_NO_EXCEPTIONS - assert(has_value()); -#else - if (!has_value()) { - throw bad_optional_access(); - } -#endif - return contained.value(); - } - - optional_constexpr14 value_type& value() optional_ref_qual { -#if optional_CONFIG_NO_EXCEPTIONS - assert(has_value()); -#else - if (!has_value()) { - throw bad_optional_access(); - } -#endif - return contained.value(); - } - -#if optional_HAVE(REF_QUALIFIER) && \ - (!optional_COMPILER_GNUC_VERSION || optional_COMPILER_GNUC_VERSION >= 490) - - // NOLINTNEXTLINE( modernize-use-nodiscard ) - /*optional_nodiscard*/ optional_constexpr value_type const&& value() const - optional_refref_qual { - return std::move(value()); - } - - optional_constexpr14 value_type&& value() optional_refref_qual { - return std::move(value()); - } - -#endif - -#if optional_CPP11_OR_GREATER - - template - optional_constexpr value_type value_or(U&& v) const optional_ref_qual { - return has_value() ? contained.value() : static_cast(std::forward(v)); - } - - template - optional_constexpr14 value_type value_or(U&& v) optional_refref_qual { - return has_value() ? std::move(contained.value()) - : static_cast(std::forward(v)); - } - -#else - - template - optional_constexpr value_type value_or(U const& v) const { - return has_value() ? contained.value() : static_cast(v); - } - -#endif // optional_CPP11_OR_GREATER - - // x.x.3.6, modifiers - - void reset() optional_noexcept { - if (has_value()) { - contained.destruct_value(); - } - - has_value_ = false; - } - - private: - void this_type_does_not_support_comparisons() const {} - - template - void initialize(V const& value) { - assert(!has_value()); - contained.construct_value(value); - has_value_ = true; - } - -#if optional_CPP11_OR_GREATER - template - void initialize(V&& value) { - assert(!has_value()); - contained.construct_value(std::move(value)); - has_value_ = true; - } - -#endif - - private: - bool has_value_; - detail::storage_t contained; -}; - -// Relational operators - -template -inline optional_constexpr bool operator==(optional const& x, optional const& y) { - return bool(x) != bool(y) ? false : !bool(x) ? true : *x == *y; -} - -template -inline optional_constexpr bool operator!=(optional const& x, optional const& y) { - return !(x == y); -} - -template -inline optional_constexpr bool operator<(optional const& x, optional const& y) { - return (!y) ? false : (!x) ? true : *x < *y; -} - -template -inline optional_constexpr bool operator>(optional const& x, optional const& y) { - return (y < x); -} - -template -inline optional_constexpr bool operator<=(optional const& x, optional const& y) { - return !(y < x); -} - -template -inline optional_constexpr bool operator>=(optional const& x, optional const& y) { - return !(x < y); -} - -// Comparison with nullopt - -template -inline optional_constexpr bool operator==(optional const& x, - nullopt_t /*unused*/) optional_noexcept { - return (!x); -} - -template -inline optional_constexpr bool operator==(nullopt_t /*unused*/, - optional const& x) optional_noexcept { - return (!x); -} - -template -inline optional_constexpr bool operator!=(optional const& x, - nullopt_t /*unused*/) optional_noexcept { - return bool(x); -} - -template -inline optional_constexpr bool operator!=(nullopt_t /*unused*/, - optional const& x) optional_noexcept { - return bool(x); -} - -template -inline optional_constexpr bool operator<(optional const& /*unused*/, - nullopt_t /*unused*/) optional_noexcept { - return false; -} - -template -inline optional_constexpr bool operator<(nullopt_t /*unused*/, - optional const& x) optional_noexcept { - return bool(x); -} - -template -inline optional_constexpr bool operator<=(optional const& x, - nullopt_t /*unused*/) optional_noexcept { - return (!x); -} - -template -inline optional_constexpr bool operator<=( - nullopt_t /*unused*/, optional const& /*unused*/) optional_noexcept { - return true; -} - -template -inline optional_constexpr bool operator>(optional const& x, - nullopt_t /*unused*/) optional_noexcept { - return bool(x); -} - -template -inline optional_constexpr bool operator>( - nullopt_t /*unused*/, optional const& /*unused*/) optional_noexcept { - return false; -} - -template -inline optional_constexpr bool operator>=(optional const& /*unused*/, - nullopt_t /*unused*/) optional_noexcept { - return true; -} - -template -inline optional_constexpr bool operator>=(nullopt_t /*unused*/, - optional const& x) optional_noexcept { - return (!x); -} - -// Comparison with T - -template -inline optional_constexpr bool operator==(optional const& x, U const& v) { - return bool(x) ? *x == v : false; -} - -template -inline optional_constexpr bool operator==(U const& v, optional const& x) { - return bool(x) ? v == *x : false; -} - -template -inline optional_constexpr bool operator!=(optional const& x, U const& v) { - return bool(x) ? *x != v : true; -} - -template -inline optional_constexpr bool operator!=(U const& v, optional const& x) { - return bool(x) ? v != *x : true; -} - -template -inline optional_constexpr bool operator<(optional const& x, U const& v) { - return bool(x) ? *x < v : true; -} - -template -inline optional_constexpr bool operator<(U const& v, optional const& x) { - return bool(x) ? v < *x : false; -} - -template -inline optional_constexpr bool operator<=(optional const& x, U const& v) { - return bool(x) ? *x <= v : true; -} - -template -inline optional_constexpr bool operator<=(U const& v, optional const& x) { - return bool(x) ? v <= *x : false; -} - -template -inline optional_constexpr bool operator>(optional const& x, U const& v) { - return bool(x) ? *x > v : false; -} - -template -inline optional_constexpr bool operator>(U const& v, optional const& x) { - return bool(x) ? v > *x : true; -} - -template -inline optional_constexpr bool operator>=(optional const& x, U const& v) { - return bool(x) ? *x >= v : false; -} - -template -inline optional_constexpr bool operator>=(U const& v, optional const& x) { - return bool(x) ? v >= *x : true; -} - -// Specialized algorithms - -template ::value&& std17::is_swappable::value) -#endif - > -void swap(optional& x, optional& y) -#if optional_CPP11_OR_GREATER - noexcept(noexcept(x.swap(y))) -#endif -{ - x.swap(y); -} - -#if optional_CPP11_OR_GREATER - -template -optional_constexpr optional::type> make_optional(T&& value) { - return optional::type>(std::forward(value)); -} - -template -optional_constexpr optional make_optional(Args&&... args) { - return optional(nonstd_lite_in_place(T), std::forward(args)...); -} - -template -optional_constexpr optional make_optional(std::initializer_list il, - Args&&... args) { - return optional(nonstd_lite_in_place(T), il, std::forward(args)...); -} - -#else - -template -optional make_optional(T const& value) { - return optional(value); -} - -#endif // optional_CPP11_OR_GREATER - -} // namespace optional_lite - -using optional_lite::bad_optional_access; -using optional_lite::nullopt; -using optional_lite::nullopt_t; -using optional_lite::optional; - -using optional_lite::make_optional; - -} // namespace nonstd - -#if optional_CPP11_OR_GREATER - -// specialize the std::hash algorithm: - -namespace std { - -template -struct hash > { - public: - std::size_t operator()(nonstd::optional const& v) const optional_noexcept { - return bool(v) ? std::hash{}(*v) : 0; - } -}; - -} // namespace std - -#endif // optional_CPP11_OR_GREATER - -#if defined(__clang__) -#pragma clang diagnostic pop -#elif defined(__GNUC__) -#pragma GCC diagnostic pop -#elif defined(_MSC_VER) -#pragma warning(pop) -#endif - -#endif // optional_USES_STD_OPTIONAL - -#endif // NONSTD_OPTIONAL_LITE_HPP diff --git a/cpp/src/gandiva/cache.h b/cpp/src/gandiva/cache.h index 8b80ff8c5e8..7cff9b02692 100644 --- a/cpp/src/gandiva/cache.h +++ b/cpp/src/gandiva/cache.h @@ -39,10 +39,10 @@ class Cache { Cache() : Cache(GetCapacity()) {} ValueType GetObjectCode(const KeyType& cache_key) { - arrow::util::optional result; + std::optional result; std::lock_guard lock(mtx_); result = cache_.get(cache_key); - return result != arrow::util::nullopt ? *result : nullptr; + return result != std::nullopt ? *result : nullptr; } void PutObjectCode(const KeyType& cache_key, const ValueType& module) { diff --git a/cpp/src/gandiva/lru_cache.h b/cpp/src/gandiva/lru_cache.h index 6602116b0a0..2fa7ccfbfe5 100644 --- a/cpp/src/gandiva/lru_cache.h +++ b/cpp/src/gandiva/lru_cache.h @@ -18,11 +18,10 @@ #pragma once #include +#include #include #include -#include "arrow/util/optional.h" - // modified from boost LRU cache -> the boost cache supported only an // ordered map. namespace gandiva { @@ -70,12 +69,12 @@ class LruCache { } } - arrow::util::optional get(const key_type& key) { + std::optional get(const key_type& key) { // lookup value in the cache typename map_type::iterator value_for_key = map_.find(key); if (value_for_key == map_.end()) { // value not in cache - return arrow::util::nullopt; + return std::nullopt; } // return the value, but first update its place in the most diff --git a/cpp/src/gandiva/lru_cache_test.cc b/cpp/src/gandiva/lru_cache_test.cc index 06c86d69032..ccd5867b322 100644 --- a/cpp/src/gandiva/lru_cache_test.cc +++ b/cpp/src/gandiva/lru_cache_test.cc @@ -50,7 +50,7 @@ TEST_F(TestLruCache, TestEvict) { cache_.insert(TestCacheKey(3), "hello"); // should have evicted key 1 ASSERT_EQ(2, cache_.size()); - ASSERT_EQ(cache_.get(TestCacheKey(1)), arrow::util::nullopt); + ASSERT_EQ(cache_.get(TestCacheKey(1)), std::nullopt); } TEST_F(TestLruCache, TestLruBehavior) { diff --git a/cpp/src/parquet/level_conversion.cc b/cpp/src/parquet/level_conversion.cc index ffdca476ddd..49ae15d6408 100644 --- a/cpp/src/parquet/level_conversion.cc +++ b/cpp/src/parquet/level_conversion.cc @@ -18,12 +18,12 @@ #include #include +#include #include "arrow/util/bit_run_reader.h" #include "arrow/util/bit_util.h" #include "arrow/util/cpu_info.h" #include "arrow/util/logging.h" -#include "arrow/util/optional.h" #include "parquet/exception.h" #include "parquet/level_comparison.h" @@ -36,7 +36,7 @@ namespace internal { namespace { using ::arrow::internal::CpuInfo; -using ::arrow::util::optional; +using ::std::optional; template void DefRepLevelsToListInfo(const int16_t* def_levels, const int16_t* rep_levels, diff --git a/cpp/src/parquet/statistics.cc b/cpp/src/parquet/statistics.cc index 591925554fa..13d20fcb33d 100644 --- a/cpp/src/parquet/statistics.cc +++ b/cpp/src/parquet/statistics.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -30,7 +31,6 @@ #include "arrow/util/bit_run_reader.h" #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" -#include "arrow/util/optional.h" #include "arrow/util/ubsan.h" #include "arrow/visit_data_inline.h" #include "parquet/encoding.h" @@ -276,7 +276,7 @@ template struct CompareHelper : public BinaryLikeCompareHelperBase {}; -using ::arrow::util::optional; +using ::std::optional; template ::arrow::enable_if_t::value, optional>> @@ -297,11 +297,11 @@ CleanStatistic(std::pair min_max) { // Ignore if one of the value is nan. if (std::isnan(min) || std::isnan(max)) { - return ::arrow::util::nullopt; + return ::std::nullopt; } if (min == std::numeric_limits::max() && max == std::numeric_limits::lowest()) { - return ::arrow::util::nullopt; + return ::std::nullopt; } T zero{}; @@ -319,7 +319,7 @@ CleanStatistic(std::pair min_max) { optional> CleanStatistic(std::pair min_max) { if (min_max.first.ptr == nullptr || min_max.second.ptr == nullptr) { - return ::arrow::util::nullopt; + return ::std::nullopt; } return min_max; } @@ -327,7 +327,7 @@ optional> CleanStatistic(std::pair min_max) { optional> CleanStatistic( std::pair min_max) { if (min_max.first.ptr == nullptr || min_max.second.ptr == nullptr) { - return ::arrow::util::nullopt; + return ::std::nullopt; } return min_max; } diff --git a/cpp/src/parquet/stream_reader.h b/cpp/src/parquet/stream_reader.h index 806b0e8ad9a..e16f8ee694c 100644 --- a/cpp/src/parquet/stream_reader.h +++ b/cpp/src/parquet/stream_reader.h @@ -22,10 +22,10 @@ #include #include #include +#include #include #include -#include "arrow/util/optional.h" #include "parquet/column_reader.h" #include "parquet/file_reader.h" #include "parquet/stream_writer.h" @@ -44,9 +44,9 @@ namespace parquet { /// Required and optional fields are supported: /// - Required fields are read using operator>>(T) /// - Optional fields are read with -/// operator>>(arrow::util::optional) +/// operator>>(std::optional) /// -/// Note that operator>>(arrow::util::optional) can be used to read +/// Note that operator>>(std::optional) can be used to read /// required fields. /// /// Similarly operator>>(T) can be used to read optional fields. @@ -58,7 +58,7 @@ namespace parquet { class PARQUET_EXPORT StreamReader { public: template - using optional = ::arrow::util::optional; + using optional = ::std::optional; // N.B. Default constructed objects are not usable. This // constructor is provided so that the object may be move diff --git a/cpp/src/parquet/stream_reader_test.cc b/cpp/src/parquet/stream_reader_test.cc index eb7b133740e..aa0ff25b10d 100644 --- a/cpp/src/parquet/stream_reader_test.cc +++ b/cpp/src/parquet/stream_reader_test.cc @@ -34,7 +34,7 @@ namespace test { template using optional = StreamReader::optional; -using ::arrow::util::nullopt; +using ::std::nullopt; struct TestData { static void init() { std::time(&ts_offset_); } diff --git a/cpp/src/parquet/stream_writer.h b/cpp/src/parquet/stream_writer.h index d0db850c341..5801011e166 100644 --- a/cpp/src/parquet/stream_writer.h +++ b/cpp/src/parquet/stream_writer.h @@ -21,10 +21,10 @@ #include #include #include +#include #include #include -#include "arrow/util/optional.h" #include "arrow/util/string_view.h" #include "parquet/column_writer.h" #include "parquet/file_writer.h" @@ -48,11 +48,11 @@ namespace parquet { /// Required and optional fields are supported: /// - Required fields are written using operator<<(T) /// - Optional fields are written using -/// operator<<(arrow::util::optional). +/// operator<<(std::optional). /// /// Note that operator<<(T) can be used to write optional fields. /// -/// Similarly, operator<<(arrow::util::optional) can be used to +/// Similarly, operator<<(std::optional) can be used to /// write required fields. However if the optional parameter does not /// have a value (i.e. it is nullopt) then a ParquetException will be /// raised. @@ -62,7 +62,7 @@ namespace parquet { class PARQUET_EXPORT StreamWriter { public: template - using optional = ::arrow::util::optional; + using optional = ::std::optional; // N.B. Default constructed objects are not usable. This // constructor is provided so that the object may be move diff --git a/dev/release/verify-release-candidate.sh b/dev/release/verify-release-candidate.sh index b016988ba91..73fdd53b8f2 100755 --- a/dev/release/verify-release-candidate.sh +++ b/dev/release/verify-release-candidate.sh @@ -1039,7 +1039,7 @@ test_macos_wheels() { local check_flight=OFF else local python_versions="3.7m 3.8 3.9 3.10" - local platform_tags="macosx_10_9_x86_64 macosx_10_13_x86_64" + local platform_tags="macosx_10_14_x86_64" fi # verify arch-native wheels inside an arch-native conda environment diff --git a/dev/tasks/homebrew-formulae/autobrew/apache-arrow.rb b/dev/tasks/homebrew-formulae/autobrew/apache-arrow.rb index 1530a7b1d97..8677b55c159 100644 --- a/dev/tasks/homebrew-formulae/autobrew/apache-arrow.rb +++ b/dev/tasks/homebrew-formulae/autobrew/apache-arrow.rb @@ -46,6 +46,7 @@ def install -DARROW_BUILD_UTILITIES=ON -DARROW_COMPUTE=ON -DARROW_CSV=ON + -DARROW_CXXFLAGS="-D_LIBCPP_DISABLE_AVAILABILITY" -DARROW_DATASET=ON -DARROW_FILESYSTEM=ON -DARROW_GCS=ON diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index a0bb7e304b7..98af1a194bf 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -478,7 +478,7 @@ tasks: {############################## Wheel OSX ####################################} -{% for macos_version, macos_codename in [("10.13", "high-sierra")] %} +{% for macos_version, macos_codename in [("10.14", "mojave")] %} {% set platform_tag = "macosx_{}_x86_64".format(macos_version.replace('.', '_')) %} wheel-macos-{{ macos_codename }}-{{ python_tag }}-amd64: @@ -538,9 +538,9 @@ tasks: params: arch: universal2 python_version: "{{ python_version }}" - macos_deployment_target: "10.13" + macos_deployment_target: "10.14" artifacts: - - pyarrow-{no_rc_version}-{{ python_tag }}-{{ python_tag }}-macosx_10_13_universal2.whl + - pyarrow-{no_rc_version}-{{ python_tag }}-{{ python_tag }}-macosx_10_14_universal2.whl {% endfor %} {############################ Python sdist ####################################} diff --git a/docs/source/cpp/datatypes.rst b/docs/source/cpp/datatypes.rst index 2a31ac65a93..1d2133cbdf3 100644 --- a/docs/source/cpp/datatypes.rst +++ b/docs/source/cpp/datatypes.rst @@ -140,7 +140,7 @@ function for any numeric (integer or float) array: typename CType = typename DataType::c_type> arrow::enable_if_number SumArray(const ArrayType& array) { CType sum = 0; - for (arrow::util::optional value : array) { + for (std::optional value : array) { if (value.has_value()) { sum += value.value(); } @@ -192,7 +192,7 @@ here is how one might sum across columns of arbitrary numeric types: template arrow::enable_if_number Visit(const ArrayType& array) { - for (arrow::util::optional value : array) { + for (std::optional value : array) { if (value.has_value()) { partial += static_cast(value.value()); } @@ -205,4 +205,4 @@ Arrow also provides abstract visitor classes (:class:`arrow::TypeVisitor`, :class:`arrow::ScalarVisitor`, :class:`arrow::ArrayVisitor`) and an ``Accept()`` method on each of the corresponding base types (e.g. :func:`arrow::Array::Accept`). However, these are not able to be implemented using template functions, so you -will typically prefer using the inline type visitors. \ No newline at end of file +will typically prefer using the inline type visitors. diff --git a/docs/source/cpp/gdb.rst b/docs/source/cpp/gdb.rst index beb9267f567..609f11a993a 100644 --- a/docs/source/cpp/gdb.rst +++ b/docs/source/cpp/gdb.rst @@ -165,5 +165,4 @@ Important utility classes are also covered: * :class:`arrow::Status` and :class:`arrow::Result` * :class:`arrow::Buffer` and subclasses * :class:`arrow::Decimal128`, :class:`arrow::Decimal256` -* :class:`arrow::util::string_view`, :class:`arrow::util::optional`, - :class:`arrow::util::Variant` +* :class:`arrow::util::string_view`, :class:`arrow::util::Variant` diff --git a/docs/source/cpp/streaming_execution.rst b/docs/source/cpp/streaming_execution.rst index daa5f4be2f0..88922a04aa7 100644 --- a/docs/source/cpp/streaming_execution.rst +++ b/docs/source/cpp/streaming_execution.rst @@ -593,7 +593,7 @@ be quite tricky to configure. To process data from files the scan operation is The source node requires some kind of function that can be called to poll for more data. This function should take no arguments and should return an -``arrow::Future>>``. +``arrow::Future>``. This function might be reading a file, iterating through an in memory structure, or receiving data from a network connection. The arrow library refers to these functions as ``arrow::AsyncGenerator`` and there are a number of utilities for working with these functions. For this example we use @@ -752,7 +752,7 @@ execution definition. :class:`arrow::compute::SinkNodeOptions` interface is used the required options. Similar to the source operator the sink operator exposes the output with a function that returns a record batch future each time it is called. It is expected the caller will repeatedly call this function until the generator function is exhausted (returns -``arrow::util::optional::nullopt``). If this function is not called often enough then record batches +``std::optional::nullopt``). If this function is not called often enough then record batches will accumulate in memory. An execution plan should only have one "terminal" node (one sink node). An :class:`ExecPlan` can terminate early due to cancellation or an error, before the output is fully consumed. However, the plan can be safely destroyed independently @@ -1003,4 +1003,4 @@ Complete Example: :start-after: (Doc section: Execution Plan Documentation Example) :end-before: (Doc section: Execution Plan Documentation Example) :linenos: - :lineno-match: \ No newline at end of file + :lineno-match: diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index e2c53974cfe..5a61227deed 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -59,7 +59,7 @@ set(CMAKE_MACOSX_RPATH 1) if(DEFINED ENV{MACOSX_DEPLOYMENT_TARGET}) set(CMAKE_OSX_DEPLOYMENT_TARGET $ENV{MACOSX_DEPLOYMENT_TARGET}) else() - set(CMAKE_OSX_DEPLOYMENT_TARGET 10.13) + set(CMAKE_OSX_DEPLOYMENT_TARGET 10.14) endif() # Generate a Clang compile_commands.json "compilation database" file for use diff --git a/python/pyarrow/includes/common.pxd b/python/pyarrow/includes/common.pxd index 07a75d4a081..c4beca4d15e 100644 --- a/python/pyarrow/includes/common.pxd +++ b/python/pyarrow/includes/common.pxd @@ -35,6 +35,15 @@ cimport cpython cdef extern from * namespace "std" nogil: cdef shared_ptr[T] static_pointer_cast[T, U](shared_ptr[U]) + +cdef extern from "" namespace "std" nogil: + cdef cppclass optional[T]: + c_bool has_value() + T value() + optional(T&) + optional& operator=[U](U&) + + # vendored from the cymove project https://github.com/ozars/cymove cdef extern from * namespace "cymove" nogil: """ diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 489d73bf27e..d9dde7803ab 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -54,13 +54,6 @@ cdef extern from "arrow/util/decimal.h" namespace "arrow" nogil: cdef cppclass CDecimal256" arrow::Decimal256": c_string ToString(int32_t scale) const -cdef extern from "arrow/util/optional.h" namespace "arrow::util" nogil: - cdef cppclass c_optional"arrow::util::optional"[T]: - c_bool has_value() - T value() - c_optional(T&) - c_optional& operator=[U](U&) - cdef extern from "arrow/config.h" namespace "arrow" nogil: cdef cppclass CBuildInfo" arrow::BuildInfo": diff --git a/python/pyarrow/includes/libarrow_fs.pxd b/python/pyarrow/includes/libarrow_fs.pxd index 7984b54f587..bf22ead83ec 100644 --- a/python/pyarrow/includes/libarrow_fs.pxd +++ b/python/pyarrow/includes/libarrow_fs.pxd @@ -223,7 +223,7 @@ cdef extern from "arrow/filesystem/api.h" namespace "arrow::fs" nogil: c_string endpoint_override c_string scheme c_string default_bucket_location - c_optional[double] retry_limit_seconds + optional[double] retry_limit_seconds shared_ptr[const CKeyValueMetadata] default_metadata c_bool Equals(const CS3Options& other) diff --git a/python/pyarrow/src/gdb.cc b/python/pyarrow/src/gdb.cc index 297bc6dbffc..7541e524609 100644 --- a/python/pyarrow/src/gdb.cc +++ b/python/pyarrow/src/gdb.cc @@ -34,7 +34,6 @@ #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" -#include "arrow/util/optional.h" #include "arrow/util/string_view.h" #include "arrow/util/variant.h" @@ -122,10 +121,6 @@ void TestSession() { auto error_result = Result(error_status); auto error_detail_result = Result(error_detail_status); - // Optionals - util::optional int_optional{42}; - util::optional null_int_optional{}; - // Variants using VariantType = util::Variant; diff --git a/python/pyarrow/src/python_test.cc b/python/pyarrow/src/python_test.cc index 54086faa7ca..865d4cf5c0f 100644 --- a/python/pyarrow/src/python_test.cc +++ b/python/pyarrow/src/python_test.cc @@ -18,6 +18,7 @@ #include "gtest/gtest.h" #include +#include #include #include @@ -28,7 +29,6 @@ #include "arrow/table.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/decimal.h" -#include "arrow/util/optional.h" #include "arrow_to_pandas.h" #include "decimal.h" @@ -399,7 +399,7 @@ TEST(BuiltinConversionTest, TestMixedTypeFails) { template void DecimalTestFromPythonDecimalRescale(std::shared_ptr type, OwnedRef python_decimal, - ::arrow::util::optional expected) { + std::optional expected) { DecimalValue value; const auto& decimal_type = checked_cast(*type); diff --git a/python/pyarrow/tests/test_gdb.py b/python/pyarrow/tests/test_gdb.py index 1990198d9f1..6b76d9b626e 100644 --- a/python/pyarrow/tests/test_gdb.py +++ b/python/pyarrow/tests/test_gdb.py @@ -297,13 +297,6 @@ def test_buffer_heap(gdb_arrow): 'arrow::Buffer of size 3, mutable, "abc"') -def test_optionals(gdb_arrow): - check_stack_repr(gdb_arrow, "int_optional", - "arrow::util::optional(42)") - check_stack_repr(gdb_arrow, "null_int_optional", - "arrow::util::optional(nullopt)") - - def test_variants(gdb_arrow): check_stack_repr( gdb_arrow, "int_variant", diff --git a/r/configure.win b/r/configure.win index 8aa05ef5970..1eb72e15613 100755 --- a/r/configure.win +++ b/r/configure.win @@ -139,7 +139,12 @@ if [ "$ARROW_R_CXXFLAGS" ]; then PKG_CFLAGS="$PKG_CFLAGS $ARROW_R_CXXFLAGS" fi -echo "*** Writing Makevars.win" +echo "*** Writing $(pwd)/src/Makevars.win" sed -e "s|@cflags@|$PKG_CFLAGS|" -e "s|@libs@|$PKG_LIBS|" src/Makevars.in > src/Makevars.win + +echo "*** Contents of $(pwd)/src/Makevars.win" +cat src/Makevars.win +echo "*** /End contents" + # Success exit 0 diff --git a/r/src/Makevars.ucrt b/r/src/Makevars.ucrt index 52488eb2b85..a91dedc2d55 100644 --- a/r/src/Makevars.ucrt +++ b/r/src/Makevars.ucrt @@ -17,3 +17,6 @@ CRT=-ucrt include Makevars.win + +# XXX for some reason, this variable doesn't seem propagated from Makevars.win +CXX_STD=CXX17 diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index f9183a3a103..abcb418a2c2 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -25,10 +25,10 @@ #include #include #include -#include #include #include +#include namespace compute = ::arrow::compute; @@ -66,7 +66,7 @@ ExecPlan_prepare(const std::shared_ptr& plan, // For now, don't require R to construct SinkNodes. // Instead, just pass the node we should collect as an argument. - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; // Sorting uses a different sink node; there is no general sort yet if (sort_options.size() > 0) { @@ -170,7 +170,7 @@ std::string ExecPlan_BuildAndShow(const std::shared_ptr& plan // For now, don't require R to construct SinkNodes. // Instead, just pass the node we should collect as an argument. - arrow::AsyncGenerator> sink_gen; + arrow::AsyncGenerator> sink_gen; // Sorting uses a different sink node; there is no general sort yet if (sort_options.size() > 0) { diff --git a/r/src/config.cpp b/r/src/config.cpp index 1d322205b5d..a45df73a64a 100644 --- a/r/src/config.cpp +++ b/r/src/config.cpp @@ -17,8 +17,9 @@ #include "./arrow_types.h" +#include + #include -#include // [[arrow::export]] std::vector build_info() { @@ -41,6 +42,6 @@ void set_timezone_database(cpp11::strings path) { } arrow::GlobalOptions options; - options.timezone_db_path = arrow::util::make_optional(paths[0]); + options.timezone_db_path = std::make_optional(paths[0]); arrow::StopIfNotOk(arrow::Initialize(options)); } From 7cfdfbb0d5472f8f8893398b51042a3ca1dd0adf Mon Sep 17 00:00:00 2001 From: Quang Hoang Date: Thu, 15 Sep 2022 19:06:53 +0700 Subject: [PATCH 066/133] ARROW-17052: [C++][Python][FlightRPC] expose flight structures serialize (#13986) Expose serialize and deserialize for flight structures in C++ and Python: Action, ActionType, Criteria, FlightEndpoint, Result and SchemaResult Notes of no change on Cython binding for: * Criteria and PutResult (FlightClient::DoPutResult) aren't exposed directly * ActionType is implemented with named-tuple Lead-authored-by: Quang Hoang Co-authored-by: Quang Hoang Xuan Signed-off-by: David Li --- cpp/src/arrow/flight/flight_internals_test.cc | 53 ++++- cpp/src/arrow/flight/serialization_internal.h | 1 + cpp/src/arrow/flight/types.cc | 185 +++++++++++++++++- cpp/src/arrow/flight/types.h | 82 ++++++++ python/pyarrow/_flight.pyx | 104 +++++++++- python/pyarrow/includes/libarrow_flight.pxd | 34 +++- python/pyarrow/tests/test_flight.py | 19 ++ 7 files changed, 466 insertions(+), 12 deletions(-) diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc index 1275db6a8d4..f315e42a6a6 100644 --- a/cpp/src/arrow/flight/flight_internals_test.cc +++ b/cpp/src/arrow/flight/flight_internals_test.cc @@ -83,22 +83,60 @@ TEST(FlightTypes, LocationUnknownScheme) { } TEST(FlightTypes, RoundTripTypes) { + ActionType action_type{"action-type1", "action-type1-description"}; + ASSERT_OK_AND_ASSIGN(std::string action_type_serialized, + action_type.SerializeToString()); + ASSERT_OK_AND_ASSIGN(ActionType action_type_deserialized, + ActionType::Deserialize(action_type_serialized)); + ASSERT_EQ(action_type, action_type_deserialized); + + Criteria criteria{"criteria1"}; + ASSERT_OK_AND_ASSIGN(std::string criteria_serialized, criteria.SerializeToString()); + ASSERT_OK_AND_ASSIGN(Criteria criteria_deserialized, + Criteria::Deserialize(criteria_serialized)); + ASSERT_EQ(criteria, criteria_deserialized); + + Action action{"action1", Buffer::FromString("action1-content")}; + ASSERT_OK_AND_ASSIGN(std::string action_serialized, action.SerializeToString()); + ASSERT_OK_AND_ASSIGN(Action action_deserialized, + Action::Deserialize(action_serialized)); + ASSERT_EQ(action, action_deserialized); + + Result result{Buffer::FromString("result1-content")}; + ASSERT_OK_AND_ASSIGN(std::string result_serialized, result.SerializeToString()); + ASSERT_OK_AND_ASSIGN(Result result_deserialized, + Result::Deserialize(result_serialized)); + ASSERT_EQ(result, result_deserialized); + + BasicAuth basic_auth{"username1", "password1"}; + ASSERT_OK_AND_ASSIGN(std::string basic_auth_serialized, basic_auth.SerializeToString()); + ASSERT_OK_AND_ASSIGN(BasicAuth basic_auth_deserialized, + BasicAuth::Deserialize(basic_auth_serialized)); + ASSERT_EQ(basic_auth, basic_auth_deserialized); + + SchemaResult schema_result{"schema_result1"}; + ASSERT_OK_AND_ASSIGN(std::string schema_result_serialized, + schema_result.SerializeToString()); + ASSERT_OK_AND_ASSIGN(SchemaResult schema_result_deserialized, + SchemaResult::Deserialize(schema_result_serialized)); + ASSERT_EQ(schema_result, schema_result_deserialized); + Ticket ticket{"foo"}; ASSERT_OK_AND_ASSIGN(std::string ticket_serialized, ticket.SerializeToString()); ASSERT_OK_AND_ASSIGN(Ticket ticket_deserialized, Ticket::Deserialize(ticket_serialized)); - ASSERT_EQ(ticket.ticket, ticket_deserialized.ticket); + ASSERT_EQ(ticket, ticket_deserialized); FlightDescriptor desc = FlightDescriptor::Command("select * from foo;"); ASSERT_OK_AND_ASSIGN(std::string desc_serialized, desc.SerializeToString()); ASSERT_OK_AND_ASSIGN(FlightDescriptor desc_deserialized, FlightDescriptor::Deserialize(desc_serialized)); - ASSERT_TRUE(desc.Equals(desc_deserialized)); + ASSERT_EQ(desc, desc_deserialized); desc = FlightDescriptor::Path({"a", "b", "test.arrow"}); ASSERT_OK_AND_ASSIGN(desc_serialized, desc.SerializeToString()); ASSERT_OK_AND_ASSIGN(desc_deserialized, FlightDescriptor::Deserialize(desc_serialized)); - ASSERT_TRUE(desc.Equals(desc_deserialized)); + ASSERT_EQ(desc, desc_deserialized); FlightInfo::Data data; std::shared_ptr schema = @@ -114,10 +152,17 @@ TEST(FlightTypes, RoundTripTypes) { ASSERT_OK_AND_ASSIGN(std::string info_serialized, info->SerializeToString()); ASSERT_OK_AND_ASSIGN(std::unique_ptr info_deserialized, FlightInfo::Deserialize(info_serialized)); - ASSERT_TRUE(info->descriptor().Equals(info_deserialized->descriptor())); + ASSERT_EQ(info->descriptor(), info_deserialized->descriptor()); ASSERT_EQ(info->endpoints(), info_deserialized->endpoints()); ASSERT_EQ(info->total_records(), info_deserialized->total_records()); ASSERT_EQ(info->total_bytes(), info_deserialized->total_bytes()); + + FlightEndpoint flight_endpoint{ticket, {location1, location2}}; + ASSERT_OK_AND_ASSIGN(std::string flight_endpoint_serialized, + flight_endpoint.SerializeToString()); + ASSERT_OK_AND_ASSIGN(FlightEndpoint flight_endpoint_deserialized, + FlightEndpoint::Deserialize(flight_endpoint_serialized)); + ASSERT_EQ(flight_endpoint, flight_endpoint_deserialized); } TEST(FlightTypes, RoundtripStatus) { diff --git a/cpp/src/arrow/flight/serialization_internal.h b/cpp/src/arrow/flight/serialization_internal.h index c27bc79b315..0e1d7a6d843 100644 --- a/cpp/src/arrow/flight/serialization_internal.h +++ b/cpp/src/arrow/flight/serialization_internal.h @@ -60,6 +60,7 @@ Status FromProto(const pb::SchemaResult& pb_result, std::string* result); Status FromProto(const pb::BasicAuth& pb_basic_auth, BasicAuth* info); Status ToProto(const FlightDescriptor& descr, pb::FlightDescriptor* pb_descr); +Status ToProto(const FlightEndpoint& endpoint, pb::FlightEndpoint* pb_endpoint); Status ToProto(const FlightInfo& info, pb::FlightInfo* pb_info); Status ToProto(const ActionType& type, pb::ActionType* pb_type); Status ToProto(const Action& action, pb::Action* pb_action); diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index 6e80f40cfbf..a505e6d6e1e 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -162,13 +162,42 @@ Status SchemaResult::GetSchema(ipc::DictionaryMemo* dictionary_memo, return GetSchema(dictionary_memo).Value(out); } +bool SchemaResult::Equals(const SchemaResult& other) const { + return raw_schema_ == other.raw_schema_; +} + +arrow::Result SchemaResult::SerializeToString() const { + pb::SchemaResult pb_schema_result; + RETURN_NOT_OK(internal::ToProto(*this, &pb_schema_result)); + + std::string out; + if (!pb_schema_result.SerializeToString(&out)) { + return Status::IOError("Serialized SchemaResult exceeded 2 GiB limit"); + } + return out; +} + +arrow::Result SchemaResult::Deserialize( + arrow::util::string_view serialized) { + pb::SchemaResult pb_schema_result; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid("Serialized SchemaResult size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_schema_result.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid SchemaResult"); + } + return SchemaResult{pb_schema_result.schema()}; +} + arrow::Result FlightDescriptor::SerializeToString() const { pb::FlightDescriptor pb_descriptor; RETURN_NOT_OK(internal::ToProto(*this, &pb_descriptor)); std::string out; if (!pb_descriptor.SerializeToString(&out)) { - return Status::IOError("Serialized descriptor exceeded 2 GiB limit"); + return Status::IOError("Serialized FlightDescriptor exceeded 2 GiB limit"); } return out; } @@ -186,7 +215,7 @@ arrow::Result FlightDescriptor::Deserialize( google::protobuf::io::ArrayInputStream input(serialized.data(), static_cast(serialized.size())); if (!pb_descriptor.ParseFromZeroCopyStream(&input)) { - return Status::Invalid("Not a valid descriptor"); + return Status::Invalid("Not a valid FlightDescriptor"); } FlightDescriptor out; RETURN_NOT_OK(internal::FromProto(pb_descriptor, &out)); @@ -206,7 +235,7 @@ arrow::Result Ticket::SerializeToString() const { std::string out; if (!pb_ticket.SerializeToString(&out)) { - return Status::IOError("Serialized ticket exceeded 2 GiB limit"); + return Status::IOError("Serialized Ticket exceeded 2 GiB limit"); } return out; } @@ -223,7 +252,7 @@ arrow::Result Ticket::Deserialize(arrow::util::string_view serialized) { google::protobuf::io::ArrayInputStream input(serialized.data(), static_cast(serialized.size())); if (!pb_ticket.ParseFromZeroCopyStream(&input)) { - return Status::Invalid("Not a valid ticket"); + return Status::Invalid("Not a valid Ticket"); } Ticket out; RETURN_NOT_OK(internal::FromProto(pb_ticket, &out)); @@ -370,10 +399,154 @@ bool FlightEndpoint::Equals(const FlightEndpoint& other) const { return ticket == other.ticket && locations == other.locations; } +arrow::Result FlightEndpoint::SerializeToString() const { + pb::FlightEndpoint pb_flight_endpoint; + RETURN_NOT_OK(internal::ToProto(*this, &pb_flight_endpoint)); + + std::string out; + if (!pb_flight_endpoint.SerializeToString(&out)) { + return Status::IOError("Serialized FlightEndpoint exceeded 2 GiB limit"); + } + return out; +} + +arrow::Result FlightEndpoint::Deserialize( + arrow::util::string_view serialized) { + pb::FlightEndpoint pb_flight_endpoint; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid("Serialized FlightEndpoint size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_flight_endpoint.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid FlightEndpoint"); + } + FlightEndpoint out; + RETURN_NOT_OK(internal::FromProto(pb_flight_endpoint, &out)); + return out; +} + bool ActionType::Equals(const ActionType& other) const { return type == other.type && description == other.description; } +arrow::Result ActionType::SerializeToString() const { + pb::ActionType pb_action_type; + RETURN_NOT_OK(internal::ToProto(*this, &pb_action_type)); + + std::string out; + if (!pb_action_type.SerializeToString(&out)) { + return Status::IOError("Serialized ActionType exceeded 2 GiB limit"); + } + return out; +} + +arrow::Result ActionType::Deserialize(arrow::util::string_view serialized) { + pb::ActionType pb_action_type; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid("Serialized ActionType size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_action_type.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid ActionType"); + } + ActionType out; + RETURN_NOT_OK(internal::FromProto(pb_action_type, &out)); + return out; +} + +bool Criteria::Equals(const Criteria& other) const { + return expression == other.expression; +} + +arrow::Result Criteria::SerializeToString() const { + pb::Criteria pb_criteria; + RETURN_NOT_OK(internal::ToProto(*this, &pb_criteria)); + + std::string out; + if (!pb_criteria.SerializeToString(&out)) { + return Status::IOError("Serialized Criteria exceeded 2 GiB limit"); + } + return out; +} + +arrow::Result Criteria::Deserialize(arrow::util::string_view serialized) { + pb::Criteria pb_criteria; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid("Serialized Criteria size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_criteria.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid Criteria"); + } + Criteria out; + RETURN_NOT_OK(internal::FromProto(pb_criteria, &out)); + return out; +} + +bool Action::Equals(const Action& other) const { + return (type == other.type) && + ((body == other.body) || (body && other.body && body->Equals(*other.body))); +} + +arrow::Result Action::SerializeToString() const { + pb::Action pb_action; + RETURN_NOT_OK(internal::ToProto(*this, &pb_action)); + + std::string out; + if (!pb_action.SerializeToString(&out)) { + return Status::IOError("Serialized Action exceeded 2 GiB limit"); + } + return out; +} + +arrow::Result Action::Deserialize(arrow::util::string_view serialized) { + pb::Action pb_action; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid("Serialized Action size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_action.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid Action"); + } + Action out; + RETURN_NOT_OK(internal::FromProto(pb_action, &out)); + return out; +} + +bool Result::Equals(const Result& other) const { + return (body == other.body) || (body && other.body && body->Equals(*other.body)); +} + +arrow::Result Result::SerializeToString() const { + pb::Result pb_result; + RETURN_NOT_OK(internal::ToProto(*this, &pb_result)); + + std::string out; + if (!pb_result.SerializeToString(&out)) { + return Status::IOError("Serialized Result exceeded 2 GiB limit"); + } + return out; +} + +arrow::Result Result::Deserialize(arrow::util::string_view serialized) { + pb::Result pb_result; + if (serialized.size() > static_cast(std::numeric_limits::max())) { + return Status::Invalid("Serialized Result size should not exceed 2 GiB"); + } + google::protobuf::io::ArrayInputStream input(serialized.data(), + static_cast(serialized.size())); + if (!pb_result.ParseFromZeroCopyStream(&input)) { + return Status::Invalid("Not a valid Result"); + } + Result out; + RETURN_NOT_OK(internal::FromProto(pb_result, &out)); + return out; +} + Status ResultStream::Next(std::unique_ptr* info) { return Next().Value(info); } Status MetadataRecordBatchReader::Next(FlightStreamChunk* next) { @@ -468,6 +641,10 @@ arrow::Result> SimpleResultStream::Next() { return std::unique_ptr(new Result(std::move(results_[position_++]))); } +bool BasicAuth::Equals(const BasicAuth& other) const { + return (username == other.username) && (password == other.password); +} + arrow::Result BasicAuth::Deserialize(arrow::util::string_view serialized) { pb::BasicAuth pb_result; if (serialized.size() > static_cast(std::numeric_limits::max())) { diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 2ec24ff5868..ae9867e44a1 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -148,12 +148,33 @@ struct ARROW_FLIGHT_EXPORT ActionType { friend bool operator!=(const ActionType& left, const ActionType& right) { return !(left == right); } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(arrow::util::string_view serialized); }; /// \brief Opaque selection criteria for ListFlights RPC struct ARROW_FLIGHT_EXPORT Criteria { /// Opaque criteria expression, dependent on server implementation std::string expression; + + bool Equals(const Criteria& other) const; + + friend bool operator==(const Criteria& left, const Criteria& right) { + return left.Equals(right); + } + friend bool operator!=(const Criteria& left, const Criteria& right) { + return !(left == right); + } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(arrow::util::string_view serialized); }; /// \brief An action to perform with the DoAction RPC @@ -163,11 +184,41 @@ struct ARROW_FLIGHT_EXPORT Action { /// The action content as a Buffer std::shared_ptr body; + + bool Equals(const Action& other) const; + + friend bool operator==(const Action& left, const Action& right) { + return left.Equals(right); + } + friend bool operator!=(const Action& left, const Action& right) { + return !(left == right); + } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(arrow::util::string_view serialized); }; /// \brief Opaque result returned after executing an action struct ARROW_FLIGHT_EXPORT Result { std::shared_ptr body; + + bool Equals(const Result& other) const; + + friend bool operator==(const Result& left, const Result& right) { + return left.Equals(right); + } + friend bool operator!=(const Result& left, const Result& right) { + return !(left == right); + } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(arrow::util::string_view serialized); }; /// \brief message for simple auth @@ -175,6 +226,15 @@ struct ARROW_FLIGHT_EXPORT BasicAuth { std::string username; std::string password; + bool Equals(const BasicAuth& other) const; + + friend bool operator==(const BasicAuth& left, const BasicAuth& right) { + return left.Equals(right); + } + friend bool operator!=(const BasicAuth& left, const BasicAuth& right) { + return !(left == right); + } + /// \brief Deserialize this message from its wire-format representation. static arrow::Result Deserialize(arrow::util::string_view serialized); /// \brief Serialize this message to its wire-format representation. @@ -377,6 +437,12 @@ struct ARROW_FLIGHT_EXPORT FlightEndpoint { friend bool operator!=(const FlightEndpoint& left, const FlightEndpoint& right) { return !(left == right); } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(arrow::util::string_view serialized); }; /// \brief Staging data structure for messages about to be put on the wire @@ -394,6 +460,7 @@ struct ARROW_FLIGHT_EXPORT FlightPayload { /// \brief Schema result returned after a schema request RPC struct ARROW_FLIGHT_EXPORT SchemaResult { public: + SchemaResult() = default; explicit SchemaResult(std::string schema) : raw_schema_(std::move(schema)) {} /// \brief Factory method to construct a SchemaResult. @@ -412,6 +479,21 @@ struct ARROW_FLIGHT_EXPORT SchemaResult { const std::string& serialized_schema() const { return raw_schema_; } + bool Equals(const SchemaResult& other) const; + + friend bool operator==(const SchemaResult& left, const SchemaResult& right) { + return left.Equals(right); + } + friend bool operator!=(const SchemaResult& left, const SchemaResult& right) { + return !(left == right); + } + + /// \brief Serialize this message to its wire-format representation. + arrow::Result SerializeToString() const; + + /// \brief Deserialize this message from its wire-format representation. + static arrow::Result Deserialize(arrow::util::string_view serialized); + private: std::string raw_schema_; }; diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index 2ad3f7128c4..16e4aad5a00 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -289,6 +289,31 @@ cdef class Action(_Weakrefable): type(action))) return ( action).action + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + return GetResultValue(self.action.SerializeToString()) + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef Action action = Action.__new__(Action) + action.action = GetResultValue( + CAction.Deserialize(tobytes(serialized))) + return action + + def __eq__(self, Action other): + return self.action == other.action + _ActionType = collections.namedtuple('_ActionType', ['type', 'description']) @@ -327,6 +352,31 @@ cdef class Result(_Weakrefable): """Get the Buffer containing the result.""" return pyarrow_wrap_buffer(self.result.get().body) + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + return GetResultValue(self.result.get().SerializeToString()) + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef Result result = Result.__new__(Result) + result.result.reset(new CFlightResult(GetResultValue( + CFlightResult.Deserialize(tobytes(serialized))))) + return result + + def __eq__(self, Result other): + return deref(self.result.get()) == deref(other.result.get()) + cdef class BasicAuth(_Weakrefable): """A container for basic auth.""" @@ -360,13 +410,16 @@ cdef class BasicAuth(_Weakrefable): @staticmethod def deserialize(serialized): auth = BasicAuth() - check_flight_status( - CBasicAuth.Deserialize(serialized).Value(auth.basic_auth.get())) + auth.basic_auth.reset(new CBasicAuth(GetResultValue( + CBasicAuth.Deserialize(tobytes(serialized))))) return auth def serialize(self): return GetResultValue(self.basic_auth.get().SerializeToString()) + def __eq__(self, BasicAuth other): + return deref(self.basic_auth.get()) == deref(other.basic_auth.get()) + class DescriptorType(enum.Enum): """ @@ -686,6 +739,28 @@ cdef class FlightEndpoint(_Weakrefable): return [Location.wrap(location) for location in self.endpoint.locations] + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + return GetResultValue(self.endpoint.SerializeToString()) + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef FlightEndpoint endpoint = FlightEndpoint.__new__(FlightEndpoint) + endpoint.endpoint = GetResultValue( + CFlightEndpoint.Deserialize(tobytes(serialized))) + return endpoint + def __repr__(self): return "".format( self.ticket, self.locations) @@ -721,6 +796,31 @@ cdef class SchemaResult(_Weakrefable): check_flight_status(self.result.get().GetSchema(&dummy_memo).Value(&schema)) return pyarrow_wrap_schema(schema) + def serialize(self): + """Get the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + return GetResultValue(self.result.get().SerializeToString()) + + @classmethod + def deserialize(cls, serialized): + """Parse the wire-format representation of this type. + + Useful when interoperating with non-Flight systems (e.g. REST + services) that may want to return Flight types. + + """ + cdef SchemaResult result = SchemaResult.__new__(SchemaResult) + result.result.reset(new CSchemaResult(GetResultValue( + CSchemaResult.Deserialize(tobytes(serialized))))) + return result + + def __eq__(self, SchemaResult other): + return deref(self.result.get()) == deref(other.result.get()) + cdef class FlightInfo(_Weakrefable): """A description of a Flight stream.""" diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index 3698292b5a0..3b9ac54fe9d 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -28,15 +28,30 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: cdef cppclass CActionType" arrow::flight::ActionType": c_string type c_string description + bint operator==(CActionType) + CResult[c_string] SerializeToString() + + @staticmethod + CResult[CActionType] Deserialize(const c_string& serialized) cdef cppclass CAction" arrow::flight::Action": c_string type shared_ptr[CBuffer] body + bint operator==(CAction) + CResult[c_string] SerializeToString() + + @staticmethod + CResult[CAction] Deserialize(const c_string& serialized) cdef cppclass CFlightResult" arrow::flight::Result": CFlightResult() CFlightResult(CFlightResult) shared_ptr[CBuffer] body + bint operator==(CFlightResult) + CResult[c_string] SerializeToString() + + @staticmethod + CResult[CFlightResult] Deserialize(const c_string& serialized) cdef cppclass CBasicAuth" arrow::flight::BasicAuth": CBasicAuth() @@ -44,7 +59,7 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CBasicAuth(CBasicAuth) c_string username c_string password - + bint operator==(CBasicAuth) CResult[c_string] SerializeToString() @staticmethod @@ -68,11 +83,11 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CDescriptorType type c_string cmd vector[c_string] path + bint operator==(CFlightDescriptor) CResult[c_string] SerializeToString() @staticmethod CResult[CFlightDescriptor] Deserialize(const c_string& serialized) - bint operator==(CFlightDescriptor) cdef cppclass CTicket" arrow::flight::Ticket": CTicket() @@ -86,6 +101,11 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: cdef cppclass CCriteria" arrow::flight::Criteria": CCriteria() c_string expression + bint operator==(CCriteria) + CResult[c_string] SerializeToString() + + @staticmethod + CResult[CCriteria] Deserialize(const c_string& serialized) cdef cppclass CLocation" arrow::flight::Location": CLocation() @@ -111,6 +131,10 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: vector[CLocation] locations bint operator==(CFlightEndpoint) + CResult[c_string] SerializeToString() + + @staticmethod + CResult[CFlightEndpoint] Deserialize(const c_string& serialized) cdef cppclass CFlightInfo" arrow::flight::FlightInfo": CFlightInfo(CFlightInfo info) @@ -126,8 +150,14 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: const c_string& serialized) cdef cppclass CSchemaResult" arrow::flight::SchemaResult": + CSchemaResult() CSchemaResult(CSchemaResult result) CResult[shared_ptr[CSchema]] GetSchema(CDictionaryMemo* memo) + bint operator==(CSchemaResult) + CResult[c_string] SerializeToString() + + @staticmethod + CResult[CSchemaResult] Deserialize(const c_string& serialized) cdef cppclass CFlightListing" arrow::flight::FlightListing": CResult[unique_ptr[CFlightInfo]] Next() diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index 905efa564b0..72d1fa5ec33 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -1560,9 +1560,22 @@ def block_read(): def test_roundtrip_types(): """Make sure serializable types round-trip.""" + action = flight.Action("action1", b"action1-body") + assert action == flight.Action.deserialize(action.serialize()) + ticket = flight.Ticket("foo") assert ticket == flight.Ticket.deserialize(ticket.serialize()) + result = flight.Result(b"result1") + assert result == flight.Result.deserialize(result.serialize()) + + basic_auth = flight.BasicAuth("username1", "password1") + assert basic_auth == flight.BasicAuth.deserialize(basic_auth.serialize()) + + schema_result = flight.SchemaResult(pa.schema([('a', pa.int32())])) + assert schema_result == flight.SchemaResult.deserialize( + schema_result.serialize()) + desc = flight.FlightDescriptor.for_command("test") assert desc == flight.FlightDescriptor.deserialize(desc.serialize()) @@ -1589,6 +1602,12 @@ def test_roundtrip_types(): assert info.total_records == info2.total_records assert info.endpoints == info2.endpoints + endpoint = flight.FlightEndpoint( + ticket, + ['grpc://test', flight.Location.for_grpc_tcp('localhost', 5005)] + ) + assert endpoint == flight.FlightEndpoint.deserialize(endpoint.serialize()) + def test_roundtrip_errors(): """Ensure that Flight errors propagate from server to client.""" From e8aa1b9b0b5bb62f1b9fdadfd81c1fa4bdab8f52 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 15 Sep 2022 11:46:24 -0400 Subject: [PATCH 067/133] MINOR: [Java] Bump commons-io from 2.6 to 2.7 in /java/flight/flight-sql-jdbc-driver (#14128) Bumps commons-io from 2.6 to 2.7. [![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=commons-io:commons-io&package-manager=maven&previous-version=2.6&new-version=2.7)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) - `@dependabot use these labels` will set the current labels as the default for future PRs for this repo and language - `@dependabot use these reviewers` will set the current reviewers as the default for future PRs for this repo and language - `@dependabot use these assignees` will set the current assignees as the default for future PRs for this repo and language - `@dependabot use this milestone` will set the current milestone as the default for future PRs for this repo and language You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/apache/arrow/network/alerts).
Authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Signed-off-by: David Li --- java/flight/flight-sql-jdbc-driver/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/java/flight/flight-sql-jdbc-driver/pom.xml b/java/flight/flight-sql-jdbc-driver/pom.xml index b8a49165adb..778a385d250 100644 --- a/java/flight/flight-sql-jdbc-driver/pom.xml +++ b/java/flight/flight-sql-jdbc-driver/pom.xml @@ -106,7 +106,7 @@ commons-io commons-io - 2.6 + 2.7 test From d8f64eecf34d7c8c347b364bac56ea078b0817f9 Mon Sep 17 00:00:00 2001 From: Alenka Frim Date: Thu, 15 Sep 2022 20:01:34 +0200 Subject: [PATCH 068/133] ARROW-17172: [C++][Python] test_cython_api fails on windows (#14133) Tis PR adds `CONDA_DLL_SEARCH_MODIFICATION_ENABLE=1` to the AppVeyor setup to make `test_cython.py` succeed as the extension module being built fails loading in a subprocess otherwise. Lead-authored-by: Alenka Frim Co-authored-by: Alenka Frim Signed-off-by: Alenka Frim --- ci/appveyor-cpp-setup.bat | 4 ++++ python/pyarrow/tests/test_cython.py | 2 -- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ci/appveyor-cpp-setup.bat b/ci/appveyor-cpp-setup.bat index f9390e9be5a..9e4e4ad5dce 100644 --- a/ci/appveyor-cpp-setup.bat +++ b/ci/appveyor-cpp-setup.bat @@ -55,6 +55,10 @@ mamba update -q -y -c conda-forge --all || exit /B @rem Create conda environment @rem +@rem Workaround for ARROW-17172 +@rem This seems necessary for test_cython.py to succeed, otherwise +@rem the extension module being built would fail loading in a subprocess. +set CONDA_DLL_SEARCH_MODIFICATION_ENABLE=1 set CONDA_PACKAGES= if "%ARROW_BUILD_GANDIVA%" == "ON" ( diff --git a/python/pyarrow/tests/test_cython.py b/python/pyarrow/tests/test_cython.py index 6c1c47f2a24..7152f3f1d44 100644 --- a/python/pyarrow/tests/test_cython.py +++ b/python/pyarrow/tests/test_cython.py @@ -81,8 +81,6 @@ def check_cython_example_module(mod): mod.cast_scalar(scal, pa.list_(pa.int64())) -@pytest.mark.skipif(sys.platform == "win32", - reason="ARROW-17172: currently fails on windows") @pytest.mark.cython def test_cython_api(tmpdir): """ From 5c773bb9220c26df9c427d325563b4edfa2cd0b9 Mon Sep 17 00:00:00 2001 From: eitsupi <50911393+eitsupi@users.noreply.github.com> Date: Fri, 16 Sep 2022 03:17:20 +0900 Subject: [PATCH 069/133] ARROW-17673: [R] `desc` in `dplyr::arrange` should allow `dplyr::` prefix (#14090) Authored-by: SHIMA Tatsuya Signed-off-by: Neal Richardson --- r/R/dplyr-arrange.R | 2 +- r/tests/testthat/test-dplyr-arrange.R | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/r/R/dplyr-arrange.R b/r/R/dplyr-arrange.R index 2f9ef61bb31..39388394d5b 100644 --- a/r/R/dplyr-arrange.R +++ b/r/R/dplyr-arrange.R @@ -77,7 +77,7 @@ find_and_remove_desc <- function(quosure) { if (identical(expr[[1]], quote(`(`))) { # remove enclosing parentheses expr <- expr[[2]] - } else if (identical(expr[[1]], quote(desc))) { + } else if (identical(expr[[1]], quote(desc)) || identical(expr[[1]], quote(dplyr::desc))) { # ensure desc() has only one argument (when an R expression is a function # call, length == 2 means it has exactly one argument) if (length(expr) > 2) { diff --git a/r/tests/testthat/test-dplyr-arrange.R b/r/tests/testthat/test-dplyr-arrange.R index edec572d10f..d8afcc5d4a8 100644 --- a/r/tests/testthat/test-dplyr-arrange.R +++ b/r/tests/testthat/test-dplyr-arrange.R @@ -33,12 +33,24 @@ test_that("arrange() on integer, double, and character columns", { collect(), tbl ) + compare_dplyr_binding( + .input %>% + arrange(int, dplyr::desc(dbl)) %>% + collect(), + tbl + ) compare_dplyr_binding( .input %>% arrange(int, desc(desc(dbl))) %>% collect(), tbl ) + compare_dplyr_binding( + .input %>% + arrange(int, dplyr::desc(dplyr::desc(dbl))) %>% + collect(), + tbl + ) compare_dplyr_binding( .input %>% arrange(int) %>% @@ -46,6 +58,13 @@ test_that("arrange() on integer, double, and character columns", { collect(), tbl ) + compare_dplyr_binding( + .input %>% + arrange(int) %>% + arrange(dplyr::desc(dbl)) %>% + collect(), + tbl + ) compare_dplyr_binding( .input %>% arrange(int + dbl, chr) %>% @@ -200,6 +219,13 @@ test_that("arrange() with bad inputs", { "expects only one argument", fixed = TRUE ) + expect_error( + tbl %>% + Table$create() %>% + arrange(dplyr::desc(int, chr)), + "expects only one argument", + fixed = TRUE + ) }) test_that("Can use across() within arrange()", { From 5c13049d9762120a2ff75f5782fe642ef019347b Mon Sep 17 00:00:00 2001 From: Jacob Wujciak-Jens Date: Thu, 15 Sep 2022 20:31:34 +0200 Subject: [PATCH 070/133] ARROW-16190: [CI][R] Implement CI on Apple M1 for R (#14099) Authored-by: Jacob Wujciak-Jens Signed-off-by: Neal Richardson --- dev/tasks/macros.jinja | 4 +-- dev/tasks/python-wheels/github.osx.arm64.yml | 16 +++++------ dev/tasks/r/github.macos.autobrew.yml | 4 +-- dev/tasks/r/github.packages.yml | 28 ++++++++++++-------- dev/tasks/tasks.yml | 3 ++- dev/tasks/verify-rc/github.macos.arm64.yml | 2 +- 6 files changed, 32 insertions(+), 25 deletions(-) diff --git a/dev/tasks/macros.jinja b/dev/tasks/macros.jinja index af3f6105645..fd46555b28b 100644 --- a/dev/tasks/macros.jinja +++ b/dev/tasks/macros.jinja @@ -354,8 +354,8 @@ on: {# use filter to cast to string and convert to lowercase to match yaml boolean #} {% set is_fork = (not is_upstream_b)|lower %} -{% set r_release = '4.2' %} -{% set r_oldrel = '4.1' %} +{% set r_release = {"ver": "4.2", "rt" : "42"} %} +{% set r_oldrel = {"ver": "4.1", "rt" : "40"} %} {%- macro github_set_env(env) -%} {% if env is defined %} diff --git a/dev/tasks/python-wheels/github.osx.arm64.yml b/dev/tasks/python-wheels/github.osx.arm64.yml index e5be422e2c3..a02661f4920 100644 --- a/dev/tasks/python-wheels/github.osx.arm64.yml +++ b/dev/tasks/python-wheels/github.osx.arm64.yml @@ -58,7 +58,7 @@ jobs: - name: Install Vcpkg env: MACOSX_DEPLOYMENT_TARGET: "11.0" - run: arch -arm64 arrow/ci/scripts/install_vcpkg.sh $VCPKG_ROOT $VCPKG_VERSION + run: arrow/ci/scripts/install_vcpkg.sh $VCPKG_ROOT $VCPKG_VERSION - name: Add Vcpkg to PATH run: echo ${VCPKG_ROOT} >> $GITHUB_PATH @@ -67,7 +67,7 @@ jobs: env: VCPKG_DEFAULT_TRIPLET: arm64-osx-static-release run: | - arch -arm64 vcpkg install \ + vcpkg install \ --clean-after-build \ --x-install-root=${VCPKG_ROOT}/installed \ --x-manifest-root=arrow/ci/vcpkg \ @@ -85,14 +85,14 @@ jobs: $PYTHON -m venv build-arm64-env source build-arm64-env/bin/activate pip install --upgrade pip wheel - arch -arm64 arrow/ci/scripts/python_wheel_macos_build.sh arm64 $(pwd)/arrow $(pwd)/build + arrow/ci/scripts/python_wheel_macos_build.sh arm64 $(pwd)/arrow $(pwd)/build {% if arch == "universal2" %} - name: Install AMD64 Packages env: VCPKG_DEFAULT_TRIPLET: amd64-osx-static-release run: | - arch -arm64 vcpkg install \ + vcpkg install \ --clean-after-build \ --x-install-root=${VCPKG_ROOT}/installed \ --x-manifest-root=arrow/ci/vcpkg \ @@ -138,10 +138,10 @@ jobs: # libffi has to be installed on the m1 runner which causes issues with # the cffi wheel. We build cffi with the flags pointing to the correct libffi location. LDFLAGS=-L$(brew --prefix libffi)/lib CFLAGS=-I$(brew --prefix libffi)/include \ - arch -arm64 pip install cffi --no-binary :all: - arch -arm64 pip install -r arrow/python/requirements-wheel-test.txt - PYTHON=python arch -arm64 arrow/ci/scripts/install_gcs_testbench.sh default - arch -arm64 arrow/ci/scripts/python_wheel_unix_test.sh $(pwd)/arrow + pip install cffi --no-binary :all: + pip install -r arrow/python/requirements-wheel-test.txt + PYTHON=python arrow/ci/scripts/install_gcs_testbench.sh default + arrow/ci/scripts/python_wheel_unix_test.sh $(pwd)/arrow {% if arch == "universal2" %} - name: Test Wheel on AMD64 diff --git a/dev/tasks/r/github.macos.autobrew.yml b/dev/tasks/r/github.macos.autobrew.yml index fd5492873db..2664455479a 100644 --- a/dev/tasks/r/github.macos.autobrew.yml +++ b/dev/tasks/r/github.macos.autobrew.yml @@ -31,8 +31,8 @@ jobs: - macos-11 - macos-10.13 # self-hosted r-version: - - {{ macros.r_release }} - - {{ macros.r_oldrel }} + - "{{ macros.r_release.ver }}" + - "{{ macros.r_oldrel.ver }}" steps: {{ macros.github_checkout_arrow()|indent }} - name: Configure autobrew script diff --git a/dev/tasks/r/github.packages.yml b/dev/tasks/r/github.packages.yml index 30dd0641d21..49ca49567a9 100644 --- a/dev/tasks/r/github.packages.yml +++ b/dev/tasks/r/github.packages.yml @@ -131,31 +131,36 @@ jobs: r-packages: needs: [source, windows-cpp] - name: {{ '${{ matrix.platform }} ${{ matrix.r_version.r }}' }} - runs-on: {{ '${{ matrix.platform }}' }} + name: {{ '${{ matrix.platform.name }} ${{ matrix.r_version.r }}' }} + runs-on: {{ '${{ matrix.platform.runs_on }}' }} strategy: fail-fast: false matrix: platform: - - windows-latest - - macos-10.13 # self-hosted - # - devops-managed # No M1 until the runner application runs native + - { runs_on: 'windows-latest', name: "Windows"} + - { runs_on: ["self-hosted", "macos-10.13"], name: "macOS High Sierra"} + - { runs_on: ["self-hosted", "macOS", "arm64", "devops-managed"], name: "macOS Big Sur" } r_version: - - { rtools: 40, r: "4.1" } - - { rtools: 42, r: "4.2" } + - { rtools: "{{ macros.r_release.rt }}", r: "{{ macros.r_release.ver }}" } + - { rtools: "{{ macros.r_oldrel.rt }}", r: "{{ macros.r_oldrel.ver }}" } steps: - uses: r-lib/actions/setup-r@v2 - if: matrix.platform != 'macos-10.13' + # expression marker prevents the ! being parsed as yaml tag + if: {{ "${{ !contains(matrix.platform.runs_on, 'self-hosted') }}" }} with: r-version: {{ '${{ matrix.r_version.r }}' }} rtools-version: {{ '${{ matrix.r_version.rtools }}' }} Ncpus: 2 - name: Setup R Self-Hosted - if: matrix.platform == 'macos-10.13' + if: contains(matrix.platform.runs_on, 'self-hosted') run: | + if [ "{{ "${{ contains(matrix.platform.runs_on, 'arm64') }}" }}" == "true" ]; then + rig_arch="-arm64" + fi # rig is a system utility that allows for switching # between pre-installed R version on the self-hosted runners - rig default {{ '${{ matrix.r_version.r }}' }} + rig default {{ '${{ matrix.r_version.r }}' }}$rig_arch + rig system setup-user-lib rig system add-pak {{ macros.github_setup_local_r_repo(false, true)|indent }} @@ -164,7 +169,8 @@ jobs: shell: bash run: | tar -xzf repo/src/contrib/arrow_*.tar.gz arrow/DESCRIPTION - - uses: r-lib/actions/setup-r-dependencies@v2 + - name: Install dependencies + uses: r-lib/actions/setup-r-dependencies@v2 with: working-directory: 'arrow' extra-packages: cpp11 diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index 98af1a194bf..7cb20488009 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -958,6 +958,8 @@ tasks: - r-pkg__bin__windows__contrib__4.2__arrow_[0-9\.]+\.zip - r-pkg__bin__macosx__contrib__4.1__arrow_[0-9\.]+\.tgz - r-pkg__bin__macosx__contrib__4.2__arrow_[0-9\.]+\.tgz + - r-pkg__bin__macosx__big-sur-arm64__contrib__4.1__arrow_[0-9\.]+\.tgz + - r-pkg__bin__macosx__big-sur-arm64__contrib__4.2__arrow_[0-9\.]+\.tgz - r-pkg__src__contrib__arrow_[0-9\.]+\.tar\.gz @@ -1085,7 +1087,6 @@ tasks: env: PYTEST_ADDOPTS: "-k 'not test_cancellation'" github_runner: ["self-hosted", "macOS", "arm64"] - arch_emulation: arm64 target: "wheels" ######################## Windows verification ############################## diff --git a/dev/tasks/verify-rc/github.macos.arm64.yml b/dev/tasks/verify-rc/github.macos.arm64.yml index 10f684e6ac1..79cdf4479eb 100644 --- a/dev/tasks/verify-rc/github.macos.arm64.yml +++ b/dev/tasks/verify-rc/github.macos.arm64.yml @@ -46,5 +46,5 @@ jobs: export PATH="$(brew --prefix node@16)/bin:$PATH" export PATH="$(brew --prefix ruby)/bin:$PATH" export PKG_CONFIG_PATH="$(brew --prefix ruby)/lib/pkgconfig" - arch -{{ arch_emulation|default("arm64") }} arrow/dev/release/verify-release-candidate.sh \ + arrow/dev/release/verify-release-candidate.sh \ {{ release|default("") }} {{ rc|default("") }} From bced07de986b9f7baaa4885d7f03f4e46d1abed8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ra=C3=BAl=20Cumplido?= Date: Thu, 15 Sep 2022 22:53:10 +0200 Subject: [PATCH 071/133] ARROW-17628: [CI][Packaging][Java] Publish latest nightly with SNAPSHOT version (#14135) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In order to help automate nightly distribution upload latest nightly with a fix SNAPSHOT version. The new artifact names will be like: ``` - flight-integration-tests-10.0.0-SNAPSHOT-tests.jar - flight-integration-tests-10.0.0-SNAPSHOT.jar - flight-integration-tests-10.0.0-SNAPSHOT.pom - flight-sql-10.0.0-SNAPSHOT-javadoc.jar - flight-sql-10.0.0-SNAPSHOT-sources.jar - flight-sql-10.0.0-SNAPSHOT-tests.jar - flight-sql-10.0.0-SNAPSHOT.jar - flight-sql-10.0.0-SNAPSHOT.pom ``` Authored-by: Raúl Cumplido Signed-off-by: Sutou Kouhei --- .github/workflows/java_nightly.yml | 8 +- dev/archery/archery/crossbow/core.py | 17 ++- dev/tasks/java-jars/github.yml | 2 +- dev/tasks/tasks.yml | 218 +++++++++++++-------------- 4 files changed, 129 insertions(+), 116 deletions(-) diff --git a/.github/workflows/java_nightly.yml b/.github/workflows/java_nightly.yml index 015e28984fb..badb9d94e52 100644 --- a/.github/workflows/java_nightly.yml +++ b/.github/workflows/java_nightly.yml @@ -97,15 +97,19 @@ jobs: - shell: bash name: Build Repository run: | + DATE=$(date +%Y-%m-%d) if [ -z $PREFIX ]; then - PREFIX=nightly-packaging-$(date +%Y-%m-%d)-0 + PREFIX=nightly-packaging-${DATE}-0 fi - PATTERN_TO_GET_LIB_AND_VERSION='([a-z].+)-([0-9]+.[0-9]+.[0-9]+.dev[0-9]+)' + PATTERN_TO_GET_LIB_AND_VERSION='([a-z].+)-([0-9]+.[0-9]+.[0-9]+-SNAPSHOT)' mkdir -p repo/org/apache/arrow/ for LIBRARY in $(ls binaries/$PREFIX/java-jars | grep -E '.jar|.pom' | grep dev); do [[ $LIBRARY =~ $PATTERN_TO_GET_LIB_AND_VERSION ]] mkdir -p repo/org/apache/arrow/${BASH_REMATCH[1]}/${BASH_REMATCH[2]} + mkdir -p repo/org/apache/arrow/${BASH_REMATCH[1]}/${DATE} + # Copy twice to maintain a latest snapshot and some earlier versions cp binaries/$PREFIX/java-jars/$LIBRARY repo/org/apache/arrow/${BASH_REMATCH[1]}/${BASH_REMATCH[2]} + cp binaries/$PREFIX/java-jars/$LIBRARY repo/org/apache/arrow/${BASH_REMATCH[1]}/${DATE} echo "Artifacts $LIBRARY configured" done - name: Prune Repository diff --git a/dev/archery/archery/crossbow/core.py b/dev/archery/archery/crossbow/core.py index 8045254871c..560a9f1e199 100644 --- a/dev/archery/archery/crossbow/core.py +++ b/dev/archery/archery/crossbow/core.py @@ -762,6 +762,14 @@ def __init__(self, head, branch, remote, version, email=None): # '0.16.1-dev10' self.no_rc_semver_version = \ re.sub(r'\.(dev\d+)\Z', r'-\1', self.no_rc_version) + # Substitute dev version for SNAPSHOT + # + # Example: + # + # '10.0.0.dev235' -> + # '10.0.0-SNAPSHOT' + self.no_rc_snapshot_version = re.sub( + r'\.(dev\d+)$', '-SNAPSHOT', self.no_rc_version) @classmethod def from_repo(cls, repo, head=None, branch=None, remote=None, version=None, @@ -1093,15 +1101,16 @@ def from_config(cls, config, target, tasks=None, groups=None, params=None): # instantiate the tasks tasks = {} - versions = {'version': target.version, - 'no_rc_version': target.no_rc_version, - 'no_rc_semver_version': target.no_rc_semver_version} + versions = { + 'version': target.version, + 'no_rc_version': target.no_rc_version, + 'no_rc_semver_version': target.no_rc_semver_version, + 'no_rc_snapshot_version': target.no_rc_snapshot_version} for task_name, task in task_definitions.items(): task = task.copy() artifacts = task.pop('artifacts', None) or [] # because of yaml artifacts = [fn.format(**versions) for fn in artifacts] tasks[task_name] = Task(task_name, artifacts=artifacts, **task) - return cls(target=target, tasks=tasks, params=params, template_searchpath=config.template_searchpath) diff --git a/dev/tasks/java-jars/github.yml b/dev/tasks/java-jars/github.yml index abaddc028e9..5375aa50822 100644 --- a/dev/tasks/java-jars/github.yml +++ b/dev/tasks/java-jars/github.yml @@ -118,7 +118,7 @@ jobs: run: | set -e pushd arrow/java - mvn versions:set -DnewVersion={{ arrow.no_rc_version }} + mvn versions:set -DnewVersion={{ arrow.no_rc_snapshot_version }} popd arrow/ci/scripts/java_full_build.sh \ $GITHUB_WORKSPACE/arrow \ diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index 7cb20488009..3792dc14905 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -818,115 +818,115 @@ tasks: ci: github template: java-jars/github.yml artifacts: - - arrow-algorithm-{no_rc_version}-javadoc.jar - - arrow-algorithm-{no_rc_version}-sources.jar - - arrow-algorithm-{no_rc_version}-tests.jar - - arrow-algorithm-{no_rc_version}.jar - - arrow-algorithm-{no_rc_version}.pom - - arrow-avro-{no_rc_version}-javadoc.jar - - arrow-avro-{no_rc_version}-sources.jar - - arrow-avro-{no_rc_version}-tests.jar - - arrow-avro-{no_rc_version}.jar - - arrow-avro-{no_rc_version}.pom - - arrow-c-data-{no_rc_version}-javadoc.jar - - arrow-c-data-{no_rc_version}-sources.jar - - arrow-c-data-{no_rc_version}-tests.jar - - arrow-c-data-{no_rc_version}.jar - - arrow-c-data-{no_rc_version}.pom - - arrow-compression-{no_rc_version}-javadoc.jar - - arrow-compression-{no_rc_version}-sources.jar - - arrow-compression-{no_rc_version}-tests.jar - - arrow-compression-{no_rc_version}.jar - - arrow-compression-{no_rc_version}.pom - - arrow-dataset-{no_rc_version}-javadoc.jar - - arrow-dataset-{no_rc_version}-sources.jar - - arrow-dataset-{no_rc_version}-tests.jar - - arrow-dataset-{no_rc_version}.jar - - arrow-dataset-{no_rc_version}.pom - - arrow-flight-{no_rc_version}.pom - - arrow-format-{no_rc_version}-javadoc.jar - - arrow-format-{no_rc_version}-sources.jar - - arrow-format-{no_rc_version}-tests.jar - - arrow-format-{no_rc_version}.jar - - arrow-format-{no_rc_version}.pom - - arrow-gandiva-{no_rc_version}-javadoc.jar - - arrow-gandiva-{no_rc_version}-sources.jar - - arrow-gandiva-{no_rc_version}-tests.jar - - arrow-gandiva-{no_rc_version}.jar - - arrow-gandiva-{no_rc_version}.pom - - arrow-java-root-{no_rc_version}-source-release.zip - - arrow-java-root-{no_rc_version}.pom - - arrow-jdbc-{no_rc_version}-javadoc.jar - - arrow-jdbc-{no_rc_version}-sources.jar - - arrow-jdbc-{no_rc_version}-tests.jar - - arrow-jdbc-{no_rc_version}.jar - - arrow-jdbc-{no_rc_version}.pom - - arrow-memory-core-{no_rc_version}-javadoc.jar - - arrow-memory-core-{no_rc_version}-sources.jar - - arrow-memory-core-{no_rc_version}-tests.jar - - arrow-memory-core-{no_rc_version}.jar - - arrow-memory-core-{no_rc_version}.pom - - arrow-memory-netty-{no_rc_version}-javadoc.jar - - arrow-memory-netty-{no_rc_version}-sources.jar - - arrow-memory-netty-{no_rc_version}-tests.jar - - arrow-memory-netty-{no_rc_version}.jar - - arrow-memory-netty-{no_rc_version}.pom - - arrow-memory-unsafe-{no_rc_version}-javadoc.jar - - arrow-memory-unsafe-{no_rc_version}-sources.jar - - arrow-memory-unsafe-{no_rc_version}-tests.jar - - arrow-memory-unsafe-{no_rc_version}.jar - - arrow-memory-unsafe-{no_rc_version}.pom - - arrow-memory-{no_rc_version}.pom - - arrow-orc-{no_rc_version}-javadoc.jar - - arrow-orc-{no_rc_version}-sources.jar - - arrow-orc-{no_rc_version}-tests.jar - - arrow-orc-{no_rc_version}.jar - - arrow-orc-{no_rc_version}.pom - - arrow-performance-{no_rc_version}-sources.jar - - arrow-performance-{no_rc_version}-tests.jar - - arrow-performance-{no_rc_version}.jar - - arrow-performance-{no_rc_version}.pom - - arrow-plasma-{no_rc_version}-javadoc.jar - - arrow-plasma-{no_rc_version}-sources.jar - - arrow-plasma-{no_rc_version}-tests.jar - - arrow-plasma-{no_rc_version}.jar - - arrow-plasma-{no_rc_version}.pom - - arrow-tools-{no_rc_version}-jar-with-dependencies.jar - - arrow-tools-{no_rc_version}-javadoc.jar - - arrow-tools-{no_rc_version}-sources.jar - - arrow-tools-{no_rc_version}-tests.jar - - arrow-tools-{no_rc_version}.jar - - arrow-tools-{no_rc_version}.pom - - arrow-vector-{no_rc_version}-javadoc.jar - - arrow-vector-{no_rc_version}-shade-format-flatbuffers.jar - - arrow-vector-{no_rc_version}-sources.jar - - arrow-vector-{no_rc_version}-tests.jar - - arrow-vector-{no_rc_version}.jar - - arrow-vector-{no_rc_version}.pom - - flight-core-{no_rc_version}-jar-with-dependencies.jar - - flight-core-{no_rc_version}-javadoc.jar - - flight-core-{no_rc_version}-shaded-ext.jar - - flight-core-{no_rc_version}-shaded.jar - - flight-core-{no_rc_version}-sources.jar - - flight-core-{no_rc_version}-tests.jar - - flight-core-{no_rc_version}.jar - - flight-core-{no_rc_version}.pom - - flight-grpc-{no_rc_version}-javadoc.jar - - flight-grpc-{no_rc_version}-sources.jar - - flight-grpc-{no_rc_version}-tests.jar - - flight-grpc-{no_rc_version}.jar - - flight-grpc-{no_rc_version}.pom - - flight-integration-tests-{no_rc_version}-jar-with-dependencies.jar - - flight-integration-tests-{no_rc_version}-javadoc.jar - - flight-integration-tests-{no_rc_version}-sources.jar - - flight-integration-tests-{no_rc_version}-tests.jar - - flight-integration-tests-{no_rc_version}.jar - - flight-integration-tests-{no_rc_version}.pom - - flight-sql-{no_rc_version}-javadoc.jar - - flight-sql-{no_rc_version}-sources.jar - - flight-sql-{no_rc_version}-tests.jar - - flight-sql-{no_rc_version}.jar - - flight-sql-{no_rc_version}.pom + - arrow-algorithm-{no_rc_snapshot_version}-javadoc.jar + - arrow-algorithm-{no_rc_snapshot_version}-sources.jar + - arrow-algorithm-{no_rc_snapshot_version}-tests.jar + - arrow-algorithm-{no_rc_snapshot_version}.jar + - arrow-algorithm-{no_rc_snapshot_version}.pom + - arrow-avro-{no_rc_snapshot_version}-javadoc.jar + - arrow-avro-{no_rc_snapshot_version}-sources.jar + - arrow-avro-{no_rc_snapshot_version}-tests.jar + - arrow-avro-{no_rc_snapshot_version}.jar + - arrow-avro-{no_rc_snapshot_version}.pom + - arrow-c-data-{no_rc_snapshot_version}-javadoc.jar + - arrow-c-data-{no_rc_snapshot_version}-sources.jar + - arrow-c-data-{no_rc_snapshot_version}-tests.jar + - arrow-c-data-{no_rc_snapshot_version}.jar + - arrow-c-data-{no_rc_snapshot_version}.pom + - arrow-compression-{no_rc_snapshot_version}-javadoc.jar + - arrow-compression-{no_rc_snapshot_version}-sources.jar + - arrow-compression-{no_rc_snapshot_version}-tests.jar + - arrow-compression-{no_rc_snapshot_version}.jar + - arrow-compression-{no_rc_snapshot_version}.pom + - arrow-dataset-{no_rc_snapshot_version}-javadoc.jar + - arrow-dataset-{no_rc_snapshot_version}-sources.jar + - arrow-dataset-{no_rc_snapshot_version}-tests.jar + - arrow-dataset-{no_rc_snapshot_version}.jar + - arrow-dataset-{no_rc_snapshot_version}.pom + - arrow-flight-{no_rc_snapshot_version}.pom + - arrow-format-{no_rc_snapshot_version}-javadoc.jar + - arrow-format-{no_rc_snapshot_version}-sources.jar + - arrow-format-{no_rc_snapshot_version}-tests.jar + - arrow-format-{no_rc_snapshot_version}.jar + - arrow-format-{no_rc_snapshot_version}.pom + - arrow-gandiva-{no_rc_snapshot_version}-javadoc.jar + - arrow-gandiva-{no_rc_snapshot_version}-sources.jar + - arrow-gandiva-{no_rc_snapshot_version}-tests.jar + - arrow-gandiva-{no_rc_snapshot_version}.jar + - arrow-gandiva-{no_rc_snapshot_version}.pom + - arrow-java-root-{no_rc_snapshot_version}-source-release.zip + - arrow-java-root-{no_rc_snapshot_version}.pom + - arrow-jdbc-{no_rc_snapshot_version}-javadoc.jar + - arrow-jdbc-{no_rc_snapshot_version}-sources.jar + - arrow-jdbc-{no_rc_snapshot_version}-tests.jar + - arrow-jdbc-{no_rc_snapshot_version}.jar + - arrow-jdbc-{no_rc_snapshot_version}.pom + - arrow-memory-core-{no_rc_snapshot_version}-javadoc.jar + - arrow-memory-core-{no_rc_snapshot_version}-sources.jar + - arrow-memory-core-{no_rc_snapshot_version}-tests.jar + - arrow-memory-core-{no_rc_snapshot_version}.jar + - arrow-memory-core-{no_rc_snapshot_version}.pom + - arrow-memory-netty-{no_rc_snapshot_version}-javadoc.jar + - arrow-memory-netty-{no_rc_snapshot_version}-sources.jar + - arrow-memory-netty-{no_rc_snapshot_version}-tests.jar + - arrow-memory-netty-{no_rc_snapshot_version}.jar + - arrow-memory-netty-{no_rc_snapshot_version}.pom + - arrow-memory-unsafe-{no_rc_snapshot_version}-javadoc.jar + - arrow-memory-unsafe-{no_rc_snapshot_version}-sources.jar + - arrow-memory-unsafe-{no_rc_snapshot_version}-tests.jar + - arrow-memory-unsafe-{no_rc_snapshot_version}.jar + - arrow-memory-unsafe-{no_rc_snapshot_version}.pom + - arrow-memory-{no_rc_snapshot_version}.pom + - arrow-orc-{no_rc_snapshot_version}-javadoc.jar + - arrow-orc-{no_rc_snapshot_version}-sources.jar + - arrow-orc-{no_rc_snapshot_version}-tests.jar + - arrow-orc-{no_rc_snapshot_version}.jar + - arrow-orc-{no_rc_snapshot_version}.pom + - arrow-performance-{no_rc_snapshot_version}-sources.jar + - arrow-performance-{no_rc_snapshot_version}-tests.jar + - arrow-performance-{no_rc_snapshot_version}.jar + - arrow-performance-{no_rc_snapshot_version}.pom + - arrow-plasma-{no_rc_snapshot_version}-javadoc.jar + - arrow-plasma-{no_rc_snapshot_version}-sources.jar + - arrow-plasma-{no_rc_snapshot_version}-tests.jar + - arrow-plasma-{no_rc_snapshot_version}.jar + - arrow-plasma-{no_rc_snapshot_version}.pom + - arrow-tools-{no_rc_snapshot_version}-jar-with-dependencies.jar + - arrow-tools-{no_rc_snapshot_version}-javadoc.jar + - arrow-tools-{no_rc_snapshot_version}-sources.jar + - arrow-tools-{no_rc_snapshot_version}-tests.jar + - arrow-tools-{no_rc_snapshot_version}.jar + - arrow-tools-{no_rc_snapshot_version}.pom + - arrow-vector-{no_rc_snapshot_version}-javadoc.jar + - arrow-vector-{no_rc_snapshot_version}-shade-format-flatbuffers.jar + - arrow-vector-{no_rc_snapshot_version}-sources.jar + - arrow-vector-{no_rc_snapshot_version}-tests.jar + - arrow-vector-{no_rc_snapshot_version}.jar + - arrow-vector-{no_rc_snapshot_version}.pom + - flight-core-{no_rc_snapshot_version}-jar-with-dependencies.jar + - flight-core-{no_rc_snapshot_version}-javadoc.jar + - flight-core-{no_rc_snapshot_version}-shaded-ext.jar + - flight-core-{no_rc_snapshot_version}-shaded.jar + - flight-core-{no_rc_snapshot_version}-sources.jar + - flight-core-{no_rc_snapshot_version}-tests.jar + - flight-core-{no_rc_snapshot_version}.jar + - flight-core-{no_rc_snapshot_version}.pom + - flight-grpc-{no_rc_snapshot_version}-javadoc.jar + - flight-grpc-{no_rc_snapshot_version}-sources.jar + - flight-grpc-{no_rc_snapshot_version}-tests.jar + - flight-grpc-{no_rc_snapshot_version}.jar + - flight-grpc-{no_rc_snapshot_version}.pom + - flight-integration-tests-{no_rc_snapshot_version}-jar-with-dependencies.jar + - flight-integration-tests-{no_rc_snapshot_version}-javadoc.jar + - flight-integration-tests-{no_rc_snapshot_version}-sources.jar + - flight-integration-tests-{no_rc_snapshot_version}-tests.jar + - flight-integration-tests-{no_rc_snapshot_version}.jar + - flight-integration-tests-{no_rc_snapshot_version}.pom + - flight-sql-{no_rc_snapshot_version}-javadoc.jar + - flight-sql-{no_rc_snapshot_version}-sources.jar + - flight-sql-{no_rc_snapshot_version}-tests.jar + - flight-sql-{no_rc_snapshot_version}.jar + - flight-sql-{no_rc_snapshot_version}.pom ############################## NuGet packages ############################### From e63a13aacbf67897202c8a56fccb3a86f624a96e Mon Sep 17 00:00:00 2001 From: Jin Shang Date: Fri, 16 Sep 2022 05:21:41 +0800 Subject: [PATCH 072/133] ARROW-17742: [C++][Gandiva] Fix Gandiva utf8proc dependency in CMake presets (#14140) Building CMake presets `ninja-release-gandiva` and `ninja-debug-gandiva` fails with a link error saying utf8proc related definitions are not found. It's because `ARROW_WITH_UTF8PROC` is explicitly turned off in these two presets but is required to build Gandiva. Add `set(ARROW_WITH_UTF8PROC ON)` if `ARROW_GANDIVA` solves the issue. An alternative solution is to enable `ARROW_WITH_UTF8PROC` in the preset settings. But I think the current solution is more general. Authored-by: Jin Shang Signed-off-by: Sutou Kouhei --- cpp/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 1a728bd7870..14584e17fb8 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -333,6 +333,7 @@ endif() if(ARROW_GANDIVA) set(ARROW_WITH_RE2 ON) + set(ARROW_WITH_UTF8PROC ON) endif() if(ARROW_BUILD_INTEGRATION AND ARROW_FLIGHT) From 2e72e0a80853e498f44b9c28334a7b5c8a0dca59 Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Fri, 16 Sep 2022 01:01:04 +0200 Subject: [PATCH 073/133] ARROW-17407: [Doc][FlightRPC] Flight/gRPC best practices (#13873) We want to provide best practices and debugging section in: [cpp flight docs](https://arrow.apache.org/docs/cpp/flight.html) [python flight docs](https://arrow.apache.org/docs/python/flight.html) [java flight docs](https://arrow.apache.org/docs/java/flight.html) [R flight docs](https://arrow.apache.org/docs/r/articles/flight.html) Lead-authored-by: Rok Co-authored-by: Rok Mihevc Co-authored-by: David Li Signed-off-by: Rok Mihevc --- docs/source/cpp/flight.rst | 191 ++++++++++++++++++++++++++++++++++ docs/source/java/flight.rst | 4 + docs/source/python/flight.rst | 3 + go/arrow/flight/doc.go | 78 ++++++++++++++ go/go.sum | 1 + go/parquet/doc.go | 2 +- r/vignettes/flight.Rmd | 2 + 7 files changed, 280 insertions(+), 1 deletion(-) create mode 100644 go/arrow/flight/doc.go diff --git a/docs/source/cpp/flight.rst b/docs/source/cpp/flight.rst index a941ead9040..e07a84e91ee 100644 --- a/docs/source/cpp/flight.rst +++ b/docs/source/cpp/flight.rst @@ -172,6 +172,197 @@ request/response. On the server, they can inspect incoming headers and fail the request; hence, they can be used to implement custom authentication methods. +.. _flight-best-practices: + +Best practices +============== + +gRPC +---- + +When using the default gRPC transport, options can be passed to it via +:member:`arrow::flight::FlightClientOptions::generic_options`. For example: + +.. tab-set:: + + .. tab-item:: C++ + + .. code-block:: cpp + + auto options = FlightClientOptions::Defaults(); + // Set the period after which a keepalive ping is sent on transport. + options.generic_options.emplace_back(GRPC_ARG_KEEPALIVE_TIME_MS, 60000); + + .. tab-item:: Python + + .. code-block:: python + + # Set the period after which a keepalive ping is sent on transport. + generic_options = [("GRPC_ARG_KEEPALIVE_TIME_MS", 60000)] + client = pyarrow.flight.FlightClient(server_uri, generic_options=generic_options) + +Also see `best gRPC practices`_ and available `gRPC keys`_. + +Re-use clients whenever possible +-------------------------------- + +Creating and closing clients requires setup and teardown on the client and +server side which can take away from actually handling RPCs. Reuse clients +whenever possible to avoid this. Note that clients are thread-safe, so a +single client can be shared across multiple threads. + +Don’t round-robin load balance +------------------------------ + +`Round robin load balancing`_ means every client can have an open connection to +every server, causing an unexpected number of open connections and depleting +server resources. + +Debugging connection issues +--------------------------- + +When facing unexpected disconnects on long running connections use netstat to +monitor the number of open connections. If number of connections is much +greater than the number of clients it might cause issues. + +For debugging, certain environment variables enable logging in gRPC. For +example, ``env GRPC_VERBOSITY=info GRPC_TRACE=http`` will print the initial +headers (on both sides) so you can see if gRPC established the connection or +not. It will also print when a message is sent, so you can tell if the +connection is open or not. + +gRPC may not report connection errors until a call is actually made. +Hence, to detect connection errors when creating a client, some sort +of dummy RPC should be made. + +Memory management +----------------- + +Flight tries to reuse allocations made by gRPC to avoid redundant +data copies. However, this means that those allocations may not +be tracked by the Arrow memory pool, and that memory usage behavior, +such as whether free memory is returned to the system, is dependent +on the allocator that gRPC uses (usually the system allocator). + +A quick way of testing: attach to the process with a debugger and call +``malloc_trim``, or call :func:`ReleaseUnused ` +on the system pool. If memory usage drops, then likely, there is memory +allocated by gRPC or by the application that the system allocator was holding +on to. This can be adjusted in platform-specific ways; see an investigation +in ARROW-16697_ for an example of how this works on Linux/glibc. glibc malloc +can be explicitly told to dump caches. + +Excessive traffic +----------------- + +gRPC will spawn up to max threads quota of threads for concurrent clients. Those +threads are not necessarily cleaned up (a "cached thread pool" in Java parlance). +glibc malloc clears some per thread state and the default tuning never clears +caches in some workloads. + +gRPC's default behaviour allows one server to accept many connections from many +different clients, but if requests do a lot of work (as they may under Flight), +the server may not be able to keep up. Configuring clients to retry +with backoff (and potentially connect to a different node), would give more +consistent quality of service. + +.. tab-set:: + + .. tab-item:: C++ + + .. code-block:: cpp + + auto options = FlightClientOptions::Defaults(); + // Set the minimum time between subsequent connection attempts. + options.generic_options.emplace_back(GRPC_ARG_MIN_RECONNECT_BACKOFF_MS, 2000); + + .. tab-item:: Python + + .. code-block:: python + + # Set the minimum time between subsequent connection attempts. + generic_options = [("GRPC_ARG_MIN_RECONNECT_BACKOFF_MS", 2000)] + client = pyarrow.flight.FlightClient(server_uri, generic_options=generic_options) + + +Limiting DoPut Batch Size +-------------------------- + +You may wish to limit the maximum batch size a client can submit to a server through +DoPut, to prevent a request from taking up too much memory on the server. On +the client-side, set :member:`arrow::flight::FlightClientOptions::write_size_limit_bytes`. +On the server-side, set the gRPC option ``GRPC_ARG_MAX_RECEIVE_MESSAGE_LENGTH``. +The client-side option will return an error that can be retried with smaller batches, +while the server-side limit will close out the connection. Setting both can be wise, since +the former provides a better user experience but the latter may be necessary to defend +against impolite clients. + +Closing unresponsive connections +-------------------------------- + +1. A stale call can be closed using + :member:`arrow::flight::FlightCallOptions::stop_token`. This requires recording the + stop token at call establishment time. + + .. tab-set:: + + .. tab-item:: C++ + + .. code-block:: cpp + + StopSource stop_source; + FlightCallOptions options; + options.stop_token = stop_source.token(); + stop_source.RequestStop(Status::Cancelled("StopSource")); + flight_client->DoAction(options, {}); + + +2. Use call timeouts. (This is a general gRPC best practice.) + + .. tab-set:: + + .. tab-item:: C++ + + .. code-block:: cpp + + FlightCallOptions options; + options.timeout = TimeoutDuration{0.2}; + Status status = client->GetFlightInfo(options, FlightDescriptor{}).status(); + + .. tab-item:: Java + + .. code-block:: java + + Iterator results = client.doAction(new Action("hang"), CallOptions.timeout(0.2, TimeUnit.SECONDS)); + + .. tab-item:: Python + + .. code-block:: python + + options = pyarrow.flight.FlightCallOptions(timeout=0.2) + result = client.do_action(action, options=options) + + +3. Client timeouts are not great for long-running streaming calls, where it may + be hard to choose a timeout for the entire operation. Instead, what is often + desired is a per-read or per-write timeout so that the operation fails if it + isn't making progress. This can be implemented with a background thread that + calls Cancel() on a timer, with the main thread resetting the timer every time + an operation completes successfully. For a fully-worked out example, see the + Cookbook. + + .. note:: There is a long standing ticket for a per-write/per-read timeout + instead of a per call timeout (ARROW-6062_), but this is not (easily) + possible to implement with the blocking gRPC API. + +.. _best gRPC practices: https://grpc.io/docs/guides/performance/#general +.. _gRPC keys: https://grpc.github.io/grpc/cpp/group__grpc__arg__keys.html +.. _Round robin load balancing: https://github.com/grpc/grpc/blob/master/doc/load-balancing.md#round_robin +.. _ARROW-15764: https://issues.apache.org/jira/browse/ARROW-15764 +.. _ARROW-16697: https://issues.apache.org/jira/browse/ARROW-16697 +.. _ARROW-6062: https://issues.apache.org/jira/browse/ARROW-6062 + + Alternative Transports ====================== diff --git a/docs/source/java/flight.rst b/docs/source/java/flight.rst index 69a7d2b8d26..f62046ecd2a 100644 --- a/docs/source/java/flight.rst +++ b/docs/source/java/flight.rst @@ -201,6 +201,10 @@ request/response. On the server, they can inspect incoming headers and fail the request; hence, they can be used to implement custom authentication methods. +:ref:`Flight best practices ` +==================================================== + + .. _`FlightClient`: https://arrow.apache.org/docs/java/reference/org/apache/arrow/flight/FlightClient.html .. _`FlightProducer`: https://arrow.apache.org/docs/java/reference/org/apache/arrow/flight/FlightProducer.html .. _`FlightServer`: https://arrow.apache.org/docs/java/reference/org/apache/arrow/flight/FlightServer.html diff --git a/docs/source/python/flight.rst b/docs/source/python/flight.rst index d038bcce57c..f07b9511ccf 100644 --- a/docs/source/python/flight.rst +++ b/docs/source/python/flight.rst @@ -128,3 +128,6 @@ Middleware are fairly limited, but they can add headers to a request/response. On the server, they can inspect incoming headers and fail the request; hence, they can be used to implement custom authentication methods. + +:ref:`Flight best practices ` +==================================================== diff --git a/go/arrow/flight/doc.go b/go/arrow/flight/doc.go new file mode 100644 index 00000000000..68d1ca3458f --- /dev/null +++ b/go/arrow/flight/doc.go @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package flight contains server and client implementations for the Arrow Flight RPC +// +// Here we list best practices and common pitfalls for Arrow Flight usage. +// +// GRPC +// +// When using gRPC for transport all client methods take an optional list +// of gRPC CallOptions: https://pkg.go.dev/google.golang.org/grpc#CallOption. +// Additional headers can be used or read via +// https://pkg.go.dev/google.golang.org/grpc@v1.48.0/metadata with the context. +// Also see available gRPC keys +// (https://grpc.github.io/grpc/cpp/group__grpc__arg__keys.html) and a list of +// best gRPC practices (https://grpc.io/docs/guides/performance/#general). +// +// Re-use clients whenever possible +// +// Closing clients causes gRPC to close and clean up connections which can take +// several seconds per connection. This will stall server and client threads if +// done too frequently. Client reuse will avoid this issue. +// +// Don’t round-robin load balance +// +// Round robin balancing can cause every client to have an open connection to +// every server causing an unexpected number of open connections and a depletion +// of resources. +// +// Debugging +// +// Use netstat to see the number of open connections. +// For debug use env GODEBUG=http2debug=1 or GODEBUG=http2debug=2 for verbose +// http2 logs (using 2 is more verbose with frame dumps). This will print the +// initial headers (on both sides) so you can see if grpc established the +// connection or not. It will also print when a message is sent, so you can tell +// if the connection is open or not. +// +// Note: "connect" isn't really a connect and we’ve observed that gRPC does not +// give you the actual error until you first try to make a call. This can cause +// error being reported at unexpected times. +// +// Excessive traffic +// +// There are basically two ways to handle excessive traffic: +// * unbounded goroutines -> everyone gets serviced, but it might take forever. +// This is what you are seeing now. Default behaviour. +// * bounded thread pool -> Reject connections / requests when under load, and have +// clients retry with backoff. This also gives an opportunity to retry with a +// different node. Not everyone gets serviced but quality of service stays consistent. +// Can be set with https://pkg.go.dev/google.golang.org/grpc#NumStreamWorkers +// +// Closing unresponsive connections +// +// * Connection timeout (https://pkg.go.dev/context#WithTimeout) or +// (https://pkg.go.dev/context#WithCancel) can be set via context.Context. +// * There is a long standing ticket for a per-write/per-read timeout instead of a per +// call timeout (https://issues.apache.org/jira/browse/ARROW-6062), but this is not +// (easily) possible to implement with the blocking gRPC API. For now one can also do +// something like set up a background thread that calls cancel() on a timer and have +// the main thread reset the timer every time a write operation completes successfully +// (that means one needs to use to_batches() + write_batch and not write_table). + + +package flight diff --git a/go/go.sum b/go/go.sum index 098a7294b40..e658e9114ef 100644 --- a/go/go.sum +++ b/go/go.sum @@ -149,6 +149,7 @@ github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PK github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= diff --git a/go/parquet/doc.go b/go/parquet/doc.go index 7c97dd5c950..c6183618d29 100644 --- a/go/parquet/doc.go +++ b/go/parquet/doc.go @@ -17,7 +17,7 @@ // Package parquet provides an implementation of Apache Parquet for Go. // // Apache Parquet is an open-source columnar data storage format using the record -// shredding and assembly algorithm to accomodate complex data structures which +// shredding and assembly algorithm to accommodate complex data structures which // can then be used to efficiently store the data. // // This implementation is a native go implementation for reading and writing the diff --git a/r/vignettes/flight.Rmd b/r/vignettes/flight.Rmd index e8af5cad6f7..ec2ac938f58 100644 --- a/r/vignettes/flight.Rmd +++ b/r/vignettes/flight.Rmd @@ -85,3 +85,5 @@ client %>% Because `flight_get()` returns an Arrow data structure, you can directly pipe its result into a [dplyr](https://dplyr.tidyverse.org/) workflow. See `vignette("dataset", package = "arrow")` for more information on working with Arrow objects via a dplyr interface. + +# [Flight best practices](../cpp/flight.html#best-practices) From 93626eebd0d9443be655336b65a4d17bfffe898c Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Thu, 15 Sep 2022 21:38:51 -0400 Subject: [PATCH 074/133] ARROW-15011: [R] Generate documentation for dplyr function bindings (#14014) Approach: * `register_binding` takes an additional optional argument, `notes`, where you can list any limitations or differences in behavior between the Arrow version and the R function * These notes are put in the `.cache` environment when the nse_funcs are built. * New script `data-raw/docgen.R` that reads `arrow:::.cache$docs` and writes out `dplyr-funcs-docs.R` containing roxygen. * Similarly, we pull the dplyr functions we s3_register and add them to the generated docs. Unfortunately, the notes about feature limitations aren't easily kept alongside the functions themselves because of how they're registered on load. The approach here creates a list in `arrow-package.R`, where the `.onLoad()` happens, and notes go there. * Docs and crossreferences are generated by roxygen2 as usual. I deferred filling in all of the function notes. See followup JIRAs on ARROW-17665. Authored-by: Neal Richardson Signed-off-by: Neal Richardson --- r/DESCRIPTION | 1 + r/Makefile | 1 + r/R/arrow-package.R | 51 ++++-- r/R/dplyr-funcs-augmented.R | 19 +- r/R/dplyr-funcs-datetime.R | 53 +++--- r/R/dplyr-funcs-doc.R | 332 +++++++++++++++++++++++++++++++++++ r/R/dplyr-funcs-string.R | 86 +++++---- r/R/dplyr-funcs-type.R | 43 +++-- r/R/dplyr-funcs.R | 17 +- r/R/expression.R | 11 +- r/_pkgdown.yml | 1 + r/data-raw/docgen.R | 192 ++++++++++++++++++++ r/man/acero.Rd | 339 ++++++++++++++++++++++++++++++++++++ r/man/add_filename.Rd | 23 +++ r/man/cast.Rd | 38 ++++ r/man/register_binding.Rd | 11 +- 16 files changed, 1109 insertions(+), 109 deletions(-) create mode 100644 r/R/dplyr-funcs-doc.R create mode 100644 r/data-raw/docgen.R create mode 100644 r/man/acero.Rd create mode 100644 r/man/add_filename.Rd create mode 100644 r/man/cast.Rd diff --git a/r/DESCRIPTION b/r/DESCRIPTION index 7ae6a8de29f..7b60f0c510a 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -103,6 +103,7 @@ Collate: 'dplyr-funcs-augmented.R' 'dplyr-funcs-conditional.R' 'dplyr-funcs-datetime.R' + 'dplyr-funcs-doc.R' 'dplyr-funcs-math.R' 'dplyr-funcs-string.R' 'dplyr-funcs-type.R' diff --git a/r/Makefile b/r/Makefile index 1ddbe595dd2..cb76b4c9775 100644 --- a/r/Makefile +++ b/r/Makefile @@ -26,6 +26,7 @@ style-all: R -s -e 'styler::style_file(setdiff(dir(pattern = "R$$", recursive = TRUE), source(".styler_excludes.R")$$value))' doc: style + R -s -f data-raw/docgen.R R -s -e 'roxygen2::roxygenize()' -git add --all man/*.Rd diff --git a/r/R/arrow-package.R b/r/R/arrow-package.R index 53fb0280a50..e6b3f481e21 100644 --- a/r/R/arrow-package.R +++ b/r/R/arrow-package.R @@ -31,25 +31,50 @@ #' @keywords internal "_PACKAGE" +# TODO(ARROW-17666): Include notes about features not supported here. +supported_dplyr_methods <- list( + select = NULL, + filter = NULL, + collect = NULL, + summarise = NULL, + group_by = NULL, + groups = NULL, + group_vars = NULL, + group_by_drop_default = NULL, + ungroup = NULL, + mutate = NULL, + transmute = NULL, + arrange = NULL, + rename = NULL, + pull = NULL, + relocate = NULL, + compute = NULL, + collapse = NULL, + distinct = NULL, + left_join = NULL, + right_join = NULL, + inner_join = NULL, + full_join = NULL, + semi_join = NULL, + anti_join = NULL, + count = NULL, + tally = NULL, + rename_with = NULL, + union = NULL, + union_all = NULL, + glimpse = NULL, + show_query = NULL, + explain = NULL +) + #' @importFrom vctrs s3_register vec_size vec_cast vec_unique .onLoad <- function(...) { # Make sure C++ knows on which thread it is safe to call the R API InitializeMainRThread() - dplyr_methods <- paste0( - "dplyr::", - c( - "select", "filter", "collect", "summarise", "group_by", "groups", - "group_vars", "group_by_drop_default", "ungroup", "mutate", "transmute", - "arrange", "rename", "pull", "relocate", "compute", "collapse", - "distinct", "left_join", "right_join", "inner_join", "full_join", - "semi_join", "anti_join", "count", "tally", "rename_with", "union", - "union_all", "glimpse", "show_query", "explain" - ) - ) for (cl in c("Dataset", "ArrowTabular", "RecordBatchReader", "arrow_dplyr_query")) { - for (m in dplyr_methods) { - s3_register(m, cl) + for (m in names(supported_dplyr_methods)) { + s3_register(paste0("dplyr::", m), cl) } } s3_register("dplyr::tbl_vars", "arrow_dplyr_query") diff --git a/r/R/dplyr-funcs-augmented.R b/r/R/dplyr-funcs-augmented.R index 6e751d49f61..1067f15573b 100644 --- a/r/R/dplyr-funcs-augmented.R +++ b/r/R/dplyr-funcs-augmented.R @@ -15,8 +15,21 @@ # specific language governing permissions and limitations # under the License. +#' Add the data filename as a column +#' +#' This function only exists inside `arrow` `dplyr` queries, and it only is +#' valid when quering on a `FileSystemDataset`. +#' +#' @return A `FieldRef` `Expression` that refers to the filename augmented +#' column. +#' @examples +#' \dontrun{ +#' open_dataset("nyc-taxi") %>% +#' mutate(file = add_filename()) +#' } +#' @keywords internal +add_filename <- function() Expression$field_ref("__filename") + register_bindings_augmented <- function() { - register_binding("add_filename", function() { - Expression$field_ref("__filename") - }) + register_binding("arrow::add_filename", add_filename) } diff --git a/r/R/dplyr-funcs-datetime.R b/r/R/dplyr-funcs-datetime.R index 9a010452b84..6106adbc5e4 100644 --- a/r/R/dplyr-funcs-datetime.R +++ b/r/R/dplyr-funcs-datetime.R @@ -649,55 +649,54 @@ register_bindings_datetime_parsers <- function() { build_expr("assume_timezone", coalesce_output, options = list(timezone = tz)) }) - } register_bindings_datetime_rounding <- function() { register_binding( - "round_date", + "lubridate::round_date", function(x, unit = "second", week_start = getOption("lubridate.week.start", 7)) { + opts <- parse_period_unit(unit) + if (opts$unit == 7L) { # weeks (unit = 7L) need to accommodate week_start + return(shift_temporal_to_week("round_temporal", x, week_start, options = opts)) + } - opts <- parse_period_unit(unit) - if (opts$unit == 7L) { # weeks (unit = 7L) need to accommodate week_start - return(shift_temporal_to_week("round_temporal", x, week_start, options = opts)) + Expression$create("round_temporal", x, options = opts) } - - Expression$create("round_temporal", x, options = opts) - }) + ) register_binding( - "floor_date", + "lubridate::floor_date", function(x, unit = "second", week_start = getOption("lubridate.week.start", 7)) { + opts <- parse_period_unit(unit) + if (opts$unit == 7L) { # weeks (unit = 7L) need to accommodate week_start + return(shift_temporal_to_week("floor_temporal", x, week_start, options = opts)) + } - opts <- parse_period_unit(unit) - if (opts$unit == 7L) { # weeks (unit = 7L) need to accommodate week_start - return(shift_temporal_to_week("floor_temporal", x, week_start, options = opts)) + Expression$create("floor_temporal", x, options = opts) } - - Expression$create("floor_temporal", x, options = opts) - }) + ) register_binding( - "ceiling_date", + "lubridate::ceiling_date", function(x, unit = "second", change_on_boundary = NULL, week_start = getOption("lubridate.week.start", 7)) { - opts <- parse_period_unit(unit) - if (is.null(change_on_boundary)) { - change_on_boundary <- ifelse(call_binding("is.Date", x), TRUE, FALSE) - } - opts$ceil_is_strictly_greater <- change_on_boundary - - if (opts$unit == 7L) { # weeks (unit = 7L) need to accommodate week_start - return(shift_temporal_to_week("ceil_temporal", x, week_start, options = opts)) - } + opts <- parse_period_unit(unit) + if (is.null(change_on_boundary)) { + change_on_boundary <- ifelse(call_binding("is.Date", x), TRUE, FALSE) + } + opts$ceil_is_strictly_greater <- change_on_boundary - Expression$create("ceil_temporal", x, options = opts) - }) + if (opts$unit == 7L) { # weeks (unit = 7L) need to accommodate week_start + return(shift_temporal_to_week("ceil_temporal", x, week_start, options = opts)) + } + Expression$create("ceil_temporal", x, options = opts) + } + ) } diff --git a/r/R/dplyr-funcs-doc.R b/r/R/dplyr-funcs-doc.R new file mode 100644 index 00000000000..cac0310f49b --- /dev/null +++ b/r/R/dplyr-funcs-doc.R @@ -0,0 +1,332 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Generated by using data-raw/docgen.R -> do not edit by hand + +#' Functions available in Arrow dplyr queries +#' +#' The `arrow` package contains methods for 32 `dplyr` table functions, many of +#' which are "verbs" that do transformations to one or more tables. +#' The package also has mappings of 205 R functions to the corresponding +#' functions in the Arrow compute library. These allow you to write code inside +#' of `dplyr` methods that call R functions, including many in packages like +#' `stringr` and `lubridate`, and they will get translated to Arrow and run +#' on the Arrow query engine (Acero). This document lists all of the mapped +#' functions. +#' +#' # `dplyr` verbs +#' +#' Most verb functions return an `arrow_dplyr_query` object, similar in spirit +#' to a `dbplyr::tbl_lazy`. This means that the verbs do not eagerly evaluate +#' the query on the data. To run the query, call either `compute()`, +#' which returns an `arrow` [Table], or `collect()`, which pulls the resulting +#' Table into an R `data.frame`. +#' +#' * [`anti_join()`][dplyr::anti_join()] +#' * [`arrange()`][dplyr::arrange()] +#' * [`collapse()`][dplyr::collapse()] +#' * [`collect()`][dplyr::collect()] +#' * [`compute()`][dplyr::compute()] +#' * [`count()`][dplyr::count()] +#' * [`distinct()`][dplyr::distinct()] +#' * [`explain()`][dplyr::explain()] +#' * [`filter()`][dplyr::filter()] +#' * [`full_join()`][dplyr::full_join()] +#' * [`glimpse()`][dplyr::glimpse()] +#' * [`group_by()`][dplyr::group_by()] +#' * [`group_by_drop_default()`][dplyr::group_by_drop_default()] +#' * [`group_vars()`][dplyr::group_vars()] +#' * [`groups()`][dplyr::groups()] +#' * [`inner_join()`][dplyr::inner_join()] +#' * [`left_join()`][dplyr::left_join()] +#' * [`mutate()`][dplyr::mutate()] +#' * [`pull()`][dplyr::pull()] +#' * [`relocate()`][dplyr::relocate()] +#' * [`rename()`][dplyr::rename()] +#' * [`rename_with()`][dplyr::rename_with()] +#' * [`right_join()`][dplyr::right_join()] +#' * [`select()`][dplyr::select()] +#' * [`semi_join()`][dplyr::semi_join()] +#' * [`show_query()`][dplyr::show_query()] +#' * [`summarise()`][dplyr::summarise()] +#' * [`tally()`][dplyr::tally()] +#' * [`transmute()`][dplyr::transmute()] +#' * [`ungroup()`][dplyr::ungroup()] +#' * [`union()`][dplyr::union()] +#' * [`union_all()`][dplyr::union_all()] +#' +#' # Function mappings +#' +#' In the list below, any differences in behavior or support between Acero and +#' the R function are listed. If no notes follow the function name, then you +#' can assume that the function works in Acero just as it does in R. +#' +#' Functions can be called either as `pkg::fun()` or just `fun()`, i.e. both +#' `str_sub()` and `stringr::str_sub()` work. +#' +#' In addition to these functions, you can call any of Arrow's 243 compute +#' functions directly. Arrow has many functions that don't map to an existing R +#' function. In other cases where there is an R function mapping, you can still +#' call the Arrow function directly if you don't want the adaptations that the R +#' mapping has that make Acero behave like R. These functions are listed in the +#' [C++ documentation](https://arrow.apache.org/docs/cpp/compute.html), and +#' in the function registry in R, they are named with an `arrow_` prefix, such +#' as `arrow_ascii_is_decimal`. +#' +#' ## arrow +#' +#' * [`add_filename()`][arrow::add_filename()] +#' * [`cast()`][arrow::cast()] +#' +#' ## base +#' +#' * [`-`][-()] +#' * [`!`][!()] +#' * [`!=`][!=()] +#' * [`*`][*()] +#' * [`/`][/()] +#' * [`&`][&()] +#' * [`%/%`][%/%()] +#' * [`%%`][%%()] +#' * [`%in%`][%in%()] +#' * [`^`][^()] +#' * [`+`][+()] +#' * [`<`][<()] +#' * [`<=`][<=()] +#' * [`==`][==()] +#' * [`>`][>()] +#' * [`>=`][>=()] +#' * [`|`][|()] +#' * [`abs()`][base::abs()] +#' * [`acos()`][base::acos()] +#' * [`all()`][base::all()] +#' * [`any()`][base::any()] +#' * [`as.character()`][base::as.character()] +#' * [`as.Date()`][base::as.Date()] +#' * [`as.difftime()`][base::as.difftime()] +#' * [`as.double()`][base::as.double()] +#' * [`as.integer()`][base::as.integer()] +#' * [`as.logical()`][base::as.logical()] +#' * [`as.numeric()`][base::as.numeric()] +#' * [`asin()`][base::asin()] +#' * [`ceiling()`][base::ceiling()] +#' * [`cos()`][base::cos()] +#' * [`data.frame()`][base::data.frame()] +#' * [`difftime()`][base::difftime()] +#' * [`endsWith()`][base::endsWith()] +#' * [`exp()`][base::exp()] +#' * [`floor()`][base::floor()] +#' * [`format()`][base::format()] +#' * [`grepl()`][base::grepl()] +#' * [`gsub()`][base::gsub()] +#' * [`ifelse()`][base::ifelse()] +#' * [`is.character()`][base::is.character()] +#' * [`is.double()`][base::is.double()] +#' * [`is.factor()`][base::is.factor()] +#' * [`is.finite()`][base::is.finite()] +#' * [`is.infinite()`][base::is.infinite()] +#' * [`is.integer()`][base::is.integer()] +#' * [`is.list()`][base::is.list()] +#' * [`is.logical()`][base::is.logical()] +#' * [`is.na()`][base::is.na()] +#' * [`is.nan()`][base::is.nan()] +#' * [`is.numeric()`][base::is.numeric()] +#' * [`ISOdate()`][base::ISOdate()] +#' * [`ISOdatetime()`][base::ISOdatetime()] +#' * [`log()`][base::log()] +#' * [`log10()`][base::log10()] +#' * [`log1p()`][base::log1p()] +#' * [`log2()`][base::log2()] +#' * [`logb()`][base::logb()] +#' * [`max()`][base::max()] +#' * [`mean()`][base::mean()] +#' * [`min()`][base::min()] +#' * [`nchar()`][base::nchar()] +#' * [`paste()`][base::paste()]: the `collapse` argument is not yet supported +#' * [`paste0()`][base::paste0()]: the `collapse` argument is not yet supported +#' * [`pmax()`][base::pmax()] +#' * [`pmin()`][base::pmin()] +#' * [`round()`][base::round()] +#' * [`sign()`][base::sign()] +#' * [`sin()`][base::sin()] +#' * [`sqrt()`][base::sqrt()] +#' * [`startsWith()`][base::startsWith()] +#' * [`strftime()`][base::strftime()] +#' * [`strptime()`][base::strptime()] +#' * [`strrep()`][base::strrep()] +#' * [`strsplit()`][base::strsplit()] +#' * [`sub()`][base::sub()] +#' * [`substr()`][base::substr()] +#' * [`substring()`][base::substring()] +#' * [`sum()`][base::sum()] +#' * [`tan()`][base::tan()] +#' * [`tolower()`][base::tolower()] +#' * [`toupper()`][base::toupper()] +#' * [`trunc()`][base::trunc()] +#' +#' ## bit64 +#' +#' * [`as.integer64()`][bit64::as.integer64()] +#' * [`is.integer64()`][bit64::is.integer64()] +#' +#' ## dplyr +#' +#' * [`across()`][dplyr::across()]: only supported inside `mutate()`, `summarize()`, and `arrange()`; purrr-style lambda functions and use of `where()` selection helper not yet supported +#' * [`between()`][dplyr::between()] +#' * [`case_when()`][dplyr::case_when()] +#' * [`coalesce()`][dplyr::coalesce()] +#' * [`desc()`][dplyr::desc()] +#' * [`if_else()`][dplyr::if_else()] +#' * [`n()`][dplyr::n()] +#' * [`n_distinct()`][dplyr::n_distinct()] +#' +#' ## lubridate +#' +#' * [`am()`][lubridate::am()] +#' * [`as_date()`][lubridate::as_date()] +#' * [`as_datetime()`][lubridate::as_datetime()] +#' * [`ceiling_date()`][lubridate::ceiling_date()] +#' * [`date()`][lubridate::date()] +#' * [`date_decimal()`][lubridate::date_decimal()] +#' * [`day()`][lubridate::day()] +#' * [`ddays()`][lubridate::ddays()] +#' * [`decimal_date()`][lubridate::decimal_date()] +#' * [`dhours()`][lubridate::dhours()] +#' * [`dmicroseconds()`][lubridate::dmicroseconds()] +#' * [`dmilliseconds()`][lubridate::dmilliseconds()] +#' * [`dminutes()`][lubridate::dminutes()] +#' * [`dmonths()`][lubridate::dmonths()] +#' * [`dmy()`][lubridate::dmy()] +#' * [`dmy_h()`][lubridate::dmy_h()] +#' * [`dmy_hm()`][lubridate::dmy_hm()] +#' * [`dmy_hms()`][lubridate::dmy_hms()] +#' * [`dnanoseconds()`][lubridate::dnanoseconds()] +#' * [`dpicoseconds()`][lubridate::dpicoseconds()] +#' * [`dseconds()`][lubridate::dseconds()] +#' * [`dst()`][lubridate::dst()] +#' * [`dweeks()`][lubridate::dweeks()] +#' * [`dyears()`][lubridate::dyears()] +#' * [`dym()`][lubridate::dym()] +#' * [`epiweek()`][lubridate::epiweek()] +#' * [`epiyear()`][lubridate::epiyear()] +#' * [`fast_strptime()`][lubridate::fast_strptime()] +#' * [`floor_date()`][lubridate::floor_date()] +#' * [`format_ISO8601()`][lubridate::format_ISO8601()] +#' * [`hour()`][lubridate::hour()] +#' * [`is.Date()`][lubridate::is.Date()] +#' * [`is.instant()`][lubridate::is.instant()] +#' * [`is.POSIXct()`][lubridate::is.POSIXct()] +#' * [`is.timepoint()`][lubridate::is.timepoint()] +#' * [`isoweek()`][lubridate::isoweek()] +#' * [`isoyear()`][lubridate::isoyear()] +#' * [`leap_year()`][lubridate::leap_year()] +#' * [`make_date()`][lubridate::make_date()] +#' * [`make_datetime()`][lubridate::make_datetime()] +#' * [`make_difftime()`][lubridate::make_difftime()] +#' * [`mday()`][lubridate::mday()] +#' * [`mdy()`][lubridate::mdy()] +#' * [`mdy_h()`][lubridate::mdy_h()] +#' * [`mdy_hm()`][lubridate::mdy_hm()] +#' * [`mdy_hms()`][lubridate::mdy_hms()] +#' * [`minute()`][lubridate::minute()] +#' * [`month()`][lubridate::month()] +#' * [`my()`][lubridate::my()] +#' * [`myd()`][lubridate::myd()] +#' * [`parse_date_time()`][lubridate::parse_date_time()] +#' * [`pm()`][lubridate::pm()] +#' * [`qday()`][lubridate::qday()] +#' * [`quarter()`][lubridate::quarter()] +#' * [`round_date()`][lubridate::round_date()] +#' * [`second()`][lubridate::second()] +#' * [`semester()`][lubridate::semester()] +#' * [`tz()`][lubridate::tz()] +#' * [`wday()`][lubridate::wday()] +#' * [`week()`][lubridate::week()] +#' * [`yday()`][lubridate::yday()] +#' * [`ydm()`][lubridate::ydm()] +#' * [`ydm_h()`][lubridate::ydm_h()] +#' * [`ydm_hm()`][lubridate::ydm_hm()] +#' * [`ydm_hms()`][lubridate::ydm_hms()] +#' * [`year()`][lubridate::year()] +#' * [`ym()`][lubridate::ym()] +#' * [`ymd()`][lubridate::ymd()] +#' * [`ymd_h()`][lubridate::ymd_h()] +#' * [`ymd_hm()`][lubridate::ymd_hm()] +#' * [`ymd_hms()`][lubridate::ymd_hms()] +#' * [`yq()`][lubridate::yq()] +#' +#' ## methods +#' +#' * [`is()`][methods::is()] +#' +#' ## rlang +#' +#' * [`is_character()`][rlang::is_character()] +#' * [`is_double()`][rlang::is_double()] +#' * [`is_integer()`][rlang::is_integer()] +#' * [`is_list()`][rlang::is_list()] +#' * [`is_logical()`][rlang::is_logical()] +#' +#' ## stats +#' +#' * [`median()`][stats::median()] +#' * [`quantile()`][stats::quantile()] +#' * [`sd()`][stats::sd()] +#' * [`var()`][stats::var()] +#' +#' ## stringi +#' +#' * [`stri_reverse()`][stringi::stri_reverse()] +#' +#' ## stringr +#' +#' * [`str_c()`][stringr::str_c()]: the `collapse` argument is not yet supported +#' * [`str_count()`][stringr::str_count()] +#' * [`str_detect()`][stringr::str_detect()] +#' * [`str_dup()`][stringr::str_dup()] +#' * [`str_ends()`][stringr::str_ends()] +#' * [`str_length()`][stringr::str_length()] +#' * `str_like()`: not yet in a released version of `stringr`, but it is supported in `arrow` +#' * [`str_pad()`][stringr::str_pad()] +#' * [`str_replace()`][stringr::str_replace()] +#' * [`str_replace_all()`][stringr::str_replace_all()] +#' * [`str_split()`][stringr::str_split()] +#' * [`str_starts()`][stringr::str_starts()] +#' * [`str_sub()`][stringr::str_sub()] +#' * [`str_to_lower()`][stringr::str_to_lower()] +#' * [`str_to_title()`][stringr::str_to_title()] +#' * [`str_to_upper()`][stringr::str_to_upper()] +#' * [`str_trim()`][stringr::str_trim()] +#' +#' ## tibble +#' +#' * [`tibble()`][tibble::tibble()] +#' +#' ## tidyselect +#' +#' * [`all_of()`][tidyselect::all_of()] +#' * [`contains()`][tidyselect::contains()] +#' * [`ends_with()`][tidyselect::ends_with()] +#' * [`everything()`][tidyselect::everything()] +#' * [`last_col()`][tidyselect::last_col()] +#' * [`matches()`][tidyselect::matches()] +#' * [`num_range()`][tidyselect::num_range()] +#' * [`one_of()`][tidyselect::one_of()] +#' * [`starts_with()`][tidyselect::starts_with()] +#' +#' @name acero +NULL diff --git a/r/R/dplyr-funcs-string.R b/r/R/dplyr-funcs-string.R index b300d7c439e..eb2326ed056 100644 --- a/r/R/dplyr-funcs-string.R +++ b/r/R/dplyr-funcs-string.R @@ -161,32 +161,44 @@ register_bindings_string_join <- function() { } } - register_binding("base::paste", function(..., sep = " ", collapse = NULL, recycle0 = FALSE) { - assert_that( - is.null(collapse), - msg = "paste() with the collapse argument is not yet supported in Arrow" - ) - if (!inherits(sep, "Expression")) { - assert_that(!is.na(sep), msg = "Invalid separator") - } - arrow_string_join_function(NullHandlingBehavior$REPLACE, "NA")(..., sep) - }) - - register_binding("base::paste0", function(..., collapse = NULL, recycle0 = FALSE) { - assert_that( - is.null(collapse), - msg = "paste0() with the collapse argument is not yet supported in Arrow" - ) - arrow_string_join_function(NullHandlingBehavior$REPLACE, "NA")(..., "") - }) - - register_binding("stringr::str_c", function(..., sep = "", collapse = NULL) { - assert_that( - is.null(collapse), - msg = "str_c() with the collapse argument is not yet supported in Arrow" - ) - arrow_string_join_function(NullHandlingBehavior$EMIT_NULL)(..., sep) - }) + register_binding( + "base::paste", + function(..., sep = " ", collapse = NULL, recycle0 = FALSE) { + assert_that( + is.null(collapse), + msg = "paste() with the collapse argument is not yet supported in Arrow" + ) + if (!inherits(sep, "Expression")) { + assert_that(!is.na(sep), msg = "Invalid separator") + } + arrow_string_join_function(NullHandlingBehavior$REPLACE, "NA")(..., sep) + }, + notes = "the `collapse` argument is not yet supported" + ) + + register_binding( + "base::paste0", + function(..., collapse = NULL, recycle0 = FALSE) { + assert_that( + is.null(collapse), + msg = "paste0() with the collapse argument is not yet supported in Arrow" + ) + arrow_string_join_function(NullHandlingBehavior$REPLACE, "NA")(..., "") + }, + notes = "the `collapse` argument is not yet supported" + ) + + register_binding( + "stringr::str_c", + function(..., sep = "", collapse = NULL) { + assert_that( + is.null(collapse), + msg = "str_c() with the collapse argument is not yet supported in Arrow" + ) + arrow_string_join_function(NullHandlingBehavior$EMIT_NULL)(..., sep) + }, + notes = "the `collapse` argument is not yet supported" + ) } register_bindings_string_regex <- function() { @@ -227,15 +239,17 @@ register_bindings_string_regex <- function() { out }) - register_binding("stringr::str_like", function(string, - pattern, - ignore_case = TRUE) { - Expression$create( - "match_like", - string, - options = list(pattern = pattern, ignore_case = ignore_case) - ) - }) + register_binding( + "stringr::str_like", + function(string, pattern, ignore_case = TRUE) { + Expression$create( + "match_like", + string, + options = list(pattern = pattern, ignore_case = ignore_case) + ) + }, + notes = "not yet in a released version of `stringr`, but it is supported in `arrow`" + ) register_binding("stringr::str_count", function(string, pattern) { opts <- get_stringr_pattern_options(enexpr(pattern)) @@ -337,7 +351,7 @@ register_bindings_string_regex <- function() { register_binding("stringr::str_replace_all", arrow_stringr_string_replace_function(-1L)) register_binding("base::strsplit", function(x, split, fixed = FALSE, perl = FALSE, - useBytes = FALSE) { + useBytes = FALSE) { assert_that(is.string(split)) arrow_fun <- ifelse(fixed, "split_pattern", "split_pattern_regex") diff --git a/r/R/dplyr-funcs-type.R b/r/R/dplyr-funcs-type.R index 9925d0347f7..aa50cdebc5d 100644 --- a/r/R/dplyr-funcs-type.R +++ b/r/R/dplyr-funcs-type.R @@ -23,23 +23,34 @@ register_bindings_type <- function() { register_bindings_type_format() } -register_bindings_type_cast <- function() { - register_binding("cast", function(x, target_type, safe = TRUE, ...) { - opts <- cast_options(safe, ...) - opts$to_type <- as_type(target_type) - Expression$create("cast", x, options = opts) - }) +#' Change the type of an array or column +#' +#' This is a wrapper around the `$cast()` method that many Arrow objects have. +#' It is more convenient to call inside `dplyr` pipelines than the method. +#' +#' @param x an `Array`, `Table`, `Expression`, or similar Arrow data object. +#' @param to [DataType] to cast to; for [Table] and [RecordBatch], +#' it should be a [Schema]. +#' @param safe logical: only allow the type conversion if no data is lost +#' (truncation, overflow, etc.). Default is `TRUE` +#' @param ... specific `CastOptions` to set +#' @return an `Expression` +#' +#' @examples +#' \dontrun{ +#' mtcars %>% +#' arrow_table() %>% +#' mutate(cyl = cast(cyl, string())) +#' } +#' @keywords internal +#' @seealso https://arrow.apache.org/docs/cpp/api/compute.html for the list of +#' supported CastOptions. +cast <- function(x, to, safe = TRUE, ...) { + x$cast(to, safe = safe, ...) +} - register_binding("dictionary_encode", function(x, - null_encoding_behavior = c("mask", "encode")) { - behavior <- toupper(match.arg(null_encoding_behavior)) - null_encoding_behavior <- NullEncodingBehavior[[behavior]] - Expression$create( - "dictionary_encode", - x, - options = list(null_encoding_behavior = null_encoding_behavior) - ) - }) +register_bindings_type_cast <- function() { + register_binding("arrow::cast", cast) # as.* type casting functions # as.factor() is mapped in expression.R diff --git a/r/R/dplyr-funcs.R b/r/R/dplyr-funcs.R index 4dadff54b48..a66db112d98 100644 --- a/r/R/dplyr-funcs.R +++ b/r/R/dplyr-funcs.R @@ -59,13 +59,17 @@ NULL #' summarise) because the data mask has to be a list. #' @param registry An environment in which the functions should be #' assigned. -#' +#' @param notes string for the docs: note any limitations or differences in +#' behavior between the Arrow version and the R function. #' @return The previously registered binding or `NULL` if no previously #' registered function existed. #' @keywords internal #' -register_binding <- function(fun_name, fun, registry = nse_funcs, - update_cache = FALSE) { +register_binding <- function(fun_name, + fun, + registry = nse_funcs, + update_cache = FALSE, + notes = character(0)) { unqualified_name <- sub("^.*?:{+}", "", fun_name) previous_fun <- registry[[unqualified_name]] @@ -76,7 +80,8 @@ register_binding <- function(fun_name, fun, registry = nse_funcs, paste0( "A \"", unqualified_name, - "\" binding already exists in the registry and will be overwritten.") + "\" binding already exists in the registry and will be overwritten." + ) ) } @@ -85,6 +90,8 @@ register_binding <- function(fun_name, fun, registry = nse_funcs, registry[[unqualified_name]] <- fun registry[[fun_name]] <- fun + .cache$docs[[fun_name]] <- notes + if (update_cache) { fun_cache <- .cache$functions fun_cache[[unqualified_name]] <- fun @@ -131,7 +138,7 @@ call_binding_agg <- function(fun_name, ...) { # Called in .onLoad() create_binding_cache <- function() { - arrow_funcs <- list() + .cache$docs <- list() # Register all available Arrow Compute functions, namespaced as arrow_fun. all_arrow_funs <- list_compute_functions() diff --git a/r/R/expression.R b/r/R/expression.R index 09a8ea24608..7a5a600d956 100644 --- a/r/R/expression.R +++ b/r/R/expression.R @@ -76,7 +76,6 @@ "lubridate::yday" = "day_of_year", "lubridate::year" = "year", "lubridate::leap_year" = "is_leap_year" - ) .binary_function_map <- list( @@ -158,13 +157,9 @@ Expression <- R6Class("Expression", compute___expr__type_id(self, schema) }, cast = function(to_type, safe = TRUE, ...) { - opts <- list( - to_type = to_type, - allow_int_overflow = !safe, - allow_time_truncate = !safe, - allow_float_truncate = !safe - ) - Expression$create("cast", self, options = modifyList(opts, list(...))) + opts <- cast_options(safe, ...) + opts$to_type <- as_type(to_type) + Expression$create("cast", self, options = opts) } ), active = list( diff --git a/r/_pkgdown.yml b/r/_pkgdown.yml index dfb0998ddff..70bd7ac518c 100644 --- a/r/_pkgdown.yml +++ b/r/_pkgdown.yml @@ -216,6 +216,7 @@ reference: - codec_is_available - title: Computation contents: + - acero - call_function - match_arrow - value_counts diff --git a/r/data-raw/docgen.R b/r/data-raw/docgen.R new file mode 100644 index 00000000000..ef39bec272f --- /dev/null +++ b/r/data-raw/docgen.R @@ -0,0 +1,192 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# This code generates dplyr-funcs-doc.R. +# It requires that the package be installed. + +file_template <- "# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# \"License\"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Generated by using data-raw/docgen.R -> do not edit by hand + +#' Functions available in Arrow dplyr queries +#' +#' The `arrow` package contains methods for %s `dplyr` table functions, many of +#' which are \"verbs\" that do transformations to one or more tables. +#' The package also has mappings of %s R functions to the corresponding +#' functions in the Arrow compute library. These allow you to write code inside +#' of `dplyr` methods that call R functions, including many in packages like +#' `stringr` and `lubridate`, and they will get translated to Arrow and run +#' on the Arrow query engine (Acero). This document lists all of the mapped +#' functions. +#' +#' # `dplyr` verbs +#' +#' Most verb functions return an `arrow_dplyr_query` object, similar in spirit +#' to a `dbplyr::tbl_lazy`. This means that the verbs do not eagerly evaluate +#' the query on the data. To run the query, call either `compute()`, +#' which returns an `arrow` [Table], or `collect()`, which pulls the resulting +#' Table into an R `data.frame`. +#' +%s +#' +#' # Function mappings +#' +#' In the list below, any differences in behavior or support between Acero and +#' the R function are listed. If no notes follow the function name, then you +#' can assume that the function works in Acero just as it does in R. +#' +#' Functions can be called either as `pkg::fun()` or just `fun()`, i.e. both +#' `str_sub()` and `stringr::str_sub()` work. +#' +#' In addition to these functions, you can call any of Arrow's %s compute +#' functions directly. Arrow has many functions that don't map to an existing R +#' function. In other cases where there is an R function mapping, you can still +#' call the Arrow function directly if you don't want the adaptations that the R +#' mapping has that make Acero behave like R. These functions are listed in the +#' [C++ documentation](https://arrow.apache.org/docs/cpp/compute.html), and +#' in the function registry in R, they are named with an `arrow_` prefix, such +#' as `arrow_ascii_is_decimal`. +#' +%s +#' +#' @name acero +NULL" + +library(dplyr) +library(purrr) + +# Functions that for whatever reason cause xref problems, so don't hyperlink +do_not_link <- c( + "stringr::str_like" # Still only in the unreleased version +) + +# Vectorized function to make entries for each function +render_fun <- function(fun, pkg_fun, notes) { + # Add () to fun if it's not an operator + not_operators <- grepl("^[[:alpha:]]", fun) + fun[not_operators] <- paste0(fun[not_operators], "()") + # Make it \code{} for better formatting + fun <- paste0("`", fun, "`") + # Wrap in \link{} + out <- ifelse( + pkg_fun %in% do_not_link, + fun, + paste0("[", fun, "][", pkg_fun, "()]") + ) + # Add notes after :, if exist + has_notes <- nzchar(notes) + out[has_notes] <- paste0(out[has_notes], ": ", notes[has_notes]) + # Make bullets + paste("*", out) +} + +# This renders a bulleted list under a package heading +render_pkg <- function(df, pkg) { + bullets <- df %>% + transmute(render_fun(fun, pkg_fun, notes)) %>% + pull() + # Add header + bullets <- c( + paste("##", pkg), + "", + bullets + ) + paste("#'", bullets, collapse = "\n") +} + +docs <- arrow:::.cache$docs + +# Add some functions + +# across() is handled by manipulating the quosures, not by nse_funcs +docs[["dplyr::across"]] <- c( + # TODO(ARROW-17387, ARROW-17389, ARROW-17390) + "only supported inside `mutate()`, `summarize()`, and `arrange()`;", + # TODO(ARROW-17366) + "purrr-style lambda functions", + "and use of `where()` selection helper not yet supported" +) +# desc() is a special helper handled inside of arrange() +docs[["dplyr::desc"]] <- character(0) + +# add tidyselect helpers by parsing the reexports file +tidyselect <- grep("^tidyselect::", readLines("R/reexports-tidyselect.R"), value = TRUE) + +docs <- c(docs, setNames(rep(list(NULL), length(tidyselect)), tidyselect)) + +fun_df <- tibble::tibble( + pkg_fun = names(docs), + notes = docs +) %>% + mutate( + has_pkg = grepl("::", pkg_fun), + fun = sub("^.*?:{+}", "", pkg_fun), + pkg = sub(":{+}.*$", "", pkg_fun), + # We will list operators under "base" (everything else must be pkg::fun) + pkg = if_else(has_pkg, pkg, "base"), + # Flatten notes to a single string + notes = map_chr(notes, ~ paste(., collapse = " ")) + ) %>% + arrange(pkg, fun) + +# Group by package name and render the lists +fun_doclets <- imap_chr(split(fun_df, fun_df$pkg), render_pkg) + +dplyr_verbs <- c( + arrow:::supported_dplyr_methods, + # Because this only has a method for arrow_dplyr_query, it's not in the main list + tbl_vars = NULL +) + +verb_bullets <- tibble::tibble( + fun = names(dplyr_verbs), + notes = dplyr_verbs +) %>% + mutate( + pkg_fun = paste0("dplyr::", fun), + notes = map_chr(notes, ~ paste(., collapse = " ")) + ) %>% + arrange(fun) %>% + transmute(render_fun(fun, pkg_fun, notes)) %>% + pull() + +writeLines( + sprintf( + file_template, + length(dplyr_verbs), + length(docs), + paste("#'", verb_bullets, collapse = "\n"), + length(arrow::list_compute_functions()), + paste(fun_doclets, collapse = "\n#'\n") + ), + "R/dplyr-funcs-doc.R" +) diff --git a/r/man/acero.Rd b/r/man/acero.Rd new file mode 100644 index 00000000000..5b5920f386e --- /dev/null +++ b/r/man/acero.Rd @@ -0,0 +1,339 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dplyr-funcs-doc.R +\name{acero} +\alias{acero} +\title{Functions available in Arrow dplyr queries} +\description{ +The \code{arrow} package contains methods for 32 \code{dplyr} table functions, many of +which are "verbs" that do transformations to one or more tables. +The package also has mappings of 205 R functions to the corresponding +functions in the Arrow compute library. These allow you to write code inside +of \code{dplyr} methods that call R functions, including many in packages like +\code{stringr} and \code{lubridate}, and they will get translated to Arrow and run +on the Arrow query engine (Acero). This document lists all of the mapped +functions. +} +\section{\code{dplyr} verbs}{ +Most verb functions return an \code{arrow_dplyr_query} object, similar in spirit +to a \code{dbplyr::tbl_lazy}. This means that the verbs do not eagerly evaluate +the query on the data. To run the query, call either \code{compute()}, +which returns an \code{arrow} \link{Table}, or \code{collect()}, which pulls the resulting +Table into an R \code{data.frame}. +\itemize{ +\item \code{\link[dplyr:filter-joins]{anti_join()}} +\item \code{\link[dplyr:arrange]{arrange()}} +\item \code{\link[dplyr:compute]{collapse()}} +\item \code{\link[dplyr:compute]{collect()}} +\item \code{\link[dplyr:compute]{compute()}} +\item \code{\link[dplyr:count]{count()}} +\item \code{\link[dplyr:distinct]{distinct()}} +\item \code{\link[dplyr:explain]{explain()}} +\item \code{\link[dplyr:filter]{filter()}} +\item \code{\link[dplyr:mutate-joins]{full_join()}} +\item \code{\link[dplyr:glimpse]{glimpse()}} +\item \code{\link[dplyr:group_by]{group_by()}} +\item \code{\link[dplyr:group_by_drop_default]{group_by_drop_default()}} +\item \code{\link[dplyr:group_data]{group_vars()}} +\item \code{\link[dplyr:group_data]{groups()}} +\item \code{\link[dplyr:mutate-joins]{inner_join()}} +\item \code{\link[dplyr:mutate-joins]{left_join()}} +\item \code{\link[dplyr:mutate]{mutate()}} +\item \code{\link[dplyr:pull]{pull()}} +\item \code{\link[dplyr:relocate]{relocate()}} +\item \code{\link[dplyr:rename]{rename()}} +\item \code{\link[dplyr:rename]{rename_with()}} +\item \code{\link[dplyr:mutate-joins]{right_join()}} +\item \code{\link[dplyr:select]{select()}} +\item \code{\link[dplyr:filter-joins]{semi_join()}} +\item \code{\link[dplyr:explain]{show_query()}} +\item \code{\link[dplyr:summarise]{summarise()}} +\item \code{\link[dplyr:count]{tally()}} +\item \code{\link[dplyr:mutate]{transmute()}} +\item \code{\link[dplyr:group_by]{ungroup()}} +\item \code{\link[dplyr:reexports]{union()}} +\item \code{\link[dplyr:setops]{union_all()}} +} +} + +\section{Function mappings}{ +In the list below, any differences in behavior or support between Acero and +the R function are listed. If no notes follow the function name, then you +can assume that the function works in Acero just as it does in R. + +Functions can be called either as \code{pkg::fun()} or just \code{fun()}, i.e. both +\code{str_sub()} and \code{stringr::str_sub()} work. + +In addition to these functions, you can call any of Arrow's 243 compute +functions directly. Arrow has many functions that don't map to an existing R +function. In other cases where there is an R function mapping, you can still +call the Arrow function directly if you don't want the adaptations that the R +mapping has that make Acero behave like R. These functions are listed in the +\href{https://arrow.apache.org/docs/cpp/compute.html}{C++ documentation}, and +in the function registry in R, they are named with an \code{arrow_} prefix, such +as \code{arrow_ascii_is_decimal}. +\subsection{arrow}{ +\itemize{ +\item \code{\link[=add_filename]{add_filename()}} +\item \code{\link[=cast]{cast()}} +} +} + +\subsection{base}{ +\itemize{ +\item \code{\link[=-]{-}} +\item \code{\link[=!]{!}} +\item \code{\link[=!=]{!=}} +\item \code{\link[=*]{*}} +\item \code{\link[=/]{/}} +\item \code{\link[=&]{&}} +\item \code{\link[=\%/\%]{\%/\%}} +\item \code{\link[=\%\%]{\%\%}} +\item \code{\link[=\%in\%]{\%in\%}} +\item \code{\link[=^]{^}} +\item \code{\link[=+]{+}} +\item \code{\link[=<]{<}} +\item \code{\link[=<=]{<=}} +\item \code{\link[===]{==}} +\item \code{\link[=>]{>}} +\item \code{\link[=>=]{>=}} +\item \code{\link[=|]{|}} +\item \code{\link[base:MathFun]{abs()}} +\item \code{\link[base:Trig]{acos()}} +\item \code{\link[base:all]{all()}} +\item \code{\link[base:any]{any()}} +\item \code{\link[base:character]{as.character()}} +\item \code{\link[base:as.Date]{as.Date()}} +\item \code{\link[base:difftime]{as.difftime()}} +\item \code{\link[base:double]{as.double()}} +\item \code{\link[base:integer]{as.integer()}} +\item \code{\link[base:logical]{as.logical()}} +\item \code{\link[base:numeric]{as.numeric()}} +\item \code{\link[base:Trig]{asin()}} +\item \code{\link[base:Round]{ceiling()}} +\item \code{\link[base:Trig]{cos()}} +\item \code{\link[base:data.frame]{data.frame()}} +\item \code{\link[base:difftime]{difftime()}} +\item \code{\link[base:startsWith]{endsWith()}} +\item \code{\link[base:Log]{exp()}} +\item \code{\link[base:Round]{floor()}} +\item \code{\link[base:format]{format()}} +\item \code{\link[base:grep]{grepl()}} +\item \code{\link[base:grep]{gsub()}} +\item \code{\link[base:ifelse]{ifelse()}} +\item \code{\link[base:character]{is.character()}} +\item \code{\link[base:double]{is.double()}} +\item \code{\link[base:factor]{is.factor()}} +\item \code{\link[base:is.finite]{is.finite()}} +\item \code{\link[base:is.finite]{is.infinite()}} +\item \code{\link[base:integer]{is.integer()}} +\item \code{\link[base:list]{is.list()}} +\item \code{\link[base:logical]{is.logical()}} +\item \code{\link[base:NA]{is.na()}} +\item \code{\link[base:is.finite]{is.nan()}} +\item \code{\link[base:numeric]{is.numeric()}} +\item \code{\link[base:ISOdatetime]{ISOdate()}} +\item \code{\link[base:ISOdatetime]{ISOdatetime()}} +\item \code{\link[base:Log]{log()}} +\item \code{\link[base:Log]{log10()}} +\item \code{\link[base:Log]{log1p()}} +\item \code{\link[base:Log]{log2()}} +\item \code{\link[base:Log]{logb()}} +\item \code{\link[base:Extremes]{max()}} +\item \code{\link[base:mean]{mean()}} +\item \code{\link[base:Extremes]{min()}} +\item \code{\link[base:nchar]{nchar()}} +\item \code{\link[base:paste]{paste()}}: the \code{collapse} argument is not yet supported +\item \code{\link[base:paste]{paste0()}}: the \code{collapse} argument is not yet supported +\item \code{\link[base:Extremes]{pmax()}} +\item \code{\link[base:Extremes]{pmin()}} +\item \code{\link[base:Round]{round()}} +\item \code{\link[base:sign]{sign()}} +\item \code{\link[base:Trig]{sin()}} +\item \code{\link[base:MathFun]{sqrt()}} +\item \code{\link[base:startsWith]{startsWith()}} +\item \code{\link[base:strptime]{strftime()}} +\item \code{\link[base:strptime]{strptime()}} +\item \code{\link[base:strrep]{strrep()}} +\item \code{\link[base:strsplit]{strsplit()}} +\item \code{\link[base:grep]{sub()}} +\item \code{\link[base:substr]{substr()}} +\item \code{\link[base:substr]{substring()}} +\item \code{\link[base:sum]{sum()}} +\item \code{\link[base:Trig]{tan()}} +\item \code{\link[base:chartr]{tolower()}} +\item \code{\link[base:chartr]{toupper()}} +\item \code{\link[base:Round]{trunc()}} +} +} + +\subsection{bit64}{ +\itemize{ +\item \code{\link[bit64:as.integer64.character]{as.integer64()}} +\item \code{\link[bit64:bit64-package]{is.integer64()}} +} +} + +\subsection{dplyr}{ +\itemize{ +\item \code{\link[dplyr:across]{across()}}: only supported inside \code{mutate()}, \code{summarize()}, and \code{arrange()}; purrr-style lambda functions and use of \code{where()} selection helper not yet supported +\item \code{\link[dplyr:between]{between()}} +\item \code{\link[dplyr:case_when]{case_when()}} +\item \code{\link[dplyr:coalesce]{coalesce()}} +\item \code{\link[dplyr:desc]{desc()}} +\item \code{\link[dplyr:if_else]{if_else()}} +\item \code{\link[dplyr:context]{n()}} +\item \code{\link[dplyr:n_distinct]{n_distinct()}} +} +} + +\subsection{lubridate}{ +\itemize{ +\item \code{\link[lubridate:am]{am()}} +\item \code{\link[lubridate:as_date]{as_date()}} +\item \code{\link[lubridate:as_date]{as_datetime()}} +\item \code{\link[lubridate:round_date]{ceiling_date()}} +\item \code{\link[lubridate:date]{date()}} +\item \code{\link[lubridate:date_decimal]{date_decimal()}} +\item \code{\link[lubridate:day]{day()}} +\item \code{\link[lubridate:duration]{ddays()}} +\item \code{\link[lubridate:decimal_date]{decimal_date()}} +\item \code{\link[lubridate:duration]{dhours()}} +\item \code{\link[lubridate:duration]{dmicroseconds()}} +\item \code{\link[lubridate:duration]{dmilliseconds()}} +\item \code{\link[lubridate:duration]{dminutes()}} +\item \code{\link[lubridate:duration]{dmonths()}} +\item \code{\link[lubridate:ymd]{dmy()}} +\item \code{\link[lubridate:ymd_hms]{dmy_h()}} +\item \code{\link[lubridate:ymd_hms]{dmy_hm()}} +\item \code{\link[lubridate:ymd_hms]{dmy_hms()}} +\item \code{\link[lubridate:duration]{dnanoseconds()}} +\item \code{\link[lubridate:duration]{dpicoseconds()}} +\item \code{\link[lubridate:duration]{dseconds()}} +\item \code{\link[lubridate:dst]{dst()}} +\item \code{\link[lubridate:duration]{dweeks()}} +\item \code{\link[lubridate:duration]{dyears()}} +\item \code{\link[lubridate:ymd]{dym()}} +\item \code{\link[lubridate:week]{epiweek()}} +\item \code{\link[lubridate:year]{epiyear()}} +\item \code{\link[lubridate:parse_date_time]{fast_strptime()}} +\item \code{\link[lubridate:round_date]{floor_date()}} +\item \code{\link[lubridate:format_ISO8601]{format_ISO8601()}} +\item \code{\link[lubridate:hour]{hour()}} +\item \code{\link[lubridate:date_utils]{is.Date()}} +\item \code{\link[lubridate:is.instant]{is.instant()}} +\item \code{\link[lubridate:posix_utils]{is.POSIXct()}} +\item \code{\link[lubridate:is.instant]{is.timepoint()}} +\item \code{\link[lubridate:week]{isoweek()}} +\item \code{\link[lubridate:year]{isoyear()}} +\item \code{\link[lubridate:leap_year]{leap_year()}} +\item \code{\link[lubridate:make_datetime]{make_date()}} +\item \code{\link[lubridate:make_datetime]{make_datetime()}} +\item \code{\link[lubridate:make_difftime]{make_difftime()}} +\item \code{\link[lubridate:day]{mday()}} +\item \code{\link[lubridate:ymd]{mdy()}} +\item \code{\link[lubridate:ymd_hms]{mdy_h()}} +\item \code{\link[lubridate:ymd_hms]{mdy_hm()}} +\item \code{\link[lubridate:ymd_hms]{mdy_hms()}} +\item \code{\link[lubridate:minute]{minute()}} +\item \code{\link[lubridate:month]{month()}} +\item \code{\link[lubridate:ymd]{my()}} +\item \code{\link[lubridate:ymd]{myd()}} +\item \code{\link[lubridate:parse_date_time]{parse_date_time()}} +\item \code{\link[lubridate:am]{pm()}} +\item \code{\link[lubridate:day]{qday()}} +\item \code{\link[lubridate:quarter]{quarter()}} +\item \code{\link[lubridate:round_date]{round_date()}} +\item \code{\link[lubridate:second]{second()}} +\item \code{\link[lubridate:quarter]{semester()}} +\item \code{\link[lubridate:tz]{tz()}} +\item \code{\link[lubridate:day]{wday()}} +\item \code{\link[lubridate:week]{week()}} +\item \code{\link[lubridate:day]{yday()}} +\item \code{\link[lubridate:ymd]{ydm()}} +\item \code{\link[lubridate:ymd_hms]{ydm_h()}} +\item \code{\link[lubridate:ymd_hms]{ydm_hm()}} +\item \code{\link[lubridate:ymd_hms]{ydm_hms()}} +\item \code{\link[lubridate:year]{year()}} +\item \code{\link[lubridate:ymd]{ym()}} +\item \code{\link[lubridate:ymd]{ymd()}} +\item \code{\link[lubridate:ymd_hms]{ymd_h()}} +\item \code{\link[lubridate:ymd_hms]{ymd_hm()}} +\item \code{\link[lubridate:ymd_hms]{ymd_hms()}} +\item \code{\link[lubridate:ymd]{yq()}} +} +} + +\subsection{methods}{ +\itemize{ +\item \code{\link[methods:is]{is()}} +} +} + +\subsection{rlang}{ +\itemize{ +\item \code{\link[rlang:type-predicates]{is_character()}} +\item \code{\link[rlang:type-predicates]{is_double()}} +\item \code{\link[rlang:type-predicates]{is_integer()}} +\item \code{\link[rlang:type-predicates]{is_list()}} +\item \code{\link[rlang:type-predicates]{is_logical()}} +} +} + +\subsection{stats}{ +\itemize{ +\item \code{\link[stats:median]{median()}} +\item \code{\link[stats:quantile]{quantile()}} +\item \code{\link[stats:sd]{sd()}} +\item \code{\link[stats:cor]{var()}} +} +} + +\subsection{stringi}{ +\itemize{ +\item \code{\link[stringi:stri_reverse]{stri_reverse()}} +} +} + +\subsection{stringr}{ +\itemize{ +\item \code{\link[stringr:str_c]{str_c()}}: the \code{collapse} argument is not yet supported +\item \code{\link[stringr:str_count]{str_count()}} +\item \code{\link[stringr:str_detect]{str_detect()}} +\item \code{\link[stringr:str_dup]{str_dup()}} +\item \code{\link[stringr:str_starts]{str_ends()}} +\item \code{\link[stringr:str_length]{str_length()}} +\item \code{str_like()}: not yet in a released version of \code{stringr}, but it is supported in \code{arrow} +\item \code{\link[stringr:str_pad]{str_pad()}} +\item \code{\link[stringr:str_replace]{str_replace()}} +\item \code{\link[stringr:str_replace]{str_replace_all()}} +\item \code{\link[stringr:str_split]{str_split()}} +\item \code{\link[stringr:str_starts]{str_starts()}} +\item \code{\link[stringr:str_sub]{str_sub()}} +\item \code{\link[stringr:case]{str_to_lower()}} +\item \code{\link[stringr:case]{str_to_title()}} +\item \code{\link[stringr:case]{str_to_upper()}} +\item \code{\link[stringr:str_trim]{str_trim()}} +} +} + +\subsection{tibble}{ +\itemize{ +\item \code{\link[tibble:tibble]{tibble()}} +} +} + +\subsection{tidyselect}{ +\itemize{ +\item \code{\link[tidyselect:all_of]{all_of()}} +\item \code{\link[tidyselect:starts_with]{contains()}} +\item \code{\link[tidyselect:starts_with]{ends_with()}} +\item \code{\link[tidyselect:everything]{everything()}} +\item \code{\link[tidyselect:everything]{last_col()}} +\item \code{\link[tidyselect:starts_with]{matches()}} +\item \code{\link[tidyselect:starts_with]{num_range()}} +\item \code{\link[tidyselect:one_of]{one_of()}} +\item \code{\link[tidyselect:starts_with]{starts_with()}} +} +} +} + diff --git a/r/man/add_filename.Rd b/r/man/add_filename.Rd new file mode 100644 index 00000000000..ca7ed0e4b17 --- /dev/null +++ b/r/man/add_filename.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dplyr-funcs-augmented.R +\name{add_filename} +\alias{add_filename} +\title{Add the data filename as a column} +\usage{ +add_filename() +} +\value{ +A \code{FieldRef} \code{Expression} that refers to the filename augmented +column. +} +\description{ +This function only exists inside \code{arrow} \code{dplyr} queries, and it only is +valid when quering on a \code{FileSystemDataset}. +} +\examples{ +\dontrun{ +open_dataset("nyc-taxi") \%>\% + mutate(file = add_filename()) +} +} +\keyword{internal} diff --git a/r/man/cast.Rd b/r/man/cast.Rd new file mode 100644 index 00000000000..88134f2e022 --- /dev/null +++ b/r/man/cast.Rd @@ -0,0 +1,38 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/dplyr-funcs-type.R +\name{cast} +\alias{cast} +\title{Change the type of an array or column} +\usage{ +cast(x, to, safe = TRUE, ...) +} +\arguments{ +\item{x}{an \code{Array}, \code{Table}, \code{Expression}, or similar Arrow data object.} + +\item{to}{\link{DataType} to cast to; for \link{Table} and \link{RecordBatch}, +it should be a \link{Schema}.} + +\item{safe}{logical: only allow the type conversion if no data is lost +(truncation, overflow, etc.). Default is \code{TRUE}} + +\item{...}{specific \code{CastOptions} to set} +} +\value{ +an \code{Expression} +} +\description{ +This is a wrapper around the \verb{$cast()} method that many Arrow objects have. +It is more convenient to call inside \code{dplyr} pipelines than the method. +} +\examples{ +\dontrun{ +mtcars \%>\% + arrow_table() \%>\% + mutate(cyl = cast(cyl, string())) +} +} +\seealso{ +https://arrow.apache.org/docs/cpp/api/compute.html for the list of +supported CastOptions. +} +\keyword{internal} diff --git a/r/man/register_binding.Rd b/r/man/register_binding.Rd index c53df707516..d2a4a380543 100644 --- a/r/man/register_binding.Rd +++ b/r/man/register_binding.Rd @@ -4,7 +4,13 @@ \alias{register_binding} \title{Register compute bindings} \usage{ -register_binding(fun_name, fun, registry = nse_funcs, update_cache = FALSE) +register_binding( + fun_name, + fun, + registry = nse_funcs, + update_cache = FALSE, + notes = character(0) +) } \arguments{ \item{fun_name}{A string containing a function name in the form \code{"function"} or @@ -26,6 +32,9 @@ non-aggregate functions could be revisited...it is currently used as the data mask in mutate, filter, and aggregate (but not summarise) because the data mask has to be a list.} +\item{notes}{string for the docs: note any limitations or differences in +behavior between the Arrow version and the R function.} + \item{agg_fun}{An aggregate function or \code{NULL} to un-register a previous aggregate function. This function must accept \code{Expression} objects as arguments and return a \code{list()} with components: From 557acf524f6b73d73bdb9464e947b78b9d02fcea Mon Sep 17 00:00:00 2001 From: eitsupi <50911393+eitsupi@users.noreply.github.com> Date: Fri, 16 Sep 2022 10:42:34 +0900 Subject: [PATCH 075/133] ARROW-17689: [R] Implement dplyr::across() inside group_by() (#14122) Because the handling of the case `.add = TRUE` and the `add` argument have been changed, test cases for these are also added. Authored-by: SHIMA Tatsuya Signed-off-by: Dewey Dunnington --- r/R/dplyr-group-by.R | 38 ++++----- r/tests/testthat/test-dplyr-group-by.R | 110 +++++++++++++++++++++++++ 2 files changed, 126 insertions(+), 22 deletions(-) diff --git a/r/R/dplyr-group-by.R b/r/R/dplyr-group-by.R index c650799e8d0..57cf417c9ad 100644 --- a/r/R/dplyr-group-by.R +++ b/r/R/dplyr-group-by.R @@ -21,37 +21,31 @@ group_by.arrow_dplyr_query <- function(.data, ..., .add = FALSE, - add = .add, + add = NULL, .drop = dplyr::group_by_drop_default(.data)) { + if (!missing(add)) { + .Deprecated( + msg = paste("The `add` argument of `group_by()` is deprecated. Please use the `.add` argument instead.") + ) + .add <- add + } + .data <- as_adq(.data) - new_groups <- enquos(...) - # ... can contain expressions (i.e. can add (or rename?) columns) and so we - # need to identify those and add them on to the query with mutate. Specifically, - # we want to mark as new: - # * expressions (named or otherwise) - # * variables that have new names - # All others (i.e. simple references to variables) should not be (re)-added + expression_list <- expand_across(.data, quos(...)) + new_groups <- ensure_named_exprs(expression_list) - # Identify any groups with names which aren't in names of .data - new_group_ind <- map_lgl(new_groups, ~ !(quo_name(.x) %in% names(.data))) - # Identify any groups which don't have names - named_group_ind <- map_lgl(names(new_groups), nzchar) - # Retain any new groups identified above - new_groups <- new_groups[new_group_ind | named_group_ind] if (length(new_groups)) { - # now either use the name that was given in ... or if that is "" then use the expr - names(new_groups) <- imap_chr(new_groups, ~ ifelse(.y == "", quo_name(.x), .y)) - # Add them to the data .data <- dplyr::mutate(.data, !!!new_groups) } - if (".add" %in% names(formals(dplyr::group_by))) { - # For compatibility with dplyr >= 1.0 - gv <- dplyr::group_by_prepare(.data, ..., .add = .add)$group_names + + if (.add) { + gv <- union(dplyr::group_vars(.data), names(new_groups)) } else { - gv <- dplyr::group_by_prepare(.data, ..., add = add)$group_names + gv <- names(new_groups) } - .data$group_by_vars <- gv + + .data$group_by_vars <- gv %||% character() .data$drop_empty_groups <- ifelse(length(gv), .drop, dplyr::group_by_drop_default(.data)) .data } diff --git a/r/tests/testthat/test-dplyr-group-by.R b/r/tests/testthat/test-dplyr-group-by.R index c7380e96ec3..9bb6aa9600d 100644 --- a/r/tests/testthat/test-dplyr-group-by.R +++ b/r/tests/testthat/test-dplyr-group-by.R @@ -166,3 +166,113 @@ test_that("group_by() with namespaced functions", { tbl ) }) + +test_that("group_by() with .add", { + compare_dplyr_binding( + .input %>% + group_by(dbl2) %>% + group_by(.add = FALSE) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + group_by(dbl2) %>% + group_by(.add = TRUE) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + group_by(dbl2) %>% + group_by(chr, .add = FALSE) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + group_by(dbl2) %>% + group_by(chr, .add = TRUE) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + group_by(chr, .add = FALSE) %>% + collect(), + tbl %>% + group_by(dbl2) + ) + compare_dplyr_binding( + .input %>% + group_by(chr, .add = TRUE) %>% + collect(), + tbl %>% + group_by(dbl2) + ) + suppressWarnings(compare_dplyr_binding( + .input %>% + group_by(dbl2) %>% + group_by(add = FALSE) %>% + collect(), + tbl, + warning = "deprecated" + )) + suppressWarnings(compare_dplyr_binding( + .input %>% + group_by(dbl2) %>% + group_by(add = TRUE) %>% + collect(), + tbl, + warning = "deprecated" + )) + expect_warning( + tbl %>% + arrow_table() %>% + group_by(add = TRUE) %>% + collect(), + "The `add` argument of `group_by\\(\\)` is deprecated" + ) + expect_error( + suppressWarnings( + tbl %>% + arrow_table() %>% + group_by(add = dbl2) %>% + collect() + ), + "object 'dbl2' not found" + ) +}) + +test_that("Can use across() within group_by()", { + test_groups <- c("dbl", "int", "chr") + compare_dplyr_binding( + .input %>% + group_by(across()) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + group_by(across(starts_with("d"))) %>% + collect(), + tbl + ) + compare_dplyr_binding( + .input %>% + group_by(across({{ test_groups }})) %>% + collect(), + tbl + ) + + # ARROW-12778 - `where()` is not yet supported + expect_error( + compare_dplyr_binding( + .input %>% + group_by(across(where(is.numeric))) %>% + collect(), + tbl + ), + "Unsupported selection helper" + ) +}) From 59883630fcd737079e18035a3269a31eb7e0495e Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Thu, 15 Sep 2022 22:43:56 -0300 Subject: [PATCH 076/133] ARROW-17178: [R] Support head() in arrow_dplyr_query with user-defined function (#13706) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds support for more types of queries that include calls to R code (i.e., `map_batches(..., .lazy = TRUE)`, user-defined functions in mutates, arranges, and filters, and custom extension type behaviour). Previously these queries failed because it wasn't possible to guarantee that the exec plan would be completely executed within a call to `RunWithCapturedR()` where we establish an event loop on the main R thread and launch a background thread to do "arrow stuff" that can queue tasks to run on the R thread. The approach I took here was to stuff more of the ExecPlan-to-RecordBatchReader logic in a subclass of RecordBatchReader that doesn't call `plan->StartProducing()` until the first batch has been pulled. This lets you return a record batch reader and pass it around at the R level (currently how head/tail/a few other things are implemented), and as long as it's drained all at once (i.e., `reader$read_table()`) the calls into R will work. The R code calls within an exec plan *won't* work with `reader$read_next_batch()` or the C data interface because there we can't guarantee an event loop. This also has the benefit of allowing us to inject some cancelability to the ExecPlan since we can check a StopToken after #13635 (ARROW-11841) for an interrupt (for all exec plans). The biggest benefit is, in my view, that the lifecycle of the ExecPlan is more explicit...before, the plan was stopped when the object was deleted but it was written in a way that I didn't understand for a long time. I think a reader subclass makes it more explicit and maybe will help to print out nested queries (since they're no longer eagerly evaluated). An example of something that didn't work before that now does: ``` r library(arrow, warn.conflicts = FALSE) #> Some features are not enabled in this build of Arrow. Run `arrow_info()` for more information. library(dplyr, warn.conflicts = FALSE) register_scalar_function( "times_32", function(context, x) x * 32.0, int32(), float64(), auto_convert = TRUE ) record_batch(a = 1:1000) %>% dplyr::mutate(b = times_32(a)) %>% as_record_batch_reader() %>% as_arrow_table() #> Table #> 1000 rows x 2 columns #> $a #> $b record_batch(a = 1:1000) %>% dplyr::mutate(fun_result = times_32(a)) %>% head(11) %>% dplyr::collect() #> # A tibble: 11 × 2 #> a fun_result #> #> 1 1 32 #> 2 2 64 #> 3 3 96 #> 4 4 128 #> 5 5 160 #> 6 6 192 #> 7 7 224 #> 8 8 256 #> 9 9 288 #> 10 10 320 #> 11 11 352 ``` Created on 2022-07-25 by the [reprex package](https://reprex.tidyverse.org) (v2.0.1) Lead-authored-by: Dewey Dunnington Co-authored-by: Dewey Dunnington Signed-off-by: Dewey Dunnington --- r/R/arrowExports.R | 32 ++-- r/R/compute.R | 7 - r/R/dplyr.R | 9 +- r/R/query-engine.R | 84 +++------- r/R/record-batch-reader.R | 1 + r/R/table.R | 12 +- r/src/arrowExports.cpp | 78 ++++++---- r/src/arrow_types.h | 2 + r/src/compute-exec.cpp | 224 +++++++++++++++++---------- r/src/recordbatchreader.cpp | 81 ++++++++-- r/tests/testthat/test-compute.R | 92 +++++------ r/tests/testthat/test-query-engine.R | 60 +++++++ 12 files changed, 416 insertions(+), 266 deletions(-) diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index ab3358d6664..6e76cd64687 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -404,24 +404,32 @@ ExecPlan_create <- function(use_threads) { .Call(`_arrow_ExecPlan_create`, use_threads) } -ExecPlan_run <- function(plan, final_node, sort_options, metadata, head) { - .Call(`_arrow_ExecPlan_run`, plan, final_node, sort_options, metadata, head) +ExecPlanReader__batches <- function(reader) { + .Call(`_arrow_ExecPlanReader__batches`, reader) } -ExecPlan_read_table <- function(plan, final_node, sort_options, metadata, head) { - .Call(`_arrow_ExecPlan_read_table`, plan, final_node, sort_options, metadata, head) +Table__from_ExecPlanReader <- function(reader) { + .Call(`_arrow_Table__from_ExecPlanReader`, reader) } -ExecPlan_StopProducing <- function(plan) { - invisible(.Call(`_arrow_ExecPlan_StopProducing`, plan)) +ExecPlanReader__Plan <- function(reader) { + .Call(`_arrow_ExecPlanReader__Plan`, reader) } -ExecNode_output_schema <- function(node) { - .Call(`_arrow_ExecNode_output_schema`, node) +ExecPlanReader__PlanStatus <- function(reader) { + .Call(`_arrow_ExecPlanReader__PlanStatus`, reader) +} + +ExecPlan_run <- function(plan, final_node, sort_options, metadata, head) { + .Call(`_arrow_ExecPlan_run`, plan, final_node, sort_options, metadata, head) } -ExecPlan_BuildAndShow <- function(plan, final_node, sort_options, head) { - .Call(`_arrow_ExecPlan_BuildAndShow`, plan, final_node, sort_options, head) +ExecPlan_ToString <- function(plan) { + .Call(`_arrow_ExecPlan_ToString`, plan) +} + +ExecNode_output_schema <- function(node) { + .Call(`_arrow_ExecNode_output_schema`, node) } ExecNode_Scan <- function(plan, dataset, filter, materialized_field_names) { @@ -1728,6 +1736,10 @@ RecordBatchReader__schema <- function(reader) { .Call(`_arrow_RecordBatchReader__schema`, reader) } +RecordBatchReader__Close <- function(reader) { + invisible(.Call(`_arrow_RecordBatchReader__Close`, reader)) +} + RecordBatchReader__ReadNext <- function(reader) { .Call(`_arrow_RecordBatchReader__ReadNext`, reader) } diff --git a/r/R/compute.R b/r/R/compute.R index 636c9146ca3..a144e7d678a 100644 --- a/r/R/compute.R +++ b/r/R/compute.R @@ -385,13 +385,6 @@ register_scalar_function <- function(name, fun, in_type, out_type, update_cache = TRUE ) - # User-defined functions require some special handling - # in the query engine which currently require an opt-in using - # the R_ARROW_COLLECT_WITH_UDF environment variable while this - # behaviour is stabilized. - # TODO(ARROW-17178) remove the need for this! - Sys.setenv(R_ARROW_COLLECT_WITH_UDF = "true") - invisible(NULL) } diff --git a/r/R/dplyr.R b/r/R/dplyr.R index dffe269199c..86132d8ae4a 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -266,7 +266,7 @@ tail.arrow_dplyr_query <- function(x, n = 6L, ...) { #' show_exec_plan() show_exec_plan <- function(x) { adq <- as_adq(x) - plan <- ExecPlan$create() + # do not show the plan if we have a nested query (as this will force the # evaluation of the inner query/queries) # TODO see if we can remove after ARROW-16628 @@ -274,8 +274,11 @@ show_exec_plan <- function(x) { warn("The `ExecPlan` cannot be printed for a nested query.") return(invisible(x)) } - final_node <- plan$Build(adq) - cat(plan$BuildAndShow(final_node)) + + result <- as_record_batch_reader(adq) + cat(result$Plan()$ToString()) + result$Close() + invisible(x) } diff --git a/r/R/query-engine.R b/r/R/query-engine.R index c132b291b87..89a8c6a7f37 100644 --- a/r/R/query-engine.R +++ b/r/R/query-engine.R @@ -194,13 +194,11 @@ ExecPlan <- R6Class("ExecPlan", } node }, - Run = function(node, as_table = FALSE) { - # a section of this code is used by `BuildAndShow()` too - the 2 need to be in sync - # Start of chunk used in `BuildAndShow()` + Run = function(node) { assert_is(node, "ExecNode") # Sorting and head/tail (if sorted) are handled in the SinkNode, - # created in ExecPlan_run + # created in ExecPlan_build sorting <- node$extras$sort %||% list() select_k <- node$extras$head %||% -1L has_sorting <- length(sorting) > 0 @@ -214,16 +212,7 @@ ExecPlan <- R6Class("ExecPlan", sorting$orders <- as.integer(sorting$orders) } - # End of chunk used in `BuildAndShow()` - - # If we are going to return a Table anyway, we do this in one step and - # entirely in one C++ call to ensure that we can execute user-defined - # functions from the worker threads spawned by the ExecPlan. If not, we - # use ExecPlan_run which returns a RecordBatchReader that can be - # manipulated in R code (but that right now won't work with - # user-defined functions). - exec_fun <- if (as_table) ExecPlan_read_table else ExecPlan_run - out <- exec_fun( + out <- ExecPlan_run( self, node, sorting, @@ -240,18 +229,13 @@ ExecPlan <- R6Class("ExecPlan", slice_size <- node$extras$head %||% node$extras$tail if (!is.null(slice_size)) { out <- head(out, slice_size) - # We already have everything we need for the head, so StopProducing - self$Stop() } } else if (!is.null(node$extras$tail)) { # TODO(ARROW-16630): proper BottomK support # Reverse the row order to get back what we expect out <- as_arrow_table(out) out <- out[rev(seq_len(nrow(out))), , drop = FALSE] - # Put back into RBR - if (!as_table) { - out <- as_record_batch_reader(out) - } + out <- as_record_batch_reader(out) } # If arrange() created $temp_columns, make sure to omit them from the result @@ -261,11 +245,7 @@ ExecPlan <- R6Class("ExecPlan", if (length(node$extras$sort$temp_columns) > 0) { tab <- as_arrow_table(out) tab <- tab[, setdiff(names(tab), node$extras$sort$temp_columns), drop = FALSE] - if (!as_table) { - out <- as_record_batch_reader(tab) - } else { - out <- tab - } + out <- as_record_batch_reader(tab) } out @@ -279,40 +259,9 @@ ExecPlan <- R6Class("ExecPlan", ... ) }, - # SinkNodes (involved in arrange and/or head/tail operations) are created in - # ExecPlan_run and are not captured by the regulat print method. We take a - # similar approach to expose them before calling the print method. - BuildAndShow = function(node) { - # a section of this code is copied from `Run()` - the 2 need to be in sync - # Start of chunk copied from `Run()` - - assert_is(node, "ExecNode") - - # Sorting and head/tail (if sorted) are handled in the SinkNode, - # created in ExecPlan_run - sorting <- node$extras$sort %||% list() - select_k <- node$extras$head %||% -1L - has_sorting <- length(sorting) > 0 - if (has_sorting) { - if (!is.null(node$extras$tail)) { - # Reverse the sort order and take the top K, then after we'll reverse - # the resulting rows so that it is ordered as expected - sorting$orders <- !sorting$orders - select_k <- node$extras$tail - } - sorting$orders <- as.integer(sorting$orders) - } - - # End of chunk copied from `Run()` - - ExecPlan_BuildAndShow( - self, - node, - sorting, - select_k - ) - }, - Stop = function() ExecPlan_StopProducing(self) + ToString = function() { + ExecPlan_ToString(self) + } ) ) # nolint end. @@ -396,6 +345,23 @@ ExecNode <- R6Class("ExecNode", ) ) +ExecPlanReader <- R6Class("ExecPlanReader", + inherit = RecordBatchReader, + public = list( + batches = function() ExecPlanReader__batches(self), + read_table = function() Table__from_ExecPlanReader(self), + Plan = function() ExecPlanReader__Plan(self), + PlanStatus = function() ExecPlanReader__PlanStatus(self), + ToString = function() { + sprintf( + "\n\n%s\n\nSee $Plan() for details.", + self$PlanStatus(), + super$ToString() + ) + } + ) +) + do_exec_plan_substrait <- function(substrait_plan) { if (is.string(substrait_plan)) { substrait_plan <- substrait__internal__SubstraitFromJSON(substrait_plan) diff --git a/r/R/record-batch-reader.R b/r/R/record-batch-reader.R index 3a985d8abce..e1dd52ed715 100644 --- a/r/R/record-batch-reader.R +++ b/r/R/record-batch-reader.R @@ -98,6 +98,7 @@ RecordBatchReader <- R6Class("RecordBatchReader", read_next_batch = function() RecordBatchReader__ReadNext(self), batches = function() RecordBatchReader__batches(self), read_table = function() Table__from_RecordBatchReader(self), + Close = function() RecordBatchReader__Close(self), export_to_c = function(stream_ptr) ExportRecordBatchReader(self, stream_ptr), ToString = function() self$schema$ToString() ), diff --git a/r/R/table.R b/r/R/table.R index d7e276415c5..c5291257792 100644 --- a/r/R/table.R +++ b/r/R/table.R @@ -328,15 +328,5 @@ as_arrow_table.RecordBatchReader <- function(x, ...) { #' @rdname as_arrow_table #' @export as_arrow_table.arrow_dplyr_query <- function(x, ...) { - # See query-engine.R for ExecPlan/Nodes - plan <- ExecPlan$create() - final_node <- plan$Build(x) - - run_with_event_loop <- identical( - Sys.getenv("R_ARROW_COLLECT_WITH_UDF", ""), - "true" - ) - - result <- plan$Run(final_node, as_table = run_with_event_loop) - as_arrow_table(result) + as_arrow_table(as_record_batch_reader(x)) } diff --git a/r/src/arrowExports.cpp b/r/src/arrowExports.cpp index adb6636e9ee..26ec6e3d9b1 100644 --- a/r/src/arrowExports.cpp +++ b/r/src/arrowExports.cpp @@ -869,36 +869,55 @@ BEGIN_CPP11 END_CPP11 } // compute-exec.cpp -std::shared_ptr ExecPlan_run(const std::shared_ptr& plan, const std::shared_ptr& final_node, cpp11::list sort_options, cpp11::strings metadata, int64_t head); -extern "C" SEXP _arrow_ExecPlan_run(SEXP plan_sexp, SEXP final_node_sexp, SEXP sort_options_sexp, SEXP metadata_sexp, SEXP head_sexp){ +cpp11::list ExecPlanReader__batches(const std::shared_ptr& reader); +extern "C" SEXP _arrow_ExecPlanReader__batches(SEXP reader_sexp){ BEGIN_CPP11 - arrow::r::Input&>::type plan(plan_sexp); - arrow::r::Input&>::type final_node(final_node_sexp); - arrow::r::Input::type sort_options(sort_options_sexp); - arrow::r::Input::type metadata(metadata_sexp); - arrow::r::Input::type head(head_sexp); - return cpp11::as_sexp(ExecPlan_run(plan, final_node, sort_options, metadata, head)); + arrow::r::Input&>::type reader(reader_sexp); + return cpp11::as_sexp(ExecPlanReader__batches(reader)); +END_CPP11 +} +// compute-exec.cpp +std::shared_ptr Table__from_ExecPlanReader(const std::shared_ptr& reader); +extern "C" SEXP _arrow_Table__from_ExecPlanReader(SEXP reader_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type reader(reader_sexp); + return cpp11::as_sexp(Table__from_ExecPlanReader(reader)); +END_CPP11 +} +// compute-exec.cpp +std::shared_ptr ExecPlanReader__Plan(const std::shared_ptr& reader); +extern "C" SEXP _arrow_ExecPlanReader__Plan(SEXP reader_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type reader(reader_sexp); + return cpp11::as_sexp(ExecPlanReader__Plan(reader)); +END_CPP11 +} +// compute-exec.cpp +std::string ExecPlanReader__PlanStatus(const std::shared_ptr& reader); +extern "C" SEXP _arrow_ExecPlanReader__PlanStatus(SEXP reader_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type reader(reader_sexp); + return cpp11::as_sexp(ExecPlanReader__PlanStatus(reader)); END_CPP11 } // compute-exec.cpp -std::shared_ptr ExecPlan_read_table(const std::shared_ptr& plan, const std::shared_ptr& final_node, cpp11::list sort_options, cpp11::strings metadata, int64_t head); -extern "C" SEXP _arrow_ExecPlan_read_table(SEXP plan_sexp, SEXP final_node_sexp, SEXP sort_options_sexp, SEXP metadata_sexp, SEXP head_sexp){ +std::shared_ptr ExecPlan_run(const std::shared_ptr& plan, const std::shared_ptr& final_node, cpp11::list sort_options, cpp11::strings metadata, int64_t head); +extern "C" SEXP _arrow_ExecPlan_run(SEXP plan_sexp, SEXP final_node_sexp, SEXP sort_options_sexp, SEXP metadata_sexp, SEXP head_sexp){ BEGIN_CPP11 arrow::r::Input&>::type plan(plan_sexp); arrow::r::Input&>::type final_node(final_node_sexp); arrow::r::Input::type sort_options(sort_options_sexp); arrow::r::Input::type metadata(metadata_sexp); arrow::r::Input::type head(head_sexp); - return cpp11::as_sexp(ExecPlan_read_table(plan, final_node, sort_options, metadata, head)); + return cpp11::as_sexp(ExecPlan_run(plan, final_node, sort_options, metadata, head)); END_CPP11 } // compute-exec.cpp -void ExecPlan_StopProducing(const std::shared_ptr& plan); -extern "C" SEXP _arrow_ExecPlan_StopProducing(SEXP plan_sexp){ +std::string ExecPlan_ToString(const std::shared_ptr& plan); +extern "C" SEXP _arrow_ExecPlan_ToString(SEXP plan_sexp){ BEGIN_CPP11 arrow::r::Input&>::type plan(plan_sexp); - ExecPlan_StopProducing(plan); - return R_NilValue; + return cpp11::as_sexp(ExecPlan_ToString(plan)); END_CPP11 } // compute-exec.cpp @@ -910,17 +929,6 @@ BEGIN_CPP11 END_CPP11 } // compute-exec.cpp -std::string ExecPlan_BuildAndShow(const std::shared_ptr& plan, const std::shared_ptr& final_node, cpp11::list sort_options, int64_t head); -extern "C" SEXP _arrow_ExecPlan_BuildAndShow(SEXP plan_sexp, SEXP final_node_sexp, SEXP sort_options_sexp, SEXP head_sexp){ -BEGIN_CPP11 - arrow::r::Input&>::type plan(plan_sexp); - arrow::r::Input&>::type final_node(final_node_sexp); - arrow::r::Input::type sort_options(sort_options_sexp); - arrow::r::Input::type head(head_sexp); - return cpp11::as_sexp(ExecPlan_BuildAndShow(plan, final_node, sort_options, head)); -END_CPP11 -} -// compute-exec.cpp #if defined(ARROW_R_WITH_DATASET) std::shared_ptr ExecNode_Scan(const std::shared_ptr& plan, const std::shared_ptr& dataset, const std::shared_ptr& filter, std::vector materialized_field_names); extern "C" SEXP _arrow_ExecNode_Scan(SEXP plan_sexp, SEXP dataset_sexp, SEXP filter_sexp, SEXP materialized_field_names_sexp){ @@ -4466,6 +4474,15 @@ BEGIN_CPP11 END_CPP11 } // recordbatchreader.cpp +void RecordBatchReader__Close(const std::shared_ptr& reader); +extern "C" SEXP _arrow_RecordBatchReader__Close(SEXP reader_sexp){ +BEGIN_CPP11 + arrow::r::Input&>::type reader(reader_sexp); + RecordBatchReader__Close(reader); + return R_NilValue; +END_CPP11 +} +// recordbatchreader.cpp std::shared_ptr RecordBatchReader__ReadNext(const std::shared_ptr& reader); extern "C" SEXP _arrow_RecordBatchReader__ReadNext(SEXP reader_sexp){ BEGIN_CPP11 @@ -5286,11 +5303,13 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_io___CompressedOutputStream__Make", (DL_FUNC) &_arrow_io___CompressedOutputStream__Make, 2}, { "_arrow_io___CompressedInputStream__Make", (DL_FUNC) &_arrow_io___CompressedInputStream__Make, 2}, { "_arrow_ExecPlan_create", (DL_FUNC) &_arrow_ExecPlan_create, 1}, + { "_arrow_ExecPlanReader__batches", (DL_FUNC) &_arrow_ExecPlanReader__batches, 1}, + { "_arrow_Table__from_ExecPlanReader", (DL_FUNC) &_arrow_Table__from_ExecPlanReader, 1}, + { "_arrow_ExecPlanReader__Plan", (DL_FUNC) &_arrow_ExecPlanReader__Plan, 1}, + { "_arrow_ExecPlanReader__PlanStatus", (DL_FUNC) &_arrow_ExecPlanReader__PlanStatus, 1}, { "_arrow_ExecPlan_run", (DL_FUNC) &_arrow_ExecPlan_run, 5}, - { "_arrow_ExecPlan_read_table", (DL_FUNC) &_arrow_ExecPlan_read_table, 5}, - { "_arrow_ExecPlan_StopProducing", (DL_FUNC) &_arrow_ExecPlan_StopProducing, 1}, + { "_arrow_ExecPlan_ToString", (DL_FUNC) &_arrow_ExecPlan_ToString, 1}, { "_arrow_ExecNode_output_schema", (DL_FUNC) &_arrow_ExecNode_output_schema, 1}, - { "_arrow_ExecPlan_BuildAndShow", (DL_FUNC) &_arrow_ExecPlan_BuildAndShow, 4}, { "_arrow_ExecNode_Scan", (DL_FUNC) &_arrow_ExecNode_Scan, 4}, { "_arrow_ExecPlan_Write", (DL_FUNC) &_arrow_ExecPlan_Write, 14}, { "_arrow_ExecNode_Filter", (DL_FUNC) &_arrow_ExecNode_Filter, 2}, @@ -5617,6 +5636,7 @@ static const R_CallMethodDef CallEntries[] = { { "_arrow_RecordBatch__from_arrays", (DL_FUNC) &_arrow_RecordBatch__from_arrays, 2}, { "_arrow_RecordBatch__ReferencedBufferSize", (DL_FUNC) &_arrow_RecordBatch__ReferencedBufferSize, 1}, { "_arrow_RecordBatchReader__schema", (DL_FUNC) &_arrow_RecordBatchReader__schema, 1}, + { "_arrow_RecordBatchReader__Close", (DL_FUNC) &_arrow_RecordBatchReader__Close, 1}, { "_arrow_RecordBatchReader__ReadNext", (DL_FUNC) &_arrow_RecordBatchReader__ReadNext, 1}, { "_arrow_RecordBatchReader__batches", (DL_FUNC) &_arrow_RecordBatchReader__batches, 1}, { "_arrow_RecordBatchReader__from_batches", (DL_FUNC) &_arrow_RecordBatchReader__from_batches, 2}, diff --git a/r/src/arrow_types.h b/r/src/arrow_types.h index d9fee37e7f1..dd0dc24449e 100644 --- a/r/src/arrow_types.h +++ b/r/src/arrow_types.h @@ -58,6 +58,8 @@ class ExecNode; } // namespace compute } // namespace arrow +class ExecPlanReader; + #if defined(ARROW_R_WITH_PARQUET) #include #endif diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index abcb418a2c2..71dc6d8b2e1 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -56,118 +56,156 @@ std::shared_ptr MakeExecNodeOrStop( }); } -std::pair, std::shared_ptr> -ExecPlan_prepare(const std::shared_ptr& plan, - const std::shared_ptr& final_node, - cpp11::list sort_options, cpp11::strings metadata, int64_t head = -1) { - // a section of this code is copied and used in ExecPlan_BuildAndShow - the 2 need - // to be in sync - // Start of chunk used in ExecPlan_BuildAndShow +// This class is a special RecordBatchReader that holds a reference to the +// underlying exec plan so that (1) it can request that the ExecPlan *stop* +// producing when this object is deleted and (2) it can defer requesting +// the ExecPlan to *start* producing until the first batch has been pulled. +// This allows it to be transformed (e.g., using map_batches() or head()) +// and queried (i.e., used as input to another ExecPlan), at the R level +// while maintaining the ability for the entire plan to be executed at once +// (e.g., to support user-defined functions) or never executed at all (e.g., +// to support printing a nested ExecPlan without having to execute it). +class ExecPlanReader : public arrow::RecordBatchReader { + public: + enum ExecPlanReaderStatus { PLAN_NOT_STARTED, PLAN_RUNNING, PLAN_FINISHED }; + + ExecPlanReader(const std::shared_ptr& plan, + const std::shared_ptr& schema, + arrow::AsyncGenerator> sink_gen) + : schema_(schema), plan_(plan), sink_gen_(sink_gen), status_(PLAN_NOT_STARTED) {} + + std::string PlanStatus() const { + switch (status_) { + case PLAN_NOT_STARTED: + return "PLAN_NOT_STARTED"; + case PLAN_RUNNING: + return "PLAN_RUNNING"; + case PLAN_FINISHED: + return "PLAN_FINISHED"; + default: + return "UNKNOWN"; + } + } - // For now, don't require R to construct SinkNodes. - // Instead, just pass the node we should collect as an argument. - arrow::AsyncGenerator> sink_gen; + std::shared_ptr schema() const override { return schema_; } - // Sorting uses a different sink node; there is no general sort yet - if (sort_options.size() > 0) { - if (head >= 0) { - // Use the SelectK node to take only what we need - MakeExecNodeOrStop( - "select_k_sink", plan.get(), {final_node.get()}, - compute::SelectKSinkNodeOptions{ - arrow::compute::SelectKOptions( - head, std::dynamic_pointer_cast( - make_compute_options("sort_indices", sort_options)) - ->sort_keys), - &sink_gen}); + arrow::Status ReadNext(std::shared_ptr* batch_out) override { + // TODO(ARROW-11841) check a StopToken to potentially cancel this plan + + // If this is the first batch getting pulled, tell the exec plan to + // start producing + if (status_ == PLAN_NOT_STARTED) { + ARROW_RETURN_NOT_OK(StartProducing()); + } + + // If we've closed the reader, keep sending nullptr + // (consistent with what most RecordBatchReader subclasses do) + if (status_ == PLAN_FINISHED) { + batch_out->reset(); + return arrow::Status::OK(); + } + + auto out = sink_gen_().result(); + if (!out.ok()) { + StopProducing(); + return out.status(); + } + + if (out.ValueUnsafe()) { + auto batch_result = out.ValueUnsafe()->ToRecordBatch(schema_, gc_memory_pool()); + if (!batch_result.ok()) { + StopProducing(); + return batch_result.status(); + } + + *batch_out = batch_result.ValueUnsafe(); } else { - MakeExecNodeOrStop("order_by_sink", plan.get(), {final_node.get()}, - compute::OrderBySinkNodeOptions{ - *std::dynamic_pointer_cast( - make_compute_options("sort_indices", sort_options)), - &sink_gen}); + batch_out->reset(); + StopProducing(); } - } else { - MakeExecNodeOrStop("sink", plan.get(), {final_node.get()}, - compute::SinkNodeOptions{&sink_gen}); + + return arrow::Status::OK(); } - // End of chunk used in ExecPlan_BuildAndShow + arrow::Status Close() override { + StopProducing(); + return arrow::Status::OK(); + } - StopIfNotOk(plan->Validate()); + const std::shared_ptr& Plan() const { return plan_; } - // If the generator is destroyed before being completely drained, inform plan - std::shared_ptr stop_producing{nullptr, [plan](...) { - bool not_finished_yet = - plan->finished().TryAddCallback([&plan] { - return [plan](const arrow::Status&) {}; - }); + ~ExecPlanReader() { StopProducing(); } - if (not_finished_yet) { - plan->StopProducing(); - } - }}; + private: + std::shared_ptr schema_; + std::shared_ptr plan_; + arrow::AsyncGenerator> sink_gen_; + int status_; - // Attach metadata to the schema - auto out_schema = final_node->output_schema(); - if (metadata.size() > 0) { - auto kv = strings_to_kvm(metadata); - out_schema = out_schema->WithMetadata(kv); + arrow::Status StartProducing() { + ARROW_RETURN_NOT_OK(plan_->StartProducing()); + status_ = PLAN_RUNNING; + return arrow::Status::OK(); } - std::pair, std::shared_ptr> - out; - out.first = plan; - out.second = compute::MakeGeneratorReader( - out_schema, [stop_producing, plan, sink_gen] { return sink_gen(); }, - gc_memory_pool()); - return out; -} + void StopProducing() { + if (status_ == PLAN_RUNNING) { + // We're done with the plan, but it may still need some time + // to finish and clean up after itself. To do this, we give a + // callable with its own copy of the shared_ptr so + // that it can delete itself when it is safe to do so. + std::shared_ptr plan(plan_); + bool not_finished_yet = plan_->finished().TryAddCallback( + [&plan] { return [plan](const arrow::Status&) {}; }); + + if (not_finished_yet) { + plan_->StopProducing(); + } + } + + status_ = PLAN_FINISHED; + plan_.reset(); + sink_gen_ = arrow::MakeEmptyGenerator>(); + } +}; // [[arrow::export]] -std::shared_ptr ExecPlan_run( - const std::shared_ptr& plan, - const std::shared_ptr& final_node, cpp11::list sort_options, - cpp11::strings metadata, int64_t head = -1) { - auto prepared_plan = ExecPlan_prepare(plan, final_node, sort_options, metadata, head); - StopIfNotOk(prepared_plan.first->StartProducing()); - return prepared_plan.second; +cpp11::list ExecPlanReader__batches( + const std::shared_ptr& reader) { + auto result = RunWithCapturedRIfPossible( + [&]() { return reader->ToRecordBatches(); }); + return arrow::r::to_r_list(ValueOrStop(result)); } // [[arrow::export]] -std::shared_ptr ExecPlan_read_table( - const std::shared_ptr& plan, - const std::shared_ptr& final_node, cpp11::list sort_options, - cpp11::strings metadata, int64_t head = -1) { - auto prepared_plan = ExecPlan_prepare(plan, final_node, sort_options, metadata, head); - +std::shared_ptr Table__from_ExecPlanReader( + const std::shared_ptr& reader) { auto result = RunWithCapturedRIfPossible>( - [&]() -> arrow::Result> { - ARROW_RETURN_NOT_OK(prepared_plan.first->StartProducing()); - return prepared_plan.second->ToTable(); - }); + [&]() { return reader->ToTable(); }); return ValueOrStop(result); } // [[arrow::export]] -void ExecPlan_StopProducing(const std::shared_ptr& plan) { - plan->StopProducing(); +std::shared_ptr ExecPlanReader__Plan( + const std::shared_ptr& reader) { + if (reader->PlanStatus() == "PLAN_FINISHED") { + cpp11::stop("Can't extract ExecPlan from a finished ExecPlanReader"); + } + + return reader->Plan(); } // [[arrow::export]] -std::shared_ptr ExecNode_output_schema( - const std::shared_ptr& node) { - return node->output_schema(); +std::string ExecPlanReader__PlanStatus(const std::shared_ptr& reader) { + return reader->PlanStatus(); } // [[arrow::export]] -std::string ExecPlan_BuildAndShow(const std::shared_ptr& plan, - const std::shared_ptr& final_node, - cpp11::list sort_options, int64_t head = -1) { - // a section of this code is copied from ExecPlan_prepare - the 2 need to be in sync - // Start of chunk copied from ExecPlan_prepare - +std::shared_ptr ExecPlan_run( + const std::shared_ptr& plan, + const std::shared_ptr& final_node, cpp11::list sort_options, + cpp11::strings metadata, int64_t head = -1) { // For now, don't require R to construct SinkNodes. // Instead, just pass the node we should collect as an argument. arrow::AsyncGenerator> sink_gen; @@ -196,11 +234,29 @@ std::string ExecPlan_BuildAndShow(const std::shared_ptr& plan compute::SinkNodeOptions{&sink_gen}); } - // End of chunk copied from ExecPlan_prepare + StopIfNotOk(plan->Validate()); + + // Attach metadata to the schema + auto out_schema = final_node->output_schema(); + if (metadata.size() > 0) { + auto kv = strings_to_kvm(metadata); + out_schema = out_schema->WithMetadata(kv); + } + + return std::make_shared(plan, out_schema, sink_gen); +} +// [[arrow::export]] +std::string ExecPlan_ToString(const std::shared_ptr& plan) { return plan->ToString(); } +// [[arrow::export]] +std::shared_ptr ExecNode_output_schema( + const std::shared_ptr& node) { + return node->output_schema(); +} + #if defined(ARROW_R_WITH_DATASET) #include diff --git a/r/src/recordbatchreader.cpp b/r/src/recordbatchreader.cpp index c571d282da1..d0c52acc416 100644 --- a/r/src/recordbatchreader.cpp +++ b/r/src/recordbatchreader.cpp @@ -27,6 +27,11 @@ std::shared_ptr RecordBatchReader__schema( return reader->schema(); } +// [[arrow::export]] +void RecordBatchReader__Close(const std::shared_ptr& reader) { + return arrow::StopIfNotOk(reader->Close()); +} + // [[arrow::export]] std::shared_ptr RecordBatchReader__ReadNext( const std::shared_ptr& reader) { @@ -111,19 +116,77 @@ std::shared_ptr Table__from_RecordBatchReader( return ValueOrStop(reader->ToTable()); } +// Because the head() operation can leave a RecordBatchReader whose contents +// will never be drained, we implement a wrapper class here that takes care +// to (1) return only the requested number of rows (or fewer) and (2) Close +// and release the underlying reader as soon as possible. This is mostly +// useful for the ExecPlanReader, whose Close() method also requests +// that the ExecPlan stop producing, but may also be useful for readers +// that point to an open file and whose Close() or delete method releases +// the file. +class RecordBatchReaderHead : public arrow::RecordBatchReader { + public: + RecordBatchReaderHead(std::shared_ptr reader, + int64_t num_rows) + : schema_(reader->schema()), reader_(reader), num_rows_(num_rows) {} + + std::shared_ptr schema() const override { return schema_; } + + arrow::Status ReadNext(std::shared_ptr* batch_out) override { + if (!reader_) { + // Close() has been called + batch_out = nullptr; + return arrow::Status::OK(); + } + + ARROW_RETURN_NOT_OK(reader_->ReadNext(batch_out)); + if (batch_out->get()) { + num_rows_ -= batch_out->get()->num_rows(); + if (num_rows_ < 0) { + auto smaller_batch = + batch_out->get()->Slice(0, batch_out->get()->num_rows() + num_rows_); + *batch_out = smaller_batch; + } + + if (num_rows_ <= 0) { + // We've run out of num_rows before batches + ARROW_RETURN_NOT_OK(Close()); + } + } else { + // We've run out of batches before num_rows + ARROW_RETURN_NOT_OK(Close()); + } + + return arrow::Status::OK(); + } + + arrow::Status Close() override { + if (reader_) { + arrow::Status result = reader_->Close(); + reader_.reset(); + return result; + } else { + return arrow::Status::OK(); + } + } + + private: + std::shared_ptr schema_; + std::shared_ptr reader_; + int64_t num_rows_; +}; + // [[arrow::export]] std::shared_ptr RecordBatchReader__Head( const std::shared_ptr& reader, int64_t num_rows) { - std::vector> batches; - std::shared_ptr this_batch; - while (num_rows > 0) { - this_batch = ValueOrStop(reader->Next()); - if (this_batch == nullptr) break; - batches.push_back(this_batch->Slice(0, num_rows)); - num_rows -= this_batch->num_rows(); + if (num_rows <= 0) { + // If we are never going to pull any batches from this reader, close it + // immediately. + StopIfNotOk(reader->Close()); + return ValueOrStop(arrow::RecordBatchReader::Make({}, reader->schema())); + } else { + return std::make_shared(reader, num_rows); } - return ValueOrStop( - arrow::RecordBatchReader::Make(std::move(batches), reader->schema())); } // -------- RecordBatchStreamReader diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 5821c0fa2df..11c37519ae5 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -91,11 +91,7 @@ test_that("register_scalar_function() adds a compute function to the registry", int32(), float64(), auto_convert = TRUE ) - on.exit({ - unregister_binding("times_32", update_cache = TRUE) - # TODO(ARROW-17178) remove the need for this! - Sys.unsetenv("R_ARROW_COLLECT_WITH_UDF") - }) + on.exit(unregister_binding("times_32", update_cache = TRUE)) expect_true("times_32" %in% names(asNamespace("arrow")$.cache$functions)) expect_true("times_32" %in% list_compute_functions()) @@ -127,11 +123,7 @@ test_that("arrow_scalar_function() with bad return type errors", { int32(), float64() ) - on.exit({ - unregister_binding("times_32_bad_return_type_array", update_cache = TRUE) - # TODO(ARROW-17178) remove the need for this! - Sys.unsetenv("R_ARROW_COLLECT_WITH_UDF") - }) + on.exit(unregister_binding("times_32_bad_return_type_array", update_cache = TRUE)) expect_error( call_function("times_32_bad_return_type_array", Array$create(1L)), @@ -144,11 +136,7 @@ test_that("arrow_scalar_function() with bad return type errors", { int32(), float64() ) - on.exit({ - unregister_binding("times_32_bad_return_type_scalar", update_cache = TRUE) - # TODO(ARROW-17178) remove the need for this! - Sys.unsetenv("R_ARROW_COLLECT_WITH_UDF") - }) + on.exit(unregister_binding("times_32_bad_return_type_scalar", update_cache = TRUE)) expect_error( call_function("times_32_bad_return_type_scalar", Array$create(1L)), @@ -166,11 +154,7 @@ test_that("register_scalar_function() can register multiple kernels", { out_type = function(in_types) in_types[[1]], auto_convert = TRUE ) - on.exit({ - unregister_binding("times_32", update_cache = TRUE) - # TODO(ARROW-17178) remove the need for this! - Sys.unsetenv("R_ARROW_COLLECT_WITH_UDF") - }) + on.exit(unregister_binding("times_32", update_cache = TRUE)) expect_equal( call_function("times_32", Scalar$create(1L, int32())), @@ -189,9 +173,6 @@ test_that("register_scalar_function() can register multiple kernels", { }) test_that("register_scalar_function() errors for unsupported specifications", { - # TODO(ARROW-17178) remove the need for this! - on.exit(Sys.unsetenv("R_ARROW_COLLECT_WITH_UDF")) - expect_error( register_scalar_function( "no_kernels", @@ -256,11 +237,7 @@ test_that("user-defined functions work during multi-threaded execution", { float64(), auto_convert = TRUE ) - on.exit({ - unregister_binding("times_32", update_cache = TRUE) - # TODO(ARROW-17178) remove the need for this! - Sys.unsetenv("R_ARROW_COLLECT_WITH_UDF") - }) + on.exit(unregister_binding("times_32", update_cache = TRUE)) # check a regular collect() result <- open_dataset(tf_dataset) %>% @@ -282,7 +259,7 @@ test_that("user-defined functions work during multi-threaded execution", { expect_identical(result2$fun_result, example_df$value * 32) }) -test_that("user-defined error when called from an unsupported context", { +test_that("nested exec plans can contain user-defined functions", { skip_if_not_available("dataset") skip_if_not(CanRunWithCapturedR()) @@ -293,11 +270,7 @@ test_that("user-defined error when called from an unsupported context", { float64(), auto_convert = TRUE ) - on.exit({ - unregister_binding("times_32", update_cache = TRUE) - # TODO(ARROW-17178) remove the need for this! - Sys.unsetenv("R_ARROW_COLLECT_WITH_UDF") - }) + on.exit(unregister_binding("times_32", update_cache = TRUE)) stream_plan_with_udf <- function() { record_batch(a = 1:1000) %>% @@ -313,24 +286,35 @@ test_that("user-defined error when called from an unsupported context", { dplyr::collect() } - if (identical(tolower(Sys.info()[["sysname"]]), "windows")) { - expect_equal( - stream_plan_with_udf(), - record_batch(a = 1:1000) %>% - dplyr::mutate(b = times_32(a)) %>% - dplyr::collect(as_data_frame = FALSE) - ) - - result <- collect_plan_with_head() - expect_equal(nrow(result), 11) - } else { - expect_error( - stream_plan_with_udf(), - "Call to R \\(.*?\\) from a non-R thread from an unsupported context" - ) - expect_error( - collect_plan_with_head(), - "Call to R \\(.*?\\) from a non-R thread from an unsupported context" - ) - } + expect_equal( + stream_plan_with_udf(), + record_batch(a = 1:1000) %>% + dplyr::mutate(b = times_32(a)) %>% + dplyr::collect(as_data_frame = FALSE) + ) + + result <- collect_plan_with_head() + expect_equal(nrow(result), 11) +}) + +test_that("head() on exec plan containing user-defined functions", { + skip_if_not_available("dataset") + skip_if_not(CanRunWithCapturedR()) + + register_scalar_function( + "times_32", + function(context, x) x * 32.0, + int32(), + float64(), + auto_convert = TRUE + ) + on.exit(unregister_binding("times_32", update_cache = TRUE)) + + result <- record_batch(a = 1:1000) %>% + dplyr::mutate(b = times_32(a)) %>% + as_record_batch_reader() %>% + head(11) %>% + dplyr::collect() + + expect_equal(nrow(result), 11) }) diff --git a/r/tests/testthat/test-query-engine.R b/r/tests/testthat/test-query-engine.R index dd87335f876..f2190eb6684 100644 --- a/r/tests/testthat/test-query-engine.R +++ b/r/tests/testthat/test-query-engine.R @@ -17,6 +17,66 @@ library(dplyr, warn.conflicts = FALSE) +test_that("ExecPlanReader does not start evaluating a query", { + rbr <- as_record_batch_reader( + function(x) stop("This query will error if started"), + schema = schema(a = int32()) + ) + + reader <- as_record_batch_reader(as_adq(rbr)) + expect_identical(reader$PlanStatus(), "PLAN_NOT_STARTED") + expect_error(reader$read_table(), "This query will error if started") + expect_identical(reader$PlanStatus(), "PLAN_FINISHED") +}) + +test_that("ExecPlanReader evaluates nested exec plans lazily", { + reader <- as_record_batch_reader(as_adq(arrow_table(a = 1:10))) + expect_identical(reader$PlanStatus(), "PLAN_NOT_STARTED") + + head_reader <- head(reader, 4) + expect_identical(reader$PlanStatus(), "PLAN_NOT_STARTED") + + expect_equal( + head_reader$read_table(), + arrow_table(a = 1:4) + ) + + expect_identical(reader$PlanStatus(), "PLAN_FINISHED") +}) + +test_that("ExecPlanReader evaluates head() lazily", { + reader <- as_record_batch_reader(as_adq(arrow_table(a = 1:10))) + expect_identical(reader$PlanStatus(), "PLAN_NOT_STARTED") + + head_reader <- head(reader, 4) + expect_identical(reader$PlanStatus(), "PLAN_NOT_STARTED") + + expect_equal( + head_reader$read_table(), + arrow_table(a = 1:4) + ) + + expect_identical(reader$PlanStatus(), "PLAN_FINISHED") +}) + +test_that("ExecPlanReader evaluates head() lazily", { + # Make a rather long RecordBatchReader + reader <- RecordBatchReader$create( + batches = rep( + list(record_batch(line = letters)), + 100L + ) + ) + + # ...But only get 10 rows from it + query <- head(as_adq(reader), 10) + expect_identical(as_arrow_table(query)$num_rows, 10L) + + # Depending on exactly how quickly background threads respond to the + # request to cancel, reader$read_table()$num_rows > 0 may or may not + # evaluate to TRUE (i.e., the reader may or may not be completely drained). +}) + test_that("do_exec_plan_substrait can evaluate a simple plan", { skip_if_not_available("substrait") From 700f42fd684db2160ae3c0ed07a17960e1790c6d Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 15 Sep 2022 21:04:04 -0700 Subject: [PATCH 077/133] MINOR: [C++] Fix typo in dataset arg validation (#14150) Authored-by: Will Jones Signed-off-by: Yibo Cai --- cpp/src/arrow/dataset/dataset_writer.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/dataset/dataset_writer.cc b/cpp/src/arrow/dataset/dataset_writer.cc index 0f81fad3071..bad363b3818 100644 --- a/cpp/src/arrow/dataset/dataset_writer.cc +++ b/cpp/src/arrow/dataset/dataset_writer.cc @@ -442,7 +442,8 @@ Status ValidateOptions(const FileSystemDatasetWriteOptions& options) { return Status::Invalid("max_rows_per_group must be a positive number"); } if (options.max_rows_per_group < options.min_rows_per_group) { - return Status::Invalid("max_rows_per_group must be less than min_rows_per_group"); + return Status::Invalid( + "min_rows_per_group must be less than or equal to max_rows_per_group"); } if (options.max_rows_per_file > 0 && options.max_rows_per_file < options.max_rows_per_group) { From 5e6da78a4f696a638ea93cb69ff89a8e070c0105 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Fri, 16 Sep 2022 13:44:21 +0530 Subject: [PATCH 078/133] ARROW-16652: [Python] Cast compute kernel segfaults when called with a Table (#14044) Repro Script: ```python import pyarrow as pa import pyarrow.compute as pc table = pa.table({'a': [1, 2]}) pc.cast(table, pa.int64()) ``` Post PR we get, ``` pyarrow.lib.ArrowInvalid: Tried executing function with non-value type: Table ``` which is the same error we get if we replace `pc.cast(table, pa.int64())` with `pc.abs(table)`. Lead-authored-by: Kshiteej K Co-authored-by: kshitij12345 Signed-off-by: Alenka Frim --- cpp/src/arrow/compute/cast.cc | 4 +++- python/pyarrow/tests/test_compute.py | 7 +++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/cast.cc b/cpp/src/arrow/compute/cast.cc index 52aecf3e45a..a57c5e2c79a 100644 --- a/cpp/src/arrow/compute/cast.cc +++ b/cpp/src/arrow/compute/cast.cc @@ -94,7 +94,9 @@ class CastMetaFunction : public MetaFunction { const FunctionOptions* options, ExecContext* ctx) const override { ARROW_ASSIGN_OR_RAISE(auto cast_options, ValidateOptions(options)); - if (args[0].type()->Equals(*cast_options->to_type)) { + // args[0].type() could be a nullptr so check for that before + // we do anything with it. + if (args[0].type() && args[0].type()->Equals(*cast_options->to_type)) { return args[0]; } Result> result = diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index f2820b6e25f..c41c4fa9b3b 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -2889,3 +2889,10 @@ def test_expression_call_function(): with pytest.raises(TypeError): pc.add(1, field) + + +def test_cast_table_raises(): + table = pa.table({'a': [1, 2]}) + + with pytest.raises(pa.lib.ArrowInvalid): + pc.cast(table, pa.int64()) From b48d2287bef95ed195f6e3721dd34f97fd1735c2 Mon Sep 17 00:00:00 2001 From: andreoss Date: Fri, 16 Sep 2022 12:52:34 +0000 Subject: [PATCH 079/133] ARROW-17704: [Java][FlightRPC] Update to Junit 5 (#14103) Update flight module to Junit 5 ~~Relevant issue: https://issues.apache.org/jira/browse/ARROW-4740~~ See: https://issues.apache.org/jira/browse/ARROW-17704 Authored-by: andreoss Signed-off-by: David Li --- .../apache/arrow/flight/FlightTestUtil.java | 4 +- .../arrow/flight/TestApplicationMetadata.java | 34 +- .../org/apache/arrow/flight/TestAuth.java | 83 +-- .../apache/arrow/flight/TestBackPressure.java | 25 +- .../arrow/flight/TestBasicOperation.java | 82 +-- .../apache/arrow/flight/TestCallOptions.java | 26 +- .../arrow/flight/TestClientMiddleware.java | 35 +- .../arrow/flight/TestDictionaryUtils.java | 2 +- .../apache/arrow/flight/TestDoExchange.java | 40 +- .../arrow/flight/TestErrorMetadata.java | 28 +- .../apache/arrow/flight/TestFlightClient.java | 39 +- .../arrow/flight/TestFlightService.java | 10 +- .../apache/arrow/flight/TestLargeMessage.java | 6 +- .../org/apache/arrow/flight/TestLeak.java | 2 +- .../arrow/flight/TestMetadataVersion.java | 20 +- .../arrow/flight/TestServerMiddleware.java | 81 ++- .../arrow/flight/TestServerOptions.java | 37 +- .../java/org/apache/arrow/flight/TestTls.java | 12 +- .../arrow/flight/auth/TestBasicAuth.java | 24 +- .../arrow/flight/auth2/TestBasicAuth2.java | 32 +- .../flight/client/TestCookieHandling.java | 52 +- .../arrow/flight/grpc/TestStatusUtils.java | 18 +- .../apache/arrow/flight/perf/TestPerf.java | 5 +- .../arrow/flight/TestFlightGrpcUtils.java | 56 +- .../apache/arrow/flight/TestFlightSql.java | 557 +++++++++++------- ...qlInfoOptionsUtilsBitmaskCreationTest.java | 29 +- ...SqlInfoOptionsUtilsBitmaskParsingTest.java | 27 +- java/performance/pom.xml | 2 +- java/pom.xml | 31 +- 29 files changed, 746 insertions(+), 653 deletions(-) diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/FlightTestUtil.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/FlightTestUtil.java index cd043b639b0..a0eb80daca6 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/FlightTestUtil.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/FlightTestUtil.java @@ -28,7 +28,7 @@ import java.util.Random; import java.util.function.Function; -import org.junit.Assert; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.function.Executable; @@ -130,7 +130,7 @@ static boolean isNativeTransportAvailable() { */ public static CallStatus assertCode(FlightStatusCode code, Executable r) { final FlightRuntimeException ex = Assertions.assertThrows(FlightRuntimeException.class, r); - Assert.assertEquals(code, ex.status().code()); + Assertions.assertEquals(code, ex.status().code()); return ex.status(); } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java index c7b3321af01..fb0345b134e 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestApplicationMetadata.java @@ -32,9 +32,9 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.Assert; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; /** * Tests for application-specific metadata support in Flight. @@ -51,16 +51,16 @@ public class TestApplicationMetadata { */ @Test // This test is consistently flaky on CI, unfortunately. - @Ignore + @Disabled public void retrieveMetadata() { test((allocator, client) -> { try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) { byte i = 0; while (stream.next()) { final IntVector vector = (IntVector) stream.getRoot().getVector("a"); - Assert.assertEquals(1, vector.getValueCount()); - Assert.assertEquals(10, vector.get(0)); - Assert.assertEquals(i, stream.getLatestMetadata().getByte(0)); + Assertions.assertEquals(1, vector.getValueCount()); + Assertions.assertEquals(10, vector.get(0)); + Assertions.assertEquals(i, stream.getLatestMetadata().getByte(0)); i++; } } catch (Exception e) { @@ -81,7 +81,7 @@ public void arrow6136() { final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, listener); // Must attempt to retrieve the result to get any server-side errors. final CallStatus status = FlightTestUtil.assertCode(FlightStatusCode.INTERNAL, writer::getResult); - Assert.assertEquals(MESSAGE_ARROW_6136, status.description()); + Assertions.assertEquals(MESSAGE_ARROW_6136, status.description()); } catch (Exception e) { throw new RuntimeException(e); } @@ -92,7 +92,7 @@ public void arrow6136() { * Ensure that a client can send metadata to the server. */ @Test - @Ignore + @Disabled public void uploadMetadataAsync() { final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); test((allocator, client) -> { @@ -104,8 +104,8 @@ public void uploadMetadataAsync() { @Override public void onNext(PutResult val) { - Assert.assertNotNull(val); - Assert.assertEquals(counter, val.getApplicationMetadata().getByte(0)); + Assertions.assertNotNull(val); + Assertions.assertEquals(counter, val.getApplicationMetadata().getByte(0)); counter++; } }; @@ -134,7 +134,7 @@ public void onNext(PutResult val) { * Ensure that a client can send metadata to the server. Uses the synchronous API. */ @Test - @Ignore + @Disabled public void uploadMetadataSync() { final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); test((allocator, client) -> { @@ -153,8 +153,8 @@ public void uploadMetadataSync() { root.setRowCount(1); writer.putNext(metadata); try (final PutResult message = listener.poll(5000, TimeUnit.SECONDS)) { - Assert.assertNotNull(message); - Assert.assertEquals(i, message.getApplicationMetadata().getByte(0)); + Assertions.assertNotNull(message); + Assertions.assertEquals(i, message.getApplicationMetadata().getByte(0)); } catch (InterruptedException | ExecutionException e) { throw new RuntimeException(e); } @@ -170,7 +170,7 @@ public void uploadMetadataSync() { * Make sure that a {@link SyncPutListener} properly reclaims memory if ignored. */ @Test - @Ignore + @Disabled public void syncMemoryReclaimed() { final Schema schema = new Schema(Collections.singletonList(Field.nullable("a", new ArrowType.Int(32, true)))); test((allocator, client) -> { @@ -216,10 +216,10 @@ public void testMetadataEndianness() throws Exception { final FlightClient.ClientStreamListener writer = client.startPut(descriptor, root, reader); writer.completed(); try (final PutResult metadata = reader.read()) { - Assert.assertEquals(16, metadata.getApplicationMetadata().readableBytes()); + Assertions.assertEquals(16, metadata.getApplicationMetadata().readableBytes()); byte[] bytes = new byte[16]; metadata.getApplicationMetadata().readBytes(bytes); - Assert.assertArrayEquals(EndianFlightProducer.EXPECTED_BYTES, bytes); + Assertions.assertArrayEquals(EndianFlightProducer.EXPECTED_BYTES, bytes); } writer.getResult(); } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestAuth.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestAuth.java index 6f0ec9f0255..0da49c906fc 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestAuth.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestAuth.java @@ -24,56 +24,61 @@ import org.apache.arrow.flight.auth.ServerAuthHandler; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; public class TestAuth { /** An auth handler that does not send messages should not block the server forever. */ - @Test(expected = RuntimeException.class) + @Test public void noMessages() throws Exception { - try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); - final FlightServer s = FlightTestUtil - .getStartedServer( - location -> FlightServer.builder(allocator, location, new NoOpFlightProducer()).authHandler( - new OneshotAuthHandler()).build()); - final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { - client.authenticate(new ClientAuthHandler() { - @Override - public void authenticate(ClientAuthSender outgoing, Iterator incoming) { - } + Assertions.assertThrows(RuntimeException.class, () -> { + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final FlightServer s = FlightTestUtil + .getStartedServer( + location -> FlightServer.builder(allocator, location, new NoOpFlightProducer()).authHandler( + new OneshotAuthHandler()).build()); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + client.authenticate(new ClientAuthHandler() { + @Override + public void authenticate(ClientAuthSender outgoing, Iterator incoming) { + } - @Override - public byte[] getCallToken() { - return new byte[0]; - } - }); - } + @Override + public byte[] getCallToken() { + return new byte[0]; + } + }); + } + }); } /** An auth handler that sends an error should not block the server forever. */ - @Test(expected = RuntimeException.class) + @Test public void clientError() throws Exception { - try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); - final FlightServer s = FlightTestUtil - .getStartedServer( - location -> FlightServer.builder(allocator, location, new NoOpFlightProducer()).authHandler( - new OneshotAuthHandler()).build()); - final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { - client.authenticate(new ClientAuthHandler() { - @Override - public void authenticate(ClientAuthSender outgoing, Iterator incoming) { - outgoing.send(new byte[0]); - // Ensure the server-side runs - incoming.next(); - outgoing.onError(new RuntimeException("test")); - } + Assertions.assertThrows(RuntimeException.class, () -> { + try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); + final FlightServer s = FlightTestUtil + .getStartedServer( + location -> FlightServer.builder(allocator, location, new NoOpFlightProducer()).authHandler( + new OneshotAuthHandler()).build()); + final FlightClient client = FlightClient.builder(allocator, s.getLocation()).build()) { + client.authenticate(new ClientAuthHandler() { + @Override + public void authenticate(ClientAuthSender outgoing, Iterator incoming) { + outgoing.send(new byte[0]); + // Ensure the server-side runs + incoming.next(); + outgoing.onError(new RuntimeException("test")); + } - @Override - public byte[] getCallToken() { - return new byte[0]; - } - }); - } + @Override + public byte[] getCallToken() { + return new byte[0]; + } + }); + } + }); } private static class OneshotAuthHandler implements ServerAuthHandler { diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBackPressure.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBackPressure.java index 1a71c363e17..ae691f3ef90 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBackPressure.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBackPressure.java @@ -30,9 +30,9 @@ import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.Assert; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import com.google.common.collect.ImmutableList; @@ -43,7 +43,7 @@ public class TestBackPressure { /** * Make sure that failing to consume one stream doesn't block other streams. */ - @Ignore + @Disabled @Test public void ensureIndependentSteams() throws Exception { ensureIndependentSteams((b) -> (location -> new PerformanceTestServer(b, location))); @@ -52,7 +52,7 @@ public void ensureIndependentSteams() throws Exception { /** * Make sure that failing to consume one stream doesn't block other streams. */ - @Ignore + @Disabled @Test public void ensureIndependentSteamsWithCallbacks() throws Exception { ensureIndependentSteams((b) -> (location -> new PerformanceTestServer(b, location, @@ -62,7 +62,7 @@ public void ensureIndependentSteamsWithCallbacks() throws Exception { /** * Test to make sure stream doesn't go faster than the consumer is consuming. */ - @Ignore + @Disabled @Test public void ensureWaitUntilProceed() throws Exception { ensureWaitUntilProceed(new PollingBackpressureStrategy(), false); @@ -72,7 +72,7 @@ public void ensureWaitUntilProceed() throws Exception { * Test to make sure stream doesn't go faster than the consumer is consuming using a callback-based * backpressure strategy. */ - @Ignore + @Disabled @Test public void ensureWaitUntilProceedWithCallbacks() throws Exception { ensureWaitUntilProceed(new RecordingCallbackBackpressureStrategy(), true); @@ -177,9 +177,14 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l root.clear(); } long expected = wait - epsilon; - Assert.assertTrue( - String.format("Expected a sleep of at least %dms but only slept for %d", expected, - bpStrategy.getSleepTime()), bpStrategy.getSleepTime() > expected); + Assertions.assertTrue( + bpStrategy.getSleepTime() > expected, + String.format( + "Expected a sleep of at least %dms but only slept for %d", + expected, + bpStrategy.getSleepTime() + ) + ); } } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java index e29cd07ced5..0a1d7f8a3f8 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestBasicOperation.java @@ -50,8 +50,8 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; import com.google.common.base.Charsets; import com.google.protobuf.ByteString; @@ -65,8 +65,8 @@ public class TestBasicOperation { @Test public void fastPathDefaults() { - Assert.assertTrue(ArrowMessage.ENABLE_ZERO_COPY_READ); - Assert.assertFalse(ArrowMessage.ENABLE_ZERO_COPY_WRITE); + Assertions.assertTrue(ArrowMessage.ENABLE_ZERO_COPY_READ); + Assertions.assertFalse(ArrowMessage.ENABLE_ZERO_COPY_WRITE); } /** @@ -75,7 +75,7 @@ public void fastPathDefaults() { @Test public void unknownScheme() throws URISyntaxException { final Location location = new Location("s3://unknown"); - Assert.assertEquals("s3", location.getUri().getScheme()); + Assertions.assertEquals("s3", location.getUri().getScheme()); } @Test @@ -83,7 +83,7 @@ public void unknownSchemeRemote() throws Exception { test(c -> { try { final FlightInfo info = c.getInfo(FlightDescriptor.path("test")); - Assert.assertEquals(new URI("https://example.com"), info.getEndpoints().get(0).getLocations().get(0).getUri()); + Assertions.assertEquals(new URI("https://example.com"), info.getEndpoints().get(0).getLocations().get(0).getUri()); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -93,7 +93,7 @@ public void unknownSchemeRemote() throws Exception { @Test public void roundTripTicket() throws Exception { final Ticket ticket = new Ticket(new byte[]{0, 1, 2, 3, 4, 5}); - Assert.assertEquals(ticket, Ticket.deserialize(ticket.serialize())); + Assertions.assertEquals(ticket, Ticket.deserialize(ticket.serialize())); } @Test @@ -116,17 +116,17 @@ public void roundTripInfo() throws Exception { Location.forGrpcInsecure("localhost", 50051)) ), 200, 500); - Assert.assertEquals(info1, FlightInfo.deserialize(info1.serialize())); - Assert.assertEquals(info2, FlightInfo.deserialize(info2.serialize())); - Assert.assertEquals(info3, FlightInfo.deserialize(info3.serialize())); + Assertions.assertEquals(info1, FlightInfo.deserialize(info1.serialize())); + Assertions.assertEquals(info2, FlightInfo.deserialize(info2.serialize())); + Assertions.assertEquals(info3, FlightInfo.deserialize(info3.serialize())); } @Test public void roundTripDescriptor() throws Exception { final FlightDescriptor cmd = FlightDescriptor.command("test command".getBytes(StandardCharsets.UTF_8)); - Assert.assertEquals(cmd, FlightDescriptor.deserialize(cmd.serialize())); + Assertions.assertEquals(cmd, FlightDescriptor.deserialize(cmd.serialize())); final FlightDescriptor path = FlightDescriptor.path("foo", "bar", "test.arrow"); - Assert.assertEquals(path, FlightDescriptor.deserialize(path.serialize())); + Assertions.assertEquals(path, FlightDescriptor.deserialize(path.serialize())); } @Test @@ -136,7 +136,7 @@ public void getDescriptors() throws Exception { for (FlightInfo i : c.listFlights(Criteria.ALL)) { count += 1; } - Assert.assertEquals(1, count); + Assertions.assertEquals(1, count); }); } @@ -147,7 +147,7 @@ public void getDescriptorsWithCriteria() throws Exception { for (FlightInfo i : c.listFlights(new Criteria(new byte[]{1}))) { count += 1; } - Assert.assertEquals(0, count); + Assertions.assertEquals(0, count); }); } @@ -180,21 +180,21 @@ public void doAction() throws Exception { test(c -> { Iterator stream = c.doAction(new Action("hello")); - Assert.assertTrue(stream.hasNext()); + Assertions.assertTrue(stream.hasNext()); Result r = stream.next(); - Assert.assertArrayEquals("world".getBytes(Charsets.UTF_8), r.getBody()); + Assertions.assertArrayEquals("world".getBytes(Charsets.UTF_8), r.getBody()); }); test(c -> { Iterator stream = c.doAction(new Action("hellooo")); - Assert.assertTrue(stream.hasNext()); + Assertions.assertTrue(stream.hasNext()); Result r = stream.next(); - Assert.assertArrayEquals("world".getBytes(Charsets.UTF_8), r.getBody()); + Assertions.assertArrayEquals("world".getBytes(Charsets.UTF_8), r.getBody()); - Assert.assertTrue(stream.hasNext()); + Assertions.assertTrue(stream.hasNext()); r = stream.next(); - Assert.assertArrayEquals("!".getBytes(Charsets.UTF_8), r.getBody()); - Assert.assertFalse(stream.hasNext()); + Assertions.assertArrayEquals("!".getBytes(Charsets.UTF_8), r.getBody()); + Assertions.assertFalse(stream.hasNext()); }); } @@ -240,7 +240,7 @@ public void putStream() throws Exception { public void propagateErrors() throws Exception { test(client -> { FlightTestUtil.assertCode(FlightStatusCode.UNIMPLEMENTED, () -> { - client.doAction(new Action("invalid-action")).forEachRemaining(action -> Assert.fail()); + client.doAction(new Action("invalid-action")).forEachRemaining(action -> Assertions.fail()); }); }); } @@ -254,7 +254,7 @@ public void getStream() throws Exception { int value = 0; while (stream.next()) { for (int i = 0; i < root.getRowCount(); i++) { - Assert.assertEquals(value, iv.get(i)); + Assertions.assertEquals(value, iv.get(i)); value++; } } @@ -269,12 +269,12 @@ public void getStream() throws Exception { public void getStreamLargeBatch() throws Exception { test(c -> { try (final FlightStream stream = c.getStream(new Ticket(Producer.TICKET_LARGE_BATCH))) { - Assert.assertEquals(128, stream.getRoot().getFieldVectors().size()); - Assert.assertTrue(stream.next()); - Assert.assertEquals(65536, stream.getRoot().getRowCount()); - Assert.assertTrue(stream.next()); - Assert.assertEquals(65536, stream.getRoot().getRowCount()); - Assert.assertFalse(stream.next()); + Assertions.assertEquals(128, stream.getRoot().getFieldVectors().size()); + Assertions.assertTrue(stream.next()); + Assertions.assertEquals(65536, stream.getRoot().getRowCount()); + Assertions.assertTrue(stream.next()); + Assertions.assertEquals(65536, stream.getRoot().getRowCount()); + Assertions.assertFalse(stream.next()); } catch (Exception e) { throw new RuntimeException(e); } @@ -362,28 +362,28 @@ public void testProtobufRecordBatchCompatibility() throws Exception { final MethodDescriptor.Marshaller marshaller = ArrowMessage.createMarshaller(allocator); try (final ArrowMessage message = new ArrowMessage( unloader.getRecordBatch(), /* appMetadata */ null, /* tryZeroCopy */ false, IpcOption.DEFAULT)) { - Assert.assertEquals(ArrowMessage.HeaderType.RECORD_BATCH, message.getMessageType()); + Assertions.assertEquals(ArrowMessage.HeaderType.RECORD_BATCH, message.getMessageType()); // Should have at least one empty body buffer (there may be multiple for e.g. data and validity) Iterator iterator = message.getBufs().iterator(); - Assert.assertTrue(iterator.hasNext()); + Assertions.assertTrue(iterator.hasNext()); while (iterator.hasNext()) { - Assert.assertEquals(0, iterator.next().capacity()); + Assertions.assertEquals(0, iterator.next().capacity()); } final Flight.FlightData protobufData = arrowMessageToProtobuf(marshaller, message) .toBuilder() .clearDataBody() .build(); - Assert.assertEquals(0, protobufData.getDataBody().size()); + Assertions.assertEquals(0, protobufData.getDataBody().size()); ArrowMessage parsedMessage = marshaller.parse(new ByteArrayInputStream(protobufData.toByteArray())); // Should have an empty body buffer Iterator parsedIterator = parsedMessage.getBufs().iterator(); - Assert.assertTrue(parsedIterator.hasNext()); - Assert.assertEquals(0, parsedIterator.next().capacity()); + Assertions.assertTrue(parsedIterator.hasNext()); + Assertions.assertEquals(0, parsedIterator.next().capacity()); // Should have only one (the parser synthesizes exactly one); in the case of empty buffers, this is equivalent - Assert.assertFalse(parsedIterator.hasNext()); + Assertions.assertFalse(parsedIterator.hasNext()); // Should not throw final ArrowRecordBatch rb = parsedMessage.asRecordBatch(); - Assert.assertEquals(rb.computeBodyLength(), 0); + Assertions.assertEquals(rb.computeBodyLength(), 0); } } } @@ -396,17 +396,17 @@ public void testProtobufSchemaCompatibility() throws Exception { final MethodDescriptor.Marshaller marshaller = ArrowMessage.createMarshaller(allocator); Flight.FlightDescriptor descriptor = FlightDescriptor.command(new byte[0]).toProtocol(); try (final ArrowMessage message = new ArrowMessage(descriptor, schema, IpcOption.DEFAULT)) { - Assert.assertEquals(ArrowMessage.HeaderType.SCHEMA, message.getMessageType()); + Assertions.assertEquals(ArrowMessage.HeaderType.SCHEMA, message.getMessageType()); // Should have no body buffers - Assert.assertFalse(message.getBufs().iterator().hasNext()); + Assertions.assertFalse(message.getBufs().iterator().hasNext()); final Flight.FlightData protobufData = arrowMessageToProtobuf(marshaller, message) .toBuilder() .setDataBody(ByteString.EMPTY) .build(); - Assert.assertEquals(0, protobufData.getDataBody().size()); + Assertions.assertEquals(0, protobufData.getDataBody().size()); final ArrowMessage parsedMessage = marshaller.parse(new ByteArrayInputStream(protobufData.toByteArray())); // Should have no body buffers - Assert.assertFalse(parsedMessage.getBufs().iterator().hasNext()); + Assertions.assertFalse(parsedMessage.getBufs().iterator().hasNext()); // Should not throw parsedMessage.asSchema(); } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java index d739189e080..adfa44ef9c8 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestCallOptions.java @@ -26,16 +26,16 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.junit.Assert; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import io.grpc.Metadata; public class TestCallOptions { @Test - @Ignore + @Disabled public void timeoutFires() { // Ignored due to CI flakiness test((client) -> { @@ -43,26 +43,26 @@ public void timeoutFires() { Iterator results = client.doAction(new Action("hang"), CallOptions.timeout(1, TimeUnit.SECONDS)); try { results.next(); - Assert.fail("Call should have failed"); + Assertions.fail("Call should have failed"); } catch (RuntimeException e) { - Assert.assertTrue(e.getMessage(), e.getMessage().contains("deadline exceeded")); + Assertions.assertTrue(e.getMessage().contains("deadline exceeded"), e.getMessage()); } Instant end = Instant.now(); - Assert.assertTrue("Call took over 1500 ms despite timeout", Duration.between(start, end).toMillis() < 1500); + Assertions.assertTrue(Duration.between(start, end).toMillis() < 1500, "Call took over 1500 ms despite timeout"); }); } @Test - @Ignore + @Disabled public void underTimeout() { // Ignored due to CI flakiness test((client) -> { Instant start = Instant.now(); // This shouldn't fail and it should complete within the timeout Iterator results = client.doAction(new Action("fast"), CallOptions.timeout(2, TimeUnit.SECONDS)); - Assert.assertArrayEquals(new byte[]{42, 42}, results.next().getBody()); + Assertions.assertArrayEquals(new byte[]{42, 42}, results.next().getBody()); Instant end = Instant.now(); - Assert.assertTrue("Call took over 2500 ms despite timeout", Duration.between(start, end).toMillis() < 2500); + Assertions.assertTrue(Duration.between(start, end).toMillis() < 2500, "Call took over 2500 ms despite timeout"); }); } @@ -104,13 +104,13 @@ private void testHeaders(CallHeaders headers) { FlightServer s = FlightTestUtil.getStartedServer((location) -> FlightServer.builder(a, location, producer).build()); FlightClient client = FlightClient.builder(a, s.getLocation()).build()) { - Assert.assertFalse(client.doAction(new Action(""), new HeaderCallOption(headers)).hasNext()); + Assertions.assertFalse(client.doAction(new Action(""), new HeaderCallOption(headers)).hasNext()); final CallHeaders incomingHeaders = producer.headers(); for (String key : headers.keys()) { if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { - Assert.assertArrayEquals(headers.getByte(key), incomingHeaders.getByte(key)); + Assertions.assertArrayEquals(headers.getByte(key), incomingHeaders.getByte(key)); } else { - Assert.assertEquals(headers.get(key), incomingHeaders.get(key)); + Assertions.assertEquals(headers.get(key), incomingHeaders.get(key)); } } } catch (InterruptedException | IOException e) { diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestClientMiddleware.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestClientMiddleware.java index f150a294aa4..a191a597f41 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestClientMiddleware.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestClientMiddleware.java @@ -28,15 +28,12 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.junit.Assert; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; /** * A basic test of client middleware using a simplified OpenTracing-like example. */ -@RunWith(JUnit4.class) public class TestClientMiddleware { /** @@ -65,9 +62,9 @@ public void middleware_propagateHeader() { FlightTestUtil.assertCode(FlightStatusCode.UNIMPLEMENTED, () -> client.listActions().forEach(actionType -> { })); }); - Assert.assertEquals(context.outgoingSpanId, context.incomingSpanId); - Assert.assertNotNull(context.finalStatus); - Assert.assertEquals(FlightStatusCode.UNIMPLEMENTED, context.finalStatus.code()); + Assertions.assertEquals(context.outgoingSpanId, context.incomingSpanId); + Assertions.assertNotNull(context.finalStatus); + Assertions.assertEquals(FlightStatusCode.UNIMPLEMENTED, context.finalStatus.code()); } /** Ensure both server and client can send and receive multi-valued headers (both binary and text values). */ @@ -87,18 +84,20 @@ public void testMultiValuedHeaders() { for (final Map.Entry> entry : EXPECTED_BINARY_HEADERS.entrySet()) { // Compare header values entry-by-entry because byte arrays don't compare via equals final List receivedValues = clientFactory.lastBinaryHeaders.get(entry.getKey()); - Assert.assertNotNull("Missing for header: " + entry.getKey(), receivedValues); - Assert.assertEquals( - "Missing or wrong value for header: " + entry.getKey(), - entry.getValue().size(), receivedValues.size()); + Assertions.assertNotNull(receivedValues, "Missing for header: " + entry.getKey()); + Assertions.assertEquals( + entry.getValue().size(), + receivedValues.size(), "Missing or wrong value for header: " + entry.getKey()); for (int i = 0; i < entry.getValue().size(); i++) { - Assert.assertArrayEquals(entry.getValue().get(i), receivedValues.get(i)); + Assertions.assertArrayEquals(entry.getValue().get(i), receivedValues.get(i)); } } for (final Map.Entry> entry : EXPECTED_TEXT_HEADERS.entrySet()) { - Assert.assertEquals( - "Missing or wrong value for header: " + entry.getKey(), - entry.getValue(), clientFactory.lastTextHeaders.get(entry.getKey())); + Assertions.assertEquals( + entry.getValue(), + clientFactory.lastTextHeaders.get(entry.getKey()), + "Missing or wrong value for header: " + entry.getKey() + ); } } @@ -329,11 +328,11 @@ public MultiHeaderClientMiddleware(MultiHeaderClientMiddlewareFactory factory) { public void onBeforeSendingHeaders(CallHeaders outgoingHeaders) { for (final Map.Entry> entry : EXPECTED_BINARY_HEADERS.entrySet()) { entry.getValue().forEach((value) -> outgoingHeaders.insert(entry.getKey(), value)); - Assert.assertTrue(outgoingHeaders.containsKey(entry.getKey())); + Assertions.assertTrue(outgoingHeaders.containsKey(entry.getKey())); } for (final Map.Entry> entry : EXPECTED_TEXT_HEADERS.entrySet()) { entry.getValue().forEach((value) -> outgoingHeaders.insert(entry.getKey(), value)); - Assert.assertTrue(outgoingHeaders.containsKey(entry.getKey())); + Assertions.assertTrue(outgoingHeaders.containsKey(entry.getKey())); } } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDictionaryUtils.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDictionaryUtils.java index b5bf117c628..b3a716ab3ce 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDictionaryUtils.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDictionaryUtils.java @@ -32,7 +32,7 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.Test; +import org.junit.jupiter.api.Test; import com.google.common.collect.ImmutableList; diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java index 6c9b560342b..c2f8e755969 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestDoExchange.java @@ -17,12 +17,12 @@ package org.apache.arrow.flight; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import java.nio.charset.StandardCharsets; import java.util.Arrays; @@ -43,10 +43,10 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; public class TestDoExchange { static byte[] EXCHANGE_DO_GET = "do-get".getBytes(StandardCharsets.UTF_8); @@ -60,7 +60,7 @@ public class TestDoExchange { private FlightServer server; private FlightClient client; - @Before + @BeforeEach public void setUp() throws Exception { allocator = new RootAllocator(Integer.MAX_VALUE); final Location serverLocation = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, 0); @@ -70,7 +70,7 @@ public void setUp() throws Exception { client = FlightClient.builder(allocator, clientLocation).build(); } - @After + @AfterEach public void tearDown() throws Exception { AutoCloseables.close(client, server, allocator); } @@ -115,7 +115,7 @@ public void testDoExchangeDoGet() throws Exception { int value = 0; while (reader.next()) { for (int i = 0; i < root.getRowCount(); i++) { - assertFalse(String.format("Row %d should not be null", value), iv.isNull(i)); + assertFalse(iv.isNull(i), String.format("Row %d should not be null", value)); assertEquals(value, iv.get(i)); value++; } @@ -200,7 +200,7 @@ public void testDoExchangeEcho() throws Exception { stream.getWriter().completed(); // The server will end its side of the call, so this shouldn't block or indicate that // there is more data. - assertFalse("We should not be waiting for any messages", reader.next()); + assertFalse(reader.next(), "We should not be waiting for any messages"); } } @@ -233,7 +233,7 @@ public void testTransform() throws Exception { assertEquals(schema, reader.getSchema()); final VectorSchemaRoot root = reader.getRoot(); for (int batchIndex = 0; batchIndex < 10; batchIndex++) { - assertTrue("Didn't receive batch #" + batchIndex, reader.next()); + assertTrue(reader.next(), "Didn't receive batch #" + batchIndex); assertEquals(batchIndex, root.getRowCount()); for (final FieldVector rawVec : root.getFieldVectors()) { final IntVector vec = (IntVector) rawVec; @@ -244,9 +244,9 @@ public void testTransform() throws Exception { } // The server also sends back a metadata-only message containing the message count - assertTrue("There should be one extra message", reader.next()); + assertTrue(reader.next(), "There should be one extra message"); assertEquals(10, reader.getLatestMetadata().getInt(0)); - assertFalse("There should be no more data", reader.next()); + assertFalse(reader.next(), "There should be no more data"); } } @@ -289,7 +289,7 @@ public void testTransformZeroCopy() throws Exception { assertEquals(schema, reader.getSchema()); final VectorSchemaRoot root = reader.getRoot(); for (int batchIndex = 0; batchIndex < 100; batchIndex++) { - assertTrue("Didn't receive batch #" + batchIndex, reader.next()); + assertTrue(reader.next(), "Didn't receive batch #" + batchIndex); assertEquals(rowsPerBatch, root.getRowCount()); for (final FieldVector rawVec : root.getFieldVectors()) { final IntVector vec = (IntVector) rawVec; @@ -300,9 +300,9 @@ public void testTransformZeroCopy() throws Exception { } // The server also sends back a metadata-only message containing the message count - assertTrue("There should be one extra message", reader.next()); + assertTrue(reader.next(), "There should be one extra message"); assertEquals(100, reader.getLatestMetadata().getInt(0)); - assertFalse("There should be no more data", reader.next()); + assertFalse(reader.next(), "There should be no more data"); } } @@ -354,7 +354,7 @@ public void testServerCancelLeak() throws Exception { /** Have the client cancel without reading; ensure memory is not leaked. */ @Test - @Ignore + @Disabled public void testClientCancel() throws Exception { try (final FlightClient.ExchangeReaderWriter stream = client.doExchange(FlightDescriptor.command(EXCHANGE_DO_GET))) { diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java index 2c62bc7fa68..1f1bbbe50fb 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestErrorMetadata.java @@ -20,8 +20,8 @@ import org.apache.arrow.flight.perf.impl.PerfOuterClass; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; @@ -58,21 +58,21 @@ public void testGrpcMetadata() throws Exception { }); PerfOuterClass.Perf newPerf = null; ErrorFlightMetadata metadata = flightStatus.metadata(); - Assert.assertNotNull(metadata); - Assert.assertEquals(2, metadata.keys().size()); - Assert.assertTrue(metadata.containsKey("grpc-status-details-bin")); + Assertions.assertNotNull(metadata); + Assertions.assertEquals(2, metadata.keys().size()); + Assertions.assertTrue(metadata.containsKey("grpc-status-details-bin")); Status status = marshaller.parseBytes(metadata.getByte("grpc-status-details-bin")); for (Any details : status.getDetailsList()) { if (details.is(PerfOuterClass.Perf.class)) { try { newPerf = details.unpack(PerfOuterClass.Perf.class); } catch (InvalidProtocolBufferException e) { - Assert.fail(); + Assertions.fail(); } } } - Assert.assertNotNull(newPerf); - Assert.assertEquals(perf, newPerf); + Assertions.assertNotNull(newPerf); + Assertions.assertEquals(perf, newPerf); } } @@ -89,17 +89,17 @@ public void testFlightMetadata() throws Exception { stream.next(); }); ErrorFlightMetadata metadata = flightStatus.metadata(); - Assert.assertNotNull(metadata); - Assert.assertEquals("foo", metadata.get("x-foo")); - Assert.assertArrayEquals(new byte[]{1}, metadata.getByte("x-bar-bin")); + Assertions.assertNotNull(metadata); + Assertions.assertEquals("foo", metadata.get("x-foo")); + Assertions.assertArrayEquals(new byte[]{1}, metadata.getByte("x-bar-bin")); flightStatus = FlightTestUtil.assertCode(FlightStatusCode.INVALID_ARGUMENT, () -> { client.getInfo(FlightDescriptor.command(new byte[0])); }); metadata = flightStatus.metadata(); - Assert.assertNotNull(metadata); - Assert.assertEquals("foo", metadata.get("x-foo")); - Assert.assertArrayEquals(new byte[]{1}, metadata.getByte("x-bar-bin")); + Assertions.assertNotNull(metadata); + Assertions.assertEquals("foo", metadata.get("x-foo")); + Assertions.assertArrayEquals(new byte[]{1}, metadata.getByte("x-bar-bin")); } } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightClient.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightClient.java index 30e351e941a..d6cc175b99d 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightClient.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightClient.java @@ -40,10 +40,9 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.Assert; -import org.junit.Ignore; -import org.junit.Test; import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; public class TestFlightClient { /** @@ -63,7 +62,7 @@ public void independentShutdown() throws Exception { final ClientStreamListener listener = client1.startPut(FlightDescriptor.path("test"), root, new AsyncPutListener()); try (final FlightClient client2 = FlightClient.builder(allocator, location).build()) { - client2.listActions().forEach(actionType -> Assert.assertNotNull(actionType.getType())); + client2.listActions().forEach(actionType -> Assertions.assertNotNull(actionType.getType())); } listener.completed(); listener.getResult(); @@ -74,7 +73,7 @@ public void independentShutdown() throws Exception { /** * ARROW-5978: make sure that we can properly close a client/stream after requesting dictionaries. */ - @Ignore // Unfortunately this test is flaky in CI. + @Disabled // Unfortunately this test is flaky in CI. @Test public void freeDictionaries() throws Exception { final Schema expectedSchema = new Schema(Collections @@ -88,18 +87,18 @@ public void freeDictionaries() throws Exception { final Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()); try (final FlightClient client = FlightClient.builder(allocator, location).build()) { try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) { - Assert.assertTrue(stream.next()); - Assert.assertNotNull(stream.getDictionaryProvider().lookup(1)); + Assertions.assertTrue(stream.next()); + Assertions.assertNotNull(stream.getDictionaryProvider().lookup(1)); final VectorSchemaRoot root = stream.getRoot(); - Assert.assertEquals(expectedSchema, root.getSchema()); - Assert.assertEquals(6, root.getVector("encoded").getValueCount()); + Assertions.assertEquals(expectedSchema, root.getSchema()); + Assertions.assertEquals(6, root.getVector("encoded").getValueCount()); try (final ValueVector decoded = DictionaryEncoder .decode(root.getVector("encoded"), stream.getDictionaryProvider().lookup(1))) { - Assert.assertFalse(decoded.isNull(1)); - Assert.assertTrue(decoded instanceof VarCharVector); - Assert.assertArrayEquals("one".getBytes(StandardCharsets.UTF_8), ((VarCharVector) decoded).get(1)); + Assertions.assertFalse(decoded.isNull(1)); + Assertions.assertTrue(decoded instanceof VarCharVector); + Assertions.assertArrayEquals("one".getBytes(StandardCharsets.UTF_8), ((VarCharVector) decoded).get(1)); } - Assert.assertFalse(stream.next()); + Assertions.assertFalse(stream.next()); } // Closing stream fails if it doesn't free dictionaries; closing dictionaries fails (refcount goes negative) // if reference isn't retained in ArrowMessage @@ -110,7 +109,7 @@ public void freeDictionaries() throws Exception { /** * ARROW-5978: make sure that dictionary ownership can't be claimed twice. */ - @Ignore // Unfortunately this test is flaky in CI. + @Disabled // Unfortunately this test is flaky in CI. @Test public void ownDictionaries() throws Exception { try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); @@ -121,8 +120,8 @@ public void ownDictionaries() throws Exception { final Location location = Location.forGrpcInsecure(FlightTestUtil.LOCALHOST, server.getPort()); try (final FlightClient client = FlightClient.builder(allocator, location).build()) { try (final FlightStream stream = client.getStream(new Ticket(new byte[0]))) { - Assert.assertTrue(stream.next()); - Assert.assertFalse(stream.next()); + Assertions.assertTrue(stream.next()); + Assertions.assertFalse(stream.next()); final DictionaryProvider provider = stream.takeDictionaryOwnership(); Assertions.assertThrows(IllegalStateException.class, stream::takeDictionaryOwnership); Assertions.assertThrows(IllegalStateException.class, stream::getDictionaryProvider); @@ -135,7 +134,7 @@ public void ownDictionaries() throws Exception { /** * ARROW-5978: make sure that dictionaries can be used after closing the stream. */ - @Ignore // Unfortunately this test is flaky in CI. + @Disabled // Unfortunately this test is flaky in CI. @Test public void useDictionariesAfterClose() throws Exception { try (final BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); @@ -160,9 +159,9 @@ public void useDictionariesAfterClose() throws Exception { } try (final ValueVector decoded = DictionaryEncoder .decode(root.getVector("encoded"), provider.lookup(1))) { - Assert.assertFalse(decoded.isNull(1)); - Assert.assertTrue(decoded instanceof VarCharVector); - Assert.assertArrayEquals("one".getBytes(StandardCharsets.UTF_8), ((VarCharVector) decoded).get(1)); + Assertions.assertFalse(decoded.isNull(1)); + Assertions.assertTrue(decoded instanceof VarCharVector); + Assertions.assertArrayEquals("one".getBytes(StandardCharsets.UTF_8), ((VarCharVector) decoded).get(1)); } root.close(); DictionaryUtils.closeDictionaries(root.getSchema(), provider); diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java index 65ef12a8acf..fb47a84164b 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestFlightService.java @@ -23,9 +23,9 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.AutoCloseables; -import org.junit.After; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import io.grpc.stub.ServerCallStreamObserver; @@ -33,12 +33,12 @@ public class TestFlightService { private BufferAllocator allocator; - @Before + @BeforeEach public void setup() { allocator = new RootAllocator(Long.MAX_VALUE); } - @After + @AfterEach public void cleanup() throws Exception { AutoCloseables.close(allocator); } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLargeMessage.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLargeMessage.java index 629b6f5ebd8..7c7011a8cd2 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLargeMessage.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLargeMessage.java @@ -29,8 +29,8 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; public class TestLargeMessage { /** @@ -51,7 +51,7 @@ public void getLargeMessage() throws Exception { int value = 0; final IntVector iv = (IntVector) root.getVector(field.getName()); for (int i = 0; i < root.getRowCount(); i++) { - Assert.assertEquals(value, iv.get(i)); + Assertions.assertEquals(value, iv.get(i)); value++; } } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLeak.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLeak.java index 6e28704997f..9c9da1249a3 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLeak.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestLeak.java @@ -29,7 +29,7 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.Test; +import org.junit.jupiter.api.Test; /** * Tests for scenarios where Flight could leak memory. diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java index 83a694bf34e..d6efa4ff800 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestMetadataVersion.java @@ -17,10 +17,10 @@ package org.apache.arrow.flight; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import java.nio.charset.StandardCharsets; import java.util.Arrays; @@ -36,9 +36,9 @@ import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; /** * Test clients/servers with different metadata versions. @@ -50,7 +50,7 @@ public class TestMetadataVersion { private static IpcOption optionV5; private static Schema unionSchema; - @BeforeClass + @BeforeAll public static void setUpClass() { allocator = new RootAllocator(Integer.MAX_VALUE); schema = new Schema(Collections.singletonList(Field.nullable("foo", new ArrowType.Int(32, true)))); @@ -62,7 +62,7 @@ public static void setUpClass() { optionV5 = IpcOption.DEFAULT; } - @AfterClass + @AfterAll public static void tearDownClass() { allocator.close(); } @@ -94,7 +94,7 @@ public void testUnionCheck() throws Exception { final FlightClient client = connect(server); final FlightStream stream = client.getStream(new Ticket("union".getBytes(StandardCharsets.UTF_8)))) { final FlightRuntimeException err = assertThrows(FlightRuntimeException.class, stream::next); - assertTrue(err.getMessage(), err.getMessage().contains("Cannot write union with V4 metadata")); + assertTrue(err.getMessage().contains("Cannot write union with V4 metadata"), err.getMessage()); } try (final FlightServer server = startServer(optionV4); @@ -105,7 +105,7 @@ public void testUnionCheck() throws Exception { final FlightClient.ClientStreamListener listener = client.startPut(descriptor, reader); final IllegalArgumentException err = assertThrows(IllegalArgumentException.class, () -> listener.start(root, null, optionV4)); - assertTrue(err.getMessage(), err.getMessage().contains("Cannot write union with V4 metadata")); + assertTrue(err.getMessage().contains("Cannot write union with V4 metadata"), err.getMessage()); } } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerMiddleware.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerMiddleware.java index 1f3e35ca38d..79c5811c490 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerMiddleware.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerMiddleware.java @@ -30,12 +30,9 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.Assert; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; -@RunWith(JUnit4.class) public class TestServerMiddleware { private static final RuntimeException EXPECTED_EXCEPTION = new RuntimeException("test"); @@ -56,9 +53,9 @@ public void doPutErrors() { } }, (recorder) -> { final CallStatus status = recorder.statusFuture.get(); - Assert.assertNotNull(status); - Assert.assertNotNull(status.cause()); - Assert.assertEquals(FlightStatusCode.INTERNAL, status.code()); + Assertions.assertNotNull(status); + Assertions.assertNotNull(status.cause()); + Assertions.assertEquals(FlightStatusCode.INTERNAL, status.code()); }); // Check the status after server shutdown (to make sure gRPC finishes pending calls on the server side) } @@ -79,10 +76,10 @@ public void doPutCustomCode() { } }, (recorder) -> { final CallStatus status = recorder.statusFuture.get(); - Assert.assertNotNull(status); - Assert.assertNull(status.cause()); - Assert.assertEquals(FlightStatusCode.UNAVAILABLE, status.code()); - Assert.assertEquals("description", status.description()); + Assertions.assertNotNull(status); + Assertions.assertNull(status.cause()); + Assertions.assertEquals(FlightStatusCode.UNAVAILABLE, status.code()); + Assertions.assertEquals("description", status.description()); }); } @@ -102,11 +99,11 @@ public void doPutUncaught() { }, (recorder) -> { final CallStatus status = recorder.statusFuture.get(); final Throwable err = recorder.errFuture.get(); - Assert.assertNotNull(status); - Assert.assertEquals(FlightStatusCode.OK, status.code()); - Assert.assertNull(status.cause()); - Assert.assertNotNull(err); - Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); + Assertions.assertNotNull(status); + Assertions.assertEquals(FlightStatusCode.OK, status.code()); + Assertions.assertNull(status.cause()); + Assertions.assertNotNull(err); + Assertions.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); }); } @@ -117,11 +114,11 @@ public void listFlightsUncaught() { }), (recorder) -> { final CallStatus status = recorder.statusFuture.get(); final Throwable err = recorder.errFuture.get(); - Assert.assertNotNull(status); - Assert.assertEquals(FlightStatusCode.OK, status.code()); - Assert.assertNull(status.cause()); - Assert.assertNotNull(err); - Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); + Assertions.assertNotNull(status); + Assertions.assertEquals(FlightStatusCode.OK, status.code()); + Assertions.assertNull(status.cause()); + Assertions.assertNotNull(err); + Assertions.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); }); } @@ -132,11 +129,11 @@ public void doActionUncaught() { }), (recorder) -> { final CallStatus status = recorder.statusFuture.get(); final Throwable err = recorder.errFuture.get(); - Assert.assertNotNull(status); - Assert.assertEquals(FlightStatusCode.OK, status.code()); - Assert.assertNull(status.cause()); - Assert.assertNotNull(err); - Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); + Assertions.assertNotNull(status); + Assertions.assertEquals(FlightStatusCode.OK, status.code()); + Assertions.assertNull(status.cause()); + Assertions.assertNotNull(err); + Assertions.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); }); } @@ -147,11 +144,11 @@ public void listActionsUncaught() { }), (recorder) -> { final CallStatus status = recorder.statusFuture.get(); final Throwable err = recorder.errFuture.get(); - Assert.assertNotNull(status); - Assert.assertEquals(FlightStatusCode.OK, status.code()); - Assert.assertNull(status.cause()); - Assert.assertNotNull(err); - Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); + Assertions.assertNotNull(status); + Assertions.assertEquals(FlightStatusCode.OK, status.code()); + Assertions.assertNull(status.cause()); + Assertions.assertNotNull(err); + Assertions.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); }); } @@ -162,10 +159,10 @@ public void getFlightInfoUncaught() { FlightTestUtil.assertCode(FlightStatusCode.INTERNAL, () -> client.getInfo(FlightDescriptor.path("test"))); }, (recorder) -> { final CallStatus status = recorder.statusFuture.get(); - Assert.assertNotNull(status); - Assert.assertEquals(FlightStatusCode.INTERNAL, status.code()); - Assert.assertNotNull(status.cause()); - Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), status.cause().getMessage()); + Assertions.assertNotNull(status); + Assertions.assertEquals(FlightStatusCode.INTERNAL, status.code()); + Assertions.assertNotNull(status.cause()); + Assertions.assertEquals(EXPECTED_EXCEPTION.getMessage(), status.cause().getMessage()); }); } @@ -177,16 +174,16 @@ public void doGetUncaught() { while (stream.next()) { } } catch (Exception e) { - Assert.fail(e.toString()); + Assertions.fail(e.toString()); } }, (recorder) -> { final CallStatus status = recorder.statusFuture.get(); final Throwable err = recorder.errFuture.get(); - Assert.assertNotNull(status); - Assert.assertEquals(FlightStatusCode.OK, status.code()); - Assert.assertNull(status.cause()); - Assert.assertNotNull(err); - Assert.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); + Assertions.assertNotNull(status); + Assertions.assertEquals(FlightStatusCode.OK, status.code()); + Assertions.assertNull(status.cause()); + Assertions.assertNotNull(err); + Assertions.assertEquals(EXPECTED_EXCEPTION.getMessage(), err.getMessage()); }); } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerOptions.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerOptions.java index 363ad443e48..03f11cec10f 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerOptions.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestServerOptions.java @@ -17,8 +17,8 @@ package org.apache.arrow.flight; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; import java.io.File; import java.util.HashMap; @@ -35,17 +35,14 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.IntVector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.junit.Assert; -import org.junit.Assume; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; import io.grpc.MethodDescriptor; import io.grpc.ServerServiceDefinition; import io.grpc.netty.NettyServerBuilder; -@RunWith(JUnit4.class) public class TestServerOptions { @Test @@ -61,7 +58,7 @@ public void builderConsumer() throws Exception { (location) -> FlightServer.builder(a, location, producer) .transportHint("grpc.builderConsumer", consumer).build() )) { - Assert.assertTrue(consumerCalled.get()); + Assertions.assertTrue(consumerCalled.get()); } } @@ -81,7 +78,7 @@ public void defaultExecutorClosed() throws Exception { assertNotNull(server.grpcExecutor); executor = server.grpcExecutor; } - Assert.assertTrue(executor.isShutdown()); + Assertions.assertTrue(executor.isShutdown()); } /** @@ -99,9 +96,9 @@ public void suppliedExecutorNotClosed() throws Exception { .executor(executor) .build() )) { - Assert.assertNull(server.grpcExecutor); + Assertions.assertNull(server.grpcExecutor); } - Assert.assertFalse(executor.isShutdown()); + Assertions.assertFalse(executor.isShutdown()); } finally { executor.shutdown(); } @@ -109,12 +106,12 @@ public void suppliedExecutorNotClosed() throws Exception { @Test public void domainSocket() throws Exception { - Assume.assumeTrue("We have a native transport available", FlightTestUtil.isNativeTransportAvailable()); + Assumptions.assumeTrue(FlightTestUtil.isNativeTransportAvailable(), "We have a native transport available"); final File domainSocket = File.createTempFile("flight-unit-test-", ".sock"); - Assert.assertTrue(domainSocket.delete()); + Assertions.assertTrue(domainSocket.delete()); // Domain socket paths have a platform-dependent limit. Set a conservative limit and skip the test if the temporary // file name is too long. (We do not assume a particular platform-dependent temporary directory path.) - Assume.assumeTrue("The domain socket path is not too long", domainSocket.getAbsolutePath().length() < 100); + Assumptions.assumeTrue(domainSocket.getAbsolutePath().length() < 100, "The domain socket path is not too long"); final Location location = Location.forGrpcDomainSocket(domainSocket.getAbsolutePath()); try ( BufferAllocator a = new RootAllocator(Long.MAX_VALUE); @@ -130,7 +127,7 @@ public void domainSocket() throws Exception { int value = 0; while (stream.next()) { for (int i = 0; i < root.getRowCount(); i++) { - Assert.assertEquals(value, iv.get(i)); + Assertions.assertEquals(value, iv.get(i)); value++; } } @@ -161,10 +158,10 @@ public void checkReflectionMetadata() { for (final MethodDescriptor descriptor : FlightServiceGrpc.getServiceDescriptor().getMethods()) { final String methodName = descriptor.getFullMethodName(); - Assert.assertTrue("Method is missing from ServerServiceDefinition: " + methodName, - definedMethods.containsKey(methodName)); - Assert.assertTrue("Method is missing from ServiceDescriptor: " + methodName, - definedMethods.containsKey(methodName)); + Assertions.assertTrue(definedMethods.containsKey(methodName), + "Method is missing from ServerServiceDefinition: " + methodName); + Assertions.assertTrue(definedMethods.containsKey(methodName), + "Method is missing from ServiceDescriptor: " + methodName); assertEquals(descriptor.getSchemaDescriptor(), definedMethods.get(methodName).getSchemaDescriptor()); assertEquals(descriptor.getSchemaDescriptor(), serviceMethods.get(methodName).getSchemaDescriptor()); diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestTls.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestTls.java index c5cd871e2be..a552f635b9c 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestTls.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/TestTls.java @@ -27,8 +27,8 @@ import org.apache.arrow.flight.FlightClient.Builder; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; /** * Tests for TLS in Flight. @@ -45,8 +45,8 @@ public void connectTls() { final FlightClient client = builder.trustedCertificates(roots).build()) { final Iterator responses = client.doAction(new Action("hello-world")); final byte[] response = responses.next().getBody(); - Assert.assertEquals("Hello, world!", new String(response, StandardCharsets.UTF_8)); - Assert.assertFalse(responses.hasNext()); + Assertions.assertEquals("Hello, world!", new String(response, StandardCharsets.UTF_8)); + Assertions.assertFalse(responses.hasNext()); } catch (InterruptedException | IOException e) { throw new RuntimeException(e); } @@ -94,8 +94,8 @@ public void connectTlsDisableServerVerification() { try (final FlightClient client = builder.verifyServer(false).build()) { final Iterator responses = client.doAction(new Action("hello-world")); final byte[] response = responses.next().getBody(); - Assert.assertEquals("Hello, world!", new String(response, StandardCharsets.UTF_8)); - Assert.assertFalse(responses.hasNext()); + Assertions.assertEquals("Hello, world!", new String(response, StandardCharsets.UTF_8)); + Assertions.assertFalse(responses.hasNext()); } catch (InterruptedException e) { throw new RuntimeException(e); } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java index c18f5709b54..6ec507b5906 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth/TestBasicAuth.java @@ -38,11 +38,11 @@ import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import com.google.common.collect.ImmutableList; @@ -59,18 +59,18 @@ public class TestBasicAuth { @Test public void validAuth() { client.authenticateBasic(USERNAME, PASSWORD); - Assert.assertTrue(ImmutableList.copyOf(client.listFlights(Criteria.ALL)).size() == 0); + Assertions.assertTrue(ImmutableList.copyOf(client.listFlights(Criteria.ALL)).size() == 0); } // ARROW-7722: this test occasionally leaks memory - @Ignore + @Disabled @Test public void asyncCall() throws Exception { client.authenticateBasic(USERNAME, PASSWORD); client.listFlights(Criteria.ALL); try (final FlightStream s = client.getStream(new Ticket(new byte[1]))) { while (s.next()) { - Assert.assertEquals(4095, s.getRoot().getRowCount()); + Assertions.assertEquals(4095, s.getRoot().getRowCount()); } } } @@ -82,18 +82,18 @@ public void invalidAuth() { }); FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> { - client.listFlights(Criteria.ALL).forEach(action -> Assert.fail()); + client.listFlights(Criteria.ALL).forEach(action -> Assertions.fail()); }); } @Test public void didntAuth() { FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> { - client.listFlights(Criteria.ALL).forEach(action -> Assert.fail()); + client.listFlights(Criteria.ALL).forEach(action -> Assertions.fail()); }); } - @Before + @BeforeEach public void setup() throws IOException { allocator = new RootAllocator(Long.MAX_VALUE); final BasicServerAuthHandler.BasicAuthValidator validator = new BasicServerAuthHandler.BasicAuthValidator() { @@ -150,7 +150,7 @@ public void getStream(CallContext context, Ticket ticket, ServerStreamListener l client = FlightClient.builder(allocator, server.getLocation()).build(); } - @After + @AfterEach public void shutdown() throws Exception { AutoCloseables.close(client, server, allocator); } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth2/TestBasicAuth2.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth2/TestBasicAuth2.java index 9bec32f1b72..310971ba958 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth2/TestBasicAuth2.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/auth2/TestBasicAuth2.java @@ -38,11 +38,11 @@ import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; @@ -59,7 +59,7 @@ public class TestBasicAuth2 { private FlightClient client; private FlightClient client2; - @Before + @BeforeEach public void setup() throws Exception { allocator = new RootAllocator(Long.MAX_VALUE); startServerAndClient(); @@ -108,7 +108,7 @@ private void startServerAndClient() throws IOException { .build(); } - @After + @AfterEach public void shutdown() throws Exception { AutoCloseables.close(client, client2, server, allocator); client = null; @@ -155,7 +155,7 @@ public void validAuthWithMultipleClientsWithDifferentCredentialsWithBearerAuthSe } // ARROW-7722: this test occasionally leaks memory - @Ignore + @Disabled @Test public void asyncCall() throws Exception { final CredentialCallOption bearerToken = client @@ -163,7 +163,7 @@ public void asyncCall() throws Exception { client.listFlights(Criteria.ALL, bearerToken); try (final FlightStream s = client.getStream(new Ticket(new byte[1]))) { while (s.next()) { - Assert.assertEquals(4095, s.getRoot().getRowCount()); + Assertions.assertEquals(4095, s.getRoot().getRowCount()); } } } @@ -181,7 +181,7 @@ public void didntAuthWithBearerAuthServer() throws IOException { private void testValidAuth(FlightClient client) { final CredentialCallOption bearerToken = client .authenticateBasicToken(USERNAME_1, PASSWORD_1).get(); - Assert.assertTrue(ImmutableList.copyOf(client + Assertions.assertTrue(ImmutableList.copyOf(client .listFlights(Criteria.ALL, bearerToken)) .isEmpty()); } @@ -192,10 +192,10 @@ private void testValidAuthWithMultipleClientsWithSameCredentials( .authenticateBasicToken(USERNAME_1, PASSWORD_1).get(); final CredentialCallOption bearerToken2 = client2 .authenticateBasicToken(USERNAME_1, PASSWORD_1).get(); - Assert.assertTrue(ImmutableList.copyOf(client1 + Assertions.assertTrue(ImmutableList.copyOf(client1 .listFlights(Criteria.ALL, bearerToken1)) .isEmpty()); - Assert.assertTrue(ImmutableList.copyOf(client2 + Assertions.assertTrue(ImmutableList.copyOf(client2 .listFlights(Criteria.ALL, bearerToken2)) .isEmpty()); } @@ -206,10 +206,10 @@ private void testValidAuthWithMultipleClientsWithDifferentCredentials( .authenticateBasicToken(USERNAME_1, PASSWORD_1).get(); final CredentialCallOption bearerToken2 = client2 .authenticateBasicToken(USERNAME_2, PASSWORD_2).get(); - Assert.assertTrue(ImmutableList.copyOf(client1 + Assertions.assertTrue(ImmutableList.copyOf(client1 .listFlights(Criteria.ALL, bearerToken1)) .isEmpty()); - Assert.assertTrue(ImmutableList.copyOf(client2 + Assertions.assertTrue(ImmutableList.copyOf(client2 .listFlights(Criteria.ALL, bearerToken2)) .isEmpty()); } @@ -222,11 +222,11 @@ private void testInvalidAuth(FlightClient client) { client.authenticateBasicToken(NO_USERNAME, PASSWORD_1)); FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> - client.listFlights(Criteria.ALL).forEach(action -> Assert.fail())); + client.listFlights(Criteria.ALL).forEach(action -> Assertions.fail())); } private void didntAuth(FlightClient client) { FlightTestUtil.assertCode(FlightStatusCode.UNAUTHENTICATED, () -> - client.listFlights(Criteria.ALL).forEach(action -> Assert.fail())); + client.listFlights(Criteria.ALL).forEach(action -> Assertions.fail())); } } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java index f205f9a3b63..235bcbadb3b 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/client/TestCookieHandling.java @@ -36,11 +36,11 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.AutoCloseables; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; /** * Tests for correct handling of cookies from the FlightClient using {@link ClientCookieMiddleware}. @@ -55,13 +55,13 @@ public class TestCookieHandling { private ClientCookieMiddlewareTestFactory testFactory = new ClientCookieMiddlewareTestFactory(); private ClientCookieMiddleware cookieMiddleware = new ClientCookieMiddleware(testFactory); - @Before + @BeforeEach public void setup() throws Exception { allocator = new RootAllocator(Long.MAX_VALUE); startServerAndClient(); } - @After + @AfterEach public void cleanup() throws Exception { testFactory = new ClientCookieMiddlewareTestFactory(); cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); @@ -77,7 +77,7 @@ public void basicCookie() { headersToSend.insert(SET_COOKIE_HEADER, "k=v"); cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); cookieMiddleware.onHeadersReceived(headersToSend); - Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); + Assertions.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); } @Test @@ -86,20 +86,20 @@ public void cookieStaysAfterMultipleRequests() { headersToSend.insert(SET_COOKIE_HEADER, "k=v"); cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); cookieMiddleware.onHeadersReceived(headersToSend); - Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); + Assertions.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); headersToSend = new ErrorFlightMetadata(); cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); cookieMiddleware.onHeadersReceived(headersToSend); - Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); + Assertions.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); headersToSend = new ErrorFlightMetadata(); cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); cookieMiddleware.onHeadersReceived(headersToSend); - Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); + Assertions.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); } - @Ignore + @Disabled @Test public void cookieAutoExpires() { CallHeaders headersToSend = new ErrorFlightMetadata(); @@ -107,12 +107,12 @@ public void cookieAutoExpires() { cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); cookieMiddleware.onHeadersReceived(headersToSend); // Note: using max-age changes cookie version from 0->1, which quotes values. - Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString()); + Assertions.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString()); headersToSend = new ErrorFlightMetadata(); cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); cookieMiddleware.onHeadersReceived(headersToSend); - Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString()); + Assertions.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString()); try { Thread.sleep(5000); @@ -120,7 +120,7 @@ public void cookieAutoExpires() { } // Verify that the k cookie was discarded because it expired. - Assert.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty()); + Assertions.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty()); } @Test @@ -130,7 +130,7 @@ public void cookieExplicitlyExpires() { cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); cookieMiddleware.onHeadersReceived(headersToSend); // Note: using max-age changes cookie version from 0->1, which quotes values. - Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString()); + Assertions.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString()); // Note: The JDK treats Max-Age < 0 as not expired and treats 0 as expired. // This violates the RFC, which states that less than zero and zero should both be expired. @@ -140,10 +140,10 @@ public void cookieExplicitlyExpires() { cookieMiddleware.onHeadersReceived(headersToSend); // Verify that the k cookie was discarded because the server told the client it is expired. - Assert.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty()); + Assertions.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty()); } - @Ignore + @Disabled @Test public void cookieExplicitlyExpiresWithMaxAgeMinusOne() { CallHeaders headersToSend = new ErrorFlightMetadata(); @@ -151,7 +151,7 @@ public void cookieExplicitlyExpiresWithMaxAgeMinusOne() { cookieMiddleware = testFactory.onCallStarted(new CallInfo(FlightMethod.DO_ACTION)); cookieMiddleware.onHeadersReceived(headersToSend); // Note: using max-age changes cookie version from 0->1, which quotes values. - Assert.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString()); + Assertions.assertEquals("k=\"v\"", cookieMiddleware.getValidCookiesAsString()); headersToSend = new ErrorFlightMetadata(); @@ -162,7 +162,7 @@ public void cookieExplicitlyExpiresWithMaxAgeMinusOne() { cookieMiddleware.onHeadersReceived(headersToSend); // Verify that the k cookie was discarded because the server told the client it is expired. - Assert.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty()); + Assertions.assertTrue(cookieMiddleware.getValidCookiesAsString().isEmpty()); } @Test @@ -170,12 +170,12 @@ public void changeCookieValue() { CallHeaders headersToSend = new ErrorFlightMetadata(); headersToSend.insert(SET_COOKIE_HEADER, "k=v"); cookieMiddleware.onHeadersReceived(headersToSend); - Assert.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); + Assertions.assertEquals("k=v", cookieMiddleware.getValidCookiesAsString()); headersToSend = new ErrorFlightMetadata(); headersToSend.insert(SET_COOKIE_HEADER, "k=v2"); cookieMiddleware.onHeadersReceived(headersToSend); - Assert.assertEquals("k=v2", cookieMiddleware.getValidCookiesAsString()); + Assertions.assertEquals("k=v2", cookieMiddleware.getValidCookiesAsString()); } @Test @@ -184,17 +184,17 @@ public void multipleCookiesWithSetCookie() { headersToSend.insert(SET_COOKIE_HEADER, "firstKey=firstVal"); headersToSend.insert(SET_COOKIE_HEADER, "secondKey=secondVal"); cookieMiddleware.onHeadersReceived(headersToSend); - Assert.assertEquals("firstKey=firstVal; secondKey=secondVal", cookieMiddleware.getValidCookiesAsString()); + Assertions.assertEquals("firstKey=firstVal; secondKey=secondVal", cookieMiddleware.getValidCookiesAsString()); } @Test public void cookieStaysAfterMultipleRequestsEndToEnd() { client.handshake(); - Assert.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString()); + Assertions.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString()); client.handshake(); - Assert.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString()); + Assertions.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString()); client.listFlights(Criteria.ALL); - Assert.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString()); + Assertions.assertEquals("k=v", testFactory.clientCookieMiddleware.getValidCookiesAsString()); } /** diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestStatusUtils.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestStatusUtils.java index 5d76e8ae144..9912a26ea34 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestStatusUtils.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/grpc/TestStatusUtils.java @@ -19,8 +19,8 @@ import org.apache.arrow.flight.CallStatus; import org.apache.arrow.flight.FlightStatusCode; -import org.junit.Assert; -import org.junit.Test; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; import io.grpc.Metadata; import io.grpc.Status; @@ -40,12 +40,12 @@ public void testParseTrailers() { CallStatus callStatus = StatusUtils.fromGrpcStatusAndTrailers(status, trailers); - Assert.assertEquals(FlightStatusCode.CANCELLED, callStatus.code()); - Assert.assertTrue(callStatus.metadata().containsKey(":status")); - Assert.assertEquals("502", callStatus.metadata().get(":status")); - Assert.assertTrue(callStatus.metadata().containsKey("date")); - Assert.assertEquals("Fri, 13 Sep 2015 11:23:58 GMT", callStatus.metadata().get("date")); - Assert.assertTrue(callStatus.metadata().containsKey("content-type")); - Assert.assertEquals("text/html", callStatus.metadata().get("content-type")); + Assertions.assertEquals(FlightStatusCode.CANCELLED, callStatus.code()); + Assertions.assertTrue(callStatus.metadata().containsKey(":status")); + Assertions.assertEquals("502", callStatus.metadata().get(":status")); + Assertions.assertTrue(callStatus.metadata().containsKey("date")); + Assertions.assertEquals("Fri, 13 Sep 2015 11:23:58 GMT", callStatus.metadata().get("date")); + Assertions.assertTrue(callStatus.metadata().containsKey("content-type")); + Assertions.assertEquals("text/html", callStatus.metadata().get("content-type")); } } diff --git a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java index 9e2d7cc544f..bc9f9cba305 100644 --- a/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java +++ b/java/flight/flight-core/src/test/java/org/apache/arrow/flight/perf/TestPerf.java @@ -38,7 +38,8 @@ import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; -import org.junit.Test; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import com.google.common.base.MoreObjects; import com.google.common.base.Stopwatch; @@ -49,7 +50,7 @@ import com.google.common.util.concurrent.MoreExecutors; import com.google.protobuf.ByteString; -@org.junit.Ignore +@Disabled public class TestPerf { public static final boolean VALIDATE = false; diff --git a/java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java b/java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java index 142a0f93734..9010f2d4a98 100644 --- a/java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java +++ b/java/flight/flight-grpc/src/test/java/org/apache/arrow/flight/TestFlightGrpcUtils.java @@ -26,10 +26,10 @@ import org.apache.arrow.flight.auth.ServerAuthHandler; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.junit.After; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import com.google.protobuf.Empty; @@ -49,7 +49,7 @@ public class TestFlightGrpcUtils { private BufferAllocator allocator; private String serverName; - @Before + @BeforeEach public void setup() throws IOException { //Defines flight service allocator = new RootAllocator(Integer.MAX_VALUE); @@ -69,7 +69,7 @@ public void setup() throws IOException { server.start(); } - @After + @AfterEach public void cleanup() { server.shutdownNow(); } @@ -95,7 +95,7 @@ public void testMultipleGrpcServices() throws IOException { //Define Test client as a blocking stub and call test method which correctly returns an empty protobuf object final TestServiceGrpc.TestServiceBlockingStub blockingStub = TestServiceGrpc.newBlockingStub(managedChannel); - Assert.assertEquals(Empty.newBuilder().build(), blockingStub.test(Empty.newBuilder().build())); + Assertions.assertEquals(Empty.newBuilder().build(), blockingStub.test(Empty.newBuilder().build())); } @Test @@ -111,9 +111,9 @@ public void testShutdown() throws IOException, InterruptedException { // Should be a no-op. flightClient.close(); - Assert.assertFalse(managedChannel.isShutdown()); - Assert.assertFalse(managedChannel.isTerminated()); - Assert.assertEquals(ConnectivityState.IDLE, managedChannel.getState(false)); + Assertions.assertFalse(managedChannel.isShutdown()); + Assertions.assertFalse(managedChannel.isTerminated()); + Assertions.assertEquals(ConnectivityState.IDLE, managedChannel.getState(false)); managedChannel.shutdownNow(); } @@ -126,22 +126,22 @@ public void testProxyChannel() throws IOException, InterruptedException { final FlightGrpcUtils.NonClosingProxyManagedChannel proxyChannel = new FlightGrpcUtils.NonClosingProxyManagedChannel(managedChannel); - Assert.assertFalse(proxyChannel.isShutdown()); - Assert.assertFalse(proxyChannel.isTerminated()); + Assertions.assertFalse(proxyChannel.isShutdown()); + Assertions.assertFalse(proxyChannel.isTerminated()); proxyChannel.shutdown(); - Assert.assertTrue(proxyChannel.isShutdown()); - Assert.assertTrue(proxyChannel.isTerminated()); - Assert.assertEquals(ConnectivityState.SHUTDOWN, proxyChannel.getState(false)); + Assertions.assertTrue(proxyChannel.isShutdown()); + Assertions.assertTrue(proxyChannel.isTerminated()); + Assertions.assertEquals(ConnectivityState.SHUTDOWN, proxyChannel.getState(false)); try { proxyChannel.newCall(null, null); - Assert.fail(); + Assertions.fail(); } catch (IllegalStateException e) { // This is expected, since the proxy channel is shut down. } - Assert.assertFalse(managedChannel.isShutdown()); - Assert.assertFalse(managedChannel.isTerminated()); - Assert.assertEquals(ConnectivityState.IDLE, managedChannel.getState(false)); + Assertions.assertFalse(managedChannel.isShutdown()); + Assertions.assertFalse(managedChannel.isTerminated()); + Assertions.assertEquals(ConnectivityState.IDLE, managedChannel.getState(false)); managedChannel.shutdownNow(); } @@ -155,22 +155,22 @@ public void testProxyChannelWithClosedChannel() throws IOException, InterruptedE final FlightGrpcUtils.NonClosingProxyManagedChannel proxyChannel = new FlightGrpcUtils.NonClosingProxyManagedChannel(managedChannel); - Assert.assertFalse(proxyChannel.isShutdown()); - Assert.assertFalse(proxyChannel.isTerminated()); + Assertions.assertFalse(proxyChannel.isShutdown()); + Assertions.assertFalse(proxyChannel.isTerminated()); managedChannel.shutdownNow(); - Assert.assertTrue(proxyChannel.isShutdown()); - Assert.assertTrue(proxyChannel.isTerminated()); - Assert.assertEquals(ConnectivityState.SHUTDOWN, proxyChannel.getState(false)); + Assertions.assertTrue(proxyChannel.isShutdown()); + Assertions.assertTrue(proxyChannel.isTerminated()); + Assertions.assertEquals(ConnectivityState.SHUTDOWN, proxyChannel.getState(false)); try { proxyChannel.newCall(null, null); - Assert.fail(); + Assertions.fail(); } catch (IllegalStateException e) { // This is expected, since the proxy channel is shut down. } - Assert.assertTrue(managedChannel.isShutdown()); - Assert.assertTrue(managedChannel.isTerminated()); - Assert.assertEquals(ConnectivityState.SHUTDOWN, managedChannel.getState(false)); + Assertions.assertTrue(managedChannel.isShutdown()); + Assertions.assertTrue(managedChannel.isTerminated()); + Assertions.assertEquals(ConnectivityState.SHUTDOWN, managedChannel.getState(false)); } /** diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java index 1d7305fcf2f..d2f73b63737 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java @@ -67,12 +67,12 @@ import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.Text; import org.hamcrest.Matcher; -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ErrorCollector; +import org.hamcrest.MatcherAssert; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; import com.google.common.collect.ImmutableList; @@ -95,10 +95,8 @@ public class TestFlightSql { private static BufferAllocator allocator; private static FlightServer server; private static FlightSqlClient sqlClient; - @Rule - public final ErrorCollector collector = new ErrorCollector(); - @BeforeClass + @BeforeAll public static void setUp() throws Exception { allocator = new RootAllocator(Integer.MAX_VALUE); @@ -136,7 +134,7 @@ public static void setUp() throws Exception { Integer.toString(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE_VALUE)); } - @AfterClass + @AfterAll public static void tearDown() throws Exception { close(sqlClient, server, allocator); } @@ -177,13 +175,13 @@ private static List> getNonConformingResultsForGetSqlInfo( @Test public void testGetTablesSchema() { final FlightInfo info = sqlClient.getTables(null, null, null, null, true); - collector.checkThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA)); + MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA)); } @Test public void testGetTablesSchemaExcludeSchema() { final FlightInfo info = sqlClient.getTables(null, null, null, null, false); - collector.checkThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA)); + MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA)); } @Test @@ -192,36 +190,42 @@ public void testGetTablesResultNoSchema() throws Exception { sqlClient.getStream( sqlClient.getTables(null, null, null, null, false) .getEndpoints().get(0).getTicket())) { - collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA)); - final List> results = getResults(stream); - final List> expectedResults = ImmutableList.of( - // catalog_name | schema_name | table_name | table_type | table_schema - asList(null /* TODO No catalog yet */, "SYS", "SYSALIASES", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSCHECKS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSCOLPERMS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSCOLUMNS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSCONGLOMERATES", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSCONSTRAINTS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSDEPENDS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSFILES", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSFOREIGNKEYS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSKEYS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSPERMS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSROLES", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSROUTINEPERMS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSSCHEMAS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSSEQUENCES", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSSTATEMENTS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSSTATISTICS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSTABLEPERMS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSTABLES", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSTRIGGERS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSUSERS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSVIEWS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYSIBM", "SYSDUMMY1", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "APP", "FOREIGNTABLE", "TABLE"), - asList(null /* TODO No catalog yet */, "APP", "INTTABLE", "TABLE")); - collector.checkThat(results, is(expectedResults)); + Assertions.assertAll( + () -> { + MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA)); + }, + () -> { + final List> results = getResults(stream); + final List> expectedResults = ImmutableList.of( + // catalog_name | schema_name | table_name | table_type | table_schema + asList(null /* TODO No catalog yet */, "SYS", "SYSALIASES", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSCHECKS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSCOLPERMS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSCOLUMNS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSCONGLOMERATES", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSCONSTRAINTS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSDEPENDS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSFILES", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSFOREIGNKEYS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSKEYS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSPERMS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSROLES", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSROUTINEPERMS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSSCHEMAS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSSEQUENCES", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSSTATEMENTS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSSTATISTICS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSTABLEPERMS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSTABLES", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSTRIGGERS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSUSERS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYS", "SYSVIEWS", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "SYSIBM", "SYSDUMMY1", "SYSTEM TABLE"), + asList(null /* TODO No catalog yet */, "APP", "FOREIGNTABLE", "TABLE"), + asList(null /* TODO No catalog yet */, "APP", "INTTABLE", "TABLE")); + MatcherAssert.assertThat(results, is(expectedResults)); + } + ); } } @@ -231,13 +235,18 @@ public void testGetTablesResultFilteredNoSchema() throws Exception { sqlClient.getStream( sqlClient.getTables(null, null, null, singletonList("TABLE"), false) .getEndpoints().get(0).getTicket())) { - collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA)); - final List> results = getResults(stream); - final List> expectedResults = ImmutableList.of( - // catalog_name | schema_name | table_name | table_type | table_schema - asList(null /* TODO No catalog yet */, "APP", "FOREIGNTABLE", "TABLE"), - asList(null /* TODO No catalog yet */, "APP", "INTTABLE", "TABLE")); - collector.checkThat(results, is(expectedResults)); + + Assertions.assertAll( + () -> MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA)), + () -> { + final List> results = getResults(stream); + final List> expectedResults = ImmutableList.of( + // catalog_name | schema_name | table_name | table_type | table_schema + asList(null /* TODO No catalog yet */, "APP", "FOREIGNTABLE", "TABLE"), + asList(null /* TODO No catalog yet */, "APP", "INTTABLE", "TABLE")); + MatcherAssert.assertThat(results, is(expectedResults)); + } + ); } } @@ -247,104 +256,115 @@ public void testGetTablesResultFilteredWithSchema() throws Exception { sqlClient.getStream( sqlClient.getTables(null, null, null, singletonList("TABLE"), true) .getEndpoints().get(0).getTicket())) { - collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA)); - final List> results = getResults(stream); - final List> expectedResults = ImmutableList.of( - // catalog_name | schema_name | table_name | table_type | table_schema - asList( - null /* TODO No catalog yet */, - "APP", - "FOREIGNTABLE", - "TABLE", - new Schema(asList( - new Field("ID", new FieldType(false, MinorType.INT.getType(), null, - new FlightSqlColumnMetadata.Builder() - .catalogName("") - .typeName("INTEGER") - .schemaName("APP") - .tableName("FOREIGNTABLE") - .precision(10) - .scale(0) - .isAutoIncrement(true) - .build().getMetadataMap()), null), - new Field("FOREIGNNAME", new FieldType(true, MinorType.VARCHAR.getType(), null, - new FlightSqlColumnMetadata.Builder() - .catalogName("") - .typeName("VARCHAR") - .schemaName("APP") - .tableName("FOREIGNTABLE") - .precision(100) - .scale(0) - .isAutoIncrement(false) - .build().getMetadataMap()), null), - new Field("VALUE", new FieldType(true, MinorType.INT.getType(), null, - new FlightSqlColumnMetadata.Builder() - .catalogName("") - .typeName("INTEGER") - .schemaName("APP") - .tableName("FOREIGNTABLE") - .precision(10) - .scale(0) - .isAutoIncrement(false) - .build().getMetadataMap()), null))).toJson()), - asList( - null /* TODO No catalog yet */, - "APP", - "INTTABLE", - "TABLE", - new Schema(asList( - new Field("ID", new FieldType(false, MinorType.INT.getType(), null, - new FlightSqlColumnMetadata.Builder() - .catalogName("") - .typeName("INTEGER") - .schemaName("APP") - .tableName("INTTABLE") - .precision(10) - .scale(0) - .isAutoIncrement(true) - .build().getMetadataMap()), null), - new Field("KEYNAME", new FieldType(true, MinorType.VARCHAR.getType(), null, - new FlightSqlColumnMetadata.Builder() - .catalogName("") - .typeName("VARCHAR") - .schemaName("APP") - .tableName("INTTABLE") - .precision(100) - .scale(0) - .isAutoIncrement(false) - .build().getMetadataMap()), null), - new Field("VALUE", new FieldType(true, MinorType.INT.getType(), null, - new FlightSqlColumnMetadata.Builder() - .catalogName("") - .typeName("INTEGER") - .schemaName("APP") - .tableName("INTTABLE") - .precision(10) - .scale(0) - .isAutoIncrement(false) - .build().getMetadataMap()), null), - new Field("FOREIGNID", new FieldType(true, MinorType.INT.getType(), null, - new FlightSqlColumnMetadata.Builder() - .catalogName("") - .typeName("INTEGER") - .schemaName("APP") - .tableName("INTTABLE") - .precision(10) - .scale(0) - .isAutoIncrement(false) - .build().getMetadataMap()), null))).toJson())); - collector.checkThat(results, is(expectedResults)); + Assertions.assertAll( + () -> MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA)), + () -> { + MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA)); + final List> results = getResults(stream); + final List> expectedResults = ImmutableList.of( + // catalog_name | schema_name | table_name | table_type | table_schema + asList( + null /* TODO No catalog yet */, + "APP", + "FOREIGNTABLE", + "TABLE", + new Schema(asList( + new Field("ID", new FieldType(false, MinorType.INT.getType(), null, + new FlightSqlColumnMetadata.Builder() + .catalogName("") + .typeName("INTEGER") + .schemaName("APP") + .tableName("FOREIGNTABLE") + .precision(10) + .scale(0) + .isAutoIncrement(true) + .build().getMetadataMap()), null), + new Field("FOREIGNNAME", new FieldType(true, MinorType.VARCHAR.getType(), null, + new FlightSqlColumnMetadata.Builder() + .catalogName("") + .typeName("VARCHAR") + .schemaName("APP") + .tableName("FOREIGNTABLE") + .precision(100) + .scale(0) + .isAutoIncrement(false) + .build().getMetadataMap()), null), + new Field("VALUE", new FieldType(true, MinorType.INT.getType(), null, + new FlightSqlColumnMetadata.Builder() + .catalogName("") + .typeName("INTEGER") + .schemaName("APP") + .tableName("FOREIGNTABLE") + .precision(10) + .scale(0) + .isAutoIncrement(false) + .build().getMetadataMap()), null))).toJson()), + asList( + null /* TODO No catalog yet */, + "APP", + "INTTABLE", + "TABLE", + new Schema(asList( + new Field("ID", new FieldType(false, MinorType.INT.getType(), null, + new FlightSqlColumnMetadata.Builder() + .catalogName("") + .typeName("INTEGER") + .schemaName("APP") + .tableName("INTTABLE") + .precision(10) + .scale(0) + .isAutoIncrement(true) + .build().getMetadataMap()), null), + new Field("KEYNAME", new FieldType(true, MinorType.VARCHAR.getType(), null, + new FlightSqlColumnMetadata.Builder() + .catalogName("") + .typeName("VARCHAR") + .schemaName("APP") + .tableName("INTTABLE") + .precision(100) + .scale(0) + .isAutoIncrement(false) + .build().getMetadataMap()), null), + new Field("VALUE", new FieldType(true, MinorType.INT.getType(), null, + new FlightSqlColumnMetadata.Builder() + .catalogName("") + .typeName("INTEGER") + .schemaName("APP") + .tableName("INTTABLE") + .precision(10) + .scale(0) + .isAutoIncrement(false) + .build().getMetadataMap()), null), + new Field("FOREIGNID", new FieldType(true, MinorType.INT.getType(), null, + new FlightSqlColumnMetadata.Builder() + .catalogName("") + .typeName("INTEGER") + .schemaName("APP") + .tableName("INTTABLE") + .precision(10) + .scale(0) + .isAutoIncrement(false) + .build().getMetadataMap()), null))).toJson())); + MatcherAssert.assertThat(results, is(expectedResults)); + } + ); } } @Test public void testSimplePreparedStatementSchema() throws Exception { try (final PreparedStatement preparedStatement = sqlClient.prepare("SELECT * FROM intTable")) { - final Schema actualSchema = preparedStatement.getResultSetSchema(); - collector.checkThat(actualSchema, is(SCHEMA_INT_TABLE)); - - final FlightInfo info = preparedStatement.execute(); - collector.checkThat(info.getSchema(), is(SCHEMA_INT_TABLE)); + Assertions.assertAll( + () -> { + final Schema actualSchema = preparedStatement.getResultSetSchema(); + MatcherAssert.assertThat(actualSchema, is(SCHEMA_INT_TABLE)); + + }, + () -> { + final FlightInfo info = preparedStatement.execute(); + MatcherAssert.assertThat(info.getSchema(), is(SCHEMA_INT_TABLE)); + } + ); } } @@ -353,8 +373,10 @@ public void testSimplePreparedStatementResults() throws Exception { try (final PreparedStatement preparedStatement = sqlClient.prepare("SELECT * FROM intTable"); final FlightStream stream = sqlClient.getStream( preparedStatement.execute().getEndpoints().get(0).getTicket())) { - collector.checkThat(stream.getSchema(), is(SCHEMA_INT_TABLE)); - collector.checkThat(getResults(stream), is(EXPECTED_RESULTS_FOR_STAR_SELECT_QUERY)); + Assertions.assertAll( + () -> MatcherAssert.assertThat(stream.getSchema(), is(SCHEMA_INT_TABLE)), + () -> MatcherAssert.assertThat(getResults(stream), is(EXPECTED_RESULTS_FOR_STAR_SELECT_QUERY)) + ); } } @@ -376,8 +398,10 @@ public void testSimplePreparedStatementResultsWithParameterBinding() throws Exce .getEndpoints() .get(0).getTicket()); - collector.checkThat(stream.getSchema(), is(SCHEMA_INT_TABLE)); - collector.checkThat(getResults(stream), is(EXPECTED_RESULTS_FOR_PARAMETER_BINDING)); + Assertions.assertAll( + () -> MatcherAssert.assertThat(stream.getSchema(), is(SCHEMA_INT_TABLE)), + () -> MatcherAssert.assertThat(getResults(stream), is(EXPECTED_RESULTS_FOR_PARAMETER_BINDING)) + ); } } } @@ -410,9 +434,10 @@ public void testSimplePreparedStatementUpdateResults() throws SQLException { deletePrepare.setParameters(deleteRoot); deletedRows = deletePrepare.executeUpdate(); } - - collector.checkThat(updatedRows, is(10L)); - collector.checkThat(deletedRows, is(10L)); + Assertions.assertAll( + () -> MatcherAssert.assertThat(updatedRows, is(10L)), + () -> MatcherAssert.assertThat(deletedRows, is(10L)) + ); } } } @@ -426,84 +451,108 @@ public void testSimplePreparedStatementUpdateResultsWithoutParameters() throws S final long deletedRows = deletePrepare.executeUpdate(); - collector.checkThat(updatedRows, is(1L)); - collector.checkThat(deletedRows, is(1L)); + Assertions.assertAll( + () -> MatcherAssert.assertThat(updatedRows, is(1L)), + () -> MatcherAssert.assertThat(deletedRows, is(1L)) + ); } } @Test public void testSimplePreparedStatementClosesProperly() { final PreparedStatement preparedStatement = sqlClient.prepare("SELECT * FROM intTable"); - collector.checkThat(preparedStatement.isClosed(), is(false)); - preparedStatement.close(); - collector.checkThat(preparedStatement.isClosed(), is(true)); + Assertions.assertAll( + () -> { + MatcherAssert.assertThat(preparedStatement.isClosed(), is(false)); + }, + () -> { + preparedStatement.close(); + MatcherAssert.assertThat(preparedStatement.isClosed(), is(true)); + } + ); } @Test public void testGetCatalogsSchema() { final FlightInfo info = sqlClient.getCatalogs(); - collector.checkThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA)); + MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA)); } @Test public void testGetCatalogsResults() throws Exception { try (final FlightStream stream = sqlClient.getStream(sqlClient.getCatalogs().getEndpoints().get(0).getTicket())) { - collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA)); - List> catalogs = getResults(stream); - collector.checkThat(catalogs, is(emptyList())); + Assertions.assertAll( + () -> MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA)), + () -> { + List> catalogs = getResults(stream); + MatcherAssert.assertThat(catalogs, is(emptyList())); + } + ); } } @Test public void testGetTableTypesSchema() { final FlightInfo info = sqlClient.getTableTypes(); - collector.checkThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA)); + MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA)); } @Test public void testGetTableTypesResult() throws Exception { try (final FlightStream stream = sqlClient.getStream(sqlClient.getTableTypes().getEndpoints().get(0).getTicket())) { - collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA)); - final List> tableTypes = getResults(stream); - final List> expectedTableTypes = ImmutableList.of( - // table_type - singletonList("SYNONYM"), - singletonList("SYSTEM TABLE"), - singletonList("TABLE"), - singletonList("VIEW") + Assertions.assertAll( + () -> { + MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA)); + }, + () -> { + final List> tableTypes = getResults(stream); + final List> expectedTableTypes = ImmutableList.of( + // table_type + singletonList("SYNONYM"), + singletonList("SYSTEM TABLE"), + singletonList("TABLE"), + singletonList("VIEW") + ); + MatcherAssert.assertThat(tableTypes, is(expectedTableTypes)); + } ); - collector.checkThat(tableTypes, is(expectedTableTypes)); } } @Test public void testGetSchemasSchema() { final FlightInfo info = sqlClient.getSchemas(null, null); - collector.checkThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA)); + MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA)); } @Test public void testGetSchemasResult() throws Exception { try (final FlightStream stream = sqlClient.getStream(sqlClient.getSchemas(null, null).getEndpoints().get(0).getTicket())) { - collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA)); - final List> schemas = getResults(stream); - final List> expectedSchemas = ImmutableList.of( - // catalog_name | schema_name - asList(null /* TODO Add catalog. */, "APP"), - asList(null /* TODO Add catalog. */, "NULLID"), - asList(null /* TODO Add catalog. */, "SQLJ"), - asList(null /* TODO Add catalog. */, "SYS"), - asList(null /* TODO Add catalog. */, "SYSCAT"), - asList(null /* TODO Add catalog. */, "SYSCS_DIAG"), - asList(null /* TODO Add catalog. */, "SYSCS_UTIL"), - asList(null /* TODO Add catalog. */, "SYSFUN"), - asList(null /* TODO Add catalog. */, "SYSIBM"), - asList(null /* TODO Add catalog. */, "SYSPROC"), - asList(null /* TODO Add catalog. */, "SYSSTAT")); - collector.checkThat(schemas, is(expectedSchemas)); + Assertions.assertAll( + () -> { + MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA)); + }, + () -> { + final List> schemas = getResults(stream); + final List> expectedSchemas = ImmutableList.of( + // catalog_name | schema_name + asList(null /* TODO Add catalog. */, "APP"), + asList(null /* TODO Add catalog. */, "NULLID"), + asList(null /* TODO Add catalog. */, "SQLJ"), + asList(null /* TODO Add catalog. */, "SYS"), + asList(null /* TODO Add catalog. */, "SYSCAT"), + asList(null /* TODO Add catalog. */, "SYSCS_DIAG"), + asList(null /* TODO Add catalog. */, "SYSCS_UTIL"), + asList(null /* TODO Add catalog. */, "SYSFUN"), + asList(null /* TODO Add catalog. */, "SYSIBM"), + asList(null /* TODO Add catalog. */, "SYSPROC"), + asList(null /* TODO Add catalog. */, "SYSSTAT")); + MatcherAssert.assertThat(schemas, is(expectedSchemas)); + } + ); } } @@ -513,30 +562,37 @@ public void testGetPrimaryKey() { final FlightStream stream = sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket()); final List> results = getResults(stream); - collector.checkThat(results.size(), is(1)); - final List result = results.get(0); - - collector.checkThat(result.get(0), is("")); - collector.checkThat(result.get(1), is("APP")); - collector.checkThat(result.get(2), is("INTTABLE")); - collector.checkThat(result.get(3), is("ID")); - collector.checkThat(result.get(4), is("1")); - collector.checkThat(result.get(5), notNullValue()); + Assertions.assertAll( + () -> MatcherAssert.assertThat(results.size(), is(1)), + () -> { + final List result = results.get(0); + Assertions.assertAll( + () -> MatcherAssert.assertThat(result.get(0), is("")), + () -> MatcherAssert.assertThat(result.get(1), is("APP")), + () -> MatcherAssert.assertThat(result.get(2), is("INTTABLE")), + () -> MatcherAssert.assertThat(result.get(3), is("ID")), + () -> MatcherAssert.assertThat(result.get(4), is("1")), + () -> MatcherAssert.assertThat(result.get(5), notNullValue()) + ); + } + ); } @Test public void testGetSqlInfoSchema() { final FlightInfo info = sqlClient.getSqlInfo(); - collector.checkThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)); + MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)); } @Test public void testGetSqlInfoResults() throws Exception { final FlightInfo info = sqlClient.getSqlInfo(); try (final FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { - collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)); - collector.checkThat(getNonConformingResultsForGetSqlInfo(getResults(stream)), is(emptyList())); + Assertions.assertAll( + () -> MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)), + () -> MatcherAssert.assertThat(getNonConformingResultsForGetSqlInfo(getResults(stream)), is(emptyList())) + ); } } @@ -545,8 +601,10 @@ public void testGetSqlInfoResultsWithSingleArg() throws Exception { final FlightSql.SqlInfo arg = FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME; final FlightInfo info = sqlClient.getSqlInfo(arg); try (final FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { - collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)); - collector.checkThat(getNonConformingResultsForGetSqlInfo(getResults(stream), arg), is(emptyList())); + Assertions.assertAll( + () -> MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)), + () -> MatcherAssert.assertThat(getNonConformingResultsForGetSqlInfo(getResults(stream), arg), is(emptyList())) + ); } } @@ -557,8 +615,16 @@ public void testGetSqlInfoResultsWithTwoArgs() throws Exception { FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION}; final FlightInfo info = sqlClient.getSqlInfo(args); try (final FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { - collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)); - collector.checkThat(getNonConformingResultsForGetSqlInfo(getResults(stream), args), is(emptyList())); + Assertions.assertAll( + () -> MatcherAssert.assertThat( + stream.getSchema(), + is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA) + ), + () -> MatcherAssert.assertThat( + getNonConformingResultsForGetSqlInfo(getResults(stream), args), + is(emptyList()) + ) + ); } } @@ -570,8 +636,16 @@ public void testGetSqlInfoResultsWithThreeArgs() throws Exception { FlightSql.SqlInfo.SQL_IDENTIFIER_QUOTE_CHAR}; final FlightInfo info = sqlClient.getSqlInfo(args); try (final FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { - collector.checkThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)); - collector.checkThat(getNonConformingResultsForGetSqlInfo(getResults(stream), args), is(emptyList())); + Assertions.assertAll( + () -> MatcherAssert.assertThat( + stream.getSchema(), + is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA) + ), + () -> MatcherAssert.assertThat( + getNonConformingResultsForGetSqlInfo(getResults(stream), args), + is(emptyList()) + ) + ); } } @@ -599,10 +673,14 @@ public void testGetCommandExportedKeys() { is("3"), // update_rule is("3")); // delete_rule - Assert.assertEquals(1, results.size()); + final List assertions = new ArrayList<>(); + Assertions.assertEquals(1, results.size()); for (int i = 0; i < matchers.size(); i++) { - collector.checkThat(results.get(0).get(i), matchers.get(i)); + final String actual = results.get(0).get(i); + final Matcher expected = matchers.get(i); + assertions.add(() -> MatcherAssert.assertThat(actual, expected)); } + Assertions.assertAll(assertions); } @Test @@ -629,10 +707,14 @@ public void testGetCommandImportedKeys() { is("3"), // update_rule is("3")); // delete_rule - Assert.assertEquals(1, results.size()); + Assertions.assertEquals(1, results.size()); + final List assertions = new ArrayList<>(); for (int i = 0; i < matchers.size(); i++) { - collector.checkThat(results.get(0).get(i), matchers.get(i)); + final String actual = results.get(0).get(i); + final Matcher expected = matchers.get(i); + assertions.add(() -> MatcherAssert.assertThat(actual, expected)); } + Assertions.assertAll(assertions); } @Test @@ -710,7 +792,7 @@ public void testGetTypeInfo() { asList("XML", "2009", null, null, null, emptyList().toString(), "1", "true", "0", "false", "false", "false", "XML", null, null, null, null, null, null)); - collector.checkThat(results, is(matchers)); + MatcherAssert.assertThat(results, is(matchers)); } @Test @@ -725,7 +807,7 @@ public void testGetTypeInfoWithFiltering() { asList("BIGINT", "-5", "19", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "true", "BIGINT", "0", "0", null, null, "10", null)); - collector.checkThat(results, is(matchers)); + MatcherAssert.assertThat(results, is(matchers)); } @Test @@ -751,16 +833,20 @@ public void testGetCommandCrossReference() { is("3"), // update_rule is("3")); // delete_rule - Assert.assertEquals(1, results.size()); + Assertions.assertEquals(1, results.size()); + final List assertions = new ArrayList<>(); for (int i = 0; i < matchers.size(); i++) { - collector.checkThat(results.get(0).get(i), matchers.get(i)); + final String actual = results.get(0).get(i); + final Matcher expected = matchers.get(i); + assertions.add(() -> MatcherAssert.assertThat(actual, expected)); } + Assertions.assertAll(assertions); } @Test public void testCreateStatementSchema() throws Exception { final FlightInfo info = sqlClient.execute("SELECT * FROM intTable"); - collector.checkThat(info.getSchema(), is(SCHEMA_INT_TABLE)); + MatcherAssert.assertThat(info.getSchema(), is(SCHEMA_INT_TABLE)); // Consume statement to close connection before cache eviction try (FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { @@ -774,8 +860,14 @@ public void testCreateStatementSchema() throws Exception { public void testCreateStatementResults() throws Exception { try (final FlightStream stream = sqlClient .getStream(sqlClient.execute("SELECT * FROM intTable").getEndpoints().get(0).getTicket())) { - collector.checkThat(stream.getSchema(), is(SCHEMA_INT_TABLE)); - collector.checkThat(getResults(stream), is(EXPECTED_RESULTS_FOR_STAR_SELECT_QUERY)); + Assertions.assertAll( + () -> { + MatcherAssert.assertThat(stream.getSchema(), is(SCHEMA_INT_TABLE)); + }, + () -> { + MatcherAssert.assertThat(getResults(stream), is(EXPECTED_RESULTS_FOR_STAR_SELECT_QUERY)); + } + ); } } @@ -865,16 +957,24 @@ List> getResults(FlightStream stream) { @Test public void testExecuteUpdate() { - long insertedCount = sqlClient.executeUpdate("INSERT INTO INTTABLE (keyName, value) VALUES " + - "('KEYNAME1', 1001), ('KEYNAME2', 1002), ('KEYNAME3', 1003)"); - collector.checkThat(insertedCount, is(3L)); - - long updatedCount = sqlClient.executeUpdate("UPDATE INTTABLE SET keyName = 'KEYNAME1' " + - "WHERE keyName = 'KEYNAME2' OR keyName = 'KEYNAME3'"); - collector.checkThat(updatedCount, is(2L)); - - long deletedCount = sqlClient.executeUpdate("DELETE FROM INTTABLE WHERE keyName = 'KEYNAME1'"); - collector.checkThat(deletedCount, is(3L)); + Assertions.assertAll( + () -> { + long insertedCount = sqlClient.executeUpdate("INSERT INTO INTTABLE (keyName, value) VALUES " + + "('KEYNAME1', 1001), ('KEYNAME2', 1002), ('KEYNAME3', 1003)"); + MatcherAssert.assertThat(insertedCount, is(3L)); + + }, + () -> { + long updatedCount = sqlClient.executeUpdate("UPDATE INTTABLE SET keyName = 'KEYNAME1' " + + "WHERE keyName = 'KEYNAME2' OR keyName = 'KEYNAME3'"); + MatcherAssert.assertThat(updatedCount, is(2L)); + + }, + () -> { + long deletedCount = sqlClient.executeUpdate("DELETE FROM INTTABLE WHERE keyName = 'KEYNAME1'"); + MatcherAssert.assertThat(deletedCount, is(3L)); + } + ); } @Test @@ -882,10 +982,13 @@ public void testQueryWithNoResultsShouldNotHang() throws Exception { try (final PreparedStatement preparedStatement = sqlClient.prepare("SELECT * FROM intTable WHERE 1 = 0"); final FlightStream stream = sqlClient .getStream(preparedStatement.execute().getEndpoints().get(0).getTicket())) { - collector.checkThat(stream.getSchema(), is(SCHEMA_INT_TABLE)); - - final List> result = getResults(stream); - collector.checkThat(result, is(emptyList())); + Assertions.assertAll( + () -> MatcherAssert.assertThat(stream.getSchema(), is(SCHEMA_INT_TABLE)), + () -> { + final List> result = getResults(stream); + MatcherAssert.assertThat(result, is(emptyList())); + } + ); } } } diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtilsBitmaskCreationTest.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtilsBitmaskCreationTest.java index 6f2b66646bb..dfb1b9da3e2 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtilsBitmaskCreationTest.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtilsBitmaskCreationTest.java @@ -22,29 +22,15 @@ import static org.apache.arrow.flight.sql.util.AdhocTestOption.OPTION_B; import static org.apache.arrow.flight.sql.util.AdhocTestOption.OPTION_C; import static org.apache.arrow.flight.sql.util.SqlInfoOptionsUtils.createBitmaskFromEnums; -import static org.hamcrest.CoreMatchers.is; import java.util.List; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ErrorCollector; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.junit.runners.Parameterized.Parameter; -import org.junit.runners.Parameterized.Parameters; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; -@RunWith(Parameterized.class) public final class SqlInfoOptionsUtilsBitmaskCreationTest { - @Parameter - public AdhocTestOption[] adhocTestOptions; - @Parameter(value = 1) - public long expectedBitmask; - @Rule - public final ErrorCollector collector = new ErrorCollector(); - - @Parameters public static List provideParameters() { return asList( new Object[][]{ @@ -59,8 +45,11 @@ public static List provideParameters() { }); } - @Test - public void testShouldBuildBitmaskFromEnums() { - collector.checkThat(createBitmaskFromEnums(adhocTestOptions), is(expectedBitmask)); + @ParameterizedTest + @MethodSource("provideParameters") + public void testShouldBuildBitmaskFromEnums( + AdhocTestOption[] adhocTestOptions, long expectedBitmask + ) { + Assertions.assertEquals(createBitmaskFromEnums(adhocTestOptions), expectedBitmask); } } diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtilsBitmaskParsingTest.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtilsBitmaskParsingTest.java index decee38ee0a..818326a582d 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtilsBitmaskParsingTest.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/util/SqlInfoOptionsUtilsBitmaskParsingTest.java @@ -24,31 +24,17 @@ import static org.apache.arrow.flight.sql.util.AdhocTestOption.OPTION_B; import static org.apache.arrow.flight.sql.util.AdhocTestOption.OPTION_C; import static org.apache.arrow.flight.sql.util.SqlInfoOptionsUtils.doesBitmaskTranslateToEnum; -import static org.hamcrest.CoreMatchers.is; import java.util.EnumSet; import java.util.List; import java.util.Set; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ErrorCollector; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.junit.runners.Parameterized.Parameter; -import org.junit.runners.Parameterized.Parameters; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; -@RunWith(Parameterized.class) public final class SqlInfoOptionsUtilsBitmaskParsingTest { - @Parameter - public long bitmask; - @Parameter(value = 1) - public Set expectedOptions; - @Rule - public final ErrorCollector collector = new ErrorCollector(); - - @Parameters public static List provideParameters() { return asList( new Object[][]{ @@ -63,12 +49,13 @@ public static List provideParameters() { }); } - @Test - public void testShouldFilterOutEnumsBasedOnBitmask() { + @ParameterizedTest + @MethodSource("provideParameters") + public void testShouldFilterOutEnumsBasedOnBitmask(long bitmask, Set expectedOptions) { final Set actualOptions = stream(AdhocTestOption.values()) .filter(enumInstance -> doesBitmaskTranslateToEnum(enumInstance, bitmask)) .collect(toCollection(() -> EnumSet.noneOf(AdhocTestOption.class))); - collector.checkThat(actualOptions, is(expectedOptions)); + Assertions.assertEquals(actualOptions, expectedOptions); } } diff --git a/java/performance/pom.xml b/java/performance/pom.xml index 479d5e5ab17..ba4c3476029 100644 --- a/java/performance/pom.xml +++ b/java/performance/pom.xml @@ -212,7 +212,7 @@ maven-surefire-plugin - 3.0.0-M3 + 3.0.0-M7 diff --git a/java/pom.xml b/java/pom.xml index 486765df2d3..ea5c46334a3 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -29,8 +29,8 @@ ${project.build.directory}/generated-sources - 1.4.0 - 5.4.0 + 1.9.0 + 5.9.0 1.7.25 30.1.1-jre 4.1.78.Final @@ -400,6 +400,18 @@ maven-surefire-plugin 3.0.0-M7 + + + org.junit.jupiter + junit-jupiter-engine + ${dep.junit.jupiter.version} + + + org.apache.maven.surefire + surefire-junit-platform + 3.0.0-M7 + + true true @@ -417,7 +429,7 @@ maven-failsafe-plugin - 3.0.0-M3 + 3.0.0-M7 ${project.build.directory} @@ -614,13 +626,6 @@ test - - - org.junit.platform - junit-platform-runner - ${dep.junit.platform.version} - test - org.junit.jupiter junit-jupiter-engine @@ -639,6 +644,12 @@ ${dep.junit.jupiter.version} test + + org.junit.jupiter + junit-jupiter-params + ${dep.junit.jupiter.version} + test + junit From 69266721472725f37095880a9b87d88f395462b6 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Fri, 16 Sep 2022 09:58:29 -0300 Subject: [PATCH 080/133] ARROW-17643: [R] Latest duckdb release is causing test failure (#14149) #14065 experimented with ways to solve this from Arrow's end; however, the fix in https://github.com/duckdb/duckdb/pull/4712 is probably more robust. Even if that fix doesn't make it in to the next DuckDB release, when that happens this test will start failing again in case we/I forget to check. Authored-by: Dewey Dunnington Signed-off-by: Neal Richardson --- r/tests/testthat/test-duckdb.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/r/tests/testthat/test-duckdb.R b/r/tests/testthat/test-duckdb.R index 088d7a4bbd7..817ca878337 100644 --- a/r/tests/testthat/test-duckdb.R +++ b/r/tests/testthat/test-duckdb.R @@ -202,6 +202,10 @@ dbExecute(con, "PRAGMA threads=2") on.exit(dbDisconnect(con, shutdown = TRUE), add = TRUE) test_that("Joining, auto-cleanup enabled", { + # ARROW-17643: A change in duckdb 0.5.0 caused this test to fail but will + # be fixed in the next release. + skip_if_not(packageVersion("duckdb") > "0.5.0") + ds <- InMemoryDataset$create(example_data) table_one_name <- "my_arrow_table_1" From 68e0fa7499876fc0cf86b8be784a890226648645 Mon Sep 17 00:00:00 2001 From: Rasmus Johansen Date: Fri, 16 Sep 2022 13:58:41 +0100 Subject: [PATCH 081/133] ARROW-17733: [C++] Take index_width into account when filling nulls in index buffer (#14129) Take into account index_width when offsetting by position into out_data. Otherwise we offset position bytes into the array, but we want to offset position places into the array. Authored-by: Rasmus Johansen Signed-off-by: David Li --- cpp/src/arrow/array/concatenate.cc | 4 ++-- cpp/src/arrow/array/concatenate_test.cc | 11 +++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/array/concatenate.cc b/cpp/src/arrow/array/concatenate.cc index 9f77fbb5f43..aab734284fa 100644 --- a/cpp/src/arrow/array/concatenate.cc +++ b/cpp/src/arrow/array/concatenate.cc @@ -311,8 +311,8 @@ class ConcatenateImpl { /*dest_offset=*/position, run.length, transpose_map)); } else { - std::fill(out_data + position, - out_data + position + (run.length * index_width), 0x00); + std::fill(out_data + (position * index_width), + out_data + (position + run.length) * index_width, 0x00); } position += run.length; diff --git a/cpp/src/arrow/array/concatenate_test.cc b/cpp/src/arrow/array/concatenate_test.cc index aacd7518928..bff5d7eec1e 100644 --- a/cpp/src/arrow/array/concatenate_test.cc +++ b/cpp/src/arrow/array/concatenate_test.cc @@ -539,4 +539,15 @@ TEST_F(ConcatenateTest, OffsetOverflow) { ASSERT_RAISES(Invalid, Concatenate({fake_long, fake_long}).status()); } +TEST_F(ConcatenateTest, DictionaryConcatenateWithEmptyUint16) { + // Regression test for ARROW-17733 + auto dict_type = dictionary(uint16(), utf8()); + auto dict_one = DictArrayFromJSON(dict_type, "[]", "[]"); + auto dict_two = + DictArrayFromJSON(dict_type, "[0, 1, null, null, null, null]", "[\"A0\", \"A1\"]"); + ASSERT_OK_AND_ASSIGN(auto concat_actual, Concatenate({dict_one, dict_two})); + + AssertArraysEqual(*dict_two, *concat_actual); +} + } // namespace arrow From d571e93ad24d5111800540b42a3b8d56459edd9b Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 16 Sep 2022 11:05:46 -0400 Subject: [PATCH 082/133] ARROW-17730: [Go] Implement Take kernels for FSB and VarBinary (#14127) Authored-by: Matt Topol Signed-off-by: Matt Topol --- .../internal/kernels/vector_selection.go | 177 +++++++++++++++++- go/arrow/compute/vector_selection_test.go | 51 +++++ 2 files changed, 218 insertions(+), 10 deletions(-) diff --git a/go/arrow/compute/internal/kernels/vector_selection.go b/go/arrow/compute/internal/kernels/vector_selection.go index fa1c33be59d..c4bfcca8bcd 100644 --- a/go/arrow/compute/internal/kernels/vector_selection.go +++ b/go/arrow/compute/internal/kernels/vector_selection.go @@ -991,19 +991,171 @@ func binaryFilterImpl[OffsetT int32 | int64](ctx *exec.KernelCtx, values, filter return nil } -func FilterFSB(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { +func takeExecImpl[T exec.UintTypes](ctx *exec.KernelCtx, outputLen int64, values, indices *exec.ArraySpan, out *exec.ExecResult, visitValid func(int64) error, visitNull func() error) error { var ( - values = &batch.Values[0].Array - selection = &batch.Values[1].Array - outputLength = getFilterOutputSize(selection, ctx.State.(FilterState).NullSelection) - valueSize = int64(values.Type.(arrow.FixedWidthDataType).Bytes()) - valueData = values.Buffers[1].Buf[values.Offset*valueSize:] + validityBuilder = validityBuilder{mem: exec.GetAllocator(ctx.Ctx)} + indicesValues = exec.GetSpanValues[T](indices, 1) + isValid = indices.Buffers[0].Buf + valuesHaveNulls = values.MayHaveNulls() + + indicesIsValid = bitutil.OptionalBitIndexer{Bitmap: isValid, Offset: int(indices.Offset)} + valuesIsValid = bitutil.OptionalBitIndexer{Bitmap: values.Buffers[0].Buf, Offset: int(values.Offset)} + bitCounter = bitutils.NewOptionalBitBlockCounter(isValid, indices.Offset, indices.Len) + pos int64 + ) + + validityBuilder.Reserve(outputLen) + for pos < indices.Len { + block := bitCounter.NextBlock() + indicesHaveNulls := block.Popcnt < block.Len + if !indicesHaveNulls && !valuesHaveNulls { + // fastest path, neither indices nor values have nulls + validityBuilder.UnsafeAppendN(int64(block.Len), true) + for i := 0; i < int(block.Len); i++ { + if err := visitValid(int64(indicesValues[pos])); err != nil { + return err + } + pos++ + } + } else if block.Popcnt > 0 { + // since we have to branch on whether indices are null or not, + // we combine the "non-null indices block but some values null" + // and "some null indices block but values non-null" into single loop + for i := 0; i < int(block.Len); i++ { + if (!indicesHaveNulls || indicesIsValid.GetBit(int(pos))) && valuesIsValid.GetBit(int(indicesValues[pos])) { + validityBuilder.UnsafeAppend(true) + if err := visitValid(int64(indicesValues[pos])); err != nil { + return err + } + } else { + validityBuilder.UnsafeAppend(false) + if err := visitNull(); err != nil { + return err + } + } + pos++ + } + } else { + // the whole block is null + validityBuilder.UnsafeAppendN(int64(block.Len), false) + for i := 0; i < int(block.Len); i++ { + if err := visitNull(); err != nil { + return err + } + } + pos += int64(block.Len) + } + } + + out.Len = int64(validityBuilder.bitLength) + out.Nulls = int64(validityBuilder.falseCount) + out.Buffers[0].WrapBuffer(validityBuilder.Finish()) + return nil +} + +func takeExec(ctx *exec.KernelCtx, outputLen int64, values, indices *exec.ArraySpan, out *exec.ExecResult, visitValid func(int64) error, visitNull func() error) error { + indexWidth := indices.Type.(arrow.FixedWidthDataType).Bytes() + + switch indexWidth { + case 1: + return takeExecImpl[uint8](ctx, outputLen, values, indices, out, visitValid, visitNull) + case 2: + return takeExecImpl[uint16](ctx, outputLen, values, indices, out, visitValid, visitNull) + case 4: + return takeExecImpl[uint32](ctx, outputLen, values, indices, out, visitValid, visitNull) + case 8: + return takeExecImpl[uint64](ctx, outputLen, values, indices, out, visitValid, visitNull) + default: + return fmt.Errorf("%w: invalid index width", arrow.ErrInvalid) + } +} + +type outputFn func(*exec.KernelCtx, int64, *exec.ArraySpan, *exec.ArraySpan, *exec.ExecResult, func(int64) error, func() error) error +type implFn func(*exec.KernelCtx, *exec.ExecSpan, int64, *exec.ExecResult, outputFn) error + +func FilterExec(impl implFn, fn outputFn) exec.ArrayKernelExec { + return func(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + var ( + selection = &batch.Values[1].Array + outputLength = getFilterOutputSize(selection, ctx.State.(FilterState).NullSelection) + ) + return impl(ctx, batch, outputLength, out, fn) + } +} + +func TakeExec(impl implFn, fn outputFn) exec.ArrayKernelExec { + return func(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + if ctx.State.(TakeState).BoundsCheck { + if err := checkIndexBounds(&batch.Values[1].Array, uint64(batch.Values[0].Array.Len)); err != nil { + return err + } + } + + return impl(ctx, batch, batch.Values[1].Array.Len, out, fn) + } +} + +func VarBinaryImpl[OffsetT int32 | int64](ctx *exec.KernelCtx, batch *exec.ExecSpan, outputLength int64, out *exec.ExecResult, fn outputFn) error { + var ( + values = &batch.Values[0].Array + selection = &batch.Values[1].Array + rawOffsets = exec.GetSpanOffsets[OffsetT](values, 1) + rawData = values.Buffers[2].Buf + offsetBuilder = newBufferBuilder[OffsetT](exec.GetAllocator(ctx.Ctx)) + dataBuilder = newBufferBuilder[uint8](exec.GetAllocator(ctx.Ctx)) + ) + + // presize the data builder with a rough estimate of the required data size + if values.Len > 0 { + dataLength := rawOffsets[values.Len] - rawOffsets[0] + meanValueLen := float64(dataLength) / float64(values.Len) + dataBuilder.reserve(int(meanValueLen)) + } + + offsetBuilder.reserve(int(outputLength) + 1) + spaceAvail := dataBuilder.cap() + var offset OffsetT + err := fn(ctx, outputLength, values, selection, out, + func(idx int64) error { + offsetBuilder.unsafeAppend(offset) + valOffset := rawOffsets[idx] + valSize := rawOffsets[idx+1] - valOffset + + offset += valSize + if valSize > OffsetT(spaceAvail) { + dataBuilder.reserve(int(valSize)) + spaceAvail = dataBuilder.cap() - dataBuilder.len() + } + dataBuilder.unsafeAppendSlice(rawData[valOffset : valOffset+valSize]) + spaceAvail -= int(valSize) + return nil + }, func() error { + offsetBuilder.unsafeAppend(offset) + return nil + }) + + if err != nil { + return err + } + + offsetBuilder.unsafeAppend(offset) + out.Buffers[1].WrapBuffer(offsetBuilder.finish()) + out.Buffers[2].WrapBuffer(dataBuilder.finish()) + return nil +} + +func FSBImpl(ctx *exec.KernelCtx, batch *exec.ExecSpan, outputLength int64, out *exec.ExecResult, fn outputFn) error { + var ( + values = &batch.Values[0].Array + selection = &batch.Values[1].Array + valueSize = int64(values.Type.(arrow.FixedWidthDataType).Bytes()) + valueData = values.Buffers[1].Buf[values.Offset*valueSize:] ) out.Buffers[1].WrapBuffer(ctx.Allocate(int(valueSize * outputLength))) buf := out.Buffers[1].Buf - err := filterExec(ctx, outputLength, values, selection, out, + err := fn(ctx, outputLength, values, selection, out, func(idx int64) error { start := idx * int64(valueSize) copy(buf, valueData[start:start+valueSize]) @@ -1076,9 +1228,9 @@ func GetVectorSelectionKernels() (filterkernels, takeKernels []SelectionKernelDa filterkernels = []SelectionKernelData{ {In: exec.NewMatchedInput(exec.Primitive()), Exec: PrimitiveFilter}, {In: exec.NewExactInput(arrow.Null), Exec: NullFilter}, - {In: exec.NewIDInput(arrow.DECIMAL128), Exec: FilterFSB}, - {In: exec.NewIDInput(arrow.DECIMAL256), Exec: FilterFSB}, - {In: exec.NewIDInput(arrow.FIXED_SIZE_BINARY), Exec: FilterFSB}, + {In: exec.NewIDInput(arrow.DECIMAL128), Exec: FilterExec(FSBImpl, filterExec)}, + {In: exec.NewIDInput(arrow.DECIMAL256), Exec: FilterExec(FSBImpl, filterExec)}, + {In: exec.NewIDInput(arrow.FIXED_SIZE_BINARY), Exec: FilterExec(FSBImpl, filterExec)}, {In: exec.NewMatchedInput(exec.BinaryLike()), Exec: FilterBinary}, {In: exec.NewMatchedInput(exec.LargeBinaryLike()), Exec: FilterBinary}, } @@ -1086,6 +1238,11 @@ func GetVectorSelectionKernels() (filterkernels, takeKernels []SelectionKernelDa takeKernels = []SelectionKernelData{ {In: exec.NewExactInput(arrow.Null), Exec: NullTake}, {In: exec.NewMatchedInput(exec.Primitive()), Exec: PrimitiveTake}, + {In: exec.NewIDInput(arrow.DECIMAL128), Exec: TakeExec(FSBImpl, takeExec)}, + {In: exec.NewIDInput(arrow.DECIMAL256), Exec: TakeExec(FSBImpl, takeExec)}, + {In: exec.NewIDInput(arrow.FIXED_SIZE_BINARY), Exec: TakeExec(FSBImpl, takeExec)}, + {In: exec.NewMatchedInput(exec.BinaryLike()), Exec: TakeExec(VarBinaryImpl[int32], takeExec)}, + {In: exec.NewMatchedInput(exec.LargeBinaryLike()), Exec: TakeExec(VarBinaryImpl[int64], takeExec)}, } return } diff --git a/go/arrow/compute/vector_selection_test.go b/go/arrow/compute/vector_selection_test.go index 59b0e1d07ba..e5fdfbcb776 100644 --- a/go/arrow/compute/vector_selection_test.go +++ b/go/arrow/compute/vector_selection_test.go @@ -663,11 +663,62 @@ func (tk *TakeKernelTestNumeric) TestTakeNumeric() { }) } +type TakeKernelTestFSB struct { + TakeKernelTestTyped +} + +func (tk *TakeKernelTestFSB) SetupSuite() { + tk.dt = &arrow.FixedSizeBinaryType{ByteWidth: 3} +} + +func (tk *TakeKernelTestFSB) TestFixedSizeBinary() { + // YWFh == base64("aaa") + // YmJi == base64("bbb") + // Y2Nj == base64("ccc") + tk.assertTake(`["YWFh", "YmJi", "Y2Nj"]`, `[0, 1, 0]`, `["YWFh", "YmJi", "YWFh"]`) + tk.assertTake(`[null, "YmJi", "Y2Nj"]`, `[0, 1, 0]`, `[null, "YmJi", null]`) + tk.assertTake(`["YWFh", "YmJi", "Y2Nj"]`, `[null, 1, 0]`, `[null, "YmJi", "YWFh"]`) + + tk.assertNoValidityBitmapUnknownNullCountJSON(tk.dt, `["YWFh", "YmJi", "Y2Nj"]`, `[0, 1, 0]`) + + _, err := tk.takeJSON(tk.dt, `["YWFh", "YmJi", "Y2Nj"]`, arrow.PrimitiveTypes.Int8, `[0, 9, 0]`) + tk.ErrorIs(err, arrow.ErrIndex) + _, err = tk.takeJSON(tk.dt, `["YWFh", "YmJi", "Y2Nj"]`, arrow.PrimitiveTypes.Int64, `[2, 5]`) + tk.ErrorIs(err, arrow.ErrIndex) +} + +type TakeKernelTestString struct { + TakeKernelTestTyped +} + +func (tk *TakeKernelTestString) TestTakeString() { + tk.Run(tk.dt.String(), func() { + // base64 encoded so the binary non-utf8 arrays work + // YQ== -> "a" + // Yg== -> "b" + // Yw== -> "c" + tk.assertTake(`["YQ==", "Yg==", "Yw=="]`, `[0, 1, 0]`, `["YQ==", "Yg==", "YQ=="]`) + tk.assertTake(`[null, "Yg==", "Yw=="]`, `[0, 1, 0]`, `[null, "Yg==", null]`) + tk.assertTake(`["YQ==", "Yg==", "Yw=="]`, `[null, 1, 0]`, `[null, "Yg==", "YQ=="]`) + + tk.assertNoValidityBitmapUnknownNullCountJSON(tk.dt, `["YQ==", "Yg==", "Yw=="]`, `[0, 1, 0]`) + + _, err := tk.takeJSON(tk.dt, `["YQ==", "Yg==", "Yw=="]`, arrow.PrimitiveTypes.Int8, `[0, 9, 0]`) + tk.ErrorIs(err, arrow.ErrIndex) + _, err = tk.takeJSON(tk.dt, `["YQ==", "Yg==", "Yw=="]`, arrow.PrimitiveTypes.Int64, `[2, 5]`) + tk.ErrorIs(err, arrow.ErrIndex) + }) +} + func TestTakeKernels(t *testing.T) { suite.Run(t, new(TakeKernelTest)) for _, dt := range numericTypes { suite.Run(t, &TakeKernelTestNumeric{TakeKernelTestTyped: TakeKernelTestTyped{dt: dt}}) } + suite.Run(t, new(TakeKernelTestFSB)) + for _, dt := range baseBinaryTypes { + suite.Run(t, &TakeKernelTestString{TakeKernelTestTyped: TakeKernelTestTyped{dt: dt}}) + } } func TestFilterKernels(t *testing.T) { From 3ce40143f8a836df058ec5fe1b29d9da5ede169d Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 16 Sep 2022 11:06:22 -0400 Subject: [PATCH 083/133] ARROW-17688: [C++][Java][FlightRPC] Substrait, transaction, cancellation for Flight SQL (#13492) "[VOTE] Substrait for Flight SQL" https://lists.apache.org/thread/3k3np6314dwb0n7n1hrfwony5fcy7kzl Authored-by: David Li Signed-off-by: David Li --- .../flight_integration_test.cc | 4 + .../integration_tests/test_integration.cc | 656 ++++++++++++++++-- cpp/src/arrow/flight/server.cc | 4 +- cpp/src/arrow/flight/sql/CMakeLists.txt | 27 +- cpp/src/arrow/flight/sql/acero_test.cc | 239 +++++++ cpp/src/arrow/flight/sql/client.cc | 355 +++++++++- cpp/src/arrow/flight/sql/client.h | 152 +++- cpp/src/arrow/flight/sql/client_test.cc | 1 + cpp/src/arrow/flight/sql/column_metadata.cc | 8 +- cpp/src/arrow/flight/sql/column_metadata.h | 8 +- .../arrow/flight/sql/example/acero_main.cc | 70 ++ .../arrow/flight/sql/example/acero_server.cc | 306 ++++++++ .../arrow/flight/sql/example/acero_server.h | 37 + .../arrow/flight/sql/example/sqlite_server.cc | 205 ++++-- .../arrow/flight/sql/example/sqlite_server.h | 6 + .../flight/sql/example/sqlite_sql_info.cc | 9 +- cpp/src/arrow/flight/sql/server.cc | 424 +++++++++-- cpp/src/arrow/flight/sql/server.h | 171 ++++- cpp/src/arrow/flight/sql/server_test.cc | 83 ++- cpp/src/arrow/flight/sql/types.h | 79 ++- dev/archery/archery/integration/runner.py | 5 + docs/source/status.rst | 14 + format/FlightSql.proto | 298 +++++++- .../org/apache/arrow/flight/FlightClient.java | 7 +- .../apache/arrow/flight/FlightService.java | 2 +- .../tests/FlightSqlExtensionScenario.java | 217 ++++++ .../integration/tests/FlightSqlScenario.java | 52 +- .../tests/FlightSqlScenarioProducer.java | 383 ++++++++-- .../tests/IntegrationAssertions.java | 11 + .../flight/integration/tests/Scenarios.java | 1 + .../integration/tests/IntegrationTest.java | 70 ++ .../arrow/flight/sql/CancelListener.java | 51 ++ .../apache/arrow/flight/sql/CancelResult.java | 45 ++ .../arrow/flight/sql/FlightSqlClient.java | 451 +++++++++++- .../arrow/flight/sql/FlightSqlProducer.java | 198 +++++- .../arrow/flight/sql/FlightSqlUtils.java | 35 + .../arrow/flight/sql/NoResultListener.java | 45 ++ .../arrow/flight/sql/ProtoListener.java | 52 ++ .../arrow/flight/sql/SqlInfoBuilder.java | 41 ++ .../flight/sql/example/FlightSqlExample.java | 3 + 40 files changed, 4445 insertions(+), 380 deletions(-) create mode 100644 cpp/src/arrow/flight/sql/acero_test.cc create mode 100644 cpp/src/arrow/flight/sql/example/acero_main.cc create mode 100644 cpp/src/arrow/flight/sql/example/acero_server.cc create mode 100644 cpp/src/arrow/flight/sql/example/acero_server.h create mode 100644 java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlExtensionScenario.java create mode 100644 java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java create mode 100644 java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelListener.java create mode 100644 java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelResult.java create mode 100644 java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoResultListener.java create mode 100644 java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/ProtoListener.java diff --git a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc index 706ac3b7d93..e29a281f327 100644 --- a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc +++ b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc @@ -55,6 +55,10 @@ TEST(FlightIntegration, Middleware) { ASSERT_OK(RunScenario("middleware")); } TEST(FlightIntegration, FlightSql) { ASSERT_OK(RunScenario("flight_sql")); } +TEST(FlightIntegration, FlightSqlExtension) { + ASSERT_OK(RunScenario("flight_sql:extension")); +} + } // namespace integration_tests } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index b228f9cceba..43c16e0b77a 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -16,25 +16,36 @@ // under the License. #include "arrow/flight/integration_tests/test_integration.h" + +#include +#include +#include +#include +#include +#include + +#include "arrow/array/array_binary.h" +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" #include "arrow/flight/client_middleware.h" #include "arrow/flight/server_middleware.h" #include "arrow/flight/sql/client.h" #include "arrow/flight/sql/column_metadata.h" #include "arrow/flight/sql/server.h" +#include "arrow/flight/sql/types.h" #include "arrow/flight/test_util.h" #include "arrow/flight/types.h" #include "arrow/ipc/dictionary.h" +#include "arrow/status.h" #include "arrow/testing/gtest_util.h" - -#include -#include -#include -#include -#include +#include "arrow/util/checked_cast.h" namespace arrow { namespace flight { namespace integration_tests { +namespace { + +using arrow::internal::checked_cast; /// \brief The server for the basic auth integration test. class AuthBasicProtoServer : public FlightServerBase { @@ -263,29 +274,56 @@ class MiddlewareScenario : public Scenario { }; /// \brief Schema to be returned for mocking the statement/prepared statement results. +/// /// Must be the same across all languages. -std::shared_ptr GetQuerySchema() { - std::string table_name = "test"; - std::string schema_name = "schema_test"; - std::string catalog_name = "catalog_test"; - std::string type_name = "type_test"; - return arrow::schema({arrow::field("id", int64(), true, - arrow::flight::sql::ColumnMetadata::Builder() - .TableName(table_name) - .IsAutoIncrement(true) - .IsCaseSensitive(false) - .TypeName(type_name) - .SchemaName(schema_name) - .IsSearchable(true) - .CatalogName(catalog_name) - .Precision(100) - .Build() - .metadata_map())}); +const std::shared_ptr& GetQuerySchema() { + static std::shared_ptr kSchema = + schema({field("id", int64(), /*nullable=*/true, + arrow::flight::sql::ColumnMetadata::Builder() + .TableName("test") + .IsAutoIncrement(true) + .IsCaseSensitive(false) + .TypeName("type_test") + .SchemaName("schema_test") + .IsSearchable(true) + .CatalogName("catalog_test") + .Precision(100) + .Build() + .metadata_map())}); + return kSchema; +} + +/// \brief Schema to be returned for queries with transactions. +/// +/// Must be the same across all languages. +std::shared_ptr GetQueryWithTransactionSchema() { + static std::shared_ptr kSchema = + schema({field("pkey", int32(), /*nullable=*/true, + arrow::flight::sql::ColumnMetadata::Builder() + .TableName("test") + .IsAutoIncrement(true) + .IsCaseSensitive(false) + .TypeName("type_test") + .SchemaName("schema_test") + .IsSearchable(true) + .CatalogName("catalog_test") + .Precision(100) + .Build() + .metadata_map())}); + return kSchema; } constexpr int64_t kUpdateStatementExpectedRows = 10000L; +constexpr int64_t kUpdateStatementWithTransactionExpectedRows = 15000L; constexpr int64_t kUpdatePreparedStatementExpectedRows = 20000L; +constexpr int64_t kUpdatePreparedStatementWithTransactionExpectedRows = 25000L; constexpr char kSelectStatement[] = "SELECT STATEMENT"; +constexpr char kSavepointId[] = "savepoint_id"; +constexpr char kSavepointName[] = "savepoint_name"; +constexpr char kSubstraitPlanText[] = "plan"; +constexpr char kSubstraitVersion[] = "version"; +static const sql::SubstraitPlan kSubstraitPlan{kSubstraitPlanText, kSubstraitVersion}; +constexpr char kTransactionId[] = "transaction_id"; template arrow::Status AssertEq(const T& expected, const T& actual, const std::string& message) { @@ -296,25 +334,83 @@ arrow::Status AssertEq(const T& expected, const T& actual, const std::string& me return Status::OK(); } +template +arrow::Status AssertUnprintableEq(const T& expected, const T& actual, + const std::string& message) { + if (expected != actual) { + return Status::Invalid(message); + } + return Status::OK(); +} + /// \brief The server used for testing Flight SQL, this implements a static Flight SQL /// server which only asserts that commands called during integration tests are being /// parsed correctly and returns the expected schemas to be validated on client. class FlightSqlScenarioServer : public sql::FlightSqlServerBase { public: + FlightSqlScenarioServer() : sql::FlightSqlServerBase() { + RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SQL, + sql::SqlInfoResult(false)); + RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT, + sql::SqlInfoResult(true)); + RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION, + sql::SqlInfoResult(std::string("min_version"))); + RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION, + sql::SqlInfoResult(std::string("max_version"))); + RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION, + sql::SqlInfoResult(sql::SqlInfoOptions::SqlSupportedTransaction:: + SQL_SUPPORTED_TRANSACTION_SAVEPOINT)); + RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_CANCEL, + sql::SqlInfoResult(true)); + RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT, + sql::SqlInfoResult(int32_t(42))); + RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT, + sql::SqlInfoResult(int32_t(7))); + } arrow::Result> GetFlightInfoStatement( const ServerCallContext& context, const sql::StatementQuery& command, const FlightDescriptor& descriptor) override { ARROW_RETURN_NOT_OK( AssertEq(kSelectStatement, command.query, "Unexpected statement in GetFlightInfoStatement")); - - ARROW_ASSIGN_OR_RAISE(auto handle, - sql::CreateStatementQueryTicket("SELECT STATEMENT HANDLE")); - + std::string ticket; + Schema* schema; + if (command.transaction_id.empty()) { + ticket = "SELECT STATEMENT HANDLE"; + schema = GetQuerySchema().get(); + } else { + ticket = "SELECT STATEMENT WITH TXN HANDLE"; + schema = GetQueryWithTransactionSchema().get(); + } + ARROW_ASSIGN_OR_RAISE(auto handle, sql::CreateStatementQueryTicket(ticket)); std::vector endpoints{FlightEndpoint{{handle}, {}}}; - ARROW_ASSIGN_OR_RAISE( - auto result, FlightInfo::Make(*GetQuerySchema(), descriptor, endpoints, -1, -1)) + ARROW_ASSIGN_OR_RAISE(auto result, + FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)); + return std::unique_ptr(new FlightInfo(result)); + } + arrow::Result> GetFlightInfoSubstraitPlan( + const ServerCallContext& context, const sql::StatementSubstraitPlan& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK( + AssertEq(kSubstraitPlanText, command.plan.plan, + "Unexpected plan in GetFlightInfoSubstraitPlan")); + ARROW_RETURN_NOT_OK( + AssertEq(kSubstraitVersion, command.plan.version, + "Unexpected version in GetFlightInfoSubstraitPlan")); + std::string ticket; + Schema* schema; + if (command.transaction_id.empty()) { + ticket = "PLAN HANDLE"; + schema = GetQuerySchema().get(); + } else { + ticket = "PLAN WITH TXN HANDLE"; + schema = GetQueryWithTransactionSchema().get(); + } + ARROW_ASSIGN_OR_RAISE(auto handle, sql::CreateStatementQueryTicket(ticket)); + std::vector endpoints{FlightEndpoint{{handle}, {}}}; + ARROW_ASSIGN_OR_RAISE(auto result, + FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)); return std::unique_ptr(new FlightInfo(result)); } @@ -323,38 +419,84 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { const FlightDescriptor& descriptor) override { ARROW_RETURN_NOT_OK(AssertEq( kSelectStatement, command.query, "Unexpected statement in GetSchemaStatement")); - return SchemaResult::Make(*GetQuerySchema()); + if (command.transaction_id.empty()) { + return SchemaResult::Make(*GetQuerySchema()); + } else { + return SchemaResult::Make(*GetQueryWithTransactionSchema()); + } + } + + arrow::Result> GetSchemaSubstraitPlan( + const ServerCallContext& context, const sql::StatementSubstraitPlan& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK( + AssertEq(kSubstraitPlanText, command.plan.plan, + "Unexpected statement in GetSchemaSubstraitPlan")); + ARROW_RETURN_NOT_OK( + AssertEq(kSubstraitVersion, command.plan.version, + "Unexpected version in GetFlightInfoSubstraitPlan")); + if (command.transaction_id.empty()) { + return SchemaResult::Make(*GetQuerySchema()); + } else { + return SchemaResult::Make(*GetQueryWithTransactionSchema()); + } } arrow::Result> DoGetStatement( const ServerCallContext& context, const sql::StatementQueryTicket& command) override { - return DoGetForTestCase(GetQuerySchema()); + if (command.statement_handle == "SELECT STATEMENT HANDLE" || + command.statement_handle == "PLAN HANDLE") { + return DoGetForTestCase(GetQuerySchema()); + } else if (command.statement_handle == "SELECT STATEMENT WITH TXN HANDLE" || + command.statement_handle == "PLAN WITH TXN HANDLE") { + return DoGetForTestCase(GetQueryWithTransactionSchema()); + } + return Status::Invalid("Unknown handle: ", command.statement_handle); } arrow::Result> GetFlightInfoPreparedStatement( const ServerCallContext& context, const sql::PreparedStatementQuery& command, const FlightDescriptor& descriptor) override { - ARROW_RETURN_NOT_OK(AssertEq("SELECT PREPARED STATEMENT HANDLE", - command.prepared_statement_handle, - "Unexpected prepared statement handle")); - - return GetFlightInfoForCommand(descriptor, GetQuerySchema()); + if (command.prepared_statement_handle == "SELECT PREPARED STATEMENT HANDLE" || + command.prepared_statement_handle == "PLAN HANDLE") { + return GetFlightInfoForCommand(descriptor, GetQuerySchema()); + } else if (command.prepared_statement_handle == + "SELECT PREPARED STATEMENT WITH TXN HANDLE" || + command.prepared_statement_handle == "PLAN WITH TXN HANDLE") { + return GetFlightInfoForCommand(descriptor, GetQueryWithTransactionSchema()); + } + return Status::Invalid("Invalid handle for GetFlightInfoForCommand: ", + command.prepared_statement_handle); } arrow::Result> GetSchemaPreparedStatement( const ServerCallContext& context, const sql::PreparedStatementQuery& command, const FlightDescriptor& descriptor) override { - ARROW_RETURN_NOT_OK(AssertEq("SELECT PREPARED STATEMENT HANDLE", - command.prepared_statement_handle, - "Unexpected prepared statement handle")); - return SchemaResult::Make(*GetQuerySchema()); + if (command.prepared_statement_handle == "SELECT PREPARED STATEMENT HANDLE" || + command.prepared_statement_handle == "PLAN HANDLE") { + return SchemaResult::Make(*GetQuerySchema()); + } else if (command.prepared_statement_handle == + "SELECT PREPARED STATEMENT WITH TXN HANDLE" || + command.prepared_statement_handle == "PLAN WITH TXN HANDLE") { + return SchemaResult::Make(*GetQueryWithTransactionSchema()); + } + return Status::Invalid("Invalid handle for GetSchemaPreparedStatement: ", + command.prepared_statement_handle); } arrow::Result> DoGetPreparedStatement( const ServerCallContext& context, const sql::PreparedStatementQuery& command) override { - return DoGetForTestCase(GetQuerySchema()); + if (command.prepared_statement_handle == "SELECT PREPARED STATEMENT HANDLE" || + command.prepared_statement_handle == "PLAN HANDLE") { + return DoGetForTestCase(GetQuerySchema()); + } else if (command.prepared_statement_handle == + "SELECT PREPARED STATEMENT WITH TXN HANDLE" || + command.prepared_statement_handle == "PLAN WITH TXN HANDLE") { + return DoGetForTestCase(GetQueryWithTransactionSchema()); + } + return Status::Invalid("Invalid handle: ", command.prepared_statement_handle); } arrow::Result> GetFlightInfoCatalogs( @@ -381,21 +523,29 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { arrow::Result> GetFlightInfoSqlInfo( const ServerCallContext& context, const sql::GetSqlInfo& command, const FlightDescriptor& descriptor) override { - ARROW_RETURN_NOT_OK(AssertEq(2, command.info.size(), - "Wrong number of SqlInfo values passed")); - ARROW_RETURN_NOT_OK( - AssertEq(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, - command.info[0], "Unexpected SqlInfo passed")); - ARROW_RETURN_NOT_OK( - AssertEq(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY, - command.info[1], "Unexpected SqlInfo passed")); - - return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetSqlInfoSchema()); + if (command.info.size() == 2) { + // Integration test for the protocol messages + ARROW_RETURN_NOT_OK( + AssertEq(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, + command.info[0], "Unexpected SqlInfo passed")); + ARROW_RETURN_NOT_OK( + AssertEq(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY, + command.info[1], "Unexpected SqlInfo passed")); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetSqlInfoSchema()); + } + // Integration test for the values themselves + return sql::FlightSqlServerBase::GetFlightInfoSqlInfo(context, command, descriptor); } arrow::Result> DoGetSqlInfo( const ServerCallContext& context, const sql::GetSqlInfo& command) override { - return DoGetForTestCase(sql::SqlSchema::GetSqlInfoSchema()); + if (command.info.size() == 2) { + // Integration test for the protocol messages + return DoGetForTestCase(sql::SqlSchema::GetSqlInfoSchema()); + } + // Integration test for the values themselves + return sql::FlightSqlServerBase::DoGetSqlInfo(context, command); } arrow::Result> GetFlightInfoSchemas( @@ -539,8 +689,21 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { ARROW_RETURN_NOT_OK( AssertEq("UPDATE STATEMENT", command.query, "Wrong query for DoPutCommandStatementUpdate")); + return command.transaction_id.empty() ? kUpdateStatementExpectedRows + : kUpdateStatementWithTransactionExpectedRows; + } - return kUpdateStatementExpectedRows; + arrow::Result DoPutCommandSubstraitPlan( + const ServerCallContext& context, + const sql::StatementSubstraitPlan& command) override { + ARROW_RETURN_NOT_OK( + AssertEq(kSubstraitPlanText, command.plan.plan, + "Wrong plan for DoPutCommandSubstraitPlan")); + ARROW_RETURN_NOT_OK( + AssertEq(kSubstraitVersion, command.plan.version, + "Unexpected version in GetFlightInfoSubstraitPlan")); + return command.transaction_id.empty() ? kUpdateStatementExpectedRows + : kUpdateStatementWithTransactionExpectedRows; } arrow::Result CreatePreparedStatement( @@ -552,8 +715,26 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { } sql::ActionCreatePreparedStatementResult result; - result.prepared_statement_handle = request.query + " HANDLE"; + result.prepared_statement_handle = request.query; + if (!request.transaction_id.empty()) { + result.prepared_statement_handle += " WITH TXN"; + } + result.prepared_statement_handle += " HANDLE"; + return result; + } + arrow::Result CreatePreparedSubstraitPlan( + const ServerCallContext& context, + const sql::ActionCreatePreparedSubstraitPlanRequest& request) override { + ARROW_RETURN_NOT_OK( + AssertEq(kSubstraitPlanText, request.plan.plan, + "Wrong plan for CreatePreparedSubstraitPlan")); + ARROW_RETURN_NOT_OK( + AssertEq(kSubstraitVersion, request.plan.version, + "Unexpected version in GetFlightInfoSubstraitPlan")); + sql::ActionCreatePreparedStatementResult result; + result.prepared_statement_handle = + request.transaction_id.empty() ? "PLAN HANDLE" : "PLAN WITH TXN HANDLE"; return result; } @@ -561,7 +742,13 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { const ServerCallContext& context, const sql::ActionClosePreparedStatementRequest& request) override { if (request.prepared_statement_handle != "SELECT PREPARED STATEMENT HANDLE" && - request.prepared_statement_handle != "UPDATE PREPARED STATEMENT HANDLE") { + request.prepared_statement_handle != "UPDATE PREPARED STATEMENT HANDLE" && + request.prepared_statement_handle != "PLAN HANDLE" && + request.prepared_statement_handle != + "SELECT PREPARED STATEMENT WITH TXN HANDLE" && + request.prepared_statement_handle != + "UPDATE PREPARED STATEMENT WITH TXN HANDLE" && + request.prepared_statement_handle != "PLAN WITH TXN HANDLE") { return Status::Invalid("Invalid handle for ClosePreparedStatement: ", request.prepared_statement_handle); } @@ -572,28 +759,95 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { const sql::PreparedStatementQuery& command, FlightMessageReader* reader, FlightMetadataWriter* writer) override { - if (command.prepared_statement_handle != "SELECT PREPARED STATEMENT HANDLE") { + if (command.prepared_statement_handle != "SELECT PREPARED STATEMENT HANDLE" && + command.prepared_statement_handle != + "SELECT PREPARED STATEMENT WITH TXN HANDLE" && + command.prepared_statement_handle != "PLAN HANDLE" && + command.prepared_statement_handle != "PLAN WITH TXN HANDLE") { return Status::Invalid("Invalid handle for DoPutPreparedStatementQuery: ", command.prepared_statement_handle); } - ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema()); ARROW_RETURN_NOT_OK(AssertEq(*GetQuerySchema(), *actual_schema, "Wrong schema for DoPutPreparedStatementQuery")); - return Status::OK(); } arrow::Result DoPutPreparedStatementUpdate( const ServerCallContext& context, const sql::PreparedStatementUpdate& command, FlightMessageReader* reader) override { - if (command.prepared_statement_handle == "UPDATE PREPARED STATEMENT HANDLE") { + if (command.prepared_statement_handle == "UPDATE PREPARED STATEMENT HANDLE" || + command.prepared_statement_handle == "PLAN HANDLE") { return kUpdatePreparedStatementExpectedRows; + } else if (command.prepared_statement_handle == + "UPDATE PREPARED STATEMENT WITH TXN HANDLE" || + command.prepared_statement_handle == "PLAN WITH TXN HANDLE") { + return kUpdatePreparedStatementWithTransactionExpectedRows; } return Status::Invalid("Invalid handle for DoPutPreparedStatementUpdate: ", command.prepared_statement_handle); } + arrow::Result BeginSavepoint( + const ServerCallContext& context, + const sql::ActionBeginSavepointRequest& request) override { + ARROW_RETURN_NOT_OK(AssertEq( + kSavepointName, request.name, "Unexpected savepoint name in BeginSavepoint")); + ARROW_RETURN_NOT_OK( + AssertEq(kTransactionId, request.transaction_id, + "Unexpected transaction ID in BeginSavepoint")); + return sql::ActionBeginSavepointResult{kSavepointId}; + } + + arrow::Result BeginTransaction( + const ServerCallContext& context, + const sql::ActionBeginTransactionRequest& request) override { + return sql::ActionBeginTransactionResult{kTransactionId}; + } + + arrow::Result CancelQuery( + const ServerCallContext& context, + const sql::ActionCancelQueryRequest& request) override { + ARROW_RETURN_NOT_OK(AssertEq(1, request.info->endpoints().size(), + "Expected 1 endpoint for CancelQuery")); + const FlightEndpoint& endpoint = request.info->endpoints()[0]; + ARROW_ASSIGN_OR_RAISE(auto ticket, + sql::StatementQueryTicket::Deserialize(endpoint.ticket.ticket)); + ARROW_RETURN_NOT_OK(AssertEq("PLAN HANDLE", ticket.statement_handle, + "Unexpected ticket in CancelQuery")); + return sql::CancelResult::kCancelled; + } + + Status EndSavepoint(const ServerCallContext& context, + const sql::ActionEndSavepointRequest& request) override { + switch (request.action) { + case sql::ActionEndSavepointRequest::kRelease: + case sql::ActionEndSavepointRequest::kRollback: + ARROW_RETURN_NOT_OK( + AssertEq(kSavepointId, request.savepoint_id, + "Unexpected savepoint ID in EndSavepoint")); + break; + default: + return Status::Invalid("Unknown action ", static_cast(request.action)); + } + return Status::OK(); + } + + Status EndTransaction(const ServerCallContext& context, + const sql::ActionEndTransactionRequest& request) override { + switch (request.action) { + case sql::ActionEndTransactionRequest::kCommit: + case sql::ActionEndTransactionRequest::kRollback: + ARROW_RETURN_NOT_OK( + AssertEq(kTransactionId, request.transaction_id, + "Unexpected transaction ID in EndTransaction")); + break; + default: + return Status::Invalid("Unknown action ", static_cast(request.action)); + } + return Status::OK(); + } + private: arrow::Result> GetFlightInfoForCommand( const FlightDescriptor& descriptor, const std::shared_ptr& schema) { @@ -615,6 +869,7 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { /// implementations. This should ensure that RPC objects are being built and parsed /// correctly for multiple languages and that the Arrow schemas are returned as expected. class FlightSqlScenario : public Scenario { + public: Status MakeServer(std::unique_ptr* server, FlightServerOptions* options) override { server->reset(new FlightSqlScenarioServer()); @@ -785,10 +1040,290 @@ class FlightSqlScenario : public Scenario { AssertEq(kUpdatePreparedStatementExpectedRows, updated_rows, "Wrong number of updated rows for prepared statement ExecuteUpdate")); ARROW_RETURN_NOT_OK(update_prepared_statement->Close()); + return Status::OK(); + } +}; + +/// \brief Integration test scenario for validating the Substrait and +/// transaction extensions to Flight SQL. +class FlightSqlExtensionScenario : public FlightSqlScenario { + public: + Status RunClient(std::unique_ptr client) override { + sql::FlightSqlClient sql_client(std::move(client)); + Status status; + if (!(status = ValidateMetadataRetrieval(&sql_client)).ok()) { + return status.WithMessage("MetadataRetrieval failed: ", status.message()); + } + if (!(status = ValidateStatementExecution(&sql_client)).ok()) { + return status.WithMessage("StatementExecution failed: ", status.message()); + } + if (!(status = ValidatePreparedStatementExecution(&sql_client)).ok()) { + return status.WithMessage("PreparedStatementExecution failed: ", status.message()); + } + if (!(status = ValidateTransactions(&sql_client)).ok()) { + return status.WithMessage("Transactions failed: ", status.message()); + } + return Status::OK(); + } + + Status ValidateMetadataRetrieval(sql::FlightSqlClient* sql_client) { + std::unique_ptr info; + std::vector sql_info = { + sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SQL, + sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT, + sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION, + sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION, + sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION, + sql::SqlInfoOptions::FLIGHT_SQL_SERVER_CANCEL, + sql::SqlInfoOptions::FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT, + sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT, + }; + ARROW_ASSIGN_OR_RAISE(info, sql_client->GetSqlInfo({}, sql_info)); + ARROW_ASSIGN_OR_RAISE(auto reader, + sql_client->DoGet({}, info->endpoints()[0].ticket)); + + ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema()); + if (!sql::SqlSchema::GetSqlInfoSchema()->Equals(*actual_schema, + /*check_metadata=*/true)) { + return Status::Invalid("Schemas did not match. Expected:\n", + *sql::SqlSchema::GetSqlInfoSchema(), "\nActual:\n", + *actual_schema); + } + + sql::SqlInfoResultMap info_values; + while (true) { + ARROW_ASSIGN_OR_RAISE(auto chunk, reader->Next()); + if (!chunk.data) break; + + const auto& info_name = checked_cast(*chunk.data->column(0)); + const auto& value = checked_cast(*chunk.data->column(1)); + + for (int64_t i = 0; i < chunk.data->num_rows(); i++) { + const uint32_t code = info_name.Value(i); + if (info_values.find(code) != info_values.end()) { + return Status::Invalid("Duplicate SqlInfo value ", code); + } + switch (value.type_code(i)) { + case 0: { // string + std::string slot = checked_cast(*value.field(0)) + .GetString(value.value_offset(i)); + info_values[code] = sql::SqlInfoResult(std::move(slot)); + break; + } + case 1: { // bool + bool slot = checked_cast(*value.field(1)) + .Value(value.value_offset(i)); + info_values[code] = sql::SqlInfoResult(slot); + break; + } + case 2: { // int64_t + int64_t slot = checked_cast(*value.field(2)) + .Value(value.value_offset(i)); + info_values[code] = sql::SqlInfoResult(slot); + break; + } + case 3: { // int32_t + int32_t slot = checked_cast(*value.field(3)) + .Value(value.value_offset(i)); + info_values[code] = sql::SqlInfoResult(slot); + break; + } + default: + return Status::NotImplemented("Decoding SqlInfoResult of type code ", + value.type_code(i)); + } + } + } + + ARROW_RETURN_NOT_OK(AssertUnprintableEq( + info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SQL], + sql::SqlInfoResult(false), "FLIGHT_SQL_SERVER_SQL did not match")); + ARROW_RETURN_NOT_OK(AssertUnprintableEq( + info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT], + sql::SqlInfoResult(true), "FLIGHT_SQL_SERVER_SUBSTRAIT did not match")); + ARROW_RETURN_NOT_OK(AssertUnprintableEq( + info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION], + sql::SqlInfoResult(std::string("min_version")), + "FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION did not match")); + ARROW_RETURN_NOT_OK(AssertUnprintableEq( + info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION], + sql::SqlInfoResult(std::string("max_version")), + "FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION did not match")); + ARROW_RETURN_NOT_OK(AssertUnprintableEq( + info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION], + sql::SqlInfoResult(sql::SqlInfoOptions::SqlSupportedTransaction:: + SQL_SUPPORTED_TRANSACTION_SAVEPOINT), + "FLIGHT_SQL_SERVER_TRANSACTION did not match")); + ARROW_RETURN_NOT_OK(AssertUnprintableEq( + info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_CANCEL], + sql::SqlInfoResult(true), "FLIGHT_SQL_SERVER_CANCEL did not match")); + ARROW_RETURN_NOT_OK(AssertUnprintableEq( + info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT], + sql::SqlInfoResult(int32_t(42)), + "FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT did not match")); + ARROW_RETURN_NOT_OK(AssertUnprintableEq( + info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT], + sql::SqlInfoResult(int32_t(7)), + "FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT did not match")); + + return Status::OK(); + } + + Status ValidateStatementExecution(sql::FlightSqlClient* sql_client) { + ARROW_ASSIGN_OR_RAISE(std::unique_ptr info, + sql_client->ExecuteSubstrait({}, kSubstraitPlan)); + ARROW_RETURN_NOT_OK(Validate(GetQuerySchema(), *info, sql_client)); + + ARROW_ASSIGN_OR_RAISE(std::unique_ptr schema, + sql_client->GetExecuteSubstraitSchema({}, kSubstraitPlan)); + ARROW_RETURN_NOT_OK(ValidateSchema(GetQuerySchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE(info, sql_client->ExecuteSubstrait({}, kSubstraitPlan)); + ARROW_ASSIGN_OR_RAISE(sql::CancelResult cancel_result, + sql_client->CancelQuery({}, *info)); + ARROW_RETURN_NOT_OK( + AssertEq(sql::CancelResult::kCancelled, cancel_result, "Wrong cancel result")); + + ARROW_ASSIGN_OR_RAISE(const int64_t updated_rows, + sql_client->ExecuteSubstraitUpdate({}, kSubstraitPlan)); + ARROW_RETURN_NOT_OK( + AssertEq(kUpdateStatementExpectedRows, updated_rows, + "Wrong number of updated rows for ExecuteSubstraitUpdate")); + + return Status::OK(); + } + + Status ValidatePreparedStatementExecution(sql::FlightSqlClient* sql_client) { + auto parameters = + RecordBatch::Make(GetQuerySchema(), 1, {ArrayFromJSON(int64(), "[1]")}); + + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr substrait_prepared_statement, + sql_client->PrepareSubstrait({}, kSubstraitPlan)); + ARROW_RETURN_NOT_OK(substrait_prepared_statement->SetParameters(parameters)); + ARROW_ASSIGN_OR_RAISE(std::unique_ptr info, + substrait_prepared_statement->Execute()); + ARROW_RETURN_NOT_OK(Validate(GetQuerySchema(), *info, sql_client)); + ARROW_ASSIGN_OR_RAISE(std::unique_ptr schema, + substrait_prepared_statement->GetSchema({})); + ARROW_RETURN_NOT_OK(ValidateSchema(GetQuerySchema(), *schema)); + ARROW_RETURN_NOT_OK(substrait_prepared_statement->Close()); + + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr update_substrait_prepared_statement, + sql_client->PrepareSubstrait({}, kSubstraitPlan)); + ARROW_ASSIGN_OR_RAISE(const int64_t updated_rows, + update_substrait_prepared_statement->ExecuteUpdate()); + ARROW_RETURN_NOT_OK( + AssertEq(kUpdatePreparedStatementExpectedRows, updated_rows, + "Wrong number of updated rows for prepared statement ExecuteUpdate")); + ARROW_RETURN_NOT_OK(update_substrait_prepared_statement->Close()); + + return Status::OK(); + } + + Status ValidateTransactions(sql::FlightSqlClient* sql_client) { + ARROW_ASSIGN_OR_RAISE(sql::Transaction transaction, sql_client->BeginTransaction({})); + ARROW_RETURN_NOT_OK(AssertEq( + kTransactionId, transaction.transaction_id(), "Wrong transaction ID")); + + ARROW_ASSIGN_OR_RAISE(sql::Savepoint savepoint, + sql_client->BeginSavepoint({}, transaction, kSavepointName)); + ARROW_RETURN_NOT_OK(AssertEq(kSavepointId, savepoint.savepoint_id(), + "Wrong savepoint ID")); + + ARROW_ASSIGN_OR_RAISE(std::unique_ptr info, + sql_client->Execute({}, kSelectStatement, transaction)); + ARROW_RETURN_NOT_OK(Validate(GetQueryWithTransactionSchema(), *info, sql_client)); + + ARROW_ASSIGN_OR_RAISE(info, + sql_client->ExecuteSubstrait({}, kSubstraitPlan, transaction)); + ARROW_RETURN_NOT_OK(Validate(GetQueryWithTransactionSchema(), *info, sql_client)); + + ARROW_ASSIGN_OR_RAISE( + std::unique_ptr schema, + sql_client->GetExecuteSchema({}, kSelectStatement, transaction)); + ARROW_RETURN_NOT_OK(ValidateSchema(GetQueryWithTransactionSchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE( + schema, sql_client->GetExecuteSubstraitSchema({}, kSubstraitPlan, transaction)); + ARROW_RETURN_NOT_OK(ValidateSchema(GetQueryWithTransactionSchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE(int64_t updated_rows, + sql_client->ExecuteUpdate({}, "UPDATE STATEMENT", transaction)); + ARROW_RETURN_NOT_OK( + AssertEq(kUpdateStatementWithTransactionExpectedRows, updated_rows, + "Wrong number of updated rows for ExecuteUpdate with transaction")); + ARROW_ASSIGN_OR_RAISE(updated_rows, sql_client->ExecuteSubstraitUpdate( + {}, kSubstraitPlan, transaction)); + ARROW_RETURN_NOT_OK(AssertEq( + kUpdateStatementWithTransactionExpectedRows, updated_rows, + "Wrong number of updated rows for ExecuteSubstraitUpdate with transaction")); + + auto parameters = + RecordBatch::Make(GetQuerySchema(), 1, {ArrayFromJSON(int64(), "[1]")}); + + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr select_prepared_statement, + sql_client->Prepare({}, "SELECT PREPARED STATEMENT", transaction)); + ARROW_RETURN_NOT_OK(select_prepared_statement->SetParameters(parameters)); + ARROW_ASSIGN_OR_RAISE(info, select_prepared_statement->Execute()); + ARROW_RETURN_NOT_OK(Validate(GetQueryWithTransactionSchema(), *info, sql_client)); + ARROW_ASSIGN_OR_RAISE(schema, select_prepared_statement->GetSchema({})); + ARROW_RETURN_NOT_OK(ValidateSchema(GetQueryWithTransactionSchema(), *schema)); + ARROW_RETURN_NOT_OK(select_prepared_statement->Close()); + + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr substrait_prepared_statement, + sql_client->PrepareSubstrait({}, kSubstraitPlan, transaction)); + ARROW_RETURN_NOT_OK(substrait_prepared_statement->SetParameters(parameters)); + ARROW_ASSIGN_OR_RAISE(info, substrait_prepared_statement->Execute()); + ARROW_RETURN_NOT_OK(Validate(GetQueryWithTransactionSchema(), *info, sql_client)); + ARROW_ASSIGN_OR_RAISE(schema, substrait_prepared_statement->GetSchema({})); + ARROW_RETURN_NOT_OK(ValidateSchema(GetQueryWithTransactionSchema(), *schema)); + ARROW_RETURN_NOT_OK(substrait_prepared_statement->Close()); + + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr update_prepared_statement, + sql_client->Prepare({}, "UPDATE PREPARED STATEMENT", transaction)); + ARROW_ASSIGN_OR_RAISE(updated_rows, update_prepared_statement->ExecuteUpdate()); + ARROW_RETURN_NOT_OK(AssertEq(kUpdatePreparedStatementWithTransactionExpectedRows, + updated_rows, + "Wrong number of updated rows for prepared statement " + "ExecuteUpdate with transaction")); + ARROW_RETURN_NOT_OK(update_prepared_statement->Close()); + + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr update_substrait_prepared_statement, + sql_client->PrepareSubstrait({}, kSubstraitPlan, transaction)); + ARROW_ASSIGN_OR_RAISE(updated_rows, + update_substrait_prepared_statement->ExecuteUpdate()); + ARROW_RETURN_NOT_OK(AssertEq(kUpdatePreparedStatementWithTransactionExpectedRows, + updated_rows, + "Wrong number of updated rows for prepared statement " + "ExecuteUpdate with transaction")); + ARROW_RETURN_NOT_OK(update_substrait_prepared_statement->Close()); + + ARROW_RETURN_NOT_OK(sql_client->Rollback({}, savepoint)); + + ARROW_ASSIGN_OR_RAISE(sql::Savepoint savepoint2, + sql_client->BeginSavepoint({}, transaction, kSavepointName)); + ARROW_RETURN_NOT_OK(AssertEq(kSavepointId, savepoint.savepoint_id(), + "Wrong savepoint ID")); + ARROW_RETURN_NOT_OK(sql_client->Release({}, savepoint)); + + ARROW_RETURN_NOT_OK(sql_client->Commit({}, transaction)); + + ARROW_ASSIGN_OR_RAISE(sql::Transaction transaction2, + sql_client->BeginTransaction({})); + ARROW_RETURN_NOT_OK(AssertEq( + kTransactionId, transaction.transaction_id(), "Wrong transaction ID")); + ARROW_RETURN_NOT_OK(sql_client->Rollback({}, transaction2)); return Status::OK(); } }; +} // namespace Status GetScenario(const std::string& scenario_name, std::shared_ptr* out) { if (scenario_name == "auth:basic_proto") { @@ -800,6 +1335,9 @@ Status GetScenario(const std::string& scenario_name, std::shared_ptr* } else if (scenario_name == "flight_sql") { *out = std::make_shared(); return Status::OK(); + } else if (scenario_name == "flight_sql:extension") { + *out = std::make_shared(); + return Status::OK(); } return Status::KeyError("Scenario not found: ", scenario_name); } diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 1a3b52910c0..e9736b0615e 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -353,7 +353,9 @@ RecordBatchStream::RecordBatchStream(const std::shared_ptr& r impl_.reset(new RecordBatchStreamImpl(reader, options)); } -RecordBatchStream::~RecordBatchStream() {} +RecordBatchStream::~RecordBatchStream() { + ARROW_WARN_NOT_OK(impl_->Close(), "Failed to close FlightDataStream"); +} Status RecordBatchStream::Close() { return impl_->Close(); } diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt index f7312de23a9..14503069dd0 100644 --- a/cpp/src/arrow/flight/sql/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/CMakeLists.txt @@ -89,7 +89,11 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) example/sqlite_statement_batch_reader.cc example/sqlite_server.cc example/sqlite_tables_schema_batch_reader.cc) + set(ARROW_FLIGHT_SQL_TEST_SRCS server_test.cc) + set(ARROW_FLIGHT_SQL_TEST_LIBS ${SQLite3_LIBRARIES}) + set(ARROW_FLIGHT_SQL_ACERO_SRCS example/acero_server.cc) + if(NOT MSVC AND NOT MINGW) # ARROW-16902: getting Protobuf generated code to have all the # proper dllexport/dllimport declarations is difficult, since @@ -98,13 +102,34 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) list(APPEND ARROW_FLIGHT_SQL_TEST_SRCS client_test.cc) endif() + if(ARROW_COMPUTE + AND ARROW_PARQUET + AND ARROW_SUBSTRAIT) + list(APPEND ARROW_FLIGHT_SQL_TEST_SRCS ${ARROW_FLIGHT_SQL_ACERO_SRCS} acero_test.cc) + if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") + list(APPEND ARROW_FLIGHT_SQL_TEST_LIBS arrow_substrait_static) + else() + list(APPEND ARROW_FLIGHT_SQL_TEST_LIBS arrow_substrait_shared) + endif() + + if(ARROW_BUILD_EXAMPLES) + add_executable(acero-flight-sql-server ${ARROW_FLIGHT_SQL_ACERO_SRCS} + example/acero_main.cc) + target_link_libraries(acero-flight-sql-server + PRIVATE ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} + ${ARROW_FLIGHT_SQL_TEST_LIBS} ${GFLAGS_LIBRARIES}) + endif() + endif() + add_arrow_test(flight_sql_test SOURCES ${ARROW_FLIGHT_SQL_TEST_SRCS} ${ARROW_FLIGHT_SQL_TEST_SERVER_SRCS} STATIC_LINK_LIBS ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} - ${SQLite3_LIBRARIES} + ${ARROW_FLIGHT_SQL_TEST_LIBS} + EXTRA_INCLUDES + "${CMAKE_CURRENT_BINARY_DIR}/../" LABELS "arrow_flight_sql") diff --git a/cpp/src/arrow/flight/sql/acero_test.cc b/cpp/src/arrow/flight/sql/acero_test.cc new file mode 100644 index 00000000000..fd3c52e74f3 --- /dev/null +++ b/cpp/src/arrow/flight/sql/acero_test.cc @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Integration test using the Acero backend + +#include +#include + +#include +#include + +#include "arrow/array.h" +#include "arrow/engine/substrait/util.h" +#include "arrow/flight/server.h" +#include "arrow/flight/sql/client.h" +#include "arrow/flight/sql/example/acero_server.h" +#include "arrow/flight/sql/types.h" +#include "arrow/flight/types.h" +#include "arrow/stl_iterator.h" +#include "arrow/table.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type_fwd.h" +#include "arrow/util/checked_cast.h" + +namespace arrow { +namespace flight { +namespace sql { + +using arrow::internal::checked_cast; + +class TestAcero : public ::testing::Test { + public: + void SetUp() override { + ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0)); + flight::FlightServerOptions options(location); + + ASSERT_OK_AND_ASSIGN(server_, acero_example::MakeAceroServer()); + ASSERT_OK(server_->Init(options)); + + ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(server_->location())); + client_.reset(new FlightSqlClient(std::move(client))); + } + + void TearDown() override { + ASSERT_OK(client_->Close()); + ASSERT_OK(server_->Shutdown()); + } + + protected: + std::unique_ptr client_; + std::unique_ptr server_; +}; + +arrow::Result> MakeSubstraitPlan() { + ARROW_ASSIGN_OR_RAISE(std::string dir_string, + arrow::internal::GetEnvVar("PARQUET_TEST_DATA")); + ARROW_ASSIGN_OR_RAISE(auto dir, + arrow::internal::PlatformFilename::FromString(dir_string)); + ARROW_ASSIGN_OR_RAISE(auto filename, dir.Join("binary.parquet")); + std::string uri = std::string("file://") + filename.ToString(); + + // TODO(ARROW-17229): we should use a RootRel here + std::string json_plan = R"({ + "relations": [ + { + "rel": { + "read": { + "base_schema": { + "struct": { + "types": [ + {"binary": {}} + ] + }, + "names": [ + "foo" + ] + }, + "local_files": { + "items": [ + { + "uri_file": "URI_PLACEHOLDER", + "parquet": {} + } + ] + } + } + } + } + ] +})"; + std::string uri_placeholder = "URI_PLACEHOLDER"; + json_plan.replace(json_plan.find(uri_placeholder), uri_placeholder.size(), uri); + return engine::SerializeJsonPlan(json_plan); +} + +TEST_F(TestAcero, GetSqlInfo) { + FlightCallOptions call_options; + std::vector sql_info_codes = { + SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT, + SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION, + }; + ASSERT_OK_AND_ASSIGN(auto flight_info, + client_->GetSqlInfo(call_options, sql_info_codes)); + ASSERT_OK_AND_ASSIGN(auto reader, + client_->DoGet(call_options, flight_info->endpoints()[0].ticket)); + ASSERT_OK_AND_ASSIGN(auto results, reader->ToTable()); + ASSERT_OK_AND_ASSIGN(auto batch, results->CombineChunksToBatch()); + ASSERT_EQ(2, results->num_rows()); + std::vector> info; + const auto& ids = checked_cast(*batch->column(0)); + const auto& values = checked_cast(*batch->column(1)); + for (int64_t i = 0; i < batch->num_rows(); i++) { + ASSERT_OK_AND_ASSIGN(auto scalar, values.GetScalar(i)); + if (ids.Value(i) == SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT) { + ASSERT_EQ(*checked_cast(*scalar).value, + BooleanScalar(true)); + } else if (ids.Value(i) == SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION) { + ASSERT_EQ( + *checked_cast(*scalar).value, + Int32Scalar( + SqlInfoOptions::SqlSupportedTransaction::SQL_SUPPORTED_TRANSACTION_NONE)); + } else { + FAIL() << "Unexpected info value: " << ids.Value(i); + } + } +} + +TEST_F(TestAcero, Scan) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + + FlightCallOptions call_options; + ASSERT_OK_AND_ASSIGN(auto serialized_plan, MakeSubstraitPlan()); + + SubstraitPlan plan{serialized_plan->ToString(), /*version=*/"0.6.0"}; + ASSERT_OK_AND_ASSIGN(std::unique_ptr info, + client_->ExecuteSubstrait(call_options, plan)); + ipc::DictionaryMemo memo; + ASSERT_OK_AND_ASSIGN(auto schema, info->GetSchema(&memo)); + // TODO(ARROW-17229): the scanner "special" fields are still included, strip them + // manually + auto fixed_schema = arrow::schema({schema->fields()[0]}); + ASSERT_NO_FATAL_FAILURE( + AssertSchemaEqual(fixed_schema, arrow::schema({field("foo", binary())}))); + + ASSERT_EQ(1, info->endpoints().size()); + ASSERT_EQ(0, info->endpoints()[0].locations.size()); + ASSERT_OK_AND_ASSIGN(auto reader, + client_->DoGet(call_options, info->endpoints()[0].ticket)); + ASSERT_OK_AND_ASSIGN(auto reader_schema, reader->GetSchema()); + ASSERT_NO_FATAL_FAILURE(AssertSchemaEqual(schema, reader_schema)); + ASSERT_OK_AND_ASSIGN(auto table, reader->ToTable()); + ASSERT_GT(table->num_rows(), 0); +} + +TEST_F(TestAcero, Update) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + + FlightCallOptions call_options; + ASSERT_OK_AND_ASSIGN(auto serialized_plan, MakeSubstraitPlan()); + SubstraitPlan plan{serialized_plan->ToString(), /*version=*/"0.6.0"}; + EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, + ::testing::HasSubstr("Updates are unsupported"), + client_->ExecuteSubstraitUpdate(call_options, plan)); +} + +TEST_F(TestAcero, Prepare) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + + FlightCallOptions call_options; + ASSERT_OK_AND_ASSIGN(auto serialized_plan, MakeSubstraitPlan()); + SubstraitPlan plan{serialized_plan->ToString(), /*version=*/"0.6.0"}; + ASSERT_OK_AND_ASSIGN(auto prepared_statement, + client_->PrepareSubstrait(call_options, plan)); + ASSERT_NE(prepared_statement->dataset_schema(), nullptr); + ASSERT_EQ(prepared_statement->parameter_schema(), nullptr); + + auto fixed_schema = arrow::schema({prepared_statement->dataset_schema()->fields()[0]}); + ASSERT_NO_FATAL_FAILURE( + AssertSchemaEqual(fixed_schema, arrow::schema({field("foo", binary())}))); + + EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, + ::testing::HasSubstr("Updates are unsupported"), + prepared_statement->ExecuteUpdate()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr info, prepared_statement->Execute()); + ASSERT_EQ(1, info->endpoints().size()); + ASSERT_EQ(0, info->endpoints()[0].locations.size()); + ASSERT_OK_AND_ASSIGN(auto reader, + client_->DoGet(call_options, info->endpoints()[0].ticket)); + ASSERT_OK_AND_ASSIGN(auto reader_schema, reader->GetSchema()); + ASSERT_NO_FATAL_FAILURE( + AssertSchemaEqual(prepared_statement->dataset_schema(), reader_schema)); + ASSERT_OK_AND_ASSIGN(auto table, reader->ToTable()); + ASSERT_GT(table->num_rows(), 0); + + ASSERT_OK(prepared_statement->Close()); +} + +TEST_F(TestAcero, Transactions) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + + FlightCallOptions call_options; + ASSERT_OK_AND_ASSIGN(auto serialized_plan, MakeSubstraitPlan()); + Transaction handle("fake-id"); + SubstraitPlan plan{serialized_plan->ToString(), /*version=*/"0.6.0"}; + + EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, + ::testing::HasSubstr("Transactions are unsupported"), + client_->ExecuteSubstrait(call_options, plan, handle)); + EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, + ::testing::HasSubstr("Transactions are unsupported"), + client_->PrepareSubstrait(call_options, plan, handle)); +} + +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/client.cc b/cpp/src/arrow/flight/sql/client.cc index e299b7ceb11..521cf9e8cd6 100644 --- a/cpp/src/arrow/flight/sql/client.cc +++ b/cpp/src/arrow/flight/sql/client.cc @@ -66,8 +66,65 @@ arrow::Result> GetSchemaForCommand( GetFlightDescriptorForCommand(command)); return client->GetSchema(options, descriptor); } + +::arrow::Result PackAction(const std::string& action_type, + const google::protobuf::Message& message) { + google::protobuf::Any any; + if (!any.PackFrom(message)) { + return Status::SerializationError("Could not pack ", message.GetTypeName(), + " into Any"); + } + + std::string buffer; + if (!any.SerializeToString(&buffer)) { + return Status::SerializationError("Could not serialize packed ", + message.GetTypeName()); + } + + Action action; + action.type = action_type; + action.body = Buffer::FromString(std::move(buffer)); + return action; +} + +void SetPlan(const SubstraitPlan& plan, flight_sql_pb::SubstraitPlan* pb_plan) { + pb_plan->set_plan(plan.plan); + pb_plan->set_version(plan.version); +} + +Status ReadResult(ResultStream* results, google::protobuf::Message* message) { + ARROW_ASSIGN_OR_RAISE(auto result, results->Next()); + if (!result) { + return Status::IOError("Server did not return a result for ", message->GetTypeName()); + } + + google::protobuf::Any container; + if (!container.ParseFromArray(result->body->data(), + static_cast(result->body->size()))) { + return Status::IOError("Unable to parse Any (expecting ", message->GetTypeName(), + ")"); + } + if (!container.UnpackTo(message)) { + return Status::IOError("Unable to unpack Any (expecting ", message->GetTypeName(), + ")"); + } + return Status::OK(); +} + +Status DrainResultStream(ResultStream* results) { + while (true) { + ARROW_ASSIGN_OR_RAISE(auto result, results->Next()); + if (!result) break; + } + return Status::OK(); +} } // namespace +const Transaction& no_transaction() { + static Transaction kInvalidTransaction(""); + return kInvalidTransaction; +} + FlightSqlClient::FlightSqlClient(std::shared_ptr client) : impl_(std::move(client)) {} @@ -90,25 +147,59 @@ PreparedStatement::~PreparedStatement() { } arrow::Result> FlightSqlClient::Execute( - const FlightCallOptions& options, const std::string& query) { + const FlightCallOptions& options, const std::string& query, + const Transaction& transaction) { flight_sql_pb::CommandStatementQuery command; command.set_query(query); + if (transaction.is_valid()) { + command.set_transaction_id(transaction.transaction_id()); + } return GetFlightInfoForCommand(this, options, command); } arrow::Result> FlightSqlClient::GetExecuteSchema( - const FlightCallOptions& options, const std::string& query) { + const FlightCallOptions& options, const std::string& query, + const Transaction& transaction) { flight_sql_pb::CommandStatementQuery command; command.set_query(query); + if (transaction.is_valid()) { + command.set_transaction_id(transaction.transaction_id()); + } + return GetSchemaForCommand(this, options, command); +} +arrow::Result> FlightSqlClient::ExecuteSubstrait( + const FlightCallOptions& options, const SubstraitPlan& plan, + const Transaction& transaction) { + flight_sql_pb::CommandStatementSubstraitPlan command; + SetPlan(plan, command.mutable_plan()); + if (transaction.is_valid()) { + command.set_transaction_id(transaction.transaction_id()); + } + + return GetFlightInfoForCommand(this, options, command); +} + +arrow::Result> FlightSqlClient::GetExecuteSubstraitSchema( + const FlightCallOptions& options, const SubstraitPlan& plan, + const Transaction& transaction) { + flight_sql_pb::CommandStatementSubstraitPlan command; + SetPlan(plan, command.mutable_plan()); + if (transaction.is_valid()) { + command.set_transaction_id(transaction.transaction_id()); + } return GetSchemaForCommand(this, options, command); } arrow::Result FlightSqlClient::ExecuteUpdate(const FlightCallOptions& options, - const std::string& query) { + const std::string& query, + const Transaction& transaction) { flight_sql_pb::CommandStatementUpdate command; command.set_query(query); + if (transaction.is_valid()) { + command.set_transaction_id(transaction.transaction_id()); + } ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor, GetFlightDescriptorForCommand(command)); @@ -119,14 +210,41 @@ arrow::Result FlightSqlClient::ExecuteUpdate(const FlightCallOptions& o ARROW_RETURN_NOT_OK(DoPut(options, descriptor, arrow::schema({}), &writer, &reader)); std::shared_ptr metadata; - ARROW_RETURN_NOT_OK(reader->ReadMetadata(&metadata)); + ARROW_RETURN_NOT_OK(writer->Close()); + + flight_sql_pb::DoPutUpdateResult result; + if (!result.ParseFromArray(metadata->data(), static_cast(metadata->size()))) { + return Status::Invalid("Unable to parse DoPutUpdateResult"); + } + + return result.record_count(); +} + +arrow::Result FlightSqlClient::ExecuteSubstraitUpdate( + const FlightCallOptions& options, const SubstraitPlan& plan, + const Transaction& transaction) { + flight_sql_pb::CommandStatementSubstraitPlan command; + SetPlan(plan, command.mutable_plan()); + if (transaction.is_valid()) { + command.set_transaction_id(transaction.transaction_id()); + } + + ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor, + GetFlightDescriptorForCommand(command)); + + std::unique_ptr writer; + std::unique_ptr reader; + + ARROW_RETURN_NOT_OK(DoPut(options, descriptor, arrow::schema({}), &writer, &reader)); - flight_sql_pb::DoPutUpdateResult doPutUpdateResult; + std::shared_ptr metadata; + ARROW_RETURN_NOT_OK(reader->ReadMetadata(&metadata)); + ARROW_RETURN_NOT_OK(writer->Close()); flight_sql_pb::DoPutUpdateResult result; if (!result.ParseFromArray(metadata->data(), static_cast(metadata->size()))) { - return Status::Invalid("Unable to parse DoPutUpdateResult object."); + return Status::Invalid("Unable to parse DoPutUpdateResult"); } return result.record_count(); @@ -357,35 +475,41 @@ arrow::Result> FlightSqlClient::DoGet( } arrow::Result> FlightSqlClient::Prepare( - const FlightCallOptions& options, const std::string& query) { - google::protobuf::Any command; + const FlightCallOptions& options, const std::string& query, + const Transaction& transaction) { flight_sql_pb::ActionCreatePreparedStatementRequest request; request.set_query(query); - command.PackFrom(request); - - Action action; - action.type = "CreatePreparedStatement"; - action.body = Buffer::FromString(command.SerializeAsString()); + if (transaction.is_valid()) { + request.set_transaction_id(transaction.transaction_id()); + } std::unique_ptr results; - + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("CreatePreparedStatement", request)); ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); - ARROW_ASSIGN_OR_RAISE(std::unique_ptr result, results->Next()); - - google::protobuf::Any prepared_result; + return PreparedStatement::ParseResponse(this, std::move(results)); +} - std::shared_ptr message = std::move(result->body); - if (!prepared_result.ParseFromArray(message->data(), - static_cast(message->size()))) { - return Status::Invalid("Unable to parse packed ActionCreatePreparedStatementResult"); +arrow::Result> FlightSqlClient::PrepareSubstrait( + const FlightCallOptions& options, const SubstraitPlan& plan, + const Transaction& transaction) { + flight_sql_pb::ActionCreatePreparedSubstraitPlanRequest request; + SetPlan(plan, request.mutable_plan()); + if (transaction.is_valid()) { + request.set_transaction_id(transaction.transaction_id()); } - flight_sql_pb::ActionCreatePreparedStatementResult prepared_statement_result; + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("CreatePreparedSubstraitPlan", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); - if (!prepared_result.UnpackTo(&prepared_statement_result)) { - return Status::Invalid("Unable to unpack ActionCreatePreparedStatementResult"); - } + return PreparedStatement::ParseResponse(this, std::move(results)); +} + +arrow::Result> PreparedStatement::ParseResponse( + FlightSqlClient* client, std::unique_ptr results) { + flight_sql_pb::ActionCreatePreparedStatementResult prepared_statement_result; + ARROW_RETURN_NOT_OK(ReadResult(results.get(), &prepared_statement_result)); const std::string& serialized_dataset_schema = prepared_statement_result.dataset_schema(); @@ -407,14 +531,14 @@ arrow::Result> FlightSqlClient::Prepare( } auto handle = prepared_statement_result.prepared_statement_handle(); - return std::make_shared(this, handle, dataset_schema, + return std::make_shared(client, handle, dataset_schema, parameter_schema); } arrow::Result> PreparedStatement::Execute( const FlightCallOptions& options) { if (is_closed_) { - return Status::Invalid("Statement already closed."); + return Status::Invalid("Statement with handle '", handle_, "' already closed"); } flight_sql_pb::CommandPreparedStatementQuery command; @@ -433,6 +557,7 @@ arrow::Result> PreparedStatement::Execute( // Wait for the server to ack the result std::shared_ptr buffer; ARROW_RETURN_NOT_OK(reader->ReadMetadata(&buffer)); + ARROW_RETURN_NOT_OK(writer->Close()); } ARROW_ASSIGN_OR_RAISE(auto flight_info, client_->GetFlightInfo(options, descriptor)); @@ -442,7 +567,7 @@ arrow::Result> PreparedStatement::Execute( arrow::Result PreparedStatement::ExecuteUpdate( const FlightCallOptions& options) { if (is_closed_) { - return Status::Invalid("Statement already closed."); + return Status::Invalid("Statement with handle '", handle_, "' already closed"); } flight_sql_pb::CommandPreparedStatementUpdate command; @@ -496,7 +621,7 @@ std::shared_ptr PreparedStatement::parameter_schema() const { arrow::Result> PreparedStatement::GetSchema( const FlightCallOptions& options) { if (is_closed_) { - return Status::Invalid("Statement already closed"); + return Status::Invalid("Statement with handle '", handle_, "' already closed"); } flight_sql_pb::CommandPreparedStatementQuery command; @@ -508,29 +633,185 @@ arrow::Result> PreparedStatement::GetSchema( Status PreparedStatement::Close(const FlightCallOptions& options) { if (is_closed_) { - return Status::Invalid("Statement already closed."); + return Status::Invalid("Statement with handle '", handle_, "' already closed"); } - google::protobuf::Any command; + flight_sql_pb::ActionClosePreparedStatementRequest request; request.set_prepared_statement_handle(handle_); - command.PackFrom(request); + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("ClosePreparedStatement", request)); + ARROW_RETURN_NOT_OK(client_->DoAction(options, action, &results)); + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); - Action action; - action.type = "ClosePreparedStatement"; - action.body = Buffer::FromString(command.SerializeAsString()); + is_closed_ = true; + return Status::OK(); +} + +::arrow::Result FlightSqlClient::BeginTransaction( + const FlightCallOptions& options) { + flight_sql_pb::ActionBeginTransactionRequest request; std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("BeginTransaction", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); - ARROW_RETURN_NOT_OK(client_->DoAction(options, action, &results)); + flight_sql_pb::ActionBeginTransactionResult transaction; + ARROW_RETURN_NOT_OK(ReadResult(results.get(), &transaction)); + if (transaction.transaction_id().empty()) { + return Status::Invalid("Server returned an empty transaction ID"); + } - is_closed_ = true; + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); + return Transaction(transaction.transaction_id()); +} + +::arrow::Result FlightSqlClient::BeginSavepoint( + const FlightCallOptions& options, const Transaction& transaction, + const std::string& name) { + flight_sql_pb::ActionBeginSavepointRequest request; + + if (!transaction.is_valid()) { + return Status::Invalid("Must provide an active transaction"); + } + request.set_transaction_id(transaction.transaction_id()); + request.set_name(name); + + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("BeginSavepoint", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); + + flight_sql_pb::ActionBeginSavepointResult savepoint; + ARROW_RETURN_NOT_OK(ReadResult(results.get(), &savepoint)); + if (savepoint.savepoint_id().empty()) { + return Status::Invalid("Server returned an empty savepoint ID"); + } + + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); + return Savepoint(savepoint.savepoint_id()); +} + +Status FlightSqlClient::Commit(const FlightCallOptions& options, + const Transaction& transaction) { + flight_sql_pb::ActionEndTransactionRequest request; + + if (!transaction.is_valid()) { + return Status::Invalid("Must provide an active transaction"); + } + request.set_transaction_id(transaction.transaction_id()); + request.set_action(flight_sql_pb::ActionEndTransactionRequest::END_TRANSACTION_COMMIT); + + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("EndTransaction", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); + + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); + return Status::OK(); +} +Status FlightSqlClient::Release(const FlightCallOptions& options, + const Savepoint& savepoint) { + flight_sql_pb::ActionEndSavepointRequest request; + + if (!savepoint.is_valid()) { + return Status::Invalid("Must provide an active savepoint"); + } + request.set_savepoint_id(savepoint.savepoint_id()); + request.set_action(flight_sql_pb::ActionEndSavepointRequest::END_SAVEPOINT_RELEASE); + + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("EndSavepoint", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); + + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); + return Status::OK(); +} + +Status FlightSqlClient::Rollback(const FlightCallOptions& options, + const Transaction& transaction) { + flight_sql_pb::ActionEndTransactionRequest request; + + if (!transaction.is_valid()) { + return Status::Invalid("Must provide an active transaction"); + } + request.set_transaction_id(transaction.transaction_id()); + request.set_action( + flight_sql_pb::ActionEndTransactionRequest::END_TRANSACTION_ROLLBACK); + + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("EndTransaction", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); + + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); + return Status::OK(); +} + +Status FlightSqlClient::Rollback(const FlightCallOptions& options, + const Savepoint& savepoint) { + flight_sql_pb::ActionEndSavepointRequest request; + + if (!savepoint.is_valid()) { + return Status::Invalid("Must provide an active savepoint"); + } + request.set_savepoint_id(savepoint.savepoint_id()); + request.set_action(flight_sql_pb::ActionEndSavepointRequest::END_SAVEPOINT_ROLLBACK); + + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("EndSavepoint", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); + + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); return Status::OK(); } +::arrow::Result FlightSqlClient::CancelQuery( + const FlightCallOptions& options, const FlightInfo& info) { + flight_sql_pb::ActionCancelQueryRequest request; + ARROW_ASSIGN_OR_RAISE(auto serialized_info, info.SerializeToString()); + request.set_info(std::move(serialized_info)); + + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("CancelQuery", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); + + flight_sql_pb::ActionCancelQueryResult result; + ARROW_RETURN_NOT_OK(ReadResult(results.get(), &result)); + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); + switch (result.result()) { + case flight_sql_pb::ActionCancelQueryResult::CANCEL_RESULT_UNSPECIFIED: + return CancelResult::kUnspecified; + case flight_sql_pb::ActionCancelQueryResult::CANCEL_RESULT_CANCELLED: + return CancelResult::kCancelled; + case flight_sql_pb::ActionCancelQueryResult::CANCEL_RESULT_CANCELLING: + return CancelResult::kCancelling; + case flight_sql_pb::ActionCancelQueryResult::CANCEL_RESULT_NOT_CANCELLABLE: + return CancelResult::kNotCancellable; + default: + break; + } + return Status::IOError("Server returned unknown result ", result.result()); +} + Status FlightSqlClient::Close() { return impl_->Close(); } +std::ostream& operator<<(std::ostream& os, CancelResult result) { + switch (result) { + case CancelResult::kUnspecified: + os << "CancelResult::kUnspecified"; + break; + case CancelResult::kCancelled: + os << "CancelResult::kCancelled"; + break; + case CancelResult::kCancelling: + os << "CancelResult::kCancelling"; + break; + case CancelResult::kNotCancellable: + os << "CancelResult::kNotCancellable"; + break; + } + return os; +} + } // namespace sql } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h index 26315e0d234..db168847ed6 100644 --- a/cpp/src/arrow/flight/sql/client.h +++ b/cpp/src/arrow/flight/sql/client.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include @@ -32,6 +33,13 @@ namespace flight { namespace sql { class PreparedStatement; +class Transaction; +class Savepoint; + +/// \brief A default transaction to use when the default behavior +/// (auto-commit) is desired. +ARROW_FLIGHT_SQL_EXPORT +const Transaction& no_transaction(); /// \brief Flight client with Flight SQL semantics. /// @@ -47,23 +55,51 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { virtual ~FlightSqlClient() = default; - /// \brief Execute a query on the server. + /// \brief Execute a SQL query on the server. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] query The UTF8-encoded SQL query to be executed. + /// \param[in] transaction A transaction to associate this query with. + /// \return The FlightInfo describing where to access the dataset. + arrow::Result> Execute( + const FlightCallOptions& options, const std::string& query, + const Transaction& transaction = no_transaction()); + + /// \brief Execute a Substrait plan that returns a result set on the server. /// \param[in] options RPC-layer hints for this call. - /// \param[in] query The query to be executed in the UTF-8 format. + /// \param[in] plan The plan to be executed. + /// \param[in] transaction A transaction to associate this query with. /// \return The FlightInfo describing where to access the dataset. - arrow::Result> Execute(const FlightCallOptions& options, - const std::string& query); + arrow::Result> ExecuteSubstrait( + const FlightCallOptions& options, const SubstraitPlan& plan, + const Transaction& transaction = no_transaction()); /// \brief Get the result set schema from the server. arrow::Result> GetExecuteSchema( - const FlightCallOptions& options, const std::string& query); + const FlightCallOptions& options, const std::string& query, + const Transaction& transaction = no_transaction()); + + /// \brief Get the result set schema from the server. + arrow::Result> GetExecuteSubstraitSchema( + const FlightCallOptions& options, const SubstraitPlan& plan, + const Transaction& transaction = no_transaction()); /// \brief Execute an update query on the server. /// \param[in] options RPC-layer hints for this call. - /// \param[in] query The query to be executed in the UTF-8 format. + /// \param[in] query The UTF8-encoded SQL query to be executed. + /// \param[in] transaction A transaction to associate this query with. /// \return The quantity of rows affected by the operation. arrow::Result ExecuteUpdate(const FlightCallOptions& options, - const std::string& query); + const std::string& query, + const Transaction& transaction = no_transaction()); + + /// \brief Execute a Substrait plan that does not return a result set on the server. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] plan The plan to be executed. + /// \param[in] transaction A transaction to associate this query with. + /// \return The FlightInfo describing where to access the dataset. + arrow::Result ExecuteSubstraitUpdate( + const FlightCallOptions& options, const SubstraitPlan& plan, + const Transaction& transaction = no_transaction()); /// \brief Request a list of catalogs. /// \param[in] options RPC-layer hints for this call. @@ -215,9 +251,20 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { /// \brief Create a prepared statement object. /// \param[in] options RPC-layer hints for this call. /// \param[in] query The query that will be executed. + /// \param[in] transaction A transaction to associate this query with. /// \return The created prepared statement. arrow::Result> Prepare( - const FlightCallOptions& options, const std::string& query); + const FlightCallOptions& options, const std::string& query, + const Transaction& transaction = no_transaction()); + + /// \brief Create a prepared statement object. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] plan The Substrait plan that will be executed. + /// \param[in] transaction A transaction to associate this query with. + /// \return The created prepared statement. + arrow::Result> PrepareSubstrait( + const FlightCallOptions& options, const SubstraitPlan& plan, + const Transaction& transaction = no_transaction()); /// \brief Call the underlying Flight client's GetFlightInfo. virtual arrow::Result> GetFlightInfo( @@ -231,6 +278,58 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { return impl_->GetSchema(options, descriptor); } + /// \brief Begin a new transaction. + ::arrow::Result BeginTransaction(const FlightCallOptions& options); + + /// \brief Create a new savepoint within a transaction. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] transaction The parent transaction. + /// \param[in] name A friendly name for the savepoint. + ::arrow::Result BeginSavepoint(const FlightCallOptions& options, + const Transaction& transaction, + const std::string& name); + + /// \brief Commit a transaction. + /// + /// After this, the transaction and all associated savepoints will + /// be invalidated. + /// + /// \param[in] options RPC-layer hints for this call. + /// \param[in] transaction The transaction. + Status Commit(const FlightCallOptions& options, const Transaction& transaction); + + /// \brief Release a savepoint. + /// + /// After this, the savepoint (and all savepoints created after it) will be invalidated. + /// + /// \param[in] options RPC-layer hints for this call. + /// \param[in] savepoint The savepoint. + Status Release(const FlightCallOptions& options, const Savepoint& savepoint); + + /// \brief Rollback a transaction. + /// + /// After this, the transaction and all associated savepoints will be invalidated. + /// + /// \param[in] options RPC-layer hints for this call. + /// \param[in] transaction The transaction. + Status Rollback(const FlightCallOptions& options, const Transaction& transaction); + + /// \brief Rollback a savepoint. + /// + /// After this, the savepoint will still be valid, but all + /// savepoints created after it will be invalidated. + /// + /// \param[in] options RPC-layer hints for this call. + /// \param[in] savepoint The savepoint. + Status Rollback(const FlightCallOptions& options, const Savepoint& savepoint); + + /// \brief Explicitly cancel a query. + /// + /// \param[in] options RPC-layer hints for this call. + /// \param[in] info The FlightInfo of the query to cancel. + ::arrow::Result CancelQuery(const FlightCallOptions& options, + const FlightInfo& info); + /// \brief Explicitly shut down and clean up the client. Status Close(); @@ -278,6 +377,10 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { /// errors can't be caught. ~PreparedStatement(); + /// \brief Create a PreparedStatement by parsing the server response. + static arrow::Result> ParseResponse( + FlightSqlClient* client, std::unique_ptr results); + /// \brief Executes the prepared statement query on the server. /// \return A FlightInfo object representing the stream(s) to fetch. arrow::Result> Execute( @@ -295,8 +398,8 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { /// \return The ResultSet schema from the query. std::shared_ptr dataset_schema() const; - /// \brief Set a RecordBatch that contains the parameters that will be bind. - /// \param parameter_binding The parameters that will be bind. + /// \brief Set a RecordBatch that contains the parameters that will be bound. + /// \param parameter_binding The parameters that will be bound. /// \return Status. Status SetParameters(std::shared_ptr parameter_binding); @@ -305,9 +408,9 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { arrow::Result> GetSchema( const FlightCallOptions& options = {}); - /// \brief Close the prepared statement, so that this PreparedStatement can not used - /// anymore and server can free up any resources. - /// \return Status. + /// \brief Close the prepared statement so the server can free up any resources. + /// + /// After this, the prepared statement may not be used anymore. Status Close(const FlightCallOptions& options = {}); /// \brief Check if the prepared statement is closed. @@ -323,6 +426,29 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { bool is_closed_; }; +/// \brief A handle for a server-side savepoint. +class ARROW_FLIGHT_SQL_EXPORT Savepoint { + public: + explicit Savepoint(std::string savepoint_id) : savepoint_id_(std::move(savepoint_id)) {} + const std::string& savepoint_id() const { return savepoint_id_; } + bool is_valid() const { return !savepoint_id_.empty(); } + + private: + std::string savepoint_id_; +}; + +/// \brief A handle for a server-side transaction. +class ARROW_FLIGHT_SQL_EXPORT Transaction { + public: + explicit Transaction(std::string transaction_id) + : transaction_id_(std::move(transaction_id)) {} + const std::string& transaction_id() const { return transaction_id_; } + bool is_valid() const { return !transaction_id_.empty(); } + + private: + std::string transaction_id_; +}; + } // namespace sql } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/sql/client_test.cc b/cpp/src/arrow/flight/sql/client_test.cc index acd078a8477..984bf454816 100644 --- a/cpp/src/arrow/flight/sql/client_test.cc +++ b/cpp/src/arrow/flight/sql/client_test.cc @@ -410,6 +410,7 @@ TEST_F(TestFlightSqlClient, TestExecuteUpdate) { std::unique_ptr* writer, std::unique_ptr* reader) { reader->reset(new FlightMetadataReaderMock(&buffer_ptr)); + writer->reset(new FlightStreamWriterMock()); return Status::OK(); }); diff --git a/cpp/src/arrow/flight/sql/column_metadata.cc b/cpp/src/arrow/flight/sql/column_metadata.cc index 30ef240105c..adfe81f1730 100644 --- a/cpp/src/arrow/flight/sql/column_metadata.cc +++ b/cpp/src/arrow/flight/sql/column_metadata.cc @@ -118,25 +118,25 @@ const std::shared_ptr& ColumnMetadata::metadata_m } ColumnMetadata::ColumnMetadataBuilder& ColumnMetadata::ColumnMetadataBuilder::CatalogName( - std::string& catalog_name) { + const std::string& catalog_name) { metadata_map_->Append(ColumnMetadata::kCatalogName, catalog_name); return *this; } ColumnMetadata::ColumnMetadataBuilder& ColumnMetadata::ColumnMetadataBuilder::SchemaName( - std::string& schema_name) { + const std::string& schema_name) { metadata_map_->Append(ColumnMetadata::kSchemaName, schema_name); return *this; } ColumnMetadata::ColumnMetadataBuilder& ColumnMetadata::ColumnMetadataBuilder::TableName( - std::string& table_name) { + const std::string& table_name) { metadata_map_->Append(ColumnMetadata::kTableName, table_name); return *this; } ColumnMetadata::ColumnMetadataBuilder& ColumnMetadata::ColumnMetadataBuilder::TypeName( - std::string& type_name) { + const std::string& type_name) { metadata_map_->Append(ColumnMetadata::kTypeName, type_name); return *this; } diff --git a/cpp/src/arrow/flight/sql/column_metadata.h b/cpp/src/arrow/flight/sql/column_metadata.h index 15b139ec580..0eb53f3e0bb 100644 --- a/cpp/src/arrow/flight/sql/column_metadata.h +++ b/cpp/src/arrow/flight/sql/column_metadata.h @@ -122,22 +122,22 @@ class ARROW_FLIGHT_SQL_EXPORT ColumnMetadata { /// \brief Set the catalog name in the KeyValueMetadata object. /// \param[in] catalog_name The catalog name. /// \return A ColumnMetadataBuilder. - ColumnMetadataBuilder& CatalogName(std::string& catalog_name); + ColumnMetadataBuilder& CatalogName(const std::string& catalog_name); /// \brief Set the schema_name in the KeyValueMetadata object. /// \param[in] schema_name The schema_name. /// \return A ColumnMetadataBuilder. - ColumnMetadataBuilder& SchemaName(std::string& schema_name); + ColumnMetadataBuilder& SchemaName(const std::string& schema_name); /// \brief Set the table name in the KeyValueMetadata object. /// \param[in] table_name The table name. /// \return A ColumnMetadataBuilder. - ColumnMetadataBuilder& TableName(std::string& table_name); + ColumnMetadataBuilder& TableName(const std::string& table_name); /// \brief Set the type name in the KeyValueMetadata object. /// \param[in] type_name The type name. /// \return A ColumnMetadataBuilder. - ColumnMetadataBuilder& TypeName(std::string& type_name); + ColumnMetadataBuilder& TypeName(const std::string& type_name); /// \brief Set the precision in the KeyValueMetadata object. /// \param[in] precision The precision. diff --git a/cpp/src/arrow/flight/sql/example/acero_main.cc b/cpp/src/arrow/flight/sql/example/acero_main.cc new file mode 100644 index 00000000000..111bebcbf0f --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/acero_main.cc @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Example Flight SQL server backed by Acero. + +#include + +#include +#include +#include +#include + +#include + +#include "arrow/flight/sql/example/acero_server.h" +#include "arrow/status.h" +#include "arrow/util/logging.h" + +namespace flight = arrow::flight; +namespace sql = arrow::flight::sql; + +DEFINE_string(location, "grpc://localhost:12345", "Location to listen on"); + +arrow::Status RunMain(const std::string& location_str) { + ARROW_ASSIGN_OR_RAISE(flight::Location location, flight::Location::Parse(location_str)); + flight::FlightServerOptions options(location); + + std::unique_ptr server; + ARROW_ASSIGN_OR_RAISE(server, sql::acero_example::MakeAceroServer()); + ARROW_RETURN_NOT_OK(server->Init(options)); + + ARROW_RETURN_NOT_OK(server->SetShutdownOnSignals({SIGTERM})); + + ARROW_LOG(INFO) << "Listening on " << location.ToString(); + + ARROW_RETURN_NOT_OK(server->Serve()); + return arrow::Status::OK(); +} + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + arrow::util::ArrowLog::StartArrowLog("acero-flight-sql-server", + arrow::util::ArrowLogLevel::ARROW_INFO); + arrow::util::ArrowLog::InstallFailureSignalHandler(); + + arrow::Status st = RunMain(FLAGS_location); + + arrow::util::ArrowLog::ShutDownArrowLog(); + + if (!st.ok()) { + std::cerr << st << std::endl; + return EXIT_FAILURE; + } + return EXIT_SUCCESS; +} diff --git a/cpp/src/arrow/flight/sql/example/acero_server.cc b/cpp/src/arrow/flight/sql/example/acero_server.cc new file mode 100644 index 00000000000..ce1483cb8c3 --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/acero_server.cc @@ -0,0 +1,306 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/sql/example/acero_server.h" + +#include +#include +#include +#include + +#include "arrow/engine/substrait/serde.h" +#include "arrow/flight/sql/types.h" +#include "arrow/type.h" +#include "arrow/util/logging.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace acero_example { + +namespace { +/// \brief A SinkNodeConsumer that saves the schema as given to it by +/// the ExecPlan. Used to retrieve the schema of a Substrait plan to +/// fulfill the Flight SQL API contract. +class GetSchemaSinkNodeConsumer : public compute::SinkNodeConsumer { + public: + Status Init(const std::shared_ptr& schema, + compute::BackpressureControl*) override { + schema_ = schema; + return Status::OK(); + } + Status Consume(compute::ExecBatch exec_batch) override { return Status::OK(); } + Future<> Finish() override { return Status::OK(); } + + const std::shared_ptr& schema() const { return schema_; } + + private: + std::shared_ptr schema_; +}; + +/// \brief A SinkNodeConsumer that internally saves batches into a +/// queue, so that it can be read from a RecordBatchReader. In other +/// words, this bridges a push-based interface (ExecPlan) to a +/// pull-based interface (RecordBatchReader). +class QueuingSinkNodeConsumer : public compute::SinkNodeConsumer { + public: + QueuingSinkNodeConsumer() : schema_(nullptr), finished_(false) {} + + Status Init(const std::shared_ptr& schema, + compute::BackpressureControl*) override { + schema_ = schema; + return Status::OK(); + } + + Status Consume(compute::ExecBatch exec_batch) override { + { + std::lock_guard guard(mutex_); + batches_.push_back(std::move(exec_batch)); + batches_added_.notify_all(); + } + + return Status::OK(); + } + + Future<> Finish() override { + { + std::lock_guard guard(mutex_); + finished_ = true; + batches_added_.notify_all(); + } + + return Status::OK(); + } + + const std::shared_ptr& schema() const { return schema_; } + + arrow::Result> Next() { + compute::ExecBatch batch; + { + std::unique_lock guard(mutex_); + batches_added_.wait(guard, [this] { return !batches_.empty() || finished_; }); + + if (finished_ && batches_.empty()) { + return nullptr; + } + batch = std::move(batches_.front()); + batches_.pop_front(); + } + + return batch.ToRecordBatch(schema_); + } + + private: + std::mutex mutex_; + std::condition_variable batches_added_; + std::deque batches_; + std::shared_ptr schema_; + bool finished_; +}; + +/// \brief A RecordBatchReader that pulls from the +/// QueuingSinkNodeConsumer above, blocking until results are +/// available as necessary. +class ConsumerBasedRecordBatchReader : public RecordBatchReader { + public: + explicit ConsumerBasedRecordBatchReader( + std::shared_ptr plan, + std::shared_ptr consumer) + : plan_(std::move(plan)), consumer_(std::move(consumer)) {} + + std::shared_ptr schema() const override { return consumer_->schema(); } + + Status ReadNext(std::shared_ptr* batch) override { + return consumer_->Next().Value(batch); + } + + // TODO(ARROW-17242): FlightDataStream needs to call Close() + Status Close() override { return plan_->finished().status(); } + + private: + std::shared_ptr plan_; + std::shared_ptr consumer_; +}; + +/// \brief An implementation of a Flight SQL service backed by Acero. +class AceroFlightSqlServer : public FlightSqlServerBase { + public: + AceroFlightSqlServer() { + RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT, + SqlInfoResult(true)); + RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION, + SqlInfoResult(std::string("0.6.0"))); + RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION, + SqlInfoResult(std::string("0.6.0"))); + RegisterSqlInfo( + SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION, + SqlInfoResult( + SqlInfoOptions::SqlSupportedTransaction::SQL_SUPPORTED_TRANSACTION_NONE)); + RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_CANCEL, + SqlInfoResult(false)); + } + + arrow::Result> GetFlightInfoSubstraitPlan( + const ServerCallContext& context, const StatementSubstraitPlan& command, + const FlightDescriptor& descriptor) override { + if (!command.transaction_id.empty()) { + return Status::NotImplemented("Transactions are unsupported"); + } + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, + GetPlanSchema(command.plan.plan)); + + ARROW_LOG(INFO) << "GetFlightInfoSubstraitPlan: preparing plan with output schema " + << *output_schema; + + return MakeFlightInfo(command.plan.plan, descriptor, *output_schema); + } + + arrow::Result> GetFlightInfoPreparedStatement( + const ServerCallContext& context, const PreparedStatementQuery& command, + const FlightDescriptor& descriptor) override { + std::shared_ptr plan; + { + std::lock_guard guard(mutex_); + auto it = prepared_.find(command.prepared_statement_handle); + if (it == prepared_.end()) { + return Status::KeyError("Prepared statement not found"); + } + plan = it->second; + } + + return MakeFlightInfo(plan->ToString(), descriptor, Schema({})); + } + + arrow::Result> DoGetStatement( + const ServerCallContext& context, const StatementQueryTicket& command) override { + // GetFlightInfoSubstraitPlan encodes the plan into the ticket + std::shared_ptr serialized_plan = + Buffer::FromString(command.statement_handle); + std::shared_ptr consumer = + std::make_shared(); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan, + engine::DeserializePlan(*serialized_plan, consumer)); + + ARROW_LOG(INFO) << "DoGetStatement: executing plan " << plan->ToString(); + + ARROW_RETURN_NOT_OK(plan->StartProducing()); + + auto reader = std::make_shared(std::move(plan), + std::move(consumer)); + return std::unique_ptr(new RecordBatchStream(reader)); + } + + arrow::Result DoPutCommandSubstraitPlan( + const ServerCallContext& context, const StatementSubstraitPlan& command) override { + return Status::NotImplemented("Updates are unsupported"); + } + + Status DoPutPreparedStatementQuery(const ServerCallContext& context, + const PreparedStatementQuery& command, + FlightMessageReader* reader, + FlightMetadataWriter* writer) override { + return Status::NotImplemented("NYI"); + } + + arrow::Result DoPutPreparedStatementUpdate( + const ServerCallContext& context, const PreparedStatementUpdate& command, + FlightMessageReader* reader) override { + return Status::NotImplemented("Updates are unsupported"); + } + + arrow::Result CreatePreparedSubstraitPlan( + const ServerCallContext& context, + const ActionCreatePreparedSubstraitPlanRequest& request) override { + if (!request.transaction_id.empty()) { + return Status::NotImplemented("Transactions are unsupported"); + } + // There's not any real point to precompiling the plan, since the + // consumer has to be provided here. So this is effectively the + // same as a non-prepared plan. + ARROW_ASSIGN_OR_RAISE(std::shared_ptr schema, + GetPlanSchema(request.plan.plan)); + + std::string handle; + { + std::lock_guard guard(mutex_); + handle = std::to_string(counter_++); + prepared_[handle] = Buffer::FromString(request.plan.plan); + } + + return ActionCreatePreparedStatementResult{ + /*dataset_schema=*/std::move(schema), + /*parameter_schema=*/nullptr, + handle, + }; + } + + Status ClosePreparedStatement( + const ServerCallContext& context, + const ActionClosePreparedStatementRequest& request) override { + std::lock_guard guard(mutex_); + prepared_.erase(request.prepared_statement_handle); + return Status::OK(); + } + + private: + arrow::Result> GetPlanSchema( + const std::string& serialized_plan) { + std::shared_ptr plan_buf = Buffer::FromString(serialized_plan); + std::shared_ptr consumer = + std::make_shared(); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan, + engine::DeserializePlan(*plan_buf, consumer)); + std::shared_ptr output_schema; + for (compute::ExecNode* sink : plan->sinks()) { + // Force SinkNodeConsumer::Init to be called + ARROW_RETURN_NOT_OK(sink->StartProducing()); + output_schema = consumer->schema(); + break; + } + if (!output_schema) { + return Status::Invalid("Could not infer output schema"); + } + return output_schema; + } + + arrow::Result> MakeFlightInfo( + const std::string& plan, const FlightDescriptor& descriptor, const Schema& schema) { + ARROW_ASSIGN_OR_RAISE(auto ticket, CreateStatementQueryTicket(plan)); + std::vector endpoints{ + FlightEndpoint{Ticket{std::move(ticket)}, /*locations=*/{}}}; + ARROW_ASSIGN_OR_RAISE(auto info, + FlightInfo::Make(schema, descriptor, std::move(endpoints), + /*total_records=*/-1, /*total_bytes=*/-1)); + return std::make_unique(std::move(info)); + } + + std::mutex mutex_; + std::unordered_map> prepared_; + int64_t counter_; +}; + +} // namespace + +arrow::Result> MakeAceroServer() { + return std::unique_ptr(new AceroFlightSqlServer()); +} + +} // namespace acero_example +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/acero_server.h b/cpp/src/arrow/flight/sql/example/acero_server.h new file mode 100644 index 00000000000..2e82fd3d3b6 --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/acero_server.h @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/flight/sql/server.h" +#include "arrow/flight/sql/visibility.h" +#include "arrow/result.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace acero_example { + +/// \brief Make a Flight SQL server backed by the Acero query engine. +arrow::Result> MakeAceroServer(); + +} // namespace acero_example +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/sqlite_server.cc b/cpp/src/arrow/flight/sql/example/sqlite_server.cc index 35fa05468ba..0d0a7c1ea0e 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_server.cc +++ b/cpp/src/arrow/flight/sql/example/sqlite_server.cc @@ -20,11 +20,12 @@ #include #include -#include +#include #include #include +#include +#include -#include "arrow/api.h" #include "arrow/flight/sql/example/sqlite_sql_info.h" #include "arrow/flight/sql/example/sqlite_statement.h" #include "arrow/flight/sql/example/sqlite_statement_batch_reader.h" @@ -39,18 +40,6 @@ namespace example { namespace { -/// \brief Gets a SqliteStatement by given handle -arrow::Result> GetStatementByHandle( - const std::map>& prepared_statements, - const std::string& handle) { - auto search = prepared_statements.find(handle); - if (search == prepared_statements.end()) { - return Status::Invalid("Prepared statement not found"); - } - - return search->second; -} - std::string PrepareQueryForGetTables(const GetTables& command) { std::stringstream table_query; @@ -237,23 +226,83 @@ int32_t GetSqlTypeFromTypeName(const char* sqlite_type) { } class SQLiteFlightSqlServer::Impl { + private: sqlite3* db_; - std::map> prepared_statements_; + const std::string db_uri_; + std::mutex mutex_; + std::unordered_map> prepared_statements_; + std::unordered_map open_transactions_; std::default_random_engine gen_; + arrow::Result> GetStatementByHandle( + const std::string& handle) { + std::lock_guard guard(mutex_); + auto search = prepared_statements_.find(handle); + if (search == prepared_statements_.end()) { + return Status::KeyError("Prepared statement not found"); + } + return search->second; + } + + arrow::Result GetConnection(const std::string& transaction_id) { + if (transaction_id.empty()) return db_; + + std::lock_guard guard(mutex_); + auto it = open_transactions_.find(transaction_id); + if (it == open_transactions_.end()) { + return Status::KeyError("Unknown transaction ID: ", transaction_id); + } + return it->second; + } + + // Create a Ticket that combines a query and a transaction ID. + arrow::Result EncodeTransactionQuery(const std::string& query, + const std::string& transaction_id) { + std::string transaction_query = transaction_id; + transaction_query += ':'; + transaction_query += query; + ARROW_ASSIGN_OR_RAISE(auto ticket_string, + CreateStatementQueryTicket(transaction_query)); + return Ticket{std::move(ticket_string)}; + } + + arrow::Result> DecodeTransactionQuery( + const std::string& ticket) { + auto divider = ticket.find(':'); + if (divider == std::string::npos) { + return Status::Invalid("Malformed ticket"); + } + std::string transaction_id = ticket.substr(0, divider); + std::string query = ticket.substr(divider + 1); + return std::make_pair(std::move(query), std::move(transaction_id)); + } + public: - explicit Impl(sqlite3* db) : db_(db) {} + explicit Impl(sqlite3* db, std::string uri) : db_(db), db_uri_(std::move(uri)) {} - ~Impl() { sqlite3_close(db_); } + ~Impl() { + sqlite3_close(db_); + for (const auto& pair : open_transactions_) { + sqlite3_close(pair.second); + } + } std::string GenerateRandomString() { uint32_t length = 16; // MSVC doesn't support char types here std::uniform_int_distribution dist(static_cast('0'), - static_cast('z')); + static_cast('Z')); std::string ret(length, 0); - auto get_random_char = [&]() { return static_cast(dist(gen_)); }; + // Don't generate symbols to simplify parsing in DecodeTransactionQuery + auto get_random_char = [&]() { + char res; + while (true) { + res = static_cast(dist(gen_)); + if (res <= '9' || res >= 'A') break; + } + return res; + }; std::generate_n(ret.begin(), length, get_random_char); return ret; } @@ -262,13 +311,12 @@ class SQLiteFlightSqlServer::Impl { const ServerCallContext& context, const StatementQuery& command, const FlightDescriptor& descriptor) { const std::string& query = command.query; - - ARROW_ASSIGN_OR_RAISE(auto statement, SqliteStatement::Create(db_, query)); - + ARROW_ASSIGN_OR_RAISE(auto db, GetConnection(command.transaction_id)); + ARROW_ASSIGN_OR_RAISE(auto statement, SqliteStatement::Create(db, query)); ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); - - ARROW_ASSIGN_OR_RAISE(auto ticket_string, CreateStatementQueryTicket(query)); - std::vector endpoints{FlightEndpoint{{ticket_string}, {}}}; + ARROW_ASSIGN_OR_RAISE(auto ticket, + EncodeTransactionQuery(query, command.transaction_id)); + std::vector endpoints{FlightEndpoint{std::move(ticket), {}}}; ARROW_ASSIGN_OR_RAISE(auto result, FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)) @@ -277,10 +325,13 @@ class SQLiteFlightSqlServer::Impl { arrow::Result> DoGetStatement( const ServerCallContext& context, const StatementQueryTicket& command) { - const std::string& sql = command.statement_handle; + ARROW_ASSIGN_OR_RAISE(auto pair, DecodeTransactionQuery(command.statement_handle)); + const std::string& sql = pair.first; + const std::string transaction_id = pair.second; + ARROW_ASSIGN_OR_RAISE(auto db, GetConnection(transaction_id)); std::shared_ptr statement; - ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db_, sql)); + ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db, sql)); std::shared_ptr reader; ARROW_ASSIGN_OR_RAISE(reader, SqliteStatementBatchReader::Create(statement)); @@ -375,15 +426,15 @@ class SQLiteFlightSqlServer::Impl { arrow::Result DoPutCommandStatementUpdate(const ServerCallContext& context, const StatementUpdate& command) { const std::string& sql = command.query; - - ARROW_ASSIGN_OR_RAISE(auto statement, SqliteStatement::Create(db_, sql)); - + ARROW_ASSIGN_OR_RAISE(auto db, GetConnection(command.transaction_id)); + ARROW_ASSIGN_OR_RAISE(auto statement, SqliteStatement::Create(db, sql)); return statement->ExecuteUpdate(); } arrow::Result CreatePreparedStatement( const ServerCallContext& context, const ActionCreatePreparedStatementRequest& request) { + std::lock_guard guard(mutex_); std::shared_ptr statement; ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db_, request.query)); std::string handle = GenerateRandomString(); @@ -419,6 +470,7 @@ class SQLiteFlightSqlServer::Impl { Status ClosePreparedStatement(const ServerCallContext& context, const ActionClosePreparedStatementRequest& request) { + std::lock_guard guard(mutex_); const std::string& prepared_statement_handle = request.prepared_statement_handle; auto search = prepared_statements_.find(prepared_statement_handle); @@ -434,6 +486,7 @@ class SQLiteFlightSqlServer::Impl { arrow::Result> GetFlightInfoPreparedStatement( const ServerCallContext& context, const PreparedStatementQuery& command, const FlightDescriptor& descriptor) { + std::lock_guard guard(mutex_); const std::string& prepared_statement_handle = command.prepared_statement_handle; auto search = prepared_statements_.find(prepared_statement_handle); @@ -450,6 +503,7 @@ class SQLiteFlightSqlServer::Impl { arrow::Result> DoGetPreparedStatement( const ServerCallContext& context, const PreparedStatementQuery& command) { + std::lock_guard guard(mutex_); const std::string& prepared_statement_handle = command.prepared_statement_handle; auto search = prepared_statements_.find(prepared_statement_handle); @@ -470,9 +524,8 @@ class SQLiteFlightSqlServer::Impl { FlightMessageReader* reader, FlightMetadataWriter* writer) { const std::string& prepared_statement_handle = command.prepared_statement_handle; - ARROW_ASSIGN_OR_RAISE( - auto statement, - GetStatementByHandle(prepared_statements_, prepared_statement_handle)); + ARROW_ASSIGN_OR_RAISE(auto statement, + GetStatementByHandle(prepared_statement_handle)); sqlite3_stmt* stmt = statement->GetSqlite3Stmt(); ARROW_RETURN_NOT_OK(SetParametersOnSQLiteStatement(stmt, reader)); @@ -484,9 +537,8 @@ class SQLiteFlightSqlServer::Impl { const ServerCallContext& context, const PreparedStatementUpdate& command, FlightMessageReader* reader) { const std::string& prepared_statement_handle = command.prepared_statement_handle; - ARROW_ASSIGN_OR_RAISE( - auto statement, - GetStatementByHandle(prepared_statements_, prepared_statement_handle)); + ARROW_ASSIGN_OR_RAISE(auto statement, + GetStatementByHandle(prepared_statement_handle)); sqlite3_stmt* stmt = statement->GetSqlite3Stmt(); ARROW_RETURN_NOT_OK(SetParametersOnSQLiteStatement(stmt, reader)); @@ -627,28 +679,85 @@ class SQLiteFlightSqlServer::Impl { return DoGetSQLiteQuery(db_, query, SqlSchema::GetCrossReferenceSchema()); } - Status ExecuteSql(const std::string& sql) { + Status ExecuteSql(const std::string& sql) { return ExecuteSql(db_, sql); } + + Status ExecuteSql(sqlite3* db, const std::string& sql) { char* err_msg = nullptr; - int rc = sqlite3_exec(db_, sql.c_str(), nullptr, nullptr, &err_msg); + int rc = sqlite3_exec(db, sql.c_str(), nullptr, nullptr, &err_msg); if (rc != SQLITE_OK) { std::string error_msg; if (err_msg != nullptr) { error_msg = err_msg; + sqlite3_free(err_msg); } - sqlite3_free(err_msg); - return Status::ExecutionError(error_msg); + return Status::IOError(error_msg); } + if (err_msg) sqlite3_free(err_msg); return Status::OK(); } + + arrow::Result BeginTransaction( + const ServerCallContext& context, const ActionBeginTransactionRequest& request) { + std::string handle = GenerateRandomString(); + sqlite3* new_db = nullptr; + if (sqlite3_open_v2(db_uri_.c_str(), &new_db, + SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI, + /*zVfs=*/nullptr) != SQLITE_OK) { + std::string error_message = "Can't open new connection: "; + if (new_db) { + error_message += sqlite3_errmsg(new_db); + sqlite3_close(new_db); + } + return Status::Invalid(error_message); + } + + ARROW_RETURN_NOT_OK(ExecuteSql(new_db, "BEGIN TRANSACTION")); + + std::lock_guard guard(mutex_); + open_transactions_[handle] = new_db; + return ActionBeginTransactionResult{std::move(handle)}; + } + + Status EndTransaction(const ServerCallContext& context, + const ActionEndTransactionRequest& request) { + Status status; + sqlite3* transaction = nullptr; + { + std::lock_guard guard(mutex_); + auto it = open_transactions_.find(request.transaction_id); + if (it == open_transactions_.end()) { + return Status::KeyError("Unknown transaction ID: ", request.transaction_id); + } + + if (request.action == ActionEndTransactionRequest::kCommit) { + status = ExecuteSql(it->second, "COMMIT"); + } else { + status = ExecuteSql(it->second, "ROLLBACK"); + } + transaction = it->second; + open_transactions_.erase(it); + } + sqlite3_close(transaction); + return status; + } }; +// Give each server instance its own in-memory DB +std::atomic kDbCounter(0); + SQLiteFlightSqlServer::SQLiteFlightSqlServer(std::shared_ptr impl) : impl_(std::move(impl)) {} arrow::Result> SQLiteFlightSqlServer::Create() { sqlite3* db = nullptr; - if (sqlite3_open(":memory:", &db)) { + // All sqlite3* instances created from this URI will share data + std::string uri = "file:memorydb"; + uri += std::to_string(kDbCounter++); + uri += "?mode=memory&cache=shared"; + if (sqlite3_open_v2(uri.c_str(), &db, + SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI, + /*zVfs=*/nullptr)) { std::string err_msg = "Can't open database: "; if (db != nullptr) { err_msg += sqlite3_errmsg(db); @@ -660,9 +769,10 @@ arrow::Result> SQLiteFlightSqlServer::Cre return Status::Invalid(err_msg); } - std::shared_ptr impl = std::make_shared(db); + std::shared_ptr impl = std::make_shared(db, std::move(uri)); - std::shared_ptr result(new SQLiteFlightSqlServer(impl)); + std::shared_ptr result( + new SQLiteFlightSqlServer(std::move(impl))); for (const auto& id_to_result : GetSqlInfoResultMap()) { result->RegisterSqlInfo(id_to_result.first, id_to_result.second); } @@ -855,6 +965,15 @@ SQLiteFlightSqlServer::DoGetCrossReference(const ServerCallContext& context, return impl_->DoGetCrossReference(context, command); } +arrow::Result SQLiteFlightSqlServer::BeginTransaction( + const ServerCallContext& context, const ActionBeginTransactionRequest& request) { + return impl_->BeginTransaction(context, request); +} +Status SQLiteFlightSqlServer::EndTransaction(const ServerCallContext& context, + const ActionEndTransactionRequest& request) { + return impl_->EndTransaction(context, request); +} + } // namespace example } // namespace sql } // namespace flight diff --git a/cpp/src/arrow/flight/sql/example/sqlite_server.h b/cpp/src/arrow/flight/sql/example/sqlite_server.h index 744ed068d0b..389a2d921bb 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_server.h +++ b/cpp/src/arrow/flight/sql/example/sqlite_server.h @@ -141,6 +141,12 @@ class SQLiteFlightSqlServer : public FlightSqlServerBase { arrow::Result> DoGetPrimaryKeys( const ServerCallContext& context, const GetPrimaryKeys& command) override; + arrow::Result BeginTransaction( + const ServerCallContext& context, + const ActionBeginTransactionRequest& request) override; + Status EndTransaction(const ServerCallContext& context, + const ActionEndTransactionRequest& request) override; + private: class Impl; std::shared_ptr impl_; diff --git a/cpp/src/arrow/flight/sql/example/sqlite_sql_info.cc b/cpp/src/arrow/flight/sql/example/sqlite_sql_info.cc index 94f25b39017..9737b5a3090 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_sql_info.cc +++ b/cpp/src/arrow/flight/sql/example/sqlite_sql_info.cc @@ -18,6 +18,7 @@ #include "arrow/flight/sql/example/sqlite_sql_info.h" #include "arrow/flight/sql/types.h" +#include "arrow/util/config.h" namespace arrow { namespace flight { @@ -33,8 +34,14 @@ SqlInfoResultMap GetSqlInfoResultMap() { {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_VERSION, SqlInfoResult(std::string("sqlite 3"))}, {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_ARROW_VERSION, - SqlInfoResult(std::string("7.0.0-SNAPSHOT" /* Only an example */))}, + SqlInfoResult(std::string(ARROW_VERSION_STRING))}, {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY, SqlInfoResult(false)}, + {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SQL, SqlInfoResult(true)}, + {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT, SqlInfoResult(false)}, + {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION, + SqlInfoResult(SqlInfoOptions::SqlSupportedTransaction:: + SQL_SUPPORTED_TRANSACTION_TRANSACTION)}, + {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_CANCEL, SqlInfoResult(false)}, {SqlInfoOptions::SqlInfo::SQL_DDL_CATALOG, SqlInfoResult(false /* SQLite 3 does not support catalogs */)}, {SqlInfoOptions::SqlInfo::SQL_DDL_SCHEMA, diff --git a/cpp/src/arrow/flight/sql/server.cc b/cpp/src/arrow/flight/sql/server.cc index a8f3ed8a80c..1905b117d61 100644 --- a/cpp/src/arrow/flight/sql/server.cc +++ b/cpp/src/arrow/flight/sql/server.cc @@ -149,11 +149,29 @@ arrow::Result ParseCommandStatementQuery( const google::protobuf::Any& any) { pb::sql::CommandStatementQuery command; if (!any.UnpackTo(&command)) { - return Status::Invalid("Unable to unpack CommandStatementQuery."); + return Status::Invalid("Unable to unpack CommandStatementQuery"); } StatementQuery result; result.query = command.query(); + result.transaction_id = command.transaction_id(); + return result; +} + +SubstraitPlan ParseStatementSubstraitPlan(const pb::sql::SubstraitPlan& pb_plan) { + return {pb_plan.plan(), pb_plan.version()}; +} + +arrow::Result ParseCommandStatementSubstraitPlan( + const google::protobuf::Any& any) { + pb::sql::CommandStatementSubstraitPlan command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack CommandStatementSubstraitPlan"); + } + + StatementSubstraitPlan result; + result.plan = ParseStatementSubstraitPlan(command.plan()); + result.transaction_id = command.transaction_id(); return result; } @@ -190,18 +208,6 @@ arrow::Result ParseCommandGetTables(const google::protobuf::Any& any) return result; } -arrow::Result ParseStatementQueryTicket( - const google::protobuf::Any& any) { - pb::sql::TicketStatementQuery command; - if (!any.UnpackTo(&command)) { - return Status::Invalid("Unable to unpack TicketStatementQuery."); - } - - StatementQueryTicket result; - result.statement_handle = command.statement_handle(); - return result; -} - arrow::Result ParseCommandStatementUpdate( const google::protobuf::Any& any) { pb::sql::CommandStatementUpdate command; @@ -211,6 +217,7 @@ arrow::Result ParseCommandStatementUpdate( StatementUpdate result; result.query = command.query(); + result.transaction_id = command.transaction_id(); return result; } @@ -226,15 +233,65 @@ arrow::Result ParseCommandPreparedStatementUpdate( return result; } +arrow::Result ParseActionBeginSavepointRequest( + const google::protobuf::Any& any) { + pb::sql::ActionBeginSavepointRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionBeginSavepointRequest"); + } + + ActionBeginSavepointRequest result; + result.transaction_id = command.transaction_id(); + result.name = command.name(); + return result; +} + +arrow::Result ParseActionBeginTransactionRequest( + const google::protobuf::Any& any) { + pb::sql::ActionBeginTransactionRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionBeginTransactionRequest"); + } + + ActionBeginTransactionRequest result; + return result; +} + +arrow::Result ParseActionCancelQueryRequest( + const google::protobuf::Any& any) { + pb::sql::ActionCancelQueryRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionCancelQueryRequest"); + } + + ActionCancelQueryRequest result; + ARROW_ASSIGN_OR_RAISE(result.info, FlightInfo::Deserialize(command.info())); + return result; +} + arrow::Result ParseActionCreatePreparedStatementRequest(const google::protobuf::Any& any) { pb::sql::ActionCreatePreparedStatementRequest command; if (!any.UnpackTo(&command)) { - return Status::Invalid("Unable to unpack ActionCreatePreparedStatementRequest."); + return Status::Invalid("Unable to unpack ActionCreatePreparedStatementRequest"); } ActionCreatePreparedStatementRequest result; result.query = command.query(); + result.transaction_id = command.transaction_id(); + return result; +} + +arrow::Result +ParseActionCreatePreparedSubstraitPlanRequest(const google::protobuf::Any& any) { + pb::sql::ActionCreatePreparedSubstraitPlanRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionCreatePreparedSubstraitPlanRequest"); + } + + ActionCreatePreparedSubstraitPlanRequest result; + result.plan = ParseStatementSubstraitPlan(command.plan()); + result.transaction_id = command.transaction_id(); return result; } @@ -242,7 +299,7 @@ arrow::Result ParseActionClosePreparedStatementRequest(const google::protobuf::Any& any) { pb::sql::ActionClosePreparedStatementRequest command; if (!any.UnpackTo(&command)) { - return Status::Invalid("Unable to unpack ActionClosePreparedStatementRequest."); + return Status::Invalid("Unable to unpack ActionClosePreparedStatementRequest"); } ActionClosePreparedStatementRequest result; @@ -250,8 +307,139 @@ ParseActionClosePreparedStatementRequest(const google::protobuf::Any& any) { return result; } +arrow::Result ParseActionEndSavepointRequest( + const google::protobuf::Any& any) { + pb::sql::ActionEndSavepointRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionEndSavepointRequest"); + } + + ActionEndSavepointRequest result; + result.savepoint_id = command.savepoint_id(); + switch (command.action()) { + case pb::sql::ActionEndSavepointRequest::END_SAVEPOINT_UNSPECIFIED: + return Status::Invalid( + "ActionEndSavepointRequest.action was END_SAVEPOINT_UNSPECIFIED"); + case pb::sql::ActionEndSavepointRequest::END_SAVEPOINT_RELEASE: + result.action = ActionEndSavepointRequest::kRelease; + break; + case pb::sql::ActionEndSavepointRequest::END_SAVEPOINT_ROLLBACK: + result.action = ActionEndSavepointRequest::kRollback; + break; + default: + return Status::Invalid("Unknown value for ActionEndSavepointRequest.action: ", + command.action()); + } + return result; +} + +arrow::Result ParseActionEndTransactionRequest( + const google::protobuf::Any& any) { + pb::sql::ActionEndTransactionRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionEndTransactionRequest"); + } + + ActionEndTransactionRequest result; + result.transaction_id = command.transaction_id(); + switch (command.action()) { + case pb::sql::ActionEndTransactionRequest::END_TRANSACTION_UNSPECIFIED: + return Status::Invalid( + "ActionEndTransactionRequest.action was END_TRANSACTION_UNSPECIFIED"); + case pb::sql::ActionEndTransactionRequest::END_TRANSACTION_COMMIT: + result.action = ActionEndTransactionRequest::kCommit; + break; + case pb::sql::ActionEndTransactionRequest::END_TRANSACTION_ROLLBACK: + result.action = ActionEndTransactionRequest::kRollback; + break; + default: + return Status::Invalid("Unknown value for ActionEndTransactionRequest.action: ", + command.action()); + } + return result; +} + +arrow::Result PackActionResult(const google::protobuf::Message& message) { + google::protobuf::Any any; + if (!any.PackFrom(message)) { + return Status::IOError("Failed to pack ", message.GetTypeName()); + } + + std::string buffer; + if (!any.SerializeToString(&buffer)) { + return Status::IOError("Failed to serialize packed ", message.GetTypeName()); + } + return Result{Buffer::FromString(std::move(buffer))}; +} + +arrow::Result PackActionResult(ActionBeginSavepointResult result) { + pb::sql::ActionBeginSavepointResult pb_result; + pb_result.set_savepoint_id(std::move(result.savepoint_id)); + return PackActionResult(pb_result); +} + +arrow::Result PackActionResult(ActionBeginTransactionResult result) { + pb::sql::ActionBeginTransactionResult pb_result; + pb_result.set_transaction_id(std::move(result.transaction_id)); + return PackActionResult(pb_result); +} + +arrow::Result PackActionResult(CancelResult result) { + pb::sql::ActionCancelQueryResult pb_result; + switch (result) { + case CancelResult::kUnspecified: + pb_result.set_result(pb::sql::ActionCancelQueryResult::CANCEL_RESULT_UNSPECIFIED); + break; + case CancelResult::kCancelled: + pb_result.set_result(pb::sql::ActionCancelQueryResult::CANCEL_RESULT_CANCELLED); + break; + case CancelResult::kCancelling: + pb_result.set_result(pb::sql::ActionCancelQueryResult::CANCEL_RESULT_CANCELLING); + break; + case CancelResult::kNotCancellable: + pb_result.set_result( + pb::sql::ActionCancelQueryResult::CANCEL_RESULT_NOT_CANCELLABLE); + break; + } + return PackActionResult(pb_result); +} + +arrow::Result PackActionResult(ActionCreatePreparedStatementResult result) { + pb::sql::ActionCreatePreparedStatementResult pb_result; + pb_result.set_prepared_statement_handle(std::move(result.prepared_statement_handle)); + if (result.dataset_schema != nullptr) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr serialized, + ipc::SerializeSchema(*result.dataset_schema)); + pb_result.set_dataset_schema(reinterpret_cast(serialized->data()), + serialized->size()); + } + if (result.parameter_schema != nullptr) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr serialized, + ipc::SerializeSchema(*result.parameter_schema)); + pb_result.set_parameter_schema(reinterpret_cast(serialized->data()), + serialized->size()); + } + + return PackActionResult(pb_result); +} + } // namespace +arrow::Result StatementQueryTicket::Deserialize( + std::string_view serialized) { + pb::sql::TicketStatementQuery command; + google::protobuf::Any any; + if (!any.ParseFromArray(serialized.data(), static_cast(serialized.size()))) { + return Status::Invalid("Unable to parse ticket"); + } + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack TicketStatementQuery"); + } + StatementQueryTicket result; + result.statement_handle = command.statement_handle(); + return result; +} + arrow::Result CreateStatementQueryTicket( const std::string& statement_handle) { protocol::sql::TicketStatementQuery ticket_statement_query; @@ -282,6 +470,12 @@ Status FlightSqlServerBase::GetFlightInfo(const ServerCallContext& context, ARROW_ASSIGN_OR_RAISE(*info, GetFlightInfoStatement(context, internal_command, request)); return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(StatementSubstraitPlan internal_command, + ParseCommandStatementSubstraitPlan(any)); + ARROW_ASSIGN_OR_RAISE(*info, + GetFlightInfoSubstraitPlan(context, internal_command, request)); + return Status::OK(); } else if (any.Is()) { ARROW_ASSIGN_OR_RAISE(PreparedStatementQuery internal_command, ParseCommandPreparedStatementQuery(any)); @@ -358,6 +552,12 @@ Status FlightSqlServerBase::GetSchema(const ServerCallContext& context, ARROW_ASSIGN_OR_RAISE(*schema, GetSchemaStatement(context, internal_command, request)); return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(StatementSubstraitPlan internal_command, + ParseCommandStatementSubstraitPlan(any)); + ARROW_ASSIGN_OR_RAISE(*schema, + GetSchemaSubstraitPlan(context, internal_command, request)); + return Status::OK(); } else if (any.Is()) { ARROW_ASSIGN_OR_RAISE(PreparedStatementQuery internal_command, ParseCommandPreparedStatementQuery(any)); @@ -413,15 +613,19 @@ Status FlightSqlServerBase::GetSchema(const ServerCallContext& context, Status FlightSqlServerBase::DoGet(const ServerCallContext& context, const Ticket& request, std::unique_ptr* stream) { google::protobuf::Any any; - if (!any.ParseFromArray(request.ticket.data(), static_cast(request.ticket.size()))) { - return Status::Invalid("Unable to parse ticket."); + return Status::Invalid("Unable to parse ticket"); } if (any.Is()) { - ARROW_ASSIGN_OR_RAISE(StatementQueryTicket command, ParseStatementQueryTicket(any)); - ARROW_ASSIGN_OR_RAISE(*stream, DoGetStatement(context, command)); + pb::sql::TicketStatementQuery command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack TicketStatementQuery"); + } + StatementQueryTicket result; + result.statement_handle = command.statement_handle(); + ARROW_ASSIGN_OR_RAISE(*stream, DoGetStatement(context, result)); return Status::OK(); } else if (any.Is()) { ARROW_ASSIGN_OR_RAISE(PreparedStatementQuery internal_command, @@ -483,7 +687,7 @@ Status FlightSqlServerBase::DoPut(const ServerCallContext& context, google::protobuf::Any any; if (!any.ParseFromArray(request.cmd.data(), static_cast(request.cmd.size()))) { - return Status::Invalid("Unable to parse command."); + return Status::Invalid("Unable to parse command"); } if (any.Is()) { @@ -498,6 +702,18 @@ Status FlightSqlServerBase::DoPut(const ServerCallContext& context, const auto buffer = Buffer::FromString(result.SerializeAsString()); ARROW_RETURN_NOT_OK(writer->WriteMetadata(*buffer)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(StatementSubstraitPlan internal_command, + ParseCommandStatementSubstraitPlan(any)); + ARROW_ASSIGN_OR_RAISE(auto record_count, + DoPutCommandSubstraitPlan(context, internal_command)); + + pb::sql::DoPutUpdateResult result; + result.set_record_count(record_count); + + const auto buffer = Buffer::FromString(result.SerializeAsString()); + ARROW_RETURN_NOT_OK(writer->WriteMetadata(*buffer)); return Status::OK(); } else if (any.Is()) { ARROW_ASSIGN_OR_RAISE(PreparedStatementQuery internal_command, @@ -507,78 +723,104 @@ Status FlightSqlServerBase::DoPut(const ServerCallContext& context, } else if (any.Is()) { ARROW_ASSIGN_OR_RAISE(PreparedStatementUpdate internal_command, ParseCommandPreparedStatementUpdate(any)); - ARROW_ASSIGN_OR_RAISE(auto record_count, DoPutPreparedStatementUpdate( - context, internal_command, reader.get())) + ARROW_ASSIGN_OR_RAISE( + auto record_count, + DoPutPreparedStatementUpdate(context, internal_command, reader.get())); pb::sql::DoPutUpdateResult result; result.set_record_count(record_count); const auto buffer = Buffer::FromString(result.SerializeAsString()); ARROW_RETURN_NOT_OK(writer->WriteMetadata(*buffer)); - return Status::OK(); } - return Status::Invalid("The defined request is invalid."); + return Status::NotImplemented("Command not recognized: ", any.type_url()); } Status FlightSqlServerBase::ListActions(const ServerCallContext& context, std::vector* actions) { - *actions = {FlightSqlServerBase::kCreatePreparedStatementActionType, - FlightSqlServerBase::kClosePreparedStatementActionType}; + *actions = { + FlightSqlServerBase::kBeginSavepointActionType, + FlightSqlServerBase::kBeginTransactionActionType, + FlightSqlServerBase::kCancelQueryActionType, + FlightSqlServerBase::kCreatePreparedStatementActionType, + FlightSqlServerBase::kCreatePreparedSubstraitPlanActionType, + FlightSqlServerBase::kClosePreparedStatementActionType, + FlightSqlServerBase::kEndSavepointActionType, + FlightSqlServerBase::kEndTransactionActionType, + }; return Status::OK(); } Status FlightSqlServerBase::DoAction(const ServerCallContext& context, const Action& action, std::unique_ptr* result_stream) { - if (action.type == FlightSqlServerBase::kCreatePreparedStatementActionType.type) { - google::protobuf::Any any_command; - if (!any_command.ParseFromArray(action.body->data(), - static_cast(action.body->size()))) { - return Status::Invalid("Unable to parse action."); - } + google::protobuf::Any any; + if (!any.ParseFromArray(action.body->data(), static_cast(action.body->size()))) { + return Status::Invalid("Unable to parse action"); + } + std::vector results; + if (action.type == FlightSqlServerBase::kBeginSavepointActionType.type) { + ARROW_ASSIGN_OR_RAISE(ActionBeginSavepointRequest internal_command, + ParseActionBeginSavepointRequest(any)); + ARROW_ASSIGN_OR_RAISE(ActionBeginSavepointResult result, + BeginSavepoint(context, internal_command)); + ARROW_ASSIGN_OR_RAISE(Result packed_result, PackActionResult(std::move(result))); + + results.push_back(std::move(packed_result)); + } else if (action.type == FlightSqlServerBase::kBeginTransactionActionType.type) { + ARROW_ASSIGN_OR_RAISE(ActionBeginTransactionRequest internal_command, + ParseActionBeginTransactionRequest(any)); + ARROW_ASSIGN_OR_RAISE(ActionBeginTransactionResult result, + BeginTransaction(context, internal_command)); + ARROW_ASSIGN_OR_RAISE(Result packed_result, PackActionResult(std::move(result))); + + results.push_back(std::move(packed_result)); + } else if (action.type == FlightSqlServerBase::kCancelQueryActionType.type) { + ARROW_ASSIGN_OR_RAISE(ActionCancelQueryRequest internal_command, + ParseActionCancelQueryRequest(any)); + ARROW_ASSIGN_OR_RAISE(CancelResult result, CancelQuery(context, internal_command)); + ARROW_ASSIGN_OR_RAISE(Result packed_result, PackActionResult(result)); + + results.push_back(std::move(packed_result)); + } else if (action.type == + FlightSqlServerBase::kCreatePreparedStatementActionType.type) { ARROW_ASSIGN_OR_RAISE(ActionCreatePreparedStatementRequest internal_command, - ParseActionCreatePreparedStatementRequest(any_command)); - ARROW_ASSIGN_OR_RAISE(auto result, CreatePreparedStatement(context, internal_command)) - - pb::sql::ActionCreatePreparedStatementResult action_result; - action_result.set_prepared_statement_handle(result.prepared_statement_handle); - if (result.dataset_schema != nullptr) { - ARROW_ASSIGN_OR_RAISE(auto serialized_dataset_schema, - ipc::SerializeSchema(*result.dataset_schema)) - action_result.set_dataset_schema(serialized_dataset_schema->ToString()); - } - if (result.parameter_schema != nullptr) { - ARROW_ASSIGN_OR_RAISE(auto serialized_parameter_schema, - ipc::SerializeSchema(*result.parameter_schema)) - action_result.set_parameter_schema(serialized_parameter_schema->ToString()); - } - - google::protobuf::Any any; - any.PackFrom(action_result); - - auto buf = Buffer::FromString(any.SerializeAsString()); - *result_stream = std::unique_ptr(new SimpleResultStream({Result{buf}})); - - return Status::OK(); + ParseActionCreatePreparedStatementRequest(any)); + ARROW_ASSIGN_OR_RAISE(ActionCreatePreparedStatementResult result, + CreatePreparedStatement(context, internal_command)); + ARROW_ASSIGN_OR_RAISE(Result packed_result, PackActionResult(std::move(result))); + + results.push_back(std::move(packed_result)); + } else if (action.type == + FlightSqlServerBase::kCreatePreparedSubstraitPlanActionType.type) { + ARROW_ASSIGN_OR_RAISE(ActionCreatePreparedSubstraitPlanRequest internal_command, + ParseActionCreatePreparedSubstraitPlanRequest(any)); + ARROW_ASSIGN_OR_RAISE(ActionCreatePreparedStatementResult result, + CreatePreparedSubstraitPlan(context, internal_command)); + ARROW_ASSIGN_OR_RAISE(Result packed_result, PackActionResult(std::move(result))); + + results.push_back(std::move(packed_result)); } else if (action.type == FlightSqlServerBase::kClosePreparedStatementActionType.type) { - google::protobuf::Any any; - if (!any.ParseFromArray(action.body->data(), static_cast(action.body->size()))) { - return Status::Invalid("Unable to parse action."); - } - ARROW_ASSIGN_OR_RAISE(ActionClosePreparedStatementRequest internal_command, ParseActionClosePreparedStatementRequest(any)); - ARROW_RETURN_NOT_OK(ClosePreparedStatement(context, internal_command)); - - // Need to instantiate a ResultStream, otherwise clients can not wait for completion. - *result_stream = std::unique_ptr(new SimpleResultStream({})); - return Status::OK(); + } else if (action.type == FlightSqlServerBase::kEndSavepointActionType.type) { + ARROW_ASSIGN_OR_RAISE(ActionEndSavepointRequest internal_command, + ParseActionEndSavepointRequest(any)); + ARROW_RETURN_NOT_OK(EndSavepoint(context, internal_command)); + } else if (action.type == FlightSqlServerBase::kEndTransactionActionType.type) { + ARROW_ASSIGN_OR_RAISE(ActionEndTransactionRequest internal_command, + ParseActionEndTransactionRequest(any)); + ARROW_RETURN_NOT_OK(EndTransaction(context, internal_command)); + } else { + return Status::NotImplemented("Action not implemented: ", action.type); } - return Status::Invalid("The defined request is invalid."); + *result_stream = + std::unique_ptr(new SimpleResultStream(std::move(results))); + return Status::OK(); } arrow::Result> FlightSqlServerBase::GetFlightInfoCatalogs( @@ -603,6 +845,19 @@ arrow::Result> FlightSqlServerBase::GetSchemaState return Status::NotImplemented("GetSchemaStatement not implemented"); } +arrow::Result> +FlightSqlServerBase::GetFlightInfoSubstraitPlan(const ServerCallContext& context, + const StatementSubstraitPlan& command, + const FlightDescriptor& descriptor) { + return Status::NotImplemented("GetFlightInfoSubstraitPlan not implemented"); +} + +arrow::Result> FlightSqlServerBase::GetSchemaSubstraitPlan( + const ServerCallContext& context, const StatementSubstraitPlan& command, + const FlightDescriptor& descriptor) { + return Status::NotImplemented("GetSchemaSubstraitPlan not implemented"); +} + arrow::Result> FlightSqlServerBase::DoGetStatement( const ServerCallContext& context, const StatementQueryTicket& command) { return Status::NotImplemented("DoGetStatement not implemented"); @@ -773,6 +1028,21 @@ arrow::Result> FlightSqlServerBase::DoGetCross return Status::NotImplemented("DoGetCrossReference not implemented"); } +arrow::Result FlightSqlServerBase::BeginSavepoint( + const ServerCallContext& context, const ActionBeginSavepointRequest& request) { + return Status::NotImplemented("BeginSavepoint not implemented"); +} + +arrow::Result FlightSqlServerBase::BeginTransaction( + const ServerCallContext& context, const ActionBeginTransactionRequest& request) { + return Status::NotImplemented("BeginTransaction not implemented"); +} + +arrow::Result FlightSqlServerBase::CancelQuery( + const ServerCallContext& context, const ActionCancelQueryRequest& request) { + return Status::NotImplemented("CancelQuery not implemented"); +} + arrow::Result FlightSqlServerBase::CreatePreparedStatement( const ServerCallContext& context, @@ -780,12 +1050,29 @@ FlightSqlServerBase::CreatePreparedStatement( return Status::NotImplemented("CreatePreparedStatement not implemented"); } +arrow::Result +FlightSqlServerBase::CreatePreparedSubstraitPlan( + const ServerCallContext& context, + const ActionCreatePreparedSubstraitPlanRequest& request) { + return Status::NotImplemented("CreatePreparedSubstraitPlan not implemented"); +} + Status FlightSqlServerBase::ClosePreparedStatement( const ServerCallContext& context, const ActionClosePreparedStatementRequest& request) { return Status::NotImplemented("ClosePreparedStatement not implemented"); } +Status FlightSqlServerBase::EndSavepoint(const ServerCallContext& context, + const ActionEndSavepointRequest& request) { + return Status::NotImplemented("EndSavepoint not implemented"); +} + +Status FlightSqlServerBase::EndTransaction(const ServerCallContext& context, + const ActionEndTransactionRequest& request) { + return Status::NotImplemented("EndTransaction not implemented"); +} + Status FlightSqlServerBase::DoPutPreparedStatementQuery( const ServerCallContext& context, const PreparedStatementQuery& command, FlightMessageReader* reader, FlightMetadataWriter* writer) { @@ -803,6 +1090,11 @@ arrow::Result FlightSqlServerBase::DoPutCommandStatementUpdate( return Status::NotImplemented("DoPutCommandStatementUpdate not implemented"); } +arrow::Result FlightSqlServerBase::DoPutCommandSubstraitPlan( + const ServerCallContext& context, const StatementSubstraitPlan& command) { + return Status::NotImplemented("DoPutCommandSubstraitPlan not implemented"); +} + std::shared_ptr SqlSchema::GetCatalogsSchema() { return arrow::schema({field("catalog_name", utf8(), false)}); } diff --git a/cpp/src/arrow/flight/sql/server.h b/cpp/src/arrow/flight/sql/server.h index 91dad98843f..0fc8b714865 100644 --- a/cpp/src/arrow/flight/sql/server.h +++ b/cpp/src/arrow/flight/sql/server.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include "arrow/flight/server.h" @@ -44,18 +45,32 @@ namespace sql { struct ARROW_FLIGHT_SQL_EXPORT StatementQuery { /// \brief The SQL query. std::string query; + /// \brief The transaction ID, if specified (else a blank string). + std::string transaction_id; +}; + +/// \brief A Substrait plan to execute. +struct ARROW_FLIGHT_SQL_EXPORT StatementSubstraitPlan { + /// \brief The Substrait plan. + SubstraitPlan plan; + /// \brief The transaction ID, if specified (else a blank string). + std::string transaction_id; }; /// \brief A SQL update query. struct ARROW_FLIGHT_SQL_EXPORT StatementUpdate { /// \brief The SQL query. std::string query; + /// \brief The transaction ID, if specified (else a blank string). + std::string transaction_id; }; /// \brief A request to execute a query. struct ARROW_FLIGHT_SQL_EXPORT StatementQueryTicket { /// \brief The server-generated opaque identifier for the query. std::string statement_handle; + + static arrow::Result Deserialize(std::string_view serialized); }; /// \brief A prepared query statement. @@ -132,10 +147,66 @@ struct ARROW_FLIGHT_SQL_EXPORT GetCrossReference { TableRef fk_table_ref; }; +/// \brief A request to start a new transaction. +struct ARROW_FLIGHT_SQL_EXPORT ActionBeginTransactionRequest {}; + +/// \brief A request to create a new savepoint. +struct ARROW_FLIGHT_SQL_EXPORT ActionBeginSavepointRequest { + std::string transaction_id; + std::string name; +}; + +/// \brief The result of starting a new savepoint. +struct ARROW_FLIGHT_SQL_EXPORT ActionBeginSavepointResult { + std::string savepoint_id; +}; + +/// \brief The result of starting a new transaction. +struct ARROW_FLIGHT_SQL_EXPORT ActionBeginTransactionResult { + std::string transaction_id; +}; + +/// \brief A request to end a savepoint. +struct ARROW_FLIGHT_SQL_EXPORT ActionEndSavepointRequest { + enum EndSavepoint { + kRelease, + kRollback, + }; + + std::string savepoint_id; + EndSavepoint action; +}; + +/// \brief A request to end a transaction. +struct ARROW_FLIGHT_SQL_EXPORT ActionEndTransactionRequest { + enum EndTransaction { + kCommit, + kRollback, + }; + + std::string transaction_id; + EndTransaction action; +}; + +/// \brief An explicit request to cancel a running query. +struct ARROW_FLIGHT_SQL_EXPORT ActionCancelQueryRequest { + std::unique_ptr info; +}; + /// \brief A request to create a new prepared statement. struct ARROW_FLIGHT_SQL_EXPORT ActionCreatePreparedStatementRequest { /// \brief The SQL query. std::string query; + /// \brief The transaction ID, if specified (else a blank string). + std::string transaction_id; +}; + +/// \brief A request to create a new prepared statement with a Substrait plan. +struct ARROW_FLIGHT_SQL_EXPORT ActionCreatePreparedSubstraitPlanRequest { + /// \brief The serialized Substrait plan. + SubstraitPlan plan; + /// \brief The transaction ID, if specified (else a blank string). + std::string transaction_id; }; /// \brief A request to close a prepared statement. @@ -189,6 +260,15 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { const ServerCallContext& context, const StatementQuery& command, const FlightDescriptor& descriptor); + /// \brief Get a FlightInfo for executing a Substrait plan. + /// \param[in] context Per-call context. + /// \param[in] command The StatementSubstraitPlan object containing the plan. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The FlightInfo describing where to access the dataset. + virtual arrow::Result> GetFlightInfoSubstraitPlan( + const ServerCallContext& context, const StatementSubstraitPlan& command, + const FlightDescriptor& descriptor); + /// \brief Get a FlightDataStream containing the query results. /// \param[in] context Per-call context. /// \param[in] command The StatementQueryTicket containing the statement handle. @@ -231,6 +311,15 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { const ServerCallContext& context, const StatementQuery& command, const FlightDescriptor& descriptor); + /// \brief Get the schema of the result set of a Substrait plan. + /// \param[in] context Per-call context. + /// \param[in] command The StatementQuery containing the plan. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The schema of the result set. + virtual arrow::Result> GetSchemaSubstraitPlan( + const ServerCallContext& context, const StatementSubstraitPlan& command, + const FlightDescriptor& descriptor); + /// \brief Get the schema of the result set of a prepared statement. /// \param[in] context Per-call context. /// \param[in] command The PreparedStatementQuery containing the @@ -423,7 +512,14 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { virtual arrow::Result DoPutCommandStatementUpdate( const ServerCallContext& context, const StatementUpdate& command); - /// \brief Create a prepared statement from given SQL statement. + /// \brief Execute an update Substrait plan. + /// \param[in] context The call context. + /// \param[in] command The StatementSubstraitPlan object containing the plan. + /// \return The changed record count. + virtual arrow::Result DoPutCommandSubstraitPlan( + const ServerCallContext& context, const StatementSubstraitPlan& command); + + /// \brief Create a prepared statement from a given SQL statement. /// \param[in] context The call context. /// \param[in] request The ActionCreatePreparedStatementRequest object containing the /// SQL statement. @@ -433,6 +529,16 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { const ServerCallContext& context, const ActionCreatePreparedStatementRequest& request); + /// \brief Create a prepared statement from a Substrait plan. + /// \param[in] context The call context. + /// \param[in] request The ActionCreatePreparedSubstraitPlanRequest object containing + /// the Substrait plan. + /// \return A ActionCreatePreparedStatementResult containing the dataset + /// and parameter schemas and a handle for created statement. + virtual arrow::Result CreatePreparedSubstraitPlan( + const ServerCallContext& context, + const ActionCreatePreparedSubstraitPlanRequest& request); + /// \brief Close a prepared statement. /// \param[in] context The call context. /// \param[in] request The ActionClosePreparedStatementRequest object containing the @@ -462,6 +568,39 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { const ServerCallContext& context, const PreparedStatementUpdate& command, FlightMessageReader* reader); + /// \brief Begin a new transaction. + /// \param[in] context The call context. + /// \param[in] request Request parameters. + /// \return The transaction ID. + virtual arrow::Result BeginTransaction( + const ServerCallContext& context, const ActionBeginTransactionRequest& request); + + /// \brief Create a new savepoint. + /// \param[in] context The call context. + /// \param[in] request Request parameters. + /// \return The savepoint ID. + virtual arrow::Result BeginSavepoint( + const ServerCallContext& context, const ActionBeginSavepointRequest& request); + + /// \brief Release/rollback a savepoint. + /// \param[in] context The call context. + /// \param[in] request The savepoint. + virtual Status EndSavepoint(const ServerCallContext& context, + const ActionEndSavepointRequest& request); + + /// \brief Commit/rollback a transaction. + /// \param[in] context The call context. + /// \param[in] request The tranaction. + virtual Status EndTransaction(const ServerCallContext& context, + const ActionEndTransactionRequest& request); + + /// \brief Attempt to explicitly cancel a query. + /// \param[in] context The call context. + /// \param[in] request The query to cancel. + /// \return The cancellation result. + virtual arrow::Result CancelQuery( + const ServerCallContext& context, const ActionCancelQueryRequest& request); + /// @} /// \name Utility methods @@ -492,16 +631,46 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { std::unique_ptr reader, std::unique_ptr writer) final; + const ActionType kBeginSavepointActionType = + ActionType{"BeginSavepoint", + "Create a new savepoint.\n" + "Request Message: ActionBeginSavepointRequest\n" + "Response Message: ActionBeginSavepointResult"}; + const ActionType kBeginTransactionActionType = + ActionType{"BeginTransaction", + "Start a new transaction.\n" + "Request Message: ActionBeginTransactionRequest\n" + "Response Message: ActionBeginTransactionResult"}; const ActionType kCreatePreparedStatementActionType = ActionType{"CreatePreparedStatement", "Creates a reusable prepared statement resource on the server.\n" "Request Message: ActionCreatePreparedStatementRequest\n" "Response Message: ActionCreatePreparedStatementResult"}; + const ActionType kCreatePreparedSubstraitPlanActionType = + ActionType{"CreatePreparedSubstraitPlan", + "Creates a reusable prepared statement resource on the server.\n" + "Request Message: ActionCreatePreparedSubstraitPlanRequest\n" + "Response Message: ActionCreatePreparedStatementResult"}; + const ActionType kCancelQueryActionType = + ActionType{"CancelQuery", + "Explicitly cancel a running query.\n" + "Request Message: ActionCancelQueryRequest\n" + "Response Message: ActionCancelQueryResult"}; const ActionType kClosePreparedStatementActionType = ActionType{"ClosePreparedStatement", "Closes a reusable prepared statement resource on the server.\n" "Request Message: ActionClosePreparedStatementRequest\n" "Response Message: N/A"}; + const ActionType kEndSavepointActionType = + ActionType{"EndSavepoint", + "End a savepoint.\n" + "Request Message: ActionEndSavepointRequest\n" + "Response Message: N/A"}; + const ActionType kEndTransactionActionType = + ActionType{"EndTransaction", + "End a savepoint.\n" + "Request Message: ActionEndTransactionRequest\n" + "Response Message: N/A"}; Status ListActions(const ServerCallContext& context, std::vector* actions) final; diff --git a/cpp/src/arrow/flight/sql/server_test.cc b/cpp/src/arrow/flight/sql/server_test.cc index 7ba3ca4a243..785f45551fc 100644 --- a/cpp/src/arrow/flight/sql/server_test.cc +++ b/cpp/src/arrow/flight/sql/server_test.cc @@ -24,9 +24,6 @@ #include #include -#include -#include - #include "arrow/flight/api.h" #include "arrow/flight/sql/api.h" #include "arrow/flight/sql/column_metadata.h" @@ -153,17 +150,12 @@ class TestFlightSqlServer : public ::testing::Test { protected: void SetUp() override { - port = GetListenPort(); - server_thread.reset(new std::thread([&]() { RunServer(); })); - - std::unique_lock lk(server_ready_m); - server_ready_cv.wait(lk); - - std::stringstream ss; - ss << "grpc://localhost:" << port; - std::string uri = ss.str(); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("0.0.0.0", 0)); + arrow::flight::FlightServerOptions options(location); + ASSERT_OK_AND_ASSIGN(server, example::SQLiteFlightSqlServer::Create()); + ASSERT_OK(server->Init(options)); - ASSERT_OK_AND_ASSIGN(auto location, Location::Parse(uri)); + ASSERT_OK_AND_ASSIGN(location, Location::ForGrpcTcp("localhost", server->port())); ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(location)); sql_client.reset(new FlightSqlClient(std::move(client))); @@ -174,30 +166,10 @@ class TestFlightSqlServer : public ::testing::Test { sql_client.reset(); ASSERT_OK(server->Shutdown()); - server_thread->join(); - server_thread.reset(); } private: - int port; std::shared_ptr server; - std::unique_ptr server_thread; - std::condition_variable server_ready_cv; - std::mutex server_ready_m; - - void RunServer() { - ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", port)); - arrow::flight::FlightServerOptions options(location); - - ARROW_CHECK_OK(example::SQLiteFlightSqlServer::Create().Value(&server)); - - ARROW_CHECK_OK(server->Init(options)); - // Exit with a clean error code (0) on SIGTERM - ARROW_CHECK_OK(server->SetShutdownOnSignals({SIGTERM})); - - server_ready_cv.notify_all(); - ARROW_CHECK_OK(server->Serve()); - } }; TEST_F(TestFlightSqlServer, TestCommandStatementQuery) { @@ -802,6 +774,51 @@ TEST_F(TestFlightSqlServer, TestCommandGetSqlInfoNoInfo) { sql_client->DoGet(call_options, flight_info->endpoints()[0].ticket)); } +TEST_F(TestFlightSqlServer, CancelQuery) { + // Not supported + ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetSqlInfo({}, {})); + ASSERT_RAISES(NotImplemented, sql_client->CancelQuery({}, *flight_info)); +} + +TEST_F(TestFlightSqlServer, Transactions) { + ASSERT_OK_AND_ASSIGN(auto handle, sql_client->BeginTransaction({})); + ASSERT_TRUE(handle.is_valid()); + ASSERT_NE(handle.transaction_id(), ""); + ASSERT_RAISES(NotImplemented, sql_client->BeginSavepoint({}, handle, "savepoint")); + + ASSERT_OK_AND_ASSIGN(auto flight_info, + sql_client->Execute({}, "SELECT * FROM intTable", handle)); + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); + int64_t row_count = table->num_rows(); + + int64_t result; + ASSERT_OK_AND_ASSIGN(result, + sql_client->ExecuteUpdate( + {}, + "INSERT INTO intTable (keyName, value) VALUES " + "('KEYNAME1', 1001), ('KEYNAME2', 1002), ('KEYNAME3', 1003)", + handle)); + ASSERT_EQ(3, result); + + ASSERT_OK_AND_ASSIGN(flight_info, + sql_client->Execute({}, "SELECT * FROM intTable", handle)); + ASSERT_OK_AND_ASSIGN(stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + ASSERT_OK_AND_ASSIGN(table, stream->ToTable()); + ASSERT_EQ(table->num_rows(), row_count + 3); + + ASSERT_OK(sql_client->Rollback({}, handle)); + // Commit/rollback invalidate the handle + ASSERT_RAISES(KeyError, sql_client->Rollback({}, handle)); + ASSERT_RAISES(KeyError, sql_client->Commit({}, handle)); + + ASSERT_OK_AND_ASSIGN(flight_info, sql_client->Execute({}, "SELECT * FROM intTable")); + ASSERT_OK_AND_ASSIGN(stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + ASSERT_OK_AND_ASSIGN(table, stream->ToTable()); + ASSERT_EQ(table->num_rows(), row_count); +} + } // namespace sql } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/sql/types.h b/cpp/src/arrow/flight/sql/types.h index 20c7952d8d7..8b28ed18bdd 100644 --- a/cpp/src/arrow/flight/sql/types.h +++ b/cpp/src/arrow/flight/sql/types.h @@ -18,7 +18,7 @@ #pragma once #include -#include +#include #include #include #include @@ -70,6 +70,54 @@ struct ARROW_FLIGHT_SQL_EXPORT SqlInfoOptions { /// - true: if read only FLIGHT_SQL_SERVER_READ_ONLY = 3, + /// Retrieves a boolean value indicating whether the Flight SQL Server + /// supports executing SQL queries. + /// + /// Note that the absence of this info (as opposed to a false + /// value) does not necessarily mean that SQL is not supported, as + /// this property was not originally defined. + FLIGHT_SQL_SERVER_SQL = 4, + + /// Retrieves a boolean value indicating whether the Flight SQL Server + /// supports executing Substrait plans. + FLIGHT_SQL_SERVER_SUBSTRAIT = 5, + + /// Retrieves a string value indicating the minimum supported + /// Substrait version, or null if Substrait is not supported. + FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION = 6, + + /// Retrieves a string value indicating the maximum supported + /// Substrait version, or null if Substrait is not supported. + FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION = 7, + + /// Retrieves an int32 indicating whether the Flight SQL Server + /// supports the BeginTransaction, EndTransaction, BeginSavepoint, + /// and EndSavepoint actions. + /// + /// Even if this is not supported, the database may still support + /// explicit "BEGIN TRANSACTION"/"COMMIT" SQL statements (see + /// SQL_TRANSACTIONS_SUPPORTED); this property is only about + /// whether the server implements the Flight SQL API endpoints. + /// + /// The possible values are listed in `SqlSupportedTransaction`. + FLIGHT_SQL_SERVER_TRANSACTION = 8, + + /// Retrieves a boolean value indicating whether the Flight SQL Server + /// supports explicit query cancellation (the CancelQuery action). + FLIGHT_SQL_SERVER_CANCEL = 9, + + /// Retrieves an int32 value indicating the timeout (in milliseconds) for + /// prepared statement handles. + /// + /// If 0, there is no timeout. + FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT = 100, + + /// Retrieves an int32 value indicating the timeout (in milliseconds) for + /// transactions, since transactions are not tied to a connection. + /// + /// If 0, there is no timeout. + FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT = 101, + /// @} /// \name SQL Syntax Information @@ -795,6 +843,16 @@ struct ARROW_FLIGHT_SQL_EXPORT SqlInfoOptions { /// @} }; + /// The level of support for Flight SQL transaction RPCs. + enum SqlSupportedTransaction { + /// Unknown/not indicated/no supoprt + SQL_SUPPORTED_TRANSACTION_NONE = 0, + /// Transactions, but not savepoints. + SQL_SUPPORTED_TRANSACTION_TRANSACTION = 1, + /// Transactions and savepoints. + SQL_SUPPORTED_TRANSACTION_SAVEPOINT = 2, + }; + /// Indicate whether something (e.g. an identifier) is case-sensitive. enum SqlSupportedCaseSensitivity { SQL_CASE_SENSITIVITY_UNKNOWN = 0, @@ -845,6 +903,25 @@ struct ARROW_FLIGHT_SQL_EXPORT TableRef { std::string table; }; +/// \brief A Substrait plan to be executed, along with associated metadata. +struct ARROW_FLIGHT_SQL_EXPORT SubstraitPlan { + /// \brief The serialized plan. + std::string plan; + /// \brief The Substrait release, e.g. "0.12.0". + std::string version; +}; + +/// \brief The result of cancelling a query. +enum class CancelResult : int8_t { + kUnspecified, + kCancelled, + kCancelling, + kNotCancellable, +}; + +ARROW_FLIGHT_SQL_EXPORT +std::ostream& operator<<(std::ostream& os, CancelResult result); + /// @} } // namespace sql diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py index 05f945cb824..887cbf92fed 100644 --- a/dev/archery/archery/integration/runner.py +++ b/dev/archery/archery/integration/runner.py @@ -435,6 +435,11 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, description="Ensure Flight SQL protocol is working as expected.", skip={"Rust"} ), + Scenario( + "flight_sql:extension", + description="Ensure Flight SQL extensions work as expected.", + skip={"Rust", "Go"} + ), ] runner = IntegrationRunner(json_files, flight_scenarios, testers, **kwargs) diff --git a/docs/source/status.rst b/docs/source/status.rst index a5dd47c0b81..fc637872255 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -232,10 +232,22 @@ support/not support individual features. +--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ | Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | +============================================+=======+=======+=======+============+=======+=======+=======+ +| BeginSavepoint | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| BeginTransaction | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| CancelQuery | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ | ClosePreparedStatement | ✓ | ✓ | ✓ | | | | | +--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ | CreatePreparedStatement | ✓ | ✓ | ✓ | | | | | +--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| CreatePreparedSubstraitPlan | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| EndSavepoint | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| EndTransaction | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ | GetCatalogs | ✓ | ✓ | ✓ | | | | | +--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ | GetCrossReference | ✓ | ✓ | ✓ | | | | | @@ -260,6 +272,8 @@ support/not support individual features. +--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ | PreparedStatementUpdate | ✓ | ✓ | ✓ | | | | | +--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| StatementSubstraitPlan | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ | StatementQuery | ✓ | ✓ | ✓ | | | | | +--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ | StatementUpdate | ✓ | ✓ | ✓ | | | | | diff --git a/format/FlightSql.proto b/format/FlightSql.proto index 859427b6880..d8a6cb5bfdb 100644 --- a/format/FlightSql.proto +++ b/format/FlightSql.proto @@ -90,6 +90,64 @@ enum SqlInfo { */ FLIGHT_SQL_SERVER_READ_ONLY = 3; + /* + * Retrieves a boolean value indicating whether the Flight SQL Server supports executing + * SQL queries. + * + * Note that the absence of this info (as opposed to a false value) does not necessarily + * mean that SQL is not supported, as this property was not originally defined. + */ + FLIGHT_SQL_SERVER_SQL = 4; + + /* + * Retrieves a boolean value indicating whether the Flight SQL Server supports executing + * Substrait plans. + */ + FLIGHT_SQL_SERVER_SUBSTRAIT = 5; + + /* + * Retrieves a string value indicating the minimum supported Substrait version, or null + * if Substrait is not supported. + */ + FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION = 6; + + /* + * Retrieves a string value indicating the maximum supported Substrait version, or null + * if Substrait is not supported. + */ + FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION = 7; + + /* + * Retrieves an int32 indicating whether the Flight SQL Server supports the + * BeginTransaction/EndTransaction/BeginSavepoint/EndSavepoint actions. + * + * Even if this is not supported, the database may still support explicit "BEGIN + * TRANSACTION"/"COMMIT" SQL statements (see SQL_TRANSACTIONS_SUPPORTED); this property + * is only about whether the server implements the Flight SQL API endpoints. + * + * The possible values are listed in `SqlSupportedTransaction`. + */ + FLIGHT_SQL_SERVER_TRANSACTION = 8; + + /* + * Retrieves a boolean value indicating whether the Flight SQL Server supports explicit + * query cancellation (the CancelQuery action). + */ + FLIGHT_SQL_SERVER_CANCEL = 9; + + /* + * Retrieves an int32 indicating the timeout (in milliseconds) for prepared statement handles. + * + * If 0, there is no timeout. Servers should reset the timeout when the handle is used in a command. + */ + FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT = 100; + + /* + * Retrieves an int32 indicating the timeout (in milliseconds) for transactions, since transactions are not tied to a connection. + * + * If 0, there is no timeout. Servers should reset the timeout when the handle is used in a command. + */ + FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT = 101; // SQL Syntax Information [500-1000): provides information about SQL syntax supported by the Flight SQL Server. @@ -761,6 +819,18 @@ enum SqlInfo { SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED = 576; } +// The level of support for Flight SQL transaction RPCs. +enum SqlSupportedTransaction { + // Unknown/not indicated/no support + SQL_SUPPORTED_TRANSACTION_NONE = 0; + // Transactions, but not savepoints. + // A savepoint is a mark within a transaction that can be individually + // rolled back to. Not all databases support savepoints. + SQL_SUPPORTED_TRANSACTION_TRANSACTION = 1; + // Transactions and savepoints + SQL_SUPPORTED_TRANSACTION_SAVEPOINT = 2; +} + enum SqlSupportedCaseSensitivity { SQL_CASE_SENSITIVITY_UNKNOWN = 0; SQL_CASE_SENSITIVITY_CASE_INSENSITIVE = 1; @@ -1406,7 +1476,7 @@ message CommandGetCrossReference { string fk_table = 6; } -// SQL Execution Action Messages +// Query Execution Action Messages /* * Request message for the "CreatePreparedStatement" action on a Flight SQL enabled backend. @@ -1416,14 +1486,49 @@ message ActionCreatePreparedStatementRequest { // The valid SQL string to create a prepared statement for. string query = 1; + // Create/execute the prepared statement as part of this transaction (if + // unset, executions of the prepared statement will be auto-committed). + optional bytes transaction_id = 2; } /* - * Wrap the result of a "GetPreparedStatement" action. + * An embedded message describing a Substrait plan to execute. + */ +message SubstraitPlan { + option (experimental) = true; + + // The serialized substrait.Plan to create a prepared statement for. + // XXX(ARROW-16902): this is bytes instead of an embedded message + // because Protobuf does not really support one DLL using Protobuf + // definitions from another DLL. + bytes plan = 1; + // The Substrait release, e.g. "0.12.0". This information is not + // tracked in the plan itself, so this is the only way for consumers + // to potentially know if they can handle the plan. + string version = 2; +} + +/* + * Request message for the "CreatePreparedSubstraitPlan" action on a Flight SQL enabled backend. + */ +message ActionCreatePreparedSubstraitPlanRequest { + option (experimental) = true; + + // The serialized substrait.Plan to create a prepared statement for. + SubstraitPlan plan = 1; + // Create/execute the prepared statement as part of this transaction (if + // unset, executions of the prepared statement will be auto-committed). + optional bytes transaction_id = 2; +} + +/* + * Wrap the result of a "CreatePreparedStatement" or "CreatePreparedSubstraitPlan" action. * * The resultant PreparedStatement can be closed either: * - Manually, through the "ClosePreparedStatement" action; * - Automatically, by a server timeout. + * + * The result should be wrapped in a google.protobuf.Any message. */ message ActionCreatePreparedStatementResult { option (experimental) = true; @@ -1451,8 +1556,113 @@ message ActionClosePreparedStatementRequest { bytes prepared_statement_handle = 1; } +/* + * Request message for the "BeginTransaction" action. + * Begins a transaction. + */ +message ActionBeginTransactionRequest { + option (experimental) = true; +} + +/* + * Request message for the "BeginSavepoint" action. + * Creates a savepoint within a transaction. + * + * Only supported if FLIGHT_SQL_TRANSACTION is + * FLIGHT_SQL_TRANSACTION_SUPPORT_SAVEPOINT. + */ +message ActionBeginSavepointRequest { + option (experimental) = true; + + // The transaction to which a savepoint belongs. + bytes transaction_id = 1; + // Name for the savepoint. + string name = 2; +} + +/* + * The result of a "BeginTransaction" action. + * + * The transaction can be manipulated with the "EndTransaction" action, or + * automatically via server timeout. If the transaction times out, then it is + * automatically rolled back. + * + * The result should be wrapped in a google.protobuf.Any message. + */ +message ActionBeginTransactionResult { + option (experimental) = true; + + // Opaque handle for the transaction on the server. + bytes transaction_id = 1; +} + +/* + * The result of a "BeginSavepoint" action. + * + * The transaction can be manipulated with the "EndSavepoint" action. + * If the associated transaction is committed, rolled back, or times + * out, then the savepoint is also invalidated. + * + * The result should be wrapped in a google.protobuf.Any message. + */ +message ActionBeginSavepointResult { + option (experimental) = true; + + // Opaque handle for the savepoint on the server. + bytes savepoint_id = 1; +} + +/* + * Request message for the "EndTransaction" action. + * + * Commit (COMMIT) or rollback (ROLLBACK) the transaction. + * + * If the action completes successfully, the transaction handle is + * invalidated, as are all associated savepoints. + */ +message ActionEndTransactionRequest { + option (experimental) = true; -// SQL Execution Messages. + enum EndTransaction { + END_TRANSACTION_UNSPECIFIED = 0; + // Commit the transaction. + END_TRANSACTION_COMMIT = 1; + // Roll back the transaction. + END_TRANSACTION_ROLLBACK = 2; + } + // Opaque handle for the transaction on the server. + bytes transaction_id = 1; + // Whether to commit/rollback the given transaction. + EndTransaction action = 2; +} + +/* + * Request message for the "EndSavepoint" action. + * + * Release (RELEASE) the savepoint or rollback (ROLLBACK) to the + * savepoint. + * + * Releasing a savepoint invalidates that savepoint. Rolling back to + * a savepoint does not invalidate the savepoint, but invalidates all + * savepoints created after the current savepoint. + */ +message ActionEndSavepointRequest { + option (experimental) = true; + + enum EndSavepoint { + END_SAVEPOINT_UNSPECIFIED = 0; + // Release the savepoint. + END_SAVEPOINT_RELEASE = 1; + // Roll back to a savepoint. + END_SAVEPOINT_ROLLBACK = 2; + } + // Opaque handle for the savepoint on the server. + bytes savepoint_id = 1; + // Whether to rollback/release the given savepoint. + EndSavepoint action = 2; +} + +// Query Execution Messages. /* * Represents a SQL query. Used in the command member of FlightDescriptor @@ -1476,6 +1686,35 @@ message CommandStatementQuery { // The SQL syntax. string query = 1; + // Include the query as part of this transaction (if unset, the query is auto-committed). + optional bytes transaction_id = 2; +} + +/* + * Represents a Substrait plan. Used in the command member of FlightDescriptor + * for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * Fields on this schema may contain the following metadata: + * - ARROW:FLIGHT:SQL:CATALOG_NAME - Table's catalog name + * - ARROW:FLIGHT:SQL:DB_SCHEMA_NAME - Database schema name + * - ARROW:FLIGHT:SQL:TABLE_NAME - Table name + * - ARROW:FLIGHT:SQL:TYPE_NAME - The data source-specific name for the data type of the column. + * - ARROW:FLIGHT:SQL:PRECISION - Column precision/size + * - ARROW:FLIGHT:SQL:SCALE - Column scale/decimal digits if applicable + * - ARROW:FLIGHT:SQL:IS_AUTO_INCREMENT - "1" indicates if the column is auto incremented, "0" otherwise. + * - ARROW:FLIGHT:SQL:IS_CASE_SENSITIVE - "1" indicates if the column is case sensitive, "0" otherwise. + * - ARROW:FLIGHT:SQL:IS_READ_ONLY - "1" indicates if the column is read only, "0" otherwise. + * - ARROW:FLIGHT:SQL:IS_SEARCHABLE - "1" indicates if the column is searchable via WHERE clause, "0" otherwise. + * - GetFlightInfo: execute the query. + * - DoPut: execute the query. + */ +message CommandStatementSubstraitPlan { + option (experimental) = true; + + // A serialized substrait.Plan + SubstraitPlan plan = 1; + // Include the query as part of this transaction (if unset, the query is auto-committed). + optional bytes transaction_id = 2; } /** @@ -1523,6 +1762,8 @@ message CommandStatementUpdate { // The SQL syntax. string query = 1; + // Include the query as part of this transaction (if unset, the query is auto-committed). + optional bytes transaction_id = 2; } /* @@ -1550,6 +1791,57 @@ message DoPutUpdateResult { int64 record_count = 1; } +/* + * Request message for the "CancelQuery" action. + * + * Explicitly cancel a running query. + * + * This lets a single client explicitly cancel work, no matter how many clients + * are involved/whether the query is distributed or not, given server support. + * The transaction/statement is not rolled back; it is the application's job to + * commit or rollback as appropriate. This only indicates the client no longer + * wishes to read the remainder of the query results or continue submitting + * data. + * + * This command is idempotent. + */ +message ActionCancelQueryRequest { + option (experimental) = true; + + // The result of the GetFlightInfo RPC that initiated the query. + // XXX(ARROW-16902): this must be a serialized FlightInfo, but is + // rendered as bytes because Protobuf does not really support one + // DLL using Protobuf definitions from another DLL. + bytes info = 1; +} + +/* + * The result of cancelling a query. + * + * The result should be wrapped in a google.protobuf.Any message. + */ +message ActionCancelQueryResult { + option (experimental) = true; + + enum CancelResult { + // The cancellation status is unknown. Servers should avoid using + // this value (send a NOT_FOUND error if the requested query is + // not known). Clients can retry the request. + CANCEL_RESULT_UNSPECIFIED = 0; + // The cancellation request is complete. Subsequent requests with + // the same payload may return CANCELLED or a NOT_FOUND error. + CANCEL_RESULT_CANCELLED = 1; + // The cancellation request is in progress. The client may retry + // the cancellation request. + CANCEL_RESULT_CANCELLING = 2; + // The query is not cancellable. The client should not retry the + // cancellation request. + CANCEL_RESULT_NOT_CANCELLABLE = 3; + } + + CancelResult result = 1; +} + extend google.protobuf.MessageOptions { bool experimental = 1000; } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java index 762b37859b9..1f50f50a293 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -292,7 +292,12 @@ public FlightInfo getInfo(FlightDescriptor descriptor, CallOption... options) { * @param options RPC-layer hints for this call. */ public SchemaResult getSchema(FlightDescriptor descriptor, CallOption... options) { - return SchemaResult.fromProtocol(CallOptions.wrapStub(blockingStub, options).getSchema(descriptor.toProtocol())); + try { + return SchemaResult.fromProtocol(CallOptions.wrapStub(blockingStub, options) + .getSchema(descriptor.toProtocol())); + } catch (StatusRuntimeException sre) { + throw StatusUtils.fromGrpcRuntimeException(sre); + } } /** diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java index 4fb0dea2cba..29a4f2bbd19 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java @@ -231,7 +231,7 @@ public StreamObserver doPutCustom(final StreamObserver { try { producer.acceptPut(makeContext(responseObserver), fs, ackStream).run(); - } catch (Exception ex) { + } catch (Throwable ex) { ackStream.onError(ex); } finally { // ARROW-6136: Close the stream if and only if acceptPut hasn't closed it itself diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlExtensionScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlExtensionScenario.java new file mode 100644 index 00000000000..cd20ae4f46f --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlExtensionScenario.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.CancelResult; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * Integration test scenario for validating Flight SQL specs across multiple implementations. + * This should ensure that RPC objects are being built and parsed correctly for multiple languages + * and that the Arrow schemas are returned as expected. + */ +public class FlightSqlExtensionScenario extends FlightSqlScenario { + @Override + public void client(BufferAllocator allocator, Location location, FlightClient client) + throws Exception { + try (final FlightSqlClient sqlClient = new FlightSqlClient(client)) { + validateMetadataRetrieval(sqlClient); + validateStatementExecution(sqlClient); + validatePreparedStatementExecution(allocator, sqlClient); + validateTransactions(allocator, sqlClient); + } + } + + private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Exception { + FlightInfo info = sqlClient.getSqlInfo(); + Ticket ticket = info.getEndpoints().get(0).getTicket(); + + Map infoValues = new HashMap<>(); + try (FlightStream stream = sqlClient.getStream(ticket)) { + Schema actualSchema = stream.getSchema(); + IntegrationAssertions.assertEquals(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, actualSchema); + + while (stream.next()) { + UInt4Vector infoName = (UInt4Vector) stream.getRoot().getVector(0); + DenseUnionVector value = (DenseUnionVector) stream.getRoot().getVector(1); + + for (int i = 0; i < stream.getRoot().getRowCount(); i++) { + final int code = infoName.get(i); + if (infoValues.containsKey(code)) { + throw new AssertionError("Duplicate SqlInfo value: " + code); + } + Object object; + byte typeId = value.getTypeId(i); + switch (typeId) { + case 0: // string + object = Preconditions.checkNotNull(value.getVarCharVector(typeId) + .getObject(value.getOffset(i))) + .toString(); + break; + case 1: // bool + object = value.getBitVector(typeId).getObject(value.getOffset(i)); + break; + case 2: // int64 + object = value.getBigIntVector(typeId).getObject(value.getOffset(i)); + break; + case 3: // int32 + object = value.getIntVector(typeId).getObject(value.getOffset(i)); + break; + default: + throw new AssertionError("Decoding SqlInfo of type code " + typeId); + } + infoValues.put(code, object); + } + } + } + + IntegrationAssertions.assertEquals(Boolean.FALSE, + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SQL_VALUE)); + IntegrationAssertions.assertEquals(Boolean.TRUE, + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_VALUE)); + IntegrationAssertions.assertEquals("min_version", + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION_VALUE)); + IntegrationAssertions.assertEquals("max_version", + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION_VALUE)); + IntegrationAssertions.assertEquals(FlightSql.SqlSupportedTransaction.SQL_SUPPORTED_TRANSACTION_SAVEPOINT_VALUE, + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_VALUE)); + IntegrationAssertions.assertEquals(Boolean.TRUE, + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_CANCEL_VALUE)); + IntegrationAssertions.assertEquals(42, + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT_VALUE)); + IntegrationAssertions.assertEquals(7, + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT_VALUE)); + } + + private void validateStatementExecution(FlightSqlClient sqlClient) throws Exception { + FlightInfo info = sqlClient.executeSubstrait(SUBSTRAIT_PLAN); + validate(FlightSqlScenarioProducer.getQuerySchema(), info, sqlClient); + + SchemaResult result = sqlClient.getExecuteSubstraitSchema(SUBSTRAIT_PLAN); + validateSchema(FlightSqlScenarioProducer.getQuerySchema(), result); + + IntegrationAssertions.assertEquals(CancelResult.CANCELLED, sqlClient.cancelQuery(info)); + + IntegrationAssertions.assertEquals(sqlClient.executeSubstraitUpdate(SUBSTRAIT_PLAN), + UPDATE_STATEMENT_EXPECTED_ROWS); + } + + private void validatePreparedStatementExecution(BufferAllocator allocator, + FlightSqlClient sqlClient) throws Exception { + try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare(SUBSTRAIT_PLAN); + VectorSchemaRoot parameters = VectorSchemaRoot.create( + FlightSqlScenarioProducer.getQuerySchema(), allocator)) { + parameters.setRowCount(1); + preparedStatement.setParameters(parameters); + validate(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.execute(), sqlClient); + validateSchema(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.fetchSchema()); + } + + try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare(SUBSTRAIT_PLAN)) { + IntegrationAssertions.assertEquals(preparedStatement.executeUpdate(), + UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS); + } + } + + private void validateTransactions(BufferAllocator allocator, FlightSqlClient sqlClient) throws Exception { + final FlightSqlClient.Transaction transaction = sqlClient.beginTransaction(); + IntegrationAssertions.assertEquals(TRANSACTION_ID, transaction.getTransactionId()); + + final FlightSqlClient.Savepoint savepoint = sqlClient.beginSavepoint(transaction, SAVEPOINT_NAME); + IntegrationAssertions.assertEquals(SAVEPOINT_ID, savepoint.getSavepointId()); + + FlightInfo info = sqlClient.execute("SELECT STATEMENT", transaction); + validate(FlightSqlScenarioProducer.getQueryWithTransactionSchema(), info, sqlClient); + + info = sqlClient.executeSubstrait(SUBSTRAIT_PLAN, transaction); + validate(FlightSqlScenarioProducer.getQueryWithTransactionSchema(), info, sqlClient); + + SchemaResult schema = sqlClient.getExecuteSchema("SELECT STATEMENT", transaction); + validateSchema(FlightSqlScenarioProducer.getQueryWithTransactionSchema(), schema); + + schema = sqlClient.getExecuteSubstraitSchema(SUBSTRAIT_PLAN, transaction); + validateSchema(FlightSqlScenarioProducer.getQueryWithTransactionSchema(), schema); + + IntegrationAssertions.assertEquals(sqlClient.executeUpdate("UPDATE STATEMENT", transaction), + UPDATE_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS); + IntegrationAssertions.assertEquals(sqlClient.executeSubstraitUpdate(SUBSTRAIT_PLAN, transaction), + UPDATE_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS); + + try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare( + "SELECT PREPARED STATEMENT", transaction); + VectorSchemaRoot parameters = VectorSchemaRoot.create( + FlightSqlScenarioProducer.getQuerySchema(), allocator)) { + parameters.setRowCount(1); + preparedStatement.setParameters(parameters); + validate(FlightSqlScenarioProducer.getQueryWithTransactionSchema(), preparedStatement.execute(), sqlClient); + schema = preparedStatement.fetchSchema(); + validateSchema(FlightSqlScenarioProducer.getQueryWithTransactionSchema(), schema); + } + + try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare(SUBSTRAIT_PLAN, transaction); + VectorSchemaRoot parameters = VectorSchemaRoot.create( + FlightSqlScenarioProducer.getQuerySchema(), allocator)) { + parameters.setRowCount(1); + preparedStatement.setParameters(parameters); + validate(FlightSqlScenarioProducer.getQueryWithTransactionSchema(), preparedStatement.execute(), sqlClient); + schema = preparedStatement.fetchSchema(); + validateSchema(FlightSqlScenarioProducer.getQueryWithTransactionSchema(), schema); + } + + try (FlightSqlClient.PreparedStatement preparedStatement = + sqlClient.prepare("UPDATE PREPARED STATEMENT", transaction)) { + IntegrationAssertions.assertEquals(preparedStatement.executeUpdate(), + UPDATE_PREPARED_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS); + } + + try (FlightSqlClient.PreparedStatement preparedStatement = + sqlClient.prepare(SUBSTRAIT_PLAN, transaction)) { + IntegrationAssertions.assertEquals(preparedStatement.executeUpdate(), + UPDATE_PREPARED_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS); + } + + sqlClient.rollback(savepoint); + + final FlightSqlClient.Savepoint savepoint2 = sqlClient.beginSavepoint(transaction, SAVEPOINT_NAME); + IntegrationAssertions.assertEquals(SAVEPOINT_ID, savepoint2.getSavepointId()); + sqlClient.release(savepoint); + + sqlClient.commit(transaction); + + final FlightSqlClient.Transaction transaction2 = sqlClient.beginTransaction(); + IntegrationAssertions.assertEquals(TRANSACTION_ID, transaction2.getTransactionId()); + sqlClient.rollback(transaction); + } +} diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java index 19c1378cfe6..71f1f741d58 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java @@ -17,6 +17,7 @@ package org.apache.arrow.flight.integration.tests; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import org.apache.arrow.flight.CallOption; @@ -42,9 +43,17 @@ * and that the Arrow schemas are returned as expected. */ public class FlightSqlScenario implements Scenario { - public static final long UPDATE_STATEMENT_EXPECTED_ROWS = 10000L; + public static final long UPDATE_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS = 15000L; public static final long UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS = 20000L; + public static final long UPDATE_PREPARED_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS = 25000L; + public static final byte[] SAVEPOINT_ID = "savepoint_id".getBytes(StandardCharsets.UTF_8); + public static final String SAVEPOINT_NAME = "savepoint_name"; + public static final byte[] SUBSTRAIT_PLAN_TEXT = "plan".getBytes(StandardCharsets.UTF_8); + public static final String SUBSTRAIT_VERSION = "version"; + public static final FlightSqlClient.SubstraitPlan SUBSTRAIT_PLAN = + new FlightSqlClient.SubstraitPlan(SUBSTRAIT_PLAN_TEXT, SUBSTRAIT_VERSION); + public static final byte[] TRANSACTION_ID = "transaction_id".getBytes(StandardCharsets.UTF_8); @Override public FlightProducer producer(BufferAllocator allocator, Location location) throws Exception { @@ -59,13 +68,11 @@ public void buildServer(FlightServer.Builder builder) throws Exception { @Override public void client(BufferAllocator allocator, Location location, FlightClient client) throws Exception { - final FlightSqlClient sqlClient = new FlightSqlClient(client); - - validateMetadataRetrieval(sqlClient); - - validateStatementExecution(sqlClient); - - validatePreparedStatementExecution(sqlClient, allocator); + try (final FlightSqlClient sqlClient = new FlightSqlClient(client)) { + validateMetadataRetrieval(sqlClient); + validateStatementExecution(sqlClient); + validatePreparedStatementExecution(allocator, sqlClient); + } } private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Exception { @@ -122,40 +129,35 @@ private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Excepti } private void validateStatementExecution(FlightSqlClient sqlClient) throws Exception { - final CallOption[] options = new CallOption[0]; - - validate(FlightSqlScenarioProducer.getQuerySchema(), - sqlClient.execute("SELECT STATEMENT", options), sqlClient); + FlightInfo info = sqlClient.execute("SELECT STATEMENT"); + validate(FlightSqlScenarioProducer.getQuerySchema(), info, sqlClient); validateSchema(FlightSqlScenarioProducer.getQuerySchema(), - sqlClient.getExecuteSchema("SELECT STATEMENT", options)); + sqlClient.getExecuteSchema("SELECT STATEMENT")); - IntegrationAssertions.assertEquals(sqlClient.executeUpdate("UPDATE STATEMENT", options), + IntegrationAssertions.assertEquals(sqlClient.executeUpdate("UPDATE STATEMENT"), UPDATE_STATEMENT_EXPECTED_ROWS); } - private void validatePreparedStatementExecution(FlightSqlClient sqlClient, - BufferAllocator allocator) throws Exception { - final CallOption[] options = new CallOption[0]; + private void validatePreparedStatementExecution(BufferAllocator allocator, + FlightSqlClient sqlClient) throws Exception { try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare( "SELECT PREPARED STATEMENT"); VectorSchemaRoot parameters = VectorSchemaRoot.create( FlightSqlScenarioProducer.getQuerySchema(), allocator)) { parameters.setRowCount(1); preparedStatement.setParameters(parameters); - - validate(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.execute(options), - sqlClient); + validate(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.execute(), sqlClient); validateSchema(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.fetchSchema()); } - try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare( - "UPDATE PREPARED STATEMENT")) { - IntegrationAssertions.assertEquals(preparedStatement.executeUpdate(options), + try (FlightSqlClient.PreparedStatement preparedStatement = + sqlClient.prepare("UPDATE PREPARED STATEMENT")) { + IntegrationAssertions.assertEquals(preparedStatement.executeUpdate(), UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS); } } - private void validate(Schema expectedSchema, FlightInfo flightInfo, + protected void validate(Schema expectedSchema, FlightInfo flightInfo, FlightSqlClient sqlClient) throws Exception { Ticket ticket = flightInfo.getEndpoints().get(0).getTicket(); try (FlightStream stream = sqlClient.getStream(ticket)) { @@ -164,7 +166,7 @@ private void validate(Schema expectedSchema, FlightInfo flightInfo, } } - private void validateSchema(Schema expected, SchemaResult actual) { + protected void validateSchema(Schema expected, SchemaResult actual) { IntegrationAssertions.assertEquals(expected, actual.getSchema()); } } diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java index 33d62b650e1..4ed9a3df0fc 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java @@ -17,11 +17,11 @@ package org.apache.arrow.flight.integration.tests; -import static com.google.protobuf.Any.pack; -import static java.util.Collections.singletonList; - +import java.util.Arrays; +import java.util.Collections; import java.util.List; +import org.apache.arrow.flight.CallStatus; import org.apache.arrow.flight.Criteria; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightEndpoint; @@ -31,8 +31,10 @@ import org.apache.arrow.flight.Result; import org.apache.arrow.flight.SchemaResult; import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.CancelResult; import org.apache.arrow.flight.sql.FlightSqlColumnMetadata; import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.SqlInfoBuilder; import org.apache.arrow.flight.sql.impl.FlightSql; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; @@ -42,7 +44,9 @@ import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; +import com.google.protobuf.Any; import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; /** @@ -61,7 +65,7 @@ public FlightSqlScenarioProducer(BufferAllocator allocator) { */ static Schema getQuerySchema() { return new Schema( - singletonList( + Collections.singletonList( new Field("id", new FieldType(true, new ArrowType.Int(64, true), null, new FlightSqlColumnMetadata.Builder() .tableName("test") @@ -77,6 +81,94 @@ static Schema getQuerySchema() { ); } + /** + * The expected schema for queries with transactions. + *

+ * Must be the same across all languages. + */ + static Schema getQueryWithTransactionSchema() { + return new Schema( + Collections.singletonList( + new Field("pkey", new FieldType(true, new ArrowType.Int(32, true), + null, new FlightSqlColumnMetadata.Builder() + .tableName("test") + .isAutoIncrement(true) + .isCaseSensitive(false) + .typeName("type_test") + .schemaName("schema_test") + .isSearchable(true) + .catalogName("catalog_test") + .precision(100) + .build().getMetadataMap()), null) + ) + ); + } + + @Override + public void beginSavepoint(FlightSql.ActionBeginSavepointRequest request, CallContext context, + StreamListener listener) { + if (!request.getName().equals(FlightSqlScenario.SAVEPOINT_NAME)) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected name '%s', not '%s'", + FlightSqlScenario.SAVEPOINT_NAME, request.getName())) + .toRuntimeException()); + return; + } + if (!Arrays.equals(request.getTransactionId().toByteArray(), FlightSqlScenario.TRANSACTION_ID)) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected transaction ID '%s', not '%s'", + Arrays.toString(FlightSqlScenario.TRANSACTION_ID), + Arrays.toString(request.getTransactionId().toByteArray()))) + .toRuntimeException()); + return; + } + listener.onNext(FlightSql.ActionBeginSavepointResult.newBuilder() + .setSavepointId(ByteString.copyFrom(FlightSqlScenario.SAVEPOINT_ID)) + .build()); + listener.onCompleted(); + } + + @Override + public void beginTransaction(FlightSql.ActionBeginTransactionRequest request, CallContext context, + StreamListener listener) { + listener.onNext(FlightSql.ActionBeginTransactionResult.newBuilder() + .setTransactionId(ByteString.copyFrom(FlightSqlScenario.TRANSACTION_ID)) + .build()); + listener.onCompleted(); + } + + @Override + public void cancelQuery(FlightInfo info, CallContext context, StreamListener listener) { + final String expectedTicket = "PLAN HANDLE"; + if (info.getEndpoints().size() != 1) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected 1 endpoint, got %d", info.getEndpoints().size())) + .toRuntimeException()); + } + final FlightEndpoint endpoint = info.getEndpoints().get(0); + try { + final Any any = Any.parseFrom(endpoint.getTicket().getBytes()); + if (!any.is(FlightSql.TicketStatementQuery.class)) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected TicketStatementQuery, found '%s'", any.getTypeUrl())) + .toRuntimeException()); + } + final FlightSql.TicketStatementQuery ticket = any.unpack(FlightSql.TicketStatementQuery.class); + if (!ticket.getStatementHandle().toStringUtf8().equals(expectedTicket)) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected ticket '%s'", expectedTicket)) + .toRuntimeException()); + } + listener.onNext(CancelResult.CANCELLED); + listener.onCompleted(); + } catch (InvalidProtocolBufferException e) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription("Invalid Protobuf:" + e) + .withCause(e) + .toRuntimeException()); + } + } + @Override public void createPreparedStatement(FlightSql.ActionCreatePreparedStatementRequest request, CallContext context, StreamListener listener) { @@ -84,21 +176,106 @@ public void createPreparedStatement(FlightSql.ActionCreatePreparedStatementReque request.getQuery().equals("SELECT PREPARED STATEMENT") || request.getQuery().equals("UPDATE PREPARED STATEMENT")); + String text = request.getQuery(); + if (!request.getTransactionId().isEmpty()) { + text += " WITH TXN"; + } + text += " HANDLE"; final FlightSql.ActionCreatePreparedStatementResult result = FlightSql.ActionCreatePreparedStatementResult.newBuilder() - .setPreparedStatementHandle(ByteString.copyFromUtf8(request.getQuery() + " HANDLE")) + .setPreparedStatementHandle(ByteString.copyFromUtf8(text)) .build(); - listener.onNext(new Result(pack(result).toByteArray())); + listener.onNext(new Result(Any.pack(result).toByteArray())); + listener.onCompleted(); + } + + @Override + public void createPreparedSubstraitPlan(FlightSql.ActionCreatePreparedSubstraitPlanRequest request, + CallContext context, + StreamListener listener) { + if (!Arrays.equals(request.getPlan().getPlan().toByteArray(), FlightSqlScenario.SUBSTRAIT_PLAN_TEXT)) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected plan '%s', not '%s'", + Arrays.toString(FlightSqlScenario.SUBSTRAIT_PLAN_TEXT), + Arrays.toString(request.getPlan().getPlan().toByteArray()))) + .toRuntimeException()); + return; + } + if (!FlightSqlScenario.SUBSTRAIT_VERSION.equals(request.getPlan().getVersion())) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected version '%s', not '%s'", + FlightSqlScenario.SUBSTRAIT_VERSION, + request.getPlan().getVersion())) + .toRuntimeException()); + return; + } + final String handle = request.getTransactionId().isEmpty() ? + "PREPARED PLAN HANDLE" : "PREPARED PLAN WITH TXN HANDLE"; + final FlightSql.ActionCreatePreparedStatementResult result = + FlightSql.ActionCreatePreparedStatementResult.newBuilder() + .setPreparedStatementHandle(ByteString.copyFromUtf8(handle)) + .build(); + listener.onNext(result); listener.onCompleted(); } @Override public void closePreparedStatement(FlightSql.ActionClosePreparedStatementRequest request, CallContext context, StreamListener listener) { - IntegrationAssertions.assertTrue("Expect to be one of the two queries used on tests", - request.getPreparedStatementHandle().toStringUtf8().equals("SELECT PREPARED STATEMENT HANDLE") || - request.getPreparedStatementHandle().toStringUtf8().equals("UPDATE PREPARED STATEMENT HANDLE")); + final String handle = request.getPreparedStatementHandle().toStringUtf8(); + IntegrationAssertions.assertTrue("Expect to be one of the queries used on tests", + handle.equals("SELECT PREPARED STATEMENT HANDLE") || + handle.equals("SELECT PREPARED STATEMENT WITH TXN HANDLE") || + handle.equals("UPDATE PREPARED STATEMENT HANDLE") || + handle.equals("UPDATE PREPARED STATEMENT WITH TXN HANDLE") || + handle.equals("PREPARED PLAN HANDLE") || + handle.equals("PREPARED PLAN WITH TXN HANDLE")); + listener.onCompleted(); + } + @Override + public void endSavepoint(FlightSql.ActionEndSavepointRequest request, CallContext context, + StreamListener listener) { + switch (request.getAction()) { + case END_SAVEPOINT_RELEASE: + case END_SAVEPOINT_ROLLBACK: + if (!Arrays.equals(request.getSavepointId().toByteArray(), FlightSqlScenario.SAVEPOINT_ID)) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription("Unexpected ID: " + Arrays.toString(request.getSavepointId().toByteArray())) + .toRuntimeException()); + } + break; + case UNRECOGNIZED: + default: { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription("Unknown action: " + request.getAction()) + .toRuntimeException()); + return; + } + } + listener.onCompleted(); + } + + @Override + public void endTransaction(FlightSql.ActionEndTransactionRequest request, CallContext context, + StreamListener listener) { + switch (request.getAction()) { + case END_TRANSACTION_COMMIT: + case END_TRANSACTION_ROLLBACK: + if (!Arrays.equals(request.getTransactionId().toByteArray(), FlightSqlScenario.TRANSACTION_ID)) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription("Unexpected ID: " + Arrays.toString(request.getTransactionId().toByteArray())) + .toRuntimeException()); + } + break; + case UNRECOGNIZED: + default: { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription("Unknown action: " + request.getAction()) + .toRuntimeException()); + return; + } + } listener.onCompleted(); } @@ -106,11 +283,31 @@ public void closePreparedStatement(FlightSql.ActionClosePreparedStatementRequest public FlightInfo getFlightInfoStatement(FlightSql.CommandStatementQuery command, CallContext context, FlightDescriptor descriptor) { IntegrationAssertions.assertEquals(command.getQuery(), "SELECT STATEMENT"); + if (command.getTransactionId().isEmpty()) { + String handle = "SELECT STATEMENT HANDLE"; + FlightSql.TicketStatementQuery ticket = FlightSql.TicketStatementQuery.newBuilder() + .setStatementHandle(ByteString.copyFromUtf8(handle)) + .build(); + return getFlightInfoForSchema(ticket, descriptor, getQuerySchema()); + } else { + String handle = "SELECT STATEMENT WITH TXN HANDLE"; + FlightSql.TicketStatementQuery ticket = FlightSql.TicketStatementQuery.newBuilder() + .setStatementHandle(ByteString.copyFromUtf8(handle)) + .build(); + return getFlightInfoForSchema(ticket, descriptor, getQueryWithTransactionSchema()); + } + } - ByteString handle = ByteString.copyFromUtf8("SELECT STATEMENT HANDLE"); - + @Override + public FlightInfo getFlightInfoSubstraitPlan(FlightSql.CommandStatementSubstraitPlan command, CallContext context, + FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(command.getPlan().getPlan().toByteArray(), + FlightSqlScenario.SUBSTRAIT_PLAN_TEXT); + IntegrationAssertions.assertEquals(command.getPlan().getVersion(), FlightSqlScenario.SUBSTRAIT_VERSION); + String handle = command.getTransactionId().isEmpty() ? + "PLAN HANDLE" : "PLAN WITH TXN HANDLE"; FlightSql.TicketStatementQuery ticket = FlightSql.TicketStatementQuery.newBuilder() - .setStatementHandle(handle) + .setStatementHandle(ByteString.copyFromUtf8(handle)) .build(); return getFlightInfoForSchema(ticket, descriptor, getQuerySchema()); } @@ -119,37 +316,91 @@ public FlightInfo getFlightInfoStatement(FlightSql.CommandStatementQuery command public FlightInfo getFlightInfoPreparedStatement(FlightSql.CommandPreparedStatementQuery command, CallContext context, FlightDescriptor descriptor) { - IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), - "SELECT PREPARED STATEMENT HANDLE"); + String handle = command.getPreparedStatementHandle().toStringUtf8(); + if (handle.equals("SELECT PREPARED STATEMENT HANDLE") || + handle.equals("PREPARED PLAN HANDLE")) { + return getFlightInfoForSchema(command, descriptor, getQuerySchema()); + } else if (handle.equals("SELECT PREPARED STATEMENT WITH TXN HANDLE") || + handle.equals("PREPARED PLAN WITH TXN HANDLE")) { + return getFlightInfoForSchema(command, descriptor, getQueryWithTransactionSchema()); + } + throw CallStatus.INVALID_ARGUMENT.withDescription("Unknown handle: " + handle).toRuntimeException(); + } - return getFlightInfoForSchema(command, descriptor, getQuerySchema()); + @Override + public SchemaResult getSchemaStatement(FlightSql.CommandStatementQuery command, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(command.getQuery(), "SELECT STATEMENT"); + if (command.getTransactionId().isEmpty()) { + return new SchemaResult(getQuerySchema()); + } + return new SchemaResult(getQueryWithTransactionSchema()); } @Override public SchemaResult getSchemaPreparedStatement(FlightSql.CommandPreparedStatementQuery command, CallContext context, FlightDescriptor descriptor) { - IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), - "SELECT PREPARED STATEMENT HANDLE"); - return new SchemaResult(getQuerySchema()); + String handle = command.getPreparedStatementHandle().toStringUtf8(); + if (handle.equals("SELECT PREPARED STATEMENT HANDLE") || + handle.equals("PREPARED PLAN HANDLE")) { + return new SchemaResult(getQuerySchema()); + } else if (handle.equals("SELECT PREPARED STATEMENT WITH TXN HANDLE") || + handle.equals("PREPARED PLAN WITH TXN HANDLE")) { + return new SchemaResult(getQueryWithTransactionSchema()); + } + throw CallStatus.INVALID_ARGUMENT.withDescription("Unknown handle: " + handle).toRuntimeException(); } @Override - public SchemaResult getSchemaStatement(FlightSql.CommandStatementQuery command, - CallContext context, FlightDescriptor descriptor) { - IntegrationAssertions.assertEquals(command.getQuery(), "SELECT STATEMENT"); - return new SchemaResult(getQuerySchema()); + public SchemaResult getSchemaSubstraitPlan(FlightSql.CommandStatementSubstraitPlan command, CallContext context, + FlightDescriptor descriptor) { + if (!Arrays.equals(command.getPlan().getPlan().toByteArray(), FlightSqlScenario.SUBSTRAIT_PLAN_TEXT)) { + throw CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected plan '%s', not '%s'", + Arrays.toString(FlightSqlScenario.SUBSTRAIT_PLAN_TEXT), + Arrays.toString(command.getPlan().getPlan().toByteArray()))) + .toRuntimeException(); + } + if (!FlightSqlScenario.SUBSTRAIT_VERSION.equals(command.getPlan().getVersion())) { + throw CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected version '%s', not '%s'", + FlightSqlScenario.SUBSTRAIT_VERSION, + command.getPlan().getVersion())) + .toRuntimeException(); + } + if (command.getTransactionId().isEmpty()) { + return new SchemaResult(getQuerySchema()); + } + return new SchemaResult(getQueryWithTransactionSchema()); } @Override public void getStreamStatement(FlightSql.TicketStatementQuery ticket, CallContext context, ServerStreamListener listener) { - putEmptyBatchToStreamListener(listener, getQuerySchema()); + final String handle = ticket.getStatementHandle().toStringUtf8(); + if (handle.equals("SELECT STATEMENT HANDLE") || handle.equals("PLAN HANDLE")) { + putEmptyBatchToStreamListener(listener, getQuerySchema()); + } else if (handle.equals("SELECT STATEMENT WITH TXN HANDLE") || handle.equals("PLAN WITH TXN HANDLE")) { + putEmptyBatchToStreamListener(listener, getQueryWithTransactionSchema()); + } else { + listener.error(CallStatus.INVALID_ARGUMENT.withDescription("Unknown handle: " + handle).toRuntimeException()); + } } @Override public void getStreamPreparedStatement(FlightSql.CommandPreparedStatementQuery command, CallContext context, ServerStreamListener listener) { - putEmptyBatchToStreamListener(listener, getQuerySchema()); + String handle = command.getPreparedStatementHandle().toStringUtf8(); + if (handle.equals("SELECT PREPARED STATEMENT HANDLE") || handle.equals("PREPARED PLAN HANDLE")) { + putEmptyBatchToStreamListener(listener, getQuerySchema()); + } else if (handle.equals("SELECT PREPARED STATEMENT WITH TXN HANDLE") || + handle.equals("PREPARED PLAN WITH TXN HANDLE")) { + putEmptyBatchToStreamListener(listener, getQueryWithTransactionSchema()); + } else { + listener.error(CallStatus.INVALID_ARGUMENT + .withDescription("Unknown handle: " + handle) + .toRuntimeException()); + } } private Runnable acceptPutReturnConstant(StreamListener ackStream, long value) { @@ -170,48 +421,92 @@ public Runnable acceptPutStatement(FlightSql.CommandStatementUpdate command, Cal FlightStream flightStream, StreamListener ackStream) { IntegrationAssertions.assertEquals(command.getQuery(), "UPDATE STATEMENT"); + return acceptPutReturnConstant(ackStream, + command.getTransactionId().isEmpty() ? FlightSqlScenario.UPDATE_STATEMENT_EXPECTED_ROWS : + FlightSqlScenario.UPDATE_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS); + } - return acceptPutReturnConstant(ackStream, FlightSqlScenario.UPDATE_STATEMENT_EXPECTED_ROWS); + @Override + public Runnable acceptPutSubstraitPlan(FlightSql.CommandStatementSubstraitPlan command, CallContext context, + FlightStream flightStream, StreamListener ackStream) { + IntegrationAssertions.assertEquals(command.getPlan().getPlan().toByteArray(), + FlightSqlScenario.SUBSTRAIT_PLAN_TEXT); + IntegrationAssertions.assertEquals(command.getPlan().getVersion(), FlightSqlScenario.SUBSTRAIT_VERSION); + return acceptPutReturnConstant(ackStream, + command.getTransactionId().isEmpty() ? FlightSqlScenario.UPDATE_STATEMENT_EXPECTED_ROWS : + FlightSqlScenario.UPDATE_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS); } @Override public Runnable acceptPutPreparedStatementUpdate(FlightSql.CommandPreparedStatementUpdate command, CallContext context, FlightStream flightStream, StreamListener ackStream) { - IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), - "UPDATE PREPARED STATEMENT HANDLE"); - - return acceptPutReturnConstant(ackStream, FlightSqlScenario.UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS); + final String handle = command.getPreparedStatementHandle().toStringUtf8(); + if (handle.equals("UPDATE PREPARED STATEMENT HANDLE") || + handle.equals("PREPARED PLAN HANDLE")) { + return acceptPutReturnConstant(ackStream, FlightSqlScenario.UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS); + } else if (handle.equals("UPDATE PREPARED STATEMENT WITH TXN HANDLE") || + handle.equals("PREPARED PLAN WITH TXN HANDLE")) { + return acceptPutReturnConstant( + ackStream, FlightSqlScenario.UPDATE_PREPARED_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS); + } + return () -> { + ackStream.onError(CallStatus.INVALID_ARGUMENT + .withDescription("Unknown handle: " + handle) + .toRuntimeException()); + }; } @Override public Runnable acceptPutPreparedStatementQuery(FlightSql.CommandPreparedStatementQuery command, CallContext context, FlightStream flightStream, StreamListener ackStream) { - IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), - "SELECT PREPARED STATEMENT HANDLE"); - - IntegrationAssertions.assertEquals(getQuerySchema(), flightStream.getSchema()); - - return ackStream::onCompleted; + final String handle = command.getPreparedStatementHandle().toStringUtf8(); + if (handle.equals("SELECT PREPARED STATEMENT HANDLE") || + handle.equals("SELECT PREPARED STATEMENT WITH TXN HANDLE") || + handle.equals("PREPARED PLAN HANDLE") || + handle.equals("PREPARED PLAN WITH TXN HANDLE")) { + IntegrationAssertions.assertEquals(getQuerySchema(), flightStream.getSchema()); + return ackStream::onCompleted; + } + return () -> { + ackStream.onError(CallStatus.INVALID_ARGUMENT + .withDescription("Unknown handle: " + handle) + .toRuntimeException()); + }; } @Override public FlightInfo getFlightInfoSqlInfo(FlightSql.CommandGetSqlInfo request, CallContext context, FlightDescriptor descriptor) { - IntegrationAssertions.assertEquals(request.getInfoCount(), 2); - IntegrationAssertions.assertEquals(request.getInfo(0), - FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE); - IntegrationAssertions.assertEquals(request.getInfo(1), - FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE); - + if (request.getInfoCount() == 2) { + // Integration test for the protocol messages + IntegrationAssertions.assertEquals(request.getInfo(0), + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE); + IntegrationAssertions.assertEquals(request.getInfo(1), + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE); + } return getFlightInfoForSchema(request, descriptor, Schemas.GET_SQL_INFO_SCHEMA); } @Override public void getStreamSqlInfo(FlightSql.CommandGetSqlInfo command, CallContext context, ServerStreamListener listener) { - putEmptyBatchToStreamListener(listener, Schemas.GET_SQL_INFO_SCHEMA); + if (command.getInfoCount() == 2) { + // Integration test for the protocol messages + putEmptyBatchToStreamListener(listener, Schemas.GET_SQL_INFO_SCHEMA); + return; + } + SqlInfoBuilder sqlInfoBuilder = new SqlInfoBuilder() + .withFlightSqlServerSql(false) + .withFlightSqlServerSubstrait(true) + .withFlightSqlServerSubstraitMinVersion("min_version") + .withFlightSqlServerSubstraitMaxVersion("max_version") + .withFlightSqlServerTransaction(FlightSql.SqlSupportedTransaction.SQL_SUPPORTED_TRANSACTION_SAVEPOINT) + .withFlightSqlServerCancel(true) + .withFlightSqlServerStatementTimeout(42) + .withFlightSqlServerTransactionTimeout(7); + sqlInfoBuilder.send(command.getInfoList(), listener); } @Override @@ -373,8 +668,8 @@ public void listFlights(CallContext context, Criteria criteria, private FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor, final Schema schema) { - final Ticket ticket = new Ticket(pack(request).toByteArray()); - final List endpoints = singletonList(new FlightEndpoint(ticket)); + final Ticket ticket = new Ticket(Any.pack(request).toByteArray()); + final List endpoints = Collections.singletonList(new FlightEndpoint(ticket)); return new FlightInfo(schema, descriptor, endpoints, -1, -1); } diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java index 76f846a8b73..a60efcbb78d 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java @@ -19,6 +19,7 @@ import java.io.PrintWriter; import java.io.StringWriter; +import java.util.Arrays; import java.util.Objects; import org.apache.arrow.flight.CallStatus; @@ -59,6 +60,16 @@ static void assertEquals(Object expected, Object actual) { } } + /** + * Assert that the two arrays are equal. + */ + static void assertEquals(byte[] expected, byte[] actual) { + if (!Arrays.equals(expected, actual)) { + throw new AssertionError( + String.format("Expected:\n%s\nbut got:\n%s", Arrays.toString(expected), Arrays.toString(actual))); + } + } + /** * Assert that the value is false, using the given message as an error otherwise. */ diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java index 16cc856daf5..77f7ab0006d 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java @@ -42,6 +42,7 @@ private Scenarios() { scenarios.put("auth:basic_proto", AuthBasicProtoScenario::new); scenarios.put("middleware", MiddlewareScenario::new); scenarios.put("flight_sql", FlightSqlScenario::new); + scenarios.put("flight_sql:extension", FlightSqlExtensionScenario::new); } private static Scenarios getInstance() { diff --git a/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java b/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java new file mode 100644 index 00000000000..0751e1d7a89 --- /dev/null +++ b/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.jupiter.api.Test; + +/** + * Run the integration test scenarios in-process. + */ +class IntegrationTest { + @Test + void authBasicProto() throws Exception { + testScenario("auth:basic_proto"); + } + + @Test + void middleware() throws Exception { + testScenario("middleware"); + } + + @Test + void flightSql() throws Exception { + testScenario("flight_sql"); + } + + @Test + void flightSqlExtension() throws Exception { + testScenario("flight_sql:extension"); + } + + void testScenario(String scenarioName) throws Exception { + try (final BufferAllocator allocator = new RootAllocator()) { + final FlightServer.Builder builder = FlightServer.builder() + .allocator(allocator) + .location(Location.forGrpcInsecure("0.0.0.0", 0)); + final Scenario scenario = Scenarios.getScenario(scenarioName); + scenario.buildServer(builder); + builder.producer(scenario.producer(allocator, Location.forGrpcInsecure("0.0.0.0", 0))); + + try (final FlightServer server = builder.build()) { + server.start(); + + final Location location = Location.forGrpcInsecure("localhost", server.getPort()); + try (final FlightClient client = FlightClient.builder(allocator, location).build()) { + scenario.client(allocator, location, client); + } + } + } + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelListener.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelListener.java new file mode 100644 index 00000000000..3438f788dcf --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelListener.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.sql.impl.FlightSql; + +import com.google.protobuf.Any; + +/** Typed StreamListener for cancelQuery. */ +class CancelListener implements FlightProducer.StreamListener { + private final FlightProducer.StreamListener listener; + + CancelListener(FlightProducer.StreamListener listener) { + this.listener = listener; + } + + @Override + public void onNext(CancelResult val) { + FlightSql.ActionCancelQueryResult result = FlightSql.ActionCancelQueryResult.newBuilder() + .setResult(val.toProtocol()) + .build(); + listener.onNext(new Result(Any.pack(result).toByteArray())); + } + + @Override + public void onError(Throwable t) { + listener.onError(t); + } + + @Override + public void onCompleted() { + listener.onCompleted(); + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelResult.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelResult.java new file mode 100644 index 00000000000..d1ae4178310 --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelResult.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import org.apache.arrow.flight.sql.impl.FlightSql; + +/** + * The result of cancelling a query. + */ +public enum CancelResult { + UNSPECIFIED, + CANCELLED, + CANCELLING, + NOT_CANCELLABLE, + ; + + FlightSql.ActionCancelQueryResult.CancelResult toProtocol() { + switch (this) { + default: + case UNSPECIFIED: + return FlightSql.ActionCancelQueryResult.CancelResult.CANCEL_RESULT_UNSPECIFIED; + case CANCELLED: + return FlightSql.ActionCancelQueryResult.CancelResult.CANCEL_RESULT_CANCELLED; + case CANCELLING: + return FlightSql.ActionCancelQueryResult.CancelResult.CANCEL_RESULT_CANCELLING; + case NOT_CANCELLABLE: + return FlightSql.ActionCancelQueryResult.CancelResult.CANCEL_RESULT_NOT_CANCELLABLE; + } + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java index f1f07a1588f..922495a18e0 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java @@ -17,8 +17,17 @@ package org.apache.arrow.flight.sql; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginSavepointRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginSavepointResult; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginTransactionRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginTransactionResult; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCancelQueryRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCancelQueryResult; import static org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest; import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedSubstraitPlanRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionEndSavepointRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionEndTransactionRequest; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCrossReference; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas; @@ -31,6 +40,7 @@ import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetXdbcTypeInfo; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementUpdate; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementSubstraitPlan; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; import static org.apache.arrow.flight.sql.impl.FlightSql.DoPutUpdateResult; import static org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo; @@ -91,8 +101,51 @@ public FlightSqlClient(final FlightClient client) { * @return a FlightInfo object representing the stream(s) to fetch. */ public FlightInfo execute(final String query, final CallOption... options) { - final CommandStatementQuery.Builder builder = CommandStatementQuery.newBuilder(); - builder.setQuery(query); + return execute(query, /*transaction*/ null, options); + } + + /** + * Execute a query on the server. + * + * @param query The query to execute. + * @param transaction The transaction that this query is part of. + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo execute(final String query, Transaction transaction, final CallOption... options) { + final CommandStatementQuery.Builder builder = CommandStatementQuery.newBuilder().setQuery(query); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + return client.getInfo(descriptor, options); + } + + /** + * Execute a Substrait plan on the server. + * + * @param plan The Substrait plan to execute. + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo executeSubstrait(SubstraitPlan plan, CallOption... options) { + return executeSubstrait(plan, /*transaction*/ null, options); + } + + /** + * Execute a Substrait plan on the server. + * + * @param plan The Substrait plan to execute. + * @param transaction The transaction that this query is part of. + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo executeSubstrait(SubstraitPlan plan, Transaction transaction, CallOption... options) { + final CommandStatementSubstraitPlan.Builder builder = CommandStatementSubstraitPlan.newBuilder(); + builder.getPlanBuilder().setPlan(ByteString.copyFrom(plan.getPlan())).setVersion(plan.getVersion()); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); return client.getInfo(descriptor, options); } @@ -100,13 +153,44 @@ public FlightInfo execute(final String query, final CallOption... options) { /** * Get the schema of the result set of a query. */ - public SchemaResult getExecuteSchema(final String query, final CallOption... options) { + public SchemaResult getExecuteSchema(String query, Transaction transaction, CallOption... options) { final CommandStatementQuery.Builder builder = CommandStatementQuery.newBuilder(); builder.setQuery(query); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); return client.getSchema(descriptor, options); } + /** + * Get the schema of the result set of a query. + */ + public SchemaResult getExecuteSchema(String query, CallOption... options) { + return getExecuteSchema(query, /*transaction*/null, options); + } + + /** + * Get the schema of the result set of a Substrait plan. + */ + public SchemaResult getExecuteSubstraitSchema(SubstraitPlan plan, Transaction transaction, + final CallOption... options) { + final CommandStatementSubstraitPlan.Builder builder = CommandStatementSubstraitPlan.newBuilder(); + builder.getPlanBuilder().setPlan(ByteString.copyFrom(plan.getPlan())).setVersion(plan.getVersion()); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + return client.getSchema(descriptor, options); + } + + /** + * Get the schema of the result set of a Substrait plan. + */ + public SchemaResult getExecuteSubstraitSchema(SubstraitPlan substraitPlan, final CallOption... options) { + return getExecuteSubstraitSchema(substraitPlan, /*transaction*/null, options); + } + /** * Execute an update query on the server. * @@ -115,18 +199,77 @@ public SchemaResult getExecuteSchema(final String query, final CallOption... opt * @return the number of rows affected. */ public long executeUpdate(final String query, final CallOption... options) { - final CommandStatementUpdate.Builder builder = CommandStatementUpdate.newBuilder(); - builder.setQuery(query); + return executeUpdate(query, /*transaction*/ null, options); + } + + /** + * Execute an update query on the server. + * + * @param query The query to execute. + * @param transaction The transaction that this query is part of. + * @param options RPC-layer hints for this call. + * @return the number of rows affected. + */ + public long executeUpdate(final String query, Transaction transaction, final CallOption... options) { + final CommandStatementUpdate.Builder builder = CommandStatementUpdate.newBuilder().setQuery(query); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); - final SyncPutListener putListener = new SyncPutListener(); - client.startPut(descriptor, VectorSchemaRoot.of(), putListener, options); + try (final SyncPutListener putListener = new SyncPutListener()) { + final FlightClient.ClientStreamListener listener = + client.startPut(descriptor, VectorSchemaRoot.of(), putListener, options); + try (final PutResult result = putListener.read()) { + final DoPutUpdateResult doPutUpdateResult = DoPutUpdateResult.parseFrom( + result.getApplicationMetadata().nioBuffer()); + return doPutUpdateResult.getRecordCount(); + } finally { + listener.getResult(); + } + } catch (final InterruptedException | ExecutionException e) { + throw CallStatus.CANCELLED.withCause(e).toRuntimeException(); + } catch (final InvalidProtocolBufferException e) { + throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); + } + } - try { - final PutResult read = putListener.read(); - try (final ArrowBuf metadata = read.getApplicationMetadata()) { - final DoPutUpdateResult doPutUpdateResult = DoPutUpdateResult.parseFrom(metadata.nioBuffer()); + /** + * Execute an update query on the server. + * + * @param plan The Substrait plan to execute. + * @param options RPC-layer hints for this call. + * @return the number of rows affected. + */ + public long executeSubstraitUpdate(SubstraitPlan plan, CallOption... options) { + return executeSubstraitUpdate(plan, /*transaction*/ null, options); + } + + /** + * Execute an update query on the server. + * + * @param plan The Substrait plan to execute. + * @param transaction The transaction that this query is part of. + * @param options RPC-layer hints for this call. + * @return the number of rows affected. + */ + public long executeSubstraitUpdate(SubstraitPlan plan, Transaction transaction, CallOption... options) { + final CommandStatementSubstraitPlan.Builder builder = CommandStatementSubstraitPlan.newBuilder(); + builder.getPlanBuilder().setPlan(ByteString.copyFrom(plan.getPlan())).setVersion(plan.getVersion()); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } + + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + try (final SyncPutListener putListener = new SyncPutListener()) { + final FlightClient.ClientStreamListener listener = + client.startPut(descriptor, VectorSchemaRoot.of(), putListener, options); + try (final PutResult result = putListener.read()) { + final DoPutUpdateResult doPutUpdateResult = DoPutUpdateResult.parseFrom( + result.getApplicationMetadata().nioBuffer()); return doPutUpdateResult.getRecordCount(); + } finally { + listener.getResult(); } } catch (final InterruptedException | ExecutionException e) { throw CallStatus.CANCELLED.withCause(e).toRuntimeException(); @@ -551,14 +694,198 @@ public SchemaResult getTableTypesSchema(final CallOption... options) { } /** - * Create a prepared statement on the server. + * Create a prepared statement for a SQL query on the server. * * @param query The query to prepare. * @param options RPC-layer hints for this call. * @return The representation of the prepared statement which exists on the server. */ - public PreparedStatement prepare(final String query, final CallOption... options) { - return new PreparedStatement(client, query, options); + public PreparedStatement prepare(String query, CallOption... options) { + return prepare(query, /*transaction*/ null, options); + } + + /** + * Create a prepared statement for a SQL query on the server. + * + * @param query The query to prepare. + * @param transaction The transaction that this query is part of. + * @param options RPC-layer hints for this call. + * @return The representation of the prepared statement which exists on the server. + */ + public PreparedStatement prepare(String query, Transaction transaction, CallOption... options) { + ActionCreatePreparedStatementRequest.Builder builder = + ActionCreatePreparedStatementRequest.newBuilder().setQuery(query); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } + return new PreparedStatement(client, + new Action( + FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), + Any.pack(builder.build()).toByteArray()), + options); + } + + /** + * Create a prepared statement for a Substrait plan on the server. + * + * @param plan The query to prepare. + * @param options RPC-layer hints for this call. + * @return The representation of the prepared statement which exists on the server. + */ + public PreparedStatement prepare(SubstraitPlan plan, CallOption... options) { + return prepare(plan, /*transaction*/ null, options); + } + + /** + * Create a prepared statement for a Substrait plan on the server. + * + * @param plan The query to prepare. + * @param transaction The transaction that this query is part of. + * @param options RPC-layer hints for this call. + * @return The representation of the prepared statement which exists on the server. + */ + public PreparedStatement prepare(SubstraitPlan plan, Transaction transaction, CallOption... options) { + ActionCreatePreparedSubstraitPlanRequest.Builder builder = + ActionCreatePreparedSubstraitPlanRequest.newBuilder(); + builder.getPlanBuilder().setPlan(ByteString.copyFrom(plan.getPlan())).setVersion(plan.getVersion()); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } + return new PreparedStatement(client, + new Action( + FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_SUBSTRAIT_PLAN.getType(), + Any.pack(builder.build()).toByteArray()), + options); + } + + /** Begin a transaction. */ + public Transaction beginTransaction(CallOption... options) { + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_BEGIN_TRANSACTION.getType(), + Any.pack(ActionBeginTransactionRequest.getDefaultInstance()).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + final ActionBeginTransactionResult result = FlightSqlUtils.unpackAndParseOrThrow( + preparedStatementResults.next().getBody(), + ActionBeginTransactionResult.class); + preparedStatementResults.forEachRemaining((ignored) -> { }); + if (result.getTransactionId().isEmpty()) { + throw CallStatus.INTERNAL.withDescription("Server returned an empty transaction ID").toRuntimeException(); + } + return new Transaction(result.getTransactionId().toByteArray()); + } + + /** Create a savepoint within a transaction. */ + public Savepoint beginSavepoint(Transaction transaction, String name, CallOption... options) { + Preconditions.checkArgument(transaction.getTransactionId().length != 0, "Transaction must be initialized"); + ActionBeginSavepointRequest request = ActionBeginSavepointRequest.newBuilder() + .setTransactionId(ByteString.copyFrom(transaction.getTransactionId())) + .setName(name) + .build(); + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_BEGIN_SAVEPOINT.getType(), + Any.pack(request).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + final ActionBeginSavepointResult result = FlightSqlUtils.unpackAndParseOrThrow( + preparedStatementResults.next().getBody(), + ActionBeginSavepointResult.class); + preparedStatementResults.forEachRemaining((ignored) -> { }); + if (result.getSavepointId().isEmpty()) { + throw CallStatus.INTERNAL.withDescription("Server returned an empty transaction ID").toRuntimeException(); + } + return new Savepoint(result.getSavepointId().toByteArray()); + } + + /** Commit a transaction. */ + public void commit(Transaction transaction, CallOption... options) { + Preconditions.checkArgument(transaction.getTransactionId().length != 0, "Transaction must be initialized"); + ActionEndTransactionRequest request = ActionEndTransactionRequest.newBuilder() + .setTransactionId(ByteString.copyFrom(transaction.getTransactionId())) + .setActionValue(ActionEndTransactionRequest.EndTransaction.END_TRANSACTION_COMMIT.getNumber()) + .build(); + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_END_TRANSACTION.getType(), + Any.pack(request).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + preparedStatementResults.forEachRemaining((ignored) -> { }); + } + + /** Release a savepoint. */ + public void release(Savepoint savepoint, CallOption... options) { + Preconditions.checkArgument(savepoint.getSavepointId().length != 0, "Savepoint must be initialized"); + ActionEndSavepointRequest request = ActionEndSavepointRequest.newBuilder() + .setSavepointId(ByteString.copyFrom(savepoint.getSavepointId())) + .setActionValue(ActionEndSavepointRequest.EndSavepoint.END_SAVEPOINT_RELEASE.getNumber()) + .build(); + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_END_SAVEPOINT.getType(), + Any.pack(request).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + preparedStatementResults.forEachRemaining((ignored) -> { }); + } + + /** Rollback a transaction. */ + public void rollback(Transaction transaction, CallOption... options) { + Preconditions.checkArgument(transaction.getTransactionId().length != 0, "Transaction must be initialized"); + ActionEndTransactionRequest request = ActionEndTransactionRequest.newBuilder() + .setTransactionId(ByteString.copyFrom(transaction.getTransactionId())) + .setActionValue(ActionEndTransactionRequest.EndTransaction.END_TRANSACTION_ROLLBACK.getNumber()) + .build(); + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_END_TRANSACTION.getType(), + Any.pack(request).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + preparedStatementResults.forEachRemaining((ignored) -> { }); + } + + /** Rollback to a savepoint. */ + public void rollback(Savepoint savepoint, CallOption... options) { + Preconditions.checkArgument(savepoint.getSavepointId().length != 0, "Savepoint must be initialized"); + ActionEndSavepointRequest request = ActionEndSavepointRequest.newBuilder() + .setSavepointId(ByteString.copyFrom(savepoint.getSavepointId())) + .setActionValue(ActionEndSavepointRequest.EndSavepoint.END_SAVEPOINT_RELEASE.getNumber()) + .build(); + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_END_SAVEPOINT.getType(), + Any.pack(request).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + preparedStatementResults.forEachRemaining((ignored) -> { }); + } + + /** + * Explicitly cancel a running query. + *

+ * This lets a single client explicitly cancel work, no matter how many clients + * are involved/whether the query is distributed or not, given server support. + * The transaction/statement is not rolled back; it is the application's job to + * commit or rollback as appropriate. This only indicates the client no longer + * wishes to read the remainder of the query results or continue submitting + * data. + */ + public CancelResult cancelQuery(FlightInfo info, CallOption... options) { + ActionCancelQueryRequest request = ActionCancelQueryRequest.newBuilder() + .setInfo(ByteString.copyFrom(info.serialize())) + .build(); + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_CANCEL_QUERY.getType(), + Any.pack(request).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + final ActionCancelQueryResult result = FlightSqlUtils.unpackAndParseOrThrow( + preparedStatementResults.next().getBody(), + ActionCancelQueryResult.class); + preparedStatementResults.forEachRemaining((ignored) -> { }); + switch (result.getResult()) { + case CANCEL_RESULT_UNSPECIFIED: + return CancelResult.UNSPECIFIED; + case CANCEL_RESULT_CANCELLED: + return CancelResult.CANCELLED; + case CANCEL_RESULT_CANCELLING: + return CancelResult.CANCELLING; + case CANCEL_RESULT_NOT_CANCELLABLE: + return CancelResult.NOT_CANCELLABLE; + case UNRECOGNIZED: + default: + throw CallStatus.INTERNAL.withDescription("Unknown result: " + result.getResult()).toRuntimeException(); + } } @Override @@ -577,28 +904,13 @@ public static class PreparedStatement implements AutoCloseable { private Schema resultSetSchema; private Schema parameterSchema; - /** - * Constructor. - * - * @param client The client. PreparedStatement does not maintain this resource. - * @param sql The query. - * @param options RPC-layer hints for this call. - */ - public PreparedStatement(final FlightClient client, final String sql, final CallOption... options) { + PreparedStatement(FlightClient client, Action action, CallOption... options) { this.client = client; - final Action action = new Action( - FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), - Any.pack(ActionCreatePreparedStatementRequest - .newBuilder() - .setQuery(sql) - .build()) - .toByteArray()); - final Iterator preparedStatementResults = client.doAction(action, options); + final Iterator preparedStatementResults = client.doAction(action, options); preparedStatementResult = FlightSqlUtils.unpackAndParseOrThrow( preparedStatementResults.next().getBody(), ActionCreatePreparedStatementResult.class); - isClosed = false; } @@ -790,4 +1102,81 @@ public boolean isClosed() { return isClosed; } } + + /** A handle for an active savepoint. */ + public static class Savepoint { + private final byte[] transactionId; + + public Savepoint(byte[] transactionId) { + this.transactionId = transactionId; + } + + public byte[] getSavepointId() { + return transactionId; + } + } + + /** A handle for an active transaction. */ + public static class Transaction { + private final byte[] transactionId; + + public Transaction(byte[] transactionId) { + this.transactionId = transactionId; + } + + public byte[] getTransactionId() { + return transactionId; + } + } + + /** A wrapper around a Substrait plan and a Substrait version. */ + public static final class SubstraitPlan { + private final byte[] plan; + private final String version; + + public SubstraitPlan(byte[] plan, String version) { + this.plan = Preconditions.checkNotNull(plan); + this.version = Preconditions.checkNotNull(version); + } + + public byte[] getPlan() { + return plan; + } + + public String getVersion() { + return version; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + SubstraitPlan that = (SubstraitPlan) o; + + if (!Arrays.equals(getPlan(), that.getPlan())) { + return false; + } + return getVersion().equals(that.getVersion()); + } + + @Override + public int hashCode() { + int result = Arrays.hashCode(getPlan()); + result = 31 * result + getVersion().hashCode(); + return result; + } + + @Override + public String toString() { + return "SubstraitPlan{" + + "plan=" + Arrays.toString(plan) + + ", version='" + version + '\'' + + '}'; + } + } } diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java index 4226ec9e228..00a83667990 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java @@ -20,12 +20,21 @@ import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static java.util.stream.IntStream.range; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginSavepointRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginSavepointResult; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginTransactionRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginTransactionResult; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCancelQueryRequest; import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedSubstraitPlanRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionEndSavepointRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionEndTransactionRequest; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCrossReference; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetExportedKeys; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetImportedKeys; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetXdbcTypeInfo; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementSubstraitPlan; import static org.apache.arrow.vector.complex.MapVector.DATA_VECTOR_NAME; import static org.apache.arrow.vector.complex.MapVector.KEY_NAME; import static org.apache.arrow.vector.complex.MapVector.VALUE_NAME; @@ -37,6 +46,8 @@ import static org.apache.arrow.vector.types.Types.MinorType.UINT4; import static org.apache.arrow.vector.types.Types.MinorType.VARCHAR; +import java.io.IOException; +import java.net.URISyntaxException; import java.util.List; import org.apache.arrow.flight.Action; @@ -95,6 +106,9 @@ default FlightInfo getFlightInfo(CallContext context, FlightDescriptor descripto if (command.is(CommandStatementQuery.class)) { return getFlightInfoStatement( FlightSqlUtils.unpackOrThrow(command, CommandStatementQuery.class), context, descriptor); + } else if (command.is(CommandStatementSubstraitPlan.class)) { + return getFlightInfoSubstraitPlan( + FlightSqlUtils.unpackOrThrow(command, CommandStatementSubstraitPlan.class), context, descriptor); } else if (command.is(CommandPreparedStatementQuery.class)) { return getFlightInfoPreparedStatement( FlightSqlUtils.unpackOrThrow(command, CommandPreparedStatementQuery.class), context, descriptor); @@ -130,7 +144,9 @@ default FlightInfo getFlightInfo(CallContext context, FlightDescriptor descripto FlightSqlUtils.unpackOrThrow(command, CommandGetXdbcTypeInfo.class), context, descriptor); } - throw CallStatus.INVALID_ARGUMENT.withDescription("The defined request is invalid.").toRuntimeException(); + throw CallStatus.INVALID_ARGUMENT + .withDescription("Unrecognized request: " + command.getTypeUrl()) + .toRuntimeException(); } /** @@ -150,6 +166,9 @@ default SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) } else if (command.is(CommandPreparedStatementQuery.class)) { return getSchemaPreparedStatement( FlightSqlUtils.unpackOrThrow(command, CommandPreparedStatementQuery.class), context, descriptor); + } else if (command.is(CommandStatementSubstraitPlan.class)) { + return getSchemaSubstraitPlan( + FlightSqlUtils.unpackOrThrow(command, CommandStatementSubstraitPlan.class), context, descriptor); } else if (command.is(CommandGetCatalogs.class)) { return new SchemaResult(Schemas.GET_CATALOGS_SCHEMA); } else if (command.is(CommandGetCrossReference.class)) { @@ -175,7 +194,9 @@ default SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) return new SchemaResult(Schemas.GET_TYPE_INFO_SCHEMA); } - throw CallStatus.INVALID_ARGUMENT.withDescription("Invalid command provided.").toRuntimeException(); + throw CallStatus.INVALID_ARGUMENT + .withDescription("Unrecognized request: " + command.getTypeUrl()) + .toRuntimeException(); } /** @@ -249,6 +270,10 @@ default Runnable acceptPut(CallContext context, FlightStream flightStream, Strea return acceptPutStatement( FlightSqlUtils.unpackOrThrow(command, CommandStatementUpdate.class), context, flightStream, ackStream); + } else if (command.is(CommandStatementSubstraitPlan.class)) { + return acceptPutSubstraitPlan( + FlightSqlUtils.unpackOrThrow(command, CommandStatementSubstraitPlan.class), + context, flightStream, ackStream); } else if (command.is(CommandPreparedStatementUpdate.class)) { return acceptPutPreparedStatementUpdate( FlightSqlUtils.unpackOrThrow(command, CommandPreparedStatementUpdate.class), @@ -284,19 +309,91 @@ default void listActions(CallContext context, StreamListener listene @Override default void doAction(CallContext context, Action action, StreamListener listener) { final String actionType = action.getType(); - if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType())) { + + if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_BEGIN_SAVEPOINT.getType())) { + final ActionBeginSavepointRequest request = + FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), ActionBeginSavepointRequest.class); + beginSavepoint(request, context, new ProtoListener<>(listener)); + } else if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_BEGIN_TRANSACTION.getType())) { + final ActionBeginTransactionRequest request = + FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), ActionBeginTransactionRequest.class); + beginTransaction(request, context, new ProtoListener<>(listener)); + } else if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_CANCEL_QUERY.getType())) { + final ActionCancelQueryRequest request = + FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), ActionCancelQueryRequest.class); + final FlightInfo info; + try { + info = FlightInfo.deserialize(request.getInfo().asReadOnlyByteBuffer()); + } catch (IOException | URISyntaxException e) { + listener.onError(CallStatus.INTERNAL + .withDescription("Could not unpack FlightInfo: " + e) + .withCause(e) + .toRuntimeException()); + return; + } + cancelQuery(info, context, new CancelListener(listener)); + } else if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType())) { final ActionCreatePreparedStatementRequest request = FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), ActionCreatePreparedStatementRequest.class); createPreparedStatement(request, context, listener); + } else if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_SUBSTRAIT_PLAN.getType())) { + final ActionCreatePreparedSubstraitPlanRequest request = + FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), ActionCreatePreparedSubstraitPlanRequest.class); + createPreparedSubstraitPlan(request, context, new ProtoListener<>(listener)); } else if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_CLOSE_PREPARED_STATEMENT.getType())) { - final ActionClosePreparedStatementRequest request = FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), - ActionClosePreparedStatementRequest.class); - closePreparedStatement(request, context, listener); + final ActionClosePreparedStatementRequest request = + FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), ActionClosePreparedStatementRequest.class); + closePreparedStatement(request, context, new NoResultListener(listener)); + } else if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_END_SAVEPOINT.getType())) { + ActionEndSavepointRequest request = + FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), ActionEndSavepointRequest.class); + endSavepoint(request, context, new NoResultListener(listener)); + } else if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_END_TRANSACTION.getType())) { + ActionEndTransactionRequest request = + FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), ActionEndTransactionRequest.class); + endTransaction(request, context, new NoResultListener(listener)); } else { - throw CallStatus.INVALID_ARGUMENT.withDescription("Invalid action provided.").toRuntimeException(); + throw CallStatus.INVALID_ARGUMENT + .withDescription("Unrecognized request: " + action.getType()) + .toRuntimeException(); } } + /** + * Create a savepoint within a transaction. + * + * @param request The savepoint request. + * @param context Per-call context. + * @param listener The newly created savepoint ID. + */ + default void beginSavepoint(ActionBeginSavepointRequest request, CallContext context, + StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + } + + /** + * Begin a transaction. + * + * @param request The transaction request. + * @param context Per-call context. + * @param listener The newly created transaction ID. + */ + default void beginTransaction(ActionBeginTransactionRequest request, CallContext context, + StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + } + + /** + * Explicitly cancel a query. + * + * @param info The FlightInfo of the query to cancel. + * @param context Per-call context. + * @param listener Whether cancellation succeeded. + */ + default void cancelQuery(FlightInfo info, CallContext context, StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + } + /** * Creates a prepared statement on the server and returns a handle and metadata for in a * {@link ActionCreatePreparedStatementResult} object in a {@link Result} @@ -309,6 +406,17 @@ default void doAction(CallContext context, Action action, StreamListener void createPreparedStatement(ActionCreatePreparedStatementRequest request, CallContext context, StreamListener listener); + /** + * Pre-compile a Substrait plan. + * @param request The plan. + * @param context Per-call context. + * @param listener The resulting prepared statement. + */ + default void createPreparedSubstraitPlan(ActionCreatePreparedSubstraitPlanRequest request, CallContext context, + StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + } + /** * Closes a prepared statement on the server. No result is expected. * @@ -320,9 +428,35 @@ void closePreparedStatement(ActionClosePreparedStatementRequest request, CallCon StreamListener listener); /** - * Gets information about a particular SQL query based data stream. + * Release or roll back to a savepoint. * - * @param command The sql command to generate the data stream. + * @param request The savepoint, and whether to release/rollback. + * @param context Per-call context. + * @param listener Call {@link StreamListener#onCompleted()} or + * {@link StreamListener#onError(Throwable)} when done; do not send a result. + */ + default void endSavepoint(ActionEndSavepointRequest request, CallContext context, + StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + } + + /** + * Commit or roll back to a transaction. + * + * @param request The transaction, and whether to release/rollback. + * @param context Per-call context. + * @param listener Call {@link StreamListener#onCompleted()} or + * {@link StreamListener#onError(Throwable)} when done; do not send a result. + */ + default void endTransaction(ActionEndTransactionRequest request, CallContext context, + StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + } + + /** + * Evaluate a SQL query. + * + * @param command The SQL query. * @param context Per-call context. * @param descriptor The descriptor identifying the data stream. * @return Metadata about the stream. @@ -330,6 +464,19 @@ void closePreparedStatement(ActionClosePreparedStatementRequest request, CallCon FlightInfo getFlightInfoStatement(CommandStatementQuery command, CallContext context, FlightDescriptor descriptor); + /** + * Evaluate a Substrait plan. + * + * @param command The Substrait plan. + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Metadata about the stream. + */ + default FlightInfo getFlightInfoSubstraitPlan(CommandStatementSubstraitPlan command, CallContext context, + FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.toRuntimeException(); + } + /** * Gets information about a particular prepared statement data stream. * @@ -342,7 +489,7 @@ FlightInfo getFlightInfoPreparedStatement(CommandPreparedStatementQuery command, CallContext context, FlightDescriptor descriptor); /** - * Get the schema of the result set of a query. + * Get the result schema for a SQL query. * * @param command The SQL query. * @param context Per-call context. @@ -367,6 +514,19 @@ default SchemaResult getSchemaPreparedStatement(CommandPreparedStatementQuery co .toRuntimeException(); } + /** + * Get the result schema for a Substrait plan. + * + * @param command The Substrait plan. + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Schema for the stream. + */ + default SchemaResult getSchemaSubstraitPlan(CommandStatementSubstraitPlan command, CallContext context, + FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.toRuntimeException(); + } + /** * Returns data for a SQL query based data stream. * @param ticket Ticket message containing the statement handle. @@ -399,6 +559,22 @@ void getStreamPreparedStatement(CommandPreparedStatementQuery command, CallConte Runnable acceptPutStatement(CommandStatementUpdate command, CallContext context, FlightStream flightStream, StreamListener ackStream); + /** + * Handle a Substrait plan with uploaded data. + * + * @param command The Substrait plan to evaluate. + * @param context Per-call context. + * @param flightStream The data stream being uploaded. + * @param ackStream The result data stream. + * @return A runnable to process the stream. + */ + default Runnable acceptPutSubstraitPlan(CommandStatementSubstraitPlan command, CallContext context, + FlightStream flightStream, StreamListener ackStream) { + return () -> { + ackStream.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + }; + } + /** * Accepts uploaded data for a particular prepared statement data stream. *

`PutResult`s must be in the form of a {@link DoPutUpdateResult}. @@ -450,7 +626,7 @@ FlightInfo getFlightInfoSqlInfo(CommandGetSqlInfo request, CallContext context, /** * Returns a description of all the data types supported by source. - * + * * @param request request filter parameters. * @param descriptor The descriptor identifying the data stream. * @return Metadata about the stream. diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java index e461515c40e..532921a8ac6 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java @@ -31,6 +31,18 @@ * Utilities to work with Flight SQL semantics. */ public final class FlightSqlUtils { + + public static final ActionType FLIGHT_SQL_BEGIN_SAVEPOINT = + new ActionType("BeginSavepoint", + "Create a new savepoint.\n" + + "Request Message: ActionBeginSavepointRequest\n" + + "Response Message: ActionBeginSavepointResult"); + + public static final ActionType FLIGHT_SQL_BEGIN_TRANSACTION = + new ActionType("BeginTransaction", + "Start a new transaction.\n" + + "Request Message: ActionBeginTransactionRequest\n" + + "Response Message: ActionBeginTransactionResult"); public static final ActionType FLIGHT_SQL_CREATE_PREPARED_STATEMENT = new ActionType("CreatePreparedStatement", "Creates a reusable prepared statement resource on the server. \n" + "Request Message: ActionCreatePreparedStatementRequest\n" + @@ -41,6 +53,29 @@ public final class FlightSqlUtils { "Request Message: ActionClosePreparedStatementRequest\n" + "Response Message: N/A"); + public static final ActionType FLIGHT_SQL_CREATE_PREPARED_SUBSTRAIT_PLAN = + new ActionType("CreatePreparedSubstraitPlan", + "Creates a reusable prepared statement resource on the server.\n" + + "Request Message: ActionCreatePreparedSubstraitPlanRequest\n" + + "Response Message: ActionCreatePreparedStatementResult"); + + public static final ActionType FLIGHT_SQL_CANCEL_QUERY = + new ActionType("CancelQuery", + "Explicitly cancel a running query.\n" + + "Request Message: ActionCancelQueryRequest\n" + + "Response Message: ActionCancelQueryResult"); + + public static final ActionType FLIGHT_SQL_END_SAVEPOINT = + new ActionType("EndSavepoint", + "End a savepoint.\n" + + "Request Message: ActionEndSavepointRequest\n" + + "Response Message: N/A"); + public static final ActionType FLIGHT_SQL_END_TRANSACTION = + new ActionType("EndTransaction", + "End a transaction.\n" + + "Request Message: ActionEndTransactionRequest\n" + + "Response Message: N/A"); + public static final List FLIGHT_SQL_ACTIONS = ImmutableList.of( FLIGHT_SQL_CREATE_PREPARED_STATEMENT, FLIGHT_SQL_CLOSE_PREPARED_STATEMENT diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoResultListener.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoResultListener.java new file mode 100644 index 00000000000..2c80076a8f5 --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoResultListener.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.Result; + +/** A StreamListener for actions that do not return results. */ +class NoResultListener implements FlightProducer.StreamListener { + private final FlightProducer.StreamListener listener; + + NoResultListener(FlightProducer.StreamListener listener) { + this.listener = listener; + } + + @Override + public void onNext(Result val) { + throw new UnsupportedOperationException("Do not call onNext on this listener."); + } + + @Override + public void onError(Throwable t) { + listener.onError(t); + } + + @Override + public void onCompleted() { + listener.onCompleted(); + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/ProtoListener.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/ProtoListener.java new file mode 100644 index 00000000000..fd5fd048962 --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/ProtoListener.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.Result; + +import com.google.protobuf.Any; +import com.google.protobuf.Message; + +/** + * A StreamListener that accepts a particular type. + * + * @param The message type to accept. + */ +class ProtoListener implements FlightProducer.StreamListener { + private final FlightProducer.StreamListener listener; + + ProtoListener(FlightProducer.StreamListener listener) { + this.listener = listener; + } + + @Override + public void onNext(T val) { + listener.onNext(new Result(Any.pack(val).toByteArray())); + } + + @Override + public void onError(Throwable t) { + listener.onError(t); + } + + @Override + public void onCompleted() { + listener.onCompleted(); + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java index 3866cb89b1f..18793f9b905 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java @@ -20,6 +20,7 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.stream.IntStream.range; import static org.apache.arrow.flight.FlightProducer.ServerStreamListener; +import static org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedTransaction; import static org.apache.arrow.flight.sql.util.SqlInfoOptionsUtils.createBitmaskFromEnums; import java.nio.charset.StandardCharsets; @@ -118,6 +119,46 @@ public SqlInfoBuilder withFlightSqlServerArrowVersion(final String value) { return withStringProvider(SqlInfo.FLIGHT_SQL_SERVER_ARROW_VERSION_VALUE, value); } + /** Set a value for SQL support. */ + public SqlInfoBuilder withFlightSqlServerSql(boolean value) { + return withBooleanProvider(SqlInfo.FLIGHT_SQL_SERVER_SQL_VALUE, value); + } + + /** Set a value for Substrait support. */ + public SqlInfoBuilder withFlightSqlServerSubstrait(boolean value) { + return withBooleanProvider(SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_VALUE, value); + } + + /** Set a value for Substrait minimum version support. */ + public SqlInfoBuilder withFlightSqlServerSubstraitMinVersion(String value) { + return withStringProvider(SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION_VALUE, value); + } + + /** Set a value for Substrait maximum version support. */ + public SqlInfoBuilder withFlightSqlServerSubstraitMaxVersion(String value) { + return withStringProvider(SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION_VALUE, value); + } + + /** Set a value for transaction support. */ + public SqlInfoBuilder withFlightSqlServerTransaction(SqlSupportedTransaction value) { + return withIntProvider(SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_VALUE, value.getNumber()); + } + + /** Set a value for query cancellation support. */ + public SqlInfoBuilder withFlightSqlServerCancel(boolean value) { + return withBooleanProvider(SqlInfo.FLIGHT_SQL_SERVER_CANCEL_VALUE, value); + } + + /** Set a value for statement timeouts. */ + public SqlInfoBuilder withFlightSqlServerStatementTimeout(int value) { + return withIntProvider(SqlInfo.FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT_VALUE, value); + } + + /** Set a value for transaction timeouts. */ + public SqlInfoBuilder withFlightSqlServerTransactionTimeout(int value) { + return withIntProvider(SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT_VALUE, value); + } + /** * Sets a value for {@link SqlInfo#SQL_IDENTIFIER_QUOTE_CHAR} in the builder. * diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java index d66b8df9283..fe1e1445afc 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java @@ -217,6 +217,9 @@ public FlightSqlExample(final Location location) { .withFlightSqlServerVersion(metaData.getDatabaseProductVersion()) .withFlightSqlServerArrowVersion(metaData.getDriverVersion()) .withFlightSqlServerReadOnly(metaData.isReadOnly()) + .withFlightSqlServerSql(true) + .withFlightSqlServerSubstrait(false) + .withFlightSqlServerTransaction(SqlSupportedTransaction.SQL_SUPPORTED_TRANSACTION_NONE) .withSqlIdentifierQuoteChar(metaData.getIdentifierQuoteString()) .withSqlDdlCatalog(metaData.supportsCatalogsInDataManipulation()) .withSqlDdlSchema( metaData.supportsSchemasInDataManipulation()) From de1ada364dd0affd5b075e40fe5ffbb53f1301f9 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 16 Sep 2022 11:20:34 -0400 Subject: [PATCH 084/133] ARROW-17734: [Go] Implement Take for Lists and Dense Union (#14130) Authored-by: Matt Topol Signed-off-by: Matt Topol --- go/arrow/array/bufferbuilder.go | 3 + go/arrow/compute/executor.go | 5 + go/arrow/compute/go.mod | 11 +- go/arrow/compute/go.sum | 280 ++---------------- go/arrow/compute/internal/exec/span.go | 2 + .../internal/kernels/vector_selection.go | 175 ++++++++++- go/arrow/compute/selection.go | 90 +++++- go/arrow/compute/vector_selection_test.go | 117 ++++++++ 8 files changed, 398 insertions(+), 285 deletions(-) diff --git a/go/arrow/array/bufferbuilder.go b/go/arrow/array/bufferbuilder.go index 6a91031c22b..f8a6bed255e 100644 --- a/go/arrow/array/bufferbuilder.go +++ b/go/arrow/array/bufferbuilder.go @@ -131,6 +131,9 @@ func (b *bufferBuilder) Finish() (buffer *memory.Buffer) { buffer = b.buffer b.buffer = nil b.Reset() + if buffer == nil { + buffer = memory.NewBufferBytes(nil) + } return } diff --git a/go/arrow/compute/executor.go b/go/arrow/compute/executor.go index 0e09e4afc0e..fce290c092b 100644 --- a/go/arrow/compute/executor.go +++ b/go/arrow/compute/executor.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "math" + "runtime" "sync" "github.com/apache/arrow/go/v10/arrow" @@ -44,6 +45,7 @@ type ExecCtx struct { PreallocContiguous bool Registry FunctionRegistry ExecChannelSize int + NP int } type ctxExecKey struct{} @@ -70,6 +72,9 @@ func init() { defaultExecCtx.PreallocContiguous = true defaultExecCtx.Registry = GetFunctionRegistry() defaultExecCtx.ExecChannelSize = 10 + // default level of parallelism + // set to 1 to disable parallelization + defaultExecCtx.NP = runtime.NumCPU() } // SetExecCtx returns a new child context containing the passed in ExecCtx diff --git a/go/arrow/compute/go.mod b/go/arrow/compute/go.mod index 224d56faff4..9aa0379f689 100644 --- a/go/arrow/compute/go.mod +++ b/go/arrow/compute/go.mod @@ -23,8 +23,9 @@ replace github.com/apache/arrow/go/v10 => ../../ require ( github.com/apache/arrow/go/v10 v10.0.0-00010101000000-000000000000 github.com/stretchr/testify v1.8.0 - golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e - golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 + golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 + golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde + golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f ) @@ -32,16 +33,18 @@ require ( github.com/andybalholm/brotli v1.0.4 // indirect github.com/apache/thrift v0.16.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/goccy/go-json v0.9.10 // indirect + github.com/goccy/go-json v0.9.11 // indirect github.com/golang/snappy v0.0.4 // indirect - github.com/google/flatbuffers v2.0.6+incompatible // indirect + github.com/google/flatbuffers v2.0.8+incompatible // indirect github.com/klauspost/asmfmt v1.3.2 // indirect github.com/klauspost/compress v1.15.9 // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect + github.com/kr/text v0.2.0 // indirect github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 // indirect github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 // indirect github.com/pierrec/lz4/v4 v4.1.15 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rogpeppe/go-internal v1.9.0 // indirect github.com/zeebo/xxh3 v1.0.2 // indirect golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect golang.org/x/tools v0.1.12 // indirect diff --git a/go/arrow/compute/go.sum b/go/arrow/compute/go.sum index b05bdd419c7..cc95f335c21 100644 --- a/go/arrow/compute/go.sum +++ b/go/arrow/compute/go.sum @@ -1,320 +1,74 @@ -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8= -git.sr.ht/~sbinet/gg v0.3.1/go.mod h1:KGYtlADtqsqANL9ueOFkWymvzUvLMQllU5Ixo+8v3pc= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c h1:RGWPOewvKIROun94nF7v2cua9qP+thov/7M50KEoeSU= -github.com/JohnCGriffin/overflow v0.0.0-20211019200055-46fa312c352c/go.mod h1:X0CRv0ky0k6m906ixxpzmDRLvX58TFUKS2eePweuyxk= -github.com/ajstarks/deck v0.0.0-20200831202436-30c9fc6549a9/go.mod h1:JynElWSGnm/4RlzPXRlREEwqTHAN3T56Bv2ITsFT3gY= -github.com/ajstarks/deck/generate v0.0.0-20210309230005-c3f852c02e19/go.mod h1:T13YZdzov6OU0A1+RfKZiZN9ca6VeKdBdyDV+BY97Tk= -github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= -github.com/ajstarks/svgo v0.0.0-20211024235047-1546f124cd8b/go.mod h1:1KcenG0jGWcpt8ov532z81sp/kMMUG485J2InIOyADM= github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= -github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/apache/thrift v0.16.0 h1:qEy6UW60iVOlUy+b9ZR0d5WzUWYGOo4HfopoyBaNmoY= github.com/apache/thrift v0.16.0/go.mod h1:PHK3hniurgQaNMZYaCLEqXKsYK8upmhPbmdP2FXSqgU= -github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= -github.com/boombuler/barcode v1.0.1/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= -github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XPnfI9Z40ddWsdw2W/uZgQLFXToKeRcDiI= -github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= -github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= -github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= -github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= -github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1/go.mod h1:KJwIaB5Mv44NWtYuAOFCVOjcI94vtpEz2JU/D2v6IjE= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= -github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= -github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/go-fonts/dejavu v0.1.0/go.mod h1:4Wt4I4OU2Nq9asgDCteaAaWZOV24E+0/Pwo0gppep4g= -github.com/go-fonts/latin-modern v0.2.0/go.mod h1:rQVLdDMK+mK1xscDwsqM5J8U2jrRa3T0ecnM9pNujks= -github.com/go-fonts/liberation v0.1.1/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY= -github.com/go-fonts/liberation v0.2.0/go.mod h1:K6qoJYypsmfVjWg8KOVDQhLc8UDgIK2HYqyqAO9z7GY= -github.com/go-fonts/stix v0.1.0/go.mod h1:w/c1f0ldAUlJmLBvlbkvVXLAD+tAMqobIIQpmnUIzUY= -github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= -github.com/go-latex/latex v0.0.0-20210118124228-b3d85cf34e07/go.mod h1:CO1AlKB2CSIqUrmQPqA0gdRIlnLEY0gK5JGjh37zN5U= -github.com/go-latex/latex v0.0.0-20210823091927-c0d11ff05a81/go.mod h1:SX0U8uGpxhq9o2S/CELCSUxEWWAuoCUcVCQWv7G2OCk= -github.com/go-pdf/fpdf v0.5.0/go.mod h1:HzcnA+A23uwogo0tp9yU+l3V+KXhiESpt1PMayhOh5M= -github.com/go-pdf/fpdf v0.6.0/go.mod h1:HzcnA+A23uwogo0tp9yU+l3V+KXhiESpt1PMayhOh5M= -github.com/goccy/go-json v0.9.10 h1:hCeNmprSNLB8B8vQKWl6DpuH0t60oEs+TAk9a7CScKc= -github.com/goccy/go-json v0.9.10/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= -github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= +github.com/goccy/go-json v0.9.11 h1:/pAaQDLHEoCq/5FFmSKBswWmK6H0e8g4159Kc/X/nqk= +github.com/goccy/go-json v0.9.11/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang/mock v1.5.0/go.mod h1:CWnOUgYIOo4TcNZ0wHX3YZCqsaM1I1Jvs6v3mP3KVu8= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= -github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= -github.com/golang/protobuf v1.4.0-rc.2/go.mod h1:LlEzMj4AhA7rCAGe4KMBDvJI+AwstrUpVNzEA03Pprs= -github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:WU3c8KckQ9AFe+yFwt9sWVRKCVIyN9cPHBJSNnbL67w= -github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= -github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= -github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/google/flatbuffers v2.0.6+incompatible h1:XHFReMv7nFFusa+CEokzWbzaYocKXI6C7hdU5Kgh9Lw= -github.com/google/flatbuffers v2.0.6+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= -github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= -github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= -github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= -github.com/jung-kurt/gofpdf v1.0.0/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= -github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes= -github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/google/flatbuffers v2.0.8+incompatible h1:ivUb1cGomAB101ZM1T0nOiWz9pSrTMoa9+EiY7igmkM= +github.com/google/flatbuffers v2.0.8+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE= github.com/klauspost/compress v1.15.9 h1:wKRjX6JRtDdrE9qwa4b/Cip7ACOshUI4smpCQanqjSY= github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= -github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= -github.com/mattn/go-sqlite3 v1.14.12/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8 h1:AMFGa4R4MiIpspGNG7Z948v4n35fFGB3RR3G/ry4FWs= github.com/minio/asm2plan9s v0.0.0-20200509001527-cdd76441f9d8/go.mod h1:mC1jAcsrzbxHt8iiaC+zU4b1ylILSosueou12R++wfY= github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3 h1:+n/aFZefKZp7spd8DFdX7uMikMLXX4oubIzJF4kv/wI= github.com/minio/c2goasm v0.0.0-20190812172519-36a3d3bbc4f3/go.mod h1:RagcQ7I8IeTMnF8JTXieKnO4Z6JCsikNEzj0DwauVzE= -github.com/phpdave11/gofpdf v1.4.2/go.mod h1:zpO6xFn9yxo3YLyMvW8HcKWVdbNqgIfOOp2dXMnm1mY= -github.com/phpdave11/gofpdi v1.0.12/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI= -github.com/phpdave11/gofpdi v1.0.13/go.mod h1:vBmVV0Do6hSBHC8uKUQ71JGW+ZGQq74llk/7bXwjDoI= github.com/pierrec/lz4/v4 v4.1.15 h1:MO0/ucJhngq7299dKLwIMtgTfbkoSPF6AoMYDd8Q4q0= github.com/pierrec/lz4/v4 v4.1.15/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= -github.com/ruudk/golang-pdf417 v0.0.0-20181029194003-1af4ab5afa58/go.mod h1:6lfFZQK844Gfx8o5WFuvpxWRwnSoipWe/p622j1v06w= -github.com/ruudk/golang-pdf417 v0.0.0-20201230142125-a7e3863a1245/go.mod h1:pQAZKsJ8yyVxGRWYNEm9oFB8ieLgKFnamEyDmSA0BRk= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zeebo/assert v1.3.0 h1:g7C04CbJuIDKNPFHmsk4hwZDO5O+kntRxzaUoNXj+IQ= -github.com/zeebo/assert v1.3.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= -go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE= -golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e h1:+WEEuIdZHnUeJJmEUjyYC2gfUMj69yZXw17EnHg/otA= -golang.org/x/exp v0.0.0-20220722155223-a9213eeb770e/go.mod h1:Kr81I6Kryrl9sr8s2FK3vxD90NdsKWRuOIl2O4CvYbA= -golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= -golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= -golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20190910094157-69e4b8554b2a/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20200119044424-58c23975cae1/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20200430140353-33d19683fad8/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20200618115811-c13761719519/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20201208152932-35266b937fa6/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20210216034530-4410531fe030/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= -golang.org/x/image v0.0.0-20210607152325-775e3b0c77b9/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= -golang.org/x/image v0.0.0-20210628002857-a66eb6448b8d/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= -golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= -golang.org/x/image v0.0.0-20220302094943-723b81ca9867/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= -golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY= +golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 h1:tnebWN09GYg9OLPss1KXj8txwZc6X6uMr6VFdcGNbHw= +golang.org/x/exp v0.0.0-20220827204233-334a2380cb91/go.mod h1:cyybsKvd6eL0RnXn6p/Grxp8F5bW7iYuBgsNCOHpMYE= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= -golang.org/x/mod v0.6.0-dev.0.20220106191415-9b9b3d81d5e3/go.mod h1:3p9vT2HGsQu2K1YbXdKPJLVgG5VJdoTa1poYQBtP1AY= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 h1:6zppjxzCulZykYSLyVDYbneBfbaBIQPYMevg0bEwv2s= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= -golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde h1:ejfdSekXMDxDLbRrJMwUk6KnSLZ2McaUCVcIKM+N6jc= +golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210304124612-50617c2ba197/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs= -golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/sys v0.0.0-20220829200755-d48e67d00261 h1:v6hYoSR9T5oet+pMXwUWkbiVqx/63mlHjefrHmxwfeY= +golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190206041539-40960b6deb8e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190927191325-030b2cf1153e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20201124115921-2c860bdd6e78/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= -golang.org/x/tools v0.1.9/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= -golang.org/x/tools v0.1.10/go.mod h1:Uh6Zz+xoGYZom868N8YTex3t7RhtHDBrE8Gzo9bV56E= golang.org/x/tools v0.1.12 h1:VveCTK38A2rkS8ZqFY25HIDFscX5X9OoEhJd3quQmXU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f h1:uF6paiQQebLeSXkrTqHqz0MXhXXS1KgF41eUdBNvxK0= golang.org/x/xerrors v0.0.0-20220609144429-65e65417b02f/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= -gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJDxZx5Dpnq+c3wtXuadVZAcxbbBo= -gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0= -gonum.org/v1/gonum v0.9.3/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0= gonum.org/v1/gonum v0.11.0 h1:f1IJhK4Km5tBJmaiJXtk/PkL4cdVX6J+tGiM187uT5E= gonum.org/v1/gonum v0.11.0/go.mod h1:fSG4YDCxxUZQJ7rKsQrj0gMOg00Il0Z96/qMA4bVQhA= -gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw= -gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc= -gonum.org/v1/plot v0.9.0/go.mod h1:3Pcqqmp6RHvJI72kgb8fThyUnav364FOsdDo2aGW5lY= -gonum.org/v1/plot v0.10.1/go.mod h1:VZW5OlhkL1mysU9vaqNHnsy86inf6Ot+jB3r+BczCEo= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= -google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= -google.golang.org/grpc v1.48.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= -google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= -google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= -google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= -google.golang.org/protobuf v1.20.1-0.20200309200217-e05f789c0967/go.mod h1:A+miEFZTKqfCUM6K7xSMQL9OKL/b6hQv+e19PK+JZNE= -google.golang.org/protobuf v1.21.0/go.mod h1:47Nbq4nVaFHyn7ilMalzfO3qCViNmqZ2kzikPIcrTAo= -google.golang.org/protobuf v1.22.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= -google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.1.3/go.mod h1:NgwopIslSNH47DimFoV78dnkksY2EFtX0ajyb3K/las= -lukechampine.com/uint128 v1.1.1/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= -lukechampine.com/uint128 v1.2.0/go.mod h1:c4eWIwlEGaxC/+H1VguhU4PHXNWDCDMUlWdIWl2j1gk= -modernc.org/cc/v3 v3.36.0/go.mod h1:NFUHyPn4ekoC/JHeZFfZurN6ixxawE1BnVonP/oahEI= -modernc.org/cc/v3 v3.36.1/go.mod h1:NFUHyPn4ekoC/JHeZFfZurN6ixxawE1BnVonP/oahEI= -modernc.org/ccgo/v3 v3.0.0-20220428102840-41399a37e894/go.mod h1:eI31LL8EwEBKPpNpA4bU1/i+sKOwOrQy8D87zWUcRZc= -modernc.org/ccgo/v3 v3.0.0-20220430103911-bc99d88307be/go.mod h1:bwdAnOoaIt8Ax9YdWGjxWsdkPcZyRPHqrOvJxaKAKGw= -modernc.org/ccgo/v3 v3.16.4/go.mod h1:tGtX0gE9Jn7hdZFeU88slbTh1UtCYKusWOoCJuvkWsQ= -modernc.org/ccgo/v3 v3.16.6/go.mod h1:tGtX0gE9Jn7hdZFeU88slbTh1UtCYKusWOoCJuvkWsQ= -modernc.org/ccgo/v3 v3.16.8/go.mod h1:zNjwkizS+fIFDrDjIAgBSCLkWbJuHF+ar3QRn+Z9aws= -modernc.org/ccorpus v1.11.6/go.mod h1:2gEUTrWqdpH2pXsmTM1ZkjeSrUWDpjMu2T6m29L/ErQ= -modernc.org/httpfs v1.0.6/go.mod h1:7dosgurJGp0sPaRanU53W4xZYKh14wfzX420oZADeHM= -modernc.org/libc v0.0.0-20220428101251-2d5f3daf273b/go.mod h1:p7Mg4+koNjc8jkqwcoFBJx7tXkpj00G77X7A72jXPXA= -modernc.org/libc v1.16.0/go.mod h1:N4LD6DBE9cf+Dzf9buBlzVJndKr/iJHG97vGLHYnb5A= -modernc.org/libc v1.16.1/go.mod h1:JjJE0eu4yeK7tab2n4S1w8tlWd9MxXLRzheaRnAKymU= -modernc.org/libc v1.16.7/go.mod h1:hYIV5VZczAmGZAnG15Vdngn5HSF5cSkbvfz2B7GRuVU= -modernc.org/libc v1.16.17/go.mod h1:hYIV5VZczAmGZAnG15Vdngn5HSF5cSkbvfz2B7GRuVU= -modernc.org/libc v1.16.19/go.mod h1:p7Mg4+koNjc8jkqwcoFBJx7tXkpj00G77X7A72jXPXA= -modernc.org/mathutil v1.2.2/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= -modernc.org/mathutil v1.4.1/go.mod h1:mZW8CKdRPY1v87qxC/wUdX5O1qDzXMP5TH3wjfpga6E= -modernc.org/memory v1.1.1/go.mod h1:/0wo5ibyrQiaoUoH7f9D8dnglAmILJ5/cxZlRECf+Nw= -modernc.org/opt v0.1.1/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= -modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= -modernc.org/sqlite v1.18.0/go.mod h1:B9fRWZacNxJBHoCJZQr1R54zhVn3fjfl0aszflrTSxY= -modernc.org/strutil v1.1.1/go.mod h1:DE+MQQ/hjKBZS2zNInV5hhcipt5rLPWkmpbGeW5mmdw= -modernc.org/strutil v1.1.2/go.mod h1:OYajnUAcI/MX+XD/Wx7v1bbdvcQSvxgtb0gC+u3d3eg= -modernc.org/tcl v1.13.1/go.mod h1:XOLfOwzhkljL4itZkK6T72ckMgvj0BDsnKNdZVUOecw= -modernc.org/token v1.0.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= -modernc.org/z v1.5.1/go.mod h1:eWFB510QWW5Th9YGZT81s+LwvaAs3Q2yr4sP0rmLkv8= -rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/go/arrow/compute/internal/exec/span.go b/go/arrow/compute/internal/exec/span.go index 17a4e5378aa..2e1775985f8 100644 --- a/go/arrow/compute/internal/exec/span.go +++ b/go/arrow/compute/internal/exec/span.go @@ -160,6 +160,8 @@ func (a *ArraySpan) MakeData() arrow.ArrayData { defer dict.Release() result.SetDictionary(dict) return result + } else if dt.ID() == arrow.DENSE_UNION || dt.ID() == arrow.SPARSE_UNION { + bufs[0] = nil } if len(a.Children) > 0 { diff --git a/go/arrow/compute/internal/kernels/vector_selection.go b/go/arrow/compute/internal/kernels/vector_selection.go index c4bfcca8bcd..e2b4e4d351b 100644 --- a/go/arrow/compute/internal/kernels/vector_selection.go +++ b/go/arrow/compute/internal/kernels/vector_selection.go @@ -1070,20 +1070,20 @@ func takeExec(ctx *exec.KernelCtx, outputLen int64, values, indices *exec.ArrayS } } -type outputFn func(*exec.KernelCtx, int64, *exec.ArraySpan, *exec.ArraySpan, *exec.ExecResult, func(int64) error, func() error) error -type implFn func(*exec.KernelCtx, *exec.ExecSpan, int64, *exec.ExecResult, outputFn) error +type selectionOutputFn func(*exec.KernelCtx, int64, *exec.ArraySpan, *exec.ArraySpan, *exec.ExecResult, func(int64) error, func() error) error +type selectionImplFn func(*exec.KernelCtx, *exec.ExecSpan, int64, *exec.ExecResult, selectionOutputFn) error -func FilterExec(impl implFn, fn outputFn) exec.ArrayKernelExec { +func FilterExec(impl selectionImplFn) exec.ArrayKernelExec { return func(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { var ( selection = &batch.Values[1].Array outputLength = getFilterOutputSize(selection, ctx.State.(FilterState).NullSelection) ) - return impl(ctx, batch, outputLength, out, fn) + return impl(ctx, batch, outputLength, out, filterExec) } } -func TakeExec(impl implFn, fn outputFn) exec.ArrayKernelExec { +func TakeExec(impl selectionImplFn) exec.ArrayKernelExec { return func(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { if ctx.State.(TakeState).BoundsCheck { if err := checkIndexBounds(&batch.Values[1].Array, uint64(batch.Values[0].Array.Len)); err != nil { @@ -1091,11 +1091,11 @@ func TakeExec(impl implFn, fn outputFn) exec.ArrayKernelExec { } } - return impl(ctx, batch, batch.Values[1].Array.Len, out, fn) + return impl(ctx, batch, batch.Values[1].Array.Len, out, takeExec) } } -func VarBinaryImpl[OffsetT int32 | int64](ctx *exec.KernelCtx, batch *exec.ExecSpan, outputLength int64, out *exec.ExecResult, fn outputFn) error { +func VarBinaryImpl[OffsetT int32 | int64](ctx *exec.KernelCtx, batch *exec.ExecSpan, outputLength int64, out *exec.ExecResult, fn selectionOutputFn) error { var ( values = &batch.Values[0].Array selection = &batch.Values[1].Array @@ -1144,7 +1144,7 @@ func VarBinaryImpl[OffsetT int32 | int64](ctx *exec.KernelCtx, batch *exec.ExecS return nil } -func FSBImpl(ctx *exec.KernelCtx, batch *exec.ExecSpan, outputLength int64, out *exec.ExecResult, fn outputFn) error { +func FSBImpl(ctx *exec.KernelCtx, batch *exec.ExecSpan, outputLength int64, out *exec.ExecResult, fn selectionOutputFn) error { var ( values = &batch.Values[0].Array selection = &batch.Values[1].Array @@ -1177,6 +1177,149 @@ func FSBImpl(ctx *exec.KernelCtx, batch *exec.ExecSpan, outputLength int64, out return nil } +func ListImpl[OffsetT int32 | int64](ctx *exec.KernelCtx, batch *exec.ExecSpan, outputLength int64, out *exec.ExecResult, fn selectionOutputFn) error { + var ( + values = &batch.Values[0].Array + selection = &batch.Values[1].Array + + rawOffsets = exec.GetSpanOffsets[OffsetT](values, 1) + mem = exec.GetAllocator(ctx.Ctx) + offsetBuilder = newBufferBuilder[OffsetT](mem) + childIdxBuilder = newBufferBuilder[OffsetT](mem) + ) + + if values.Len > 0 { + dataLength := rawOffsets[values.Len] - rawOffsets[0] + meanListLen := float64(dataLength) / float64(values.Len) + childIdxBuilder.reserve(int(meanListLen)) + } + + offsetBuilder.reserve(int(outputLength) + 1) + var offset OffsetT + err := fn(ctx, outputLength, values, selection, out, + func(idx int64) error { + offsetBuilder.unsafeAppend(offset) + valueOffset := rawOffsets[idx] + valueLength := rawOffsets[idx+1] - valueOffset + offset += valueLength + childIdxBuilder.reserve(int(valueLength)) + for j := valueOffset; j < valueOffset+valueLength; j++ { + childIdxBuilder.unsafeAppend(j) + } + return nil + }, func() error { + offsetBuilder.unsafeAppend(offset) + return nil + }) + + if err != nil { + return err + } + + offsetBuilder.unsafeAppend(offset) + out.Buffers[1].WrapBuffer(offsetBuilder.finish()) + + out.Children = make([]exec.ArraySpan, 1) + out.Children[0].Type = exec.GetDataType[OffsetT]() + out.Children[0].Len = int64(childIdxBuilder.len()) + out.Children[0].Buffers[1].WrapBuffer(childIdxBuilder.finish()) + + return nil +} + +func FSLImpl(ctx *exec.KernelCtx, batch *exec.ExecSpan, outputLength int64, out *exec.ExecResult, fn selectionOutputFn) error { + var ( + values = &batch.Values[0].Array + selection = &batch.Values[1].Array + + listSize = values.Type.(*arrow.FixedSizeListType).Len() + baseOffset = values.Offset + + childIdxBuilder = array.NewInt64Builder(exec.GetAllocator(ctx.Ctx)) + ) + + // we need to take listSize elements even for null elements of indices + childIdxBuilder.Reserve(int(outputLength) * int(listSize)) + err := fn(ctx, outputLength, values, selection, out, + func(idx int64) error { + offset := (baseOffset + idx) * int64(listSize) + for j := offset; j < (offset + int64(listSize)); j++ { + childIdxBuilder.UnsafeAppend(j) + } + return nil + }, func() error { + for n := int32(0); n < listSize; n++ { + childIdxBuilder.AppendNull() + } + return nil + }) + + if err != nil { + return err + } + + arr := childIdxBuilder.NewArray() + defer arr.Release() + out.Children = make([]exec.ArraySpan, 1) + out.Children[0].TakeOwnership(arr.Data()) + return nil +} + +func DenseUnionImpl(ctx *exec.KernelCtx, batch *exec.ExecSpan, outputLength int64, out *exec.ExecResult, fn selectionOutputFn) error { + var ( + values = &batch.Values[0].Array + selection = &batch.Values[1].Array + + mem = exec.GetAllocator(ctx.Ctx) + valueOffsetBldr = newBufferBuilder[int32](mem) + childIdBldr = newBufferBuilder[int8](mem) + typeCodes = values.Type.(arrow.UnionType).TypeCodes() + childIndicesBldrs = make([]*array.Int32Builder, len(typeCodes)) + ) + + for i := range childIndicesBldrs { + childIndicesBldrs[i] = array.NewInt32Builder(mem) + } + + childIdBldr.reserve(int(outputLength)) + valueOffsetBldr.reserve(int(outputLength)) + + typedValues := values.MakeArray().(*array.DenseUnion) + defer typedValues.Release() + + err := fn(ctx, outputLength, values, selection, out, + func(idx int64) error { + childID := typedValues.ChildID(int(idx)) + childIdBldr.unsafeAppend(typeCodes[childID]) + valueOffset := typedValues.ValueOffset(int(idx)) + valueOffsetBldr.unsafeAppend(int32(childIndicesBldrs[childID].Len())) + childIndicesBldrs[childID].Append(valueOffset) + return nil + }, func() error { + childID := 0 + childIdBldr.unsafeAppend(typeCodes[childID]) + valueOffsetBldr.unsafeAppend(int32(childIndicesBldrs[childID].Len())) + childIndicesBldrs[childID].AppendNull() + return nil + }) + if err != nil { + return err + } + + out.Type = typedValues.DataType() + out.Buffers[1].WrapBuffer(childIdBldr.finish()) + out.Buffers[2].WrapBuffer(valueOffsetBldr.finish()) + + out.Children = make([]exec.ArraySpan, len(childIndicesBldrs)) + for i, b := range childIndicesBldrs { + arr := b.NewArray() + out.Children[i].TakeOwnership(arr.Data()) + arr.Release() + b.Release() + } + return nil +} + func FilterBinary(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { var ( nullSelect = ctx.State.(FilterState).NullSelection @@ -1228,9 +1371,9 @@ func GetVectorSelectionKernels() (filterkernels, takeKernels []SelectionKernelDa filterkernels = []SelectionKernelData{ {In: exec.NewMatchedInput(exec.Primitive()), Exec: PrimitiveFilter}, {In: exec.NewExactInput(arrow.Null), Exec: NullFilter}, - {In: exec.NewIDInput(arrow.DECIMAL128), Exec: FilterExec(FSBImpl, filterExec)}, - {In: exec.NewIDInput(arrow.DECIMAL256), Exec: FilterExec(FSBImpl, filterExec)}, - {In: exec.NewIDInput(arrow.FIXED_SIZE_BINARY), Exec: FilterExec(FSBImpl, filterExec)}, + {In: exec.NewIDInput(arrow.DECIMAL128), Exec: FilterExec(FSBImpl)}, + {In: exec.NewIDInput(arrow.DECIMAL256), Exec: FilterExec(FSBImpl)}, + {In: exec.NewIDInput(arrow.FIXED_SIZE_BINARY), Exec: FilterExec(FSBImpl)}, {In: exec.NewMatchedInput(exec.BinaryLike()), Exec: FilterBinary}, {In: exec.NewMatchedInput(exec.LargeBinaryLike()), Exec: FilterBinary}, } @@ -1238,11 +1381,11 @@ func GetVectorSelectionKernels() (filterkernels, takeKernels []SelectionKernelDa takeKernels = []SelectionKernelData{ {In: exec.NewExactInput(arrow.Null), Exec: NullTake}, {In: exec.NewMatchedInput(exec.Primitive()), Exec: PrimitiveTake}, - {In: exec.NewIDInput(arrow.DECIMAL128), Exec: TakeExec(FSBImpl, takeExec)}, - {In: exec.NewIDInput(arrow.DECIMAL256), Exec: TakeExec(FSBImpl, takeExec)}, - {In: exec.NewIDInput(arrow.FIXED_SIZE_BINARY), Exec: TakeExec(FSBImpl, takeExec)}, - {In: exec.NewMatchedInput(exec.BinaryLike()), Exec: TakeExec(VarBinaryImpl[int32], takeExec)}, - {In: exec.NewMatchedInput(exec.LargeBinaryLike()), Exec: TakeExec(VarBinaryImpl[int64], takeExec)}, + {In: exec.NewIDInput(arrow.DECIMAL128), Exec: TakeExec(FSBImpl)}, + {In: exec.NewIDInput(arrow.DECIMAL256), Exec: TakeExec(FSBImpl)}, + {In: exec.NewIDInput(arrow.FIXED_SIZE_BINARY), Exec: TakeExec(FSBImpl)}, + {In: exec.NewMatchedInput(exec.BinaryLike()), Exec: TakeExec(VarBinaryImpl[int32])}, + {In: exec.NewMatchedInput(exec.LargeBinaryLike()), Exec: TakeExec(VarBinaryImpl[int64])}, } return } diff --git a/go/arrow/compute/selection.go b/go/arrow/compute/selection.go index c2e78fc48c7..88d3daa74ef 100644 --- a/go/arrow/compute/selection.go +++ b/go/arrow/compute/selection.go @@ -21,8 +21,10 @@ import ( "fmt" "github.com/apache/arrow/go/v10/arrow" + "github.com/apache/arrow/go/v10/arrow/array" "github.com/apache/arrow/go/v10/arrow/compute/internal/exec" "github.com/apache/arrow/go/v10/arrow/compute/internal/kernels" + "golang.org/x/sync/errgroup" ) var ( @@ -64,8 +66,8 @@ var ( }) ) -func Take(ctx context.Context, opts *TakeOptions, values, indices Datum) (Datum, error) { - return CallFunction(ctx, "array_take", opts, values, indices) +func Take(ctx context.Context, opts TakeOptions, values, indices Datum) (Datum, error) { + return CallFunction(ctx, "array_take", &opts, values, indices) } func TakeArray(ctx context.Context, values, indices arrow.Array) (arrow.Array, error) { @@ -83,6 +85,82 @@ func TakeArray(ctx context.Context, values, indices arrow.Array) (arrow.Array, e return out.(*ArrayDatum).MakeArray(), nil } +func TakeArrayOpts(ctx context.Context, values, indices arrow.Array, opts TakeOptions) (arrow.Array, error) { + v := NewDatum(values) + idx := NewDatum(indices) + defer v.Release() + defer idx.Release() + + out, err := CallFunction(ctx, "array_take", &opts, v, idx) + if err != nil { + return nil, err + } + defer out.Release() + + return out.(*ArrayDatum).MakeArray(), nil +} + +type listArr interface { + arrow.Array + ListValues() arrow.Array +} + +func takeListImpl(fn exec.ArrayKernelExec) exec.ArrayKernelExec { + return func(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + if err := fn(ctx, batch, out); err != nil { + return err + } + + // out.Children[0] contains the child indexes of values that we + // want to take after processing. + values := batch.Values[0].Array.MakeArray().(listArr) + defer values.Release() + + childIndices := out.Children[0].MakeArray() + defer childIndices.Release() + + takenChild, err := TakeArrayOpts(ctx.Ctx, values.ListValues(), childIndices, kernels.TakeOptions{BoundsCheck: false}) + if err != nil { + return err + } + defer takenChild.Release() + + out.Children[0].TakeOwnership(takenChild.Data()) + return nil + } +} + +func denseUnionImpl(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + ex := kernels.TakeExec(kernels.DenseUnionImpl) + if err := ex(ctx, batch, out); err != nil { + return err + } + + typedValues := batch.Values[0].Array.MakeArray().(*array.DenseUnion) + defer typedValues.Release() + + eg, cctx := errgroup.WithContext(ctx.Ctx) + eg.SetLimit(GetExecCtx(ctx.Ctx).NP) + + for i := 0; i < typedValues.NumFields(); i++ { + i := i + eg.Go(func() error { + arr := typedValues.Field(i) + childIndices := out.Children[i].MakeArray() + defer childIndices.Release() + taken, err := TakeArrayOpts(cctx, arr, childIndices, kernels.TakeOptions{}) + if err != nil { + return err + } + defer taken.Release() + out.Children[i].TakeOwnership(taken.Data()) + return nil + }) + } + + return eg.Wait() +} + // RegisterVectorSelection registers functions that select specific // values from arrays such as Take and Filter func RegisterVectorSelection(reg FunctionRegistry) { @@ -92,6 +170,13 @@ func RegisterVectorSelection(reg FunctionRegistry) { reg.AddFunction(takeMetaFunc, false) filterKernels, takeKernels := kernels.GetVectorSelectionKernels() + takeKernels = append(takeKernels, []kernels.SelectionKernelData{ + {In: exec.NewIDInput(arrow.LIST), Exec: takeListImpl(kernels.TakeExec(kernels.ListImpl[int32]))}, + {In: exec.NewIDInput(arrow.LARGE_LIST), Exec: takeListImpl(kernels.TakeExec(kernels.ListImpl[int64]))}, + {In: exec.NewIDInput(arrow.FIXED_SIZE_LIST), Exec: takeListImpl(kernels.TakeExec(kernels.FSLImpl))}, + {In: exec.NewIDInput(arrow.DENSE_UNION), Exec: denseUnionImpl}, + }...) + vfunc := NewVectorFunction("array_filter", Binary(), EmptyFuncDoc) vfunc.defaultOpts = &kernels.FilterOptions{} @@ -118,6 +203,7 @@ func RegisterVectorSelection(reg FunctionRegistry) { InputTypes: []exec.InputType{kd.In, selectionType}, OutType: kernels.OutputFirstType, } + basekernel.ExecFn = kd.Exec vfunc.AddKernel(basekernel) } diff --git a/go/arrow/compute/vector_selection_test.go b/go/arrow/compute/vector_selection_test.go index e5fdfbcb776..9e23896fe4a 100644 --- a/go/arrow/compute/vector_selection_test.go +++ b/go/arrow/compute/vector_selection_test.go @@ -710,6 +710,121 @@ func (tk *TakeKernelTestString) TestTakeString() { }) } +type TakeKernelLists struct { + TakeKernelTestTyped +} + +func (tk *TakeKernelLists) TestListInt32() { + tk.dt = arrow.ListOf(arrow.PrimitiveTypes.Int32) + + listJSON := `[[], [1, 2], null, [3]]` + tk.checkTake(tk.dt, listJSON, `[]`, `[]`) + tk.checkTake(tk.dt, listJSON, `[3, 2, 1]`, `[[3], null, [1,2]]`) + tk.checkTake(tk.dt, listJSON, `[null, 3, 0]`, `[null, [3], []]`) + tk.checkTake(tk.dt, listJSON, `[null, null]`, `[null, null]`) + tk.checkTake(tk.dt, listJSON, `[3, 0, 0, 3]`, `[[3], [], [], [3]]`) + tk.checkTake(tk.dt, listJSON, `[0, 1, 2, 3]`, listJSON) + tk.checkTake(tk.dt, listJSON, `[0, 0, 0, 0, 0, 0, 1]`, `[[], [], [], [], [], [], [1, 2]]`) + + tk.assertNoValidityBitmapUnknownNullCountJSON(tk.dt, `[[], [1, 2], [3]]`, `[0, 1, 0]`) +} + +func (tk *TakeKernelLists) TestListListInt32() { + tk.dt = arrow.ListOf(arrow.ListOf(arrow.PrimitiveTypes.Int32)) + + listJSON := `[ + [], + [[1], [2, null, 2], []], + null, + [[3, null], null] + ]` + tk.checkTake(tk.dt, listJSON, `[]`, `[]`) + tk.checkTake(tk.dt, listJSON, `[3, 2, 1]`, `[ + [[3, null], null], + null, + [[1], [2, null, 2], []] + ]`) + tk.checkTake(tk.dt, listJSON, `[null, 3, 0]`, `[ + null, + [[3, null], null], + [] + ]`) + tk.checkTake(tk.dt, listJSON, `[null, null]`, `[null, null]`) + tk.checkTake(tk.dt, listJSON, `[3, 0, 0, 3]`, `[[[3, null], null], [], [], [[3, null], null]]`) + tk.checkTake(tk.dt, listJSON, `[0, 1, 2, 3]`, listJSON) + tk.checkTake(tk.dt, listJSON, `[0, 0, 0, 0, 0, 0, 1]`, + `[[], [], [], [], [], [], [[1], [2, null, 2], []]]`) + + tk.assertNoValidityBitmapUnknownNullCountJSON(tk.dt, `[[[1], [2, null, 2], []], [[3, null]]]`, `[0, 1, 0]`) +} + +func (tk *TakeKernelLists) TestLargeListInt32() { + tk.dt = arrow.LargeListOf(arrow.PrimitiveTypes.Int32) + listJSON := `[[], [1, 2], null, [3]]` + tk.checkTake(tk.dt, listJSON, `[]`, `[]`) + tk.checkTake(tk.dt, listJSON, `[null, 1, 2, 0]`, `[null, [1, 2], null, []]`) +} + +func (tk *TakeKernelLists) TestFixedSizeListInt32() { + tk.dt = arrow.FixedSizeListOf(3, arrow.PrimitiveTypes.Int32) + listJSON := `[null, [1, null, 3], [4, 5, 6], [7, 8, null]]` + tk.checkTake(tk.dt, listJSON, `[]`, `[]`) + tk.checkTake(tk.dt, listJSON, `[3, 2, 1]`, `[[7, 8, null], [4, 5, 6], [1, null, 3]]`) + tk.checkTake(tk.dt, listJSON, `[null, 2, 0]`, `[null, [4, 5, 6], null]`) + tk.checkTake(tk.dt, listJSON, `[null, null]`, `[null, null]`) + tk.checkTake(tk.dt, listJSON, `[3, 0, 0, 3]`, `[[7, 8, null], null, null, [7, 8, null]]`) + tk.checkTake(tk.dt, listJSON, `[0, 1, 2, 3]`, listJSON) + tk.checkTake(tk.dt, listJSON, `[2, 2, 2, 2, 2, 2, 1]`, + `[[4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [1, null, 3]]`) + + tk.assertNoValidityBitmapUnknownNullCountJSON(tk.dt, `[[1, null, 3], [4, 5, 6], [7, 8, null]]`, `[0, 1, 0]`) +} + +type TakeKernelDenseUnion struct { + TakeKernelTestTyped +} + +func (tk *TakeKernelDenseUnion) TestTakeUnion() { + tk.dt = arrow.DenseUnionOf([]arrow.Field{ + {Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + {Name: "b", Type: arrow.BinaryTypes.String, Nullable: true}, + }, []arrow.UnionTypeCode{2, 5}) + + unionJSON := `[ + [2, null], + [2, 222], + [5, "hello"], + [5, "eh"], + [2, null], + [2, 111], + [5, null] + ]` + tk.checkTake(tk.dt, unionJSON, `[]`, `[]`) + tk.checkTake(tk.dt, unionJSON, `[3, 1, 3, 1, 3]`, `[ + [5, "eh"], + [2, 222], + [5, "eh"], + [2, 222], + [5, "eh"] + ]`) + tk.checkTake(tk.dt, unionJSON, `[4, 2, 1, 6]`, `[ + [2, null], + [5, "hello"], + [2, 222], + [5, null] + ]`) + tk.checkTake(tk.dt, unionJSON, `[0, 1, 2, 3, 4, 5, 6]`, unionJSON) + tk.checkTake(tk.dt, unionJSON, `[0, 2, 2, 2, 2, 2, 2]`, `[ + [2, null], + [5, "hello"], + [5, "hello"], + [5, "hello"], + [5, "hello"], + [5, "hello"], + [5, "hello"] + ]`) +} + func TestTakeKernels(t *testing.T) { suite.Run(t, new(TakeKernelTest)) for _, dt := range numericTypes { @@ -719,6 +834,8 @@ func TestTakeKernels(t *testing.T) { for _, dt := range baseBinaryTypes { suite.Run(t, &TakeKernelTestString{TakeKernelTestTyped: TakeKernelTestTyped{dt: dt}}) } + suite.Run(t, new(TakeKernelLists)) + suite.Run(t, new(TakeKernelDenseUnion)) } func TestFilterKernels(t *testing.T) { From b31c9ac329ea286b295472efdfa9fbadccdb0262 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 16 Sep 2022 11:26:29 -0400 Subject: [PATCH 085/133] ARROW-17749: [Go] Implement Filter and Take for Structs (#14145) Authored-by: Matt Topol Signed-off-by: Matt Topol --- go/arrow/array/concat.go | 7 +- go/arrow/compute/internal/exec/span.go | 1 + go/arrow/compute/internal/kernels/helpers.go | 4 + .../internal/kernels/vector_selection.go | 162 +++++++++++ go/arrow/compute/selection.go | 138 ++++++++-- go/arrow/compute/vector_selection_test.go | 251 ++++++++++++++++++ 6 files changed, 547 insertions(+), 16 deletions(-) diff --git a/go/arrow/array/concat.go b/go/arrow/array/concat.go index 22885f569ab..db574de4eaa 100644 --- a/go/arrow/array/concat.go +++ b/go/arrow/array/concat.go @@ -382,7 +382,12 @@ func concat(data []arrow.ArrayData, mem memory.Allocator) (arrow.ArrayData, erro out.buffers[0] = bm } - switch dt := out.dtype.(type) { + dt := out.dtype + if dt.ID() == arrow.EXTENSION { + dt = dt.(arrow.ExtensionType).StorageType() + } + + switch dt := dt.(type) { case *arrow.NullType: case *arrow.BooleanType: bm, err := concatBitmaps(gatherBitmaps(data, 1), mem) diff --git a/go/arrow/compute/internal/exec/span.go b/go/arrow/compute/internal/exec/span.go index 2e1775985f8..ca6caf436b9 100644 --- a/go/arrow/compute/internal/exec/span.go +++ b/go/arrow/compute/internal/exec/span.go @@ -162,6 +162,7 @@ func (a *ArraySpan) MakeData() arrow.ArrayData { return result } else if dt.ID() == arrow.DENSE_UNION || dt.ID() == arrow.SPARSE_UNION { bufs[0] = nil + nulls = 0 } if len(a.Children) > 0 { diff --git a/go/arrow/compute/internal/kernels/helpers.go b/go/arrow/compute/internal/kernels/helpers.go index 72e4c870d5b..4b4b70929b6 100644 --- a/go/arrow/compute/internal/kernels/helpers.go +++ b/go/arrow/compute/internal/kernels/helpers.go @@ -579,6 +579,10 @@ func (bldr *execBufBuilder) unsafeAdvance(n int) { } func (bldr *execBufBuilder) finish() (buf *memory.Buffer) { + if bldr.buffer == nil { + buf = memory.NewBufferBytes(nil) + return + } bldr.buffer.Resize(bldr.sz) buf = bldr.buffer bldr.buffer, bldr.sz = nil, 0 diff --git a/go/arrow/compute/internal/kernels/vector_selection.go b/go/arrow/compute/internal/kernels/vector_selection.go index e2b4e4d351b..310904dd959 100644 --- a/go/arrow/compute/internal/kernels/vector_selection.go +++ b/go/arrow/compute/internal/kernels/vector_selection.go @@ -18,11 +18,14 @@ package kernels import ( "fmt" + "math" "github.com/apache/arrow/go/v10/arrow" "github.com/apache/arrow/go/v10/arrow/array" "github.com/apache/arrow/go/v10/arrow/bitutil" "github.com/apache/arrow/go/v10/arrow/compute/internal/exec" + "github.com/apache/arrow/go/v10/arrow/internal/debug" + "github.com/apache/arrow/go/v10/arrow/memory" "github.com/apache/arrow/go/v10/internal/bitutils" ) @@ -87,6 +90,149 @@ func preallocateData(ctx *exec.KernelCtx, length int64, bitWidth int, allocateVa } } +type builder[T any] interface { + array.Builder + Append(T) + UnsafeAppend(T) + UnsafeAppendBoolToBitmap(bool) +} + +func getTakeIndices[T exec.IntTypes | exec.UintTypes](mem memory.Allocator, filter *exec.ArraySpan, nullSelect NullSelectionBehavior) arrow.ArrayData { + var ( + filterData = filter.Buffers[1].Buf + haveFilterNulls = filter.MayHaveNulls() + filterIsValid = filter.Buffers[0].Buf + idxType = exec.GetDataType[T]() + ) + + if haveFilterNulls && nullSelect == EmitNulls { + // Most complex case: the filter may have nulls and we don't drop them. + // The logic is ternary: + // - filter is null: emit null + // - filter is valid and true: emit index + // - filter is valid and false: don't emit anything + + bldr := array.NewBuilder(mem, idxType).(builder[T]) + defer bldr.Release() + + // position relative to start of filter + var pos T + // current position taking the filter offset into account + posWithOffset := filter.Offset + + // to count blocks where filterData[i] || !filterIsValid[i] + filterCounter := bitutils.NewBinaryBitBlockCounter(filterData, filterIsValid, filter.Offset, filter.Offset, filter.Len) + isValidCounter := bitutils.NewBitBlockCounter(filterIsValid, filter.Offset, filter.Len) + for int64(pos) < filter.Len { + // true OR NOT valid + selectedOrNullBlock := filterCounter.NextOrNotWord() + if selectedOrNullBlock.NoneSet() { + pos += T(selectedOrNullBlock.Len) + posWithOffset += int64(selectedOrNullBlock.Len) + continue + } + bldr.Reserve(int(selectedOrNullBlock.Popcnt)) + + // if the values are all valid and the selectedOrNullBlock + // is full, then we can infer that all the values are true + // and skip the bit checking + isValidBlock := isValidCounter.NextWord() + if selectedOrNullBlock.AllSet() && isValidBlock.AllSet() { + // all the values are selected and non-null + for i := 0; i < int(selectedOrNullBlock.Len); i++ { + bldr.UnsafeAppend(pos) + pos++ + } + posWithOffset += int64(selectedOrNullBlock.Len) + } else { + // some of the values are false or null + for i := 0; i < int(selectedOrNullBlock.Len); i++ { + if bitutil.BitIsSet(filterIsValid, int(posWithOffset)) { + if bitutil.BitIsSet(filterData, int(posWithOffset)) { + bldr.UnsafeAppend(pos) + } + } else { + // null slot, append null + bldr.UnsafeAppendBoolToBitmap(false) + } + pos++ + posWithOffset++ + } + } + } + + result := bldr.NewArray() + defer result.Release() + result.Data().Retain() + return result.Data() + } + + bldr := newBufferBuilder[T](mem) + if haveFilterNulls { + // the filter may have nulls, so we scan the validity bitmap + // and the filter data bitmap together + debug.Assert(nullSelect == DropNulls, "incorrect nullselect logic") + + // position relative to start of the filter + var pos T + // current position taking the filter offset into account + posWithOffset := filter.Offset + + filterCounter := bitutils.NewBinaryBitBlockCounter(filterData, filterIsValid, filter.Offset, filter.Offset, filter.Len) + for int64(pos) < filter.Len { + andBlock := filterCounter.NextAndWord() + bldr.reserve(int(andBlock.Popcnt)) + if andBlock.AllSet() { + // all the values are selected and non-null + for i := 0; i < int(andBlock.Len); i++ { + bldr.unsafeAppend(pos) + pos++ + } + posWithOffset += int64(andBlock.Len) + } else if !andBlock.NoneSet() { + // some values are false or null + for i := 0; i < int(andBlock.Len); i++ { + if bitutil.BitIsSet(filterIsValid, int(posWithOffset)) && bitutil.BitIsSet(filterData, int(posWithOffset)) { + bldr.unsafeAppend(pos) + } + pos++ + posWithOffset++ + } + } else { + pos += T(andBlock.Len) + posWithOffset += int64(andBlock.Len) + } + } + } else { + // filter has no nulls, so we only need to look for true values + bitutils.VisitSetBitRuns(filterData, filter.Offset, filter.Len, + func(pos, length int64) error { + // append consecutive run of indices + bldr.reserve(int(length)) + for i := int64(0); i < length; i++ { + bldr.unsafeAppend(T(pos + i)) + } + return nil + }) + } + + length := bldr.len() + outBuf := bldr.finish() + defer outBuf.Release() + return array.NewData(idxType, length, []*memory.Buffer{nil, outBuf}, nil, 0, 0) +} + +func GetTakeIndices(mem memory.Allocator, filter *exec.ArraySpan, nullSelect NullSelectionBehavior) (arrow.ArrayData, error) { + debug.Assert(filter.Type.ID() == arrow.BOOL, "filter should be a boolean array") + if filter.Len < math.MaxUint16 { + return getTakeIndices[uint16](mem, filter, nullSelect), nil + } else if filter.Len < math.MaxUint32 { + return getTakeIndices[uint32](mem, filter, nullSelect), nil + } + return nil, fmt.Errorf("%w: filter length exceeds UINT32_MAX, consider a different strategy for selecting elements", + arrow.ErrNotImplemented) +} + type writeFiltered interface { OutPos() int WriteValue(int64) @@ -1121,6 +1267,9 @@ func VarBinaryImpl[OffsetT int32 | int64](ctx *exec.KernelCtx, batch *exec.ExecS valOffset := rawOffsets[idx] valSize := rawOffsets[idx+1] - valOffset + if valSize == 0 { + return nil + } offset += valSize if valSize > OffsetT(spaceAvail) { dataBuilder.reserve(int(valSize)) @@ -1362,6 +1511,19 @@ func FilterBinary(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResul return fmt.Errorf("%w: invalid type for binary filter", arrow.ErrInvalid) } +func visitNoop() error { return nil } +func visitIdxNoop(int64) error { return nil } + +func StructImpl(ctx *exec.KernelCtx, batch *exec.ExecSpan, outputLength int64, out *exec.ExecResult, fn selectionOutputFn) error { + var ( + values = &batch.Values[0].Array + selection = &batch.Values[1].Array + ) + + // nothing we need to do other than generate the validity bitmap + return fn(ctx, outputLength, values, selection, out, visitIdxNoop, visitNoop) +} + type SelectionKernelData struct { In exec.InputType Exec exec.ArrayKernelExec diff --git a/go/arrow/compute/selection.go b/go/arrow/compute/selection.go index 88d3daa74ef..262e1ddeed4 100644 --- a/go/arrow/compute/selection.go +++ b/go/arrow/compute/selection.go @@ -105,7 +105,7 @@ type listArr interface { ListValues() arrow.Array } -func takeListImpl(fn exec.ArrayKernelExec) exec.ArrayKernelExec { +func selectListImpl(fn exec.ArrayKernelExec) exec.ArrayKernelExec { return func(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { if err := fn(ctx, batch, out); err != nil { return err @@ -130,29 +130,126 @@ func takeListImpl(fn exec.ArrayKernelExec) exec.ArrayKernelExec { } } -func denseUnionImpl(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { - ex := kernels.TakeExec(kernels.DenseUnionImpl) - if err := ex(ctx, batch, out); err != nil { +func denseUnionImpl(fn exec.ArrayKernelExec) exec.ArrayKernelExec { + return func(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + if err := fn(ctx, batch, out); err != nil { + return err + } + + typedValues := batch.Values[0].Array.MakeArray().(*array.DenseUnion) + defer typedValues.Release() + + eg, cctx := errgroup.WithContext(ctx.Ctx) + eg.SetLimit(GetExecCtx(ctx.Ctx).NP) + + for i := 0; i < typedValues.NumFields(); i++ { + i := i + eg.Go(func() error { + arr := typedValues.Field(i) + childIndices := out.Children[i].MakeArray() + defer childIndices.Release() + taken, err := TakeArrayOpts(cctx, arr, childIndices, kernels.TakeOptions{}) + if err != nil { + return err + } + defer taken.Release() + out.Children[i].TakeOwnership(taken.Data()) + return nil + }) + } + + return eg.Wait() + } +} + +func extensionFilterImpl(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + extArray := batch.Values[0].Array.MakeArray().(array.ExtensionArray) + defer extArray.Release() + + selection := batch.Values[1].Array.MakeArray() + defer selection.Release() + result, err := FilterArray(ctx.Ctx, extArray.Storage(), selection, FilterOptions(ctx.State.(kernels.FilterState))) + if err != nil { + return err + } + defer result.Release() + + out.TakeOwnership(result.Data()) + out.Type = extArray.DataType() + return nil +} + +func extensionTakeImpl(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + extArray := batch.Values[0].Array.MakeArray().(array.ExtensionArray) + defer extArray.Release() + + selection := batch.Values[1].Array.MakeArray() + defer selection.Release() + result, err := TakeArrayOpts(ctx.Ctx, extArray.Storage(), selection, TakeOptions(ctx.State.(kernels.TakeState))) + if err != nil { return err } + defer result.Release() - typedValues := batch.Values[0].Array.MakeArray().(*array.DenseUnion) - defer typedValues.Release() + out.TakeOwnership(result.Data()) + out.Type = extArray.DataType() + return nil +} +func structFilter(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + // transform filter to selection indices and use take + indices, err := kernels.GetTakeIndices(exec.GetAllocator(ctx.Ctx), + &batch.Values[1].Array, ctx.State.(kernels.FilterState).NullSelection) + if err != nil { + return err + } + defer indices.Release() + + filter := NewDatum(indices) + defer filter.Release() + + valData := batch.Values[0].Array.MakeData() + defer valData.Release() + + vals := NewDatum(valData) + defer vals.Release() + + result, err := Take(ctx.Ctx, kernels.TakeOptions{BoundsCheck: false}, vals, filter) + if err != nil { + return err + } + defer result.Release() + + out.TakeOwnership(result.(*ArrayDatum).Value) + return nil +} + +func structTake(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) error { + // generate top level validity bitmap + if err := kernels.TakeExec(kernels.StructImpl)(ctx, batch, out); err != nil { + return err + } + + values := batch.Values[0].Array.MakeArray().(*array.Struct) + defer values.Release() + + // select from children without bounds checking + out.Children = make([]exec.ArraySpan, values.NumField()) eg, cctx := errgroup.WithContext(ctx.Ctx) eg.SetLimit(GetExecCtx(ctx.Ctx).NP) - for i := 0; i < typedValues.NumFields(); i++ { + selection := batch.Values[1].Array.MakeArray() + defer selection.Release() + + for i := range out.Children { i := i eg.Go(func() error { - arr := typedValues.Field(i) - childIndices := out.Children[i].MakeArray() - defer childIndices.Release() - taken, err := TakeArrayOpts(cctx, arr, childIndices, kernels.TakeOptions{}) + taken, err := TakeArrayOpts(cctx, values.Field(i), selection, kernels.TakeOptions{BoundsCheck: false}) if err != nil { return err } defer taken.Release() + out.Children[i].TakeOwnership(taken.Data()) return nil }) @@ -170,11 +267,22 @@ func RegisterVectorSelection(reg FunctionRegistry) { reg.AddFunction(takeMetaFunc, false) filterKernels, takeKernels := kernels.GetVectorSelectionKernels() + filterKernels = append(filterKernels, []kernels.SelectionKernelData{ + {In: exec.NewIDInput(arrow.LIST), Exec: selectListImpl(kernels.FilterExec(kernels.ListImpl[int32]))}, + {In: exec.NewIDInput(arrow.LARGE_LIST), Exec: selectListImpl(kernels.FilterExec(kernels.ListImpl[int64]))}, + {In: exec.NewIDInput(arrow.FIXED_SIZE_LIST), Exec: selectListImpl(kernels.FilterExec(kernels.FSLImpl))}, + {In: exec.NewIDInput(arrow.DENSE_UNION), Exec: denseUnionImpl(kernels.FilterExec(kernels.DenseUnionImpl))}, + {In: exec.NewIDInput(arrow.EXTENSION), Exec: extensionFilterImpl}, + {In: exec.NewIDInput(arrow.STRUCT), Exec: structFilter}, + }...) + takeKernels = append(takeKernels, []kernels.SelectionKernelData{ - {In: exec.NewIDInput(arrow.LIST), Exec: takeListImpl(kernels.TakeExec(kernels.ListImpl[int32]))}, - {In: exec.NewIDInput(arrow.LARGE_LIST), Exec: takeListImpl(kernels.TakeExec(kernels.ListImpl[int64]))}, - {In: exec.NewIDInput(arrow.FIXED_SIZE_LIST), Exec: takeListImpl(kernels.TakeExec(kernels.FSLImpl))}, - {In: exec.NewIDInput(arrow.DENSE_UNION), Exec: denseUnionImpl}, + {In: exec.NewIDInput(arrow.LIST), Exec: selectListImpl(kernels.TakeExec(kernels.ListImpl[int32]))}, + {In: exec.NewIDInput(arrow.LARGE_LIST), Exec: selectListImpl(kernels.TakeExec(kernels.ListImpl[int64]))}, + {In: exec.NewIDInput(arrow.FIXED_SIZE_LIST), Exec: selectListImpl(kernels.TakeExec(kernels.FSLImpl))}, + {In: exec.NewIDInput(arrow.DENSE_UNION), Exec: denseUnionImpl(kernels.TakeExec(kernels.DenseUnionImpl))}, + {In: exec.NewIDInput(arrow.EXTENSION), Exec: extensionTakeImpl}, + {In: exec.NewIDInput(arrow.STRUCT), Exec: structTake}, }...) vfunc := NewVectorFunction("array_filter", Binary(), EmptyFuncDoc) diff --git a/go/arrow/compute/vector_selection_test.go b/go/arrow/compute/vector_selection_test.go index 9e23896fe4a..6b01a11b74d 100644 --- a/go/arrow/compute/vector_selection_test.go +++ b/go/arrow/compute/vector_selection_test.go @@ -28,6 +28,7 @@ import ( "github.com/apache/arrow/go/v10/arrow/compute/internal/exec" "github.com/apache/arrow/go/v10/arrow/compute/internal/kernels" "github.com/apache/arrow/go/v10/arrow/internal/testing/gen" + "github.com/apache/arrow/go/v10/arrow/internal/testing/types" "github.com/apache/arrow/go/v10/arrow/memory" "github.com/apache/arrow/go/v10/arrow/scalar" "github.com/stretchr/testify/assert" @@ -370,6 +371,49 @@ func (f *FilterKernelWithBoolean) TestDefaultOptions() { assertDatumsEqual(f.T(), defOpts, noOpts) } +type FilterKernelExtension struct { + FilterKernelTestSuite +} + +func (f *FilterKernelExtension) TestExtension() { + dt := types.NewSmallintType() + arrow.RegisterExtensionType(dt) + defer arrow.UnregisterExtensionType(dt.ExtensionName()) + + f.assertFilterJSON(dt, `[]`, `[]`, `[]`) + f.assertFilterJSON(dt, `[9]`, `[false]`, `[]`) + f.assertFilterJSON(dt, `[9]`, `[true]`, `[9]`) + f.assertFilterJSON(dt, `[9]`, `[null]`, `[null]`) + f.assertFilterJSON(dt, `[null]`, `[false]`, `[]`) + f.assertFilterJSON(dt, `[null]`, `[true]`, `[null]`) + f.assertFilterJSON(dt, `[null]`, `[null]`, `[null]`) + + f.assertFilterJSON(dt, `[7, 8, 9]`, `[false, true, false]`, `[8]`) + f.assertFilterJSON(dt, `[7, 8, 9]`, `[true, false, true]`, `[7, 9]`) + f.assertFilterJSON(dt, `[null, 8, 9]`, `[false, true, false]`, `[8]`) + f.assertFilterJSON(dt, `[7, 8, 9]`, `[null, true, false]`, `[null, 8]`) + f.assertFilterJSON(dt, `[7, 8, 9]`, `[true, null, true]`, `[7, null, 9]`) + + val := f.getArr(dt, `[7, 8, 9]`) + defer val.Release() + filter := f.getArr(arrow.FixedWidthTypes.Boolean, `[false, true, true, true, false, true]`) + defer filter.Release() + filter = array.NewSlice(filter, 3, 6) + defer filter.Release() + exp := f.getArr(dt, `[7, 9]`) + defer exp.Release() + + f.assertFilter(val, filter, exp) + + invalidFilter := f.getArr(arrow.FixedWidthTypes.Boolean, `[]`) + defer invalidFilter.Release() + + _, err := compute.FilterArray(context.TODO(), val, invalidFilter, f.emitNulls) + f.ErrorIs(err, arrow.ErrInvalid) + _, err = compute.FilterArray(context.TODO(), val, invalidFilter, f.dropOpts) + f.ErrorIs(err, arrow.ErrInvalid) +} + type FilterKernelNumeric struct { FilterKernelTestSuite @@ -633,6 +677,143 @@ func (f *FilterKernelWithString) TestFilterString() { }) } +type FilterKernelWithList struct { + FilterKernelTestSuite +} + +func (f *FilterKernelWithList) TestListInt32() { + dt := arrow.ListOf(arrow.PrimitiveTypes.Int32) + listJSON := `[[], [1, 2], null, [3]]` + f.assertFilterJSON(dt, listJSON, `[false, false, false, false]`, `[]`) + f.assertFilterJSON(dt, listJSON, `[false, true, true, null]`, `[[1, 2], null, null]`) + f.assertFilterJSON(dt, listJSON, `[false, false, true, null]`, `[null, null]`) + f.assertFilterJSON(dt, listJSON, `[true, false, false, true]`, `[[], [3]]`) + f.assertFilterJSON(dt, listJSON, `[true, true, true, true]`, listJSON) + f.assertFilterJSON(dt, listJSON, `[false, true, false, true]`, `[[1, 2], [3]]`) +} + +func (f *FilterKernelWithList) TestListListInt32() { + dt := arrow.ListOf(arrow.ListOf(arrow.PrimitiveTypes.Int32)) + listJSON := `[ + [], + [[1], [2, null, 2], []], + null, + [[3, null], null] + ]` + + f.assertFilterJSON(dt, listJSON, `[false, false, false, false]`, `[]`) + f.assertFilterJSON(dt, listJSON, `[false, true, true, null]`, `[ + [[1], [2, null, 2], []], + null, + null + ]`) + f.assertFilterJSON(dt, listJSON, `[false, false, true, null]`, `[null, null]`) + f.assertFilterJSON(dt, listJSON, `[true, false, false, true]`, `[ + [], + [[3, null], null] + ]`) + f.assertFilterJSON(dt, listJSON, `[true, true, true, true]`, listJSON) + f.assertFilterJSON(dt, listJSON, `[false, true, false, true]`, `[ + [[1], [2, null, 2], []], + [[3, null], null] + ]`) +} + +func (f *FilterKernelWithList) TestLargeListInt32() { + dt := arrow.LargeListOf(arrow.PrimitiveTypes.Int32) + listJSON := `[[], [1, 2], null, [3]]` + f.assertFilterJSON(dt, listJSON, `[false, false, false, false]`, `[]`) + f.assertFilterJSON(dt, listJSON, `[false, true, true, null]`, `[[1, 2], null, null]`) +} + +func (f *FilterKernelWithList) TestFixedSizeListInt32() { + dt := arrow.FixedSizeListOf(3, arrow.PrimitiveTypes.Int32) + listJSON := `[null, [1, null, 3], [4, 5, 6], [7, 8, null]]` + f.assertFilterJSON(dt, listJSON, `[false, false, false, false]`, `[]`) + f.assertFilterJSON(dt, listJSON, `[false, true, true, null]`, `[[1, null, 3], [4, 5, 6], null]`) + f.assertFilterJSON(dt, listJSON, `[false, false, true, null]`, `[[4, 5, 6], null]`) + f.assertFilterJSON(dt, listJSON, `[true, true, true, true]`, listJSON) + f.assertFilterJSON(dt, listJSON, `[false, true, false, true]`, `[[1, null, 3], [7, 8, null]]`) +} + +type FilterKernelWithUnion struct { + FilterKernelTestSuite +} + +func (f *FilterKernelWithUnion) TestDenseUnion() { + dt := arrow.DenseUnionOf([]arrow.Field{ + {Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + {Name: "b", Type: arrow.BinaryTypes.String, Nullable: true}, + }, []arrow.UnionTypeCode{2, 5}) + + unionJSON := `[ + [2, null], + [2, 222], + [5, "hello"], + [5, "eh"], + [2, null], + [2, 111], + [5, null] + ]` + + f.assertFilterJSON(dt, unionJSON, `[false, false, false, false, false, false, false]`, `[]`) + f.assertFilterJSON(dt, unionJSON, `[false, true, true, null, false, true, true]`, `[ + [2, 222], + [5, "hello"], + [2, null], + [2, 111], + [5, null] + ]`) + f.assertFilterJSON(dt, unionJSON, `[true, false, true, false, true, false, false]`, `[ + [2, null], + [5, "hello"], + [2, null] + ]`) + f.assertFilterJSON(dt, unionJSON, `[true, true, true, true, true, true, true]`, unionJSON) + + // sliced + // (check this manually as concat of dense unions isn't supported) + unionArr, _, _ := array.FromJSON(f.mem, dt, strings.NewReader(unionJSON)) + defer unionArr.Release() + + filterArr, _, _ := array.FromJSON(f.mem, arrow.FixedWidthTypes.Boolean, strings.NewReader(`[false, true, true, null, false, true, true]`)) + defer filterArr.Release() + + expected, _, _ := array.FromJSON(f.mem, dt, strings.NewReader(`[[5, "hello"], [2, null], [2, 111]]`)) + defer expected.Release() + + values := array.NewSlice(unionArr, 2, 6) + defer values.Release() + filter := array.NewSlice(filterArr, 2, 6) + defer filter.Release() + f.assertFilter(values, filter, expected) +} + +type FilterKernelWithStruct struct { + FilterKernelTestSuite +} + +func (f *FilterKernelWithStruct) TestStruct() { + dt := arrow.StructOf(arrow.Field{Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + arrow.Field{Name: "b", Type: arrow.BinaryTypes.String, Nullable: true}) + + structJSON := `[ + null, + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ]` + + f.assertFilterJSON(dt, structJSON, `[false, false, false, false]`, `[]`) + f.assertFilterJSON(dt, structJSON, `[false, true, true, null]`, `[ + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + null + ]`) + f.assertFilterJSON(dt, structJSON, `[true, true, true, true]`, structJSON) + f.assertFilterJSON(dt, structJSON, `[true, false, true, false]`, `[null, {"a": 2, "b": "hello"}]`) +} + type TakeKernelTestTyped struct { TakeKernelTestSuite @@ -663,6 +844,28 @@ func (tk *TakeKernelTestNumeric) TestTakeNumeric() { }) } +type TakeKernelTestExtension struct { + TakeKernelTestTyped +} + +func (tk *TakeKernelTestExtension) TestTakeExtension() { + tk.dt = types.NewSmallintType() + arrow.RegisterExtensionType(tk.dt.(arrow.ExtensionType)) + defer arrow.UnregisterExtensionType("smallint") + + tk.assertTake(`[7, 8, 9]`, `[]`, `[]`) + tk.assertTake(`[7, 8, 9]`, `[0, 1, 0]`, `[7, 8, 7]`) + tk.assertTake(`[null, 8, 9]`, `[0, 1, 0]`, `[null, 8, null]`) + tk.assertTake(`[7, 8, 9]`, `[null, 1, 0]`, `[null, 8, 7]`) + tk.assertTake(`[null, 8, 9]`, `[]`, `[]`) + tk.assertTake(`[7, 8, 9]`, `[0, 0, 0, 0, 0, 0, 2]`, `[7, 7, 7, 7, 7, 7, 9]`) + + _, err := tk.takeJSON(tk.dt, `[7, 8, 9]`, arrow.PrimitiveTypes.Int8, `[0, 9, 0]`) + tk.ErrorIs(err, arrow.ErrIndex) + _, err = tk.takeJSON(tk.dt, `[7, 8, 9]`, arrow.PrimitiveTypes.Int8, `[0, -1, 0]`) + tk.ErrorIs(err, arrow.ErrIndex) +} + type TakeKernelTestFSB struct { TakeKernelTestTyped } @@ -825,6 +1028,48 @@ func (tk *TakeKernelDenseUnion) TestTakeUnion() { ]`) } +type TakeKernelStruct struct { + TakeKernelTestTyped +} + +func (tk *TakeKernelStruct) TestStruct() { + tk.dt = arrow.StructOf(arrow.Field{Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + arrow.Field{Name: "b", Type: arrow.BinaryTypes.String, Nullable: true}) + + structJSON := `[ + null, + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ]` + + tk.checkTake(tk.dt, structJSON, `[]`, `[]`) + tk.checkTake(tk.dt, structJSON, `[3, 1, 3, 1, 3]`, `[ + {"a": 4, "b": "eh"}, + {"a": 1, "b": ""}, + {"a": 4, "b": "eh"}, + {"a": 1, "b": ""}, + {"a": 4, "b": "eh"} + ]`) + tk.checkTake(tk.dt, structJSON, `[3, 1, 0]`, `[ + {"a": 4, "b": "eh"}, + {"a": 1, "b": ""}, + null + ]`) + tk.checkTake(tk.dt, structJSON, `[0, 1, 2, 3]`, structJSON) + tk.checkTake(tk.dt, structJSON, `[0, 2, 2, 2, 2, 2, 2]`, `[ + null, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"} + ]`) + + tk.assertNoValidityBitmapUnknownNullCountJSON(tk.dt, `[{"a": 1}, {"a": 2, "b": "hello"}]`, `[0, 1, 0]`) +} + func TestTakeKernels(t *testing.T) { suite.Run(t, new(TakeKernelTest)) for _, dt := range numericTypes { @@ -836,6 +1081,8 @@ func TestTakeKernels(t *testing.T) { } suite.Run(t, new(TakeKernelLists)) suite.Run(t, new(TakeKernelDenseUnion)) + suite.Run(t, new(TakeKernelTestExtension)) + suite.Run(t, new(TakeKernelStruct)) } func TestFilterKernels(t *testing.T) { @@ -850,4 +1097,8 @@ func TestFilterKernels(t *testing.T) { for _, dt := range baseBinaryTypes { suite.Run(t, &FilterKernelWithString{dt: dt}) } + suite.Run(t, new(FilterKernelWithList)) + suite.Run(t, new(FilterKernelWithUnion)) + suite.Run(t, new(FilterKernelExtension)) + suite.Run(t, new(FilterKernelWithStruct)) } From a87b47d47df7fe5217a330806731e4bd74c50162 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 16 Sep 2022 11:57:48 -0400 Subject: [PATCH 086/133] ARROW-17741: [Packaging] Include JDBC driver in java-jars artifacts (#14139) Authored-by: David Li Signed-off-by: David Li --- dev/tasks/tasks.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index 3792dc14905..c00d101f43e 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -927,6 +927,11 @@ tasks: - flight-sql-{no_rc_snapshot_version}-tests.jar - flight-sql-{no_rc_snapshot_version}.jar - flight-sql-{no_rc_snapshot_version}.pom + - flight-sql-jdbc-driver-{no_rc_snapshot_version}-javadoc.jar + - flight-sql-jdbc-driver-{no_rc_snapshot_version}-sources.jar + - flight-sql-jdbc-driver-{no_rc_snapshot_version}-tests.jar + - flight-sql-jdbc-driver-{no_rc_snapshot_version}.jar + - flight-sql-jdbc-driver-{no_rc_snapshot_version}.pom ############################## NuGet packages ############################### From e3be15f3eef5822402f1bc687fe23cfc0df8e130 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Fri, 16 Sep 2022 12:04:42 -0400 Subject: [PATCH 087/133] ARROW-17677: [Go] Filter functions for list and extension types (#14141) Authored-by: Matt Topol Signed-off-by: Matt Topol --- go/arrow/compute/executor.go | 36 ++++++++++++++++++++++++++++++----- go/arrow/compute/selection.go | 4 ++-- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/go/arrow/compute/executor.go b/go/arrow/compute/executor.go index fce290c092b..c20bd9e4684 100644 --- a/go/arrow/compute/executor.go +++ b/go/arrow/compute/executor.go @@ -41,11 +41,23 @@ import ( // An ExecCtx should be placed into a context.Context by using // SetExecCtx and GetExecCtx to pass it along for execution. type ExecCtx struct { - ChunkSize int64 + // ChunkSize is the size used when iterating batches for execution + // ChunkSize elements will be operated on as a time unless an argument + // is a chunkedarray with a chunk that is smaller + ChunkSize int64 + // PreallocContiguous determines whether preallocating memory for + // execution of compute attempts to preallocate a full contiguous + // buffer for all of the chunks beforehand. PreallocContiguous bool - Registry FunctionRegistry - ExecChannelSize int - NP int + // Registry allows specifying the Function Registry to utilize + // when searching for kernel implementations. + Registry FunctionRegistry + // ExecChannelSize is the size of the channel used for passing + // exec results to the WrapResults function. + ExecChannelSize int + // NumParallel determines the number of parallel goroutines + // allowed for parallel executions. + NumParallel int } type ctxExecKey struct{} @@ -67,6 +79,20 @@ var ( GetAllocator = exec.GetAllocator ) +// DefaultExecCtx returns the default exec context which will be used +// if there is no ExecCtx set into the context for execution. +// +// This can be called to get a copy of the default values which can +// then be modified to set into a context. +// +// The default exec context uses the following values: +// - ChunkSize = DefaultMaxChunkSize (MaxInt64) +// - PreallocContiguous = true +// - Registry = GetFunctionRegistry() +// - ExecChannelSize = 10 +// - NumParallel = runtime.NumCPU() +func DefaultExecCtx() ExecCtx { return defaultExecCtx } + func init() { defaultExecCtx.ChunkSize = DefaultMaxChunkSize defaultExecCtx.PreallocContiguous = true @@ -74,7 +100,7 @@ func init() { defaultExecCtx.ExecChannelSize = 10 // default level of parallelism // set to 1 to disable parallelization - defaultExecCtx.NP = runtime.NumCPU() + defaultExecCtx.NumParallel = runtime.NumCPU() } // SetExecCtx returns a new child context containing the passed in ExecCtx diff --git a/go/arrow/compute/selection.go b/go/arrow/compute/selection.go index 262e1ddeed4..ac5b4e4c653 100644 --- a/go/arrow/compute/selection.go +++ b/go/arrow/compute/selection.go @@ -140,7 +140,7 @@ func denseUnionImpl(fn exec.ArrayKernelExec) exec.ArrayKernelExec { defer typedValues.Release() eg, cctx := errgroup.WithContext(ctx.Ctx) - eg.SetLimit(GetExecCtx(ctx.Ctx).NP) + eg.SetLimit(GetExecCtx(ctx.Ctx).NumParallel) for i := 0; i < typedValues.NumFields(); i++ { i := i @@ -236,7 +236,7 @@ func structTake(ctx *exec.KernelCtx, batch *exec.ExecSpan, out *exec.ExecResult) // select from children without bounds checking out.Children = make([]exec.ArraySpan, values.NumField()) eg, cctx := errgroup.WithContext(ctx.Ctx) - eg.SetLimit(GetExecCtx(ctx.Ctx).NP) + eg.SetLimit(GetExecCtx(ctx.Ctx).NumParallel) selection := batch.Values[1].Array.MakeArray() defer selection.Release() From 6bc2e010d9fb4e50d8a9490ec5fa092f2f8783b4 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Fri, 16 Sep 2022 13:18:38 -0400 Subject: [PATCH 088/133] MINOR: [R] Fix lint warnings and run styler over everything (#14153) Authored-by: Neal Richardson Signed-off-by: Neal Richardson --- r/DESCRIPTION | 2 +- r/R/arrowExports.R | 1 - r/R/dplyr-datetime-helpers.R | 8 +- r/R/dplyr-funcs-doc.R | 28 ++--- r/R/dplyr.R | 3 +- r/data-raw/docgen.R | 12 +- r/man/acero.Rd | 4 +- r/man/show_exec_plan.Rd | 2 +- r/tests/testthat/test-Table.R | 1 - r/tests/testthat/test-compute.R | 2 +- r/tests/testthat/test-dataset-dplyr.R | 40 +++---- r/tests/testthat/test-dataset.R | 4 +- r/tests/testthat/test-dplyr-across.R | 1 - r/tests/testthat/test-dplyr-funcs-datetime.R | 109 +++++++------------ r/tests/testthat/test-dplyr-funcs-math.R | 3 +- r/tests/testthat/test-dplyr-funcs-string.R | 3 +- r/tests/testthat/test-dplyr-funcs-type.R | 2 +- r/tools/winlibs.R | 2 +- 18 files changed, 100 insertions(+), 127 deletions(-) diff --git a/r/DESCRIPTION b/r/DESCRIPTION index 7b60f0c510a..90e84d34bc2 100644 --- a/r/DESCRIPTION +++ b/r/DESCRIPTION @@ -41,7 +41,7 @@ Imports: utils, vctrs Roxygen: list(markdown = TRUE, r6 = FALSE, load = "source") -RoxygenNote: 7.2.0 +RoxygenNote: 7.2.1 Config/testthat/edition: 3 VignetteBuilder: knitr Suggests: diff --git a/r/R/arrowExports.R b/r/R/arrowExports.R index 6e76cd64687..35c73e547c9 100644 --- a/r/R/arrowExports.R +++ b/r/R/arrowExports.R @@ -2043,4 +2043,3 @@ SetIOThreadPoolCapacity <- function(threads) { Array__infer_type <- function(x) { .Call(`_arrow_Array__infer_type`, x) } - diff --git a/r/R/dplyr-datetime-helpers.R b/r/R/dplyr-datetime-helpers.R index 4c9a8d1bf05..ba9bb0d5437 100644 --- a/r/R/dplyr-datetime-helpers.R +++ b/r/R/dplyr-datetime-helpers.R @@ -442,8 +442,10 @@ parse_period_unit <- function(x) { str_unit <- substr(x, capture_start[[2]], capture_end[[2]]) str_multiple <- substr(x, capture_start[[1]], capture_end[[1]]) - known_units <- c("nanosecond", "microsecond", "millisecond", "second", - "minute", "hour", "day", "week", "month", "quarter", "year") + known_units <- c( + "nanosecond", "microsecond", "millisecond", "second", + "minute", "hour", "day", "week", "month", "quarter", "year" + ) # match the period unit str_unit_start <- substr(str_unit, 1, 3) @@ -464,7 +466,7 @@ parse_period_unit <- function(x) { if (capture_length[[1]] == 0) { multiple <- 1L - # otherwise parse the multiple + # otherwise parse the multiple } else { multiple <- as.numeric(str_multiple) diff --git a/r/R/dplyr-funcs-doc.R b/r/R/dplyr-funcs-doc.R index cac0310f49b..cbfe475232b 100644 --- a/r/R/dplyr-funcs-doc.R +++ b/r/R/dplyr-funcs-doc.R @@ -88,12 +88,12 @@ #' as `arrow_ascii_is_decimal`. #' #' ## arrow -#' +#' #' * [`add_filename()`][arrow::add_filename()] #' * [`cast()`][arrow::cast()] #' #' ## base -#' +#' #' * [`-`][-()] #' * [`!`][!()] #' * [`!=`][!=()] @@ -179,13 +179,15 @@ #' * [`trunc()`][base::trunc()] #' #' ## bit64 -#' +#' #' * [`as.integer64()`][bit64::as.integer64()] #' * [`is.integer64()`][bit64::is.integer64()] #' #' ## dplyr -#' -#' * [`across()`][dplyr::across()]: only supported inside `mutate()`, `summarize()`, and `arrange()`; purrr-style lambda functions and use of `where()` selection helper not yet supported +#' +#' * [`across()`][dplyr::across()]: supported inside `mutate()`, `summarize()`, `group_by()`, and `arrange()`; +#' purrr-style lambda functions +#' and use of `where()` selection helper not yet supported #' * [`between()`][dplyr::between()] #' * [`case_when()`][dplyr::case_when()] #' * [`coalesce()`][dplyr::coalesce()] @@ -195,7 +197,7 @@ #' * [`n_distinct()`][dplyr::n_distinct()] #' #' ## lubridate -#' +#' #' * [`am()`][lubridate::am()] #' * [`as_date()`][lubridate::as_date()] #' * [`as_datetime()`][lubridate::as_datetime()] @@ -270,11 +272,11 @@ #' * [`yq()`][lubridate::yq()] #' #' ## methods -#' +#' #' * [`is()`][methods::is()] #' #' ## rlang -#' +#' #' * [`is_character()`][rlang::is_character()] #' * [`is_double()`][rlang::is_double()] #' * [`is_integer()`][rlang::is_integer()] @@ -282,18 +284,18 @@ #' * [`is_logical()`][rlang::is_logical()] #' #' ## stats -#' +#' #' * [`median()`][stats::median()] #' * [`quantile()`][stats::quantile()] #' * [`sd()`][stats::sd()] #' * [`var()`][stats::var()] #' #' ## stringi -#' +#' #' * [`stri_reverse()`][stringi::stri_reverse()] #' #' ## stringr -#' +#' #' * [`str_c()`][stringr::str_c()]: the `collapse` argument is not yet supported #' * [`str_count()`][stringr::str_count()] #' * [`str_detect()`][stringr::str_detect()] @@ -313,11 +315,11 @@ #' * [`str_trim()`][stringr::str_trim()] #' #' ## tibble -#' +#' #' * [`tibble()`][tibble::tibble()] #' #' ## tidyselect -#' +#' #' * [`all_of()`][tidyselect::all_of()] #' * [`contains()`][tidyselect::contains()] #' * [`ends_with()`][tidyselect::ends_with()] diff --git a/r/R/dplyr.R b/r/R/dplyr.R index 86132d8ae4a..a96678f9757 100644 --- a/r/R/dplyr.R +++ b/r/R/dplyr.R @@ -189,7 +189,6 @@ dim.arrow_dplyr_query <- function(x) { #' @export unique.arrow_dplyr_query <- function(x, incomparables = FALSE, fromLast = FALSE, ...) { - if (isTRUE(incomparables)) { arrow_not_supported("`unique()` with `incomparables = TRUE`") } @@ -262,7 +261,7 @@ tail.arrow_dplyr_query <- function(x, n = 6L, ...) { #' mtcars %>% #' arrow_table() %>% #' filter(mpg > 20) %>% -#' mutate(x = gear/carb) %>% +#' mutate(x = gear / carb) %>% #' show_exec_plan() show_exec_plan <- function(x) { adq <- as_adq(x) diff --git a/r/data-raw/docgen.R b/r/data-raw/docgen.R index ef39bec272f..9f839cfd123 100644 --- a/r/data-raw/docgen.R +++ b/r/data-raw/docgen.R @@ -116,8 +116,7 @@ render_pkg <- function(df, pkg) { pull() # Add header bullets <- c( - paste("##", pkg), - "", + paste0("## ", pkg, "\n#'"), bullets ) paste("#'", bullets, collapse = "\n") @@ -129,10 +128,11 @@ docs <- arrow:::.cache$docs # across() is handled by manipulating the quosures, not by nse_funcs docs[["dplyr::across"]] <- c( - # TODO(ARROW-17387, ARROW-17389, ARROW-17390) - "only supported inside `mutate()`, `summarize()`, and `arrange()`;", - # TODO(ARROW-17366) + # TODO(ARROW-17387, ARROW-17389, ARROW-17390): other verbs + "supported inside `mutate()`, `summarize()`, `group_by()`, and `arrange()`;", + # TODO(ARROW-17366): do ~ "purrr-style lambda functions", + # TODO(ARROW-17384): implement where "and use of `where()` selection helper not yet supported" ) # desc() is a special helper handled inside of arrange() @@ -154,7 +154,7 @@ fun_df <- tibble::tibble( # We will list operators under "base" (everything else must be pkg::fun) pkg = if_else(has_pkg, pkg, "base"), # Flatten notes to a single string - notes = map_chr(notes, ~ paste(., collapse = " ")) + notes = map_chr(notes, ~ paste(., collapse = "\n#' ")) ) %>% arrange(pkg, fun) diff --git a/r/man/acero.Rd b/r/man/acero.Rd index 5b5920f386e..5d4859edcb5 100644 --- a/r/man/acero.Rd +++ b/r/man/acero.Rd @@ -175,7 +175,9 @@ as \code{arrow_ascii_is_decimal}. \subsection{dplyr}{ \itemize{ -\item \code{\link[dplyr:across]{across()}}: only supported inside \code{mutate()}, \code{summarize()}, and \code{arrange()}; purrr-style lambda functions and use of \code{where()} selection helper not yet supported +\item \code{\link[dplyr:across]{across()}}: supported inside \code{mutate()}, \code{summarize()}, \code{group_by()}, and \code{arrange()}; +purrr-style lambda functions +and use of \code{where()} selection helper not yet supported \item \code{\link[dplyr:between]{between()}} \item \code{\link[dplyr:case_when]{case_when()}} \item \code{\link[dplyr:coalesce]{coalesce()}} diff --git a/r/man/show_exec_plan.Rd b/r/man/show_exec_plan.Rd index c020838b2ed..d6eb2298f22 100644 --- a/r/man/show_exec_plan.Rd +++ b/r/man/show_exec_plan.Rd @@ -25,7 +25,7 @@ library(dplyr) mtcars \%>\% arrow_table() \%>\% filter(mpg > 20) \%>\% - mutate(x = gear/carb) \%>\% + mutate(x = gear / carb) \%>\% show_exec_plan() \dontshow{\}) # examplesIf} } diff --git a/r/tests/testthat/test-Table.R b/r/tests/testthat/test-Table.R index bafd183108a..d2818943823 100644 --- a/r/tests/testthat/test-Table.R +++ b/r/tests/testthat/test-Table.R @@ -692,5 +692,4 @@ test_that("num_rows method not susceptible to integer overflow", { expect_type(big_table$num_rows, "double") expect_identical(big_string_array$data()$buffers[[3]]$size, 2148007936) - }) diff --git a/r/tests/testthat/test-compute.R b/r/tests/testthat/test-compute.R index 11c37519ae5..882eea03e52 100644 --- a/r/tests/testthat/test-compute.R +++ b/r/tests/testthat/test-compute.R @@ -273,7 +273,7 @@ test_that("nested exec plans can contain user-defined functions", { on.exit(unregister_binding("times_32", update_cache = TRUE)) stream_plan_with_udf <- function() { - record_batch(a = 1:1000) %>% + record_batch(a = 1:1000) %>% dplyr::mutate(b = times_32(a)) %>% as_record_batch_reader() %>% as_arrow_table() diff --git a/r/tests/testthat/test-dataset-dplyr.R b/r/tests/testthat/test-dataset-dplyr.R index b09b549d590..f3ec858c6ce 100644 --- a/r/tests/testthat/test-dataset-dplyr.R +++ b/r/tests/testthat/test-dataset-dplyr.R @@ -356,9 +356,9 @@ test_that("show_exec_plan(), show_query() and explain() with datasets", { ds %>% show_exec_plan(), regexp = paste0( - "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan - "ProjectNode.*", # output columns - "SourceNode" # entry point + "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan + "ProjectNode.*", # output columns + "SourceNode" # entry point ) ) @@ -369,11 +369,11 @@ test_that("show_exec_plan(), show_query() and explain() with datasets", { filter(integer > 6L & part == 1) %>% show_exec_plan(), regexp = paste0( - "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan - "ProjectNode.*", # output columns - "FilterNode.*", # filter node - "int > 6.*cast.*", # filtering expressions + auto-casting of part - "SourceNode" # entry point + "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan + "ProjectNode.*", # output columns + "FilterNode.*", # filter node + "int > 6.*cast.*", # filtering expressions + auto-casting of part + "SourceNode" # entry point ) ) @@ -384,13 +384,13 @@ test_that("show_exec_plan(), show_query() and explain() with datasets", { summarise(avg = mean(int)) %>% show_exec_plan(), regexp = paste0( - "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan - "ProjectNode.*", # output columns - "GroupByNode.*", # group by node - "keys=.*part.*", # key for aggregations - "aggregates=.*hash_mean.*", # aggregations - "ProjectNode.*", # input columns - "SourceNode" # entry point + "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan + "ProjectNode.*", # output columns + "GroupByNode.*", # group by node + "keys=.*part.*", # key for aggregations + "aggregates=.*hash_mean.*", # aggregations + "ProjectNode.*", # input columns + "SourceNode" # entry point ) ) @@ -401,12 +401,12 @@ test_that("show_exec_plan(), show_query() and explain() with datasets", { arrange(chr) %>% show_exec_plan(), regexp = paste0( - "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan + "ExecPlan with .* nodes:.*", # boiler plate for ExecPlan "OrderBySinkNode.*chr.*ASC.*", # arrange goes via the OrderBy sink node - "ProjectNode.*", # output columns - "FilterNode.*", # filter node - "filter=lgl.*", # filtering expression - "SourceNode" # entry point + "ProjectNode.*", # output columns + "FilterNode.*", # filter node + "filter=lgl.*", # filtering expression + "SourceNode" # entry point ) ) diff --git a/r/tests/testthat/test-dataset.R b/r/tests/testthat/test-dataset.R index d9512ef94f3..bf38aaec3f2 100644 --- a/r/tests/testthat/test-dataset.R +++ b/r/tests/testthat/test-dataset.R @@ -692,7 +692,7 @@ test_that("map_batches", { test_that("map_batches with explicit schema", { fun_with_dots <- function(batch, first_col, first_col_val) { record_batch( - !! first_col := first_col_val, + !!first_col := first_col_val, b = batch$a$cast(float64()) ) } @@ -736,7 +736,7 @@ test_that("map_batches with explicit schema", { test_that("map_batches without explicit schema", { fun_with_dots <- function(batch, first_col, first_col_val) { record_batch( - !! first_col := first_col_val, + !!first_col := first_col_val, b = batch$a$cast(float64()) ) } diff --git a/r/tests/testthat/test-dplyr-across.R b/r/tests/testthat/test-dplyr-across.R index 8945c2a5f3b..cc419dbbe54 100644 --- a/r/tests/testthat/test-dplyr-across.R +++ b/r/tests/testthat/test-dplyr-across.R @@ -222,5 +222,4 @@ test_that("expand_across correctly expands quosures", { regexp = "`.names` specification must produce (number of columns * number of functions) names.", fixed = TRUE ) - }) diff --git a/r/tests/testthat/test-dplyr-funcs-datetime.R b/r/tests/testthat/test-dplyr-funcs-datetime.R index b9ba877d6f6..b8f8da11ccc 100644 --- a/r/tests/testthat/test-dplyr-funcs-datetime.R +++ b/r/tests/testthat/test-dplyr-funcs-datetime.R @@ -186,7 +186,6 @@ test_that("strptime", { collect(), tibble::tibble(string_1 = c("2022-02-11-12:23:45", NA)) ) - }) test_that("strptime works for individual formats", { @@ -200,21 +199,21 @@ test_that("strptime works for individual formats", { expect_equal( strptime_test_df %>% arrow_table() %>% - mutate( - parsed_H = strptime(string_H, format = "%Y-%m-%d-%H"), - parsed_I = strptime(string_I, format = "%Y-%m-%d-%I"), - parsed_j = strptime(string_j, format = "%Y-%m-%d-%j"), - parsed_M = strptime(string_M, format = "%Y-%m-%d-%M"), - parsed_S = strptime(string_S, format = "%Y-%m-%d-%S"), - parsed_U = strptime(string_U, format = "%Y-%m-%d-%U"), - parsed_w = strptime(string_w, format = "%Y-%m-%d-%w"), - parsed_W = strptime(string_W, format = "%Y-%m-%d-%W"), - parsed_y = strptime(string_y, format = "%y-%m-%d"), - parsed_Y = strptime(string_Y, format = "%Y-%m-%d"), - parsed_R = strptime(string_R, format = "%Y-%m-%d-%R"), - parsed_T = strptime(string_T, format = "%Y-%m-%d-%T") - ) %>% - collect(), + mutate( + parsed_H = strptime(string_H, format = "%Y-%m-%d-%H"), + parsed_I = strptime(string_I, format = "%Y-%m-%d-%I"), + parsed_j = strptime(string_j, format = "%Y-%m-%d-%j"), + parsed_M = strptime(string_M, format = "%Y-%m-%d-%M"), + parsed_S = strptime(string_S, format = "%Y-%m-%d-%S"), + parsed_U = strptime(string_U, format = "%Y-%m-%d-%U"), + parsed_w = strptime(string_w, format = "%Y-%m-%d-%w"), + parsed_W = strptime(string_W, format = "%Y-%m-%d-%W"), + parsed_y = strptime(string_y, format = "%y-%m-%d"), + parsed_Y = strptime(string_Y, format = "%Y-%m-%d"), + parsed_R = strptime(string_R, format = "%Y-%m-%d-%R"), + parsed_T = strptime(string_T, format = "%Y-%m-%d-%T") + ) %>% + collect(), strptime_test_df %>% mutate( parsed_H = as.POSIXct(strptime(string_H, format = "%Y-%m-%d-%H")), @@ -238,15 +237,15 @@ test_that("strptime works for individual formats", { expect_equal( strptime_test_df %>% arrow_table() %>% - mutate( - parsed_a = strptime(string_a, format = "%Y-%m-%d-%a"), - parsed_A = strptime(string_A, format = "%Y-%m-%d-%A"), - parsed_b = strptime(string_b, format = "%Y-%m-%d-%b"), - parsed_B = strptime(string_B, format = "%Y-%m-%d-%B"), - parsed_p = strptime(string_p, format = "%Y-%m-%d-%p"), - parsed_r = strptime(string_r, format = "%Y-%m-%d-%r") - ) %>% - collect(), + mutate( + parsed_a = strptime(string_a, format = "%Y-%m-%d-%a"), + parsed_A = strptime(string_A, format = "%Y-%m-%d-%A"), + parsed_b = strptime(string_b, format = "%Y-%m-%d-%b"), + parsed_B = strptime(string_B, format = "%Y-%m-%d-%B"), + parsed_p = strptime(string_p, format = "%Y-%m-%d-%p"), + parsed_r = strptime(string_r, format = "%Y-%m-%d-%r") + ) %>% + collect(), strptime_test_df %>% mutate( parsed_a = as.POSIXct(strptime(string_a, format = "%Y-%m-%d-%a")), @@ -258,7 +257,6 @@ test_that("strptime works for individual formats", { ) %>% collect() ) - }) test_that("timestamp round trip correctly via strftime and strptime", { @@ -297,8 +295,8 @@ test_that("timestamp round trip correctly via strftime and strptime", { expect_equal( test_df %>% arrow_table() %>% - mutate(!!fmt := strptime(x, format = fmt)) %>% - collect(), + mutate(!!fmt := strptime(x, format = fmt)) %>% + collect(), test_df %>% mutate(!!fmt := as.POSIXct(strptime(x, format = fmt))) %>% collect() @@ -312,14 +310,13 @@ test_that("timestamp round trip correctly via strftime and strptime", { expect_equal( test_df %>% arrow_table() %>% - mutate(!!fmt := strptime(x, format = fmt2)) %>% - collect(), + mutate(!!fmt := strptime(x, format = fmt2)) %>% + collect(), test_df %>% mutate(!!fmt := as.POSIXct(strptime(x, format = fmt2))) %>% collect() ) } - }) test_that("strptime returns NA when format doesn't match the data", { @@ -994,8 +991,8 @@ test_that("extract qday from date", { compare_dplyr_binding( .input %>% - mutate(y = qday(as.Date("2022-06-29"))) %>% - collect(), + mutate(y = qday(as.Date("2022-06-29"))) %>% + collect(), test_df ) }) @@ -2299,7 +2296,6 @@ test_that("parse_date_time's other formats", { tibble::tibble(string_1 = c("2022-Feb-11-12:23:45", NA)) ) } - }) test_that("lubridate's fast_strptime", { @@ -2835,7 +2831,6 @@ test_that("parse_date_time with `exact = TRUE`, and with regular R objects", { }) test_that("build_formats() and build_format_from_order()", { - ymd_formats <- c( "%y-%m-%d", "%Y-%m-%d", "%y-%B-%d", "%Y-%B-%d", "%y-%b-%d", "%Y-%b-%d", "%y%m%d", "%Y%m%d", "%y%B%d", "%Y%B%d", "%y%b%d", "%Y%b%d" @@ -3027,7 +3022,7 @@ boundary_times <- tibble::tibble( "2022-03-10 00:00:01", # boundary for second, millisecond "2022-03-10 00:01:00", # boundary for second, millisecond, minute "2022-03-10 01:00:00", # boundary for second, millisecond, minute, hour - "2022-01-01 00:00:00" # boundary for year + "2022-01-01 00:00:00" # boundary for year ), tz = "UTC", format = "%F %T")), date = as.Date(datetime) ) @@ -3053,14 +3048,13 @@ datestrings <- c( ) tz_times <- tibble::tibble( utc_time = as.POSIXct(datestrings, tz = "UTC"), - syd_time = as.POSIXct(datestrings, tz = "Australia/Sydney"), # UTC +10 (UTC +11 with DST) + syd_time = as.POSIXct(datestrings, tz = "Australia/Sydney"), # UTC +10 (UTC +11 with DST) adl_time = as.POSIXct(datestrings, tz = "Australia/Adelaide"), # UTC +9:30 (UTC +10:30 with DST) - mar_time = as.POSIXct(datestrings, tz = "Pacific/Marquesas"), # UTC -9:30 (no DST) - kat_time = as.POSIXct(datestrings, tz = "Asia/Kathmandu") # UTC +5:45 (no DST) + mar_time = as.POSIXct(datestrings, tz = "Pacific/Marquesas"), # UTC -9:30 (no DST) + kat_time = as.POSIXct(datestrings, tz = "Asia/Kathmandu") # UTC +5:45 (no DST) ) test_that("timestamp round/floor/ceiling works for a minimal test", { - compare_dplyr_binding( .input %>% mutate( @@ -3103,7 +3097,6 @@ test_that("timestamp round/floor/ceiling accepts period unit abbreviation", { }) test_that("temporal round/floor/ceiling accepts periods with multiple units", { - check_multiple_unit_period <- function(unit, multiplier) { unit_string <- paste(multiplier, unit) compare_dplyr_binding( @@ -3152,7 +3145,6 @@ check_date_rounding <- function(data, unit, lubridate_unit = unit, ...) { } check_timestamp_rounding <- function(data, unit, lubridate_unit = unit, ...) { - expect_equal( data %>% arrow_table() %>% @@ -3173,16 +3165,13 @@ check_timestamp_rounding <- function(data, unit, lubridate_unit = unit, ...) { } test_that("date round/floor/ceil works for units of 1 day or less", { - test_df %>% check_date_rounding("1 millisecond", lubridate_unit = ".001 second") test_df %>% check_date_rounding("1 day") test_df %>% check_date_rounding("1 second") test_df %>% check_date_rounding("1 hour") - }) test_that("timestamp round/floor/ceil works for units of 1 day or less", { - test_df %>% check_timestamp_rounding("second") test_df %>% check_timestamp_rounding("minute") test_df %>% check_timestamp_rounding("hour") @@ -3195,15 +3184,12 @@ test_that("timestamp round/floor/ceil works for units of 1 day or less", { test_df %>% check_timestamp_rounding("1 millisecond", lubridate_unit = ".001 second") test_df %>% check_timestamp_rounding("1 microsecond", lubridate_unit = ".000001 second") test_df %>% check_timestamp_rounding("1 nanosecond", lubridate_unit = ".000000001 second") - }) test_that("timestamp round/floor/ceil works for units: month/quarter/year", { - year_of_dates %>% check_timestamp_rounding("month", ignore_attr = TRUE) year_of_dates %>% check_timestamp_rounding("quarter", ignore_attr = TRUE) year_of_dates %>% check_timestamp_rounding("year", ignore_attr = TRUE) - }) # check helper invoked when we need to avoid the lubridate rounding bug @@ -3249,7 +3235,6 @@ test_that("date round/floor/ceil works for units: month/quarter/year", { check_date_rounding_1051_bypass(year_of_dates, "month", ignore_attr = TRUE) check_date_rounding_1051_bypass(year_of_dates, "quarter", ignore_attr = TRUE) check_date_rounding_1051_bypass(year_of_dates, "year", ignore_attr = TRUE) - }) check_date_week_rounding <- function(data, week_start, ignore_attr = TRUE, ...) { @@ -3289,20 +3274,16 @@ check_timestamp_week_rounding <- function(data, week_start, ignore_attr = TRUE, } test_that("timestamp round/floor/ceil works for week units (standard week_start)", { - fortnight %>% check_timestamp_week_rounding(week_start = 1) # Monday fortnight %>% check_timestamp_week_rounding(week_start = 7) # Sunday - }) test_that("timestamp round/floor/ceil works for week units (non-standard week_start)", { - fortnight %>% check_timestamp_week_rounding(week_start = 2) # Tuesday fortnight %>% check_timestamp_week_rounding(week_start = 3) # Wednesday fortnight %>% check_timestamp_week_rounding(week_start = 4) # Thursday fortnight %>% check_timestamp_week_rounding(week_start = 5) # Friday fortnight %>% check_timestamp_week_rounding(week_start = 6) # Saturday - }) check_date_week_rounding <- function(data, week_start, ignore_attr = TRUE, ...) { @@ -3337,20 +3318,16 @@ check_date_week_rounding <- function(data, week_start, ignore_attr = TRUE, ...) } test_that("date round/floor/ceil works for week units (standard week_start)", { - check_date_week_rounding(fortnight, week_start = 1) # Monday check_date_week_rounding(fortnight, week_start = 7) # Sunday - }) test_that("date round/floor/ceil works for week units (non-standard week_start)", { - check_date_week_rounding(fortnight, week_start = 2) # Tuesday check_date_week_rounding(fortnight, week_start = 3) # Wednesday check_date_week_rounding(fortnight, week_start = 4) # Thursday check_date_week_rounding(fortnight, week_start = 5) # Friday check_date_week_rounding(fortnight, week_start = 6) # Saturday - }) # Test helper used to check that the change_on_boundary argument to @@ -3389,8 +3366,6 @@ check_boundary_with_unit <- function(unit, ...) { ), ... ) - - } test_that("ceiling_date() applies change_on_boundary correctly", { @@ -3405,7 +3380,6 @@ test_that("ceiling_date() applies change_on_boundary correctly", { # exceeded. Checks that arrow mimics this behaviour and throws an identically # worded error message test_that("temporal round/floor/ceil period unit maxima are enforced", { - expect_error( call_binding("round_date", Expression$scalar(Sys.time()), "61 seconds"), "Rounding with second > 60 is not supported" @@ -3422,7 +3396,6 @@ test_that("temporal round/floor/ceil period unit maxima are enforced", { call_binding("round_date", Expression$scalar(Sys.Date()), "25 hours"), "Rounding with hour > 24 is not supported" ) - }) # one method to test that temporal rounding takes place in local time is to @@ -3464,7 +3437,6 @@ check_timezone_rounding_vs_lubridate <- function(data, unit) { collect(), data ) - } # another method to check that temporal rounding takes place in local @@ -3474,7 +3446,6 @@ check_timezone_rounding_vs_lubridate <- function(data, unit) { # for UTC test. this test isn't useful for subsecond resolution but avoids # dependency on lubridate check_timezone_rounding_for_consistency <- function(data, unit) { - shifted_times <- data %>% arrow_table() %>% mutate( @@ -3498,11 +3469,11 @@ check_timezone_rounding_for_consistency <- function(data, unit) { compare_local_times <- function(time1, time2) { all(year(time1) == year(time1) & - month(time1) == month(time2) & - day(time1) == day(time2) & - hour(time1) == hour(time2) & - minute(time1) == minute(time2) & - second(time1) == second(time1)) + month(time1) == month(time2) & + day(time1) == day(time2) & + hour(time1) == hour(time2) & + minute(time1) == minute(time2) & + second(time1) == second(time1)) } base <- shifted_times$utc_rounded @@ -3525,7 +3496,6 @@ check_timezone_rounding_for_consistency <- function(data, unit) { } test_that("timestamp rounding takes place in local time", { - tz_times %>% check_timezone_rounding_vs_lubridate(".001 second") tz_times %>% check_timezone_rounding_vs_lubridate("second") tz_times %>% check_timezone_rounding_vs_lubridate("minute") @@ -3556,5 +3526,4 @@ test_that("timestamp rounding takes place in local time", { tz_times %>% check_timezone_rounding_for_consistency("13 hours") tz_times %>% check_timezone_rounding_for_consistency("13 months") tz_times %>% check_timezone_rounding_for_consistency("13 years") - }) diff --git a/r/tests/testthat/test-dplyr-funcs-math.R b/r/tests/testthat/test-dplyr-funcs-math.R index b9a6a3707d4..66b3a510f9c 100644 --- a/r/tests/testthat/test-dplyr-funcs-math.R +++ b/r/tests/testthat/test-dplyr-funcs-math.R @@ -25,7 +25,8 @@ test_that("abs()", { .input %>% transmute( abs = abs(x), - abs2 = base::abs(x)) %>% + abs2 = base::abs(x) + ) %>% collect(), df ) diff --git a/r/tests/testthat/test-dplyr-funcs-string.R b/r/tests/testthat/test-dplyr-funcs-string.R index 229347372ae..57f8532ea83 100644 --- a/r/tests/testthat/test-dplyr-funcs-string.R +++ b/r/tests/testthat/test-dplyr-funcs-string.R @@ -61,7 +61,8 @@ test_that("paste, paste0, and str_c", { .input %>% transmute( a = paste0(v, w), - a2 = base::paste0(v, w)) %>% + a2 = base::paste0(v, w) + ) %>% collect(), df ) diff --git a/r/tests/testthat/test-dplyr-funcs-type.R b/r/tests/testthat/test-dplyr-funcs-type.R index 285d86520f3..ccf16dd4db4 100644 --- a/r/tests/testthat/test-dplyr-funcs-type.R +++ b/r/tests/testthat/test-dplyr-funcs-type.R @@ -289,7 +289,7 @@ test_that("type checks with is() giving Arrow types", { str_is_dec256 = is(str, decimal256(3, 2)), str_is_i64 = is(str, float64()), str_is_str = is(str, string()) - ) %>% + ) %>% collect() %>% t() %>% as.vector(), diff --git a/r/tools/winlibs.R b/r/tools/winlibs.R index 5aeea2e417e..165c98da5ea 100644 --- a/r/tools/winlibs.R +++ b/r/tools/winlibs.R @@ -48,7 +48,7 @@ if (!file.exists(sprintf("windows/arrow-%s/include/arrow/api.h", VERSION))) { rwinlib <- "https://github.com/rwinlib/arrow/archive/v%s.zip" dev_version <- package_version(VERSION)[1, 4] - + # Small dev versions are added for R-only changes during CRAN submission. if (is.na(dev_version) || dev_version < 100) { VERSION <- package_version(VERSION)[1, 1:3] From 9131724e451e00feaf6acb944d7ffbb6f07c1add Mon Sep 17 00:00:00 2001 From: Jin Shang Date: Sat, 17 Sep 2022 04:14:20 +0800 Subject: [PATCH 089/133] ARROW-17728: [C++][Gandiva] Accept LLVM 15.0 (#14125) Lead-authored-by: Jin Shang Co-authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- .github/workflows/ruby.yml | 3 +- ci/scripts/cpp_test.sh | 30 +++++++++---------- cpp/CMakeLists.txt | 1 + .../{Findzstd.cmake => FindzstdAlt.cmake} | 27 ++++++++++++++--- cpp/cmake_modules/ThirdpartyToolchain.cmake | 18 ++++++----- 5 files changed, 50 insertions(+), 29 deletions(-) rename cpp/cmake_modules/{Findzstd.cmake => FindzstdAlt.cmake} (82%) diff --git a/.github/workflows/ruby.yml b/.github/workflows/ruby.yml index 3f877b4aa30..4dd61befab6 100644 --- a/.github/workflows/ruby.yml +++ b/.github/workflows/ruby.yml @@ -195,7 +195,8 @@ jobs: ARROW_BUILD_TYPE: release ARROW_FLIGHT: ON ARROW_FLIGHT_SQL: ON - ARROW_GANDIVA: ON + # ARROW-17728: SEGV on MinGW + ARROW_GANDIVA: OFF ARROW_GCS: ON ARROW_HDFS: OFF ARROW_HOME: /ucrt${{ matrix.mingw-n-bits }} diff --git a/ci/scripts/cpp_test.sh b/ci/scripts/cpp_test.sh index 2bd7db8b2c4..06b7d0fe413 100755 --- a/ci/scripts/cpp_test.sh +++ b/ci/scripts/cpp_test.sh @@ -55,22 +55,20 @@ case "$(uname)" in exclude_tests="gandiva-internals-test" exclude_tests="${exclude_tests}|gandiva-projector-test" exclude_tests="${exclude_tests}|gandiva-utf8-test" - if [ "${MSYSTEM}" = "MINGW32" ]; then - exclude_tests="${exclude_tests}|gandiva-binary-test" - exclude_tests="${exclude_tests}|gandiva-boolean-expr-test" - exclude_tests="${exclude_tests}|gandiva-date-time-test" - exclude_tests="${exclude_tests}|gandiva-decimal-single-test" - exclude_tests="${exclude_tests}|gandiva-decimal-test" - exclude_tests="${exclude_tests}|gandiva-filter-project-test" - exclude_tests="${exclude_tests}|gandiva-filter-test" - exclude_tests="${exclude_tests}|gandiva-hash-test" - exclude_tests="${exclude_tests}|gandiva-if-expr-test" - exclude_tests="${exclude_tests}|gandiva-in-expr-test" - exclude_tests="${exclude_tests}|gandiva-literal-test" - exclude_tests="${exclude_tests}|gandiva-null-validity-test" - exclude_tests="${exclude_tests}|gandiva-precompiled-test" - exclude_tests="${exclude_tests}|gandiva-projector-test" - fi + exclude_tests="${exclude_tests}|gandiva-binary-test" + exclude_tests="${exclude_tests}|gandiva-boolean-expr-test" + exclude_tests="${exclude_tests}|gandiva-date-time-test" + exclude_tests="${exclude_tests}|gandiva-decimal-single-test" + exclude_tests="${exclude_tests}|gandiva-decimal-test" + exclude_tests="${exclude_tests}|gandiva-filter-project-test" + exclude_tests="${exclude_tests}|gandiva-filter-test" + exclude_tests="${exclude_tests}|gandiva-hash-test" + exclude_tests="${exclude_tests}|gandiva-if-expr-test" + exclude_tests="${exclude_tests}|gandiva-in-expr-test" + exclude_tests="${exclude_tests}|gandiva-literal-test" + exclude_tests="${exclude_tests}|gandiva-null-validity-test" + exclude_tests="${exclude_tests}|gandiva-precompiled-test" + exclude_tests="${exclude_tests}|gandiva-projector-test" ctest_options+=(--exclude-regex "${exclude_tests}") ;; *) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 14584e17fb8..d67142e0569 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -128,6 +128,7 @@ set(ARROW_DOC_DIR "share/doc/${PROJECT_NAME}") set(BUILD_SUPPORT_DIR "${CMAKE_SOURCE_DIR}/build-support") set(ARROW_LLVM_VERSIONS + "15.0" "14.0" "13.0" "12.0" diff --git a/cpp/cmake_modules/Findzstd.cmake b/cpp/cmake_modules/FindzstdAlt.cmake similarity index 82% rename from cpp/cmake_modules/Findzstd.cmake rename to cpp/cmake_modules/FindzstdAlt.cmake index d907467f947..476006a768b 100644 --- a/cpp/cmake_modules/Findzstd.cmake +++ b/cpp/cmake_modules/FindzstdAlt.cmake @@ -15,7 +15,20 @@ # specific language governing permissions and limitations # under the License. +if(zstdAlt_FOUND) + return() +endif() + +set(find_package_args) +if(zstdAlt_FIND_VERSION) + list(APPEND find_package_args ${zstdAlt_FIND_VERSION}) +endif() +if(zstdAlt_FIND_QUIETLY) + list(APPEND find_package_args QUIET) +endif() +find_package(zstd ${find_package_args}) if(zstd_FOUND) + set(zstdAlt_FOUND TRUE) return() endif() @@ -83,11 +96,17 @@ else() endif() endif() -find_package_handle_standard_args(zstd REQUIRED_VARS ZSTD_LIB ZSTD_INCLUDE_DIR) +find_package_handle_standard_args(zstdAlt REQUIRED_VARS ZSTD_LIB ZSTD_INCLUDE_DIR) -if(zstd_FOUND) - add_library(zstd::libzstd UNKNOWN IMPORTED) - set_target_properties(zstd::libzstd +if(zstdAlt_FOUND) + if(ARROW_ZSTD_USE_SHARED) + set(zstd_TARGET zstd::libzstd_shared) + add_library(${zstd_TARGET} SHARED IMPORTED) + else() + set(zstd_TARGET zstd::libzstd_static) + add_library(${zstd_TARGET} STATIC IMPORTED) + endif() + set_target_properties(${zstd_TARGET} PROPERTIES IMPORTED_LOCATION "${ZSTD_LIB}" INTERFACE_INCLUDE_DIRECTORIES "${ZSTD_INCLUDE_DIR}") endif() diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 52847a99f9b..a45ae61015e 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -2397,30 +2397,32 @@ macro(build_zstd) file(MAKE_DIRECTORY "${ZSTD_PREFIX}/include") - add_library(zstd::libzstd STATIC IMPORTED) - set_target_properties(zstd::libzstd + add_library(zstd::libzstd_static STATIC IMPORTED) + set_target_properties(zstd::libzstd_static PROPERTIES IMPORTED_LOCATION "${ZSTD_STATIC_LIB}" INTERFACE_INCLUDE_DIRECTORIES "${ZSTD_PREFIX}/include") add_dependencies(toolchain zstd_ep) - add_dependencies(zstd::libzstd zstd_ep) + add_dependencies(zstd::libzstd_static zstd_ep) - list(APPEND ARROW_BUNDLED_STATIC_LIBS zstd::libzstd) + list(APPEND ARROW_BUNDLED_STATIC_LIBS zstd::libzstd_static) + + set(ZSTD_VENDORED TRUE) endmacro() if(ARROW_WITH_ZSTD) # ARROW-13384: ZSTD_minCLevel was added in v1.4.0, required by ARROW-13091 resolve_dependency(zstd + HAVE_ALT + TRUE PC_PACKAGE_NAMES libzstd REQUIRED_VERSION 1.4.0) - if(TARGET zstd::libzstd) - set(ARROW_ZSTD_LIBZSTD zstd::libzstd) + if(ZSTD_VENDORED) + set(ARROW_ZSTD_LIBZSTD zstd::libzstd_static) else() - # "SYSTEM" source will prioritize cmake config, which exports - # zstd::libzstd_{static,shared} if(ARROW_ZSTD_USE_SHARED) if(TARGET zstd::libzstd_shared) set(ARROW_ZSTD_LIBZSTD zstd::libzstd_shared) From decddbbf1afbf620710f95e2c82b4b6f59a752da Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Fri, 16 Sep 2022 22:20:21 +0200 Subject: [PATCH 090/133] ARROW-17695: [C++] Remove Variant class (#14136) Use std::variant instead. Authored-by: Antoine Pitrou Signed-off-by: Sutou Kouhei --- c_glib/gandiva-glib/node.cpp | 2 +- cpp/gdb_arrow.py | 32 +- cpp/src/arrow/compute/exec.h | 6 +- cpp/src/arrow/compute/exec/exec_plan.cc | 6 +- cpp/src/arrow/compute/exec/exec_plan.h | 2 +- cpp/src/arrow/compute/exec/expression.cc | 6 +- cpp/src/arrow/compute/exec/expression.h | 4 +- cpp/src/arrow/compute/exec/test_util.cc | 2 +- cpp/src/arrow/dataset/discovery.h | 2 +- cpp/src/arrow/dataset/file_base.cc | 8 +- cpp/src/arrow/datum.cc | 35 +- cpp/src/arrow/datum.h | 18 +- .../engine/substrait/relation_internal.cc | 6 +- cpp/src/arrow/engine/substrait/serde_test.cc | 49 +- cpp/src/arrow/filesystem/mockfs.cc | 12 +- cpp/src/arrow/flight/client.h | 4 +- cpp/src/arrow/flight/sql/server.cc | 2 +- cpp/src/arrow/flight/sql/server_test.cc | 2 +- cpp/src/arrow/flight/sql/types.h | 6 +- .../flight/transport/grpc/grpc_client.cc | 8 +- cpp/src/arrow/type.cc | 27 +- cpp/src/arrow/type.h | 18 +- cpp/src/arrow/util/CMakeLists.txt | 4 +- cpp/src/arrow/util/variant.h | 443 ------------------ cpp/src/arrow/util/variant_benchmark.cc | 248 ---------- cpp/src/arrow/util/variant_test.cc | 345 -------------- cpp/src/gandiva/interval_holder.h | 2 +- cpp/src/gandiva/literal_holder.cc | 2 +- cpp/src/gandiva/literal_holder.h | 7 +- cpp/src/gandiva/llvm_generator.cc | 26 +- cpp/src/gandiva/node.h | 4 +- cpp/src/gandiva/random_generator_holder.cc | 2 +- cpp/src/gandiva/regex_functions_holder.cc | 14 +- cpp/src/gandiva/to_date_holder.cc | 4 +- cpp/src/parquet/arrow/path_internal.cc | 66 +-- cpp/src/parquet/encryption/key_metadata.h | 9 +- python/pyarrow/includes/libarrow_flight.pxd | 4 +- python/pyarrow/src/gdb.cc | 8 - python/pyarrow/tests/test_gdb.py | 13 - 39 files changed, 191 insertions(+), 1267 deletions(-) delete mode 100644 cpp/src/arrow/util/variant.h delete mode 100644 cpp/src/arrow/util/variant_benchmark.cc delete mode 100644 cpp/src/arrow/util/variant_test.cc diff --git a/c_glib/gandiva-glib/node.cpp b/c_glib/gandiva-glib/node.cpp index d42d4801b7e..1ced7754a70 100644 --- a/c_glib/gandiva-glib/node.cpp +++ b/c_glib/gandiva-glib/node.cpp @@ -29,7 +29,7 @@ ggandiva_literal_node_get(GGandivaLiteralNode *node) { auto gandiva_literal_node = std::static_pointer_cast(ggandiva_node_get_raw(GGANDIVA_NODE(node))); - return arrow::util::get(gandiva_literal_node->holder()); + return std::get(gandiva_literal_node->holder()); } G_BEGIN_DECLS diff --git a/cpp/gdb_arrow.py b/cpp/gdb_arrow.py index 5421b4ffb15..5d7e1719afe 100644 --- a/cpp/gdb_arrow.py +++ b/cpp/gdb_arrow.py @@ -426,12 +426,17 @@ def value(self): class Variant: """ - A arrow::util::Variant<...>. + A `std::variant<...>`. """ def __init__(self, val): self.val = val - self.index = int(self.val['index_']) + try: + # libstdc++ internals + self.index = val['_M_index'] + except gdb.error: + # fallback for other C++ standard libraries + self.index = gdb.parse_and_eval(f"{for_evaluation(val)}.index()") try: self.value_type = self.val.type.template_argument(self.index) except RuntimeError: @@ -2175,28 +2180,6 @@ def to_string(self): return f"arrow::util::string_view of size {size}, {data}" -class VariantPrinter: - """ - Pretty-printer for arrow::util::Variant. - """ - - def __init__(self, name, val): - self.val = val - self.variant = Variant(val) - - def to_string(self): - if self.variant.value_type is None: - return "arrow::util::Variant (uninitialized or corrupt)" - type_desc = (f"arrow::util::Variant of index {self.variant.index} " - f"(actual type {self.variant.value_type})") - - value = self.variant.value - if value is None: - return (f"{type_desc}, unavailable value") - else: - return (f"{type_desc}, value {value}") - - class FieldPrinter: """ Pretty-printer for arrow::Field. @@ -2415,7 +2398,6 @@ def to_string(self): "arrow::Status": StatusPrinter, "arrow::Table": TablePrinter, "arrow::util::string_view": StringViewPrinter, - "arrow::util::Variant": VariantPrinter, "nonstd::sv_lite::basic_string_view": StringViewPrinter, } diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index d03b073bb88..12cce42038d 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -313,7 +313,7 @@ struct ExecValue { struct ARROW_EXPORT ExecResult { // The default value of the variant is ArraySpan - util::Variant> value; + std::variant> value; int64_t length() const { if (this->is_array_span()) { @@ -332,12 +332,12 @@ struct ARROW_EXPORT ExecResult { } ArraySpan* array_span() const { - return const_cast(&util::get(this->value)); + return const_cast(&std::get(this->value)); } bool is_array_span() const { return this->value.index() == 0; } const std::shared_ptr& array_data() const { - return util::get>(this->value); + return std::get>(this->value); } bool is_array_data() const { return this->value.index() == 1; } diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 322b5f5e456..057e1ace5cd 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -612,12 +612,12 @@ Result Declaration::AddToPlan(ExecPlan* plan, size_t i = 0; for (const Input& input : this->inputs) { - if (auto node = util::get_if(&input)) { + if (auto node = std::get_if(&input)) { inputs[i++] = *node; continue; } ARROW_ASSIGN_OR_RAISE(inputs[i++], - util::get(input).AddToPlan(plan, registry)); + std::get(input).AddToPlan(plan, registry)); } ARROW_ASSIGN_OR_RAISE( @@ -638,7 +638,7 @@ Declaration Declaration::Sequence(std::vector decls) { decls.pop_back(); receiver->inputs.emplace_back(std::move(input)); - receiver = &util::get(receiver->inputs.front()); + receiver = &std::get(receiver->inputs.front()); } return out; } diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 93d06551241..3ff2340856f 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -449,7 +449,7 @@ inline Result MakeExecNode( /// inputs may also be Declarations). The node can be constructed and added to a plan /// with Declaration::AddToPlan, which will recursively construct any inputs as necessary. struct ARROW_EXPORT Declaration { - using Input = util::Variant; + using Input = std::variant; Declaration() {} diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index 16942a0f80f..d23838303f7 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -76,10 +76,10 @@ Expression call(std::string function, std::vector arguments, return Expression(std::move(call)); } -const Datum* Expression::literal() const { return util::get_if(impl_.get()); } +const Datum* Expression::literal() const { return std::get_if(impl_.get()); } const Expression::Parameter* Expression::parameter() const { - return util::get_if(impl_.get()); + return std::get_if(impl_.get()); } const FieldRef* Expression::field_ref() const { @@ -90,7 +90,7 @@ const FieldRef* Expression::field_ref() const { } const Expression::Call* Expression::call() const { - return util::get_if(impl_.get()); + return std::get_if(impl_.get()); } const DataType* Expression::type() const { diff --git a/cpp/src/arrow/compute/exec/expression.h b/cpp/src/arrow/compute/exec/expression.h index a872e799597..d49fe5c893e 100644 --- a/cpp/src/arrow/compute/exec/expression.h +++ b/cpp/src/arrow/compute/exec/expression.h @@ -22,13 +22,13 @@ #include #include #include +#include #include #include "arrow/compute/type_fwd.h" #include "arrow/datum.h" #include "arrow/type_fwd.h" #include "arrow/util/small_vector.h" -#include "arrow/util/variant.h" namespace arrow { namespace compute { @@ -127,7 +127,7 @@ class ARROW_EXPORT Expression { explicit Expression(Parameter parameter); private: - using Impl = util::Variant; + using Impl = std::variant; std::shared_ptr impl_; ARROW_EXPORT friend bool Identical(const Expression& l, const Expression& r); diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 8c8c3f6b3b2..2abe6e9e029 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -424,7 +424,7 @@ void PrintTo(const Declaration& decl, std::ostream* os) { *os << "{"; for (const auto& input : decl.inputs) { - if (auto decl = util::get_if(&input)) { + if (auto decl = std::get_if(&input)) { PrintTo(*decl, os); } } diff --git a/cpp/src/arrow/dataset/discovery.h b/cpp/src/arrow/dataset/discovery.h index 40c02051955..238b33e40fe 100644 --- a/cpp/src/arrow/dataset/discovery.h +++ b/cpp/src/arrow/dataset/discovery.h @@ -25,6 +25,7 @@ #include #include +#include #include #include "arrow/dataset/partition.h" @@ -33,7 +34,6 @@ #include "arrow/filesystem/type_fwd.h" #include "arrow/result.h" #include "arrow/util/macros.h" -#include "arrow/util/variant.h" namespace arrow { namespace dataset { diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index 64daf08fd03..23f4d09a9d2 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -21,6 +21,7 @@ #include #include +#include #include #include "arrow/compute/api_scalar.h" @@ -43,7 +44,6 @@ #include "arrow/util/string.h" #include "arrow/util/task_group.h" #include "arrow/util/tracing_internal.h" -#include "arrow/util/variant.h" namespace arrow { @@ -154,7 +154,7 @@ struct FileSystemDataset::FragmentSubtrees { // Forest for skipping fragments based on extracted subtree expressions compute::Forest forest; // fragment indices and subtree expressions in forest order - std::vector> fragments_and_subtrees; + std::vector> fragments_and_subtrees; }; Result> FileSystemDataset::Make( @@ -242,13 +242,13 @@ Result FileSystemDataset::GetFragmentsImpl( RETURN_NOT_OK(subtrees_->forest.Visit( [&](compute::Forest::Ref ref) -> Result { if (auto fragment_index = - util::get_if(&subtrees_->fragments_and_subtrees[ref.i])) { + std::get_if(&subtrees_->fragments_and_subtrees[ref.i])) { fragment_indices.push_back(*fragment_index); return false; } const auto& subtree_expr = - util::get(subtrees_->fragments_and_subtrees[ref.i]); + std::get(subtrees_->fragments_and_subtrees[ref.i]); ARROW_ASSIGN_OR_RAISE(auto simplified, SimplifyWithGuarantee(predicates.back(), subtree_expr)); diff --git a/cpp/src/arrow/datum.cc b/cpp/src/arrow/datum.cc index f06e97a20ec..d0b5cf62c61 100644 --- a/cpp/src/arrow/datum.cc +++ b/cpp/src/arrow/datum.cc @@ -69,18 +69,18 @@ Datum::Datum(const RecordBatch& value) std::shared_ptr Datum::make_array() const { DCHECK_EQ(Datum::ARRAY, this->kind()); - return MakeArray(util::get>(this->value)); + return MakeArray(std::get>(this->value)); } const std::shared_ptr& Datum::type() const { if (this->kind() == Datum::ARRAY) { - return util::get>(this->value)->type; + return std::get>(this->value)->type; } if (this->kind() == Datum::CHUNKED_ARRAY) { - return util::get>(this->value)->type(); + return std::get>(this->value)->type(); } if (this->kind() == Datum::SCALAR) { - return util::get>(this->value)->type; + return std::get>(this->value)->type; } static std::shared_ptr no_type; return no_type; @@ -88,10 +88,10 @@ const std::shared_ptr& Datum::type() const { const std::shared_ptr& Datum::schema() const { if (this->kind() == Datum::RECORD_BATCH) { - return util::get>(this->value)->schema(); + return std::get>(this->value)->schema(); } if (this->kind() == Datum::TABLE) { - return util::get>(this->value)->schema(); + return std::get>(this->value)->schema(); } static std::shared_ptr no_schema; return no_schema; @@ -100,13 +100,13 @@ const std::shared_ptr& Datum::schema() const { int64_t Datum::length() const { switch (this->kind()) { case Datum::ARRAY: - return util::get>(this->value)->length; + return std::get>(this->value)->length; case Datum::CHUNKED_ARRAY: - return util::get>(this->value)->length(); + return std::get>(this->value)->length(); case Datum::RECORD_BATCH: - return util::get>(this->value)->num_rows(); + return std::get>(this->value)->num_rows(); case Datum::TABLE: - return util::get>(this->value)->num_rows(); + return std::get>(this->value)->num_rows(); case Datum::SCALAR: return 1; default: @@ -117,14 +117,13 @@ int64_t Datum::length() const { int64_t Datum::TotalBufferSize() const { switch (this->kind()) { case Datum::ARRAY: - return util::TotalBufferSize(*util::get>(this->value)); + return util::TotalBufferSize(*std::get>(this->value)); case Datum::CHUNKED_ARRAY: - return util::TotalBufferSize( - *util::get>(this->value)); + return util::TotalBufferSize(*std::get>(this->value)); case Datum::RECORD_BATCH: - return util::TotalBufferSize(*util::get>(this->value)); + return util::TotalBufferSize(*std::get>(this->value)); case Datum::TABLE: - return util::TotalBufferSize(*util::get>(this->value)); + return util::TotalBufferSize(*std::get>(this->value)); case Datum::SCALAR: return 0; default: @@ -135,11 +134,11 @@ int64_t Datum::TotalBufferSize() const { int64_t Datum::null_count() const { if (this->kind() == Datum::ARRAY) { - return util::get>(this->value)->GetNullCount(); + return std::get>(this->value)->GetNullCount(); } else if (this->kind() == Datum::CHUNKED_ARRAY) { - return util::get>(this->value)->null_count(); + return std::get>(this->value)->null_count(); } else if (this->kind() == Datum::SCALAR) { - const auto& val = *util::get>(this->value); + const auto& val = *std::get>(this->value); return val.is_valid ? 0 : 1; } else { DCHECK(false) << "This function only valid for array-like values"; diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h index d4aaff22ce3..ffea7800ecf 100644 --- a/cpp/src/arrow/datum.h +++ b/cpp/src/arrow/datum.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "arrow/array/data.h" @@ -30,7 +31,6 @@ #include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" #include "arrow/util/macros.h" -#include "arrow/util/variant.h" // IWYU pragma: export #include "arrow/util/visibility.h" namespace arrow { @@ -51,9 +51,9 @@ struct ARROW_EXPORT Datum { // current variant does not have a length. static constexpr int64_t kUnknownLength = -1; - util::Variant, std::shared_ptr, - std::shared_ptr, std::shared_ptr, - std::shared_ptr

> + std::variant, std::shared_ptr, + std::shared_ptr, std::shared_ptr, + std::shared_ptr
> value; /// \brief Empty datum, to be populated elsewhere @@ -136,7 +136,7 @@ struct ARROW_EXPORT Datum { } const std::shared_ptr& array() const { - return util::get>(this->value); + return std::get>(this->value); } /// \brief The sum of bytes in each buffer referenced by the datum @@ -149,19 +149,19 @@ struct ARROW_EXPORT Datum { std::shared_ptr make_array() const; const std::shared_ptr& chunked_array() const { - return util::get>(this->value); + return std::get>(this->value); } const std::shared_ptr& record_batch() const { - return util::get>(this->value); + return std::get>(this->value); } const std::shared_ptr
& table() const { - return util::get>(this->value); + return std::get>(this->value); } const std::shared_ptr& scalar() const { - return util::get>(this->value); + return std::get>(this->value); } template diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 3911373b7b7..00427fc4c9e 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -558,7 +558,7 @@ Result> ExtractSchemaToBind(const compute::Declaration& const auto& opts = checked_cast(*(declr.options)); bind_schema = opts.dataset->schema(); } else if (declr.factory_name == "filter") { - auto input_declr = util::get(declr.inputs[0]); + auto input_declr = std::get(declr.inputs[0]); ARROW_ASSIGN_OR_RAISE(bind_schema, ExtractSchemaToBind(input_declr)); } else if (declr.factory_name == "sink") { // Note that the sink has no output_schema @@ -633,7 +633,7 @@ Result> FilterRelationConverter( auto declr_input = declaration.inputs[0]; ARROW_ASSIGN_OR_RAISE( auto input_rel, - ToProto(util::get(declr_input), ext_set, conversion_options)); + ToProto(std::get(declr_input), ext_set, conversion_options)); filter_rel->set_allocated_input(input_rel.release()); ARROW_ASSIGN_OR_RAISE(auto subs_expr, @@ -667,7 +667,7 @@ Status SerializeAndCombineRelations(const compute::Declaration& declaration, // Generally when a plan is deserialized the declaration will be a sink declaration. // Since there is no Sink relation in substrait, this function would be recursively // called on the input of the Sink declaration. - auto sink_input_decl = util::get(declaration.inputs[0]); + auto sink_input_decl = std::get(declaration.inputs[0]); RETURN_NOT_OK( SerializeAndCombineRelations(sink_input_decl, ext_set, rel, conversion_options)); } else { diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index cb1eadcdbdf..afa676a4095 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -185,10 +185,10 @@ void CheckRoundTripResult(const std::shared_ptr output_schema, ASSERT_OK_AND_ASSIGN(auto sink_decls, DeserializePlans( *buf, [] { return kNullConsumer; }, ext_id_reg, &ext_set, conversion_options)); - auto other_declrs = sink_decls[0].inputs[0].get(); + auto& other_declrs = std::get(sink_decls[0].inputs[0]); ASSERT_OK_AND_ASSIGN(auto output_table, - GetTableFromPlan(*other_declrs, exec_context, output_schema)); + GetTableFromPlan(other_declrs, exec_context, output_schema)); if (!include_columns.empty()) { ASSERT_OK_AND_ASSIGN(output_table, output_table->SelectColumns(include_columns)); } @@ -1045,7 +1045,7 @@ TEST(Substrait, DeserializeWithWriteOptionsFactory) { compute::Declaration* decl = &declarations[0]; ASSERT_EQ(decl->factory_name, "write"); ASSERT_EQ(decl->inputs.size(), 1); - decl = util::get_if(&decl->inputs[0]); + decl = std::get_if(&decl->inputs[0]); ASSERT_NE(decl, nullptr); ASSERT_EQ(decl->factory_name, "scan"); ASSERT_OK_AND_ASSIGN(auto plan, compute::ExecPlan::Make()); @@ -1216,21 +1216,21 @@ TEST(Substrait, JoinPlanBasic) { auto join_decl = sink_decls[0].inputs[0]; - const auto& join_rel = join_decl.get(); + const auto& join_rel = std::get(join_decl); const auto& join_options = - checked_cast(*join_rel->options); + checked_cast(*join_rel.options); - EXPECT_EQ(join_rel->factory_name, "hashjoin"); + EXPECT_EQ(join_rel.factory_name, "hashjoin"); EXPECT_EQ(join_options.join_type, compute::JoinType::INNER); - const auto& left_rel = join_rel->inputs[0].get(); - const auto& right_rel = join_rel->inputs[1].get(); + const auto& left_rel = std::get(join_rel.inputs[0]); + const auto& right_rel = std::get(join_rel.inputs[1]); const auto& l_options = - checked_cast(*left_rel->options); + checked_cast(*left_rel.options); const auto& r_options = - checked_cast(*right_rel->options); + checked_cast(*right_rel.options); AssertSchemaEqual( l_options.dataset->schema(), @@ -1582,12 +1582,12 @@ TEST(Substrait, AggregateBasic) { DeserializePlans(*buf, [] { return kNullConsumer; })); auto agg_decl = sink_decls[0].inputs[0]; - const auto& agg_rel = agg_decl.get(); + const auto& agg_rel = std::get(agg_decl); const auto& agg_options = - checked_cast(*agg_rel->options); + checked_cast(*agg_rel.options); - EXPECT_EQ(agg_rel->factory_name, "aggregate"); + EXPECT_EQ(agg_rel.factory_name, "aggregate"); EXPECT_EQ(agg_options.aggregates[0].name, ""); EXPECT_EQ(agg_options.aggregates[0].function, "hash_sum"); } @@ -1968,9 +1968,10 @@ TEST(Substrait, BasicPlanRoundTripping) { DeserializePlans( *serialized_plan, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); // filter declaration - auto roundtripped_filter = sink_decls[0].inputs[0].get(); + const auto& roundtripped_filter = + std::get(sink_decls[0].inputs[0]); const auto& filter_opts = - checked_cast(*(roundtripped_filter->options)); + checked_cast(*(roundtripped_filter.options)); auto roundtripped_expr = filter_opts.filter_expression; if (auto* call = roundtripped_expr.call()) { @@ -1982,9 +1983,10 @@ TEST(Substrait, BasicPlanRoundTripping) { EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right); } // scan declaration - auto roundtripped_scan = roundtripped_filter->inputs[0].get(); + const auto& roundtripped_scan = + std::get(roundtripped_filter.inputs[0]); const auto& dataset_opts = - checked_cast(*(roundtripped_scan->options)); + checked_cast(*(roundtripped_scan.options)); const auto& roundripped_ds = dataset_opts.dataset; EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema)); ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments()); @@ -2080,9 +2082,9 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { DeserializePlans( *serialized_plan, [] { return kNullConsumer; }, ext_id_reg, &ext_set)); // filter declaration - auto roundtripped_filter = sink_decls[0].inputs[0].get(); + auto& roundtripped_filter = std::get(sink_decls[0].inputs[0]); const auto& filter_opts = - checked_cast(*(roundtripped_filter->options)); + checked_cast(*(roundtripped_filter.options)); auto roundtripped_expr = filter_opts.filter_expression; if (auto* call = roundtripped_expr.call()) { @@ -2094,9 +2096,10 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { EXPECT_EQ(dummy_schema->field_names()[right_index], filter_col_right); } // scan declaration - auto roundtripped_scan = roundtripped_filter->inputs[0].get(); + const auto& roundtripped_scan = + std::get(roundtripped_filter.inputs[0]); const auto& dataset_opts = - checked_cast(*(roundtripped_scan->options)); + checked_cast(*(roundtripped_scan.options)); const auto& roundripped_ds = dataset_opts.dataset; EXPECT_TRUE(roundripped_ds->schema()->Equals(*dummy_schema)); ASSERT_OK_AND_ASSIGN(auto roundtripped_frgs, roundripped_ds->GetFragments()); @@ -2112,8 +2115,8 @@ TEST(Substrait, BasicPlanRoundTrippingEndToEnd) { checked_cast(roundtrip_frg_vec[idx++].get()); EXPECT_TRUE(l_frag->Equals(*r_frag)); } - ASSERT_OK_AND_ASSIGN(auto rnd_trp_table, GetTableFromPlan(*roundtripped_filter, - exec_context, dummy_schema)); + ASSERT_OK_AND_ASSIGN(auto rnd_trp_table, + GetTableFromPlan(roundtripped_filter, exec_context, dummy_schema)); EXPECT_TRUE(expected_table->Equals(*rnd_trp_table)); } diff --git a/cpp/src/arrow/filesystem/mockfs.cc b/cpp/src/arrow/filesystem/mockfs.cc index 69e49b32043..d8302bed471 100644 --- a/cpp/src/arrow/filesystem/mockfs.cc +++ b/cpp/src/arrow/filesystem/mockfs.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "arrow/buffer.h" @@ -35,7 +36,6 @@ #include "arrow/util/future.h" #include "arrow/util/logging.h" #include "arrow/util/string_view.h" -#include "arrow/util/variant.h" #include "arrow/util/windows_fixup.h" namespace arrow { @@ -120,7 +120,7 @@ struct Directory { }; // A filesystem entry -using EntryBase = util::Variant; +using EntryBase = std::variant; class Entry : public EntryBase { public: @@ -129,13 +129,13 @@ class Entry : public EntryBase { explicit Entry(Directory&& v) : EntryBase(std::move(v)) {} explicit Entry(File&& v) : EntryBase(std::move(v)) {} - bool is_dir() const { return util::holds_alternative(*this); } + bool is_dir() const { return std::holds_alternative(*this); } - bool is_file() const { return util::holds_alternative(*this); } + bool is_file() const { return std::holds_alternative(*this); } - Directory& as_dir() { return util::get(*this); } + Directory& as_dir() { return std::get(*this); } - File& as_file() { return util::get(*this); } + File& as_file() { return std::get(*this); } // Get info for this entry. Note the path() property isn't set. FileInfo GetInfo() { diff --git a/cpp/src/arrow/flight/client.h b/cpp/src/arrow/flight/client.h index 0298abe366d..61fa6e9d0c4 100644 --- a/cpp/src/arrow/flight/client.h +++ b/cpp/src/arrow/flight/client.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include "arrow/ipc/options.h" @@ -32,7 +33,6 @@ #include "arrow/result.h" #include "arrow/status.h" #include "arrow/util/cancel.h" -#include "arrow/util/variant.h" #include "arrow/flight/type_fwd.h" #include "arrow/flight/types.h" // IWYU pragma: keep @@ -118,7 +118,7 @@ struct ARROW_FLIGHT_EXPORT FlightClientOptions { /// \brief Generic connection options, passed to the underlying /// transport; interpretation is implementation-dependent. - std::vector>> generic_options; + std::vector>> generic_options; /// \brief Use TLS without validating the server certificate. Use with caution. bool disable_server_verification = false; diff --git a/cpp/src/arrow/flight/sql/server.cc b/cpp/src/arrow/flight/sql/server.cc index 1905b117d61..c303873f549 100644 --- a/cpp/src/arrow/flight/sql/server.cc +++ b/cpp/src/arrow/flight/sql/server.cc @@ -935,7 +935,7 @@ arrow::Result> FlightSqlServerBase::DoGetSqlIn return Status::KeyError("No information for SQL info number ", info); } ARROW_RETURN_NOT_OK(name_field_builder.Append(info)); - ARROW_RETURN_NOT_OK(arrow::util::visit(sql_info_result_appender, it->second)); + ARROW_RETURN_NOT_OK(std::visit(sql_info_result_appender, it->second)); } std::shared_ptr name; diff --git a/cpp/src/arrow/flight/sql/server_test.cc b/cpp/src/arrow/flight/sql/server_test.cc index 785f45551fc..dc59b4a2c1c 100644 --- a/cpp/src/arrow/flight/sql/server_test.cc +++ b/cpp/src/arrow/flight/sql/server_test.cc @@ -760,7 +760,7 @@ TEST_F(TestFlightSqlServer, TestCommandGetSqlInfo) { reinterpret_cast(*scalar)); const auto& expected_result = sql_info_expected_results.at(col_name_chunk_data[row]); - arrow::util::visit(validator, expected_result); + std::visit(validator, expected_result); } } } diff --git a/cpp/src/arrow/flight/sql/types.h b/cpp/src/arrow/flight/sql/types.h index 8b28ed18bdd..79c34fa5581 100644 --- a/cpp/src/arrow/flight/sql/types.h +++ b/cpp/src/arrow/flight/sql/types.h @@ -21,11 +21,11 @@ #include #include #include +#include #include #include "arrow/flight/sql/visibility.h" #include "arrow/type_fwd.h" -#include "arrow/util/variant.h" namespace arrow { namespace flight { @@ -37,8 +37,8 @@ namespace sql { /// \brief Variant supporting all possible types on SQL info. using SqlInfoResult = - arrow::util::Variant, - std::unordered_map>>; + std::variant, + std::unordered_map>>; /// \brief Map SQL info identifier to its value. using SqlInfoResultMap = std::unordered_map; diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc index 8fe1e1bae79..e6f50169607 100644 --- a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc +++ b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc @@ -631,10 +631,10 @@ class GrpcClientImpl : public internal::ClientTransport { // Allow setting generic gRPC options. for (const auto& arg : options.generic_options) { - if (util::holds_alternative(arg.second)) { - default_args[arg.first] = util::get(arg.second); - } else if (util::holds_alternative(arg.second)) { - args.SetString(arg.first, util::get(arg.second)); + if (std::holds_alternative(arg.second)) { + default_args[arg.first] = std::get(arg.second); + } else if (std::holds_alternative(arg.second)) { + args.SetString(arg.first, std::get(arg.second)); } // Otherwise unimplemented } diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index efff07db667..c91fa234e0a 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -1110,28 +1110,29 @@ Result> FieldPath::Get(const ArrayData& data) const { } FieldRef::FieldRef(FieldPath indices) : impl_(std::move(indices)) { - DCHECK_GT(util::get(impl_).indices().size(), 0); + DCHECK_GT(std::get(impl_).indices().size(), 0); } void FieldRef::Flatten(std::vector children) { // flatten children struct Visitor { - void operator()(std::string* name) { *out++ = FieldRef(std::move(*name)); } + void operator()(std::string&& name) { out->push_back(FieldRef(std::move(name))); } - void operator()(FieldPath* indices) { *out++ = FieldRef(std::move(*indices)); } + void operator()(FieldPath&& indices) { out->push_back(FieldRef(std::move(indices))); } - void operator()(std::vector* children) { - for (auto& child : *children) { - util::visit(*this, &child.impl_); + void operator()(std::vector&& children) { + out->reserve(out->size() + children.size()); + for (auto&& child : children) { + std::visit(*this, std::move(child.impl_)); } } - std::back_insert_iterator> out; + std::vector* out; }; std::vector out; - Visitor visitor{std::back_inserter(out)}; - visitor(&children); + Visitor visitor{&out}; + visitor(std::move(children)); DCHECK(!out.empty()); DCHECK(std::none_of(out.begin(), out.end(), @@ -1237,7 +1238,7 @@ std::string FieldRef::ToDotPath() const { } }; - return util::visit(Visitor{}, impl_); + return std::visit(Visitor{}, impl_); } size_t FieldRef::hash() const { @@ -1257,7 +1258,7 @@ size_t FieldRef::hash() const { } }; - return util::visit(Visitor{}, impl_); + return std::visit(Visitor{}, impl_); } std::string FieldRef::ToString() const { @@ -1277,7 +1278,7 @@ std::string FieldRef::ToString() const { } }; - return "FieldRef." + util::visit(Visitor{}, impl_); + return "FieldRef." + std::visit(Visitor{}, impl_); } std::vector FieldRef::FindAll(const Schema& schema) const { @@ -1379,7 +1380,7 @@ std::vector FieldRef::FindAll(const FieldVector& fields) const { const FieldVector& fields_; }; - return util::visit(Visitor{fields}, impl_); + return std::visit(Visitor{fields}, impl_); } std::vector FieldRef::FindAll(const ArrayData& array) const { diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 209c5e52e12..663c4765127 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include "arrow/result.h" @@ -32,7 +33,6 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/endian.h" #include "arrow/util/macros.h" -#include "arrow/util/variant.h" #include "arrow/util/visibility.h" #include "arrow/visitor.h" // IWYU pragma: keep @@ -1735,23 +1735,23 @@ class ARROW_EXPORT FieldRef : public util::EqualityComparable { explicit operator bool() const { return Equals(FieldPath{}); } bool operator!() const { return !Equals(FieldPath{}); } - bool IsFieldPath() const { return util::holds_alternative(impl_); } - bool IsName() const { return util::holds_alternative(impl_); } + bool IsFieldPath() const { return std::holds_alternative(impl_); } + bool IsName() const { return std::holds_alternative(impl_); } bool IsNested() const { if (IsName()) return false; - if (IsFieldPath()) return util::get(impl_).indices().size() > 1; + if (IsFieldPath()) return std::get(impl_).indices().size() > 1; return true; } const FieldPath* field_path() const { - return IsFieldPath() ? &util::get(impl_) : NULLPTR; + return IsFieldPath() ? &std::get(impl_) : NULLPTR; } const std::string* name() const { - return IsName() ? &util::get(impl_) : NULLPTR; + return IsName() ? &std::get(impl_) : NULLPTR; } const std::vector* nested_refs() const { - return util::holds_alternative>(impl_) - ? &util::get>(impl_) + return std::holds_alternative>(impl_) + ? &std::get>(impl_) : NULLPTR; } @@ -1843,7 +1843,7 @@ class ARROW_EXPORT FieldRef : public util::EqualityComparable { private: void Flatten(std::vector children); - util::Variant> impl_; + std::variant> impl_; ARROW_EXPORT friend void PrintTo(const FieldRef& ref, std::ostream* os); }; diff --git a/cpp/src/arrow/util/CMakeLists.txt b/cpp/src/arrow/util/CMakeLists.txt index cd1a8967eeb..051138a002b 100644 --- a/cpp/src/arrow/util/CMakeLists.txt +++ b/cpp/src/arrow/util/CMakeLists.txt @@ -71,8 +71,7 @@ add_arrow_test(utility-test trie_test.cc uri_test.cc utf8_util_test.cc - value_parsing_test.cc - variant_test.cc) + value_parsing_test.cc) add_arrow_test(threading-utility-test SOURCES @@ -100,4 +99,3 @@ add_arrow_benchmark(thread_pool_benchmark) add_arrow_benchmark(trie_benchmark) add_arrow_benchmark(utf8_util_benchmark) add_arrow_benchmark(value_parsing_benchmark) -add_arrow_benchmark(variant_benchmark) diff --git a/cpp/src/arrow/util/variant.h b/cpp/src/arrow/util/variant.h deleted file mode 100644 index 8bbce525178..00000000000 --- a/cpp/src/arrow/util/variant.h +++ /dev/null @@ -1,443 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#pragma once - -#include -#include -#include -#include - -#include "arrow/util/macros.h" -#include "arrow/util/type_traits.h" - -namespace arrow { -namespace util { - -/// \brief a std::variant-like discriminated union -/// -/// Simplifications from std::variant: -/// -/// - Strictly defaultable. The first type of T... should be nothrow default constructible -/// and it will be used for default Variants. -/// -/// - Never valueless_by_exception. std::variant supports a state outside those specified -/// by T... to which it can return in the event that a constructor throws. If a Variant -/// would become valueless_by_exception it will instead return to its default state. -/// -/// - Strictly nothrow move constructible and assignable -/// -/// - Less sophisticated type deduction. std::variant("hello") will -/// intelligently construct std::string while Variant("hello") will -/// construct bool. -/// -/// - Either both copy constructible and assignable or neither (std::variant independently -/// enables copy construction and copy assignment). Variant is copy constructible if -/// each of T... is copy constructible and assignable. -/// -/// - Slimmer interface; several members of std::variant are omitted. -/// -/// - Throws no exceptions; if a bad_variant_access would be thrown Variant will instead -/// segfault (nullptr dereference). -/// -/// - Mutable visit takes a pointer instead of mutable reference or rvalue reference, -/// which is more conformant with our code style. -template -class Variant; - -namespace detail { - -template -struct is_equality_comparable : std::false_type {}; - -template -struct is_equality_comparable< - T, typename std::enable_if() == std::declval()), bool>::value>::type> - : std::true_type {}; - -template -using conditional_t = typename std::conditional::type; - -template -struct type_constant { - using type = T; -}; - -template -struct first; - -template -struct first { - using type = H; -}; - -template -using decay_t = typename std::decay::type; - -template -struct all : std::true_type {}; - -template -struct all : conditional_t, std::false_type> {}; - -struct delete_copy_constructor { - template - struct type { - type() = default; - type(const type& other) = delete; - type& operator=(const type& other) = delete; - }; -}; - -struct explicit_copy_constructor { - template - struct type { - type() = default; - type(const type& other) { static_cast(other).copy_to(this); } - type& operator=(const type& other) { - static_cast(this)->destroy(); - static_cast(other).copy_to(this); - return *this; - } - }; -}; - -template -struct VariantStorage { - VariantStorage() = default; - VariantStorage(const VariantStorage&) {} - VariantStorage& operator=(const VariantStorage&) { return *this; } - VariantStorage(VariantStorage&&) noexcept {} - VariantStorage& operator=(VariantStorage&&) noexcept { return *this; } - ~VariantStorage() { - static_assert(offsetof(VariantStorage, data_) == 0, - "(void*)&VariantStorage::data_ == (void*)this"); - } - - typename arrow::internal::aligned_union<0, T...>::type data_; - uint8_t index_ = 0; -}; - -template -struct VariantImpl; - -template -struct VariantImpl> : VariantStorage { - static void index_of() noexcept {} - void destroy() noexcept {} - void move_to(...) noexcept {} - void copy_to(...) const {} - - template - [[noreturn]] R visit_const(Visitor&& visitor) const { - std::terminate(); - } - template - [[noreturn]] R visit_mutable(Visitor&& visitor) { - std::terminate(); - } -}; - -template -struct VariantImpl, H, T...> : VariantImpl, T...> { - using VariantType = Variant; - using Impl = VariantImpl; - - static constexpr uint8_t kIndex = sizeof...(M) - sizeof...(T) - 1; - - VariantImpl() = default; - - using VariantImpl::VariantImpl; - using Impl::operator=; - using Impl::index_of; - - explicit VariantImpl(H value) { - new (this) H(std::move(value)); - this->index_ = kIndex; - } - - VariantImpl& operator=(H value) { - static_cast(this)->destroy(); - new (this) H(std::move(value)); - this->index_ = kIndex; - return *this; - } - - H& cast_this() { return *reinterpret_cast(this); } - const H& cast_this() const { return *reinterpret_cast(this); } - - void move_to(VariantType* target) noexcept { - if (this->index_ == kIndex) { - new (target) H(std::move(cast_this())); - target->index_ = kIndex; - } else { - Impl::move_to(target); - } - } - - // Templated to avoid instantiation in case H is not copy constructible - template - void copy_to(Void* generic_target) const { - const auto target = static_cast(generic_target); - try { - if (this->index_ == kIndex) { - new (target) H(cast_this()); - target->index_ = kIndex; - } else { - Impl::copy_to(target); - } - } catch (...) { - target->construct_default(); - throw; - } - } - - void destroy() noexcept { - if (this->index_ == kIndex) { - if (!std::is_trivially_destructible::value) { - cast_this().~H(); - } - } else { - Impl::destroy(); - } - } - - static constexpr std::integral_constant index_of( - const type_constant&) { - return {}; - } - - template - R visit_const(Visitor&& visitor) const { - if (this->index_ == kIndex) { - return std::forward(visitor)(cast_this()); - } - return Impl::template visit_const(std::forward(visitor)); - } - - template - R visit_mutable(Visitor&& visitor) { - if (this->index_ == kIndex) { - return std::forward(visitor)(&cast_this()); - } - return Impl::template visit_mutable(std::forward(visitor)); - } -}; - -} // namespace detail - -template -class Variant : detail::VariantImpl, T...>, - detail::conditional_t< - detail::all<(std::is_copy_constructible::value && - std::is_copy_assignable::value)...>::value, - detail::explicit_copy_constructor, - detail::delete_copy_constructor>::template type> { - template - static constexpr uint8_t index_of() { - return Impl::index_of(detail::type_constant{}); - } - - using Impl = detail::VariantImpl, T...>; - - public: - using default_type = typename util::detail::first::type; - - Variant() noexcept { construct_default(); } - - Variant(const Variant& other) = default; - Variant& operator=(const Variant& other) = default; - Variant& operator=(Variant&& other) noexcept { - this->destroy(); - other.move_to(this); - return *this; - } - - using Impl::Impl; - using Impl::operator=; - - Variant(Variant&& other) noexcept { other.move_to(this); } - - ~Variant() { - static_assert(offsetof(Variant, data_) == 0, "(void*)&Variant::data_ == (void*)this"); - this->destroy(); - } - - /// \brief Return the zero-based type index of the value held by the variant - uint8_t index() const noexcept { return this->index_; } - - /// \brief Get a const pointer to the value held by the variant - /// - /// If the type given as template argument doesn't match, a null pointer is returned. - template ()> - const U* get() const noexcept { - return index() == I ? reinterpret_cast(this) : NULLPTR; - } - - /// \brief Get a pointer to the value held by the variant - /// - /// If the type given as template argument doesn't match, a null pointer is returned. - template ()> - U* get() noexcept { - return index() == I ? reinterpret_cast(this) : NULLPTR; - } - - /// \brief Replace the value held by the variant - /// - /// The intended type must be given as a template argument. - /// The value is constructed in-place using the given function arguments. - template ()> - void emplace(A&&... args) { - try { - this->destroy(); - new (this) U(std::forward(args)...); - this->index_ = I; - } catch (...) { - construct_default(); - throw; - } - } - - template ()> - void emplace(std::initializer_list il, A&&... args) { - try { - this->destroy(); - new (this) U(il, std::forward(args)...); - this->index_ = I; - } catch (...) { - construct_default(); - throw; - } - } - - /// \brief Swap with another variant's contents - void swap(Variant& other) noexcept { // NOLINT google-runtime-references - Variant tmp = std::move(other); - other = std::move(*this); - *this = std::move(tmp); - } - - using Impl::visit_const; - using Impl::visit_mutable; - - private: - void construct_default() noexcept { - new (this) default_type(); - this->index_ = 0; - } - - template - friend struct detail::explicit_copy_constructor::type; - - template - friend struct detail::VariantImpl; -}; - -/// \brief Call polymorphic visitor on a const variant's value -/// -/// The visitor will receive a const reference to the value held by the variant. -/// It must define overloads for each possible variant type. -/// The overloads should all return the same type (no attempt -/// is made to find a generalized return type). -template ()( - std::declval::default_type&>()))> -R visit(Visitor&& visitor, const util::Variant& v) { - return v.template visit_const(std::forward(visitor)); -} - -/// \brief Call polymorphic visitor on a non-const variant's value -/// -/// The visitor will receive a pointer to the value held by the variant. -/// It must define overloads for each possible variant type. -/// The overloads should all return the same type (no attempt -/// is made to find a generalized return type). -template ()( - std::declval::default_type*>()))> -R visit(Visitor&& visitor, util::Variant* v) { - return v->template visit_mutable(std::forward(visitor)); -} - -/// \brief Get a const reference to the value held by the variant -/// -/// If the type given as template argument doesn't match, behavior is undefined -/// (a null pointer will be dereferenced). -template -const U& get(const Variant& v) { - return *v.template get(); -} - -/// \brief Get a reference to the value held by the variant -/// -/// If the type given as template argument doesn't match, behavior is undefined -/// (a null pointer will be dereferenced). -template -U& get(Variant& v) { - return *v.template get(); -} - -/// \brief Get a const pointer to the value held by the variant -/// -/// If the type given as template argument doesn't match, a nullptr is returned. -template -const U* get_if(const Variant* v) { - return v->template get(); -} - -/// \brief Get a pointer to the value held by the variant -/// -/// If the type given as template argument doesn't match, a nullptr is returned. -template -U* get_if(Variant* v) { - return v->template get(); -} - -namespace detail { - -template -struct VariantsEqual { - template - bool operator()(const U& r) const { - return get(l_) == r; - } - const Variant& l_; -}; - -} // namespace detail - -template ::value...>::value>> -bool operator==(const Variant& l, const Variant& r) { - if (l.index() != r.index()) return false; - return visit(detail::VariantsEqual{l}, r); -} - -template -auto operator!=(const Variant& l, const Variant& r) -> decltype(l == r) { - return !(l == r); -} - -/// \brief Return whether the variant holds a value of the given type -template -bool holds_alternative(const Variant& v) { - return v.template get(); -} - -} // namespace util -} // namespace arrow diff --git a/cpp/src/arrow/util/variant_benchmark.cc b/cpp/src/arrow/util/variant_benchmark.cc deleted file mode 100644 index af3fafb8b0e..00000000000 --- a/cpp/src/arrow/util/variant_benchmark.cc +++ /dev/null @@ -1,248 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "benchmark/benchmark.h" - -#include -#include -#include -#include -#include - -#include "arrow/array.h" -#include "arrow/chunked_array.h" -#include "arrow/datum.h" -#include "arrow/status.h" -#include "arrow/testing/gtest_util.h" -#include "arrow/testing/random.h" -#include "arrow/type.h" -#include "arrow/util/checked_cast.h" -#include "arrow/util/variant.h" - -namespace arrow { - -using internal::checked_pointer_cast; - -namespace util { - -using TrivialVariant = arrow::util::Variant; - -using NonTrivialVariant = arrow::util::Variant; - -std::vector MakeInts(int64_t nitems) { - auto rng = arrow::random::RandomArrayGenerator(42); - auto array = checked_pointer_cast(rng.Int32(nitems, 0, 1 << 30)); - std::vector items(nitems); - for (int64_t i = 0; i < nitems; ++i) { - items[i] = array->Value(i); - } - return items; -} - -std::vector MakeFloats(int64_t nitems) { - auto rng = arrow::random::RandomArrayGenerator(42); - auto array = checked_pointer_cast(rng.Float32(nitems, 0.0, 1.0)); - std::vector items(nitems); - for (int64_t i = 0; i < nitems; ++i) { - items[i] = array->Value(i); - } - return items; -} - -std::vector MakeStrings(int64_t nitems) { - auto rng = arrow::random::RandomArrayGenerator(42); - // Some std::string's will use short string optimization, but not all... - auto array = checked_pointer_cast(rng.String(nitems, 5, 40)); - std::vector items(nitems); - for (int64_t i = 0; i < nitems; ++i) { - items[i] = array->GetString(i); - } - return items; -} - -static void ConstructTrivialVariant(benchmark::State& state) { - const int64_t N = 10000; - const auto ints = MakeInts(N); - const auto floats = MakeFloats(N); - - for (auto _ : state) { - for (int64_t i = 0; i < N; ++i) { - // About type selection: we ensure 50% of each type, but try to avoid - // branch mispredictions by creating runs of the same type. - if (i & 0x10) { - TrivialVariant v{ints[i]}; - const int32_t* val = &arrow::util::get(v); - benchmark::DoNotOptimize(val); - } else { - TrivialVariant v{floats[i]}; - const float* val = &arrow::util::get(v); - benchmark::DoNotOptimize(val); - } - } - } - - state.SetItemsProcessed(state.iterations() * N); -} - -static void ConstructNonTrivialVariant(benchmark::State& state) { - const int64_t N = 10000; - const auto ints = MakeInts(N); - const auto strings = MakeStrings(N); - - for (auto _ : state) { - for (int64_t i = 0; i < N; ++i) { - if (i & 0x10) { - NonTrivialVariant v{ints[i]}; - const int32_t* val = &arrow::util::get(v); - benchmark::DoNotOptimize(val); - } else { - NonTrivialVariant v{strings[i]}; - const std::string* val = &arrow::util::get(v); - benchmark::DoNotOptimize(val); - } - } - } - - state.SetItemsProcessed(state.iterations() * N); -} - -struct VariantVisitor { - int64_t total = 0; - - void operator()(const int32_t& v) { total += v; } - void operator()(const float& v) { - // Avoid potentially costly float-to-int conversion - int32_t x; - memcpy(&x, &v, 4); - total += x; - } - void operator()(const std::string& v) { total += static_cast(v.length()); } -}; - -template -static void VisitVariant(benchmark::State& state, - const std::vector& variants) { - for (auto _ : state) { - VariantVisitor visitor; - for (const auto& v : variants) { - visit(visitor, v); - } - benchmark::DoNotOptimize(visitor.total); - } - - state.SetItemsProcessed(state.iterations() * variants.size()); -} - -static void VisitTrivialVariant(benchmark::State& state) { - const int64_t N = 10000; - const auto ints = MakeInts(N); - const auto floats = MakeFloats(N); - - std::vector variants; - variants.reserve(N); - for (int64_t i = 0; i < N; ++i) { - if (i & 0x10) { - variants.emplace_back(ints[i]); - } else { - variants.emplace_back(floats[i]); - } - } - - VisitVariant(state, variants); -} - -static void VisitNonTrivialVariant(benchmark::State& state) { - const int64_t N = 10000; - const auto ints = MakeInts(N); - const auto strings = MakeStrings(N); - - std::vector variants; - variants.reserve(N); - for (int64_t i = 0; i < N; ++i) { - if (i & 0x10) { - variants.emplace_back(ints[i]); - } else { - variants.emplace_back(strings[i]); - } - } - - VisitVariant(state, variants); -} - -static void ConstructDatum(benchmark::State& state) { - const int64_t N = 10000; - auto array = *MakeArrayOfNull(int8(), 100); - auto chunked_array = std::make_shared(ArrayVector{array, array}); - - for (auto _ : state) { - for (int64_t i = 0; i < N; ++i) { - if (i & 0x10) { - Datum datum{array}; - const ArrayData* val = datum.array().get(); - benchmark::DoNotOptimize(val); - } else { - Datum datum{chunked_array}; - const ChunkedArray* val = datum.chunked_array().get(); - benchmark::DoNotOptimize(val); - } - } - } - - state.SetItemsProcessed(state.iterations() * N); -} - -static void VisitDatum(benchmark::State& state) { - const int64_t N = 10000; - auto array = *MakeArrayOfNull(int8(), 100); - auto chunked_array = std::make_shared(ArrayVector{array, array}); - - std::vector datums; - datums.reserve(N); - for (int64_t i = 0; i < N; ++i) { - if (i & 0x10) { - datums.emplace_back(array); - } else { - datums.emplace_back(chunked_array); - } - } - - for (auto _ : state) { - int64_t total = 0; - for (const auto& datum : datums) { - // The .is_XXX() methods are the usual idiom when visiting a Datum, - // rather than the visit() function. - if (datum.is_array()) { - total += datum.array()->length; - } else { - total += datum.chunked_array()->length(); - } - } - benchmark::DoNotOptimize(total); - } - - state.SetItemsProcessed(state.iterations() * datums.size()); -} - -BENCHMARK(ConstructTrivialVariant); -BENCHMARK(ConstructNonTrivialVariant); -BENCHMARK(VisitTrivialVariant); -BENCHMARK(VisitNonTrivialVariant); -BENCHMARK(ConstructDatum); -BENCHMARK(VisitDatum); - -} // namespace util -} // namespace arrow diff --git a/cpp/src/arrow/util/variant_test.cc b/cpp/src/arrow/util/variant_test.cc deleted file mode 100644 index f94d1b6ccf8..00000000000 --- a/cpp/src/arrow/util/variant_test.cc +++ /dev/null @@ -1,345 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/util/variant.h" - -#include -#include -#include -#include -#include - -#include -#include - -#include "arrow/testing/gtest_compat.h" - -namespace arrow { - -namespace util { -namespace { - -using ::testing::Eq; - -template -void AssertDefaultConstruction() { - using variant_type = Variant; - - static_assert(std::is_nothrow_default_constructible::value, ""); - - variant_type v; - EXPECT_EQ(v.index(), 0); - EXPECT_EQ(get(v), H{}); -} - -TEST(Variant, DefaultConstruction) { - AssertDefaultConstruction(); - AssertDefaultConstruction(); - AssertDefaultConstruction(); - AssertDefaultConstruction>(); - AssertDefaultConstruction, int>(); - AssertDefaultConstruction, void*, - std::true_type>(); - AssertDefaultConstruction, void*, bool, - std::string, std::true_type>(); -} - -template -struct AssertCopyConstructionOne { - void operator()(uint8_t index) { - V v{member_}; - EXPECT_EQ(v.index(), index); - EXPECT_EQ(get(v), member_); - - V copy{v}; - EXPECT_EQ(copy.index(), v.index()); - EXPECT_EQ(get(copy), get(v)); - EXPECT_EQ(copy, v); - - V assigned; - assigned = member_; - EXPECT_EQ(assigned.index(), index); - EXPECT_EQ(get(assigned), member_); - - assigned = v; - EXPECT_EQ(assigned.index(), v.index()); - EXPECT_EQ(get(assigned), get(v)); - EXPECT_EQ(assigned, v); - } - - const T& member_; -}; - -template -void AssertCopyConstruction(T... member) { - uint8_t index = 0; - for (auto Assert : {std::function( - AssertCopyConstructionOne, T>{member})...}) { - Assert(index++); - } -} - -template -void AssertCopyConstructionDisabled() { - static_assert(!std::is_copy_constructible>::value, - "copy construction was not disabled"); -} - -TEST(Variant, CopyConstruction) { - // if any member is not copy constructible then Variant is not copy constructible - AssertCopyConstructionDisabled>(); - AssertCopyConstructionDisabled, std::string>(); - AssertCopyConstructionDisabled>(); - - AssertCopyConstruction(32, std::string("hello"), true); - AssertCopyConstruction(std::string("world"), false, 53); - AssertCopyConstruction(nullptr, std::true_type{}, std::string("!")); - AssertCopyConstruction(std::vector{1, 3, 3, 7}, "C string"); - - // copy assignment operator is not used - struct CopyAssignThrows { - CopyAssignThrows() = default; - CopyAssignThrows(const CopyAssignThrows&) = default; - - CopyAssignThrows& operator=(const CopyAssignThrows&) { throw 42; } - - CopyAssignThrows(CopyAssignThrows&&) = default; - CopyAssignThrows& operator=(CopyAssignThrows&&) = default; - - bool operator==(const CopyAssignThrows&) const { return true; } - }; - EXPECT_NO_THROW(AssertCopyConstruction(CopyAssignThrows{})); -} - -TEST(Variant, Emplace) { - using variant_type = Variant, int>; - variant_type v; - - v.emplace(); - EXPECT_EQ(v, variant_type{int{}}); - - v.emplace("hello"); - EXPECT_EQ(v, variant_type{std::string("hello")}); - - v.emplace>({1, 3, 3, 7}); - EXPECT_EQ(v, variant_type{std::vector({1, 3, 3, 7})}); -} - -TEST(Variant, MoveConstruction) { - struct noop_delete { - void operator()(...) const {} - }; - using ptr = std::unique_ptr; - static_assert(!std::is_copy_constructible::value, ""); - - using variant_type = Variant; - - int tag = 42; - auto ExpectIsTag = [&](const variant_type& v) { - EXPECT_EQ(v.index(), 1); - EXPECT_EQ(get(v).get(), &tag); - }; - - ptr p; - - // move construction from member - p.reset(&tag); - variant_type v0{std::move(p)}; - ExpectIsTag(v0); - - // move assignment from member - p.reset(&tag); - v0 = std::move(p); - ExpectIsTag(v0); - - // move construction from other variant - variant_type v1{std::move(v0)}; - ExpectIsTag(v1); - - // move assignment from other variant - p.reset(&tag); - variant_type v2{std::move(p)}; - v1 = std::move(v2); - ExpectIsTag(v1); - - // type changing move assignment from member - variant_type v3; - EXPECT_NE(v3.index(), 1); - p.reset(&tag); - v3 = std::move(p); - ExpectIsTag(v3); - - // type changing move assignment from other variant - variant_type v4; - EXPECT_NE(v4.index(), 1); - v4 = std::move(v3); - ExpectIsTag(v4); -} - -TEST(Variant, ExceptionSafety) { - struct { - } actually_throw; - - struct { - } dont_throw; - - struct ConstructorThrows { - explicit ConstructorThrows(decltype(actually_throw)) { throw 42; } - explicit ConstructorThrows(decltype(dont_throw)) {} - - ConstructorThrows(const ConstructorThrows&) { throw 42; } - - ConstructorThrows& operator=(const ConstructorThrows&) = default; - ConstructorThrows(ConstructorThrows&&) = default; - ConstructorThrows& operator=(ConstructorThrows&&) = default; - }; - - Variant v; - - // constructor throws during emplacement - EXPECT_THROW(v.emplace(actually_throw), int); - // safely returned to the default state - EXPECT_EQ(v.index(), 0); - - // constructor throws during copy assignment from member - EXPECT_THROW( - { - const ConstructorThrows throws(dont_throw); - v = throws; - }, - int); - // safely returned to the default state - EXPECT_EQ(v.index(), 0); -} - -// XXX GTest 1.11 exposes a `using std::visit` in its headers which -// somehow gets preferred to `arrow::util::visit`, even if there is -// a using clause (perhaps because of macros such as EXPECT_EQ). -template -void DoVisit(Args&&... args) { - return ::arrow::util::visit(std::forward(args)...); -} - -template -void AssertVisitedEquals(const T& expected, Args&&... args) { - const auto actual = ::arrow::util::visit(std::forward(args)...); - EXPECT_EQ(expected, actual); -} - -template -struct AssertVisitOne { - void operator()(const T& actual) { EXPECT_EQ(&actual, expected_); } - - void operator()(T* actual) { EXPECT_EQ(actual, expected_); } - - template - void operator()(const U&) { - FAIL() << "the expected type was not visited."; - } - - template - void operator()(U*) { - FAIL() << "the expected type was not visited."; - } - - explicit AssertVisitOne(T member) : member_(std::move(member)) {} - - void operator()() { - V v{member_}; - expected_ = &get(v); - DoVisit(*this, v); - DoVisit(*this, &v); - } - - T member_; - const T* expected_; -}; - -// Try visiting all alternatives on a Variant -template -void AssertVisitAll(T... member) { - for (auto Assert : - {std::function(AssertVisitOne, T>{member})...}) { - Assert(); - } -} - -TEST(VariantTest, Visit) { - AssertVisitAll(32, std::string("hello"), true); - AssertVisitAll(std::string("world"), false, 53); - AssertVisitAll(nullptr, std::true_type{}, std::string("!")); - AssertVisitAll(std::vector{1, 3, 3, 7}, "C string"); - - using int_or_string = Variant; - int_or_string v; - - // value returning visit: - struct { - int_or_string operator()(int i) { return int_or_string{i * 2}; } - int_or_string operator()(const std::string& s) { return int_or_string{s + s}; } - } Double; - - v = 7; - AssertVisitedEquals(int_or_string{14}, Double, v); - - v = "lolol"; - AssertVisitedEquals(int_or_string{"lolollolol"}, Double, v); - - // mutating visit: - struct { - void operator()(int* i) { *i *= 2; } - void operator()(std::string* s) { *s += *s; } - } DoubleInplace; - - v = 7; - DoVisit(DoubleInplace, &v); - EXPECT_EQ(v, int_or_string{14}); - - v = "lolol"; - DoVisit(DoubleInplace, &v); - EXPECT_EQ(v, int_or_string{"lolollolol"}); -} - -TEST(VariantTest, Equality) { - using int_or_double = Variant; - - auto eq = [](const int_or_double& a, const int_or_double& b) { - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - }; - auto ne = [](const int_or_double& a, const int_or_double& b) { - EXPECT_TRUE(a != b); - EXPECT_FALSE(a == b); - }; - - int_or_double u, v; - u.emplace(1); - v.emplace(1); - eq(u, v); - v.emplace(2); - ne(u, v); - v.emplace(1.0); - ne(u, v); - u.emplace(1.0); - eq(u, v); - u.emplace(2.0); - ne(u, v); -} - -} // namespace -} // namespace util -} // namespace arrow diff --git a/cpp/src/gandiva/interval_holder.h b/cpp/src/gandiva/interval_holder.h index 43f34019328..e1a50bcf683 100644 --- a/cpp/src/gandiva/interval_holder.h +++ b/cpp/src/gandiva/interval_holder.h @@ -60,7 +60,7 @@ class GANDIVA_EXPORT IntervalHolder : public FunctionHolder { " function needs to be an integer literal to indicate " "whether to suppress the error"); } - suppress_errors = arrow::util::get(literal_suppress_errors->holder()); + suppress_errors = std::get(literal_suppress_errors->holder()); } return Make(suppress_errors, holder); diff --git a/cpp/src/gandiva/literal_holder.cc b/cpp/src/gandiva/literal_holder.cc index beed8119cb1..a77140b727e 100644 --- a/cpp/src/gandiva/literal_holder.cc +++ b/cpp/src/gandiva/literal_holder.cc @@ -38,7 +38,7 @@ struct LiteralToStream { std::string ToString(const LiteralHolder& holder) { std::stringstream ss; LiteralToStream visitor{ss}; - ::arrow::util::visit(visitor, holder); + ::std::visit(visitor, holder); return ss.str(); } diff --git a/cpp/src/gandiva/literal_holder.h b/cpp/src/gandiva/literal_holder.h index c4712aafc4b..40faf39e1f5 100644 --- a/cpp/src/gandiva/literal_holder.h +++ b/cpp/src/gandiva/literal_holder.h @@ -18,8 +18,7 @@ #pragma once #include - -#include +#include #include #include "gandiva/decimal_scalar.h" @@ -28,8 +27,8 @@ namespace gandiva { using LiteralHolder = - arrow::util::Variant; + std::variant; GANDIVA_EXPORT std::string ToString(const LiteralHolder& holder); diff --git a/cpp/src/gandiva/llvm_generator.cc b/cpp/src/gandiva/llvm_generator.cc index 42a05a7dff0..58efef9676f 100644 --- a/cpp/src/gandiva/llvm_generator.cc +++ b/cpp/src/gandiva/llvm_generator.cc @@ -664,44 +664,44 @@ void LLVMGenerator::Visitor::Visit(const LiteralDex& dex) { switch (dex.type()->id()) { case arrow::Type::BOOL: - value = types->i1_constant(arrow::util::get(dex.holder())); + value = types->i1_constant(std::get(dex.holder())); break; case arrow::Type::UINT8: - value = types->i8_constant(arrow::util::get(dex.holder())); + value = types->i8_constant(std::get(dex.holder())); break; case arrow::Type::UINT16: - value = types->i16_constant(arrow::util::get(dex.holder())); + value = types->i16_constant(std::get(dex.holder())); break; case arrow::Type::UINT32: - value = types->i32_constant(arrow::util::get(dex.holder())); + value = types->i32_constant(std::get(dex.holder())); break; case arrow::Type::UINT64: - value = types->i64_constant(arrow::util::get(dex.holder())); + value = types->i64_constant(std::get(dex.holder())); break; case arrow::Type::INT8: - value = types->i8_constant(arrow::util::get(dex.holder())); + value = types->i8_constant(std::get(dex.holder())); break; case arrow::Type::INT16: - value = types->i16_constant(arrow::util::get(dex.holder())); + value = types->i16_constant(std::get(dex.holder())); break; case arrow::Type::FLOAT: - value = types->float_constant(arrow::util::get(dex.holder())); + value = types->float_constant(std::get(dex.holder())); break; case arrow::Type::DOUBLE: - value = types->double_constant(arrow::util::get(dex.holder())); + value = types->double_constant(std::get(dex.holder())); break; case arrow::Type::STRING: case arrow::Type::BINARY: { - const std::string& str = arrow::util::get(dex.holder()); + const std::string& str = std::get(dex.holder()); value = ir_builder()->CreateGlobalStringPtr(str.c_str()); len = types->i32_constant(static_cast(str.length())); @@ -712,7 +712,7 @@ void LLVMGenerator::Visitor::Visit(const LiteralDex& dex) { case arrow::Type::DATE32: case arrow::Type::TIME32: case arrow::Type::INTERVAL_MONTHS: - value = types->i32_constant(arrow::util::get(dex.holder())); + value = types->i32_constant(std::get(dex.holder())); break; case arrow::Type::INT64: @@ -720,12 +720,12 @@ void LLVMGenerator::Visitor::Visit(const LiteralDex& dex) { case arrow::Type::TIME64: case arrow::Type::TIMESTAMP: case arrow::Type::INTERVAL_DAY_TIME: - value = types->i64_constant(arrow::util::get(dex.holder())); + value = types->i64_constant(std::get(dex.holder())); break; case arrow::Type::DECIMAL: { // build code for struct - auto scalar = arrow::util::get(dex.holder()); + auto scalar = std::get(dex.holder()); // ConstantInt doesn't have a get method that takes int128 or a pair of int64. so, // passing the string representation instead. auto int128_value = diff --git a/cpp/src/gandiva/node.h b/cpp/src/gandiva/node.h index 60bc38d367b..858c6570489 100644 --- a/cpp/src/gandiva/node.h +++ b/cpp/src/gandiva/node.h @@ -86,12 +86,12 @@ class GANDIVA_EXPORT LiteralNode : public Node { // The default formatter prints in decimal can cause a loss in precision. so, // print in hex. Can't use hexfloat since gcc 4.9 doesn't support it. if (return_type()->id() == arrow::Type::DOUBLE) { - double dvalue = arrow::util::get(holder_); + double dvalue = std::get(holder_); uint64_t bits; memcpy(&bits, &dvalue, sizeof(bits)); ss << " raw(" << std::hex << bits << ")"; } else if (return_type()->id() == arrow::Type::FLOAT) { - float fvalue = arrow::util::get(holder_); + float fvalue = std::get(holder_); uint32_t bits; memcpy(&bits, &fvalue, sizeof(bits)); ss << " raw(" << std::hex << bits << ")"; diff --git a/cpp/src/gandiva/random_generator_holder.cc b/cpp/src/gandiva/random_generator_holder.cc index 3471c87d92b..3d395741d70 100644 --- a/cpp/src/gandiva/random_generator_holder.cc +++ b/cpp/src/gandiva/random_generator_holder.cc @@ -39,7 +39,7 @@ Status RandomGeneratorHolder::Make(const FunctionNode& node, Status::Invalid("'random' function requires an int32 literal as parameter")); *holder = std::shared_ptr(new RandomGeneratorHolder( - literal->is_null() ? 0 : arrow::util::get(literal->holder()))); + literal->is_null() ? 0 : std::get(literal->holder()))); return Status::OK(); } } // namespace gandiva diff --git a/cpp/src/gandiva/regex_functions_holder.cc b/cpp/src/gandiva/regex_functions_holder.cc index b1e2e59cb2a..d70c73f233a 100644 --- a/cpp/src/gandiva/regex_functions_holder.cc +++ b/cpp/src/gandiva/regex_functions_holder.cc @@ -30,7 +30,7 @@ std::string& RemovePatternEscapeChars(const FunctionNode& node, std::string& pat if (node.children().size() != 2) { auto escape_char = dynamic_cast(node.children().at(2).get()); pattern.erase(std::remove(pattern.begin(), pattern.end(), - arrow::util::get(escape_char->holder()).at(0)), + std::get(escape_char->holder()).at(0)), pattern.end()); // remove escape chars } else { pattern.erase(std::remove(pattern.begin(), pattern.end(), '\\'), pattern.end()); @@ -95,10 +95,10 @@ Status LikeHolder::Make(const FunctionNode& node, std::shared_ptr* h if (node.descriptor()->name() == "ilike") { regex_op.set_case_sensitive(false); // set case-insensitive for ilike function. - return Make(arrow::util::get(literal->holder()), holder, regex_op); + return Make(std::get(literal->holder()), holder, regex_op); } if (node.children().size() == 2) { - return Make(arrow::util::get(literal->holder()), holder); + return Make(std::get(literal->holder()), holder); } else { auto escape_char = dynamic_cast(node.children().at(2).get()); ARROW_RETURN_IF( @@ -110,8 +110,8 @@ Status LikeHolder::Make(const FunctionNode& node, std::shared_ptr* h !IsArrowStringLiteral(escape_char_type), Status::Invalid( "'like' function requires a string literal as the third parameter")); - return Make(arrow::util::get(literal->holder()), - arrow::util::get(escape_char->holder()), holder); + return Make(std::get(literal->holder()), + std::get(escape_char->holder()), holder); } } @@ -180,7 +180,7 @@ Status ReplaceHolder::Make(const FunctionNode& node, Status::Invalid( "'replace' function requires a string literal as the second parameter")); - return Make(arrow::util::get(literal->holder()), holder); + return Make(std::get(literal->holder()), holder); } Status ReplaceHolder::Make(const std::string& sql_pattern, @@ -210,7 +210,7 @@ Status ExtractHolder::Make(const FunctionNode& node, literal == nullptr || !IsArrowStringLiteral(literal->return_type()->id()), Status::Invalid("'extract' function requires a literal as the second parameter")); - return ExtractHolder::Make(arrow::util::get(literal->holder()), holder); + return ExtractHolder::Make(std::get(literal->holder()), holder); } Status ExtractHolder::Make(const std::string& sql_pattern, diff --git a/cpp/src/gandiva/to_date_holder.cc b/cpp/src/gandiva/to_date_holder.cc index 1b7e2864f60..27a16d17799 100644 --- a/cpp/src/gandiva/to_date_holder.cc +++ b/cpp/src/gandiva/to_date_holder.cc @@ -45,7 +45,7 @@ Status ToDateHolder::Make(const FunctionNode& node, return Status::Invalid( "'to_date' function requires a string literal as the second parameter"); } - auto pattern = arrow::util::get(literal_pattern->holder()); + auto pattern = std::get(literal_pattern->holder()); int suppress_errors = 0; if (node.children().size() == 3) { @@ -63,7 +63,7 @@ Status ToDateHolder::Make(const FunctionNode& node, "The (optional) third parameter to 'to_date' function needs to an integer " "literal to indicate whether to suppress the error"); } - suppress_errors = arrow::util::get(literal_suppress_errors->holder()); + suppress_errors = std::get(literal_suppress_errors->holder()); } return Make(pattern, suppress_errors, holder); diff --git a/cpp/src/parquet/arrow/path_internal.cc b/cpp/src/parquet/arrow/path_internal.cc index 8002f13e799..0ef9eea1dab 100644 --- a/cpp/src/parquet/arrow/path_internal.cc +++ b/cpp/src/parquet/arrow/path_internal.cc @@ -89,6 +89,7 @@ #include #include #include +#include #include #include "arrow/array.h" @@ -104,7 +105,6 @@ #include "arrow/util/logging.h" #include "arrow/util/macros.h" #include "arrow/util/make_unique.h" -#include "arrow/util/variant.h" #include "arrow/visit_array_inline.h" #include "parquet/properties.h" @@ -519,9 +519,9 @@ struct PathInfo { // The vectors are expected to the same length info. // Note index order matters here. - using Node = ::arrow::util::Variant; + using Node = + std::variant; std::vector path; std::shared_ptr primitive_array; @@ -578,32 +578,32 @@ Status WritePath(ElementRange root_range, PathInfo* path_info, while (stack_position >= stack_base) { PathInfo::Node& node = path_info->path[stack_position - stack_base]; struct { - IterationResult operator()(NullableNode* node) { - return node->Run(stack_position, stack_position + 1, context); + IterationResult operator()(NullableNode& node) { + return node.Run(stack_position, stack_position + 1, context); } - IterationResult operator()(ListNode* node) { - return node->Run(stack_position, stack_position + 1, context); + IterationResult operator()(ListNode& node) { + return node.Run(stack_position, stack_position + 1, context); } - IterationResult operator()(NullableTerminalNode* node) { - return node->Run(*stack_position, context); + IterationResult operator()(NullableTerminalNode& node) { + return node.Run(*stack_position, context); } - IterationResult operator()(FixedSizeListNode* node) { - return node->Run(stack_position, stack_position + 1, context); + IterationResult operator()(FixedSizeListNode& node) { + return node.Run(stack_position, stack_position + 1, context); } - IterationResult operator()(AllPresentTerminalNode* node) { - return node->Run(*stack_position, context); + IterationResult operator()(AllPresentTerminalNode& node) { + return node.Run(*stack_position, context); } - IterationResult operator()(AllNullsTerminalNode* node) { - return node->Run(*stack_position, context); + IterationResult operator()(AllNullsTerminalNode& node) { + return node.Run(*stack_position, context); } - IterationResult operator()(LargeListNode* node) { - return node->Run(stack_position, stack_position + 1, context); + IterationResult operator()(LargeListNode& node) { + return node.Run(stack_position, stack_position + 1, context); } ElementRange* stack_position; PathWriteContext* context; } visitor = {stack_position, &context}; - IterationResult result = ::arrow::util::visit(visitor, &node); + IterationResult result = std::visit(visitor, node); if (ARROW_PREDICT_FALSE(result == kError)) { DCHECK(!context.last_status.ok()); @@ -640,39 +640,39 @@ struct FixupVisitor { int16_t rep_level_if_null = kLevelNotSet; template - void HandleListNode(T* arg) { - if (arg->rep_level() == max_rep_level) { - arg->SetLast(); + void HandleListNode(T& arg) { + if (arg.rep_level() == max_rep_level) { + arg.SetLast(); // after the last list node we don't need to fill // rep levels on null. rep_level_if_null = kLevelNotSet; } else { - rep_level_if_null = arg->rep_level(); + rep_level_if_null = arg.rep_level(); } } - void operator()(ListNode* node) { HandleListNode(node); } - void operator()(LargeListNode* node) { HandleListNode(node); } - void operator()(FixedSizeListNode* node) { HandleListNode(node); } + void operator()(ListNode& node) { HandleListNode(node); } + void operator()(LargeListNode& node) { HandleListNode(node); } + void operator()(FixedSizeListNode& node) { HandleListNode(node); } // For non-list intermediate nodes. template - void HandleIntermediateNode(T* arg) { + void HandleIntermediateNode(T& arg) { if (rep_level_if_null != kLevelNotSet) { - arg->SetRepLevelIfNull(rep_level_if_null); + arg.SetRepLevelIfNull(rep_level_if_null); } } - void operator()(NullableNode* arg) { HandleIntermediateNode(arg); } + void operator()(NullableNode& arg) { HandleIntermediateNode(arg); } - void operator()(AllNullsTerminalNode* arg) { + void operator()(AllNullsTerminalNode& arg) { // Even though no processing happens past this point we // still need to adjust it if a list occurred after an // all null array. HandleIntermediateNode(arg); } - void operator()(NullableTerminalNode*) {} - void operator()(AllPresentTerminalNode*) {} + void operator()(NullableTerminalNode&) {} + void operator()(AllPresentTerminalNode&) {} }; PathInfo Fixup(PathInfo info) { @@ -687,7 +687,7 @@ PathInfo Fixup(PathInfo info) { visitor.rep_level_if_null = 0; } for (size_t x = 0; x < info.path.size(); x++) { - ::arrow::util::visit(visitor, &info.path[x]); + std::visit(visitor, info.path[x]); } return info; } diff --git a/cpp/src/parquet/encryption/key_metadata.h b/cpp/src/parquet/encryption/key_metadata.h index 2281b96e60e..b6dc349f19b 100644 --- a/cpp/src/parquet/encryption/key_metadata.h +++ b/cpp/src/parquet/encryption/key_metadata.h @@ -18,8 +18,7 @@ #pragma once #include - -#include "arrow/util/variant.h" +#include #include "parquet/encryption/key_material.h" #include "parquet/exception.h" @@ -70,14 +69,14 @@ class PARQUET_EXPORT KeyMetadata { if (!is_internal_storage_) { throw ParquetException("key material is stored externally."); } - return ::arrow::util::get(key_material_or_reference_); + return ::std::get(key_material_or_reference_); } const std::string& key_reference() const { if (is_internal_storage_) { throw ParquetException("key material is stored internally."); } - return ::arrow::util::get(key_material_or_reference_); + return ::std::get(key_material_or_reference_); } private: @@ -87,7 +86,7 @@ class PARQUET_EXPORT KeyMetadata { bool is_internal_storage_; /// If is_internal_storage_ is true, KeyMaterial is set, /// else a string referencing to an outside "key material" is set. - ::arrow::util::Variant key_material_or_reference_; + ::std::variant key_material_or_reference_; }; } // namespace encryption diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index 3b9ac54fe9d..6377459404c 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -577,8 +577,8 @@ cdef extern from "arrow/python/flight.h" namespace "arrow::py::flight" nogil: unique_ptr[CSchemaResult]* out) -cdef extern from "arrow/util/variant.h" namespace "arrow" nogil: - cdef cppclass CIntStringVariant" arrow::util::Variant": +cdef extern from "" namespace "std" nogil: + cdef cppclass CIntStringVariant" std::variant": CIntStringVariant() CIntStringVariant(int) CIntStringVariant(c_string) diff --git a/python/pyarrow/src/gdb.cc b/python/pyarrow/src/gdb.cc index 7541e524609..c681dfe9caa 100644 --- a/python/pyarrow/src/gdb.cc +++ b/python/pyarrow/src/gdb.cc @@ -35,7 +35,6 @@ #include "arrow/util/logging.h" #include "arrow/util/macros.h" #include "arrow/util/string_view.h" -#include "arrow/util/variant.h" namespace arrow { @@ -121,13 +120,6 @@ void TestSession() { auto error_result = Result(error_status); auto error_detail_result = Result(error_detail_status); - // Variants - using VariantType = util::Variant; - - VariantType int_variant{42}; - VariantType bool_variant{false}; - VariantType string_variant{std::string("hello")}; - // String views util::string_view string_view_empty{}; util::string_view string_view_abc{"abc"}; diff --git a/python/pyarrow/tests/test_gdb.py b/python/pyarrow/tests/test_gdb.py index 6b76d9b626e..3056cb4326d 100644 --- a/python/pyarrow/tests/test_gdb.py +++ b/python/pyarrow/tests/test_gdb.py @@ -297,19 +297,6 @@ def test_buffer_heap(gdb_arrow): 'arrow::Buffer of size 3, mutable, "abc"') -def test_variants(gdb_arrow): - check_stack_repr( - gdb_arrow, "int_variant", - "arrow::util::Variant of index 0 (actual type int), value 42") - check_stack_repr( - gdb_arrow, "bool_variant", - "arrow::util::Variant of index 1 (actual type bool), value false") - check_stack_repr( - gdb_arrow, "string_variant", - re.compile(r'^arrow::util::Variant of index 2 \(actual type ' - r'std::.*string.*\), value .*"hello".*')) - - def test_decimals(gdb_arrow): v128 = "98765432109876543210987654321098765432" check_stack_repr(gdb_arrow, "decimal128_zero", "arrow::Decimal128(0)") From c58e6a3550a7eb887e2e7911079a3aa48a1f5bbf Mon Sep 17 00:00:00 2001 From: Jacob Wujciak-Jens Date: Fri, 16 Sep 2022 22:22:29 +0200 Subject: [PATCH 091/133] ARROW-17021: [C++][R][CI] Enable use of sccache in crossbow (#13556) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This is a first stab at enabling sccache, I used the R nightly job as an example as it compiles arrow on mac, win and linux. I want to open this for feedback specifically with an eye towards the cmake changes I implemented which enable this as long as the correct envvars (AWS creds + bucket) are set and sccache is available. My hope was to keep the changes required to activate sccache for existing jobs as minimal as possible. While working on this I also noticed that `ThirdpartyToolchain.cmake` is not build using (s)ccache, the flags had to be explicitly set. That fix alone should speed up CI across the board (in arrow and crossbow). For the R task this results in a build time reduction of 40-70% 🚀 Here a run of the same task without caching https://github.com/ursacomputing/crossbow/actions/runs/2563558583 cc: @kszucs @raulcd @kou Thanks for your input! Authored-by: Jacob Wujciak-Jens Signed-off-by: Sutou Kouhei --- ci/docker/centos-7-cpp.dockerfile | 15 +++-- ci/docker/conda-cpp.dockerfile | 3 + ci/docker/conda.dockerfile | 2 +- ci/docker/debian-10-cpp.dockerfile | 4 ++ ci/docker/debian-11-cpp.dockerfile | 4 ++ ci/docker/fedora-35-cpp.dockerfile | 4 ++ ci/docker/ubuntu-18.04-cpp.dockerfile | 5 ++ ci/docker/ubuntu-20.04-cpp-minimal.dockerfile | 4 ++ ci/docker/ubuntu-20.04-cpp.dockerfile | 4 ++ ci/docker/ubuntu-22.04-cpp-minimal.dockerfile | 4 ++ ci/docker/ubuntu-22.04-cpp.dockerfile | 4 ++ ci/scripts/PKGBUILD | 3 + ci/scripts/cpp_build.sh | 5 ++ ci/scripts/download_tz_database.sh | 0 ci/scripts/install_sccache.sh | 53 +++++++++++++++++ cpp/CMakeLists.txt | 31 +++++++++- cpp/cmake_modules/DefineOptions.cmake | 3 + dev/tasks/docker-tests/github.linux.yml | 2 + dev/tasks/macros.jinja | 13 ++++ dev/tasks/r/github.packages.yml | 26 ++++++-- dev/tasks/tasks.yml | 4 +- docker-compose.yml | 59 +++++++++++-------- r/inst/build_arrow_static.sh | 5 ++ 23 files changed, 219 insertions(+), 38 deletions(-) mode change 100644 => 100755 ci/scripts/download_tz_database.sh create mode 100755 ci/scripts/install_sccache.sh diff --git a/ci/docker/centos-7-cpp.dockerfile b/ci/docker/centos-7-cpp.dockerfile index 09a3234e3f8..2c8f867333c 100644 --- a/ci/docker/centos-7-cpp.dockerfile +++ b/ci/docker/centos-7-cpp.dockerfile @@ -18,6 +18,7 @@ FROM centos:centos7 RUN yum install -y \ + curl \ diffutils \ gcc-c++ \ libcurl-devel \ @@ -31,8 +32,12 @@ ARG cmake=3.23.1 RUN mkdir /opt/cmake-${cmake} RUN wget -nv -O - https://github.com/Kitware/CMake/releases/download/v${cmake}/cmake-${cmake}-Linux-x86_64.tar.gz | \ tar -xzf - --strip-components=1 -C /opt/cmake-${cmake} -ENV PATH=/opt/cmake-${cmake}/bin:$PATH -ENV CC=/usr/bin/gcc -ENV CXX=/usr/bin/g++ -ENV EXTRA_CMAKE_FLAGS="-DCMAKE_C_COMPILER=$CC -DCMAKE_CXX_COMPILER=$CXX" -ENV ARROW_R_DEV=TRUE + +COPY ci/scripts/install_sccache.sh /arrow/ci/scripts/ +RUN bash /arrow/ci/scripts/install_sccache.sh unknown-linux-musl /usr/local/bin + +ENV PATH=/opt/cmake-${cmake}/bin:$PATH \ + CC=/usr/bin/gcc \ + CXX=/usr/bin/g++ \ + EXTRA_CMAKE_FLAGS="-DCMAKE_C_COMPILER=$CC -DCMAKE_CXX_COMPILER=$CXX" \ + ARROW_R_DEV=TRUE \ \ No newline at end of file diff --git a/ci/docker/conda-cpp.dockerfile b/ci/docker/conda-cpp.dockerfile index 72a839cf57c..688c7a4997c 100644 --- a/ci/docker/conda-cpp.dockerfile +++ b/ci/docker/conda-cpp.dockerfile @@ -38,6 +38,9 @@ RUN mamba install -q -y \ COPY ci/scripts/install_gcs_testbench.sh /arrow/ci/scripts RUN /arrow/ci/scripts/install_gcs_testbench.sh default +COPY ci/scripts/install_sccache.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_sccache.sh unknown-linux-musl /usr/local/bin + ENV ARROW_BUILD_TESTS=ON \ ARROW_DATASET=ON \ ARROW_DEPENDENCY_SOURCE=CONDA \ diff --git a/ci/docker/conda.dockerfile b/ci/docker/conda.dockerfile index d0545e3bf84..af7a2eceab9 100644 --- a/ci/docker/conda.dockerfile +++ b/ci/docker/conda.dockerfile @@ -21,7 +21,7 @@ FROM ${arch}/ubuntu:18.04 # install build essentials RUN export DEBIAN_FRONTEND=noninteractive && \ apt-get update -y -q && \ - apt-get install -y -q wget tzdata libc6-dbg gdb \ + apt-get install -y -q curl wget tzdata libc6-dbg gdb \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* diff --git a/ci/docker/debian-10-cpp.dockerfile b/ci/docker/debian-10-cpp.dockerfile index a0872928c57..57acc77a9e1 100644 --- a/ci/docker/debian-10-cpp.dockerfile +++ b/ci/docker/debian-10-cpp.dockerfile @@ -40,6 +40,7 @@ RUN apt-get update -y -q && \ ccache \ clang-${llvm} \ cmake \ + curl \ g++ \ gcc \ gdb \ @@ -76,6 +77,9 @@ RUN apt-get update -y -q && \ COPY ci/scripts/install_minio.sh /arrow/ci/scripts/ RUN /arrow/ci/scripts/install_minio.sh latest /usr/local +COPY ci/scripts/install_sccache.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_sccache.sh unknown-linux-musl /usr/local/bin + ENV absl_SOURCE=BUNDLED \ ARROW_BUILD_TESTS=ON \ ARROW_DATASET=ON \ diff --git a/ci/docker/debian-11-cpp.dockerfile b/ci/docker/debian-11-cpp.dockerfile index a403df2368f..5051ae7f003 100644 --- a/ci/docker/debian-11-cpp.dockerfile +++ b/ci/docker/debian-11-cpp.dockerfile @@ -37,6 +37,7 @@ RUN apt-get update -y -q && \ ccache \ clang-${llvm} \ cmake \ + curl \ g++ \ gcc \ gdb \ @@ -78,6 +79,9 @@ RUN /arrow/ci/scripts/install_minio.sh latest /usr/local COPY ci/scripts/install_gcs_testbench.sh /arrow/ci/scripts/ RUN /arrow/ci/scripts/install_gcs_testbench.sh default +COPY ci/scripts/install_sccache.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_sccache.sh unknown-linux-musl /usr/local/bin + ENV absl_SOURCE=BUNDLED \ ARROW_BUILD_TESTS=ON \ ARROW_DATASET=ON \ diff --git a/ci/docker/fedora-35-cpp.dockerfile b/ci/docker/fedora-35-cpp.dockerfile index ce9c8857c85..23844580b1a 100644 --- a/ci/docker/fedora-35-cpp.dockerfile +++ b/ci/docker/fedora-35-cpp.dockerfile @@ -30,6 +30,7 @@ RUN dnf update -y && \ ccache \ clang-devel \ cmake \ + curl \ curl-devel \ flatbuffers-devel \ gcc \ @@ -71,6 +72,9 @@ RUN /arrow/ci/scripts/install_minio.sh latest /usr/local COPY ci/scripts/install_gcs_testbench.sh /arrow/ci/scripts/ RUN /arrow/ci/scripts/install_gcs_testbench.sh default +COPY ci/scripts/install_sccache.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_sccache.sh unknown-linux-musl /usr/local/bin + ENV absl_SOURCE=BUNDLED \ ARROW_BUILD_TESTS=ON \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ diff --git a/ci/docker/ubuntu-18.04-cpp.dockerfile b/ci/docker/ubuntu-18.04-cpp.dockerfile index 0e20b7c6a83..2a056eda539 100644 --- a/ci/docker/ubuntu-18.04-cpp.dockerfile +++ b/ci/docker/ubuntu-18.04-cpp.dockerfile @@ -60,6 +60,7 @@ RUN apt-get update -y -q && \ ca-certificates \ ccache \ cmake \ + curl \ g++ \ gcc \ gdb \ @@ -100,6 +101,10 @@ RUN apt-get update -y -q && \ # - s3 tests would require boost-asio that is included since Boost 1.66.0 # ARROW-17051: this build uses static Protobuf, so we must also use # static Arrow to run Flight/Flight SQL tests + +COPY ci/scripts/install_sccache.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_sccache.sh unknown-linux-musl /usr/local/bin + ENV ARROW_BUILD_STATIC=ON \ ARROW_BUILD_TESTS=ON \ ARROW_DATASET=ON \ diff --git a/ci/docker/ubuntu-20.04-cpp-minimal.dockerfile b/ci/docker/ubuntu-20.04-cpp-minimal.dockerfile index f77ff40e5fb..ca2be2873d6 100644 --- a/ci/docker/ubuntu-20.04-cpp-minimal.dockerfile +++ b/ci/docker/ubuntu-20.04-cpp-minimal.dockerfile @@ -28,6 +28,7 @@ RUN apt-get update -y -q && \ build-essential \ ccache \ cmake \ + curl \ git \ libssl-dev \ libcurl4-openssl-dev \ @@ -70,6 +71,9 @@ RUN /arrow/ci/scripts/install_minio.sh latest /usr/local COPY ci/scripts/install_gcs_testbench.sh /arrow/ci/scripts/ RUN /arrow/ci/scripts/install_gcs_testbench.sh default +COPY ci/scripts/install_sccache.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_sccache.sh unknown-linux-musl /usr/local/bin + ENV ARROW_BUILD_TESTS=ON \ ARROW_DATASET=ON \ ARROW_FLIGHT=ON \ diff --git a/ci/docker/ubuntu-20.04-cpp.dockerfile b/ci/docker/ubuntu-20.04-cpp.dockerfile index dd36aff84c5..1cd0581aa44 100644 --- a/ci/docker/ubuntu-20.04-cpp.dockerfile +++ b/ci/docker/ubuntu-20.04-cpp.dockerfile @@ -68,6 +68,7 @@ RUN apt-get update -y -q && \ ca-certificates \ ccache \ cmake \ + curl \ g++ \ gcc \ gdb \ @@ -116,6 +117,9 @@ RUN /arrow/ci/scripts/install_gcs_testbench.sh default COPY ci/scripts/install_ceph.sh /arrow/ci/scripts/ RUN /arrow/ci/scripts/install_ceph.sh +COPY ci/scripts/install_sccache.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_sccache.sh unknown-linux-musl /usr/local/bin + # Prioritize system packages and local installation # The following dependencies will be downloaded due to missing/invalid packages # provided by the distribution: diff --git a/ci/docker/ubuntu-22.04-cpp-minimal.dockerfile b/ci/docker/ubuntu-22.04-cpp-minimal.dockerfile index 8bc5ab3e484..f0dc76c65f9 100644 --- a/ci/docker/ubuntu-22.04-cpp-minimal.dockerfile +++ b/ci/docker/ubuntu-22.04-cpp-minimal.dockerfile @@ -28,6 +28,7 @@ RUN apt-get update -y -q && \ build-essential \ ccache \ cmake \ + curl \ git \ libssl-dev \ libcurl4-openssl-dev \ @@ -70,6 +71,9 @@ RUN /arrow/ci/scripts/install_minio.sh latest /usr/local COPY ci/scripts/install_gcs_testbench.sh /arrow/ci/scripts/ RUN /arrow/ci/scripts/install_gcs_testbench.sh default +COPY ci/scripts/install_sccache.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_sccache.sh unknown-linux-musl /usr/local/bin + ENV ARROW_BUILD_TESTS=ON \ ARROW_DATASET=ON \ ARROW_FLIGHT=ON \ diff --git a/ci/docker/ubuntu-22.04-cpp.dockerfile b/ci/docker/ubuntu-22.04-cpp.dockerfile index 514e314c402..4bbb5c2b317 100644 --- a/ci/docker/ubuntu-22.04-cpp.dockerfile +++ b/ci/docker/ubuntu-22.04-cpp.dockerfile @@ -68,6 +68,7 @@ RUN apt-get update -y -q && \ ca-certificates \ ccache \ cmake \ + curl \ gdb \ git \ libbenchmark-dev \ @@ -143,6 +144,9 @@ RUN /arrow/ci/scripts/install_minio.sh latest /usr/local COPY ci/scripts/install_gcs_testbench.sh /arrow/ci/scripts/ RUN /arrow/ci/scripts/install_gcs_testbench.sh default +COPY ci/scripts/install_sccache.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_sccache.sh unknown-linux-musl /usr/local/bin + # Prioritize system packages and local installation # The following dependencies will be downloaded due to missing/invalid packages # provided by the distribution: diff --git a/ci/scripts/PKGBUILD b/ci/scripts/PKGBUILD index 81822cc4eb4..72173d040aa 100644 --- a/ci/scripts/PKGBUILD +++ b/ci/scripts/PKGBUILD @@ -73,6 +73,9 @@ build() { # set the appropriate compiler definition. export CPPFLAGS="-DUTF8PROC_STATIC" + # CMAKE_UNITY_BUILD is set to OFF as otherwise some compute functionality + # segfaults in tests + MSYS2_ARG_CONV_EXCL="-DCMAKE_INSTALL_PREFIX=" \ ${MINGW_PREFIX}/bin/cmake.exe \ ${ARROW_CPP_DIR} \ diff --git a/ci/scripts/cpp_build.sh b/ci/scripts/cpp_build.sh index 68437938b5e..5f29381fd1b 100755 --- a/ci/scripts/cpp_build.sh +++ b/ci/scripts/cpp_build.sh @@ -181,6 +181,11 @@ if [ "${ARROW_USE_CCACHE}" == "ON" ]; then ccache -s fi +if command -v sccache &> /dev/null; then + echo "=== sccache stats after the build ===" + sccache --show-stats +fi + if [ "${BUILD_DOCS_CPP}" == "ON" ]; then pushd ${source_dir}/apidoc doxygen diff --git a/ci/scripts/download_tz_database.sh b/ci/scripts/download_tz_database.sh old mode 100644 new mode 100755 diff --git a/ci/scripts/install_sccache.sh b/ci/scripts/install_sccache.sh new file mode 100755 index 00000000000..e5af790084b --- /dev/null +++ b/ci/scripts/install_sccache.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e + +if [ "$#" -lt 1 -o "$#" -gt 3 ]; then + echo "Usage: $0 " + echo "Will default to arch=x86_64 and version=0.3.0 " + exit 1 +fi + +BUILD=$1 +PREFIX=$2 +ARCH=${3:-x86_64} +VERSION=${4:-0.3.0} + +SCCACHE_URL="https://github.com/mozilla/sccache/releases/download/v$VERSION/sccache-v$VERSION-$ARCH-$BUILD.tar.gz" +SCCACHE_ARCHIVE=sccache.tar.gz + +# Download archive and checksum +curl -L $SCCACHE_URL --output $SCCACHE_ARCHIVE +curl -L $SCCACHE_URL.sha256 --output $SCCACHE_ARCHIVE.sha256 + +echo "$(cat $SCCACHE_ARCHIVE.sha256) $SCCACHE_ARCHIVE" | sha256sum --check --status + +if [ ! -d $PREFIX ]; then + mkdir -p $PREFIX +fi + +tar -xzvf $SCCACHE_ARCHIVE --strip-component=1 --directory $PREFIX --wildcards sccache*/sccache* +chmod u+x $PREFIX/sccache + +if [ "${GITHUB_ACTIONS}" = "true" ]; then + echo "$PREFIX" >> $GITHUB_PATH + # Add executable for windows as mingw workaround. + echo "SCCACHE_PATH=$PREFIX/sccache.exe" >> $GITHUB_ENV +fi diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index d67142e0569..a804ee11f46 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -192,10 +192,39 @@ else() set(PYTHON_EXECUTABLE ${Python3_EXECUTABLE}) endif() +if(ARROW_USE_SCCACHE + AND NOT CMAKE_C_COMPILER_LAUNCHER + AND NOT CMAKE_CXX_COMPILER_LAUNCHER) + + find_program(SCCACHE_FOUND sccache) + + if(NOT SCCACHE_FOUND AND DEFINED ENV{SCCACHE_PATH}) + # cmake has problems finding sccache from within mingw + message(STATUS "Did not find sccache, using envvar fallback.") + set(SCCACHE_FOUND $ENV{SCCACHE_PATH}) + endif() + + # Only use sccache if a storage backend is configured + if(SCCACHE_FOUND + AND (DEFINED ENV{SCCACHE_AZURE_BLOB_CONTAINER} + OR DEFINED ENV{SCCACHE_BUCKET} + OR DEFINED ENV{SCCACHE_DIR} + OR DEFINED ENV{SCCACHE_GCS_BUCKET} + OR DEFINED ENV{SCCACHE_MEMCACHED} + OR DEFINED ENV{SCCACHE_REDIS} + )) + message(STATUS "Using sccache: ${SCCACHE_FOUND}") + set(CMAKE_C_COMPILER_LAUNCHER ${SCCACHE_FOUND}) + set(CMAKE_CXX_COMPILER_LAUNCHER ${SCCACHE_FOUND}) + endif() +endif() + if(ARROW_USE_CCACHE AND NOT CMAKE_C_COMPILER_LAUNCHER AND NOT CMAKE_CXX_COMPILER_LAUNCHER) + find_program(CCACHE_FOUND ccache) + if(CCACHE_FOUND) message(STATUS "Using ccache: ${CCACHE_FOUND}") set(CMAKE_C_COMPILER_LAUNCHER ${CCACHE_FOUND}) @@ -203,7 +232,7 @@ if(ARROW_USE_CCACHE # ARROW-3985: let ccache preserve C++ comments, because some of them may be # meaningful to the compiler set(ENV{CCACHE_COMMENTS} "1") - endif(CCACHE_FOUND) + endif() endif() if(ARROW_USE_PRECOMPILED_HEADERS AND ${CMAKE_VERSION} VERSION_LESS "3.16.0") diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index 0dbf4cb843e..ffa1b13c904 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -114,6 +114,9 @@ if(ARROW_DEFINE_OPTIONS) define_option(ARROW_USE_CCACHE "Use ccache when compiling (if available)" ON) + define_option(ARROW_USE_SCCACHE "Use sccache when compiling (if available),;\ +takes precedence over ccache if a storage backend is configured" ON) + define_option(ARROW_USE_LD_GOLD "Use ld.gold for linking on Linux (if available)" OFF) define_option(ARROW_USE_PRECOMPILED_HEADERS "Use precompiled headers when compiling" diff --git a/dev/tasks/docker-tests/github.linux.yml b/dev/tasks/docker-tests/github.linux.yml index 638d846e410..30a6814895f 100644 --- a/dev/tasks/docker-tests/github.linux.yml +++ b/dev/tasks/docker-tests/github.linux.yml @@ -31,6 +31,8 @@ jobs: - name: Execute Docker Build shell: bash + env: + {{ macros.github_set_sccache_envvars()|indent(8) }} run: | archery docker run \ -e SETUPTOOLS_SCM_PRETEND_VERSION="{{ arrow.no_rc_version }}" \ diff --git a/dev/tasks/macros.jinja b/dev/tasks/macros.jinja index fd46555b28b..f68aa15d543 100644 --- a/dev/tasks/macros.jinja +++ b/dev/tasks/macros.jinja @@ -365,3 +365,16 @@ on: {% endfor %} {% endif %} {%- endmacro -%} + +{% macro github_set_sccache_envvars(sccache_key_prefix = "sccache") %} + {% set sccache_vars = { + "AWS_SECRET_ACCESS_KEY": '${{ secrets.AWS_SECRET_ACCESS_KEY }}', + "AWS_ACCESS_KEY_ID": '${{ secrets.AWS_ACCESS_KEY_ID }}', + "SCCACHE_BUCKET": '${{ secrets.SCCACHE_BUCKET }}', + "SCCACHE_S3_KEY_PREFIX": sccache_key_prefix + } + %} + {% for key, value in sccache_vars.items() %} + {{ key }}: "{{ value }}" + {% endfor %} +{% endmacro %} diff --git a/dev/tasks/r/github.packages.yml b/dev/tasks/r/github.packages.yml index 49ca49567a9..e7c3418da91 100644 --- a/dev/tasks/r/github.packages.yml +++ b/dev/tasks/r/github.packages.yml @@ -78,10 +78,12 @@ jobs: {{ macros.github_checkout_arrow()|indent }} {{ macros.github_change_r_pkg_version(is_fork, '${{ needs.source.outputs.pkg_version }}')|indent }} {{ macros.github_install_archery()|indent }} + - name: Build libarrow shell: bash env: - UBUNTU: {{ '${{ matrix.config.version}}' }} + UBUNTU: {{ '"${{ matrix.config.version }}"' }} + {{ macros.github_set_sccache_envvars()|indent(8) }} run: | sudo sysctl -w kernel.core_pattern="core.%e.%p" ulimit -c unlimited @@ -110,19 +112,20 @@ jobs: - run: git config --global core.autocrlf false {{ macros.github_checkout_arrow()|indent }} {{ macros.github_change_r_pkg_version(is_fork, '${{ needs.source.outputs.pkg_version }}')|indent }} - - uses: r-lib/actions/setup-r@v2 with: rtools-version: 40 r-version: "4.0" Ncpus: 2 - + - name: Install sccache + shell: bash + run: arrow/ci/scripts/install_sccache.sh pc-windows-msvc $(pwd)/sccache - name: Build Arrow C++ with rtools40 shell: bash env: ARROW_HOME: "arrow" + {{ macros.github_set_sccache_envvars()|indent(8) }} run: arrow/ci/scripts/r_windows_build.sh - - name: Upload binary artifact uses: actions/upload-artifact@v3 with: @@ -174,12 +177,17 @@ jobs: with: working-directory: 'arrow' extra-packages: cpp11 + - name: Install sccache + if: startsWith(matrix.platform, 'macos') + run: brew install sccache - name: Build Binary id: build shell: Rscript {0} env: NOT_CRAN: "true" # actions/setup-r sets this implicitly ARROW_R_DEV: TRUE + # sccache for macos + {{ macros.github_set_sccache_envvars()|indent(8) }} run: | on_windows <- tolower(Sys.info()[["sysname"]]) == "windows" @@ -295,10 +303,18 @@ jobs: with: install-r: false {{ macros.github_setup_local_r_repo(false, false)|indent }} + - name: Install sccache + shell: bash + run: | + curl -s \ + https://raw.githubusercontent.com/{{ arrow.github_repo }}/{{ arrow.head }}/ci/scripts/install_sccache.sh | \ + bash -s unknown-linux-musl /usr/local/bin + - run: sudo apt update && sudo apt install libcurl4-openssl-dev - name: Install arrow from nightly repo env: # Test source build so be sure not to download a binary LIBARROW_BINARY: "FALSE" + {{ macros.github_set_sccache_envvars()|indent(8) }} shell: Rscript {0} run: | {{ macros.github_test_r_src_pkg()|indent(8) }} @@ -308,6 +324,8 @@ jobs: env: LIBARROW_BINARY: "FALSE" ARROW_R_DEV: "TRUE" + {{ macros.github_set_sccache_envvars()|indent(8) }} + shell: Rscript {0} run: | {{ macros.github_test_r_src_pkg()|indent(8) }} diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index c00d101f43e..aee50ef9ce9 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -956,7 +956,9 @@ tasks: custom_version: Unset artifacts: - r-lib__libarrow__bin__windows__arrow-[0-9\.]+\.zip - - r-lib__libarrow__bin__centos-7__arrow-[0-9\.]+\.zip + # The centos job is currently disabled due to the + # change to C++ 17 + #- r-lib__libarrow__bin__centos-7__arrow-[0-9\.]+\.zip - r-lib__libarrow__bin__ubuntu-18.04__arrow-[0-9\.]+\.zip - r-lib__libarrow__bin__ubuntu-22.04__arrow-[0-9\.]+\.zip - r-pkg__bin__windows__contrib__4.1__arrow_[0-9\.]+\.zip diff --git a/docker-compose.yml b/docker-compose.yml index 67dfd87512e..a022757939b 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -62,6 +62,12 @@ x-ccache: &ccache CCACHE_MAXSIZE: 1G CCACHE_DIR: /ccache +x-sccache: &sccache + AWS_ACCESS_KEY_ID: + AWS_SECRET_ACCESS_KEY: + SCCACHE_BUCKET: + SCCACHE_S3_KEY_PREFIX: ${SCCACHE_S3_KEY_PREFIX:-sccache} + # CPU/memory limit presets to pass to Docker. # # Usage: archery docker run --resource-limit=github @@ -264,7 +270,7 @@ services: shm_size: *shm-size ulimits: *ulimits environment: - <<: *ccache + <<: [*ccache, *sccache] ARROW_BUILD_BENCHMARKS: "ON" ARROW_BUILD_EXAMPLES: "ON" ARROW_ENABLE_TIMING_TESTS: # inherit @@ -276,8 +282,9 @@ services: volumes: &conda-volumes - .:/arrow:delegated - ${DOCKER_VOLUME_PREFIX}conda-ccache:/ccache:delegated - command: - ["/arrow/ci/scripts/cpp_build.sh /arrow /build && + command: &conda-cpp-command + [" + /arrow/ci/scripts/cpp_build.sh /arrow /build && /arrow/ci/scripts/cpp_test.sh /arrow /build"] conda-cpp-valgrind: @@ -298,7 +305,7 @@ services: arch: ${ARCH} shm_size: *shm-size environment: - <<: *ccache + <<: [*ccache, *sccache] ARROW_CXXFLAGS: "-Og" # Shrink test runtime by enabling minimal optimizations ARROW_ENABLE_TIMING_TESTS: # inherit ARROW_FLIGHT: "OFF" @@ -311,9 +318,7 @@ services: ARROW_USE_LD_GOLD: "ON" BUILD_WARNING_LEVEL: "PRODUCTION" volumes: *conda-volumes - command: - ["/arrow/ci/scripts/cpp_build.sh /arrow /build && - /arrow/ci/scripts/cpp_test.sh /arrow /build"] + command: *conda-cpp-command debian-cpp: # Usage: @@ -334,7 +339,7 @@ services: shm_size: *shm-size ulimits: *ulimits environment: - <<: *ccache + <<: [*ccache, *sccache] ARROW_ENABLE_TIMING_TESTS: # inherit ARROW_MIMALLOC: "ON" volumes: &debian-volumes @@ -375,7 +380,7 @@ services: - apparmor:unconfined ulimits: *ulimits environment: - <<: *ccache + <<: [*ccache, *sccache] ARROW_ENABLE_TIMING_TESTS: # inherit ARROW_MIMALLOC: "ON" volumes: &ubuntu-volumes @@ -411,14 +416,16 @@ services: - apparmor:unconfined ulimits: *ulimits environment: - <<: *ccache + <<: [*ccache, *sccache] ARROW_HOME: /arrow ARROW_DEPENDENCY_SOURCE: BUNDLED LIBARROW_MINIMAL: "false" ARROW_MIMALLOC: "ON" volumes: *ubuntu-volumes - command: /bin/bash -c " - cd /arrow && r/inst/build_arrow_static.sh" + command: &cpp-static-command + /bin/bash -c " + cd /arrow && + r/inst/build_arrow_static.sh" centos-cpp-static: image: ${REPO}:centos-7-cpp-static @@ -431,7 +438,7 @@ services: volumes: - .:/arrow:delegated environment: - <<: *ccache + <<: [*ccache, *sccache] ARROW_DEPENDENCY_SOURCE: BUNDLED ARROW_HOME: /arrow LIBARROW_MINIMAL: "false" @@ -439,9 +446,7 @@ services: ARROW_GCS: "OFF" ARROW_MIMALLOC: "OFF" ARROW_S3: "OFF" - command: > - /bin/bash -c " - cd /arrow && r/inst/build_arrow_static.sh" + command: *cpp-static-command ubuntu-cpp-bundled: # Arrow build with BUNDLED dependencies @@ -458,7 +463,7 @@ services: shm_size: *shm-size ulimits: *ulimits environment: - <<: *ccache + <<: [*ccache, *sccache] ARROW_DEPENDENCY_SOURCE: BUNDLED CMAKE_GENERATOR: "Unix Makefiles" volumes: *ubuntu-volumes @@ -556,7 +561,7 @@ services: shm_size: *shm-size volumes: *ubuntu-volumes environment: - <<: *ccache + <<: [*ccache, *sccache] CC: clang-${CLANG_TOOLS} CXX: clang++-${CLANG_TOOLS} ARROW_BUILD_STATIC: "OFF" @@ -587,7 +592,7 @@ services: shm_size: *shm-size ulimits: *ulimits environment: - <<: *ccache + <<: [*ccache, *sccache] ARROW_ENABLE_TIMING_TESTS: # inherit ARROW_MIMALLOC: "ON" Protobuf_SOURCE: "BUNDLED" # Need Protobuf >= 3.15 @@ -606,12 +611,16 @@ services: # See https://github.com/conan-io/conan-docker-tools#readme for # available images. image: conanio/${CONAN} + user: root:root shm_size: *shm-size ulimits: *ulimits + environment: + <<: *sccache volumes: - .:/arrow:delegated command: >- /bin/bash -c " + /arrow/ci/scripts/install_sccache.sh unknown-linux-musl /usr/local/bin && /arrow/ci/scripts/conan_setup.sh && /arrow/ci/scripts/conan_build.sh /arrow /build" @@ -751,10 +760,11 @@ services: python: ${PYTHON} shm_size: *shm-size environment: - <<: *ccache + <<: [*ccache, *sccache] volumes: *conda-volumes command: &python-conda-command - ["/arrow/ci/scripts/cpp_build.sh /arrow /build && + [" + /arrow/ci/scripts/cpp_build.sh /arrow /build && /arrow/ci/scripts/python_build.sh /arrow /build && /arrow/ci/scripts/python_test.sh /arrow"] @@ -1034,16 +1044,13 @@ services: pandas: ${PANDAS} shm_size: *shm-size environment: - <<: *ccache + <<: [*ccache, *sccache] PARQUET_REQUIRE_ENCRYPTION: # inherit PYTEST_ARGS: # inherit HYPOTHESIS_PROFILE: # inherit PYARROW_TEST_HYPOTHESIS: # inherit volumes: *conda-volumes - command: - ["/arrow/ci/scripts/cpp_build.sh /arrow /build && - /arrow/ci/scripts/python_build.sh /arrow /build && - /arrow/ci/scripts/python_test.sh /arrow"] + command: *python-conda-command conda-python-docs: # Usage: diff --git a/r/inst/build_arrow_static.sh b/r/inst/build_arrow_static.sh index 3e6b0546b1c..96df9f7c068 100755 --- a/r/inst/build_arrow_static.sh +++ b/r/inst/build_arrow_static.sh @@ -87,4 +87,9 @@ ${CMAKE} -DARROW_BOOST_USE_SHARED=OFF \ ${CMAKE} --build . --target install +if command -v sccache &> /dev/null; then + echo "=== sccache stats after the build ===" + sccache --show-stats +fi + popd From edf03109ec519dcb1fce815d9d6898cce1ca21a6 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sat, 17 Sep 2022 13:34:19 +0900 Subject: [PATCH 092/133] ARROW-17501: [Python][wheel] Use old AWS SDK C++ (#14157) Because the latest AWS SDK C++ has a problem: https://github.com/aws/aws-sdk-cpp/issues/1809 Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- ci/docker/python-wheel-manylinux-201x.dockerfile | 3 +-- ci/scripts/python_wheel_manylinux_build.sh | 4 ++++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ci/docker/python-wheel-manylinux-201x.dockerfile b/ci/docker/python-wheel-manylinux-201x.dockerfile index 4f74b8b1c59..adab10da623 100644 --- a/ci/docker/python-wheel-manylinux-201x.dockerfile +++ b/ci/docker/python-wheel-manylinux-201x.dockerfile @@ -75,8 +75,7 @@ RUN vcpkg install \ --x-feature=flight \ --x-feature=gcs \ --x-feature=json \ - --x-feature=parquet \ - --x-feature=s3 + --x-feature=parquet ARG python=3.8 ENV PYTHON_VERSION=${python} diff --git a/ci/scripts/python_wheel_manylinux_build.sh b/ci/scripts/python_wheel_manylinux_build.sh index 6fe26134fdb..4953da57905 100755 --- a/ci/scripts/python_wheel_manylinux_build.sh +++ b/ci/scripts/python_wheel_manylinux_build.sh @@ -85,6 +85,9 @@ fi mkdir /tmp/arrow-build pushd /tmp/arrow-build +# ARROW-17501: We can remove -DAWSSDK_SOURCE=BUNDLED once +# https://github.com/aws/aws-sdk-cpp/issues/1809 is fixed and vcpkg +# ships the fix. cmake \ -DARROW_BROTLI_USE_SHARED=OFF \ -DARROW_BUILD_SHARED=ON \ @@ -117,6 +120,7 @@ cmake \ -DARROW_WITH_SNAPPY=${ARROW_WITH_SNAPPY} \ -DARROW_WITH_ZLIB=${ARROW_WITH_ZLIB} \ -DARROW_WITH_ZSTD=${ARROW_WITH_ZSTD} \ + -DAWSSDK_SOURCE=BUNDLED \ -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ -DCMAKE_INSTALL_LIBDIR=lib \ -DCMAKE_INSTALL_PREFIX=/tmp/arrow-dist \ From 02b24f72920174d554e61c67bdf33a5ee4764bd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ra=C3=BAl=20Cumplido?= Date: Sat, 17 Sep 2022 06:36:39 +0200 Subject: [PATCH 093/133] ARROW-17715: [CI][C++][Python] Temporary allow failures for s390x on travis (#14138) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The s390x builds currently time out on Travis-CI while compiling Arrow C++. I'll create a new ticket to follow this up but on the meantime we probably should allow them to fail. Authored-by: Raúl Cumplido Signed-off-by: Sutou Kouhei --- .travis.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.travis.yml b/.travis.yml index 5038f66181a..a15822b4a74 100644 --- a/.travis.yml +++ b/.travis.yml @@ -161,6 +161,8 @@ jobs: allow_failures: - name: "Java on s390x" + - name: "C++ on s390x" + - name: "Python on s390x" before_install: - eval "$(python ci/detect-changes.py)" From 80676df0ed1da98afc89b34b722164fde40d560e Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sun, 18 Sep 2022 09:45:18 +0900 Subject: [PATCH 094/133] ARROW-17764: [CI][C++] "#include " is missing (#14161) https://github.com/ursacomputing/crossbow/actions/runs/3073442125/jobs/4965573117#step:5:8420 FAILED: src/arrow/flight/sql/CMakeFiles/arrow_flight_sql_objlib.dir/sql_info_internal.cc.o /usr/bin/c++ .../arrow/cpp/src/arrow/flight/sql/sql_info_internal.cc In file included from /arrow/cpp/src/arrow/flight/sql/sql_info_internal.h:20, from /arrow/cpp/src/arrow/flight/sql/sql_info_internal.cc:18: /arrow/cpp/src/arrow/flight/sql/types.h:899:8: error: 'optional' in namespace 'std' does not name a template type 899 | std::optional catalog; | ^~~~~~~~ /arrow/cpp/src/arrow/flight/sql/types.h:29:1: note: 'std::optional' is defined in header ''; did you forget to '#include '? 28 | #include "arrow/type_fwd.h" +++ |+#include 29 | /arrow/cpp/src/arrow/flight/sql/types.h:901:8: error: 'optional' in namespace 'std' does not name a template type 901 | std::optional db_schema; | ^~~~~~~~ /arrow/cpp/src/arrow/flight/sql/types.h:901:3: note: 'std::optional' is defined in header ''; did you forget to '#include '? 901 | std::optional db_schema; | ^~~ Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- cpp/src/arrow/flight/sql/types.h | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/arrow/flight/sql/types.h b/cpp/src/arrow/flight/sql/types.h index 79c34fa5581..293b1d5579e 100644 --- a/cpp/src/arrow/flight/sql/types.h +++ b/cpp/src/arrow/flight/sql/types.h @@ -19,6 +19,7 @@ #include #include +#include #include #include #include From 99179603e6db36b65750bf2d1fbc70545baaffeb Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Sun, 18 Sep 2022 14:38:08 +0900 Subject: [PATCH 095/133] ARROW-17560: [Java][Gandiva] Move JNI build configuration from cpp/ to java/ (#14159) Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- ci/docker/fedora-35-cpp.dockerfile | 1 - ci/docker/java-jni-manylinux-201x.dockerfile | 3 +- ci/docker/linux-apt-jni.dockerfile | 1 - ci/scripts/cpp_build.sh | 1 - ci/scripts/java_build.sh | 2 +- ci/scripts/java_jni_build.sh | 1 + ci/scripts/java_jni_macos_build.sh | 3 - ci/scripts/java_jni_manylinux_build.sh | 3 - cpp/build-support/lint_exclusions.txt | 5 +- cpp/cmake_modules/DefineOptions.cmake | 2 - cpp/cmake_modules/ThirdpartyToolchain.cmake | 10 +- cpp/src/arrow/ArrowConfig.cmake.in | 12 +- cpp/src/gandiva/CMakeLists.txt | 8 +- cpp/src/gandiva/GandivaConfig.cmake.in | 15 +++ cpp/src/gandiva/jni/CMakeLists.txt | 109 ------------------ docs/source/developers/java/building.rst | 41 ++----- java/CMakeLists.txt | 4 + java/gandiva/CMakeLists.txt | 107 +++++++++++------ java/gandiva/pom.xml | 2 +- {cpp/src => java}/gandiva/proto/Types.proto | 0 .../gandiva/src/main/cpp}/config_builder.cc | 9 +- .../gandiva/src/main/cpp}/config_holder.cc | 4 +- .../gandiva/src/main/cpp}/config_holder.h | 2 +- .../gandiva/src/main/cpp}/env_helper.h | 0 .../main/cpp}/expression_registry_helper.cc | 10 +- .../gandiva/src/main/cpp}/id_to_module_map.h | 0 .../gandiva/src/main/cpp}/jni_common.cc | 22 ++-- .../gandiva/src/main/cpp}/module_holder.h | 2 +- .../gandiva/src/main/cpp}/symbols.map | 0 java/pom.xml | 98 ++-------------- 30 files changed, 160 insertions(+), 317 deletions(-) delete mode 100644 cpp/src/gandiva/jni/CMakeLists.txt rename {cpp/src => java}/gandiva/proto/Types.proto (100%) rename {cpp/src/gandiva/jni => java/gandiva/src/main/cpp}/config_builder.cc (90%) rename {cpp/src/gandiva/jni => java/gandiva/src/main/cpp}/config_holder.cc (96%) rename {cpp/src/gandiva/jni => java/gandiva/src/main/cpp}/config_holder.h (98%) rename {cpp/src/gandiva/jni => java/gandiva/src/main/cpp}/env_helper.h (100%) rename {cpp/src/gandiva/jni => java/gandiva/src/main/cpp}/expression_registry_helper.cc (97%) rename {cpp/src/gandiva/jni => java/gandiva/src/main/cpp}/id_to_module_map.h (100%) rename {cpp/src/gandiva/jni => java/gandiva/src/main/cpp}/jni_common.cc (98%) rename {cpp/src/gandiva/jni => java/gandiva/src/main/cpp}/module_holder.h (98%) rename {cpp/src/gandiva/jni => java/gandiva/src/main/cpp}/symbols.map (100%) diff --git a/ci/docker/fedora-35-cpp.dockerfile b/ci/docker/fedora-35-cpp.dockerfile index 23844580b1a..aeb7c5b7951 100644 --- a/ci/docker/fedora-35-cpp.dockerfile +++ b/ci/docker/fedora-35-cpp.dockerfile @@ -80,7 +80,6 @@ ENV absl_SOURCE=BUNDLED \ ARROW_DEPENDENCY_SOURCE=SYSTEM \ ARROW_DATASET=ON \ ARROW_FLIGHT=ON \ - ARROW_GANDIVA_JAVA=ON \ ARROW_GANDIVA=ON \ ARROW_GCS=ON \ ARROW_HOME=/usr/local \ diff --git a/ci/docker/java-jni-manylinux-201x.dockerfile b/ci/docker/java-jni-manylinux-201x.dockerfile index c77ec63df74..3fdbfb7088e 100644 --- a/ci/docker/java-jni-manylinux-201x.dockerfile +++ b/ci/docker/java-jni-manylinux-201x.dockerfile @@ -38,8 +38,7 @@ RUN yum install -y java-$java-openjdk-devel rh-maven35 && yum clean all ENV JAVA_HOME=/usr/lib/jvm/java-$java-openjdk/ # For ci/scripts/{cpp,java}_*.sh -ENV ARROW_GANDIVA_JAVA=ON \ - ARROW_HOME=/tmp/local \ +ENV ARROW_HOME=/tmp/local \ ARROW_JAVA_CDATA=ON \ ARROW_JNI=ON \ ARROW_PLASMA=ON \ diff --git a/ci/docker/linux-apt-jni.dockerfile b/ci/docker/linux-apt-jni.dockerfile index 92b6cf9a9fc..71826db0487 100644 --- a/ci/docker/linux-apt-jni.dockerfile +++ b/ci/docker/linux-apt-jni.dockerfile @@ -73,7 +73,6 @@ ENV PATH=/opt/cmake-${cmake}-Linux-x86_64/bin:$PATH ENV ARROW_BUILD_TESTS=ON \ ARROW_DATASET=ON \ ARROW_FLIGHT=OFF \ - ARROW_GANDIVA_JAVA=ON \ ARROW_GANDIVA=ON \ ARROW_HOME=/usr/local \ ARROW_JAVA_CDATA=ON \ diff --git a/ci/scripts/cpp_build.sh b/ci/scripts/cpp_build.sh index 5f29381fd1b..9aea4af8fb5 100755 --- a/ci/scripts/cpp_build.sh +++ b/ci/scripts/cpp_build.sh @@ -91,7 +91,6 @@ cmake \ -DARROW_FLIGHT=${ARROW_FLIGHT:-OFF} \ -DARROW_FLIGHT_SQL=${ARROW_FLIGHT_SQL:-OFF} \ -DARROW_FUZZING=${ARROW_FUZZING:-OFF} \ - -DARROW_GANDIVA_JAVA=${ARROW_GANDIVA_JAVA:-OFF} \ -DARROW_GANDIVA_PC_CXX_FLAGS=${ARROW_GANDIVA_PC_CXX_FLAGS:-} \ -DARROW_GANDIVA=${ARROW_GANDIVA:-OFF} \ -DARROW_GCS=${ARROW_GCS:-OFF} \ diff --git a/ci/scripts/java_build.sh b/ci/scripts/java_build.sh index ac252f55b37..220d979b7e6 100755 --- a/ci/scripts/java_build.sh +++ b/ci/scripts/java_build.sh @@ -87,7 +87,7 @@ if [ "${ARROW_JAVA_CDATA}" = "ON" ]; then ${mvn} -Darrow.c.jni.dist.dir=${java_jni_dist_dir} -Parrow-c-data install fi -if [ "${ARROW_GANDIVA_JAVA}" = "ON" ]; then +if [ "${ARROW_JNI}" = "ON" ]; then ${mvn} -Darrow.cpp.build.dir=${java_jni_dist_dir} -Parrow-jni install fi diff --git a/ci/scripts/java_jni_build.sh b/ci/scripts/java_jni_build.sh index c68b52d77ef..881cd47604c 100755 --- a/ci/scripts/java_jni_build.sh +++ b/ci/scripts/java_jni_build.sh @@ -49,6 +49,7 @@ esac : ${CMAKE_BUILD_TYPE:=release} cmake \ -DARROW_JAVA_JNI_ENABLE_DATASET=${ARROW_DATASET:-ON} \ + -DARROW_JAVA_JNI_ENABLE_GANDIVA=${ARROW_GANDIVA:-ON} \ -DBUILD_TESTING=${ARROW_JAVA_BUILD_TESTS} \ -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ -DCMAKE_PREFIX_PATH=${arrow_install_dir} \ diff --git a/ci/scripts/java_jni_macos_build.sh b/ci/scripts/java_jni_macos_build.sh index 342bc2d1188..dd15e79f578 100755 --- a/ci/scripts/java_jni_macos_build.sh +++ b/ci/scripts/java_jni_macos_build.sh @@ -33,7 +33,6 @@ install_dir=${build_dir}/cpp-install : ${ARROW_BUILD_TESTS:=ON} : ${ARROW_DATASET:=ON} : ${ARROW_FILESYSTEM:=ON} -: ${ARROW_GANDIVA_JAVA:=ON} : ${ARROW_GANDIVA:=ON} : ${ARROW_ORC:=ON} : ${ARROW_PARQUET:=ON} @@ -65,7 +64,6 @@ cmake \ -DARROW_DEPENDENCY_USE_SHARED=OFF \ -DARROW_FILESYSTEM=${ARROW_FILESYSTEM} \ -DARROW_GANDIVA=${ARROW_GANDIVA} \ - -DARROW_GANDIVA_JAVA=${ARROW_GANDIVA_JAVA} \ -DARROW_GANDIVA_STATIC_LIBSTDCPP=ON \ -DARROW_JNI=ON \ -DARROW_ORC=${ARROW_ORC} \ @@ -117,7 +115,6 @@ fi echo "=== Copying libraries to the distribution folder ===" mkdir -p "${dist_dir}" cp -L ${install_dir}/lib/libarrow_orc_jni.dylib ${dist_dir} -cp -L ${install_dir}/lib/libgandiva_jni.dylib ${dist_dir} cp -L ${build_dir}/cpp/*/libplasma_java.dylib ${dist_dir} echo "=== Checking shared dependencies for libraries ===" diff --git a/ci/scripts/java_jni_manylinux_build.sh b/ci/scripts/java_jni_manylinux_build.sh index 6669c4fdaa6..5aca27e2a1d 100755 --- a/ci/scripts/java_jni_manylinux_build.sh +++ b/ci/scripts/java_jni_manylinux_build.sh @@ -35,7 +35,6 @@ devtoolset_include_cpp="/opt/rh/devtoolset-${devtoolset_version}/root/usr/includ : ${ARROW_BUILD_TESTS:=ON} : ${ARROW_DATASET:=ON} : ${ARROW_GANDIVA:=ON} -: ${ARROW_GANDIVA_JAVA:=ON} : ${ARROW_FILESYSTEM:=ON} : ${ARROW_JEMALLOC:=ON} : ${ARROW_RPATH_ORIGIN:=ON} @@ -73,7 +72,6 @@ cmake \ -DARROW_DEPENDENCY_SOURCE="VCPKG" \ -DARROW_DEPENDENCY_USE_SHARED=OFF \ -DARROW_FILESYSTEM=${ARROW_FILESYSTEM} \ - -DARROW_GANDIVA_JAVA=${ARROW_GANDIVA_JAVA} \ -DARROW_GANDIVA_PC_CXX_FLAGS=${GANDIVA_CXX_FLAGS} \ -DARROW_GANDIVA=${ARROW_GANDIVA} \ -DARROW_JEMALLOC=${ARROW_JEMALLOC} \ @@ -137,7 +135,6 @@ fi echo "=== Copying libraries to the distribution folder ===" cp -L ${ARROW_HOME}/lib/libarrow_orc_jni.so ${dist_dir} -cp -L ${ARROW_HOME}/lib/libgandiva_jni.so ${dist_dir} cp -L ${build_dir}/cpp/*/libplasma_java.so ${dist_dir} echo "=== Checking shared dependencies for libraries ===" diff --git a/cpp/build-support/lint_exclusions.txt b/cpp/build-support/lint_exclusions.txt index 73cbd884f44..195c3dee36a 100644 --- a/cpp/build-support/lint_exclusions.txt +++ b/cpp/build-support/lint_exclusions.txt @@ -1,5 +1,7 @@ -*_generated* *.grpc.fb.* +*.pb.* +*RcppExports.cpp* +*_generated* *arrowExports.cpp* *parquet_constants.* *parquet_types.* @@ -7,7 +9,6 @@ *pyarrow_lib.h *python/config.h *python/platform.h -*RcppExports.cpp* *thirdparty/* *vendored/* *windows_compatibility.h diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index ffa1b13c904..34ee8f3456a 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -470,8 +470,6 @@ Always OFF if building binaries" OFF) #---------------------------------------------------------------------- set_option_category("Gandiva") - define_option(ARROW_GANDIVA_JAVA "Build the Gandiva JNI wrappers" OFF) - # ARROW-3860: Temporary workaround define_option(ARROW_GANDIVA_STATIC_LIBSTDCPP "Include -static-libstdc++ -static-libgcc when linking with;Gandiva static libraries" diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index a45ae61015e..5d1e50ffa25 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -210,11 +210,13 @@ endmacro() # Find modules are needed by the consumer in case of a static build, or if the # linkage is PUBLIC or INTERFACE. -macro(provide_find_module PACKAGE_NAME) +macro(provide_find_module PACKAGE_NAME ARROW_CMAKE_PACKAGE_NAME) set(module_ "${CMAKE_SOURCE_DIR}/cmake_modules/Find${PACKAGE_NAME}.cmake") if(EXISTS "${module_}") - message(STATUS "Providing CMake module for ${PACKAGE_NAME}") - install(FILES "${module_}" DESTINATION "${ARROW_CMAKE_DIR}/Arrow") + message(STATUS "Providing CMake module for ${PACKAGE_NAME} as part of ${ARROW_CMAKE_PACKAGE_NAME} CMake package" + ) + install(FILES "${module_}" + DESTINATION "${ARROW_CMAKE_DIR}/${ARROW_CMAKE_PACKAGE_NAME}") endif() unset(module_) endmacro() @@ -283,7 +285,7 @@ macro(resolve_dependency DEPENDENCY_NAME) endif() endif() if(${DEPENDENCY_NAME}_SOURCE STREQUAL "SYSTEM" AND ARG_IS_RUNTIME_DEPENDENCY) - provide_find_module(${PACKAGE_NAME}) + provide_find_module(${PACKAGE_NAME} "Arrow") list(APPEND ARROW_SYSTEM_DEPENDENCIES ${PACKAGE_NAME}) find_package(PkgConfig QUIET) foreach(ARG_PC_PACKAGE_NAME ${ARG_PC_PACKAGE_NAMES}) diff --git a/cpp/src/arrow/ArrowConfig.cmake.in b/cpp/src/arrow/ArrowConfig.cmake.in index 8386bcd7280..b2ee43707bf 100644 --- a/cpp/src/arrow/ArrowConfig.cmake.in +++ b/cpp/src/arrow/ArrowConfig.cmake.in @@ -37,8 +37,6 @@ set(ARROW_FULL_SO_VERSION "@ARROW_FULL_SO_VERSION@") set(ARROW_BUNDLED_STATIC_LIBS "@ARROW_BUNDLED_STATIC_LIBS@") set(ARROW_INCLUDE_PATH_SUFFIXES "@ARROW_INCLUDE_PATH_SUFFIXES@") set(ARROW_LIBRARY_PATH_SUFFIXES "@ARROW_LIBRARY_PATH_SUFFIXES@") -set(ARROW_LLVM_VERSIONS "@ARROW_LLVM_VERSIONS@") -set(ARROW_LLVM_VERSION_PRIMARY_MAJOR "@ARROW_LLVM_VERSION_PRIMARY_MAJOR@") set(ARROW_SYSTEM_DEPENDENCIES "@ARROW_SYSTEM_DEPENDENCIES@") include("${CMAKE_CURRENT_LIST_DIR}/ArrowOptions.cmake") @@ -51,7 +49,9 @@ if(ARROW_BUILD_STATIC) find_dependency(Threads) if(DEFINED CMAKE_MODULE_PATH) - set(_CMAKE_MODULE_PATH_OLD ${CMAKE_MODULE_PATH}) + set(ARROW_CMAKE_MODULE_PATH_OLD ${CMAKE_MODULE_PATH}) + else() + unset(ARROW_CMAKE_MODULE_PATH_OLD) endif() set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}") @@ -84,9 +84,9 @@ if(ARROW_BUILD_STATIC) endif() endforeach() - if(DEFINED _CMAKE_MODULE_PATH_OLD) - set(CMAKE_MODULE_PATH ${_CMAKE_MODULE_PATH_OLD}) - unset(_CMAKE_MODULE_PATH_OLD) + if(DEFINED ARROW_CMAKE_MODULE_PATH_OLD) + set(CMAKE_MODULE_PATH ${ARROW_CMAKE_MODULE_PATH_OLD}) + unset(ARROW_CMAKE_MODULE_PATH_OLD) else() unset(CMAKE_MODULE_PATH) endif() diff --git a/cpp/src/gandiva/CMakeLists.txt b/cpp/src/gandiva/CMakeLists.txt index 12e8b301f47..312ab84f65e 100644 --- a/cpp/src/gandiva/CMakeLists.txt +++ b/cpp/src/gandiva/CMakeLists.txt @@ -26,6 +26,7 @@ add_custom_target(gandiva-benchmarks) add_dependencies(gandiva-all gandiva gandiva-tests gandiva-benchmarks) find_package(LLVMAlt REQUIRED) +provide_find_module(LLVMAlt "Gandiva") add_definitions(-DGANDIVA_LLVM_VERSION=${LLVM_VERSION_MAJOR}) @@ -155,7 +156,8 @@ add_arrow_lib(gandiva ${GANDIVA_STATIC_LINK_LIBS} STATIC_INSTALL_INTERFACE_LIBS Arrow::arrow_static - LLVM::LLVM_HEADERS) + LLVM::LLVM_HEADERS + LLVM::LLVM_LIBS) foreach(LIB_TARGET ${GANDIVA_LIBRARIES}) target_compile_definitions(${LIB_TARGET} PRIVATE GANDIVA_EXPORTING) @@ -246,9 +248,5 @@ add_gandiva_test(internals-test gdv_function_stubs_test.cc interval_holder_test.cc) -if(ARROW_GANDIVA_JAVA) - add_subdirectory(jni) -endif() - add_subdirectory(precompiled) add_subdirectory(tests) diff --git a/cpp/src/gandiva/GandivaConfig.cmake.in b/cpp/src/gandiva/GandivaConfig.cmake.in index c6d7cef73d7..20cdc75acb6 100644 --- a/cpp/src/gandiva/GandivaConfig.cmake.in +++ b/cpp/src/gandiva/GandivaConfig.cmake.in @@ -26,9 +26,24 @@ @PACKAGE_INIT@ +set(ARROW_LLVM_VERSIONS "@ARROW_LLVM_VERSIONS@") +set(ARROW_LLVM_VERSION_PRIMARY_MAJOR "@ARROW_LLVM_VERSION_PRIMARY_MAJOR@") + include(CMakeFindDependencyMacro) find_dependency(Arrow) +if(DEFINED CMAKE_MODULE_PATH) + set(GANDIVA_CMAKE_MODULE_PATH_OLD ${CMAKE_MODULE_PATH}) +else() + unset(GANDIVA_CMAKE_MODULE_PATH_OLD) +endif() +set(CMAKE_MODULE_PATH "${CMAKE_CURRENT_LIST_DIR}") find_dependency(LLVMAlt) +if(DEFINED GANDIVA_CMAKE_MODULE_PATH_OLD) + set(CMAKE_MODULE_PATH ${GANDIVA_CMAKE_MODULE_PATH_OLD}) + unset(GANDIVA_CMAKE_MODULE_PATH_OLD) +else() + unset(CMAKE_MODULE_PATH) +endif() include("${CMAKE_CURRENT_LIST_DIR}/GandivaTargets.cmake") diff --git a/cpp/src/gandiva/jni/CMakeLists.txt b/cpp/src/gandiva/jni/CMakeLists.txt deleted file mode 100644 index b89356121dc..00000000000 --- a/cpp/src/gandiva/jni/CMakeLists.txt +++ /dev/null @@ -1,109 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -if(CMAKE_VERSION VERSION_LESS 3.11) - message(FATAL_ERROR "Building the Gandiva JNI bindings requires CMake version >= 3.11") -endif() - -if(MSVC) - add_definitions(-DPROTOBUF_USE_DLLS) -endif() - -# Find JNI -find_package(JNI REQUIRED) - -set(PROTO_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) -set(PROTO_OUTPUT_FILES "${PROTO_OUTPUT_DIR}/Types.pb.cc") -set(PROTO_OUTPUT_FILES ${PROTO_OUTPUT_FILES} "${PROTO_OUTPUT_DIR}/Types.pb.h") - -set_source_files_properties(${PROTO_OUTPUT_FILES} PROPERTIES GENERATED TRUE) - -get_filename_component(ABS_GANDIVA_PROTO - ${CMAKE_SOURCE_DIR}/src/gandiva/proto/Types.proto ABSOLUTE) - -add_custom_command(OUTPUT ${PROTO_OUTPUT_FILES} - COMMAND ${ARROW_PROTOBUF_PROTOC} --proto_path - ${CMAKE_SOURCE_DIR}/src/gandiva/proto --cpp_out - ${PROTO_OUTPUT_DIR} - ${CMAKE_SOURCE_DIR}/src/gandiva/proto/Types.proto - DEPENDS ${ABS_GANDIVA_PROTO} ${ARROW_PROTOBUF_LIBPROTOBUF} - COMMENT "Running PROTO compiler on Types.proto" - VERBATIM) - -add_custom_target(gandiva_jni_proto ALL DEPENDS ${PROTO_OUTPUT_FILES}) -set(PROTO_SRCS "${PROTO_OUTPUT_DIR}/Types.pb.cc") -set(PROTO_HDRS "${PROTO_OUTPUT_DIR}/Types.pb.h") - -# Create the jni header file (from the java class). -set(JNI_HEADERS_DIR "${CMAKE_CURRENT_BINARY_DIR}/java") -add_subdirectory(../../../../java/gandiva ./java/gandiva) - -set(GANDIVA_LINK_LIBS ${ARROW_PROTOBUF_LIBPROTOBUF}) -if(ARROW_BUILD_STATIC) - list(APPEND GANDIVA_LINK_LIBS gandiva_static) -else() - list(APPEND GANDIVA_LINK_LIBS gandiva_shared) -endif() - -set(GANDIVA_JNI_SOURCES - config_builder.cc - config_holder.cc - expression_registry_helper.cc - jni_common.cc - ${PROTO_SRCS}) - -# For users of gandiva_jni library (including integ tests), include-dir is : -# /usr/**/include dir after install, -# cpp/include during build -# For building gandiva_jni library itself, include-dir (in addition to above) is : -# cpp/src -add_arrow_lib(gandiva_jni - SOURCES - ${GANDIVA_JNI_SOURCES} - OUTPUTS - GANDIVA_JNI_LIBRARIES - BUILD_SHARED - ON - BUILD_STATIC - OFF - SHARED_LINK_LIBS - ${GANDIVA_LINK_LIBS} - DEPENDENCIES - ${GANDIVA_LINK_LIBS} - gandiva_java - gandiva_jni_headers - gandiva_jni_proto - EXTRA_INCLUDES - $ - $ - $ - PRIVATE_INCLUDES - ${JNI_INCLUDE_DIRS} - ${CMAKE_CURRENT_BINARY_DIR}) - -add_dependencies(gandiva ${GANDIVA_JNI_LIBRARIES}) - -if(ARROW_BUILD_SHARED) - # filter out everything that is not needed for the jni bridge - # statically linked stdc++ has conflicts with stdc++ loaded by other libraries. - if(CXX_LINKER_SUPPORTS_VERSION_SCRIPT) - set_target_properties(gandiva_jni_shared - PROPERTIES LINK_FLAGS - "-Wl,--version-script=${CMAKE_SOURCE_DIR}/src/gandiva/jni/symbols.map" - ) - endif() -endif() diff --git a/docs/source/developers/java/building.rst b/docs/source/developers/java/building.rst index 42faadb8d21..846277faa34 100644 --- a/docs/source/developers/java/building.rst +++ b/docs/source/developers/java/building.rst @@ -111,28 +111,19 @@ Maven $ ls -latr ../java-dist/lib |__ libarrow_cdata_jni.dylib -- To build JNI ORC & JNI Gandiva libraries: +- To build all JNI libraries except the JNI C Data Interface library: .. code-block:: $ cd arrow/java $ export JAVA_HOME= $ java --version - $ mvn clean generate-resources -Pgenerate-jnicpp-dylib_so -N - $ ls -latr java-dist/lib - |__ libarrow_orc_jni.dylib - |__ libgandiva_jni.dylib - -- To build only the JNI Dataset library: - - .. code-block:: - - $ cd arrow/java - $ export JAVA_HOME= - $ java --version - $ mvn clean generate-resources -Pgenerate-dataset-dylib_so -N + $ mvn clean generate-resources -Pgenerate-jni-dylib_so -N $ ls -latr java-dist/lib |__ libarrow_dataset_jni.dylib + |__ libarrow_orc_jni.dylib + |__ libgandiva_jni.dylib + |__ libplasma_java.dylib CMake ~~~~~ @@ -142,7 +133,7 @@ CMake .. code-block:: $ cd arrow - $ mkdir -p java-dist java-jni + $ mkdir -p java-dist java-cdata $ cmake \ -S java \ -B java-jni \ @@ -155,7 +146,7 @@ CMake $ ls -latr java-dist/lib |__ libarrow_cdata_jni.dylib -- To build JNI ORC & Gandiva libraries: +- To build all JNI libraries except the JNI C Data Interface library: .. code-block:: @@ -176,7 +167,6 @@ CMake -DARROW_DEPENDENCY_USE_SHARED=OFF \ -DARROW_FILESYSTEM=ON \ -DARROW_GANDIVA=ON \ - -DARROW_GANDIVA_JAVA=ON \ -DARROW_GANDIVA_STATIC_LIBSTDCPP=ON \ -DARROW_JNI=ON \ -DARROW_ORC=ON \ @@ -190,21 +180,11 @@ CMake -DCMAKE_INSTALL_PREFIX=java-dist \ -DCMAKE_UNITY_BUILD=ON $ cmake --build cpp-jni --target install --config Release - $ ls -latr java-dist/lib - |__ libarrow_orc_jni.dylib - |__ libgandiva_jni.dylib - -- To build only the JNI Dataset library: - - .. code-block:: - - $ cd arrow - $ mkdir -p java-dist java-jni $ cmake \ -S java \ -B java-jni \ - -DARROW_JAVA_JNI_ENABLE_DATASET=ON \ - -DARROW_JAVA_JNI_ENABLE_DEFAULT=OFF \ + -DARROW_JAVA_JNI_ENABLE_C=OFF \ + -DARROW_JAVA_JNI_ENABLE_DEFAULT=ON \ -DBUILD_TESTING=OFF \ -DCMAKE_BUILD_TYPE=Release \ -DCMAKE_INSTALL_PREFIX=java-dist/lib \ @@ -212,6 +192,9 @@ CMake $ cmake --build java-jni --target install --config Release $ ls -latr java-dist/lib |__ libarrow_dataset_jni.dylib + |__ libarrow_orc_jni.dylib + |__ libgandiva_jni.dylib + |__ libplasma_java.dylib Archery ~~~~~~~ diff --git a/java/CMakeLists.txt b/java/CMakeLists.txt index 4778f030c25..69e3f3940d5 100644 --- a/java/CMakeLists.txt +++ b/java/CMakeLists.txt @@ -29,6 +29,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) option(ARROW_JAVA_JNI_ENABLE_DEFAULT "Whether enable components by default or not" ON) option(ARROW_JAVA_JNI_ENABLE_C "Enable C data interface" ${ARROW_JAVA_JNI_ENABLE_DEFAULT}) option(ARROW_JAVA_JNI_ENABLE_DATASET "Enable dataset" ${ARROW_JAVA_JNI_ENABLE_DEFAULT}) +option(ARROW_JAVA_JNI_ENABLE_GANDIVA "Enable Gandiva" ${ARROW_JAVA_JNI_ENABLE_DEFAULT}) # ccache option(ARROW_JAVA_JNI_USE_CCACHE "Use ccache when compiling (if available)" ON) @@ -70,3 +71,6 @@ endif() if(ARROW_JAVA_JNI_ENABLE_DATASET) add_subdirectory(dataset) endif() +if(ARROW_JAVA_JNI_ENABLE_GANDIVA) + add_subdirectory(gandiva) +endif() diff --git a/java/gandiva/CMakeLists.txt b/java/gandiva/CMakeLists.txt index 5010daf7996..805253da922 100644 --- a/java/gandiva/CMakeLists.txt +++ b/java/gandiva/CMakeLists.txt @@ -15,41 +15,76 @@ # specific language governing permissions and limitations # under the License. -project(gandiva_java) - -# Find java/jni -include(FindJava) -include(UseJava) -include(FindJNI) - -message("generating headers to ${JNI_HEADERS_DIR}/jni") - -# generate_native_headers is available only from java8 -# centos5 does not have java8 images, so supporting java 7 too. -# unfortunately create_javah does not work in java8 correctly. -if(ARROW_GANDIVA_JAVA7) - add_jar(gandiva_java - src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java - src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java - src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryJniHelper.java - src/main/java/org/apache/arrow/gandiva/exceptions/GandivaException.java) - - create_javah(TARGET gandiva_jni_headers - CLASSES org.apache.arrow.gandiva.evaluator.ConfigurationBuilder - org.apache.arrow.gandiva.evaluator.JniWrapper - org.apache.arrow.gandiva.evaluator.ExpressionRegistryJniHelper - org.apache.arrow.gandiva.exceptions.GandivaException - DEPENDS gandiva_java - CLASSPATH gandiva_java - OUTPUT_DIR ${JNI_HEADERS_DIR}/jni) +find_package(Gandiva REQUIRED) + +include_directories(${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR} + ${JNI_INCLUDE_DIRS} ${JNI_HEADERS_DIR}) + +add_jar(arrow_java_jni_gandiva_jar + src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java + src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java + src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryJniHelper.java + src/main/java/org/apache/arrow/gandiva/exceptions/GandivaException.java + GENERATE_NATIVE_HEADERS + arrow_java_jni_gandiva_headers) + +set(GANDIVA_PROTO_OUTPUT_DIR ${CMAKE_CURRENT_BINARY_DIR}) +set(GANDIVA_PROTO_OUTPUT_FILES "${GANDIVA_PROTO_OUTPUT_DIR}/Types.pb.cc" + "${GANDIVA_PROTO_OUTPUT_DIR}/Types.pb.h") + +set_source_files_properties(${GANDIVA_PROTO_OUTPUT_FILES} PROPERTIES GENERATED TRUE) + +set(GANDIVA_PROTO_DIR ${CMAKE_CURRENT_SOURCE_DIR}/proto) +get_filename_component(GANDIVA_PROTO_FILE_ABSOLUTE ${GANDIVA_PROTO_DIR}/Types.proto + ABSOLUTE) + +find_package(Protobuf REQUIRED) +if(MSVC) + add_definitions(-DPROTOBUF_USE_DLLS) +endif() +add_custom_command(OUTPUT ${GANDIVA_PROTO_OUTPUT_FILES} + COMMAND protobuf::protoc --proto_path ${GANDIVA_PROTO_DIR} --cpp_out + ${GANDIVA_PROTO_OUTPUT_DIR} ${GANDIVA_PROTO_FILE_ABSOLUTE} + DEPENDS ${GANDIVA_PROTO_FILE_ABSOLUTE} + COMMENT "Running Protobuf compiler on Types.proto" + VERBATIM) + +add_custom_target(garrow_java_jni_gandiva_proto ALL DEPENDS ${GANDIVA_PROTO_OUTPUT_FILES}) +add_library(arrow_java_jni_gandiva SHARED + src/main/cpp/config_builder.cc + src/main/cpp/config_holder.cc + src/main/cpp/expression_registry_helper.cc + src/main/cpp/jni_common.cc + ${GANDIVA_PROTO_OUTPUT_FILES}) +set_property(TARGET arrow_java_jni_gandiva PROPERTY OUTPUT_NAME "gandiva_jni") +target_link_libraries(arrow_java_jni_gandiva + arrow_java_jni_gandiva_headers + jni + protobuf::libprotobuf + Gandiva::gandiva_static) + +# Localize thirdparty symbols using a linker version script. This hides them +# from the client application. The OS X linker does not support the +# version-script option. +if(CMAKE_VERSION VERSION_LESS 3.18) + if(APPLE OR WIN32) + set(CXX_LINKER_SUPPORTS_VERSION_SCRIPT FALSE) + else() + set(CXX_LINKER_SUPPORTS_VERSION_SCRIPT TRUE) + endif() else() - add_jar(gandiva_java - src/main/java/org/apache/arrow/gandiva/evaluator/ConfigurationBuilder.java - src/main/java/org/apache/arrow/gandiva/evaluator/JniWrapper.java - src/main/java/org/apache/arrow/gandiva/evaluator/ExpressionRegistryJniHelper.java - src/main/java/org/apache/arrow/gandiva/exceptions/GandivaException.java - GENERATE_NATIVE_HEADERS - gandiva_jni_headers - DESTINATION - ${JNI_HEADERS_DIR}/jni) + include(CheckLinkerFlag) + check_linker_flag(CXX + "-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/src/main/cpp/symbols.map" + CXX_LINKER_SUPPORTS_VERSION_SCRIPT) endif() +# filter out everything that is not needed for the jni bridge +# statically linked stdc++ has conflicts with stdc++ loaded by other libraries. +if(CXX_LINKER_SUPPORTS_VERSION_SCRIPT) + set_target_properties(arrow_java_jni_gandiva + PROPERTIES LINK_FLAGS + "-Wl,--version-script=${CMAKE_CURRENT_SOURCE_DIR}/src/main/cpp/symbols.map" + ) +endif() + +install(TARGETS arrow_java_jni_gandiva DESTINATION ${CMAKE_INSTALL_PREFIX}) diff --git a/java/gandiva/pom.xml b/java/gandiva/pom.xml index 3c6ba7b12c5..60c9dc8f35a 100644 --- a/java/gandiva/pom.xml +++ b/java/gandiva/pom.xml @@ -136,7 +136,7 @@ com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} - ../../cpp/src/gandiva/proto + proto diff --git a/cpp/src/gandiva/proto/Types.proto b/java/gandiva/proto/Types.proto similarity index 100% rename from cpp/src/gandiva/proto/Types.proto rename to java/gandiva/proto/Types.proto diff --git a/cpp/src/gandiva/jni/config_builder.cc b/java/gandiva/src/main/cpp/config_builder.cc similarity index 90% rename from cpp/src/gandiva/jni/config_builder.cc rename to java/gandiva/src/main/cpp/config_builder.cc index b115210cefe..85c661ee943 100644 --- a/cpp/src/gandiva/jni/config_builder.cc +++ b/java/gandiva/src/main/cpp/config_builder.cc @@ -17,10 +17,11 @@ #include -#include "gandiva/configuration.h" -#include "gandiva/jni/config_holder.h" -#include "gandiva/jni/env_helper.h" -#include "jni/org_apache_arrow_gandiva_evaluator_ConfigurationBuilder.h" +#include + +#include "config_holder.h" +#include "env_helper.h" +#include "org_apache_arrow_gandiva_evaluator_ConfigurationBuilder.h" using gandiva::ConfigHolder; using gandiva::Configuration; diff --git a/cpp/src/gandiva/jni/config_holder.cc b/java/gandiva/src/main/cpp/config_holder.cc similarity index 96% rename from cpp/src/gandiva/jni/config_holder.cc rename to java/gandiva/src/main/cpp/config_holder.cc index 11d305c819c..dfa6afce199 100644 --- a/cpp/src/gandiva/jni/config_holder.cc +++ b/java/gandiva/src/main/cpp/config_holder.cc @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -#include "gandiva/jni/config_holder.h" - #include +#include "config_holder.h" + namespace gandiva { int64_t ConfigHolder::config_id_ = 1; diff --git a/cpp/src/gandiva/jni/config_holder.h b/java/gandiva/src/main/cpp/config_holder.h similarity index 98% rename from cpp/src/gandiva/jni/config_holder.h rename to java/gandiva/src/main/cpp/config_holder.h index 3fdb7a01d55..ae031495ab2 100644 --- a/cpp/src/gandiva/jni/config_holder.h +++ b/java/gandiva/src/main/cpp/config_holder.h @@ -22,7 +22,7 @@ #include #include -#include "gandiva/configuration.h" +#include namespace gandiva { diff --git a/cpp/src/gandiva/jni/env_helper.h b/java/gandiva/src/main/cpp/env_helper.h similarity index 100% rename from cpp/src/gandiva/jni/env_helper.h rename to java/gandiva/src/main/cpp/env_helper.h diff --git a/cpp/src/gandiva/jni/expression_registry_helper.cc b/java/gandiva/src/main/cpp/expression_registry_helper.cc similarity index 97% rename from cpp/src/gandiva/jni/expression_registry_helper.cc rename to java/gandiva/src/main/cpp/expression_registry_helper.cc index 338290618d8..6765df3b972 100644 --- a/cpp/src/gandiva/jni/expression_registry_helper.cc +++ b/java/gandiva/src/main/cpp/expression_registry_helper.cc @@ -15,14 +15,14 @@ // specific language governing permissions and limitations // under the License. -#include "jni/org_apache_arrow_gandiva_evaluator_ExpressionRegistryJniHelper.h" - #include +#include +#include +#include + #include "Types.pb.h" -#include "arrow/util/logging.h" -#include "gandiva/arrow.h" -#include "gandiva/expression_registry.h" +#include "org_apache_arrow_gandiva_evaluator_ExpressionRegistryJniHelper.h" using gandiva::DataTypePtr; using gandiva::ExpressionRegistry; diff --git a/cpp/src/gandiva/jni/id_to_module_map.h b/java/gandiva/src/main/cpp/id_to_module_map.h similarity index 100% rename from cpp/src/gandiva/jni/id_to_module_map.h rename to java/gandiva/src/main/cpp/id_to_module_map.h diff --git a/cpp/src/gandiva/jni/jni_common.cc b/java/gandiva/src/main/cpp/jni_common.cc similarity index 98% rename from cpp/src/gandiva/jni/jni_common.cc rename to java/gandiva/src/main/cpp/jni_common.cc index 3940b9c81a9..ba0af1106b1 100644 --- a/cpp/src/gandiva/jni/jni_common.cc +++ b/java/gandiva/src/main/cpp/jni_common.cc @@ -28,19 +28,19 @@ #include #include #include +#include +#include +#include +#include +#include +#include #include "Types.pb.h" -#include "gandiva/configuration.h" -#include "gandiva/decimal_scalar.h" -#include "gandiva/filter.h" -#include "gandiva/jni/config_holder.h" -#include "gandiva/jni/env_helper.h" -#include "gandiva/jni/id_to_module_map.h" -#include "gandiva/jni/module_holder.h" -#include "gandiva/projector.h" -#include "gandiva/selection_vector.h" -#include "gandiva/tree_expr_builder.h" -#include "jni/org_apache_arrow_gandiva_evaluator_JniWrapper.h" +#include "config_holder.h" +#include "env_helper.h" +#include "id_to_module_map.h" +#include "module_holder.h" +#include "org_apache_arrow_gandiva_evaluator_JniWrapper.h" using gandiva::ConditionPtr; using gandiva::DataTypePtr; diff --git a/cpp/src/gandiva/jni/module_holder.h b/java/gandiva/src/main/cpp/module_holder.h similarity index 98% rename from cpp/src/gandiva/jni/module_holder.h rename to java/gandiva/src/main/cpp/module_holder.h index 929c64231f2..74bad29e68c 100644 --- a/cpp/src/gandiva/jni/module_holder.h +++ b/java/gandiva/src/main/cpp/module_holder.h @@ -20,7 +20,7 @@ #include #include -#include "gandiva/arrow.h" +#include namespace gandiva { diff --git a/cpp/src/gandiva/jni/symbols.map b/java/gandiva/src/main/cpp/symbols.map similarity index 100% rename from cpp/src/gandiva/jni/symbols.map rename to java/gandiva/src/main/cpp/symbols.map diff --git a/java/pom.xml b/java/pom.xml index ea5c46334a3..b5db145cc5c 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -930,7 +930,7 @@ 3.1.0 - cdatadir + cdata-dir generate-resources exec @@ -942,7 +942,7 @@ - cdatadefine + cdata-cmake generate-resources exec @@ -962,7 +962,7 @@ - cdatabuild + cdata-build generate-resources exec @@ -982,7 +982,7 @@ - generate-dataset-dylib_so + generate-jni-dylib_so java-dist/lib false @@ -995,7 +995,7 @@ 3.1.0 - datasetdir + jni-dir generate-resources exec @@ -1007,7 +1007,7 @@ - datasetarrowdependency + jni-cpp-cmake generate-resources exec @@ -1023,7 +1023,6 @@ -DARROW_DEPENDENCY_USE_SHARED=OFF -DARROW_FILESYSTEM=ON -DARROW_GANDIVA=ON - -DARROW_GANDIVA_JAVA=ON -DARROW_GANDIVA_STATIC_LIBSTDCPP=ON -DARROW_JNI=ON -DARROW_ORC=ON @@ -1042,7 +1041,7 @@ - datasetarrowdependencybuild + jni-cpp-build generate-resources exec @@ -1057,7 +1056,7 @@ - datasetdefine + jni-cmake generate-resources exec @@ -1067,8 +1066,8 @@ -S java -B java-jni - -DARROW_JAVA_JNI_ENABLE_DATASET=ON - -DARROW_JAVA_JNI_ENABLE_DEFAULT=OFF + -DARROW_JAVA_JNI_ENABLE_C=OFF + -DARROW_JAVA_JNI_ENABLE_DEFAULT=ON -DBUILD_TESTING=OFF -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=${arrow.c.jni.dist.dir} @@ -1078,7 +1077,7 @@ - datasetbuild + jni-build generate-resources exec @@ -1097,81 +1096,6 @@ - - generate-jnicpp-dylib_so - - - - org.codehaus.mojo - exec-maven-plugin - 3.1.0 - - - jnidir - generate-resources - - exec - - - mkdir - -p java-dist cpp-jni - ../ - - - - jnidefine - generate-resources - - exec - - - cmake - - -S cpp - -B cpp-jni - -DARROW_CSV=ON - -DARROW_DATASET=ON - -DARROW_DEPENDENCY_SOURCE=BUNDLED - -DARROW_DEPENDENCY_USE_SHARED=OFF - -DARROW_FILESYSTEM=ON - -DARROW_GANDIVA=ON - -DARROW_GANDIVA_JAVA=ON - -DARROW_GANDIVA_STATIC_LIBSTDCPP=ON - -DARROW_JNI=ON - -DARROW_ORC=ON - -DARROW_PARQUET=ON - -DARROW_PLASMA=ON - -DARROW_PLASMA_JAVA_CLIENT=ON - -DARROW_S3=ON - -DARROW_USE_CCACHE=ON - -DCMAKE_BUILD_TYPE=Release - -DCMAKE_INSTALL_LIBDIR=lib - -DCMAKE_INSTALL_PREFIX=java-dist - -DCMAKE_UNITY_BUILD=ON - - ../ - - - - jnidefinebuild - generate-resources - - exec - - - cmake - - --build cpp-jni --target install --config Release - - ../ - - - - - - - - From a866f2ff820786372ec2e1300499ef33918d17ef Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Mon, 19 Sep 2022 06:09:48 +0900 Subject: [PATCH 096/133] ARROW-17561: [Java][ORC] Move JNI build configuration from cpp/ to java/ (#14162) Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- ci/scripts/java_jni_build.sh | 9 +++- ci/scripts/java_jni_macos_build.sh | 1 - ci/scripts/java_jni_manylinux_build.sh | 1 - cpp/CMakeLists.txt | 6 --- java/CMakeLists.txt | 6 +++ java/adapter/orc/CMakeLists.txt | 27 +++++----- java/adapter/orc/src/main/cpp/CMakeLists.txt | 53 -------------------- java/c/CMakeLists.txt | 2 +- java/dataset/CMakeLists.txt | 2 +- java/gandiva/CMakeLists.txt | 2 +- 10 files changed, 29 insertions(+), 80 deletions(-) delete mode 100644 java/adapter/orc/src/main/cpp/CMakeLists.txt diff --git a/ci/scripts/java_jni_build.sh b/ci/scripts/java_jni_build.sh index 881cd47604c..999506d65cf 100755 --- a/ci/scripts/java_jni_build.sh +++ b/ci/scripts/java_jni_build.sh @@ -25,6 +25,8 @@ build_dir=${3}/java_jni # The directory where the final binaries will be stored when scripts finish dist_dir=${4} +prefix_dir="${build_dir}/java-jni" + echo "=== Clear output directories and leftovers ===" # Clear output directories and leftovers rm -rf ${build_dir} @@ -50,10 +52,12 @@ esac cmake \ -DARROW_JAVA_JNI_ENABLE_DATASET=${ARROW_DATASET:-ON} \ -DARROW_JAVA_JNI_ENABLE_GANDIVA=${ARROW_GANDIVA:-ON} \ + -DARROW_JAVA_JNI_ENABLE_ORC=${ARROW_ORC:-ON} \ -DBUILD_TESTING=${ARROW_JAVA_BUILD_TESTS} \ -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ -DCMAKE_PREFIX_PATH=${arrow_install_dir} \ - -DCMAKE_INSTALL_PREFIX=${dist_dir} \ + -DCMAKE_INSTALL_PREFIX=${prefix_dir} \ + -DCMAKE_INSTALL_LIBDIR=lib \ -DCMAKE_UNITY_BUILD=${CMAKE_UNITY_BUILD:-OFF} \ -GNinja \ ${JAVA_JNI_CMAKE_ARGS:-} \ @@ -68,3 +72,6 @@ if [ "${ARROW_JAVA_BUILD_TESTS}" = "ON" ]; then fi cmake --build . --config ${CMAKE_BUILD_TYPE} --target install popd + +mkdir -p ${dist_dir} +mv ${prefix_dir}/lib/* ${dist_dir}/ diff --git a/ci/scripts/java_jni_macos_build.sh b/ci/scripts/java_jni_macos_build.sh index dd15e79f578..51740d2537b 100755 --- a/ci/scripts/java_jni_macos_build.sh +++ b/ci/scripts/java_jni_macos_build.sh @@ -114,7 +114,6 @@ fi echo "=== Copying libraries to the distribution folder ===" mkdir -p "${dist_dir}" -cp -L ${install_dir}/lib/libarrow_orc_jni.dylib ${dist_dir} cp -L ${build_dir}/cpp/*/libplasma_java.dylib ${dist_dir} echo "=== Checking shared dependencies for libraries ===" diff --git a/ci/scripts/java_jni_manylinux_build.sh b/ci/scripts/java_jni_manylinux_build.sh index 5aca27e2a1d..2048ecf04a7 100755 --- a/ci/scripts/java_jni_manylinux_build.sh +++ b/ci/scripts/java_jni_manylinux_build.sh @@ -134,7 +134,6 @@ fi echo "=== Copying libraries to the distribution folder ===" -cp -L ${ARROW_HOME}/lib/libarrow_orc_jni.so ${dist_dir} cp -L ${build_dir}/cpp/*/libplasma_java.so ${dist_dir} echo "=== Checking shared dependencies for libraries ===" diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a804ee11f46..a9bed5e5896 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1023,12 +1023,6 @@ if(ARROW_PARQUET) endif() endif() -if(ARROW_JNI) - if(ARROW_ORC) - add_subdirectory(../java/adapter/orc/src/main/cpp ./java/orc/jni) - endif() -endif() - if(ARROW_GANDIVA) add_subdirectory(src/gandiva) endif() diff --git a/java/CMakeLists.txt b/java/CMakeLists.txt index 69e3f3940d5..f184bb0a6f9 100644 --- a/java/CMakeLists.txt +++ b/java/CMakeLists.txt @@ -30,6 +30,9 @@ option(ARROW_JAVA_JNI_ENABLE_DEFAULT "Whether enable components by default or no option(ARROW_JAVA_JNI_ENABLE_C "Enable C data interface" ${ARROW_JAVA_JNI_ENABLE_DEFAULT}) option(ARROW_JAVA_JNI_ENABLE_DATASET "Enable dataset" ${ARROW_JAVA_JNI_ENABLE_DEFAULT}) option(ARROW_JAVA_JNI_ENABLE_GANDIVA "Enable Gandiva" ${ARROW_JAVA_JNI_ENABLE_DEFAULT}) +option(ARROW_JAVA_JNI_ENABLE_ORC "Enable ORC" ${ARROW_JAVA_JNI_ENABLE_DEFAULT}) + +include(GNUInstallDirs) # ccache option(ARROW_JAVA_JNI_USE_CCACHE "Use ccache when compiling (if available)" ON) @@ -74,3 +77,6 @@ endif() if(ARROW_JAVA_JNI_ENABLE_GANDIVA) add_subdirectory(gandiva) endif() +if(ARROW_JAVA_JNI_ENABLE_ORC) + add_subdirectory(adapter/orc) +endif() diff --git a/java/adapter/orc/CMakeLists.txt b/java/adapter/orc/CMakeLists.txt index e2d4655d79e..764a73863d2 100644 --- a/java/adapter/orc/CMakeLists.txt +++ b/java/adapter/orc/CMakeLists.txt @@ -15,22 +15,12 @@ # specific language governing permissions and limitations # under the License. -# -# arrow_orc_java -# - -# Headers: top level - -project(arrow_orc_java) +find_package(Arrow REQUIRED) -# Find java/jni -include(FindJava) -include(UseJava) -include(FindJNI) +include_directories(${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR} + ${JNI_INCLUDE_DIRS} ${JNI_HEADERS_DIR}) -message("generating headers to ${JNI_HEADERS_DIR}") - -add_jar(arrow_orc_java +add_jar(arrow_java_jni_orc_jar src/main/java/org/apache/arrow/adapter/orc/OrcReaderJniWrapper.java src/main/java/org/apache/arrow/adapter/orc/OrcStripeReaderJniWrapper.java src/main/java/org/apache/arrow/adapter/orc/OrcMemoryJniWrapper.java @@ -38,6 +28,13 @@ add_jar(arrow_orc_java src/main/java/org/apache/arrow/adapter/orc/OrcRecordBatch.java src/main/java/org/apache/arrow/adapter/orc/OrcFieldNode.java GENERATE_NATIVE_HEADERS - arrow_orc_java-native + arrow_java_jni_orc_headers DESTINATION ${JNI_HEADERS_DIR}) + +add_library(arrow_java_jni_orc SHARED src/main/cpp/jni_wrapper.cpp) +set_property(TARGET arrow_java_jni_orc PROPERTY OUTPUT_NAME "arrow_orc_jni") +target_link_libraries(arrow_java_jni_orc arrow_java_jni_orc_headers jni + Arrow::arrow_static) + +install(TARGETS arrow_java_jni_orc DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/java/adapter/orc/src/main/cpp/CMakeLists.txt b/java/adapter/orc/src/main/cpp/CMakeLists.txt deleted file mode 100644 index 96d5748729e..00000000000 --- a/java/adapter/orc/src/main/cpp/CMakeLists.txt +++ /dev/null @@ -1,53 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -# -# arrow_orc_jni -# - -project(arrow_orc_jni) - -cmake_minimum_required(VERSION 3.11) - -find_package(JNI REQUIRED) - -add_custom_target(arrow_orc_jni) - -set(JNI_HEADERS_DIR "${CMAKE_CURRENT_BINARY_DIR}/generated") - -add_subdirectory(../../../../orc ./java) - -add_arrow_lib(arrow_orc_jni - BUILD_SHARED - ON - BUILD_STATIC - OFF - SOURCES - jni_wrapper.cpp - OUTPUTS - ARROW_ORC_JNI_LIBRARIES - SHARED_PRIVATE_LINK_LIBS - arrow_static - EXTRA_INCLUDES - ${JNI_HEADERS_DIR} - PRIVATE_INCLUDES - ${JNI_INCLUDE_DIRS} - DEPENDENCIES - arrow_static - arrow_orc_java) - -add_dependencies(arrow_orc_jni ${ARROW_ORC_JNI_LIBRARIES}) diff --git a/java/c/CMakeLists.txt b/java/c/CMakeLists.txt index 7510ab233fe..a069a1e8eea 100644 --- a/java/c/CMakeLists.txt +++ b/java/c/CMakeLists.txt @@ -30,4 +30,4 @@ add_library(arrow_java_jni_cdata SHARED src/main/cpp/jni_wrapper.cc) set_property(TARGET arrow_java_jni_cdata PROPERTY OUTPUT_NAME "arrow_cdata_jni") target_link_libraries(arrow_java_jni_cdata arrow_java_jni_cdata_headers jni) -install(TARGETS arrow_java_jni_cdata DESTINATION ${CMAKE_INSTALL_PREFIX}) +install(TARGETS arrow_java_jni_cdata DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/java/dataset/CMakeLists.txt b/java/dataset/CMakeLists.txt index 3b76b4e03bc..c777e21296d 100644 --- a/java/dataset/CMakeLists.txt +++ b/java/dataset/CMakeLists.txt @@ -42,4 +42,4 @@ if(BUILD_TESTING) add_test(NAME arrow-java-jni-dataset-test COMMAND arrow-java-jni-dataset-test) endif() -install(TARGETS arrow_java_jni_dataset DESTINATION ${CMAKE_INSTALL_PREFIX}) +install(TARGETS arrow_java_jni_dataset DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/java/gandiva/CMakeLists.txt b/java/gandiva/CMakeLists.txt index 805253da922..954bc799730 100644 --- a/java/gandiva/CMakeLists.txt +++ b/java/gandiva/CMakeLists.txt @@ -87,4 +87,4 @@ if(CXX_LINKER_SUPPORTS_VERSION_SCRIPT) ) endif() -install(TARGETS arrow_java_jni_gandiva DESTINATION ${CMAKE_INSTALL_PREFIX}) +install(TARGETS arrow_java_jni_gandiva DESTINATION ${CMAKE_INSTALL_LIBDIR}) From 59b57287e3e17e84ac41fdf7662fb2be2d4ef9f5 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Mon, 19 Sep 2022 14:45:40 +0900 Subject: [PATCH 097/133] ARROW-17767: [Java][ORC] Move JNI build configuration from cpp/ to java/ (#14163) Authored-by: Sutou Kouhei Signed-off-by: Sutou Kouhei --- ci/docker/java-jni-manylinux-201x.dockerfile | 2 +- ci/docker/linux-apt-jni.dockerfile | 2 +- ci/scripts/cpp_build.sh | 2 - ci/scripts/java_jni_build.sh | 3 +- ci/scripts/java_jni_macos_build.sh | 7 - ci/scripts/java_jni_manylinux_build.sh | 7 - ci/scripts/java_test.sh | 36 +++-- cpp/CMakeLists.txt | 4 - cpp/cmake_modules/DefineOptions.cmake | 4 - cpp/cmake_modules/ThirdpartyToolchain.cmake | 4 - cpp/src/arrow/gpu/ArrowCUDAConfig.cmake.in | 5 + cpp/src/plasma/CMakeLists.txt | 38 ----- .../org_apache_arrow_plasma_PlasmaClientJNI.h | 141 ------------------ docs/source/developers/java/building.rst | 13 +- java/CMakeLists.txt | 10 ++ java/plasma/CMakeLists.txt | 41 +++++ .../plasma/src/main/cpp/plasma_client.cc | 7 +- java/pom.xml | 5 +- 18 files changed, 98 insertions(+), 233 deletions(-) delete mode 100644 cpp/src/plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.h create mode 100644 java/plasma/CMakeLists.txt rename cpp/src/plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.cc => java/plasma/src/main/cpp/plasma_client.cc (98%) diff --git a/ci/docker/java-jni-manylinux-201x.dockerfile b/ci/docker/java-jni-manylinux-201x.dockerfile index 3fdbfb7088e..b3ecbf00a92 100644 --- a/ci/docker/java-jni-manylinux-201x.dockerfile +++ b/ci/docker/java-jni-manylinux-201x.dockerfile @@ -40,6 +40,6 @@ ENV JAVA_HOME=/usr/lib/jvm/java-$java-openjdk/ # For ci/scripts/{cpp,java}_*.sh ENV ARROW_HOME=/tmp/local \ ARROW_JAVA_CDATA=ON \ - ARROW_JNI=ON \ + ARROW_JAVA_JNI=ON \ ARROW_PLASMA=ON \ ARROW_USE_CCACHE=ON diff --git a/ci/docker/linux-apt-jni.dockerfile b/ci/docker/linux-apt-jni.dockerfile index 71826db0487..7b3e1b8416b 100644 --- a/ci/docker/linux-apt-jni.dockerfile +++ b/ci/docker/linux-apt-jni.dockerfile @@ -76,7 +76,7 @@ ENV ARROW_BUILD_TESTS=ON \ ARROW_GANDIVA=ON \ ARROW_HOME=/usr/local \ ARROW_JAVA_CDATA=ON \ - ARROW_JNI=ON \ + ARROW_JAVA_JNI=ON \ ARROW_ORC=ON \ ARROW_PARQUET=ON \ ARROW_PLASMA_JAVA_CLIENT=ON \ diff --git a/ci/scripts/cpp_build.sh b/ci/scripts/cpp_build.sh index 9aea4af8fb5..bb3c2b1bf13 100755 --- a/ci/scripts/cpp_build.sh +++ b/ci/scripts/cpp_build.sh @@ -97,14 +97,12 @@ cmake \ -DARROW_HDFS=${ARROW_HDFS:-ON} \ -DARROW_INSTALL_NAME_RPATH=${ARROW_INSTALL_NAME_RPATH:-ON} \ -DARROW_JEMALLOC=${ARROW_JEMALLOC:-ON} \ - -DARROW_JNI=${ARROW_JNI:-OFF} \ -DARROW_JSON=${ARROW_JSON:-ON} \ -DARROW_LARGE_MEMORY_TESTS=${ARROW_LARGE_MEMORY_TESTS:-OFF} \ -DARROW_MIMALLOC=${ARROW_MIMALLOC:-OFF} \ -DARROW_NO_DEPRECATED_API=${ARROW_NO_DEPRECATED_API:-OFF} \ -DARROW_ORC=${ARROW_ORC:-OFF} \ -DARROW_PARQUET=${ARROW_PARQUET:-OFF} \ - -DARROW_PLASMA_JAVA_CLIENT=${ARROW_PLASMA_JAVA_CLIENT:-OFF} \ -DARROW_PLASMA=${ARROW_PLASMA:-OFF} \ -DARROW_PYTHON=${ARROW_PYTHON:-OFF} \ -DARROW_RUNTIME_SIMD_LEVEL=${ARROW_RUNTIME_SIMD_LEVEL:-MAX} \ diff --git a/ci/scripts/java_jni_build.sh b/ci/scripts/java_jni_build.sh index 999506d65cf..3acaac7c4e5 100755 --- a/ci/scripts/java_jni_build.sh +++ b/ci/scripts/java_jni_build.sh @@ -53,11 +53,12 @@ cmake \ -DARROW_JAVA_JNI_ENABLE_DATASET=${ARROW_DATASET:-ON} \ -DARROW_JAVA_JNI_ENABLE_GANDIVA=${ARROW_GANDIVA:-ON} \ -DARROW_JAVA_JNI_ENABLE_ORC=${ARROW_ORC:-ON} \ + -DARROW_JAVA_JNI_ENABLE_PLASMA=${ARROW_PLASMA:-ON} \ -DBUILD_TESTING=${ARROW_JAVA_BUILD_TESTS} \ -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} \ -DCMAKE_PREFIX_PATH=${arrow_install_dir} \ - -DCMAKE_INSTALL_PREFIX=${prefix_dir} \ -DCMAKE_INSTALL_LIBDIR=lib \ + -DCMAKE_INSTALL_PREFIX=${prefix_dir} \ -DCMAKE_UNITY_BUILD=${CMAKE_UNITY_BUILD:-OFF} \ -GNinja \ ${JAVA_JNI_CMAKE_ARGS:-} \ diff --git a/ci/scripts/java_jni_macos_build.sh b/ci/scripts/java_jni_macos_build.sh index 51740d2537b..8923b851042 100755 --- a/ci/scripts/java_jni_macos_build.sh +++ b/ci/scripts/java_jni_macos_build.sh @@ -58,18 +58,15 @@ pushd "${build_dir}/cpp" cmake \ -DARROW_BUILD_SHARED=OFF \ -DARROW_BUILD_TESTS=${ARROW_BUILD_TESTS} \ - -DARROW_BUILD_UTILITIES=OFF \ -DARROW_CSV=${ARROW_DATASET} \ -DARROW_DATASET=${ARROW_DATASET} \ -DARROW_DEPENDENCY_USE_SHARED=OFF \ -DARROW_FILESYSTEM=${ARROW_FILESYSTEM} \ -DARROW_GANDIVA=${ARROW_GANDIVA} \ -DARROW_GANDIVA_STATIC_LIBSTDCPP=ON \ - -DARROW_JNI=ON \ -DARROW_ORC=${ARROW_ORC} \ -DARROW_PARQUET=${ARROW_PARQUET} \ -DARROW_PLASMA=${ARROW_PLASMA} \ - -DARROW_PLASMA_JAVA_CLIENT=${ARROW_PLASMA_JAVA_CLIENT} \ -DARROW_S3=${ARROW_S3} \ -DARROW_USE_CCACHE=${ARROW_USE_CCACHE} \ -DAWSSDK_SOURCE=BUNDLED \ @@ -112,12 +109,8 @@ if [ "${ARROW_USE_CCACHE}" == "ON" ]; then ccache -s fi -echo "=== Copying libraries to the distribution folder ===" -mkdir -p "${dist_dir}" -cp -L ${build_dir}/cpp/*/libplasma_java.dylib ${dist_dir} echo "=== Checking shared dependencies for libraries ===" - pushd ${dist_dir} archery linking check-dependencies \ --allow CoreFoundation \ diff --git a/ci/scripts/java_jni_manylinux_build.sh b/ci/scripts/java_jni_manylinux_build.sh index 2048ecf04a7..7f6e89cb7a3 100755 --- a/ci/scripts/java_jni_manylinux_build.sh +++ b/ci/scripts/java_jni_manylinux_build.sh @@ -66,7 +66,6 @@ pushd "${build_dir}/cpp" cmake \ -DARROW_BUILD_SHARED=OFF \ -DARROW_BUILD_TESTS=ON \ - -DARROW_BUILD_UTILITIES=OFF \ -DARROW_CSV=${ARROW_DATASET} \ -DARROW_DATASET=${ARROW_DATASET} \ -DARROW_DEPENDENCY_SOURCE="VCPKG" \ @@ -75,10 +74,8 @@ cmake \ -DARROW_GANDIVA_PC_CXX_FLAGS=${GANDIVA_CXX_FLAGS} \ -DARROW_GANDIVA=${ARROW_GANDIVA} \ -DARROW_JEMALLOC=${ARROW_JEMALLOC} \ - -DARROW_JNI=ON \ -DARROW_ORC=${ARROW_ORC} \ -DARROW_PARQUET=${ARROW_PARQUET} \ - -DARROW_PLASMA_JAVA_CLIENT=${ARROW_PLASMA_JAVA_CLIENT} \ -DARROW_PLASMA=${ARROW_PLASMA} \ -DARROW_RPATH_ORIGIN=${ARROW_RPATH_ORIGIN} \ -DARROW_S3=${ARROW_S3} \ @@ -133,11 +130,7 @@ if [ "${ARROW_USE_CCACHE}" == "ON" ]; then fi -echo "=== Copying libraries to the distribution folder ===" -cp -L ${build_dir}/cpp/*/libplasma_java.so ${dist_dir} - echo "=== Checking shared dependencies for libraries ===" - pushd ${dist_dir} archery linking check-dependencies \ --allow ld-linux-x86-64 \ diff --git a/ci/scripts/java_test.sh b/ci/scripts/java_test.sh index bb30894d9ef..c062590e07a 100755 --- a/ci/scripts/java_test.sh +++ b/ci/scripts/java_test.sh @@ -38,20 +38,36 @@ pushd ${source_dir} ${mvn} test -if [ "${ARROW_JNI}" = "ON" ]; then - ${mvn} test -Parrow-jni -pl adapter/orc,gandiva,dataset -Darrow.cpp.build.dir=${java_jni_dist_dir} +projects=() +if [ "${ARROW_DATASET}" = "ON" ]; then + projects+=(gandiva) fi +if [ "${ARROW_GANDIVA}" = "ON" ]; then + projects+=(gandiva) +fi +if [ "${ARROW_ORC}" = "ON" ]; then + projects+=(adapter/orc) +fi +if [ "${ARROW_PLASMA}" = "ON" ]; then + projects+=(plasma) +fi +if [ "${#projects[@]}" -gt 0 ]; then + ${mvn} test \ + -Parrow-jni \ + -pl $(IFS=,; echo "${projects[*]}") \ + -Darrow.cpp.build.dir=${java_jni_dist_dir} -if [ "${ARROW_JAVA_CDATA}" = "ON" ]; then - ${mvn} test -Parrow-c-data -pl c -Darrow.c.jni.dist.dir=${java_jni_dist_dir} + if [ "${ARROW_PLASMA}" = "ON" ]; then + pushd ${source_dir}/plasma + java -cp target/test-classes:target/classes \ + -Djava.library.path=${java_jni_dist_dir} \ + org.apache.arrow.plasma.PlasmaClientTest + popd + fi fi -if [ "${ARROW_PLASMA}" = "ON" ]; then - pushd ${source_dir}/plasma - java -cp target/test-classes:target/classes \ - -Djava.library.path=${java_jni_dist_dir} \ - org.apache.arrow.plasma.PlasmaClientTest - popd +if [ "${ARROW_JAVA_CDATA}" = "ON" ]; then + ${mvn} test -Parrow-c-data -pl c -Darrow.c.jni.dist.dir=${java_jni_dist_dir} fi popd diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a9bed5e5896..34582f6f072 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -425,10 +425,6 @@ if(MSVC_TOOLCHAIN) set(ARROW_USE_GLOG OFF) endif() -if(ARROW_JNI) - set(ARROW_BUILD_STATIC ON) -endif() - if(ARROW_ORC) set(ARROW_WITH_LZ4 ON) set(ARROW_WITH_SNAPPY ON) diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index 34ee8f3456a..8a1271dfcdf 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -269,8 +269,6 @@ takes precedence over ccache if a storage backend is configured" ON) define_option(ARROW_JEMALLOC ${ARROW_JEMALLOC_DESCRIPTION} ON) endif() - define_option(ARROW_JNI "Build the Arrow JNI lib" OFF) - define_option(ARROW_JSON "Build Arrow with JSON support (requires RapidJSON)" OFF) define_option(ARROW_MIMALLOC "Build the Arrow mimalloc-based allocator" OFF) @@ -281,8 +279,6 @@ takes precedence over ccache if a storage backend is configured" ON) define_option(ARROW_PLASMA "Build the plasma object store along with Arrow" OFF) - define_option(ARROW_PLASMA_JAVA_CLIENT "Build the plasma object store java client" OFF) - define_option(ARROW_PYTHON "Build the Arrow CPython extensions" OFF) define_option(ARROW_S3 "Build Arrow with S3 support (requires the AWS SDK for C++)" OFF) diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 5d1e50ffa25..0ab66c5bf07 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1626,10 +1626,6 @@ if(ARROW_WITH_PROTOBUF) if(ARROW_WITH_GRPC) # FlightSQL uses proto3 optionals, which require 3.15 or later. set(ARROW_PROTOBUF_REQUIRED_VERSION "3.15.0") - elseif(ARROW_GANDIVA_JAVA) - # google::protobuf::MessageLite::ByteSize() is deprecated since - # Protobuf 3.4.0. - set(ARROW_PROTOBUF_REQUIRED_VERSION "3.4.0") elseif(ARROW_SUBSTRAIT) # Substrait protobuf files use proto3 syntax set(ARROW_PROTOBUF_REQUIRED_VERSION "3.0.0") diff --git a/cpp/src/arrow/gpu/ArrowCUDAConfig.cmake.in b/cpp/src/arrow/gpu/ArrowCUDAConfig.cmake.in index b251b86f43e..bb36abf2411 100644 --- a/cpp/src/arrow/gpu/ArrowCUDAConfig.cmake.in +++ b/cpp/src/arrow/gpu/ArrowCUDAConfig.cmake.in @@ -28,6 +28,11 @@ include(CMakeFindDependencyMacro) find_dependency(Arrow) +if(CMAKE_VERSION VERSION_LESS 3.17) + find_package(CUDA REQUIRED) +else() + find_package(CUDAToolkit REQUIRED) +endif() include("${CMAKE_CURRENT_LIST_DIR}/ArrowCUDATargets.cmake") diff --git a/cpp/src/plasma/CMakeLists.txt b/cpp/src/plasma/CMakeLists.txt index c98a66367bc..7f05d5bcdc6 100644 --- a/cpp/src/plasma/CMakeLists.txt +++ b/cpp/src/plasma/CMakeLists.txt @@ -171,44 +171,6 @@ install(TARGETS plasma-store-server ${INSTALL_IS_OPTIONAL} EXPORT plasma_targets DESTINATION ${CMAKE_INSTALL_BINDIR}) -if(ARROW_PLASMA_JAVA_CLIENT) - # Plasma java client support - find_package(JNI REQUIRED) - # add jni support - include_directories(${JAVA_INCLUDE_PATH}) - include_directories(${JAVA_INCLUDE_PATH2}) - if(JNI_FOUND) - message(STATUS "JNI_INCLUDE_DIRS = ${JNI_INCLUDE_DIRS}") - message(STATUS "JNI_LIBRARIES = ${JNI_LIBRARIES}") - else() - message(WARNING "Could not find JNI") - endif() - - add_compile_options("-I$ENV{JAVA_HOME}/include/") - if(WIN32) - add_compile_options("-I$ENV{JAVA_HOME}/include/win32") - elseif(APPLE) - add_compile_options("-I$ENV{JAVA_HOME}/include/darwin") - else() # linux - add_compile_options("-I$ENV{JAVA_HOME}/include/linux") - endif() - - include_directories("${CMAKE_CURRENT_LIST_DIR}/lib/java") - - file(GLOB PLASMA_LIBRARY_EXT_java_SRC lib/java/*.cc lib/*.cc) - add_library(plasma_java SHARED ${PLASMA_LIBRARY_EXT_java_SRC}) - - if(APPLE) - target_link_libraries(plasma_java - plasma_static - ${PLASMA_STATIC_LINK_LIBS} - "-undefined dynamic_lookup" - ${PTHREAD_LIBRARY}) - else() - target_link_libraries(plasma_java plasma_static ${PLASMA_STATIC_LINK_LIBS} - ${PTHREAD_LIBRARY}) - endif() -endif() # # Unit tests # diff --git a/cpp/src/plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.h b/cpp/src/plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.h deleted file mode 100644 index 8a18be91deb..00000000000 --- a/cpp/src/plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.h +++ /dev/null @@ -1,141 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -/* DO NOT EDIT THIS FILE - it is machine generated */ -#include -/* Header for class org_apache_arrow_plasma_PlasmaClientJNI */ - -#ifndef _Included_org_apache_arrow_plasma_PlasmaClientJNI -#define _Included_org_apache_arrow_plasma_PlasmaClientJNI -#ifdef __cplusplus -extern "C" { -#endif -/* - * Class: org_apache_arrow_plasma_PlasmaClientJNI - * Method: connect - * Signature: (Ljava/lang/String;Ljava/lang/String;I)J - */ -JNIEXPORT jlong JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_connect( - JNIEnv*, jclass, jstring, jstring, jint); - -/* - * Class: org_apache_arrow_plasma_PlasmaClientJNI - * Method: disconnect - * Signature: (J)V - */ -JNIEXPORT void JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_disconnect(JNIEnv*, - jclass, - jlong); - -/* - * Class: org_apache_arrow_plasma_PlasmaClientJNI - * Method: create - * Signature: (J[BI[B)Ljava/nio/ByteBuffer; - */ -JNIEXPORT jobject JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_create( - JNIEnv*, jclass, jlong, jbyteArray, jint, jbyteArray); - -/* - * Class: org_apache_arrow_plasma_PlasmaClientJNI - * Method: hash - * Signature: (J[B)[B - */ -JNIEXPORT jbyteArray JNICALL -Java_org_apache_arrow_plasma_PlasmaClientJNI_hash(JNIEnv*, jclass, jlong, jbyteArray); - -/* - * Class: org_apache_arrow_plasma_PlasmaClientJNI - * Method: seal - * Signature: (J[B)V - */ -JNIEXPORT void JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_seal(JNIEnv*, jclass, - jlong, - jbyteArray); - -/* - * Class: org_apache_arrow_plasma_PlasmaClientJNI - * Method: release - * Signature: (J[B)V - */ -JNIEXPORT void JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_release(JNIEnv*, - jclass, jlong, - jbyteArray); - -/* - * Class: org_apache_arrow_plasma_PlasmaClientJNI - * Method: delete - * Signature: (J[B)V - */ -JNIEXPORT void JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_delete(JNIEnv*, - jclass, jlong, - jbyteArray); - -/* - * Class: org_apache_arrow_plasma_PlasmaClientJNI - * Method: get - * Signature: (J[[BI)[[Ljava/nio/ByteBuffer; - */ -JNIEXPORT jobjectArray JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_get( - JNIEnv*, jclass, jlong, jobjectArray, jint); - -/* - * Class: org_apache_arrow_plasma_PlasmaClientJNI - * Method: contains - * Signature: (J[B)Z - */ -JNIEXPORT jboolean JNICALL -Java_org_apache_arrow_plasma_PlasmaClientJNI_contains(JNIEnv*, jclass, jlong, jbyteArray); - -/* - * Class: org_apache_arrow_plasma_PlasmaClientJNI - * Method: fetch - * Signature: (J[[B)V - */ -JNIEXPORT void JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_fetch(JNIEnv*, jclass, - jlong, - jobjectArray); - -/* - * Class: org_apache_arrow_plasma_PlasmaClientJNI - * Method: wait - * Signature: (J[[BII)[[B - */ -JNIEXPORT jobjectArray JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_wait( - JNIEnv*, jclass, jlong, jobjectArray, jint, jint); - -/* - * Class: org_apache_arrow_plasma_PlasmaClientJNI - * Method: evict - * Signature: (JJ)J - */ -JNIEXPORT jlong JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_evict(JNIEnv*, - jclass, jlong, - jlong); - -/* - * Class: org_apache_arrow_plasma_PlasmaClientJNI - * Method: list - * Signature: (J)[[B - */ -JNIEXPORT jobjectArray JNICALL Java_org_apache_arrow_plasma_PlasmaClientJNI_list(JNIEnv*, - jclass, - jlong); - -#ifdef __cplusplus -} -#endif -#endif diff --git a/docs/source/developers/java/building.rst b/docs/source/developers/java/building.rst index 846277faa34..d0f366ca8e4 100644 --- a/docs/source/developers/java/building.rst +++ b/docs/source/developers/java/building.rst @@ -119,7 +119,7 @@ Maven $ export JAVA_HOME= $ java --version $ mvn clean generate-resources -Pgenerate-jni-dylib_so -N - $ ls -latr java-dist/lib + $ ls -latr java-dist/lib/*_{jni,java}.* |__ libarrow_dataset_jni.dylib |__ libarrow_orc_jni.dylib |__ libgandiva_jni.dylib @@ -141,7 +141,8 @@ CMake -DARROW_JAVA_JNI_ENABLE_DEFAULT=OFF \ -DBUILD_TESTING=OFF \ -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_INSTALL_PREFIX=java-dist/lib + -DCMAKE_INSTALL_LIBDIR=lib \ + -DCMAKE_INSTALL_PREFIX=java-dist $ cmake --build java-jni --target install --config Release $ ls -latr java-dist/lib |__ libarrow_cdata_jni.dylib @@ -161,6 +162,7 @@ CMake $ cmake \ -S cpp \ -B cpp-jni \ + -DARROW_BUILD_SHARED=OFF \ -DARROW_CSV=ON \ -DARROW_DATASET=ON \ -DARROW_DEPENDENCY_SOURCE=BUNDLED \ @@ -168,11 +170,9 @@ CMake -DARROW_FILESYSTEM=ON \ -DARROW_GANDIVA=ON \ -DARROW_GANDIVA_STATIC_LIBSTDCPP=ON \ - -DARROW_JNI=ON \ -DARROW_ORC=ON \ -DARROW_PARQUET=ON \ -DARROW_PLASMA=ON \ - -DARROW_PLASMA_JAVA_CLIENT=ON \ -DARROW_S3=ON \ -DARROW_USE_CCACHE=ON \ -DCMAKE_BUILD_TYPE=Release \ @@ -187,10 +187,11 @@ CMake -DARROW_JAVA_JNI_ENABLE_DEFAULT=ON \ -DBUILD_TESTING=OFF \ -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_INSTALL_PREFIX=java-dist/lib \ + -DCMAKE_INSTALL_LIBDIR=lib \ + -DCMAKE_INSTALL_PREFIX=java-dist \ -DCMAKE_PREFIX_PATH=$PWD/java-dist $ cmake --build java-jni --target install --config Release - $ ls -latr java-dist/lib + $ ls -latr java-dist/lib/*_{jni,java}.* |__ libarrow_dataset_jni.dylib |__ libarrow_orc_jni.dylib |__ libgandiva_jni.dylib diff --git a/java/CMakeLists.txt b/java/CMakeLists.txt index f184bb0a6f9..371f3e60758 100644 --- a/java/CMakeLists.txt +++ b/java/CMakeLists.txt @@ -18,6 +18,12 @@ cmake_minimum_required(VERSION 3.11) message(STATUS "Building using CMake version: ${CMAKE_VERSION}") +# find_package() uses _ROOT variables. +# https://cmake.org/cmake/help/latest/policy/CMP0074.html +if(POLICY CMP0074) + cmake_policy(SET CMP0074 NEW) +endif() + project(arrow-java-jni) if("${CMAKE_CXX_STANDARD}" STREQUAL "") @@ -31,6 +37,7 @@ option(ARROW_JAVA_JNI_ENABLE_C "Enable C data interface" ${ARROW_JAVA_JNI_ENABLE option(ARROW_JAVA_JNI_ENABLE_DATASET "Enable dataset" ${ARROW_JAVA_JNI_ENABLE_DEFAULT}) option(ARROW_JAVA_JNI_ENABLE_GANDIVA "Enable Gandiva" ${ARROW_JAVA_JNI_ENABLE_DEFAULT}) option(ARROW_JAVA_JNI_ENABLE_ORC "Enable ORC" ${ARROW_JAVA_JNI_ENABLE_DEFAULT}) +option(ARROW_JAVA_JNI_ENABLE_PLASMA "Enable Plasma" ${ARROW_JAVA_JNI_ENABLE_DEFAULT}) include(GNUInstallDirs) @@ -80,3 +87,6 @@ endif() if(ARROW_JAVA_JNI_ENABLE_ORC) add_subdirectory(adapter/orc) endif() +if(ARROW_JAVA_JNI_ENABLE_PLASMA) + add_subdirectory(plasma) +endif() diff --git a/java/plasma/CMakeLists.txt b/java/plasma/CMakeLists.txt new file mode 100644 index 00000000000..4cad01cb44b --- /dev/null +++ b/java/plasma/CMakeLists.txt @@ -0,0 +1,41 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +find_package(Plasma REQUIRED) + +include_directories(${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR} + ${JNI_INCLUDE_DIRS} ${JNI_HEADERS_DIR}) + +add_jar(arrow_java_jni_plasma_jar + src/main/java/org/apache/arrow/plasma/PlasmaClientJNI.java + src/main/java/org/apache/arrow/plasma/exceptions/DuplicateObjectException.java + src/main/java/org/apache/arrow/plasma/exceptions/PlasmaClientException.java + src/main/java/org/apache/arrow/plasma/exceptions/PlasmaOutOfMemoryException.java + GENERATE_NATIVE_HEADERS + arrow_java_jni_plasma_headers) + +add_library(arrow_java_jni_plasma SHARED src/main/cpp/plasma_client.cc) +set_property(TARGET arrow_java_jni_plasma PROPERTY OUTPUT_NAME "plasma_java") +target_link_libraries(arrow_java_jni_plasma arrow_java_jni_plasma_headers jni + Plasma::plasma_static) + +if(APPLE) + set_target_properties(arrow_java_jni_plasma PROPERTIES LINK_FLAGS + "-undefined dynamic_lookup") +endif() + +install(TARGETS arrow_java_jni_plasma DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/cpp/src/plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.cc b/java/plasma/src/main/cpp/plasma_client.cc similarity index 98% rename from cpp/src/plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.cc rename to java/plasma/src/main/cpp/plasma_client.cc index 10e0fcb371d..19267ba21e6 100644 --- a/cpp/src/plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.cc +++ b/java/plasma/src/main/cpp/plasma_client.cc @@ -15,9 +15,6 @@ // specific language governing permissions and limitations // under the License. -#include "plasma/lib/java/org_apache_arrow_plasma_PlasmaClientJNI.h" - -#include #include #include @@ -28,9 +25,9 @@ #include #include -#include "arrow/util/logging.h" +#include -#include "plasma/client.h" +#include "org_apache_arrow_plasma_PlasmaClientJNI.h" constexpr jsize OBJECT_ID_SIZE = sizeof(plasma::ObjectID) / sizeof(jbyte); diff --git a/java/pom.xml b/java/pom.xml index b5db145cc5c..ec3d1157ff8 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -956,6 +956,7 @@ -DARROW_JAVA_JNI_ENABLE_DEFAULT=OFF -DBUILD_TESTING=OFF -DCMAKE_BUILD_TYPE=Release + -DCMAKE_INSTALL_LIBDIR=lib -DCMAKE_INSTALL_PREFIX=${arrow.c.jni.dist.dir} ../ @@ -1017,6 +1018,7 @@ -S cpp -B cpp-jni + -DARROW_BUILD_SHARED=OFF -DARROW_CSV=ON -DARROW_DATASET=ON -DARROW_DEPENDENCY_SOURCE=BUNDLED @@ -1024,11 +1026,9 @@ -DARROW_FILESYSTEM=ON -DARROW_GANDIVA=ON -DARROW_GANDIVA_STATIC_LIBSTDCPP=ON - -DARROW_JNI=ON -DARROW_ORC=ON -DARROW_PARQUET=ON -DARROW_PLASMA=ON - -DARROW_PLASMA_JAVA_CLIENT=ON -DARROW_S3=ON -DARROW_USE_CCACHE=ON -DCMAKE_BUILD_TYPE=Release @@ -1070,6 +1070,7 @@ -DARROW_JAVA_JNI_ENABLE_DEFAULT=ON -DBUILD_TESTING=OFF -DCMAKE_BUILD_TYPE=Release + -DCMAKE_INSTALL_LIBDIR=lib -DCMAKE_INSTALL_PREFIX=${arrow.c.jni.dist.dir} -DCMAKE_PREFIX_PATH=${project.basedir}/../java-dist From 081b70bafeab1091428dcc6950ede3fdafd590c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ra=C3=BAl=20Cumplido?= Date: Mon, 19 Sep 2022 14:52:34 +0200 Subject: [PATCH 098/133] MINOR: [Java][CI] Fix grep for new nightlies versioning (#14166) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I've noticed that the latest Java nightlies where not uploaded https://nightlies.apache.org/arrow/java/org/apache/arrow/arrow-c-data/. I missed to update the grep on this PR: https://github.com/apache/arrow/pull/14135 I have now tested locally the bash script to validate the repo structure will now be generated correctly: ``` repo └── org └── apache └── arrow ├── arrow-algorithm │   ├── 10.0.0-SNAPSHOT │   │   ├── arrow-algorithm-10.0.0-SNAPSHOT.jar │   │   ├── arrow-algorithm-10.0.0-SNAPSHOT-javadoc.jar │   │   ├── arrow-algorithm-10.0.0-SNAPSHOT.pom │   │   ├── arrow-algorithm-10.0.0-SNAPSHOT-sources.jar │   │   └── arrow-algorithm-10.0.0-SNAPSHOT-tests.jar │   └── 2022-09-18 │   ├── arrow-algorithm-10.0.0-SNAPSHOT.jar │   ├── arrow-algorithm-10.0.0-SNAPSHOT-javadoc.jar │   ├── arrow-algorithm-10.0.0-SNAPSHOT.pom │   ├── arrow-algorithm-10.0.0-SNAPSHOT-sources.jar │   └── arrow-algorithm-10.0.0-SNAPSHOT-tests.jar ├── arrow-avro │   ├── 10.0.0-SNAPSHOT │   │   ├── arrow-avro-10.0.0-SNAPSHOT.jar │   │   ├── arrow-avro-10.0.0-SNAPSHOT-javadoc.jar │   │   ├── arrow-avro-10.0.0-SNAPSHOT.pom │   │   ├── arrow-avro-10.0.0-SNAPSHOT-sources.jar │   │   └── arrow-avro-10.0.0-SNAPSHOT-tests.jar │   └── 2022-09-18 │   ├── arrow-avro-10.0.0-SNAPSHOT.jar │   ├── arrow-avro-10.0.0-SNAPSHOT-javadoc.jar │   ├── arrow-avro-10.0.0-SNAPSHOT.pom │   ├── arrow-avro-10.0.0-SNAPSHOT-sources.jar │   └── arrow-avro-10.0.0-SNAPSHOT-tests.jar ... ``` Authored-by: Raúl Cumplido Signed-off-by: Sutou Kouhei --- .github/workflows/java_nightly.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/java_nightly.yml b/.github/workflows/java_nightly.yml index badb9d94e52..f853947e6db 100644 --- a/.github/workflows/java_nightly.yml +++ b/.github/workflows/java_nightly.yml @@ -103,7 +103,7 @@ jobs: fi PATTERN_TO_GET_LIB_AND_VERSION='([a-z].+)-([0-9]+.[0-9]+.[0-9]+-SNAPSHOT)' mkdir -p repo/org/apache/arrow/ - for LIBRARY in $(ls binaries/$PREFIX/java-jars | grep -E '.jar|.pom' | grep dev); do + for LIBRARY in $(ls binaries/$PREFIX/java-jars | grep -E '.jar|.pom' | grep SNAPSHOT); do [[ $LIBRARY =~ $PATTERN_TO_GET_LIB_AND_VERSION ]] mkdir -p repo/org/apache/arrow/${BASH_REMATCH[1]}/${BASH_REMATCH[2]} mkdir -p repo/org/apache/arrow/${BASH_REMATCH[1]}/${DATE} From 529f653dfa58887522af06028e5c32e8dd1a14ea Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Mon, 19 Sep 2022 08:31:16 -1000 Subject: [PATCH 099/133] ARROW-17517: [C++] Remove internal headers from substrait API (#14131) Lead-authored-by: Weston Pace Co-authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/compute/api.h | 3 +- cpp/src/arrow/compute/exec/aggregate.cc | 2 + cpp/src/arrow/compute/exec/benchmark_util.cc | 1 + cpp/src/arrow/compute/exec/exec_plan.cc | 95 +------------ cpp/src/arrow/compute/exec/exec_plan.h | 49 +------ .../compute/exec/expression_benchmark.cc | 2 + cpp/src/arrow/compute/exec/filter_node.cc | 1 + cpp/src/arrow/compute/exec/map_node.cc | 131 ++++++++++++++++++ cpp/src/arrow/compute/exec/map_node.h | 80 +++++++++++ cpp/src/arrow/compute/exec/project_node.cc | 1 + cpp/src/arrow/compute/exec/tpch_node.cc | 1 + cpp/src/arrow/compute/type_fwd.h | 3 + cpp/src/arrow/dataset/file_base.cc | 1 + cpp/src/arrow/dataset/type_fwd.h | 2 + .../arrow/engine/substrait/extension_set.cc | 104 ++++++++------ .../arrow/engine/substrait/extension_set.h | 30 ++-- cpp/src/arrow/engine/substrait/options.h | 7 + cpp/src/arrow/engine/substrait/serde.cc | 10 ++ cpp/src/arrow/engine/substrait/serde.h | 12 +- cpp/src/arrow/engine/substrait/serde_test.cc | 1 + cpp/src/arrow/engine/substrait/type_fwd.h | 31 +++++ cpp/src/arrow/engine/substrait/util.cc | 3 + .../arrow/flight/sql/example/acero_server.cc | 2 + cpp/src/arrow/util/type_fwd.h | 1 + 25 files changed, 371 insertions(+), 203 deletions(-) create mode 100644 cpp/src/arrow/compute/exec/map_node.cc create mode 100644 cpp/src/arrow/compute/exec/map_node.h create mode 100644 cpp/src/arrow/engine/substrait/type_fwd.h diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 88d72b11832..23f0a7c9f1a 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -397,6 +397,7 @@ if(ARROW_COMPUTE) compute/exec/hash_join_node.cc compute/exec/key_hash.cc compute/exec/key_map.cc + compute/exec/map_node.cc compute/exec/order_by_impl.cc compute/exec/partition_util.cc compute/exec/options.cc diff --git a/cpp/src/arrow/compute/api.h b/cpp/src/arrow/compute/api.h index 3539bab038a..ba8d26da4d5 100644 --- a/cpp/src/arrow/compute/api.h +++ b/cpp/src/arrow/compute/api.h @@ -56,4 +56,5 @@ /// @{ /// @} -#include "arrow/compute/exec.h" // IWYU pragma: export +#include "arrow/compute/exec.h" // IWYU pragma: export +#include "arrow/compute/exec/exec_plan.h" // IWYU pragma: export diff --git a/cpp/src/arrow/compute/exec/aggregate.cc b/cpp/src/arrow/compute/exec/aggregate.cc index cc2c464d42b..95f85317c8d 100644 --- a/cpp/src/arrow/compute/exec/aggregate.cc +++ b/cpp/src/arrow/compute/exec/aggregate.cc @@ -18,11 +18,13 @@ #include "arrow/compute/exec/aggregate.h" #include +#include #include "arrow/compute/exec_internal.h" #include "arrow/compute/registry.h" #include "arrow/compute/row/grouper.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/logging.h" #include "arrow/util/task_group.h" namespace arrow { diff --git a/cpp/src/arrow/compute/exec/benchmark_util.cc b/cpp/src/arrow/compute/exec/benchmark_util.cc index d4e14540a54..dcc7ca6e165 100644 --- a/cpp/src/arrow/compute/exec/benchmark_util.cc +++ b/cpp/src/arrow/compute/exec/benchmark_util.cc @@ -24,6 +24,7 @@ #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/task_util.h" +#include "arrow/compute/exec/util.h" #include "arrow/util/macros.h" namespace arrow { diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 057e1ace5cd..8d9b214a720 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -26,6 +26,7 @@ #include "arrow/compute/exec/expression.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/task_util.h" +#include "arrow/compute/exec/util.h" #include "arrow/compute/exec_internal.h" #include "arrow/compute/registry.h" #include "arrow/datum.h" @@ -474,100 +475,6 @@ bool ExecNode::ErrorIfNotOk(Status status) { return true; } -MapNode::MapNode(ExecPlan* plan, std::vector inputs, - std::shared_ptr output_schema, bool async_mode) - : ExecNode(plan, std::move(inputs), /*input_labels=*/{"target"}, - std::move(output_schema), - /*num_outputs=*/1) { - if (async_mode) { - executor_ = plan_->exec_context()->executor(); - } else { - executor_ = nullptr; - } -} - -void MapNode::ErrorReceived(ExecNode* input, Status error) { - DCHECK_EQ(input, inputs_[0]); - EVENT(span_, "ErrorReceived", {{"error.message", error.message()}}); - outputs_[0]->ErrorReceived(this, std::move(error)); -} - -void MapNode::InputFinished(ExecNode* input, int total_batches) { - DCHECK_EQ(input, inputs_[0]); - EVENT(span_, "InputFinished", {{"batches.length", total_batches}}); - outputs_[0]->InputFinished(this, total_batches); - if (input_counter_.SetTotal(total_batches)) { - this->Finish(); - } -} - -Status MapNode::StartProducing() { - START_COMPUTE_SPAN( - span_, std::string(kind_name()) + ":" + label(), - {{"node.label", label()}, {"node.detail", ToString()}, {"node.kind", kind_name()}}); - return Status::OK(); -} - -void MapNode::PauseProducing(ExecNode* output, int32_t counter) { - inputs_[0]->PauseProducing(this, counter); -} - -void MapNode::ResumeProducing(ExecNode* output, int32_t counter) { - inputs_[0]->ResumeProducing(this, counter); -} - -void MapNode::StopProducing(ExecNode* output) { - DCHECK_EQ(output, outputs_[0]); - StopProducing(); -} - -void MapNode::StopProducing() { - EVENT(span_, "StopProducing"); - if (executor_) { - this->stop_source_.RequestStop(); - } - if (input_counter_.Cancel()) { - this->Finish(); - } - inputs_[0]->StopProducing(this); -} - -void MapNode::SubmitTask(std::function(ExecBatch)> map_fn, - ExecBatch batch) { - Status status; - // This will be true if the node is stopped early due to an error or manual - // cancellation - if (input_counter_.Completed()) { - return; - } - auto task = [this, map_fn, batch]() { - auto guarantee = batch.guarantee; - auto output_batch = map_fn(std::move(batch)); - if (ErrorIfNotOk(output_batch.status())) { - return output_batch.status(); - } - output_batch->guarantee = guarantee; - outputs_[0]->InputReceived(this, output_batch.MoveValueUnsafe()); - return Status::OK(); - }; - - status = task(); - if (!status.ok()) { - if (input_counter_.Cancel()) { - this->Finish(status); - } - inputs_[0]->StopProducing(this); - return; - } - if (input_counter_.Increment()) { - this->Finish(); - } -} - -void MapNode::Finish(Status finish_st /*= Status::OK()*/) { - this->finished_.MarkFinished(finish_st); -} - std::shared_ptr MakeGeneratorReader( std::shared_ptr schema, std::function>()> gen, MemoryPool* pool) { diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 3ff2340856f..f82afb604a1 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -17,21 +17,24 @@ #pragma once +#include +#include #include #include #include #include +#include #include #include "arrow/compute/exec.h" -#include "arrow/compute/exec/util.h" #include "arrow/compute/type_fwd.h" #include "arrow/type_fwd.h" -#include "arrow/util/async_util.h" #include "arrow/util/cancel.h" +#include "arrow/util/future.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/macros.h" #include "arrow/util/tracing.h" +#include "arrow/util/type_fwd.h" #include "arrow/util/visibility.h" namespace arrow { @@ -369,48 +372,6 @@ class ARROW_EXPORT ExecNode { util::tracing::Span span_; }; -/// \brief MapNode is an ExecNode type class which process a task like filter/project -/// (See SubmitTask method) to each given ExecBatch object, which have one input, one -/// output, and are pure functions on the input -/// -/// A simple parallel runner is created with a "map_fn" which is just a function that -/// takes a batch in and returns a batch. This simple parallel runner also needs an -/// executor (use simple synchronous runner if there is no executor) - -class ARROW_EXPORT MapNode : public ExecNode { - public: - MapNode(ExecPlan* plan, std::vector inputs, - std::shared_ptr output_schema, bool async_mode); - - void ErrorReceived(ExecNode* input, Status error) override; - - void InputFinished(ExecNode* input, int total_batches) override; - - Status StartProducing() override; - - void PauseProducing(ExecNode* output, int32_t counter) override; - - void ResumeProducing(ExecNode* output, int32_t counter) override; - - void StopProducing(ExecNode* output) override; - - void StopProducing() override; - - protected: - void SubmitTask(std::function(ExecBatch)> map_fn, ExecBatch batch); - - virtual void Finish(Status finish_st = Status::OK()); - - protected: - // Counter for the number of batches received - AtomicCounter input_counter_; - - ::arrow::internal::Executor* executor_; - - // Variable used to cancel remaining tasks in the executor - StopSource stop_source_; -}; - /// \brief An extensible registry for factories of ExecNodes class ARROW_EXPORT ExecFactoryRegistry { public: diff --git a/cpp/src/arrow/compute/exec/expression_benchmark.cc b/cpp/src/arrow/compute/exec/expression_benchmark.cc index debd2284980..e431497e45b 100644 --- a/cpp/src/arrow/compute/exec/expression_benchmark.cc +++ b/cpp/src/arrow/compute/exec/expression_benchmark.cc @@ -17,6 +17,8 @@ #include "benchmark/benchmark.h" +#include + #include "arrow/compute/cast.h" #include "arrow/compute/exec/expression.h" #include "arrow/compute/exec/test_util.h" diff --git a/cpp/src/arrow/compute/exec/filter_node.cc b/cpp/src/arrow/compute/exec/filter_node.cc index b424da35f85..280d1e9ae00 100644 --- a/cpp/src/arrow/compute/exec/filter_node.cc +++ b/cpp/src/arrow/compute/exec/filter_node.cc @@ -19,6 +19,7 @@ #include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression.h" +#include "arrow/compute/exec/map_node.h" #include "arrow/compute/exec/options.h" #include "arrow/datum.h" #include "arrow/result.h" diff --git a/cpp/src/arrow/compute/exec/map_node.cc b/cpp/src/arrow/compute/exec/map_node.cc new file mode 100644 index 00000000000..b99d0644905 --- /dev/null +++ b/cpp/src/arrow/compute/exec/map_node.cc @@ -0,0 +1,131 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/exec/map_node.h" + +#include +#include +#include +#include +#include + +#include "arrow/compute/exec.h" +#include "arrow/compute/exec/expression.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/logging.h" +#include "arrow/util/tracing_internal.h" + +namespace arrow { +namespace compute { + +MapNode::MapNode(ExecPlan* plan, std::vector inputs, + std::shared_ptr output_schema, bool async_mode) + : ExecNode(plan, std::move(inputs), /*input_labels=*/{"target"}, + std::move(output_schema), + /*num_outputs=*/1) { + if (async_mode) { + executor_ = plan_->exec_context()->executor(); + } else { + executor_ = nullptr; + } +} + +void MapNode::ErrorReceived(ExecNode* input, Status error) { + DCHECK_EQ(input, inputs_[0]); + EVENT(span_, "ErrorReceived", {{"error.message", error.message()}}); + outputs_[0]->ErrorReceived(this, std::move(error)); +} + +void MapNode::InputFinished(ExecNode* input, int total_batches) { + DCHECK_EQ(input, inputs_[0]); + EVENT(span_, "InputFinished", {{"batches.length", total_batches}}); + outputs_[0]->InputFinished(this, total_batches); + if (input_counter_.SetTotal(total_batches)) { + this->Finish(); + } +} + +Status MapNode::StartProducing() { + START_COMPUTE_SPAN( + span_, std::string(kind_name()) + ":" + label(), + {{"node.label", label()}, {"node.detail", ToString()}, {"node.kind", kind_name()}}); + return Status::OK(); +} + +void MapNode::PauseProducing(ExecNode* output, int32_t counter) { + inputs_[0]->PauseProducing(this, counter); +} + +void MapNode::ResumeProducing(ExecNode* output, int32_t counter) { + inputs_[0]->ResumeProducing(this, counter); +} + +void MapNode::StopProducing(ExecNode* output) { + DCHECK_EQ(output, outputs_[0]); + StopProducing(); +} + +void MapNode::StopProducing() { + EVENT(span_, "StopProducing"); + if (executor_) { + this->stop_source_.RequestStop(); + } + if (input_counter_.Cancel()) { + this->Finish(); + } + inputs_[0]->StopProducing(this); +} + +void MapNode::SubmitTask(std::function(ExecBatch)> map_fn, + ExecBatch batch) { + Status status; + // This will be true if the node is stopped early due to an error or manual + // cancellation + if (input_counter_.Completed()) { + return; + } + auto task = [this, map_fn, batch]() { + auto guarantee = batch.guarantee; + auto output_batch = map_fn(std::move(batch)); + if (ErrorIfNotOk(output_batch.status())) { + return output_batch.status(); + } + output_batch->guarantee = guarantee; + outputs_[0]->InputReceived(this, output_batch.MoveValueUnsafe()); + return Status::OK(); + }; + + status = task(); + if (!status.ok()) { + if (input_counter_.Cancel()) { + this->Finish(status); + } + inputs_[0]->StopProducing(this); + return; + } + if (input_counter_.Increment()) { + this->Finish(); + } +} + +void MapNode::Finish(Status finish_st /*= Status::OK()*/) { + this->finished_.MarkFinished(finish_st); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/map_node.h b/cpp/src/arrow/compute/exec/map_node.h new file mode 100644 index 00000000000..63d9db4a782 --- /dev/null +++ b/cpp/src/arrow/compute/exec/map_node.h @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// \brief MapNode is an ExecNode type class which process a task like filter/project +/// (See SubmitTask method) to each given ExecBatch object, which have one input, one +/// output, and are pure functions on the input +/// +/// A simple parallel runner is created with a "map_fn" which is just a function that +/// takes a batch in and returns a batch. This simple parallel runner also needs an +/// executor (use simple synchronous runner if there is no executor) + +#pragma once + +#include +#include +#include +#include + +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/util.h" +#include "arrow/compute/type_fwd.h" +#include "arrow/status.h" +#include "arrow/type_fwd.h" +#include "arrow/util/cancel.h" +#include "arrow/util/type_fwd.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +class ARROW_EXPORT MapNode : public ExecNode { + public: + MapNode(ExecPlan* plan, std::vector inputs, + std::shared_ptr output_schema, bool async_mode); + + void ErrorReceived(ExecNode* input, Status error) override; + + void InputFinished(ExecNode* input, int total_batches) override; + + Status StartProducing() override; + + void PauseProducing(ExecNode* output, int32_t counter) override; + + void ResumeProducing(ExecNode* output, int32_t counter) override; + + void StopProducing(ExecNode* output) override; + + void StopProducing() override; + + protected: + void SubmitTask(std::function(ExecBatch)> map_fn, ExecBatch batch); + + virtual void Finish(Status finish_st = Status::OK()); + + protected: + // Counter for the number of batches received + AtomicCounter input_counter_; + + ::arrow::internal::Executor* executor_; + + // Variable used to cancel remaining tasks in the executor + StopSource stop_source_; +}; + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/project_node.cc b/cpp/src/arrow/compute/exec/project_node.cc index 76925eb6139..678925901c4 100644 --- a/cpp/src/arrow/compute/exec/project_node.cc +++ b/cpp/src/arrow/compute/exec/project_node.cc @@ -21,6 +21,7 @@ #include "arrow/compute/exec.h" #include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression.h" +#include "arrow/compute/exec/map_node.h" #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/util.h" #include "arrow/datum.h" diff --git a/cpp/src/arrow/compute/exec/tpch_node.cc b/cpp/src/arrow/compute/exec/tpch_node.cc index 40d44dccccf..13fbef2bd5c 100644 --- a/cpp/src/arrow/compute/exec/tpch_node.cc +++ b/cpp/src/arrow/compute/exec/tpch_node.cc @@ -22,6 +22,7 @@ #include "arrow/util/formatting.h" #include "arrow/util/future.h" #include "arrow/util/io_util.h" +#include "arrow/util/logging.h" #include "arrow/util/make_unique.h" #include "arrow/util/pcg_random.h" #include "arrow/util/unreachable.h" diff --git a/cpp/src/arrow/compute/type_fwd.h b/cpp/src/arrow/compute/type_fwd.h index 62f15c16000..1494d3e1d13 100644 --- a/cpp/src/arrow/compute/type_fwd.h +++ b/cpp/src/arrow/compute/type_fwd.h @@ -40,11 +40,14 @@ struct VectorKernel; struct KernelState; +struct Declaration; class Expression; class ExecNode; class ExecPlan; class ExecNodeOptions; class ExecFactoryRegistry; +class SinkNodeConsumer; + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/dataset/file_base.cc b/cpp/src/arrow/dataset/file_base.cc index 23f4d09a9d2..ecf5d106d2e 100644 --- a/cpp/src/arrow/dataset/file_base.cc +++ b/cpp/src/arrow/dataset/file_base.cc @@ -26,6 +26,7 @@ #include "arrow/compute/api_scalar.h" #include "arrow/compute/exec/forest_internal.h" +#include "arrow/compute/exec/map_node.h" #include "arrow/compute/exec/subtree_internal.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/dataset_writer.h" diff --git a/cpp/src/arrow/dataset/type_fwd.h b/cpp/src/arrow/dataset/type_fwd.h index 52fe631f5ac..a7ea8d6ce9e 100644 --- a/cpp/src/arrow/dataset/type_fwd.h +++ b/cpp/src/arrow/dataset/type_fwd.h @@ -51,6 +51,7 @@ class FileWriteOptions; class FileSystemDataset; class FileSystemDatasetFactory; struct FileSystemDatasetWriteOptions; +class WriteNodeOptions; /// \brief Controls what happens if files exist in an output directory during a dataset /// write @@ -92,6 +93,7 @@ struct HivePartitioningOptions; class FilenamePartitioning; struct FilenamePartitioningOptions; +class ScanNodeOptions; struct ScanOptions; class Scanner; diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index 926fe846fff..f7fcd1e1279 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -17,11 +17,14 @@ #include "arrow/engine/substrait/extension_set.h" +#include #include +#include #include "arrow/engine/substrait/expression_internal.h" #include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" +#include "arrow/util/make_unique.h" #include "arrow/util/string_view.h" namespace arrow { @@ -62,53 +65,66 @@ size_t IdHashEq::operator()(Id id) const { bool IdHashEq::operator()(Id l, Id r) const { return l.uri == r.uri && l.name == r.name; } -Id IdStorage::Emplace(Id id) { - util::string_view owned_uri = EmplaceUri(id.uri); +class IdStorageImpl : public IdStorage { + public: + Id Emplace(Id id) override { + util::string_view owned_uri = EmplaceUri(id.uri); + + util::string_view owned_name; + auto name_itr = names_.find(id.name); + if (name_itr == names_.end()) { + owned_names_.emplace_back(id.name); + owned_name = owned_names_.back(); + names_.insert(owned_name); + } else { + owned_name = *name_itr; + } - util::string_view owned_name; - auto name_itr = names_.find(id.name); - if (name_itr == names_.end()) { - owned_names_.emplace_back(id.name); - owned_name = owned_names_.back(); - names_.insert(owned_name); - } else { - owned_name = *name_itr; + return {owned_uri, owned_name}; } - return {owned_uri, owned_name}; -} + std::optional Find(Id id) const override { + std::optional maybe_owned_uri = FindUri(id.uri); + if (!maybe_owned_uri) { + return std::nullopt; + } -std::optional IdStorage::Find(Id id) const { - std::optional maybe_owned_uri = FindUri(id.uri); - if (!maybe_owned_uri) { - return std::nullopt; + auto name_itr = names_.find(id.name); + if (name_itr == names_.end()) { + return std::nullopt; + } else { + return Id{*maybe_owned_uri, *name_itr}; + } } - auto name_itr = names_.find(id.name); - if (name_itr == names_.end()) { - return std::nullopt; - } else { - return Id{*maybe_owned_uri, *name_itr}; + std::optional FindUri(util::string_view uri) const override { + auto uri_itr = uris_.find(uri); + if (uri_itr == uris_.end()) { + return std::nullopt; + } + return *uri_itr; } -} -std::optional IdStorage::FindUri(util::string_view uri) const { - auto uri_itr = uris_.find(uri); - if (uri_itr == uris_.end()) { - return std::nullopt; + util::string_view EmplaceUri(util::string_view uri) override { + auto uri_itr = uris_.find(uri); + if (uri_itr == uris_.end()) { + owned_uris_.emplace_back(uri); + util::string_view owned_uri = owned_uris_.back(); + uris_.insert(owned_uri); + return owned_uri; + } + return *uri_itr; } - return *uri_itr; -} -util::string_view IdStorage::EmplaceUri(util::string_view uri) { - auto uri_itr = uris_.find(uri); - if (uri_itr == uris_.end()) { - owned_uris_.emplace_back(uri); - util::string_view owned_uri = owned_uris_.back(); - uris_.insert(owned_uri); - return owned_uri; - } - return *uri_itr; + private: + std::unordered_set uris_; + std::unordered_set names_; + std::list owned_uris_; + std::list owned_names_; +}; + +std::unique_ptr IdStorage::Make() { + return ::arrow::internal::make_unique(); } Result> SubstraitCall::GetEnumArg(uint32_t index) const { @@ -211,7 +227,7 @@ Result ExtensionSet::Make( "Plan contained a URI that the extension registry is unaware of: ", uri.second); } - set.uris_[uri.first] = set.plan_specific_ids_.EmplaceUri(uri.second); + set.uris_[uri.first] = set.plan_specific_ids_->EmplaceUri(uri.second); } } @@ -242,7 +258,7 @@ Result ExtensionSet::Make( function_id.second.uri, "#", function_id.second.name); } set.functions_[function_id.first] = - set.plan_specific_ids_.Emplace(function_id.second); + set.plan_specific_ids_->Emplace(function_id.second); } } @@ -315,7 +331,7 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return parent_uri; } } - return ids_.FindUri(uri); + return ids_->FindUri(uri); } std::optional FindId(Id id) const override { @@ -325,7 +341,7 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { return parent_id; } } - return ids_.Find(id); + return ids_->Find(id); } std::optional GetType(const DataType& type) const override { @@ -368,7 +384,7 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { ARROW_RETURN_NOT_OK(parent_->CanRegisterType(id, type)); } - Id copied_id = ids_.Emplace(id); + Id copied_id = ids_->Emplace(id); auto index = static_cast(type_ids_.size()); @@ -419,7 +435,7 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { Id substrait_id, ConverterType conversion_func, std::unordered_map* dest) { // Convert id to view into registry-owned memory - Id copied_id = ids_.Emplace(substrait_id); + Id copied_id = ids_->Emplace(substrait_id); auto add_result = dest->emplace(copied_id, std::move(conversion_func)); if (!add_result.second) { @@ -587,7 +603,7 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { const ExtensionIdRegistry* parent_; // owning storage of ids & types - IdStorage ids_; + std::unique_ptr ids_ = IdStorage::Make(); DataTypeVector types_; // There should only be one entry per Arrow function so there is no need // to separate ownership and lookup diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index e2b20f989ac..46c83b81d16 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -19,19 +19,22 @@ #pragma once -#include +#include +#include +#include +#include #include +#include #include -#include +#include #include -#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/api_aggregate.h" #include "arrow/compute/exec/expression.h" #include "arrow/engine/substrait/visibility.h" #include "arrow/result.h" #include "arrow/type_fwd.h" -#include "arrow/util/hash_util.h" -#include "arrow/util/hashing.h" +#include "arrow/util/macros.h" #include "arrow/util/string_view.h" namespace arrow { @@ -75,28 +78,25 @@ struct IdHashEq { /// storage. class IdStorage { public: + virtual ~IdStorage() = default; /// \brief Get an equivalent id pointing into this storage /// /// This operation will copy the ids into storage if they do not already exist - Id Emplace(Id id); + virtual Id Emplace(Id id) = 0; /// \brief Get an equivalent view pointing into this storage for a URI /// /// If no URI is found then the uri will be copied into storage - util::string_view EmplaceUri(util::string_view uri); + virtual util::string_view EmplaceUri(util::string_view uri) = 0; /// \brief Get an equivalent id pointing into this storage /// /// If no id is found then nullopt will be returned - std::optional Find(Id id) const; + virtual std::optional Find(Id id) const = 0; /// \brief Get an equivalent view pointing into this storage for a URI /// /// If no URI is found then nullopt will be returned - std::optional FindUri(util::string_view uri) const; + virtual std::optional FindUri(util::string_view uri) const = 0; - private: - std::unordered_set uris_; - std::unordered_set names_; - std::list owned_uris_; - std::list owned_names_; + static std::unique_ptr Make(); }; /// \brief Describes a Substrait call @@ -404,7 +404,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { // that we can safely ignore. For example, we can usually safely ignore // extension type variations if we assume the plan is valid. These ignorable // ids are stored here. - IdStorage plan_specific_ids_; + std::unique_ptr plan_specific_ids_ = IdStorage::Make(); // Map from anchor values to URI values referenced by this extension set std::unordered_map uris_; diff --git a/cpp/src/arrow/engine/substrait/options.h b/cpp/src/arrow/engine/substrait/options.h index eace200f0ac..014842f4d8f 100644 --- a/cpp/src/arrow/engine/substrait/options.h +++ b/cpp/src/arrow/engine/substrait/options.h @@ -19,6 +19,13 @@ #pragma once +#include +#include +#include + +#include "arrow/compute/type_fwd.h" +#include "arrow/type_fwd.h" + namespace arrow { namespace engine { diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index c6297675492..1e1c61fc322 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -17,10 +17,20 @@ #include "arrow/engine/substrait/serde.h" +#include + +#include "arrow/buffer.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/expression.h" +#include "arrow/compute/exec/options.h" +#include "arrow/dataset/file_base.h" #include "arrow/engine/substrait/expression_internal.h" +#include "arrow/engine/substrait/extension_set.h" #include "arrow/engine/substrait/plan_internal.h" #include "arrow/engine/substrait/relation_internal.h" +#include "arrow/engine/substrait/type_fwd.h" #include "arrow/engine/substrait/type_internal.h" +#include "arrow/type.h" #include "arrow/util/string_view.h" #include diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index 2a14ca67570..cc59adb0d25 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -20,17 +20,19 @@ #pragma once #include +#include #include #include -#include "arrow/buffer.h" -#include "arrow/compute/exec/exec_plan.h" -#include "arrow/compute/exec/options.h" -#include "arrow/dataset/file_base.h" -#include "arrow/engine/substrait/extension_set.h" +#include "arrow/compute/type_fwd.h" +#include "arrow/dataset/type_fwd.h" #include "arrow/engine/substrait/options.h" +#include "arrow/engine/substrait/type_fwd.h" #include "arrow/engine/substrait/visibility.h" #include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type_fwd.h" +#include "arrow/util/macros.h" #include "arrow/util/string_view.h" namespace arrow { diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index afa676a4095..7601bcf4370 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -20,6 +20,7 @@ #include #include +#include "arrow/compute/exec/exec_plan.h" #include "arrow/compute/exec/expression_internal.h" #include "arrow/dataset/file_base.h" #include "arrow/dataset/file_ipc.h" diff --git a/cpp/src/arrow/engine/substrait/type_fwd.h b/cpp/src/arrow/engine/substrait/type_fwd.h new file mode 100644 index 00000000000..235d9e82d1b --- /dev/null +++ b/cpp/src/arrow/engine/substrait/type_fwd.h @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +namespace arrow { +namespace engine { + +class ExtensionIdRegistry; +class ExtensionSet; + +struct ConversionOptions; + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index 6587ea077dd..0df3420c234 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -16,6 +16,9 @@ // under the License. #include "arrow/engine/substrait/util.h" + +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" #include "arrow/util/async_generator.h" #include "arrow/util/async_util.h" diff --git a/cpp/src/arrow/flight/sql/example/acero_server.cc b/cpp/src/arrow/flight/sql/example/acero_server.cc index ce1483cb8c3..201234c5263 100644 --- a/cpp/src/arrow/flight/sql/example/acero_server.cc +++ b/cpp/src/arrow/flight/sql/example/acero_server.cc @@ -22,6 +22,8 @@ #include #include +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" #include "arrow/engine/substrait/serde.h" #include "arrow/flight/sql/types.h" #include "arrow/type.h" diff --git a/cpp/src/arrow/util/type_fwd.h b/cpp/src/arrow/util/type_fwd.h index ca107c2c69d..976a22bb0be 100644 --- a/cpp/src/arrow/util/type_fwd.h +++ b/cpp/src/arrow/util/type_fwd.h @@ -54,6 +54,7 @@ struct Compression { }; namespace util { +class AsyncTaskScheduler; class Compressor; class Decompressor; class Codec; From 796916493057cacec45abe37c03b2ef17776a896 Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Mon, 19 Sep 2022 15:55:15 -0400 Subject: [PATCH 100/133] MINOR: [R] Forward compatibility for tidyselect 1.2 (#14170) cc @hadley Authored-by: Neal Richardson Signed-off-by: Neal Richardson --- r/tests/testthat/test-dplyr-filter.R | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/r/tests/testthat/test-dplyr-filter.R b/r/tests/testthat/test-dplyr-filter.R index e019a91cac4..f94450a0257 100644 --- a/r/tests/testthat/test-dplyr-filter.R +++ b/r/tests/testthat/test-dplyr-filter.R @@ -375,7 +375,9 @@ test_that("filter() with .data pronoun", { compare_dplyr_binding( .input %>% filter(.data$dbl > 4) %>% - select(.data$chr, .data$int, .data$lgl) %>% + # use "quoted" strings instead of .data pronoun where tidyselect is used + # .data pronoun deprecated in select in tidyselect 1.2 + select("chr", "int", "lgl") %>% collect(), tbl ) @@ -383,7 +385,7 @@ test_that("filter() with .data pronoun", { compare_dplyr_binding( .input %>% filter(is.na(.data$lgl)) %>% - select(.data$chr, .data$int, .data$lgl) %>% + select("chr", "int", "lgl") %>% collect(), tbl ) @@ -393,7 +395,7 @@ test_that("filter() with .data pronoun", { compare_dplyr_binding( .input %>% filter(.data$dbl > .env$chr) %>% - select(.data$chr, .data$int, .data$lgl) %>% + select("chr", "int", "lgl") %>% collect(), tbl ) From b97497aadcb06f03305f534dc5aee8a759b5442f Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Tue, 20 Sep 2022 04:20:55 +0530 Subject: [PATCH 101/133] ARROW-17647: [C++] Using better namespace style when using protobuf with Substrait (#14121) This PR includes minor changes to the namespace usage in Substrait integation. Authored-by: Vibhatha Abeykoon Signed-off-by: Weston Pace --- .../engine/substrait/expression_internal.cc | 2 +- .../engine/substrait/relation_internal.cc | 23 ++++++++++--------- .../arrow/engine/substrait/type_internal.cc | 2 +- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 1f9d234bff7..ec0578828a6 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -1011,7 +1011,7 @@ Result> ToProto( if (arguments[0]->has_selection() && arguments[0]->selection().has_direct_reference()) { if (arguments[1]->has_literal() && arguments[1]->literal().literal_type_case() == - substrait::Expression_Literal::kI32) { + substrait::Expression::Literal::kI32) { return MakeListElementReference(std::move(arguments[0]), arguments[1]->literal().i32()); } diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 00427fc4c9e..ed07f75f2b9 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -165,23 +165,23 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& for (const auto& item : read.local_files().items()) { std::string path; if (item.path_type_case() == - substrait::ReadRel_LocalFiles_FileOrFiles::kUriPath) { + substrait::ReadRel::LocalFiles::FileOrFiles::kUriPath) { path = item.uri_path(); } else if (item.path_type_case() == - substrait::ReadRel_LocalFiles_FileOrFiles::kUriFile) { + substrait::ReadRel::LocalFiles::FileOrFiles::kUriFile) { path = item.uri_file(); } else if (item.path_type_case() == - substrait::ReadRel_LocalFiles_FileOrFiles::kUriFolder) { + substrait::ReadRel::LocalFiles::FileOrFiles::kUriFolder) { path = item.uri_folder(); } else { path = item.uri_path_glob(); } switch (item.file_format_case()) { - case substrait::ReadRel_LocalFiles_FileOrFiles::kParquet: + case substrait::ReadRel::LocalFiles::FileOrFiles::kParquet: format = std::make_shared(); break; - case substrait::ReadRel_LocalFiles_FileOrFiles::kArrow: + case substrait::ReadRel::LocalFiles::FileOrFiles::kArrow: format = std::make_shared(); break; default: @@ -212,7 +212,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& path = path.substr(7); switch (item.path_type_case()) { - case substrait::ReadRel_LocalFiles_FileOrFiles::kUriPath: { + case substrait::ReadRel::LocalFiles::FileOrFiles::kUriPath: { ARROW_ASSIGN_OR_RAISE(auto file, filesystem->GetFileInfo(path)); if (file.type() == fs::FileType::File) { files.push_back(std::move(file)); @@ -226,11 +226,11 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& } break; } - case substrait::ReadRel_LocalFiles_FileOrFiles::kUriFile: { + case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFile: { files.emplace_back(path, fs::FileType::File); break; } - case substrait::ReadRel_LocalFiles_FileOrFiles::kUriFolder: { + case substrait::ReadRel::LocalFiles::FileOrFiles::kUriFolder: { fs::FileSelector selector; selector.base_dir = path; selector.recursive = true; @@ -240,7 +240,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::back_inserter(files)); break; } - case substrait::ReadRel_LocalFiles_FileOrFiles::kUriPathGlob: { + case substrait::ReadRel::LocalFiles::FileOrFiles::kUriPathGlob: { ARROW_ASSIGN_OR_RAISE(auto discovered_files, fs::internal::GlobFiles(filesystem, path)); std::move(discovered_files.begin(), discovered_files.end(), @@ -582,15 +582,16 @@ Result> ScanRelationConverter( return Status::Invalid( "Can only convert scan node with FileSystemDataset to a Substrait plan."); } + // set schema ARROW_ASSIGN_OR_RAISE(auto named_struct, ToProto(*dataset->schema(), ext_set, conversion_options)); read_rel->set_allocated_base_schema(named_struct.release()); // set local files - auto read_rel_lfs = make_unique(); + auto read_rel_lfs = make_unique(); for (const auto& file : dataset->files()) { - auto read_rel_lfs_ffs = make_unique(); + auto read_rel_lfs_ffs = make_unique(); read_rel_lfs_ffs->set_uri_path(UriFromAbsolutePath(file)); // set file format auto format_type_name = dataset->format()->type_name(); diff --git a/cpp/src/arrow/engine/substrait/type_internal.cc b/cpp/src/arrow/engine/substrait/type_internal.cc index 310413a8926..a2b5445cdce 100644 --- a/cpp/src/arrow/engine/substrait/type_internal.cc +++ b/cpp/src/arrow/engine/substrait/type_internal.cc @@ -398,7 +398,7 @@ struct DataTypeToProtoImpl { template Status EncodeUserDefined(const T& t) { ARROW_ASSIGN_OR_RAISE(auto anchor, ext_set_->EncodeType(t)); - auto user_defined = internal::make_unique<::substrait::Type_UserDefined>(); + auto user_defined = internal::make_unique<::substrait::Type::UserDefined>(); user_defined->set_type_reference(anchor); user_defined->set_nullability(nullable_ ? ::substrait::Type::NULLABILITY_NULLABLE : ::substrait::Type::NULLABILITY_REQUIRED); From cd67e5195aee78748e2b646a15a4e9f8b791a776 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 20 Sep 2022 09:10:39 +0200 Subject: [PATCH 102/133] ARROW-17517: [C++] Test engine API in public API test (#13965) Also some assorted header inclusion cleanups. Followup to PR #13965. Authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- cpp/cmake_modules/DefineOptions.cmake | 4 +- cpp/src/arrow/compute/exec.h | 4 +- cpp/src/arrow/compute/exec/aggregate.cc | 1 + cpp/src/arrow/compute/exec/exec_plan.cc | 1 + cpp/src/arrow/compute/exec/exec_plan.h | 3 -- cpp/src/arrow/compute/exec/tpch_node.cc | 25 +++++++----- cpp/src/arrow/compute/type_fwd.h | 4 ++ cpp/src/arrow/public_api_test.cc | 52 ++++++++++++++++++------- cpp/src/arrow/util/config.h.cmake | 2 + r/src/compute-exec.cpp | 1 + 10 files changed, 65 insertions(+), 32 deletions(-) diff --git a/cpp/cmake_modules/DefineOptions.cmake b/cpp/cmake_modules/DefineOptions.cmake index 8a1271dfcdf..6b77e36db77 100644 --- a/cpp/cmake_modules/DefineOptions.cmake +++ b/cpp/cmake_modules/DefineOptions.cmake @@ -239,8 +239,6 @@ takes precedence over ccache if a storage backend is configured" ON) define_option(ARROW_DATASET "Build the Arrow Dataset Modules" OFF) - define_option(ARROW_SUBSTRAIT "Build the Arrow Substrait Consumer Module" OFF) - define_option(ARROW_FILESYSTEM "Build the Arrow Filesystem Layer" OFF) define_option(ARROW_FLIGHT @@ -285,6 +283,8 @@ takes precedence over ccache if a storage backend is configured" ON) define_option(ARROW_SKYHOOK "Build the Skyhook libraries" OFF) + define_option(ARROW_SUBSTRAIT "Build the Arrow Substrait Consumer Module" OFF) + define_option(ARROW_TENSORFLOW "Build Arrow with TensorFlow support enabled" OFF) define_option(ARROW_TESTING "Build the Arrow testing libraries" OFF) diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index 12cce42038d..2731c39abaa 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -30,8 +30,8 @@ #include "arrow/array/data.h" #include "arrow/compute/exec/expression.h" +#include "arrow/compute/type_fwd.h" #include "arrow/datum.h" -#include "arrow/memory_pool.h" #include "arrow/result.h" #include "arrow/type_fwd.h" #include "arrow/util/macros.h" @@ -127,8 +127,6 @@ class ARROW_EXPORT ExecContext { bool use_threads_ = true; }; -ARROW_EXPORT ExecContext* default_exec_context(); - // TODO: Consider standardizing on uint16 selection vectors and only use them // when we can ensure that each value is 64K length or smaller diff --git a/cpp/src/arrow/compute/exec/aggregate.cc b/cpp/src/arrow/compute/exec/aggregate.cc index 95f85317c8d..180cfacf3d5 100644 --- a/cpp/src/arrow/compute/exec/aggregate.cc +++ b/cpp/src/arrow/compute/exec/aggregate.cc @@ -19,6 +19,7 @@ #include #include +#include #include "arrow/compute/exec_internal.h" #include "arrow/compute/registry.h" diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 8d9b214a720..6b02b76916c 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -34,6 +34,7 @@ #include "arrow/result.h" #include "arrow/util/async_generator.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/tracing_internal.h" diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index f82afb604a1..e9af46be261 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -26,12 +26,9 @@ #include #include -#include "arrow/compute/exec.h" #include "arrow/compute/type_fwd.h" #include "arrow/type_fwd.h" -#include "arrow/util/cancel.h" #include "arrow/util/future.h" -#include "arrow/util/key_value_metadata.h" #include "arrow/util/macros.h" #include "arrow/util/tracing.h" #include "arrow/util/type_fwd.h" diff --git a/cpp/src/arrow/compute/exec/tpch_node.cc b/cpp/src/arrow/compute/exec/tpch_node.cc index 13fbef2bd5c..0cd33313b03 100644 --- a/cpp/src/arrow/compute/exec/tpch_node.cc +++ b/cpp/src/arrow/compute/exec/tpch_node.cc @@ -16,16 +16,6 @@ // under the License. #include "arrow/compute/exec/tpch_node.h" -#include "arrow/buffer.h" -#include "arrow/compute/exec/exec_plan.h" -#include "arrow/util/async_util.h" -#include "arrow/util/formatting.h" -#include "arrow/util/future.h" -#include "arrow/util/io_util.h" -#include "arrow/util/logging.h" -#include "arrow/util/make_unique.h" -#include "arrow/util/pcg_random.h" -#include "arrow/util/unreachable.h" #include #include @@ -34,10 +24,25 @@ #include #include #include +#include #include #include +#include "arrow/buffer.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/datum.h" +#include "arrow/util/async_util.h" +#include "arrow/util/formatting.h" +#include "arrow/util/future.h" +#include "arrow/util/io_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/pcg_random.h" +#include "arrow/util/unreachable.h" + namespace arrow { + using internal::checked_cast; using internal::GetRandomSeed; diff --git a/cpp/src/arrow/compute/type_fwd.h b/cpp/src/arrow/compute/type_fwd.h index 1494d3e1d13..11c45fde091 100644 --- a/cpp/src/arrow/compute/type_fwd.h +++ b/cpp/src/arrow/compute/type_fwd.h @@ -17,6 +17,8 @@ #pragma once +#include "arrow/util/visibility.h" + namespace arrow { struct Datum; @@ -49,5 +51,7 @@ class ExecFactoryRegistry; class SinkNodeConsumer; +ARROW_EXPORT ExecContext* default_exec_context(); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/public_api_test.cc b/cpp/src/arrow/public_api_test.cc index a611dd7920c..9abff229508 100644 --- a/cpp/src/arrow/public_api_test.cc +++ b/cpp/src/arrow/public_api_test.cc @@ -46,45 +46,69 @@ #include "arrow/flight/api.h" // IWYU pragma: keep #endif +#ifdef ARROW_FLIGHT_SQL +#include "arrow/flight/sql/api.h" // IWYU pragma: keep +#endif + #ifdef ARROW_JSON #include "arrow/json/api.h" // IWYU pragma: keep #endif +#ifdef ARROW_SUBSTRAIT +#include "arrow/engine/api.h" // IWYU pragma: keep +#include "arrow/engine/substrait/api.h" // IWYU pragma: keep +#endif + +#include +#include + +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/util.h" + +namespace arrow { + +TEST(InternalHeaders, DCheckExposed) { #ifdef DCHECK -#error "DCHECK should not be visible from Arrow public headers." + FAIL() << "DCHECK should not be visible from Arrow public headers."; #endif +} +TEST(InternalHeaders, AssignOrRaiseExposed) { #ifdef ASSIGN_OR_RAISE -#error "ASSIGN_OR_RAISE should not be visible from Arrow public headers." + FAIL() << "ASSIGN_OR_RAISE should not be visible from Arrow public headers."; #endif +} +TEST(InternalDependencies, OpenTelemetryExposed) { #ifdef OPENTELEMETRY_VERSION -#error "OpenTelemetry should not be visible from Arrow public headers." + FAIL() << "OpenTelemetry should not be visible from Arrow public headers."; #endif +} +TEST(InternalDependencies, XSimdExposed) { #ifdef XSIMD_VERSION_MAJOR -#error "xsimd should not be visible from Arrow public headers." + FAIL() << "xsimd should not be visible from Arrow public headers."; #endif +} +TEST(InternalDependencies, DateLibraryExposed) { #ifdef HAS_CHRONO_ROUNDING -#error "arrow::vendored::date should not be visible from Arrow public headers." + FAIL() << "arrow::vendored::date should not be visible from Arrow public headers."; #endif +} +TEST(InternalDependencies, ProtobufExposed) { #ifdef PROTOBUF_EXPORT -#error "Protocol Buffers should not be visible from Arrow public headers." + FAIL() << "Protocol Buffers should not be visible from Arrow public headers."; #endif +} +TEST(TransitiveDependencies, WindowsHeadersExposed) { #if defined(SendMessage) || defined(GetObject) || defined(ERROR_INVALID_HANDLE) || \ defined(FILE_SHARE_READ) || defined(WAIT_TIMEOUT) -#error "Windows.h should not be included by Arrow public headers" + FAIL() << "Windows.h should not be included by Arrow public headers"; #endif - -#include -#include -#include "arrow/testing/gtest_util.h" -#include "arrow/testing/util.h" - -namespace arrow { +} TEST(Misc, BuildInfo) { const auto& info = GetBuildInfo(); diff --git a/cpp/src/arrow/util/config.h.cmake b/cpp/src/arrow/util/config.h.cmake index 9948c1e3587..f6fad2016a2 100644 --- a/cpp/src/arrow/util/config.h.cmake +++ b/cpp/src/arrow/util/config.h.cmake @@ -42,12 +42,14 @@ #cmakedefine ARROW_DATASET #cmakedefine ARROW_FILESYSTEM #cmakedefine ARROW_FLIGHT +#cmakedefine ARROW_FLIGHT_SQL #cmakedefine ARROW_IPC #cmakedefine ARROW_JEMALLOC #cmakedefine ARROW_JEMALLOC_VENDORED #cmakedefine ARROW_JSON #cmakedefine ARROW_ORC #cmakedefine ARROW_PARQUET +#cmakedefine ARROW_SUBSTRAIT #cmakedefine ARROW_GCS #cmakedefine ARROW_S3 diff --git a/r/src/compute-exec.cpp b/r/src/compute-exec.cpp index 71dc6d8b2e1..5af6450050e 100644 --- a/r/src/compute-exec.cpp +++ b/r/src/compute-exec.cpp @@ -18,6 +18,7 @@ #include "./arrow_types.h" #include "./safe-call-into-r.h" +#include #include #include #include From ab71673ce0955798645ae9178018f562a82ed7f2 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Tue, 20 Sep 2022 01:48:40 -0700 Subject: [PATCH 103/133] ARROW-13454: [C++][Docs] Tables vs Record Batches (#14008) Adds a little more explanation of the difference between tables and record batches, as well as a diagram representation. Authored-by: Will Jones Signed-off-by: Antoine Pitrou --- .../cpp/tables-versus-record-batches.svg | 102 ++++++++++++++++++ docs/source/cpp/tables.rst | 12 +++ docs/source/format/Glossary.rst | 8 +- 3 files changed, 120 insertions(+), 2 deletions(-) create mode 100644 docs/source/cpp/tables-versus-record-batches.svg diff --git a/docs/source/cpp/tables-versus-record-batches.svg b/docs/source/cpp/tables-versus-record-batches.svg new file mode 100644 index 00000000000..d793b1de2bf --- /dev/null +++ b/docs/source/cpp/tables-versus-record-batches.svg @@ -0,0 +1,102 @@ + + + + + + Arrow Table versus Record Batch + + + + Arrow Table + + Schema + + + + + Field + + + + + + Chunked + Array + + + + + + + + Array + + + + + A Table is a C++ data structure, + allowing for a mixed chunking structure and very large arrays. + + + + Arrow Record Batch + + Schema + + + + + Field + + + + + + Array + + + + + A Record Batch is a common Arrow data structure which is recognized by all implementations. + + + \ No newline at end of file diff --git a/docs/source/cpp/tables.rst b/docs/source/cpp/tables.rst index ea9198771cf..b28a9fc1e13 100644 --- a/docs/source/cpp/tables.rst +++ b/docs/source/cpp/tables.rst @@ -77,6 +77,18 @@ has a schema which must match its arrays' datatypes. Record batches are a convenient unit of work for various serialization and computation functions, possibly incremental. +.. image:: tables-versus-record-batches.svg + :alt: A graphical representation of an Arrow Table and a Record Batch, with + structure as described in text above. + +Record batches can be sent between implementations, such as via +:ref:`IPC ` or +via the :doc:`C Data Interface <../format/CDataInterface>`. Tables and +chunked arrays, on the other hand, are concepts in the C++ implementation, +not in the Arrow format itself, so they aren't directly portable. + +However, a table can be converted to and built from a sequence of record +batches easily without needing to copy the underlying array buffers. A table can be streamed as an arbitrary number of record batches using a :class:`arrow::TableBatchReader`. Conversely, a logical sequence of record batches can be assembled to form a table using one of the diff --git a/docs/source/format/Glossary.rst b/docs/source/format/Glossary.rst index 423ebf85783..5944d7c18cf 100644 --- a/docs/source/format/Glossary.rst +++ b/docs/source/format/Glossary.rst @@ -196,7 +196,11 @@ Glossary different buffers for different indices. Not part of the columnar format; this term is specific to - certain language implementations of Arrow (primarily C++ and - its bindings). + certain language implementations of Arrow (for example C++ and + its bindings, and Go). + + .. image:: ../cpp/tables-versus-record-batches.svg + :alt: A graphical representation of an Arrow Table and a + Record Batch, with structure as described in text above. .. seealso:: :term:`chunked array`, :term:`record batch` From 4f31bfc2ffed603089c8bcd3e44ae0950f171126 Mon Sep 17 00:00:00 2001 From: Pavel Solodovnikov Date: Tue, 20 Sep 2022 11:57:08 +0300 Subject: [PATCH 104/133] ARROW-17318: [C++][Dataset] Support async streaming interface for getting fragments in Dataset (#13804) Add `GetFragmentsAsync()` and `GetFragmentsAsyncImpl()` functions to the generic `Dataset` interface, which allows to produce fragments in a streamed fashion. This is one of the prerequisites for making `FileSystemDataset` to support lazy fragment processing, which, in turn, can be used to start scan operations without waiting for the entire dataset to be discovered. To aid the transition process of moving to async implementation in `Dataset`/`AsyncScanner` code, a default implementation for `GetFragmentsAsyncImpl()` is provided (yielding a VectorGenerator over the fragments vector, which is stored by every implementation of Dataset interface at the moment). Tests: unit(release) Signed-off-by: Pavel Solodovnikov Authored-by: Pavel Solodovnikov Signed-off-by: Antoine Pitrou --- cpp/src/arrow/dataset/dataset.cc | 32 ++++++++++- cpp/src/arrow/dataset/dataset.h | 24 ++++++++ cpp/src/arrow/dataset/dataset_test.cc | 28 ++++++++++ cpp/src/arrow/dataset/test_util.h | 18 +++++- cpp/src/arrow/util/async_generator.h | 4 +- cpp/src/arrow/util/async_generator_fwd.h | 71 ++++++++++++++++++++++++ 6 files changed, 171 insertions(+), 6 deletions(-) create mode 100644 cpp/src/arrow/util/async_generator_fwd.h diff --git a/cpp/src/arrow/dataset/dataset.cc b/cpp/src/arrow/dataset/dataset.cc index 6faaa953bb3..eb307681e91 100644 --- a/cpp/src/arrow/dataset/dataset.cc +++ b/cpp/src/arrow/dataset/dataset.cc @@ -15,18 +15,19 @@ // specific language governing permissions and limitations // under the License. -#include "arrow/dataset/dataset.h" - #include #include +#include "arrow/dataset/dataset.h" #include "arrow/dataset/dataset_internal.h" #include "arrow/dataset/scanner.h" #include "arrow/table.h" +#include "arrow/util/async_generator.h" #include "arrow/util/bit_util.h" #include "arrow/util/iterator.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" +#include "arrow/util/thread_pool.h" namespace arrow { @@ -160,6 +161,33 @@ Result Dataset::GetFragments(compute::Expression predicate) { : MakeEmptyIterator>(); } +Result Dataset::GetFragmentsAsync() { + return GetFragmentsAsync(compute::literal(true)); +} + +Result Dataset::GetFragmentsAsync(compute::Expression predicate) { + ARROW_ASSIGN_OR_RAISE( + predicate, SimplifyWithGuarantee(std::move(predicate), partition_expression_)); + return predicate.IsSatisfiable() + ? GetFragmentsAsyncImpl(std::move(predicate), + arrow::internal::GetCpuThreadPool()) + : MakeEmptyGenerator>(); +} + +// Default impl delegating the work to `GetFragmentsImpl` and wrapping it into +// BackgroundGenerator/TransferredGenerator, which offloads potentially +// IO-intensive work to the default IO thread pool and then transfers the control +// back to the specified executor. +Result Dataset::GetFragmentsAsyncImpl( + compute::Expression predicate, arrow::internal::Executor* executor) { + ARROW_ASSIGN_OR_RAISE(auto iter, GetFragmentsImpl(std::move(predicate))); + ARROW_ASSIGN_OR_RAISE( + auto background_gen, + MakeBackgroundGenerator(std::move(iter), io::default_io_context().executor())); + auto transferred_gen = MakeTransferredGenerator(std::move(background_gen), executor); + return transferred_gen; +} + struct VectorRecordBatchGenerator : InMemoryDataset::RecordBatchGenerator { explicit VectorRecordBatchGenerator(RecordBatchVector batches) : batches_(std::move(batches)) {} diff --git a/cpp/src/arrow/dataset/dataset.h b/cpp/src/arrow/dataset/dataset.h index 62181b60ba4..3a5030b6be8 100644 --- a/cpp/src/arrow/dataset/dataset.h +++ b/cpp/src/arrow/dataset/dataset.h @@ -29,10 +29,16 @@ #include "arrow/compute/exec/expression.h" #include "arrow/dataset/type_fwd.h" #include "arrow/dataset/visibility.h" +#include "arrow/util/async_generator_fwd.h" #include "arrow/util/macros.h" #include "arrow/util/mutex.h" namespace arrow { + +namespace internal { +class Executor; +} // namespace internal + namespace dataset { using RecordBatchGenerator = std::function>()>; @@ -134,6 +140,8 @@ class ARROW_DS_EXPORT InMemoryFragment : public Fragment { /// @} +using FragmentGenerator = AsyncGenerator>; + /// \brief A container of zero or more Fragments. /// /// A Dataset acts as a union of Fragments, e.g. files deeply nested in a @@ -148,6 +156,10 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { Result GetFragments(compute::Expression predicate); Result GetFragments(); + /// \brief Async versions of `GetFragments`. + Result GetFragmentsAsync(compute::Expression predicate); + Result GetFragmentsAsync(); + const std::shared_ptr& schema() const { return schema_; } /// \brief An expression which evaluates to true for all data viewed by this Dataset. @@ -174,6 +186,18 @@ class ARROW_DS_EXPORT Dataset : public std::enable_shared_from_this { Dataset(std::shared_ptr schema, compute::Expression partition_expression); virtual Result GetFragmentsImpl(compute::Expression predicate) = 0; + /// \brief Default non-virtual implementation method for the base + /// `GetFragmentsAsyncImpl` method, which creates a fragment generator for + /// the dataset, possibly filtering results with a predicate (forwarding to + /// the synchronous `GetFragmentsImpl` method and moving the computations + /// to the background, using the IO thread pool). + /// + /// Currently, `executor` is always the same as `internal::GetCPUThreadPool()`, + /// which means the results from the underlying fragment generator will be + /// transfered to the default CPU thread pool. The generator itself is + /// offloaded to run on the default IO thread pool. + virtual Result GetFragmentsAsyncImpl( + compute::Expression predicate, arrow::internal::Executor* executor); std::shared_ptr schema_; compute::Expression partition_expression_ = compute::literal(true); diff --git a/cpp/src/arrow/dataset/dataset_test.cc b/cpp/src/arrow/dataset/dataset_test.cc index cb155d7b962..5d199823474 100644 --- a/cpp/src/arrow/dataset/dataset_test.cc +++ b/cpp/src/arrow/dataset/dataset_test.cc @@ -146,6 +146,34 @@ TEST_F(TestInMemoryDataset, HandlesDifferingSchemas) { scanner->ToTable()); } +TEST_F(TestInMemoryDataset, GetFragmentsSync) { + constexpr int64_t kBatchSize = 1024; + constexpr int64_t kNumberBatches = 16; + + SetSchema({field("i32", int32()), field("f64", float64())}); + auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); + auto reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch); + + auto dataset = std::make_shared( + schema_, RecordBatchVector{static_cast(kNumberBatches), batch}); + + AssertDatasetFragmentsEqual(reader.get(), dataset.get()); +} + +TEST_F(TestInMemoryDataset, GetFragmentsAsync) { + constexpr int64_t kBatchSize = 1024; + constexpr int64_t kNumberBatches = 16; + + SetSchema({field("i32", int32()), field("f64", float64())}); + auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_); + auto reader = ConstantArrayGenerator::Repeat(kNumberBatches, batch); + + auto dataset = std::make_shared( + schema_, RecordBatchVector{static_cast(kNumberBatches), batch}); + + AssertDatasetAsyncFragmentsEqual(reader.get(), dataset.get()); +} + class TestUnionDataset : public DatasetFixtureMixin {}; TEST_F(TestUnionDataset, ReplaceSchema) { diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index 05a98693896..fb54dc3a91a 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -167,7 +167,7 @@ class DatasetFixtureMixin : public ::testing::Test { void AssertFragmentEquals(RecordBatchReader* expected, Fragment* fragment, bool ensure_drained = true) { ASSERT_OK_AND_ASSIGN(auto batch_gen, fragment->ScanBatchesAsync(options_)); - AssertScanTaskEquals(expected, batch_gen); + AssertScanTaskEquals(expected, batch_gen, ensure_drained); if (ensure_drained) { EnsureRecordBatchReaderDrained(expected); @@ -191,6 +191,22 @@ class DatasetFixtureMixin : public ::testing::Test { } } + void AssertDatasetAsyncFragmentsEqual(RecordBatchReader* expected, Dataset* dataset, + bool ensure_drained = true) { + ASSERT_OK_AND_ASSIGN(auto predicate, options_->filter.Bind(*dataset->schema())); + ASSERT_OK_AND_ASSIGN(auto gen, dataset->GetFragmentsAsync(predicate)) + + ASSERT_FINISHES_OK(VisitAsyncGenerator( + std::move(gen), [this, expected](const std::shared_ptr& f) { + AssertFragmentEquals(expected, f.get(), false /*ensure_drained*/); + return Status::OK(); + })); + + if (ensure_drained) { + EnsureRecordBatchReaderDrained(expected); + } + } + /// \brief Ensure that record batches found in reader are equals to the /// record batches yielded by a scanner. void AssertScannerEquals(RecordBatchReader* expected, Scanner* scanner, diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index d4a9c2829a7..0d51208ac72 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -25,6 +25,7 @@ #include #include +#include "arrow/util/async_generator_fwd.h" #include "arrow/util/async_util.h" #include "arrow/util/functional.h" #include "arrow/util/future.h" @@ -66,9 +67,6 @@ namespace arrow { // until all outstanding futures have completed. Generators that spawn multiple // concurrent futures may need to hold onto an error while other concurrent futures wrap // up. -template -using AsyncGenerator = std::function()>; - template struct IterationTraits> { /// \brief by default when iterating through a sequence of AsyncGenerator, diff --git a/cpp/src/arrow/util/async_generator_fwd.h b/cpp/src/arrow/util/async_generator_fwd.h new file mode 100644 index 00000000000..f3c5bf9ef6f --- /dev/null +++ b/cpp/src/arrow/util/async_generator_fwd.h @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/type_fwd.h" + +namespace arrow { + +template +using AsyncGenerator = std::function()>; + +template +class MappingGenerator; + +template +class SequencingGenerator; + +template +class TransformingGenerator; + +template +class SerialReadaheadGenerator; + +template +class ReadaheadGenerator; + +template +class PushGenerator; + +template +class MergedGenerator; + +template +struct Enumerated; + +template +class EnumeratingGenerator; + +template +class TransferringGenerator; + +template +class BackgroundGenerator; + +template +class GeneratorIterator; + +template +struct CancellableGenerator; + +template +class DefaultIfEmptyGenerator; + +} // namespace arrow From 2577ac1a10270b4e8e7fe9d2240b36216407319f Mon Sep 17 00:00:00 2001 From: Neal Richardson Date: Tue, 20 Sep 2022 06:06:44 -0400 Subject: [PATCH 105/133] ARROW-17690: [R] Implement dplyr::across() inside distinct() (#14154) The feature was actually done by ARROW-17689; this PR just adds the test and updates the docs. cc @eitsupi Authored-by: Neal Richardson Signed-off-by: Neal Richardson --- r/R/dplyr-funcs-doc.R | 2 +- r/data-raw/docgen.R | 4 ++-- r/man/acero.Rd | 2 +- r/tests/testthat/test-dplyr-distinct.R | 10 ++++++++++ 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/r/R/dplyr-funcs-doc.R b/r/R/dplyr-funcs-doc.R index cbfe475232b..5360c7fad66 100644 --- a/r/R/dplyr-funcs-doc.R +++ b/r/R/dplyr-funcs-doc.R @@ -185,7 +185,7 @@ #' #' ## dplyr #' -#' * [`across()`][dplyr::across()]: supported inside `mutate()`, `summarize()`, `group_by()`, and `arrange()`; +#' * [`across()`][dplyr::across()]: not yet supported inside `filter()`; #' purrr-style lambda functions #' and use of `where()` selection helper not yet supported #' * [`between()`][dplyr::between()] diff --git a/r/data-raw/docgen.R b/r/data-raw/docgen.R index 9f839cfd123..267a85bdabe 100644 --- a/r/data-raw/docgen.R +++ b/r/data-raw/docgen.R @@ -128,8 +128,8 @@ docs <- arrow:::.cache$docs # across() is handled by manipulating the quosures, not by nse_funcs docs[["dplyr::across"]] <- c( - # TODO(ARROW-17387, ARROW-17389, ARROW-17390): other verbs - "supported inside `mutate()`, `summarize()`, `group_by()`, and `arrange()`;", + # TODO(ARROW-17387): do filter + "not yet supported inside `filter()`;", # TODO(ARROW-17366): do ~ "purrr-style lambda functions", # TODO(ARROW-17384): implement where diff --git a/r/man/acero.Rd b/r/man/acero.Rd index 5d4859edcb5..76f1b13fe3a 100644 --- a/r/man/acero.Rd +++ b/r/man/acero.Rd @@ -175,7 +175,7 @@ as \code{arrow_ascii_is_decimal}. \subsection{dplyr}{ \itemize{ -\item \code{\link[dplyr:across]{across()}}: supported inside \code{mutate()}, \code{summarize()}, \code{group_by()}, and \code{arrange()}; +\item \code{\link[dplyr:across]{across()}}: not yet supported inside \code{filter()}; purrr-style lambda functions and use of \code{where()} selection helper not yet supported \item \code{\link[dplyr:between]{between()}} diff --git a/r/tests/testthat/test-dplyr-distinct.R b/r/tests/testthat/test-dplyr-distinct.R index c679794d419..b598d058901 100644 --- a/r/tests/testthat/test-dplyr-distinct.R +++ b/r/tests/testthat/test-dplyr-distinct.R @@ -90,6 +90,16 @@ test_that("distinct() can contain expressions", { ) }) +test_that("across() works in distinct()", { + compare_dplyr_binding( + .input %>% + distinct(across(starts_with("d"))) %>% + collect() %>% + arrange(dbl, dbl2), + tbl + ) +}) + test_that("distinct() can return all columns", { skip("ARROW-14045") compare_dplyr_binding( From 2629f208cf8be638c7e251946947009bc181bb31 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Tue, 20 Sep 2022 14:50:01 +0200 Subject: [PATCH 106/133] ARROW-13055: [Doc] Create canonical extension types document (#14167) Vote result at https://lists.apache.org/thread/sxd5fhc42hb6svs79t3fd79gkqj83pfh Authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- docs/source/format/CanonicalExtensions.rst | 75 ++++++++++++++++++++++ docs/source/format/Columnar.rst | 9 +++ docs/source/format/Glossary.rst | 16 +++-- docs/source/index.rst | 1 + 4 files changed, 97 insertions(+), 4 deletions(-) create mode 100644 docs/source/format/CanonicalExtensions.rst diff --git a/docs/source/format/CanonicalExtensions.rst b/docs/source/format/CanonicalExtensions.rst new file mode 100644 index 00000000000..3ede97ef7dc --- /dev/null +++ b/docs/source/format/CanonicalExtensions.rst @@ -0,0 +1,75 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +.. _format_canonical_extensions: + +************************* +Canonical Extension Types +************************* + +============ +Introduction +============ + +The Arrow Columnar Format allows defining +:ref:`extension types ` so as to extend +standard Arrow data types with custom semantics. Often these semantics +will be specific to a system or application. However, it is beneficial +to share the definitions of well-known extension types so as to improve +interoperability between different systems integrating Arrow columnar data. + +Standardization +=============== + +These rules must be followed for the standardization of canonical extension +types: + +* Canonical extension types are described and maintained below in this document. + +* Each canonical extension type requires a distinct discussion and vote + on the `Arrow development mailing-list `__. + +* The specification text to be added *must* follow these requirements: + + 1) It *must* define a well-defined extension name starting with "``arrow.``". + + 2) Its parameters, if any, *must* be described in the proposal. + + 3) Its serialization *must* be described in the proposal and should + not require unduly implementation work or unusual software dependencies + (for example, a trivial custom text format or JSON would be acceptable). + + 4) Its expected semantics *should* be described as well and any + potential ambiguities or pain points addressed or at least mentioned. + +* The extension type *should* have one implementation submitted; + preferably two if non-trivial (for example if parameterized). + +Making Modifications +==================== + +Like standard Arrow data types, canonical extension types should be considered +stable once standardized. Modifying a canonical extension type (for example +to expand the set of parameters) should be an exceptional event, follow the +same rules as laid out above, and provide backwards compatibility guarantees. + + +============= +Official List +============= + +No canonical extension types have been standardized yet. diff --git a/docs/source/format/Columnar.rst b/docs/source/format/Columnar.rst index 109b81e2b9d..5f9537384c0 100644 --- a/docs/source/format/Columnar.rst +++ b/docs/source/format/Columnar.rst @@ -1167,6 +1167,11 @@ structure. These extension keys are: * ``'ARROW:extension:metadata'`` for a serialized representation of the ``ExtensionType`` necessary to reconstruct the custom type +.. note:: + Extension names beginning with ``arrow.`` are reserved for + :ref:`canonical extension types `, + they should not be used for third-party extension types. + This extension metadata can annotate any of the built-in Arrow logical types. The intent is that an implementation that does not support an extension type can still handle the underlying data. For example a @@ -1190,6 +1195,10 @@ extension types: metadata indicating the market trading calendar the data corresponds to +.. seealso:: + :ref:`format_canonical_extensions` + + Implementation guidelines ========================= diff --git a/docs/source/format/Glossary.rst b/docs/source/format/Glossary.rst index 5944d7c18cf..ac18c1618bc 100644 --- a/docs/source/format/Glossary.rst +++ b/docs/source/format/Glossary.rst @@ -52,6 +52,14 @@ Glossary device (e.g. GPU) memory, etc., though not all Arrow implementations support all of these possibilities. + canonical extension type + An :term:`extension type` that has been standardized by the + Arrow community so as to improve interoperability between + implementations. + + .. seealso:: + :ref:`format_canonical_extensions`. + child array parent array In an array of a :term:`nested type`, the parent array @@ -112,10 +120,10 @@ Glossary extension type storage type - A user-defined :term:`data type` that adds additional semantics - to an existing data type. This allows implementations that do - not support a particular extension type to still handle the - underlying data type (the "storage type"). + An extension type is an user-defined :term:`data type` that adds + additional semantics to an existing data type. This allows + implementations that do not support a particular extension type to + still handle the underlying data type (the "storage type"). For example, a UUID can be represented as a 16-byte fixed-size binary type. diff --git a/docs/source/index.rst b/docs/source/index.rst index b261474c6fa..60879993e45 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -74,6 +74,7 @@ target environment.** format/Versioning format/Columnar + format/CanonicalExtensions format/Flight format/FlightSql format/Integration From ace85889fed1e26c9125c6582a531029036389b6 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Tue, 20 Sep 2022 10:28:24 -0400 Subject: [PATCH 107/133] ARROW-17778: [Go][CSV] Simple CSV Reader Schema and type inference (#14171) Authored-by: Matt Topol Signed-off-by: Matt Topol --- go/arrow/compare.go | 3 + go/arrow/csv/common.go | 37 ++++ go/arrow/csv/reader.go | 327 ++++++++++++++++++++++++++++-------- go/arrow/csv/reader_test.go | 146 +++++++++++++--- 4 files changed, 428 insertions(+), 85 deletions(-) diff --git a/go/arrow/compare.go b/go/arrow/compare.go index 6cc01bc9a22..511abe22389 100644 --- a/go/arrow/compare.go +++ b/go/arrow/compare.go @@ -118,6 +118,9 @@ func TypeEqual(left, right DataType, opts ...TypeEqualOption) bool { } } return true + case *TimestampType: + r := right.(*TimestampType) + return l.Unit == r.Unit && l.TimeZone == r.TimeZone default: return reflect.DeepEqual(left, right) } diff --git a/go/arrow/csv/common.go b/go/arrow/csv/common.go index 326c7c6f019..f27ba4c0a98 100644 --- a/go/arrow/csv/common.go +++ b/go/arrow/csv/common.go @@ -174,6 +174,43 @@ func WithBoolWriter(fmtr func(bool) string) Option { } } +// WithColumnTypes allows specifying optional per-column types (disabling +// type inference on those columns). +// +// Will panic if used in conjunction with an explicit schema. +func WithColumnTypes(types map[string]arrow.DataType) Option { + return func(cfg config) { + switch cfg := cfg.(type) { + case *Reader: + if cfg.schema != nil { + panic(fmt.Errorf("%w: cannot use WithColumnTypes with explicit schema", arrow.ErrInvalid)) + } + cfg.columnTypes = types + default: + panic(fmt.Errorf("%w: WithColumnTypes only allowed for csv reader", arrow.ErrInvalid)) + } + } +} + +// WithIncludeColumns indicates the names of the columns from the CSV file +// that should actually be read and converted (in the slice's order). +// If set and non-empty, columns not in this slice will be ignored. +// +// Will panic if used in conjunction with an explicit schema. +func WithIncludeColumns(cols []string) Option { + return func(cfg config) { + switch cfg := cfg.(type) { + case *Reader: + if cfg.schema != nil { + panic(fmt.Errorf("%w: cannot use WithIncludeColumns with explicit schema", arrow.ErrInvalid)) + } + cfg.columnFilter = cols + default: + panic(fmt.Errorf("%w: WithIncludeColumns only allowed on csv Reader", arrow.ErrInvalid)) + } + } +} + func validate(schema *arrow.Schema) { for i, f := range schema.Fields() { switch ft := f.Type.(type) { diff --git a/go/arrow/csv/reader.go b/go/arrow/csv/reader.go index 091aa85e960..cd3902affca 100644 --- a/go/arrow/csv/reader.go +++ b/go/arrow/csv/reader.go @@ -17,6 +17,7 @@ package csv import ( + "encoding/base64" "encoding/csv" "errors" "fmt" @@ -24,6 +25,8 @@ import ( "strconv" "sync" "sync/atomic" + "time" + "unicode/utf8" "github.com/apache/arrow/go/v10/arrow" "github.com/apache/arrow/go/v10/arrow/array" @@ -50,12 +53,48 @@ type Reader struct { header bool once sync.Once - fieldConverter []func(field array.Builder, val string) + fieldConverter []func(val string) + columnFilter []string + columnTypes map[string]arrow.DataType + conversions []conversionColumn stringsCanBeNull bool nulls []string } +// NewInferringReader creates a CSV reader that attempts to infer the types +// and column names from the data in the first row of the CSV file. +// +// This can be further customized using the WithColumnTypes and +// WithIncludeColumns options. +func NewInferringReader(r io.Reader, opts ...Option) *Reader { + rr := &Reader{ + r: csv.NewReader(r), + refs: 1, + chunk: 1, + stringsCanBeNull: false, + } + rr.r.ReuseRecord = true + for _, opt := range opts { + opt(rr) + } + + if rr.mem == nil { + rr.mem = memory.DefaultAllocator + } + + switch { + case rr.chunk < 0: + rr.next = rr.nextall + case rr.chunk > 1: + rr.next = rr.nextn + default: + rr.next = rr.next1 + } + + return rr +} + // NewReader returns a reader that reads from the CSV file and creates // arrow.Records from the given schema. // @@ -91,36 +130,85 @@ func NewReader(r io.Reader, schema *arrow.Schema, opts ...Option) *Reader { rr.next = rr.next1 } - // Create a table of functions that will parse columns. This optimization - // allows us to specialize the implementation of each column's decoding - // and hoist type-based branches outside the inner loop. - rr.fieldConverter = make([]func(array.Builder, string), len(schema.Fields())) - for idx, field := range schema.Fields() { - rr.fieldConverter[idx] = rr.initFieldConverter(&field) - } - return rr } func (r *Reader) readHeader() error { + // if we have an explicit schema and we want to skip the header + // then just return and do everything normally + if r.schema != nil && !r.header { + return nil + } + + // either we need this first line for the header line + // or we are going to need this line to infer types records, err := r.r.Read() if err != nil { return fmt.Errorf("arrow/csv: could not read header from file: %w", err) } - if len(records) != len(r.schema.Fields()) { - return ErrMismatchFields - } + // if we have an explicit schema, then r.header must be true otherwise + // we would have skipped this via the first line of this func + if r.schema != nil { + if len(records) != len(r.schema.Fields()) { + return ErrMismatchFields + } + + fields := make([]arrow.Field, len(records)) + for idx, name := range records { + fields[idx] = r.schema.Field(idx) + fields[idx].Name = name + } - fields := make([]arrow.Field, len(records)) - for idx, name := range records { - fields[idx] = r.schema.Field(idx) - fields[idx].Name = name + meta := r.schema.Metadata() + r.schema = arrow.NewSchema(fields, &meta) + r.bld = array.NewRecordBuilder(r.mem, r.schema) + return nil } - meta := r.schema.Metadata() - r.schema = arrow.NewSchema(fields, &meta) - r.bld = array.NewRecordBuilder(r.mem, r.schema) + // we're going to need to infer some column types + r.conversions = make([]conversionColumn, 0, len(records)) + if len(r.columnFilter) == 0 { + for i, rec := range records { + // if we are skipping the header, autogenerate field names + // using "f" e.g. f0, f1, .... + if !r.header { + rec = fmt.Sprintf("f%d", i) + } + var dt arrow.DataType + if len(r.columnTypes) > 0 { + dt = r.columnTypes[rec] + } + r.conversions = append(r.conversions, conversionColumn{name: rec, index: i, typ: dt}) + } + } else { + // include columns from columnFilter (in that order) + // compute the indices of columns in the csv file + colIndices := make(map[string]int) + for i, n := range records { + // if we are skipping the header, autogenerate field names + // using "f" e.g. f0, f1, .... + if !r.header { + n = fmt.Sprintf("f%d", i) + } + colIndices[n] = i + } + + for _, n := range r.columnFilter { + idx, ok := colIndices[n] + if !ok { + return fmt.Errorf("%w: column '%s' in included columns, but doesn't exist in CSV file", + ErrMismatchFields, n) + } + var dt arrow.DataType + if len(r.columnTypes) > 0 { + dt = r.columnTypes[n] + } + r.conversions = append(r.conversions, conversionColumn{name: n, index: idx, typ: dt}) + } + r.columnFilter = nil + } + r.columnTypes = nil return nil } @@ -143,11 +231,18 @@ func (r *Reader) Record() arrow.Record { return r.cur } // Subsequent calls to Next will return false - The user should check Err() after // each call to Next to check if an error took place. func (r *Reader) Next() bool { - if r.header { - r.once.Do(func() { - r.err = r.readHeader() - }) - } + r.once.Do(func() { + r.err = r.readHeader() + if r.err == nil && r.schema != nil { + // Create a table of functions that will parse columns. This optimization + // allows us to specialize the implementation of each column's decoding + // and hoist type-based branches outside the inner loop. + r.fieldConverter = make([]func(string), len(r.schema.Fields())) + for idx := range r.schema.Fields() { + r.fieldConverter[idx] = r.initFieldConverter(r.bld.Field(idx)) + } + } + }) if r.cur != nil { r.cur.Release() @@ -243,7 +338,29 @@ func (r *Reader) validate(recs []string) { return } - if len(recs) != len(r.schema.Fields()) { + if r.bld == nil { + // initialize the record builder in the case where we're inferring a schema + r.fieldConverter = make([]func(val string), len(recs)) + fieldList := make([]arrow.Field, len(r.conversions)) + for idx, cc := range r.conversions { + fieldList[idx].Name = cc.name + fieldList[idx].Nullable = true + fieldList[idx].Type = cc.inferType(recs[cc.index]) + } + + r.schema = arrow.NewSchema(fieldList, nil) + r.bld = array.NewRecordBuilder(r.mem, r.schema) + for idx, cc := range r.conversions { + r.fieldConverter[cc.index] = r.initFieldConverter(r.bld.Field(idx)) + } + for idx, fc := range r.fieldConverter { + if fc == nil { + r.fieldConverter[idx] = func(string) {} + } + } + } + + if len(recs) != len(r.fieldConverter) { r.err = ErrMismatchFields return } @@ -260,78 +377,78 @@ func (r *Reader) isNull(val string) bool { func (r *Reader) read(recs []string) { for i, str := range recs { - r.fieldConverter[i](r.bld.Field(i), str) + r.fieldConverter[i](str) } } -func (r *Reader) initFieldConverter(field *arrow.Field) func(array.Builder, string) { - switch dt := field.Type.(type) { +func (r *Reader) initFieldConverter(bldr array.Builder) func(string) { + switch dt := bldr.Type().(type) { case *arrow.BooleanType: - return func(field array.Builder, str string) { - r.parseBool(field, str) + return func(str string) { + r.parseBool(bldr, str) } case *arrow.Int8Type: - return func(field array.Builder, str string) { - r.parseInt8(field, str) + return func(str string) { + r.parseInt8(bldr, str) } case *arrow.Int16Type: - return func(field array.Builder, str string) { - r.parseInt16(field, str) + return func(str string) { + r.parseInt16(bldr, str) } case *arrow.Int32Type: - return func(field array.Builder, str string) { - r.parseInt32(field, str) + return func(str string) { + r.parseInt32(bldr, str) } case *arrow.Int64Type: - return func(field array.Builder, str string) { - r.parseInt64(field, str) + return func(str string) { + r.parseInt64(bldr, str) } case *arrow.Uint8Type: - return func(field array.Builder, str string) { - r.parseUint8(field, str) + return func(str string) { + r.parseUint8(bldr, str) } case *arrow.Uint16Type: - return func(field array.Builder, str string) { - r.parseUint16(field, str) + return func(str string) { + r.parseUint16(bldr, str) } case *arrow.Uint32Type: - return func(field array.Builder, str string) { - r.parseUint32(field, str) + return func(str string) { + r.parseUint32(bldr, str) } case *arrow.Uint64Type: - return func(field array.Builder, str string) { - r.parseUint64(field, str) + return func(str string) { + r.parseUint64(bldr, str) } case *arrow.Float32Type: - return func(field array.Builder, str string) { - r.parseFloat32(field, str) + return func(str string) { + r.parseFloat32(bldr, str) } case *arrow.Float64Type: - return func(field array.Builder, str string) { - r.parseFloat64(field, str) + return func(str string) { + r.parseFloat64(bldr, str) } case *arrow.StringType: // specialize the implementation when we know we cannot have nulls if r.stringsCanBeNull { - return func(field array.Builder, str string) { + return func(str string) { if r.isNull(str) { - field.AppendNull() + bldr.AppendNull() } else { - field.(*array.StringBuilder).Append(str) + bldr.(*array.StringBuilder).Append(str) } } } else { - return func(field array.Builder, str string) { - field.(*array.StringBuilder).Append(str) + return func(str string) { + bldr.(*array.StringBuilder).Append(str) } } case *arrow.TimestampType: - return func(field array.Builder, str string) { - r.parseTimestamp(field, str, dt.Unit) + return func(str string) { + r.parseTimestamp(bldr, str, dt.Unit) } default: - panic(fmt.Errorf("arrow/csv: unhandled field type %T", field.Type)) + panic(fmt.Errorf("arrow/csv: unhandled field type %T", bldr.Type())) } } @@ -341,14 +458,9 @@ func (r *Reader) parseBool(field array.Builder, str string) { return } - var v bool - switch str { - case "false", "False", "0": - v = false - case "true", "True", "1": - v = true - default: - r.err = fmt.Errorf("unrecognized boolean: %s", str) + v, err := strconv.ParseBool(str) + if err != nil { + r.err = fmt.Errorf("%w: unrecognized boolean: %s", err, str) field.AppendNull() return } @@ -551,6 +663,89 @@ func (r *Reader) Release() { } } +type conversionColumn struct { + name string + index int + typ arrow.DataType +} + +func (c conversionColumn) inferType(v string) arrow.DataType { + if c.typ != nil { + return c.typ + } + + var err error + c.typ = arrow.PrimitiveTypes.Int64 + for { + // attempt to parse + if err = tryParse(v, c.typ); err == nil { + return c.typ + } + + switch dt := c.typ.(type) { + case *arrow.Int64Type: + c.typ = arrow.FixedWidthTypes.Boolean + case *arrow.BooleanType: + c.typ = arrow.FixedWidthTypes.Date32 + case *arrow.Date32Type: + c.typ = arrow.FixedWidthTypes.Time32s + case *arrow.Time32Type: + c.typ = &arrow.TimestampType{Unit: arrow.Second} + case *arrow.TimestampType: + if dt.TimeZone == "" { + if dt.Unit == arrow.Second { + c.typ = &arrow.TimestampType{Unit: arrow.Nanosecond} + } else { + c.typ = &arrow.TimestampType{Unit: arrow.Second, TimeZone: "UTC"} + } + } else { + if dt.Unit == arrow.Second { + c.typ = &arrow.TimestampType{Unit: arrow.Nanosecond, TimeZone: "UTC"} + } else { + c.typ = arrow.PrimitiveTypes.Float64 + } + } + case *arrow.Float64Type: + c.typ = arrow.BinaryTypes.String + case *arrow.StringType: + // binary is the fallback type + return arrow.BinaryTypes.Binary + } + } +} + +func tryParse(val string, dt arrow.DataType) error { + switch dt := dt.(type) { + case *arrow.Int64Type: + _, err := strconv.ParseInt(val, 10, 64) + return err + case *arrow.BooleanType: + _, err := strconv.ParseBool(val) + return err + case *arrow.Date32Type: + _, err := time.Parse("2006-01-02", val) + return err + case *arrow.Time32Type: + _, err := arrow.Time32FromString(val, dt.Unit) + return err + case *arrow.TimestampType: + _, err := arrow.TimestampFromString(val, dt.Unit) + return err + case *arrow.Float64Type: + _, err := strconv.ParseFloat(val, 64) + return err + case *arrow.StringType: + if !utf8.ValidString(val) { + return arrow.ErrInvalid + } + return nil + case *arrow.BinaryType: + _, err := base64.RawStdEncoding.DecodeString(val) + return err + } + panic("shouldn't end up here") +} + var ( _ array.RecordReader = (*Reader)(nil) ) diff --git a/go/arrow/csv/reader_test.go b/go/arrow/csv/reader_test.go index 9b735ba6ddc..4096084b076 100644 --- a/go/arrow/csv/reader_test.go +++ b/go/arrow/csv/reader_test.go @@ -18,14 +18,20 @@ package csv_test import ( "bytes" + stdcsv "encoding/csv" "fmt" "io/ioutil" "log" + "os" + "strings" "testing" "github.com/apache/arrow/go/v10/arrow" + "github.com/apache/arrow/go/v10/arrow/array" "github.com/apache/arrow/go/v10/arrow/csv" "github.com/apache/arrow/go/v10/arrow/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Example() { @@ -257,19 +263,19 @@ func testCSVReader(t *testing.T, filepath string, withHeader bool) { schema := arrow.NewSchema( []arrow.Field{ - arrow.Field{Name: "bool", Type: arrow.FixedWidthTypes.Boolean}, - arrow.Field{Name: "i8", Type: arrow.PrimitiveTypes.Int8}, - arrow.Field{Name: "i16", Type: arrow.PrimitiveTypes.Int16}, - arrow.Field{Name: "i32", Type: arrow.PrimitiveTypes.Int32}, - arrow.Field{Name: "i64", Type: arrow.PrimitiveTypes.Int64}, - arrow.Field{Name: "u8", Type: arrow.PrimitiveTypes.Uint8}, - arrow.Field{Name: "u16", Type: arrow.PrimitiveTypes.Uint16}, - arrow.Field{Name: "u32", Type: arrow.PrimitiveTypes.Uint32}, - arrow.Field{Name: "u64", Type: arrow.PrimitiveTypes.Uint64}, - arrow.Field{Name: "f32", Type: arrow.PrimitiveTypes.Float32}, - arrow.Field{Name: "f64", Type: arrow.PrimitiveTypes.Float64}, - arrow.Field{Name: "str", Type: arrow.BinaryTypes.String}, - arrow.Field{Name: "ts", Type: arrow.FixedWidthTypes.Timestamp_ms}, + {Name: "bool", Type: arrow.FixedWidthTypes.Boolean}, + {Name: "i8", Type: arrow.PrimitiveTypes.Int8}, + {Name: "i16", Type: arrow.PrimitiveTypes.Int16}, + {Name: "i32", Type: arrow.PrimitiveTypes.Int32}, + {Name: "i64", Type: arrow.PrimitiveTypes.Int64}, + {Name: "u8", Type: arrow.PrimitiveTypes.Uint8}, + {Name: "u16", Type: arrow.PrimitiveTypes.Uint16}, + {Name: "u32", Type: arrow.PrimitiveTypes.Uint32}, + {Name: "u64", Type: arrow.PrimitiveTypes.Uint64}, + {Name: "f32", Type: arrow.PrimitiveTypes.Float32}, + {Name: "f64", Type: arrow.PrimitiveTypes.Float64}, + {Name: "str", Type: arrow.BinaryTypes.String}, + {Name: "ts", Type: arrow.FixedWidthTypes.Timestamp_ms}, }, nil, ) @@ -379,9 +385,9 @@ func TestCSVReaderWithChunk(t *testing.T) { schema := arrow.NewSchema( []arrow.Field{ - arrow.Field{Name: "i64", Type: arrow.PrimitiveTypes.Int64}, - arrow.Field{Name: "f64", Type: arrow.PrimitiveTypes.Float64}, - arrow.Field{Name: "str", Type: arrow.BinaryTypes.String}, + {Name: "i64", Type: arrow.PrimitiveTypes.Int64}, + {Name: "f64", Type: arrow.PrimitiveTypes.Float64}, + {Name: "str", Type: arrow.BinaryTypes.String}, }, nil, ) @@ -651,9 +657,9 @@ func benchRead(b *testing.B, raw []byte, rows, cols, chunks int) { var fields []arrow.Field for i := 0; i < cols; i++ { fields = append(fields, []arrow.Field{ - arrow.Field{Name: fmt.Sprintf("i64-%d", i), Type: arrow.PrimitiveTypes.Int64}, - arrow.Field{Name: fmt.Sprintf("f64-%d", i), Type: arrow.PrimitiveTypes.Float64}, - arrow.Field{Name: fmt.Sprintf("str-%d", i), Type: arrow.BinaryTypes.String}, + {Name: fmt.Sprintf("i64-%d", i), Type: arrow.PrimitiveTypes.Int64}, + {Name: fmt.Sprintf("f64-%d", i), Type: arrow.PrimitiveTypes.Float64}, + {Name: fmt.Sprintf("str-%d", i), Type: arrow.BinaryTypes.String}, }...) } @@ -682,3 +688,105 @@ func benchRead(b *testing.B, raw []byte, rows, cols, chunks int) { } } } + +func TestInferringSchema(t *testing.T) { + var b bytes.Buffer + wr := stdcsv.NewWriter(&b) + wr.WriteAll([][]string{ + {"i64", "f64", "str", "ts", "bool"}, + {"123", "1.23", "foobar", "2022-05-09T00:01:01", "false"}, + {"456", "45.6", "baz", "2022-05-09T23:59:59", "true"}, + {"null", "NULL", "null", "N/A", "null"}, + {"-78", "-1.25", "", "2021-01-01T10:11:12", "TRUE"}, + }) + wr.Flush() + + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + r := csv.NewInferringReader(&b, csv.WithAllocator(mem), csv.WithHeader(true), csv.WithNullReader(true, defaultNullValues...)) + defer r.Release() + + assert.Nil(t, r.Schema()) + assert.True(t, r.Next()) + assert.NoError(t, r.Err()) + + expSchema := arrow.NewSchema([]arrow.Field{ + {Name: "i64", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + {Name: "f64", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + {Name: "str", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "ts", Type: &arrow.TimestampType{Unit: arrow.Second}, Nullable: true}, + {Name: "bool", Type: arrow.FixedWidthTypes.Boolean, Nullable: true}, + }, nil) + + exp, _, _ := array.RecordFromJSON(mem, expSchema, strings.NewReader(`[ + {"i64": 123, "f64": 1.23, "str": "foobar", "ts": "2022-05-09T00:01:01", "bool": false}, + {"i64": 456, "f64": 45.6, "str": "baz", "ts": "2022-05-09T23:59:59", "bool": true}, + {"i64": null, "f64": null, "str": null, "ts": null, "bool": null}, + {"i64": -78, "f64": -1.25, "str": null, "ts": "2021-01-01T10:11:12", "bool": true} + ]`)) + defer exp.Release() + + assertRowEqual := func(expected, actual arrow.Record, row int) { + ex := expected.NewSlice(int64(row), int64(row+1)) + defer ex.Release() + assert.Truef(t, array.RecordEqual(ex, actual), "expected: %s\ngot: %s", ex, actual) + } + + assert.True(t, expSchema.Equal(r.Schema()), expSchema.String(), r.Schema().String()) + // verify first row: + assertRowEqual(exp, r.Record(), 0) + assert.True(t, r.Next()) + assertRowEqual(exp, r.Record(), 1) + assert.True(t, r.Next()) + assertRowEqual(exp, r.Record(), 2) + assert.True(t, r.Next()) + assertRowEqual(exp, r.Record(), 3) + assert.False(t, r.Next()) +} + +func TestInferCSVOptions(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + f, err := os.Open("testdata/header.csv") + require.NoError(t, err) + defer f.Close() + + r := csv.NewInferringReader(f, csv.WithAllocator(mem), + csv.WithComma(';'), csv.WithComment('#'), csv.WithHeader(true), + csv.WithNullReader(true, defaultNullValues...), + csv.WithIncludeColumns([]string{"f64", "i32", "bool", "str", "i64", "u64", "i8"}), + csv.WithColumnTypes(map[string]arrow.DataType{ + "i32": arrow.PrimitiveTypes.Int32, + "i8": arrow.PrimitiveTypes.Int8, + "i16": arrow.PrimitiveTypes.Int16, + "u64": arrow.PrimitiveTypes.Uint64, + }), csv.WithChunk(-1)) + defer r.Release() + + assert.True(t, r.Next()) + rec := r.Record() + rec.Retain() + defer rec.Release() + assert.False(t, r.Next()) + + expSchema := arrow.NewSchema([]arrow.Field{ + {Name: "f64", Type: arrow.PrimitiveTypes.Float64, Nullable: true}, + {Name: "i32", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + {Name: "bool", Type: arrow.FixedWidthTypes.Boolean, Nullable: true}, + {Name: "str", Type: arrow.BinaryTypes.String, Nullable: true}, + {Name: "i64", Type: arrow.PrimitiveTypes.Int64, Nullable: true}, + {Name: "u64", Type: arrow.PrimitiveTypes.Uint64, Nullable: true}, + {Name: "i8", Type: arrow.PrimitiveTypes.Int8, Nullable: true}, + }, nil) + expRec, _, _ := array.RecordFromJSON(mem, expSchema, strings.NewReader(`[ + {"f64": 1.1, "i32": -1, "bool": true, "str": "str-1", "i64": -1, "u64": 1, "i8": -1}, + {"f64": 2.2, "i32": -2, "bool": false, "str": "str-2", "i64": -2, "u64": 2, "i8": -2}, + {"f64": null, "i32": null, "bool": null, "str": null, "i64": null, "u64": null, "i8": null} + ]`)) + defer expRec.Release() + + assert.True(t, expSchema.Equal(r.Schema()), expSchema.String(), r.Schema().String()) + assert.Truef(t, array.RecordEqual(expRec, rec), "expected: %s\ngot: %s", expRec, rec) +} From 40ec95646962cccdcd62032c80e8506d4c275bc6 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Tue, 20 Sep 2022 10:29:32 -0400 Subject: [PATCH 108/133] ARROW-17678: [Go] Filter kernels for Record Batches and Tables (#14156) Authored-by: Matt Topol Signed-off-by: Matt Topol --- go/arrow/array/compare.go | 2 +- go/arrow/array/util.go | 34 +++ go/arrow/compute/executor.go | 8 +- go/arrow/compute/internal/exec/utils.go | 55 ++++ go/arrow/compute/internal/exec/utils_test.go | 109 ++++++++ go/arrow/compute/selection.go | 182 ++++++++++++- go/arrow/compute/vector_selection_test.go | 273 +++++++++++++++++++ go/arrow/table.go | 10 +- 8 files changed, 665 insertions(+), 8 deletions(-) create mode 100644 go/arrow/compute/internal/exec/utils_test.go diff --git a/go/arrow/array/compare.go b/go/arrow/array/compare.go index 78075cd0f41..828ed477c72 100644 --- a/go/arrow/array/compare.go +++ b/go/arrow/array/compare.go @@ -117,7 +117,7 @@ func ChunkedEqual(left, right *arrow.Chunked) bool { return false } - var isequal bool + var isequal bool = true chunkedBinaryApply(left, right, func(left arrow.Array, lbeg, lend int64, right arrow.Array, rbeg, rend int64) bool { isequal = SliceEqual(left, lbeg, lend, right, rbeg, rend) return isequal diff --git a/go/arrow/array/util.go b/go/arrow/array/util.go index 53634e4ec5c..a40b3e0ba54 100644 --- a/go/arrow/array/util.go +++ b/go/arrow/array/util.go @@ -237,6 +237,19 @@ func RecordToJSON(rec arrow.Record, w io.Writer) error { return nil } +func TableFromJSON(mem memory.Allocator, sc *arrow.Schema, recJSON []string, opt ...FromJSONOption) (arrow.Table, error) { + batches := make([]arrow.Record, len(recJSON)) + for i, batchJSON := range recJSON { + batch, _, err := RecordFromJSON(mem, sc, strings.NewReader(batchJSON), opt...) + if err != nil { + return nil, err + } + defer batch.Release() + batches[i] = batch + } + return NewTableFromRecords(sc, batches), nil +} + func getDictArrayData(mem memory.Allocator, valueType arrow.DataType, memoTable hashing.MemoTable, startOffset int) (*Data, error) { dictLen := memoTable.Size() - startOffset buffers := []*memory.Buffer{nil, nil} @@ -300,6 +313,27 @@ func DictArrayFromJSON(mem memory.Allocator, dt *arrow.DictionaryType, indicesJS return NewDictionaryArray(dt, indices, dict), nil } +func ChunkedFromJSON(mem memory.Allocator, dt arrow.DataType, chunkStrs []string, opts ...FromJSONOption) (*arrow.Chunked, error) { + chunks := make([]arrow.Array, len(chunkStrs)) + defer func() { + for _, c := range chunks { + if c != nil { + c.Release() + } + } + }() + + var err error + for i, c := range chunkStrs { + chunks[i], _, err = FromJSON(mem, dt, strings.NewReader(c), opts...) + if err != nil { + return nil, err + } + } + + return arrow.NewChunked(dt, chunks), nil +} + func getMaxBufferLen(dt arrow.DataType, length int) int { bufferLen := int(bitutil.BytesForBits(int64(length))) diff --git a/go/arrow/compute/executor.go b/go/arrow/compute/executor.go index c20bd9e4684..11340b295a7 100644 --- a/go/arrow/compute/executor.go +++ b/go/arrow/compute/executor.go @@ -970,7 +970,13 @@ func (v *vectorExecutor) WrapResults(ctx context.Context, out <-chan Datum, hasC ) toChunked := func() { - acc = output.(ArrayLikeDatum).Chunks() + out := output.(ArrayLikeDatum).Chunks() + acc = make([]arrow.Array, 0, len(out)) + for _, o := range out { + if o.Len() > 0 { + acc = append(acc, o) + } + } if output.Kind() != KindChunked { output.Release() } diff --git a/go/arrow/compute/internal/exec/utils.go b/go/arrow/compute/internal/exec/utils.go index 9192b85d4a6..0d6a21ef024 100644 --- a/go/arrow/compute/internal/exec/utils.go +++ b/go/arrow/compute/internal/exec/utils.go @@ -18,6 +18,7 @@ package exec import ( "fmt" + "math" "reflect" "unsafe" @@ -174,3 +175,57 @@ func ArrayFromSlice[T NumericTypes](mem memory.Allocator, data []T) arrow.Array bldr.AppendValues(data, nil) return bldr.NewArray() } + +func RechunkArraysConsistently(groups [][]arrow.Array) [][]arrow.Array { + if len(groups) <= 1 { + return groups + } + + var totalLen int + for _, a := range groups[0] { + totalLen += a.Len() + } + + if totalLen == 0 { + return groups + } + + rechunked := make([][]arrow.Array, len(groups)) + offsets := make([]int, len(groups)) + // scan all array vectors at once, rechunking along the way + var start int64 + for start < int64(totalLen) { + // first compute max possible length for next chunk + chunkLength := math.MaxInt64 + for i, g := range groups { + offset := offsets[i] + // skip any done arrays including 0-length + for offset == g[0].Len() { + g = g[1:] + offset = 0 + } + arr := g[0] + chunkLength = Min(chunkLength, arr.Len()-offset) + + offsets[i] = offset + groups[i] = g + } + + // now slice all the arrays along this chunk size + for i, g := range groups { + offset := offsets[i] + arr := g[0] + if offset == 0 && arr.Len() == chunkLength { + // slice spans entire array + arr.Retain() + rechunked[i] = append(rechunked[i], arr) + } else { + rechunked[i] = append(rechunked[i], array.NewSlice(arr, int64(offset), int64(offset+chunkLength))) + } + offsets[i] += chunkLength + } + + start += int64(chunkLength) + } + return rechunked +} diff --git a/go/arrow/compute/internal/exec/utils_test.go b/go/arrow/compute/internal/exec/utils_test.go new file mode 100644 index 00000000000..1917429f7d1 --- /dev/null +++ b/go/arrow/compute/internal/exec/utils_test.go @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package exec_test + +import ( + "testing" + + "github.com/apache/arrow/go/v10/arrow" + "github.com/apache/arrow/go/v10/arrow/array" + "github.com/apache/arrow/go/v10/arrow/compute/internal/exec" + "github.com/apache/arrow/go/v10/arrow/memory" + "github.com/stretchr/testify/assert" +) + +func TestRechunkConsistentArraysTrivial(t *testing.T) { + var groups [][]arrow.Array + rechunked := exec.RechunkArraysConsistently(groups) + assert.Zero(t, rechunked) + + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + a1 := exec.ArrayFromSlice(mem, []int16{}) + defer a1.Release() + a2 := exec.ArrayFromSlice(mem, []int16{}) + defer a2.Release() + b1 := exec.ArrayFromSlice(mem, []int32{}) + defer b1.Release() + groups = [][]arrow.Array{{a1, a2}, {}, {b1}} + rechunked = exec.RechunkArraysConsistently(groups) + assert.Len(t, rechunked, 3) + + for _, arrvec := range rechunked { + for _, arr := range arrvec { + assert.Zero(t, arr.Len()) + } + } +} + +func assertEqual[T exec.NumericTypes](t *testing.T, mem memory.Allocator, arr arrow.Array, data []T) { + exp := exec.ArrayFromSlice(mem, data) + defer exp.Release() + assert.Truef(t, array.Equal(exp, arr), "expected: %s\ngot: %s", exp, arr) +} + +func TestRechunkArraysConsistentlyPlain(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer mem.AssertSize(t, 0) + + a1 := exec.ArrayFromSlice(mem, []int16{1, 2, 3}) + defer a1.Release() + a2 := exec.ArrayFromSlice(mem, []int16{4, 5}) + defer a2.Release() + a3 := exec.ArrayFromSlice(mem, []int16{6, 7, 8, 9}) + defer a3.Release() + + b1 := exec.ArrayFromSlice(mem, []int32{41, 42}) + defer b1.Release() + b2 := exec.ArrayFromSlice(mem, []int32{43, 44, 45}) + defer b2.Release() + b3 := exec.ArrayFromSlice(mem, []int32{46, 47}) + defer b3.Release() + b4 := exec.ArrayFromSlice(mem, []int32{48, 49}) + defer b4.Release() + + groups := [][]arrow.Array{{a1, a2, a3}, {b1, b2, b3, b4}} + rechunked := exec.RechunkArraysConsistently(groups) + assert.Len(t, rechunked, 2) + ra := rechunked[0] + rb := rechunked[1] + + assert.Len(t, ra, 5) + assertEqual(t, mem, ra[0], []int16{1, 2}) + ra[0].Release() + assertEqual(t, mem, ra[1], []int16{3}) + ra[1].Release() + assertEqual(t, mem, ra[2], []int16{4, 5}) + ra[2].Release() + assertEqual(t, mem, ra[3], []int16{6, 7}) + ra[3].Release() + assertEqual(t, mem, ra[4], []int16{8, 9}) + ra[4].Release() + + assert.Len(t, rb, 5) + assertEqual(t, mem, rb[0], []int32{41, 42}) + rb[0].Release() + assertEqual(t, mem, rb[1], []int32{43}) + rb[1].Release() + assertEqual(t, mem, rb[2], []int32{44, 45}) + rb[2].Release() + assertEqual(t, mem, rb[3], []int32{46, 47}) + rb[3].Release() + assertEqual(t, mem, rb[4], []int32{48, 49}) + rb[4].Release() +} diff --git a/go/arrow/compute/selection.go b/go/arrow/compute/selection.go index ac5b4e4c653..fd7d941184e 100644 --- a/go/arrow/compute/selection.go +++ b/go/arrow/compute/selection.go @@ -32,15 +32,39 @@ var ( filterMetaFunc = NewMetaFunction("filter", Binary(), filterDoc, func(ctx context.Context, opts FunctionOptions, args ...Datum) (Datum, error) { if args[1].(ArrayLikeDatum).Type().ID() != arrow.BOOL { - return nil, fmt.Errorf("%w: fitler argument must be boolean type", + return nil, fmt.Errorf("%w: filter argument must be boolean type", arrow.ErrNotImplemented) } switch args[0].Kind() { case KindRecord: - return nil, fmt.Errorf("%w: record batch filtering", arrow.ErrNotImplemented) + filtOpts, ok := opts.(*FilterOptions) + if !ok { + return nil, fmt.Errorf("%w: invalid options type", arrow.ErrInvalid) + } + + if filter, ok := args[1].(*ArrayDatum); ok { + filterArr := filter.MakeArray() + defer filterArr.Release() + rec, err := FilterRecordBatch(ctx, args[0].(*RecordDatum).Value, filterArr, filtOpts) + if err != nil { + return nil, err + } + return &RecordDatum{Value: rec}, nil + } + return nil, fmt.Errorf("%w: record batch filtering only implemented for Array filter", arrow.ErrNotImplemented) case KindTable: - return nil, fmt.Errorf("%w: table filtering", arrow.ErrNotImplemented) + filtOpts, ok := opts.(*FilterOptions) + if !ok { + return nil, fmt.Errorf("%w: invalid options type", arrow.ErrInvalid) + } + + tbl, err := FilterTable(ctx, args[0].(*TableDatum).Value, args[1], filtOpts) + if err != nil { + return nil, err + } + return &TableDatum{Value: tbl}, nil + default: return CallFunction(ctx, "array_filter", opts, args...) } @@ -343,3 +367,155 @@ func FilterArray(ctx context.Context, values, filter arrow.Array, options Filter defer outDatum.Release() return outDatum.(*ArrayDatum).MakeArray(), nil } + +func FilterRecordBatch(ctx context.Context, batch arrow.Record, filter arrow.Array, opts *FilterOptions) (arrow.Record, error) { + if batch.NumRows() != int64(filter.Len()) { + return nil, fmt.Errorf("%w: filter inputs must all be the same length", arrow.ErrInvalid) + } + + var filterSpan exec.ArraySpan + filterSpan.SetMembers(filter.Data()) + + indices, err := kernels.GetTakeIndices(exec.GetAllocator(ctx), &filterSpan, opts.NullSelection) + if err != nil { + return nil, err + } + defer indices.Release() + + indicesArr := array.MakeFromData(indices) + defer indicesArr.Release() + + cols := make([]arrow.Array, batch.NumCols()) + defer func() { + for _, c := range cols { + if c != nil { + c.Release() + } + } + }() + eg, cctx := errgroup.WithContext(ctx) + eg.SetLimit(GetExecCtx(ctx).NumParallel) + for i, col := range batch.Columns() { + i, col := i, col + eg.Go(func() error { + out, err := TakeArrayOpts(cctx, col, indicesArr, kernels.TakeOptions{BoundsCheck: false}) + if err != nil { + return err + } + cols[i] = out + return nil + }) + } + + if err := eg.Wait(); err != nil { + return nil, err + } + + return array.NewRecord(batch.Schema(), cols, int64(indicesArr.Len())), nil +} + +func FilterTable(ctx context.Context, tbl arrow.Table, filter Datum, opts *FilterOptions) (arrow.Table, error) { + if tbl.NumRows() != filter.Len() { + return nil, fmt.Errorf("%w: filter inputs must all be the same length", arrow.ErrInvalid) + } + + if tbl.NumRows() == 0 { + cols := make([]arrow.Column, tbl.NumCols()) + for i := 0; i < int(tbl.NumCols()); i++ { + cols[i] = *tbl.Column(i) + } + return array.NewTable(tbl.Schema(), cols, 0), nil + } + + // last input element will be the filter array + nCols := tbl.NumCols() + inputs := make([][]arrow.Array, nCols+1) + for i := int64(0); i < nCols; i++ { + inputs[i] = tbl.Column(int(i)).Data().Chunks() + } + + switch ft := filter.(type) { + case *ArrayDatum: + inputs[nCols] = ft.Chunks() + defer inputs[nCols][0].Release() + case *ChunkedDatum: + inputs[nCols] = ft.Chunks() + default: + return nil, fmt.Errorf("%w: filter should be array-like", arrow.ErrNotImplemented) + } + + // rechunk inputs to allow consistent iteration over the respective chunks + inputs = exec.RechunkArraysConsistently(inputs) + + // instead of filtering each column with the boolean filter + // (which would be slow if the table has a large number of columns) + // convert each filter chunk to indices and take() the column + mem := GetAllocator(ctx) + outCols := make([][]arrow.Array, nCols) + // pre-size the output + nChunks := len(inputs[nCols]) + for i := range outCols { + outCols[i] = make([]arrow.Array, nChunks) + } + var outNumRows int64 + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() + + eg, cctx := errgroup.WithContext(ctx) + eg.SetLimit(GetExecCtx(cctx).NumParallel) + + var filterSpan exec.ArraySpan + for i, filterChunk := range inputs[nCols] { + filterSpan.SetMembers(filterChunk.Data()) + indices, err := kernels.GetTakeIndices(mem, &filterSpan, opts.NullSelection) + if err != nil { + return nil, err + } + defer indices.Release() + filterChunk.Release() + if indices.Len() == 0 { + for col := int64(0); col < nCols; col++ { + inputs[col][i].Release() + } + continue + } + + // take from all input columns + outNumRows += int64(indices.Len()) + indicesDatum := NewDatum(indices) + defer indicesDatum.Release() + + for col := int64(0); col < nCols; col++ { + columnChunk := inputs[col][i] + defer columnChunk.Release() + i := i + col := col + eg.Go(func() error { + columnDatum := NewDatum(columnChunk) + defer columnDatum.Release() + out, err := Take(cctx, kernels.TakeOptions{BoundsCheck: false}, columnDatum, indicesDatum) + if err != nil { + return err + } + defer out.Release() + outCols[col][i] = out.(*ArrayDatum).MakeArray() + return nil + }) + } + } + + if err := eg.Wait(); err != nil { + return nil, err + } + + outChunks := make([]arrow.Column, nCols) + for i, chunks := range outCols { + chk := arrow.NewChunked(tbl.Column(i).DataType(), chunks) + outChunks[i] = *arrow.NewColumn(tbl.Schema().Field(i), chk) + defer outChunks[i].Release() + chk.Release() + } + + return array.NewTable(tbl.Schema(), outChunks, outNumRows), nil +} diff --git a/go/arrow/compute/vector_selection_test.go b/go/arrow/compute/vector_selection_test.go index 6b01a11b74d..31ca6b6e640 100644 --- a/go/arrow/compute/vector_selection_test.go +++ b/go/arrow/compute/vector_selection_test.go @@ -814,6 +814,276 @@ func (f *FilterKernelWithStruct) TestStruct() { f.assertFilterJSON(dt, structJSON, `[true, false, true, false]`, `[null, {"a": 2, "b": "hello"}]`) } +type FilterKernelWithRecordBatch struct { + FilterKernelTestSuite +} + +func (f *FilterKernelWithRecordBatch) doFilter(sc *arrow.Schema, batchJSON, selection string, opts compute.FilterOptions) (arrow.Record, error) { + rec, _, err := array.RecordFromJSON(f.mem, sc, strings.NewReader(batchJSON), array.WithUseNumber()) + if err != nil { + return nil, err + } + defer rec.Release() + + batch := compute.NewDatum(rec) + defer batch.Release() + + filter, _, _ := array.FromJSON(f.mem, arrow.FixedWidthTypes.Boolean, strings.NewReader(selection)) + defer filter.Release() + filterDatum := compute.NewDatum(filter) + defer filterDatum.Release() + + outDatum, err := compute.Filter(context.TODO(), batch, filterDatum, opts) + if err != nil { + return nil, err + } + + return outDatum.(*compute.RecordDatum).Value, nil +} + +func (f *FilterKernelWithRecordBatch) assertFilter(sc *arrow.Schema, batchJSON, selection string, opts compute.FilterOptions, expectedBatch string) { + actual, err := f.doFilter(sc, batchJSON, selection, opts) + f.Require().NoError(err) + defer actual.Release() + + expected, _, err := array.RecordFromJSON(f.mem, sc, strings.NewReader(expectedBatch), array.WithUseNumber()) + f.Require().NoError(err) + defer expected.Release() + + f.Truef(array.RecordEqual(expected, actual), "expected: %s\ngot: %s", expected, actual) +} + +func (f *FilterKernelWithRecordBatch) TestFilterRecord() { + fields := []arrow.Field{ + {Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + {Name: "b", Type: arrow.BinaryTypes.String, Nullable: true}, + } + sc := arrow.NewSchema(fields, nil) + + batchJSON := `[ + {"a": null, "b": "yo"}, + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ]` + + for _, opts := range []compute.FilterOptions{f.emitNulls, f.dropOpts} { + f.assertFilter(sc, batchJSON, `[false, false, false, false]`, opts, `[]`) + f.assertFilter(sc, batchJSON, `[true, true, true, true]`, opts, batchJSON) + f.assertFilter(sc, batchJSON, `[true, false, true, false]`, opts, `[ + {"a": null, "b": "yo"}, + {"a": 2, "b": "hello"} + ]`) + } + + f.assertFilter(sc, batchJSON, `[false, true, true, null]`, f.dropOpts, `[ + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"} + ]`) + + f.assertFilter(sc, batchJSON, `[false, true, true, null]`, f.emitNulls, `[ + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": null, "b": null} + ]`) +} + +type FilterKernelWithChunked struct { + FilterKernelTestSuite +} + +func (f *FilterKernelWithChunked) filterWithArray(dt arrow.DataType, values []string, filterStr string) (*arrow.Chunked, error) { + chk, err := array.ChunkedFromJSON(f.mem, dt, values) + f.Require().NoError(err) + defer chk.Release() + + input := compute.NewDatum(chk) + defer input.Release() + + filter, _, _ := array.FromJSON(f.mem, arrow.FixedWidthTypes.Boolean, strings.NewReader(filterStr)) + defer filter.Release() + + filterDatum := compute.NewDatum(filter) + defer filterDatum.Release() + + out, err := compute.Filter(context.TODO(), input, filterDatum, *compute.DefaultFilterOptions()) + if err != nil { + return nil, err + } + return out.(*compute.ChunkedDatum).Value, nil +} + +func (f *FilterKernelWithChunked) filterWithChunked(dt arrow.DataType, values, filter []string) (*arrow.Chunked, error) { + chk, err := array.ChunkedFromJSON(f.mem, dt, values) + f.Require().NoError(err) + defer chk.Release() + + input := compute.NewDatum(chk) + defer input.Release() + + filtChk, err := array.ChunkedFromJSON(f.mem, arrow.FixedWidthTypes.Boolean, filter) + f.Require().NoError(err) + defer filtChk.Release() + + filtDatum := compute.NewDatum(filtChk) + defer filtDatum.Release() + + out, err := compute.Filter(context.TODO(), input, filtDatum, *compute.DefaultFilterOptions()) + if err != nil { + return nil, err + } + return out.(*compute.ChunkedDatum).Value, nil +} + +func (f *FilterKernelWithChunked) assertFilter(dt arrow.DataType, values []string, filter string, expected []string) { + actual, err := f.filterWithArray(dt, values, filter) + f.Require().NoError(err) + defer actual.Release() + + expectedResult, _ := array.ChunkedFromJSON(f.mem, dt, expected) + defer expectedResult.Release() + if !f.True(array.ChunkedEqual(expectedResult, actual)) { + var s strings.Builder + s.WriteString("expected: \n") + for _, c := range expectedResult.Chunks() { + fmt.Fprintf(&s, "%s\n", c) + } + s.WriteString("actual: \n") + for _, c := range actual.Chunks() { + fmt.Fprintf(&s, "%s\n", c) + } + f.T().Log(s.String()) + } +} + +func (f *FilterKernelWithChunked) assertChunkedFilter(dt arrow.DataType, values, filter, expected []string) { + actual, err := f.filterWithChunked(dt, values, filter) + f.Require().NoError(err) + defer actual.Release() + + expectedResult, _ := array.ChunkedFromJSON(f.mem, dt, expected) + defer expectedResult.Release() + if !f.True(array.ChunkedEqual(expectedResult, actual)) { + var s strings.Builder + s.WriteString("expected: \n") + for _, c := range expectedResult.Chunks() { + fmt.Fprintf(&s, "%s\n", c) + } + s.WriteString("actual: \n") + for _, c := range actual.Chunks() { + fmt.Fprintf(&s, "%s\n", c) + } + f.T().Log(s.String()) + } +} + +func (f *FilterKernelWithChunked) TestFilterChunked() { + f.assertFilter(arrow.PrimitiveTypes.Int8, []string{`[]`}, `[]`, []string{}) + f.assertChunkedFilter(arrow.PrimitiveTypes.Int8, []string{`[]`}, []string{`[]`}, []string{}) + + f.assertFilter(arrow.PrimitiveTypes.Int8, []string{`[7]`, `[8, 9]`}, `[false, true, false]`, []string{`[8]`}) + f.assertChunkedFilter(arrow.PrimitiveTypes.Int8, []string{`[7]`, `[8, 9]`}, []string{`[false]`, `[true, false]`}, []string{`[8]`}) + f.assertChunkedFilter(arrow.PrimitiveTypes.Int8, []string{`[7]`, `[8, 9]`}, []string{`[false, true]`, `[false]`}, []string{`[8]`}) + + _, err := f.filterWithArray(arrow.PrimitiveTypes.Int8, []string{`[7]`, `[8, 9]`}, `[false, true, false, true, true]`) + f.ErrorIs(err, arrow.ErrInvalid) + _, err = f.filterWithChunked(arrow.PrimitiveTypes.Int8, []string{`[7]`, `[8, 9]`}, []string{`[ false, true, false]`, `[true, true]`}) + f.ErrorIs(err, arrow.ErrInvalid) +} + +type FilterKernelWithTable struct { + FilterKernelTestSuite +} + +func (f *FilterKernelWithTable) filterWithArray(sc *arrow.Schema, values []string, filter string, opts compute.FilterOptions) (arrow.Table, error) { + tbl, err := array.TableFromJSON(f.mem, sc, values) + if err != nil { + return nil, err + } + defer tbl.Release() + + filterArr, _, _ := array.FromJSON(f.mem, arrow.FixedWidthTypes.Boolean, strings.NewReader(filter)) + defer filterArr.Release() + + out, err := compute.Filter(context.TODO(), &compute.TableDatum{Value: tbl}, &compute.ArrayDatum{Value: filterArr.Data()}, opts) + if err != nil { + return nil, err + } + return out.(*compute.TableDatum).Value, nil +} + +func (f *FilterKernelWithTable) filterWithChunked(sc *arrow.Schema, values, filter []string, opts compute.FilterOptions) (arrow.Table, error) { + tbl, err := array.TableFromJSON(f.mem, sc, values) + if err != nil { + return nil, err + } + defer tbl.Release() + + filtChk, err := array.ChunkedFromJSON(f.mem, arrow.FixedWidthTypes.Boolean, filter) + f.Require().NoError(err) + defer filtChk.Release() + + out, err := compute.Filter(context.TODO(), &compute.TableDatum{Value: tbl}, &compute.ChunkedDatum{Value: filtChk}, opts) + if err != nil { + return nil, err + } + return out.(*compute.TableDatum).Value, nil +} + +func (f *FilterKernelWithTable) assertChunkedFilter(sc *arrow.Schema, tableJSON, filter []string, opts compute.FilterOptions, expTable []string) { + actual, err := f.filterWithChunked(sc, tableJSON, filter, opts) + f.Require().NoError(err) + defer actual.Release() + + expected, err := array.TableFromJSON(f.mem, sc, expTable) + f.Require().NoError(err) + defer expected.Release() + + f.Truef(array.TableEqual(expected, actual), "expected: %s\ngot: %s", expected, actual) +} + +func (f *FilterKernelWithTable) assertFilter(sc *arrow.Schema, tableJSON []string, filter string, opts compute.FilterOptions, expectedTable []string) { + actual, err := f.filterWithArray(sc, tableJSON, filter, opts) + f.Require().NoError(err) + defer actual.Release() + + expected, err := array.TableFromJSON(f.mem, sc, expectedTable) + f.Require().NoError(err) + defer expected.Release() + + f.Truef(array.TableEqual(expected, actual), "expected: %s\ngot: %s", expected, actual) +} + +func (f *FilterKernelWithTable) TestFilterTable() { + fields := []arrow.Field{ + {Name: "a", Type: arrow.PrimitiveTypes.Int32, Nullable: true}, + {Name: "b", Type: arrow.BinaryTypes.String, Nullable: true}, + } + sc := arrow.NewSchema(fields, nil) + tableJSON := []string{`[ + {"a": null, "b": "yo"}, + {"a": 1, "b": ""} + ]`, `[ + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ]`} + + for _, opt := range []compute.FilterOptions{f.emitNulls, f.dropOpts} { + f.assertFilter(sc, tableJSON, `[false, false, false, false]`, opt, []string{}) + f.assertChunkedFilter(sc, tableJSON, []string{`[false]`, `[false, false, false]`}, opt, []string{}) + f.assertFilter(sc, tableJSON, `[true, true, true, true]`, opt, tableJSON) + f.assertChunkedFilter(sc, tableJSON, []string{`[true]`, `[true, true, true]`}, opt, tableJSON) + } + + expectedEmitNull := []string{`[{"a": 1, "b": ""}]`, `[{"a": 2, "b": "hello"},{"a": null, "b": null}]`} + f.assertFilter(sc, tableJSON, `[false, true, true, null]`, f.emitNulls, expectedEmitNull) + f.assertChunkedFilter(sc, tableJSON, []string{`[false, true, true]`, `[null]`}, f.emitNulls, expectedEmitNull) + + expectedDrop := []string{`[{"a": 1, "b": ""}]`, `[{"a": 2, "b": "hello"}]`} + f.assertFilter(sc, tableJSON, `[false, true, true, null]`, f.dropOpts, expectedDrop) + f.assertChunkedFilter(sc, tableJSON, []string{`[false, true, true]`, `[null]`}, f.dropOpts, expectedDrop) +} + type TakeKernelTestTyped struct { TakeKernelTestSuite @@ -1101,4 +1371,7 @@ func TestFilterKernels(t *testing.T) { suite.Run(t, new(FilterKernelWithUnion)) suite.Run(t, new(FilterKernelExtension)) suite.Run(t, new(FilterKernelWithStruct)) + suite.Run(t, new(FilterKernelWithRecordBatch)) + suite.Run(t, new(FilterKernelWithChunked)) + suite.Run(t, new(FilterKernelWithTable)) } diff --git a/go/arrow/table.go b/go/arrow/table.go index c4a6351cce2..e0c2caf515c 100644 --- a/go/arrow/table.go +++ b/go/arrow/table.go @@ -140,16 +140,20 @@ type Chunked struct { // NewChunked panics if the chunks do not have the same data type. func NewChunked(dtype DataType, chunks []Array) *Chunked { arr := &Chunked{ - chunks: make([]Array, len(chunks)), + chunks: make([]Array, 0, len(chunks)), refCount: 1, dtype: dtype, } - for i, chunk := range chunks { + for _, chunk := range chunks { + if chunk == nil { + continue + } + if !TypeEqual(chunk.DataType(), dtype) { panic("arrow/array: mismatch data type") } chunk.Retain() - arr.chunks[i] = chunk + arr.chunks = append(arr.chunks, chunk) arr.length += chunk.Len() arr.nulls += chunk.NullN() } From 8ecb73015560498fc28b9fe498b3568296e3f4ab Mon Sep 17 00:00:00 2001 From: Jacob Wujciak-Jens Date: Tue, 20 Sep 2022 23:10:04 +0200 Subject: [PATCH 109/133] ARROW-17782: [C++][R] R package not building on macos 10.13 with C++17 std lib (#14178) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit r-binary-package build ✔️ [here](https://github.com/ursacomputing/crossbow/actions/runs/3091654213) the failure in "test source" is unrelated. An error message without `-D_LIBCPP_DISABLE_AVAILABILITY`: https://github.com/ursacomputing/crossbow/actions/runs/3081084528/jobs/4979666988#step:13:10847 /Users/voltrondata/tmp/build-apache-arrow/opt/apache-arrow/include/arrow/compute/exec.h:340:12: error: call to unavailable function 'get': introduced in macOS 10.14 return std::get>(this->value); ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Authored-by: Jacob Wujciak-Jens Signed-off-by: Sutou Kouhei --- r/tools/autobrew | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/r/tools/autobrew b/r/tools/autobrew index d895b57eaf9..73e6e11a161 100644 --- a/r/tools/autobrew +++ b/r/tools/autobrew @@ -74,7 +74,7 @@ for FILE in $BREWDIR/Cellar/*/*/lib/*.a; do PKG_LIBS=`echo $PKG_LIBS | sed "s/-l$LIBNAME/-lbrew$LIBNAME/g"` done -PKG_CFLAGS="-I$BREWDIR/opt/$PKG_BREW_NAME/include -DARROW_R_WITH_PARQUET -DARROW_R_WITH_DATASET -DARROW_R_WITH_JSON -DARROW_R_WITH_S3 -DARROW_R_WITH_GCS" +PKG_CFLAGS="-I$BREWDIR/opt/$PKG_BREW_NAME/include -DARROW_R_WITH_PARQUET -DARROW_R_WITH_DATASET -DARROW_R_WITH_JSON -DARROW_R_WITH_S3 -DARROW_R_WITH_GCS -D_LIBCPP_DISABLE_AVAILABILITY" unset HOMEBREW_NO_ANALYTICS unset HOMEBREW_NO_AUTO_UPDATE From 8ad5e59803dda03ff5b829ae1635bbe301a1a4e4 Mon Sep 17 00:00:00 2001 From: Kae S <31679808+ksuarez1423@users.noreply.github.com> Date: Tue, 20 Sep 2022 18:27:45 -0400 Subject: [PATCH 110/133] ARROW-17377: [C++][Docs] Adds tutorial for basic Arrow, file access, compute, and datasets (#13859) I intend for this PR to add a few small tutorial articles to the Arrow documentation, for basic Arrow usage, file access, compute, and dataset functionality. Right now, this is a draft PR, with just the code for the examples. Before I set it up with comments and prose in Sphinx, I wanted to get it reviewed. Do these examples seem suitable for the tutorials they target? Authored-by: kaesuarez Signed-off-by: David Li --- cpp/examples/tutorial_examples/CMakeLists.txt | 48 ++ .../tutorial_examples/arrow_example.cc | 163 +++++++ cpp/examples/tutorial_examples/build_arrow.sh | 38 ++ .../tutorial_examples/build_example.sh | 27 ++ .../tutorial_examples/compute_example.cc | 138 ++++++ .../tutorial_examples/dataset_example.cc | 244 ++++++++++ .../tutorial_examples/docker-compose.yml | 29 ++ .../tutorial_examples/file_access_example.cc | 216 +++++++++ cpp/examples/tutorial_examples/run.sh | 51 ++ .../tutorial_examples/tutorial.dockerfile | 27 ++ dev/tasks/tasks.yml | 7 + docs/source/cpp/getting_started.rst | 45 +- docs/source/cpp/index.rst | 1 + docs/source/cpp/tutorials/basic_arrow.rst | 284 +++++++++++ .../source/cpp/tutorials/compute_tutorial.rst | 343 +++++++++++++ .../cpp/tutorials/datasets_tutorial.rst | 453 ++++++++++++++++++ docs/source/cpp/tutorials/io_tutorial.rst | 404 ++++++++++++++++ docs/source/cpp/user_guide.rst | 43 ++ 18 files changed, 2540 insertions(+), 21 deletions(-) create mode 100644 cpp/examples/tutorial_examples/CMakeLists.txt create mode 100644 cpp/examples/tutorial_examples/arrow_example.cc create mode 100755 cpp/examples/tutorial_examples/build_arrow.sh create mode 100755 cpp/examples/tutorial_examples/build_example.sh create mode 100644 cpp/examples/tutorial_examples/compute_example.cc create mode 100644 cpp/examples/tutorial_examples/dataset_example.cc create mode 100644 cpp/examples/tutorial_examples/docker-compose.yml create mode 100644 cpp/examples/tutorial_examples/file_access_example.cc create mode 100755 cpp/examples/tutorial_examples/run.sh create mode 100644 cpp/examples/tutorial_examples/tutorial.dockerfile create mode 100644 docs/source/cpp/tutorials/basic_arrow.rst create mode 100644 docs/source/cpp/tutorials/compute_tutorial.rst create mode 100644 docs/source/cpp/tutorials/datasets_tutorial.rst create mode 100644 docs/source/cpp/tutorials/io_tutorial.rst create mode 100644 docs/source/cpp/user_guide.rst diff --git a/cpp/examples/tutorial_examples/CMakeLists.txt b/cpp/examples/tutorial_examples/CMakeLists.txt new file mode 100644 index 00000000000..ed399edbd60 --- /dev/null +++ b/cpp/examples/tutorial_examples/CMakeLists.txt @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +cmake_minimum_required(VERSION 3.0) + +project(ArrowTutorialExamples) + +find_package(Arrow REQUIRED) + +get_filename_component(ARROW_CONFIG_PATH ${Arrow_CONFIG} DIRECTORY) +find_package(Parquet REQUIRED HINTS ${ARROW_CONFIG_PATH}) +find_package(ArrowDataset REQUIRED HINTS ${ARROW_CONFIG_PATH}) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Werror -Wall -Wextra") + +set(CMAKE_BUILD_TYPE Release) + +message(STATUS "Arrow version: ${ARROW_VERSION}") +message(STATUS "Arrow SO version: ${ARROW_FULL_SO_VERSION}") + +add_executable(arrow_example arrow_example.cc) +target_link_libraries(arrow_example PRIVATE Arrow::arrow_shared) + +add_executable(file_access_example file_access_example.cc) +target_link_libraries(file_access_example PRIVATE Arrow::arrow_shared + Parquet::parquet_shared) + +add_executable(compute_example compute_example.cc) +target_link_libraries(compute_example PRIVATE Arrow::arrow_shared) + +add_executable(dataset_example dataset_example.cc) +target_link_libraries(dataset_example PRIVATE Arrow::arrow_shared Parquet::parquet_shared + ArrowDataset::arrow_dataset_shared) diff --git a/cpp/examples/tutorial_examples/arrow_example.cc b/cpp/examples/tutorial_examples/arrow_example.cc new file mode 100644 index 00000000000..50b8c3033f3 --- /dev/null +++ b/cpp/examples/tutorial_examples/arrow_example.cc @@ -0,0 +1,163 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// (Doc section: Basic Example) + +// (Doc section: Includes) +#include + +#include +// (Doc section: Includes) + +// (Doc section: RunMain Start) +arrow::Status RunMain() { + // (Doc section: RunMain Start) + // (Doc section: int8builder 1 Append) + // Builders are the main way to create Arrays in Arrow from existing values that are not + // on-disk. In this case, we'll make a simple array, and feed that in. + // Data types are important as ever, and there is a Builder for each compatible type; + // in this case, int8. + arrow::Int8Builder int8builder; + int8_t days_raw[5] = {1, 12, 17, 23, 28}; + // AppendValues, as called, puts 5 values from days_raw into our Builder object. + ARROW_RETURN_NOT_OK(int8builder.AppendValues(days_raw, 5)); + // (Doc section: int8builder 1 Append) + + // (Doc section: int8builder 1 Finish) + // We only have a Builder though, not an Array -- the following code pushes out the + // built up data into a proper Array. + std::shared_ptr days; + ARROW_ASSIGN_OR_RAISE(days, int8builder.Finish()); + // (Doc section: int8builder 1 Finish) + + // (Doc section: int8builder 2) + // Builders clear their state every time they fill an Array, so if the type is the same, + // we can re-use the builder. We do that here for month values. + int8_t months_raw[5] = {1, 3, 5, 7, 1}; + ARROW_RETURN_NOT_OK(int8builder.AppendValues(months_raw, 5)); + std::shared_ptr months; + ARROW_ASSIGN_OR_RAISE(months, int8builder.Finish()); + // (Doc section: int8builder 2) + + // (Doc section: int16builder) + // Now that we change to int16, we use the Builder for that data type instead. + arrow::Int16Builder int16builder; + int16_t years_raw[5] = {1990, 2000, 1995, 2000, 1995}; + ARROW_RETURN_NOT_OK(int16builder.AppendValues(years_raw, 5)); + std::shared_ptr years; + ARROW_ASSIGN_OR_RAISE(years, int16builder.Finish()); + // (Doc section: int16builder) + + // (Doc section: Schema) + // Now, we want a RecordBatch, which has columns and labels for said columns. + // This gets us to the 2d data structures we want in Arrow. + // These are defined by schema, which have fields -- here we get both those object types + // ready. + std::shared_ptr field_day, field_month, field_year; + std::shared_ptr schema; + + // Every field needs its name and data type. + field_day = arrow::field("Day", arrow::int8()); + field_month = arrow::field("Month", arrow::int8()); + field_year = arrow::field("Year", arrow::int16()); + + // The schema can be built from a vector of fields, and we do so here. + schema = arrow::schema({field_day, field_month, field_year}); + // (Doc section: Schema) + + // (Doc section: RBatch) + // With the schema and Arrays full of data, we can make our RecordBatch! Here, + // each column is internally contiguous. This is in opposition to Tables, which we'll + // see next. + std::shared_ptr rbatch; + // The RecordBatch needs the schema, length for columns, which all must match, + // and the actual data itself. + rbatch = arrow::RecordBatch::Make(schema, days->length(), {days, months, years}); + + std::cout << rbatch->ToString(); + // (Doc section: RBatch) + + // (Doc section: More Arrays) + // Now, let's get some new arrays! It'll be the same datatypes as above, so we re-use + // Builders. + int8_t days_raw2[5] = {6, 12, 3, 30, 22}; + ARROW_RETURN_NOT_OK(int8builder.AppendValues(days_raw2, 5)); + std::shared_ptr days2; + ARROW_ASSIGN_OR_RAISE(days2, int8builder.Finish()); + + int8_t months_raw2[5] = {5, 4, 11, 3, 2}; + ARROW_RETURN_NOT_OK(int8builder.AppendValues(months_raw2, 5)); + std::shared_ptr months2; + ARROW_ASSIGN_OR_RAISE(months2, int8builder.Finish()); + + int16_t years_raw2[5] = {1980, 2001, 1915, 2020, 1996}; + ARROW_RETURN_NOT_OK(int16builder.AppendValues(years_raw2, 5)); + std::shared_ptr years2; + ARROW_ASSIGN_OR_RAISE(years2, int16builder.Finish()); + // (Doc section: More Arrays) + + // (Doc section: ArrayVector) + // ChunkedArrays let us have a list of arrays, which aren't contiguous + // with each other. First, we get a vector of arrays. + arrow::ArrayVector day_vecs{days, days2}; + // (Doc section: ArrayVector) + // (Doc section: ChunkedArray Day) + // Then, we use that to initialize a ChunkedArray, which can be used with other + // functions in Arrow! This is good, since having a normal vector of arrays wouldn't + // get us far. + std::shared_ptr day_chunks = + std::make_shared(day_vecs); + // (Doc section: ChunkedArray Day) + + // (Doc section: ChunkedArray Month Year) + // Repeat for months. + arrow::ArrayVector month_vecs{months, months}; + std::shared_ptr month_chunks = + std::make_shared(month_vecs); + + // Repeat for years. + arrow::ArrayVector year_vecs{years, years2}; + std::shared_ptr year_chunks = + std::make_shared(year_vecs); + // (Doc section: ChunkedArray Month Year) + + // (Doc section: Table) + // A Table is the structure we need for these non-contiguous columns, and keeps them + // all in one place for us so we can use them as if they were normal arrays. + std::shared_ptr table; + table = arrow::Table::Make(schema, {day_chunks, month_chunks, year_chunks}, 10); + + std::cout << table->ToString(); + // (Doc section: Table) + + // (Doc section: Ret) + return arrow::Status::OK(); +} +// (Doc section: Ret) + +// (Doc section: Main) +int main() { + arrow::Status st = RunMain(); + if (!st.ok()) { + std::cerr << st << std::endl; + return 1; + } + return 0; +} + +// (Doc section: Main) +// (Doc section: Basic Example) diff --git a/cpp/examples/tutorial_examples/build_arrow.sh b/cpp/examples/tutorial_examples/build_arrow.sh new file mode 100755 index 00000000000..ec72a288c7b --- /dev/null +++ b/cpp/examples/tutorial_examples/build_arrow.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -ex + +NPROC=$(nproc) + +mkdir -p $ARROW_BUILD_DIR +pushd $ARROW_BUILD_DIR + +# Enable the CSV reader as it's used by the example third-party build +cmake /arrow/cpp \ + -DARROW_CSV=ON \ + -DARROW_DATASET=ON \ + -DARROW_FILESYSTEM=ON \ + -DARROW_PARQUET=ON \ + -DARROW_JEMALLOC=OFF \ + $ARROW_CMAKE_OPTIONS + +make -j$NPROC +make install + +popd diff --git a/cpp/examples/tutorial_examples/build_example.sh b/cpp/examples/tutorial_examples/build_example.sh new file mode 100755 index 00000000000..a315755a597 --- /dev/null +++ b/cpp/examples/tutorial_examples/build_example.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -ex + +mkdir -p $EXAMPLE_BUILD_DIR +pushd $EXAMPLE_BUILD_DIR + +cmake /io +make + +popd diff --git a/cpp/examples/tutorial_examples/compute_example.cc b/cpp/examples/tutorial_examples/compute_example.cc new file mode 100644 index 00000000000..3a65214c0ef --- /dev/null +++ b/cpp/examples/tutorial_examples/compute_example.cc @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// (Doc section: Compute Example) + +// (Doc section: Includes) +#include +#include + +#include +// (Doc section: Includes) + +// (Doc section: RunMain) +arrow::Status RunMain() { + // (Doc section: RunMain) + // (Doc section: Create Tables) + // Create a couple 32-bit integer arrays. + arrow::Int32Builder int32builder; + int32_t some_nums_raw[5] = {34, 624, 2223, 5654, 4356}; + ARROW_RETURN_NOT_OK(int32builder.AppendValues(some_nums_raw, 5)); + std::shared_ptr some_nums; + ARROW_ASSIGN_OR_RAISE(some_nums, int32builder.Finish()); + + int32_t more_nums_raw[5] = {75342, 23, 64, 17, 736}; + ARROW_RETURN_NOT_OK(int32builder.AppendValues(more_nums_raw, 5)); + std::shared_ptr more_nums; + ARROW_ASSIGN_OR_RAISE(more_nums, int32builder.Finish()); + + // Make a table out of our pair of arrays. + std::shared_ptr field_a, field_b; + std::shared_ptr schema; + + field_a = arrow::field("A", arrow::int32()); + field_b = arrow::field("B", arrow::int32()); + + schema = arrow::schema({field_a, field_b}); + + std::shared_ptr table; + table = arrow::Table::Make(schema, {some_nums, more_nums}, 5); + // (Doc section: Create Tables) + + // (Doc section: Sum Datum Declaration) + // The Datum class is what all compute functions output to, and they can take Datums + // as inputs, as well. + arrow::Datum sum; + // (Doc section: Sum Datum Declaration) + // (Doc section: Sum Call) + // Here, we can use arrow::compute::Sum. This is a convenience function, and the next + // computation won't be so simple. However, using these where possible helps + // readability. + ARROW_ASSIGN_OR_RAISE(sum, arrow::compute::Sum({table->GetColumnByName("A")})); + // (Doc section: Sum Call) + // (Doc section: Sum Datum Type) + // Get the kind of Datum and what it holds -- this is a Scalar, with int64. + std::cout << "Datum kind: " << sum.ToString() + << " content type: " << sum.type()->ToString() << std::endl; + // (Doc section: Sum Datum Type) + // (Doc section: Sum Contents) + // Note that we explicitly request a scalar -- the Datum cannot simply give what it is, + // you must ask for the correct type. + std::cout << sum.scalar_as().value << std::endl; + // (Doc section: Sum Contents) + + // (Doc section: Add Datum Declaration) + arrow::Datum element_wise_sum; + // (Doc section: Add Datum Declaration) + // (Doc section: Add Call) + // Get element-wise sum of both columns A and B in our Table. Note that here we use + // CallFunction(), which takes the name of the function as the first argument. + ARROW_ASSIGN_OR_RAISE(element_wise_sum, arrow::compute::CallFunction( + "add", {table->GetColumnByName("A"), + table->GetColumnByName("B")})); + // (Doc section: Add Call) + // (Doc section: Add Datum Type) + // Get the kind of Datum and what it holds -- this is a ChunkedArray, with int32. + std::cout << "Datum kind: " << element_wise_sum.ToString() + << " content type: " << element_wise_sum.type()->ToString() << std::endl; + // (Doc section: Add Datum Type) + // (Doc section: Add Contents) + // This time, we get a ChunkedArray, not a scalar. + std::cout << element_wise_sum.chunked_array()->ToString() << std::endl; + // (Doc section: Add Contents) + + // (Doc section: Index Datum Declare) + // Use an options struct to set up searching for 2223 in column A (the third item). + arrow::Datum third_item; + // (Doc section: Index Datum Declare) + // (Doc section: IndexOptions Declare) + // An options struct is used in lieu of passing an arbitrary amount of arguments. + arrow::compute::IndexOptions index_options; + // (Doc section: IndexOptions Declare) + // (Doc section: IndexOptions Assign) + // We need an Arrow Scalar, not a raw value. + index_options.value = arrow::MakeScalar(2223); + // (Doc section: IndexOptions Assign) + // (Doc section: Index Call) + ARROW_ASSIGN_OR_RAISE( + third_item, arrow::compute::CallFunction("index", {table->GetColumnByName("A")}, + &index_options)); + // (Doc section: Index Call) + // (Doc section: Index Inspection) + // Get the kind of Datum and what it holds -- this is a Scalar, with int64 + std::cout << "Datum kind: " << third_item.ToString() + << " content type: " << third_item.type()->ToString() << std::endl; + // We get a scalar -- the location of 2223 in column A, which is 2 in 0-based indexing. + std::cout << third_item.scalar_as().value << std::endl; + // (Doc section: Index Inspection) + // (Doc section: Ret) + return arrow::Status::OK(); +} +// (Doc section: Ret) + +// (Doc section: Main) +int main() { + arrow::Status st = RunMain(); + if (!st.ok()) { + std::cerr << st << std::endl; + return 1; + } + return 0; +} +// (Doc section: Main) + +// (Doc section: Compute Example) diff --git a/cpp/examples/tutorial_examples/dataset_example.cc b/cpp/examples/tutorial_examples/dataset_example.cc new file mode 100644 index 00000000000..005cdc324d0 --- /dev/null +++ b/cpp/examples/tutorial_examples/dataset_example.cc @@ -0,0 +1,244 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// (Doc section: Dataset Example) + +// (Doc section: Includes) +#include +#include +// We use Parquet headers for setting up examples; they are not required for using +// datasets. +#include +#include + +#include +// (Doc section: Includes) + +// (Doc section: Helper Functions) +// Generate some data for the rest of this example. +arrow::Result> CreateTable() { + // This code should look familiar from the basic Arrow example, and is not the + // focus of this example. However, we need data to work on it, and this makes that! + auto schema = + arrow::schema({arrow::field("a", arrow::int64()), arrow::field("b", arrow::int64()), + arrow::field("c", arrow::int64())}); + std::shared_ptr array_a; + std::shared_ptr array_b; + std::shared_ptr array_c; + arrow::NumericBuilder builder; + ARROW_RETURN_NOT_OK(builder.AppendValues({0, 1, 2, 3, 4, 5, 6, 7, 8, 9})); + ARROW_RETURN_NOT_OK(builder.Finish(&array_a)); + builder.Reset(); + ARROW_RETURN_NOT_OK(builder.AppendValues({9, 8, 7, 6, 5, 4, 3, 2, 1, 0})); + ARROW_RETURN_NOT_OK(builder.Finish(&array_b)); + builder.Reset(); + ARROW_RETURN_NOT_OK(builder.AppendValues({1, 2, 1, 2, 1, 2, 1, 2, 1, 2})); + ARROW_RETURN_NOT_OK(builder.Finish(&array_c)); + return arrow::Table::Make(schema, {array_a, array_b, array_c}); +} + +// Set up a dataset by writing two Parquet files. +arrow::Result CreateExampleParquetDataset( + const std::shared_ptr& filesystem, + const std::string& root_path) { + // Much like CreateTable(), this is utility that gets us the dataset we'll be reading + // from. Don't worry, we also write a dataset in the example proper. + auto base_path = root_path + "parquet_dataset"; + ARROW_RETURN_NOT_OK(filesystem->CreateDir(base_path)); + // Create an Arrow Table + ARROW_ASSIGN_OR_RAISE(auto table, CreateTable()); + // Write it into two Parquet files + ARROW_ASSIGN_OR_RAISE(auto output, + filesystem->OpenOutputStream(base_path + "/data1.parquet")); + ARROW_RETURN_NOT_OK(parquet::arrow::WriteTable( + *table->Slice(0, 5), arrow::default_memory_pool(), output, 2048)); + ARROW_ASSIGN_OR_RAISE(output, + filesystem->OpenOutputStream(base_path + "/data2.parquet")); + ARROW_RETURN_NOT_OK(parquet::arrow::WriteTable( + *table->Slice(5), arrow::default_memory_pool(), output, 2048)); + return base_path; +} + +arrow::Status PrepareEnv() { + // Get our environment prepared for reading, by setting up some quick writing. + ARROW_ASSIGN_OR_RAISE(auto src_table, CreateTable()) + std::shared_ptr setup_fs; + // Note this operates in the directory the executable is built in. + char setup_path[256]; + char* result = getcwd(setup_path, 256); + if (result == NULL) { + return arrow::Status::IOError("Fetching PWD failed."); + } + + ARROW_ASSIGN_OR_RAISE(setup_fs, arrow::fs::FileSystemFromUriOrPath(setup_path)); + ARROW_ASSIGN_OR_RAISE(auto dset_path, CreateExampleParquetDataset(setup_fs, "")); + + return arrow::Status::OK(); +} +// (Doc section: Helper Functions) + +// (Doc section: RunMain) +arrow::Status RunMain() { + // (Doc section: RunMain) + // (Doc section: PrepareEnv) + ARROW_RETURN_NOT_OK(PrepareEnv()); + // (Doc section: PrepareEnv) + + // (Doc section: FileSystem Declare) + // First, we need a filesystem object, which lets us interact with our local + // filesystem starting at a given path. For the sake of simplicity, that'll be + // the current directory. + std::shared_ptr fs; + // (Doc section: FileSystem Declare) + + // (Doc section: FileSystem Init) + // Get the CWD, use it to make the FileSystem object. + char init_path[256]; + char* result = getcwd(init_path, 256); + if (result == NULL) { + return arrow::Status::IOError("Fetching PWD failed."); + } + ARROW_ASSIGN_OR_RAISE(fs, arrow::fs::FileSystemFromUriOrPath(init_path)); + // (Doc section: FileSystem Init) + + // (Doc section: FileSelector Declare) + // A file selector lets us actually traverse a multi-file dataset. + arrow::fs::FileSelector selector; + // (Doc section: FileSelector Declare) + // (Doc section: FileSelector Config) + selector.base_dir = "parquet_dataset"; + // Recursive is a safe bet if you don't know the nesting of your dataset. + selector.recursive = true; + // (Doc section: FileSelector Config) + // (Doc section: FileSystemFactoryOptions) + // Making an options object lets us configure our dataset reading. + arrow::dataset::FileSystemFactoryOptions options; + // We'll use Hive-style partitioning. We'll let Arrow Datasets infer the partition + // schema. We won't set any other options, defaults are fine. + options.partitioning = arrow::dataset::HivePartitioning::MakeFactory(); + // (Doc section: FileSystemFactoryOptions) + // (Doc section: File Format Setup) + auto read_format = std::make_shared(); + // (Doc section: File Format Setup) + // (Doc section: FileSystemDatasetFactory Make) + // Now, we get a factory that will let us get our dataset -- we don't have the + // dataset yet! + ARROW_ASSIGN_OR_RAISE(auto factory, arrow::dataset::FileSystemDatasetFactory::Make( + fs, selector, read_format, options)); + // (Doc section: FileSystemDatasetFactory Make) + // (Doc section: FileSystemDatasetFactory Finish) + // Now we build our dataset from the factory. + ARROW_ASSIGN_OR_RAISE(auto read_dataset, factory->Finish()); + // (Doc section: FileSystemDatasetFactory Finish) + // (Doc section: Dataset Fragments) + // Print out the fragments + ARROW_ASSIGN_OR_RAISE(auto fragments, read_dataset->GetFragments()); + for (const auto& fragment : fragments) { + std::cout << "Found fragment: " << (*fragment)->ToString() << std::endl; + std::cout << "Partition expression: " + << (*fragment)->partition_expression().ToString() << std::endl; + } + // (Doc section: Dataset Fragments) + // (Doc section: Read Scan Builder) + // Scan dataset into a Table -- once this is done, you can do + // normal table things with it, like computation and printing. However, now you're + // also dedicated to being in memory. + ARROW_ASSIGN_OR_RAISE(auto read_scan_builder, read_dataset->NewScan()); + // (Doc section: Read Scan Builder) + // (Doc section: Read Scanner) + ARROW_ASSIGN_OR_RAISE(auto read_scanner, read_scan_builder->Finish()); + // (Doc section: Read Scanner) + // (Doc section: To Table) + ARROW_ASSIGN_OR_RAISE(std::shared_ptr table, read_scanner->ToTable()); + std::cout << table->ToString(); + // (Doc section: To Table) + + // (Doc section: TableBatchReader) + // Now, let's get a table out to disk as a dataset! + // We make a RecordBatchReader from our Table, then set up a scanner, which lets us + // go to a file. + std::shared_ptr write_dataset = + std::make_shared(table); + // (Doc section: TableBatchReader) + // (Doc section: WriteScanner) + auto write_scanner_builder = + arrow::dataset::ScannerBuilder::FromRecordBatchReader(write_dataset); + ARROW_ASSIGN_OR_RAISE(auto write_scanner, write_scanner_builder->Finish()) + // (Doc section: WriteScanner) + // (Doc section: Partition Schema) + // The partition schema determines which fields are used as keys for partitioning. + auto partition_schema = arrow::schema({arrow::field("a", arrow::utf8())}); + // (Doc section: Partition Schema) + // (Doc section: Partition Create) + // We'll use Hive-style partitioning, which creates directories with "key=value" + // pairs. + auto partitioning = + std::make_shared(partition_schema); + // (Doc section: Partition Create) + // (Doc section: Write Format) + // Now, we declare we'll be writing Parquet files. + auto write_format = std::make_shared(); + // (Doc section: Write Format) + // (Doc section: Write Options) + // This time, we make Options for writing, but do much more configuration. + arrow::dataset::FileSystemDatasetWriteOptions write_options; + // Defaults to start. + write_options.file_write_options = write_format->DefaultWriteOptions(); + // (Doc section: Write Options) + // (Doc section: Options FS) + // Use the filesystem we already have. + write_options.filesystem = fs; + // (Doc section: Options FS) + // (Doc section: Options Target) + // Write to the folder "write_dataset" in current directory. + write_options.base_dir = "write_dataset"; + // (Doc section: Options Target) + // (Doc section: Options Partitioning) + // Use the partitioning declared above. + write_options.partitioning = partitioning; + // (Doc section: Options Partitioning) + // (Doc section: Options Name Template) + // Define what the name for the files making up the dataset will be. + write_options.basename_template = "part{i}.parquet"; + // (Doc section: Options Name Template) + // (Doc section: Options File Behavior) + // Set behavior to overwrite existing data -- specifically, this lets this example + // be run more than once, and allows whatever code you have to overwrite what's there. + write_options.existing_data_behavior = + arrow::dataset::ExistingDataBehavior::kOverwriteOrIgnore; + // (Doc section: Options File Behavior) + // (Doc section: Write Dataset) + // Write to disk! + ARROW_RETURN_NOT_OK( + arrow::dataset::FileSystemDataset::Write(write_options, write_scanner)); + // (Doc section: Write Dataset) + // (Doc section: Ret) + return arrow::Status::OK(); +} +// (Doc section: Ret) +// (Doc section: Main) +int main() { + arrow::Status st = RunMain(); + if (!st.ok()) { + std::cerr << st << std::endl; + return 1; + } + return 0; +} +// (Doc section: Main) + +// (Doc section: Dataset Example) diff --git a/cpp/examples/tutorial_examples/docker-compose.yml b/cpp/examples/tutorial_examples/docker-compose.yml new file mode 100644 index 00000000000..90bdbcad3d8 --- /dev/null +++ b/cpp/examples/tutorial_examples/docker-compose.yml @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +version: '3.5' + +services: + tutorial: + build: + context: . + dockerfile: tutorial.dockerfile + volumes: + - ../../../:/arrow:delegated + - .:/io:delegated + command: + - "/io/run.sh" diff --git a/cpp/examples/tutorial_examples/file_access_example.cc b/cpp/examples/tutorial_examples/file_access_example.cc new file mode 100644 index 00000000000..fdc312ff421 --- /dev/null +++ b/cpp/examples/tutorial_examples/file_access_example.cc @@ -0,0 +1,216 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// (Doc section: File I/O) + +// (Doc section: Includes) +#include +#include +#include +#include +#include +#include + +#include +// (Doc section: Includes) + +// (Doc section: GenInitialFile) +arrow::Status GenInitialFile() { + // Make a couple 8-bit integer arrays and a 16-bit integer array -- just like + // basic Arrow example. + arrow::Int8Builder int8builder; + int8_t days_raw[5] = {1, 12, 17, 23, 28}; + ARROW_RETURN_NOT_OK(int8builder.AppendValues(days_raw, 5)); + std::shared_ptr days; + ARROW_ASSIGN_OR_RAISE(days, int8builder.Finish()); + + int8_t months_raw[5] = {1, 3, 5, 7, 1}; + ARROW_RETURN_NOT_OK(int8builder.AppendValues(months_raw, 5)); + std::shared_ptr months; + ARROW_ASSIGN_OR_RAISE(months, int8builder.Finish()); + + arrow::Int16Builder int16builder; + int16_t years_raw[5] = {1990, 2000, 1995, 2000, 1995}; + ARROW_RETURN_NOT_OK(int16builder.AppendValues(years_raw, 5)); + std::shared_ptr years; + ARROW_ASSIGN_OR_RAISE(years, int16builder.Finish()); + + // Get a vector of our Arrays + std::vector> columns = {days, months, years}; + + // Make a schema to initialize the Table with + std::shared_ptr field_day, field_month, field_year; + std::shared_ptr schema; + + field_day = arrow::field("Day", arrow::int8()); + field_month = arrow::field("Month", arrow::int8()); + field_year = arrow::field("Year", arrow::int16()); + + schema = arrow::schema({field_day, field_month, field_year}); + // With the schema and data, create a Table + std::shared_ptr table; + table = arrow::Table::Make(schema, columns); + + // Write out test files in IPC, CSV, and Parquet for the example to use. + std::shared_ptr outfile; + ARROW_ASSIGN_OR_RAISE(outfile, arrow::io::FileOutputStream::Open("test_in.arrow")); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr ipc_writer, + arrow::ipc::MakeFileWriter(outfile, schema)); + ARROW_RETURN_NOT_OK(ipc_writer->WriteTable(*table)); + ARROW_RETURN_NOT_OK(ipc_writer->Close()); + + ARROW_ASSIGN_OR_RAISE(outfile, arrow::io::FileOutputStream::Open("test_in.csv")); + ARROW_ASSIGN_OR_RAISE(auto csv_writer, + arrow::csv::MakeCSVWriter(outfile, table->schema())); + ARROW_RETURN_NOT_OK(csv_writer->WriteTable(*table)); + ARROW_RETURN_NOT_OK(csv_writer->Close()); + + ARROW_ASSIGN_OR_RAISE(outfile, arrow::io::FileOutputStream::Open("test_in.parquet")); + PARQUET_THROW_NOT_OK( + parquet::arrow::WriteTable(*table, arrow::default_memory_pool(), outfile, 5)); + + return arrow::Status::OK(); +} +// (Doc section: GenInitialFile) + +// (Doc section: RunMain) +arrow::Status RunMain() { + // (Doc section: RunMain) + // (Doc section: Gen Files) + // Generate initial files for each format with a helper function -- don't worry, + // we'll also write a table in this example. + ARROW_RETURN_NOT_OK(GenInitialFile()); + // (Doc section: Gen Files) + + // (Doc section: ReadableFile Definition) + // First, we have to set up a ReadableFile object, which just lets us point our + // readers to the right data on disk. We'll be reusing this object, and rebinding + // it to multiple files throughout the example. + std::shared_ptr infile; + // (Doc section: ReadableFile Definition) + // (Doc section: Arrow ReadableFile Open) + // Get "test_in.arrow" into our file pointer + ARROW_ASSIGN_OR_RAISE(infile, arrow::io::ReadableFile::Open( + "test_in.arrow", arrow::default_memory_pool())); + // (Doc section: Arrow ReadableFile Open) + // (Doc section: Arrow Read Open) + // Open up the file with the IPC features of the library, gives us a reader object. + ARROW_ASSIGN_OR_RAISE(auto ipc_reader, arrow::ipc::RecordBatchFileReader::Open(infile)); + // (Doc section: Arrow Read Open) + // (Doc section: Arrow Read) + // Using the reader, we can read Record Batches. Note that this is specific to IPC; + // for other formats, we focus on Tables, but here, RecordBatches are used. + std::shared_ptr rbatch; + ARROW_ASSIGN_OR_RAISE(rbatch, ipc_reader->ReadRecordBatch(0)); + // (Doc section: Arrow Read) + + // (Doc section: Arrow Write Open) + // Just like with input, we get an object for the output file. + std::shared_ptr outfile; + // Bind it to "test_out.arrow" + ARROW_ASSIGN_OR_RAISE(outfile, arrow::io::FileOutputStream::Open("test_out.arrow")); + // (Doc section: Arrow Write Open) + // (Doc section: Arrow Writer) + // Set up a writer with the output file -- and the schema! We're defining everything + // here, loading to fire. + ARROW_ASSIGN_OR_RAISE(std::shared_ptr ipc_writer, + arrow::ipc::MakeFileWriter(outfile, rbatch->schema())); + // (Doc section: Arrow Writer) + // (Doc section: Arrow Write) + // Write the record batch. + ARROW_RETURN_NOT_OK(ipc_writer->WriteRecordBatch(*rbatch)); + // (Doc section: Arrow Write) + // (Doc section: Arrow Close) + // Specifically for IPC, the writer needs to be explicitly closed. + ARROW_RETURN_NOT_OK(ipc_writer->Close()); + // (Doc section: Arrow Close) + + // (Doc section: CSV Read Open) + // Bind our input file to "test_in.csv" + ARROW_ASSIGN_OR_RAISE(infile, arrow::io::ReadableFile::Open("test_in.csv")); + // (Doc section: CSV Read Open) + // (Doc section: CSV Table Declare) + std::shared_ptr csv_table; + // (Doc section: CSV Table Declare) + // (Doc section: CSV Reader Make) + // The CSV reader has several objects for various options. For now, we'll use defaults. + ARROW_ASSIGN_OR_RAISE( + auto csv_reader, + arrow::csv::TableReader::Make( + arrow::io::default_io_context(), infile, arrow::csv::ReadOptions::Defaults(), + arrow::csv::ParseOptions::Defaults(), arrow::csv::ConvertOptions::Defaults())); + // (Doc section: CSV Reader Make) + // (Doc section: CSV Read) + // Read the table. + ARROW_ASSIGN_OR_RAISE(csv_table, csv_reader->Read()) + // (Doc section: CSV Read) + + // (Doc section: CSV Write) + // Bind our output file to "test_out.csv" + ARROW_ASSIGN_OR_RAISE(outfile, arrow::io::FileOutputStream::Open("test_out.csv")); + // The CSV writer has simpler defaults, review API documentation for more complex usage. + ARROW_ASSIGN_OR_RAISE(auto csv_writer, + arrow::csv::MakeCSVWriter(outfile, csv_table->schema())); + ARROW_RETURN_NOT_OK(csv_writer->WriteTable(*csv_table)); + // Not necessary, but a safe practice. + ARROW_RETURN_NOT_OK(csv_writer->Close()); + // (Doc section: CSV Write) + + // (Doc section: Parquet Read Open) + // Bind our input file to "test_in.parquet" + ARROW_ASSIGN_OR_RAISE(infile, arrow::io::ReadableFile::Open("test_in.parquet")); + // (Doc section: Parquet Read Open) + // (Doc section: Parquet FileReader) + std::unique_ptr reader; + // (Doc section: Parquet FileReader) + // (Doc section: Parquet OpenFile) + // Note that Parquet's OpenFile() takes the reader by reference, rather than returning + // a reader. + PARQUET_THROW_NOT_OK( + parquet::arrow::OpenFile(infile, arrow::default_memory_pool(), &reader)); + // (Doc section: Parquet OpenFile) + + // (Doc section: Parquet Read) + std::shared_ptr parquet_table; + // Read the table. + PARQUET_THROW_NOT_OK(reader->ReadTable(&parquet_table)); + // (Doc section: Parquet Read) + + // (Doc section: Parquet Write) + // Parquet writing does not need a declared writer object. Just get the output + // file bound, then pass in the table, memory pool, output, and chunk size for + // breaking up the Table on-disk. + ARROW_ASSIGN_OR_RAISE(outfile, arrow::io::FileOutputStream::Open("test_out.parquet")); + PARQUET_THROW_NOT_OK(parquet::arrow::WriteTable( + *parquet_table, arrow::default_memory_pool(), outfile, 5)); + // (Doc section: Parquet Write) + // (Doc section: Return) + return arrow::Status::OK(); +} +// (Doc section: Return) + +// (Doc section: Main) +int main() { + arrow::Status st = RunMain(); + if (!st.ok()) { + std::cerr << st << std::endl; + return 1; + } + return 0; +} +// (Doc section: Main) +// (Doc section: File I/O) diff --git a/cpp/examples/tutorial_examples/run.sh b/cpp/examples/tutorial_examples/run.sh new file mode 100755 index 00000000000..ed319a9d327 --- /dev/null +++ b/cpp/examples/tutorial_examples/run.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e + +cd /io + +export ARROW_BUILD_DIR=/build/arrow +export EXAMPLE_BUILD_DIR=/build/example + +echo +echo "==" +echo "== Building Arrow C++ library" +echo "==" +echo + +./build_arrow.sh + +echo +echo "==" +echo "== Building example project using Arrow C++ library" +echo "==" +echo + +./build_example.sh + +echo +echo "==" +echo "== Running example project" +echo "==" +echo + +${EXAMPLE_BUILD_DIR}/arrow_example +${EXAMPLE_BUILD_DIR}/compute_example +${EXAMPLE_BUILD_DIR}/file_access_example +${EXAMPLE_BUILD_DIR}/dataset_example diff --git a/cpp/examples/tutorial_examples/tutorial.dockerfile b/cpp/examples/tutorial_examples/tutorial.dockerfile new file mode 100644 index 00000000000..9361fc5e81d --- /dev/null +++ b/cpp/examples/tutorial_examples/tutorial.dockerfile @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +FROM ubuntu:focal + +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update -y -q && \ + apt-get install -y -q --no-install-recommends \ + build-essential \ + cmake \ + pkg-config && \ + apt-get clean && rm -rf /var/lib/apt/lists* diff --git a/dev/tasks/tasks.yml b/dev/tasks/tasks.yml index aee50ef9ce9..3771dfddab1 100644 --- a/dev/tasks/tasks.yml +++ b/dev/tasks/tasks.yml @@ -1596,6 +1596,13 @@ tasks: run: {{ kind }} {% endfor %} + cpp-tutorial-example: + ci: github + template: cpp-examples/github.linux.yml + params: + type: tutorial_examples + run: tutorial + ############################## Utility tasks ############################ preview-docs: ci: github diff --git a/docs/source/cpp/getting_started.rst b/docs/source/cpp/getting_started.rst index 31095cbd217..89bd4559ef1 100644 --- a/docs/source/cpp/getting_started.rst +++ b/docs/source/cpp/getting_started.rst @@ -18,28 +18,31 @@ .. default-domain:: cpp .. highlight:: cpp -User Guide -========== +Getting Started +=============== + +The following articles demonstrate installation, use, and a basic understanding of Arrow. +These articles will get you setup quickly using Arrow and give you a taste of what the +library is capable of. +Specifically, it contains: an installation and linking guide; documentation of conventions used +in the codebase and suggested for users; and tutorials, including: + +* Building Arrow arrays and tabular structures +* Reading and writing Parquet, Arrow, and CSV files +* Executing compute kernels on arrays +* Reading and writing multi-file partitioned datasets + +Start here to gain a basic understanding of Arrow, and move on to the :doc:`/cpp/user_guide` to +explore more specific topics and underlying concepts, or the :doc:`/cpp/api` to explore Arrow's +API. .. toctree:: - overview - conventions build_system - memory - arrays - datatypes - tables - compute - streaming_execution - io - ipc - orc - parquet - csv - json - dataset - flight - gdb - threading - env_vars + conventions + tutorials/basic_arrow.rst + tutorials/io_tutorial.rst + tutorials/compute_tutorial.rst + tutorials/datasets_tutorial.rst + + diff --git a/docs/source/cpp/index.rst b/docs/source/cpp/index.rst index 70329c07233..ab693af2c55 100644 --- a/docs/source/cpp/index.rst +++ b/docs/source/cpp/index.rst @@ -22,5 +22,6 @@ C++ Implementation :maxdepth: 2 getting_started + user_guide Examples api diff --git a/docs/source/cpp/tutorials/basic_arrow.rst b/docs/source/cpp/tutorials/basic_arrow.rst new file mode 100644 index 00000000000..06f5fde32e8 --- /dev/null +++ b/docs/source/cpp/tutorials/basic_arrow.rst @@ -0,0 +1,284 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +.. default-domain:: cpp +.. highlight:: cpp + +.. cpp:namespace:: arrow + +=========================== +Basic Arrow Data Structures +=========================== + +Apache Arrow provides fundamental data structures for representing data: +:class:`Array`, :class:`ChunkedArray`, :class:`RecordBatch`, and :class:`Table`. +This article shows how to construct these data structures from primitive +data types; specifically, we will work with integers of varying size +representing days, months, and years. We will use them to create the following data structures: + +#. Arrow :class:`Arrays ` +#. :class:`ChunkedArrays` +#. :class:`RecordBatch`, from :class:`Arrays ` +#. :class:`Table`, from :class:`ChunkedArrays` + +Pre-requisites +-------------- +Before continuing, make sure you have: + +#. An Arrow installation, which you can set up here: :doc:`/cpp/build_system` +#. Understanding of how to use basic C++ data structures +#. Understanding of basic C++ data types + + +Setup +----- + +Before trying out Arrow, we need to fill in a couple gaps: + +1. We need to include necessary headers. + +2. ``A main()`` is needed to glue things together. + +Includes +^^^^^^^^ + +First, as ever, we need some includes. We'll get ``iostream`` for output, then import Arrow's basic +functionality from ``api.h``, like so: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/arrow_example.cc + :language: cpp + :start-after: (Doc section: Includes) + :end-before: (Doc section: Includes) + +Main() +^^^^^^ + +Next, we need a ``main()`` – a common pattern with Arrow looks like the +following: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/arrow_example.cc + :language: cpp + :start-after: (Doc section: Main) + :end-before: (Doc section: Main) + +This allows us to easily use Arrow’s error-handling macros, which will +return back to ``main()`` with a :class:`arrow::Status` object if a failure occurs – and +this ``main()`` will report the error. Note that this means Arrow never +raises exceptions, instead relying upon returning :class:`Status`. For more on +that, read here: :doc:`/cpp/conventions`. + +To accompany this ``main()``, we have a ``RunMain()`` from which any :class:`Status` +objects can return – this is where we’ll write the rest of the program: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/arrow_example.cc + :language: cpp + :start-after: (Doc section: RunMain Start) + :end-before: (Doc section: RunMain Start) + + +Making an Arrow Array +--------------------- + +Building int8 Arrays +^^^^^^^^^^^^^^^^^^^^ + +Given that we have some data in standard C++ arrays, and want to use Arrow, we need to move +the data from said arrays into Arrow arrays. We still guarantee contiguity of memory in an +:class:`Array`, so no worries about a performance loss when using :class:`Array` vs C++ arrays. +The easiest way to construct an :class:`Array` uses an :class:`ArrayBuilder`. + +.. seealso:: :doc:`/cpp/arrays` for more technical details on :class:`Array` + +The following code initializes an :class:`ArrayBuilder` for an :class:`Array` that will hold 8 bit +integers. Specifically, it uses the ``AppendValues()`` method, present in concrete +:class:`arrow::ArrayBuilder` subclasses, to fill the :class:`ArrayBuilder` with the +contents of a standard C++ array. Note the use of :c:macro:`ARROW_RETURN_NOT_OK`. +If ``AppendValues()`` fails, this macro will return to ``main()``, which will +print out the meaning of the failure. + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/arrow_example.cc + :language: cpp + :start-after: (Doc section: int8builder 1 Append) + :end-before: (Doc section: int8builder 1 Append) + +Given an :class:`ArrayBuilder` has the values we want in our :class:`Array`, we can use +:func:`ArrayBuilder::Finish` to output the final structure to an :class:`Array` – specifically, +we output to a ``std::shared_ptr``. Note the use of :c:macro:`ARROW_ASSIGN_OR_RAISE` +in the following code. :func:`~ArrayBuilder::Finish` outputs a :class:`arrow::Result` object, which :c:macro:`ARROW_ASSIGN_OR_RAISE` +can process. If the method fails, it will return to ``main()`` with a :class:`Status` +that will explain what went wrong. If it succeeds, then it will assign +the final output to the left-hand variable. + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/arrow_example.cc + :language: cpp + :start-after: (Doc section: int8builder 1 Finish) + :end-before: (Doc section: int8builder 1 Finish) + +As soon as :class:`ArrayBuilder` has had its :func:`Finish ` method called, its state resets, so +it can be used again, as if it was fresh. Thus, we repeat the process above for our second array: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/arrow_example.cc + :language: cpp + :start-after: (Doc section: int8builder 2) + :end-before: (Doc section: int8builder 2) + +Building int16 Arrays +^^^^^^^^^^^^^^^^^^^^^ + +An :class:`ArrayBuilder` has its type specified at the time of declaration. +Once this is done, it cannot have its type changed. We have to make a new one when we switch to year data, which +requires a 16-bit integer at the minimum. Of course, there’s an :class:`ArrayBuilder` for that. +It uses the exact same methods, but with the new data type: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/arrow_example.cc + :language: cpp + :start-after: (Doc section: int16builder) + :end-before: (Doc section: int16builder) + +Now, we have three Arrow :class:`Arrays `, with some variance in type. + +Making a RecordBatch +-------------------- + +A columnar data format only really comes into play when you have a table. +So, let’s make one. The first kind we’ll make is the :class:`RecordBatch` – this +uses :class:`Arrays ` internally, which means all data will be contiguous within each +column, but any appending or concatenating will require copying. Making a :class:`RecordBatch` +has two steps, given existing :class:`Arrays `: + +#. Defining a :class:`Schema` +#. Loading the :class:`Schema` and Arrays into the constructor + +Defining a Schema +^^^^^^^^^^^^^^^^^ + +To get started making a :class:`RecordBatch`, we first need to define +characteristics of the columns, each represented by a :class:`Field` instance. +Each :class:`Field` contains a name and datatype for its associated column; then, +a :class:`Schema` groups them together and sets the order of the columns, like +so: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/arrow_example.cc + :language: cpp + :start-after: (Doc section: Schema) + :end-before: (Doc section: Schema) + +Building a RecordBatch +^^^^^^^^^^^^^^^^^^^^^^ + +With data in :class:`Arrays ` from the previous section, and column descriptions in our +:class:`Schema` from the previous step, we can make the :class:`RecordBatch`. Note that the +length of the columns is necessary, and the length is shared by all columns. + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/arrow_example.cc + :language: cpp + :start-after: (Doc section: RBatch) + :end-before: (Doc section: RBatch) + +Now, we have our data in a nice tabular form, safely within the :class:`RecordBatch`. +What we can do with this will be discussed in the later tutorials. + +Making a ChunkedArray +--------------------- + +Let’s say that we want an array made up of sub-arrays, because it +can be useful for avoiding data copies when concatenating, for parallelizing work, for fitting each chunk +cutely into cache, or for exceeding the 2,147,483,647 row limit in a +standard Arrow :class:`Array`. For this, Arrow offers :class:`ChunkedArray`, which can be +made up of individual Arrow :class:`Arrays `. In this example, we can reuse the arrays +we made earlier in part of our chunked array, allowing us to extend them without having to copy +data. So, let’s build a few more :class:`Arrays `, +using the same builders for ease of use: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/arrow_example.cc + :language: cpp + :start-after: (Doc section: More Arrays) + :end-before: (Doc section: More Arrays) + +In order to support an arbitrary amount of :class:`Arrays ` in the construction of the +:class:`ChunkedArray`, Arrow supplies :class:`ArrayVector`. This provides a vector for :class:`Arrays `, +and we'll use it here to prepare to make a :class:`ChunkedArray`: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/arrow_example.cc + :language: cpp + :start-after: (Doc section: ArrayVector) + :end-before: (Doc section: ArrayVector) + +In order to leverage Arrow, we do need to take that last step, and move into a :class:`ChunkedArray`: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/arrow_example.cc + :language: cpp + :start-after: (Doc section: ChunkedArray Day) + :end-before: (Doc section: ChunkedArray Day) + +With a :class:`ChunkedArray` for our day values, we now just need to repeat the process +for the month and year data: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/arrow_example.cc + :language: cpp + :start-after: (Doc section: ChunkedArray Month Year) + :end-before: (Doc section: ChunkedArray Month Year) + +With that, we are left with three :class:`ChunkedArrays `, varying in type. + +Making a Table +-------------- + +One particularly useful thing we can do with the :class:`ChunkedArrays ` from the previous section is creating +:class:`Tables
`. Much like a :class:`RecordBatch`, a :class:`Table` stores tabular data. However, a +:class:`Table` does not guarantee contiguity, due to being made up of :class:`ChunkedArrays `. +This can be useful for logic, paralellizing work, for fitting chunks into cache, or exceeding the 2,147,483,647 row limit +present in :class:`Array` and, thus, :class:`RecordBatch`. + +If you read up to :class:`RecordBatch`, you may note that the :class:`Table` constructor in the following code is +effectively identical, it just happens to put the length of the columns +in position 3, and makes a :class:`Table`. We re-use the :class:`Schema` from before, and +make our :class:`Table`: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/arrow_example.cc + :language: cpp + :start-after: (Doc section: Table) + :end-before: (Doc section: Table) + +Now, we have our data in a nice tabular form, safely within the :class:`Table`. +What we can do with this will be discussed in the later tutorials. + +Ending Program +-------------- + +At the end, we just return :func:`Status::OK()`, so the ``main()`` knows that +we’re done, and that everything’s okay. + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/arrow_example.cc + :language: cpp + :start-after: (Doc section: Ret) + :end-before: (Doc section: Ret) + +Wrapping Up +----------- + +With that, you’ve created the fundamental data structures in Arrow, and +can proceed to getting them in and out of a program with file I/O in the next article. + +Refer to the below for a copy of the complete code: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/arrow_example.cc + :language: cpp + :start-after: (Doc section: Basic Example) + :end-before: (Doc section: Basic Example) + :linenos: + :lineno-match: \ No newline at end of file diff --git a/docs/source/cpp/tutorials/compute_tutorial.rst b/docs/source/cpp/tutorials/compute_tutorial.rst new file mode 100644 index 00000000000..bcb87e6a8f9 --- /dev/null +++ b/docs/source/cpp/tutorials/compute_tutorial.rst @@ -0,0 +1,343 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +.. default-domain:: cpp +.. highlight:: cpp + +.. cpp:namespace:: arrow + +============= +Arrow Compute +============= + +Apache Arrow provides compute functions to facilitate efficient and +portable data processing. In this article, you will use Arrow’s compute +functionality to: + +1. Calculate a sum over a column + +2. Calculate element-wise sums over two columns + +3. Search for a value in a column + +Pre-requisites +--------------- + +Before continuing, make sure you have: + +1. An Arrow installation, which you can set up here: :doc:`/cpp/build_system` + +2. An understanding of basic Arrow data structures from :doc:`/cpp/tutorials/basic_arrow` + +Setup +----- + +Before running some computations, we need to fill in a couple gaps: + +1. We need to include necessary headers. + +2. ``A main()`` is needed to glue things together. + +3. We need data to play with. + +Includes +^^^^^^^^ + +Before writing C++ code, we need some includes. We'll get ``iostream`` for output, then import Arrow's +compute functionality: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: Includes) + :end-before: (Doc section: Includes) + +Main() +^^^^^^ + +For our glue, we’ll use the ``main()`` pattern from the previous tutorial on +data structures: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: Main) + :end-before: (Doc section: Main) + +Which, like when we used it before, is paired with a ``RunMain()``: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: RunMain) + :end-before: (Doc section: RunMain) + +Generating Tables for Computation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Before we begin, we’ll initialize a :class:`Table` with two columns to play with. We’ll use +the method from :doc:`/cpp/tutorials/basic_arrow`, so look back +there if anything’s confusing: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: Create Tables) + :end-before: (Doc section: Create Tables) + +Calculating a Sum over an Array +------------------------------- + +Using a computation function has two general steps, which we separate +here: + +1. Preparing a :class:`Datum` for output + +2. Calling :func:`compute::Sum`, a convenience function for summation over an :class:`Array` + +3. Retrieving and printing output + +Prepare Memory for Output with Datum +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +When computation is done, we need somewhere for our results to go. In +Arrow, the object for such output is called :class:`Datum`. This object is used +to pass around inputs and outputs in compute functions, and can contain +many differently-shaped Arrow data structures. We’ll need it to retrieve +the output from compute functions. + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: Sum Datum Declaration) + :end-before: (Doc section: Sum Datum Declaration) + +Call Sum() +^^^^^^^^^^ + +Here, we’ll get our :class:`Table`, which has columns “A” and “B”, and sum over +column “A.” For summation, there is a convenience function, called +:func:`compute::Sum`, which reduces the complexity of the compute interface. We’ll look +at the more complex version for the next computation. For a given +function, refer to :doc:`/cpp/api/compute` to see if there is a +convenience function. :func:`compute::Sum` takes in a given :class:`Array` or :class:`ChunkedArray` +– here, we use :func:`Table::GetColumnByName` to pass in column A. Then, it outputs to +a :class:`Datum`. Putting that all together, we get this: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: Sum Call) + :end-before: (Doc section: Sum Call) + +Get Results from Datum +^^^^^^^^^^^^^^^^^^^^^^ + +The previous step leaves us with a :class:`Datum` which contains our sum. +However, we cannot print it directly – its flexibility in holding +arbitrary Arrow data structures means we have to retrieve our data +carefully. First, to understand what’s in it, we can check which kind of +data structure it is, then what kind of primitive is being held: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: Sum Datum Type) + :end-before: (Doc section: Sum Datum Type) + +This should report the :class:`Datum` stores a :class:`Scalar` with a 64-bit integer. Just +to see what the value is, we can print it out like so, which yields +12891: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: Sum Contents) + :end-before: (Doc section: Sum Contents) + +Now we’ve used :func:`compute::Sum` and gotten what we want out of it! + +Calculating Element-Wise Array Addition with CallFunction() +----------------------------------------------------------- + +A next layer of complexity uses what :func:`compute::Sum` was helpfully hiding: +:func:`compute::CallFunction`. For this example, we will explore how to use the more +robust :func:`compute::CallFunction` with the “add” compute function. The pattern +remains similar: + +1. Preparing a Datum for output + +2. Calling :func:`compute::CallFunction` with “add” + +3. Retrieving and printing output + +Prepare Memory for Output with Datum +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Once more, we’ll need a Datum for any output we get: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: Add Datum Declaration) + :end-before: (Doc section: Add Datum Declaration) + +Use CallFunction() with “add” +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +:func:`compute::CallFunction` takes the name of the desired function as its first +argument, then the data inputs for said function as a vector in its +second argument. Right now, we want an element-wise addition between +columns “A” and “B”. So, we’ll ask for “add,” pass in columns “A and B”, +and output to our :class:`Datum`. Put this all together, and we get: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: Add Call) + :end-before: (Doc section: Add Call) + +.. seealso:: :ref:`compute-function-list` for a list of other functions to go with :func:`compute::CallFunction` + +Get Results from Datum +^^^^^^^^^^^^^^^^^^^^^^ + +Again, the :class:`Datum` needs some careful handling. Said handling is much +easier when we know what’s in it. This :class:`Datum` holds a :class:`ChunkedArray` with +32-bit integers, but we can print that to confirm: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: Add Datum Type) + :end-before: (Doc section: Add Datum Type) + +Since it’s a :class:`ChunkedArray`, we request that from the :class:`Datum` – :class:`ChunkedArray` +has a :func:`ChunkedArray::ToString` method, so we’ll use that to print out its contents: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: Add Contents) + :end-before: (Doc section: Add Contents) + +The output looks like this:: + + Datum kind: ChunkedArray content type: int32 + [ + [ + 75376, + 647, + 2287, + 5671, + 5092 + ] + ] + +Now, we’ve used :func:`compute::CallFunction`, instead of a convenience function! This +enables a much wider range of available computations. + +Searching for a Value with CallFunction() and Options +----------------------------------------------------- + +One class of computations remains. :func:`compute::CallFunction` uses a vector for data +inputs, but computation often needs additional arguments to function. In +order to supply this, computation functions may be associated with +structs where their arguments can be defined. You can check a given +function to see which struct it uses :ref:`here `. For this example, we’ll search for a value in column “A” using +the “index” compute function. This process has three steps, as opposed +to the two from before: + +1. Preparing a :class:`Datum` for output + +2. Preparing :class:`compute::IndexOptions` + +3. Calling :func:`compute::CallFunction` with “index” and :class:`compute::IndexOptions` + +4. Retrieving and printing output + +Prepare Memory for Output with Datum +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We’ll need a :class:`Datum` for any output we get: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: Index Datum Declare) + :end-before: (Doc section: Index Datum Declare) + +Configure “index” with IndexOptions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For this exploration, we’ll use the “index” function – this is a +searching method, which returns the index of an input value. In order to +pass this input value, we require an :class:`compute::IndexOptions` struct. So, let’s make +that struct: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: IndexOptions Declare) + :end-before: (Doc section: IndexOptions Declare) + +In a searching function, one requires a target value. Here, we’ll use +2223, the third item in column A, and configure our struct accordingly: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: IndexOptions Assign) + :end-before: (Doc section: IndexOptions Assign) + +Use CallFunction() with “index” and IndexOptions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To actually run the function, we use :func:`compute::CallFunction` again, this time +passing our IndexOptions struct by reference as a third argument. As +before, the first argument is the name of the function, and the second +our data input: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: Index Call) + :end-before: (Doc section: Index Call) + +Get Results from Datum +^^^^^^^^^^^^^^^^^^^^^^ + +One last time, let’s see what our :class:`Datum` has! This will be a :class:`Scalar` with +a 64-bit integer, and the output will be 2: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: Index Inspection) + :end-before: (Doc section: Index Inspection) + +Ending Program +-------------- + +At the end, we just return :func:`arrow::Status::OK`, so the ``main()`` knows that +we’re done, and that everything’s okay, just like the preceding +tutorials. + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: Ret) + :end-before: (Doc section: Ret) + +With that, you’ve used compute functions which fall into the three main +types – with and without convenience functions, then with an Options +struct. Now you can process any :class:`Table` you need to, and solve whatever +data problem you have that fits into memory! + +Which means that now we have to see how we can work with +larger-than-memory datasets, via Arrow Datasets in the next article. + +Refer to the below for a copy of the complete code: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/compute_example.cc + :language: cpp + :start-after: (Doc section: Compute Example) + :end-before: (Doc section: Compute Example) + :linenos: + :lineno-match: \ No newline at end of file diff --git a/docs/source/cpp/tutorials/datasets_tutorial.rst b/docs/source/cpp/tutorials/datasets_tutorial.rst new file mode 100644 index 00000000000..285fc24d8d5 --- /dev/null +++ b/docs/source/cpp/tutorials/datasets_tutorial.rst @@ -0,0 +1,453 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +.. default-domain:: cpp +.. highlight:: cpp + +.. cpp:namespace:: arrow + +============== +Arrow Datasets +============== + +Arrow C++ provides the concept and implementation of :class:`Datasets ` to work +with fragmented data, which can be larger-than-memory, be that due to +generating large amounts, reading in from a stream, or having a large +file on disk. In this article, you will: + +1. read a multi-file partitioned dataset and put it into a Table, + +2. write out a partitioned dataset from a Table. + +Pre-requisites +--------------- + +Before continuing, make sure you have: + +1. An Arrow installation, which you can set up here: :doc:`/cpp/build_system` + +2. An understanding of basic Arrow data structures from :doc:`/cpp/tutorials/basic_arrow` + +To witness the differences, it may be useful to have also read the :doc:`/cpp/tutorials/io_tutorial`. However, it is not required. + +Setup +----- + +Before running some computations, we need to fill in a couple gaps: + +1. We need to include necessary headers. + +2. A ``main()`` is needed to glue things together. + +3. We need data on disk to play with. + +Includes +^^^^^^^^ + +Before writing C++ code, we need some includes. We'll get ``iostream`` for output, then import Arrow's +compute functionality for each file type we'll work with in this article: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Includes) + :end-before: (Doc section: Includes) + +Main() +^^^^^^ + +For our glue, we’ll use the ``main()`` pattern from the previous tutorial on +data structures: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Main) + :end-before: (Doc section: Main) + +Which, like when we used it before, is paired with a ``RunMain()``: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: RunMain) + :end-before: (Doc section: RunMain) + +Generating Files for Reading +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We need some files to actually play with. In practice, you’ll likely +have some input for your own application. Here, however, we want to +explore without the overhead of supplying or finding a dataset, so let’s +generate some to make this easy to follow. Feel free to read through +this, but the concepts will be visited properly in this article – just +copy it in, for now, and realize it ends with a partitioned dataset on +disk: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Helper Functions) + :end-before: (Doc section: Helper Functions) + +In order to actually have these files, make sure the first thing called +in ``RunMain()`` is our helper function ``PrepareEnv()``, which will get a +dataset on disk for us to play with: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: PrepareEnv) + :end-before: (Doc section: PrepareEnv) + +Reading a Partitioned Dataset +----------------------------- + +Reading a Dataset is a distinct task from reading a single file. The +task takes more work than reading a single file, due to needing to be +able to parse multiple files and/or folders. This process can be broken +up into the following steps: + +1. Getting a :class:`fs::FileSystem` object for the local FS + +2. Create a :class:`fs::FileSelector` and use it to prepare a :class:`dataset::FileSystemDatasetFactory` + +3. Build a :class:`dataset::Dataset` using the :class:`dataset::FileSystemDatasetFactory` + +4. Use a :class:`dataset::Scanner` to read into a :class:`Table` + +Preparing a FileSystem Object +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to begin, we’ll need to be able to interact with the local +filesystem. In order to do that, we’ll need an :class:`fs::FileSystem` object. +A :class:`fs::FileSystem` is an abstraction that lets us use the same interface +regardless of using Amazon S3, Google Cloud Storage, or local disk – and +we’ll be using local disk. So, let’s declare it: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: FileSystem Declare) + :end-before: (Doc section: FileSystem Declare) + +For this example, we’ll have our :class:`FileSystem’s ` base path exist in the +same directory as the executable. :func:`fs::FileSystemFromUriOrPath` lets us get +a :class:`fs::FileSystem` object for any of the types of supported filesystems. +Here, though, we’ll just pass our path: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: FileSystem Init) + :end-before: (Doc section: FileSystem Init) + +.. seealso:: :class:`fs::FileSystem` for the other supported filesystems. + +Creating a FileSystemDatasetFactory +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A :class:`fs::FileSystem` stores a lot of metadata, but we need to be able to +traverse it and parse that metadata. In Arrow, we use a :class:`FileSelector` to +do so: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: FileSelector Declare) + :end-before: (Doc section: FileSelector Declare) + +This :class:`fs::FileSelector` isn’t able to do anything yet. In order to use it, we +need to configure it – we’ll have it start any selection in +“parquet_dataset,” which is where the environment preparation process +has left us a dataset, and set recursive to true, which allows for +traversal of folders. + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: FileSelector Config) + :end-before: (Doc section: FileSelector Config) + +To get a :class:`dataset::Dataset` from a :class:`fs::FileSystem`, we need to prepare a +:class:`dataset::FileSystemDatasetFactory`. This is a long but descriptive name – it’ll +make us a factory to get data from our :class:`fs::FileSystem`. First, we configure +it by filling a :class:`dataset::FileSystemFactoryOptions` struct: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: FileSystemFactoryOptions) + :end-before: (Doc section: FileSystemFactoryOptions) + +There are many file formats, and we have to pick one that will be +expected when actually reading. Parquet is what we have on disk, so of +course we’ll ask for that when reading: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: File Format Setup) + :end-before: (Doc section: File Format Setup) + +After setting up the :class:`fs::FileSystem`, :class:`fs::FileSelector`, options, and file format, +we can make that :class:`dataset::FileSystemDatasetFactory`. This simply requires passing +in everything we’ve prepared and assigning that to a variable: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: FileSystemDatasetFactory Make) + :end-before: (Doc section: FileSystemDatasetFactory Make) + +Build Dataset using Factory +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +With a :class:`dataset::FileSystemDatasetFactory` set up, we can actually build our +:class:`dataset::Dataset` with :func:`dataset::FileSystemDatasetFactory::Finish`, just +like with an :class:`ArrayBuilder` back in the basic tutorial: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: FileSystemDatasetFactory Finish) + :end-before: (Doc section: FileSystemDatasetFactory Finish) + +Now, we have a :class:`dataset::Dataset` object in memory. This does not mean that the +entire dataset is manifested in memory, but that we now have access to +tools that allow us to explore and use the dataset that is on disk. For +example, we can grab the fragments (files) that make up our whole +dataset, and print those out, along with some small info: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Dataset Fragments) + :end-before: (Doc section: Dataset Fragments) + +Move Dataset into Table +^^^^^^^^^^^^^^^^^^^^^^^ + +One way we can do something with :class:`Datasets ` is getting +them into a :class:`Table`, where we can do anything we’ve learned we can do to +:class:`Tables
` to that :class:`Table`. + +.. seealso:: :doc:`/cpp/streaming_execution` for execution that avoids manifesting the entire dataset in memory. + +In order to move a :class:`Dataset’s ` contents into a :class:`Table`, +we need a :class:`dataset::Scanner`, which scans the data and outputs it to the :class:`Table`. +First, we get a :class:`dataset::ScannerBuilder` from the :class:`dataset::Dataset`: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Read Scan Builder) + :end-before: (Doc section: Read Scan Builder) + +Of course, a Builder’s only use is to get us our :class:`dataset::Scanner`, so let’s use +:func:`dataset::ScannerBuilder::Finish`: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Read Scanner) + :end-before: (Doc section: Read Scanner) + +Now that we have a tool to move through our :class:`dataset::Dataset`, let’s use it to get +our :class:`Table`. :func:`dataset::Scanner::ToTable` offers exactly what we’re looking for, +and we can print the results: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: To Table) + :end-before: (Doc section: To Table) + +This leaves us with a normal :class:`Table`. Again, to do things with :class:`Datasets ` +without moving to a :class:`Table`, consider using Acero. + +Writing a Dataset to Disk from Table +------------------------------------ + +Writing a :class:`dataset::Dataset` is a distinct task from writing a single file. The +task takes more work than writing a single file, due to needing to be +able to parse handle a partitioning scheme across multiple files and +folders. This process can be broken up into the following steps: + +1. Prepare a :class:`TableBatchReader` + +2. Create a :class:`dataset::Scanner` to pull data from :class:`TableBatchReader` + +3. Prepare schema, partitioning, and file format options + +4. Set up :class:`dataset::FileSystemDatasetWriteOptions` – a struct that configures our writing functions + +5. Write dataset to disk + +Prepare Data from Table for Writing +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We have a :class:`Table`, and we want to get a :class:`dataset::Dataset` on disk. In fact, for the +sake of exploration, we’ll use a different partitioning scheme for the +dataset – instead of just breaking into halves like the original +fragments, we’ll partition based on each row’s value in the “a” column. + +To get started on that, let’s get a :class:`TableBatchReader`! This makes it very +easy to write to a :class:`Dataset`, and can be used elsewhere whenever a :class:`Table` +needs to be broken into a stream of :class:`RecordBatches `. Here, we can just use +the :class:`TableBatchReader’s ` constructor, with our table: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: TableBatchReader) + :end-before: (Doc section: TableBatchReader) + +Create Scanner for Moving Table Data +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The process for writing a :class:`dataset::Dataset`, once a source of data is available, +is similar to the reverse of reading it. Before, we used a :class:`dataset::Scanner` in +order to scan into a :class:`Table` – now, we need one to read out of our +:class:`TableBatchReader`. To get that :class:`dataset::Scanner`, we’ll make a :class:`dataset::ScannerBuilder` +based on our :class:`TableBatchReader`, then use that Builder to build a :class:`dataset::Scanner`: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: WriteScanner) + :end-before: (Doc section: WriteScanner) + +Prepare Schema, Partitioning, and File Format Variables +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Since we want to partition based on the “a” column, we need to declare +that. When defining our partitioning :class:`Schema`, we’ll just have a single +:class:`Field` that contains “a”: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Partition Schema) + :end-before: (Doc section: Partition Schema) + +This :class:`Schema` determines what the key is for partitioning, but we need to +choose the algorithm that’ll do something with this key. We will use +Hive-style again, this time with our schema passed to it as +configuration: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Partition Create) + :end-before: (Doc section: Partition Create) + +Several file formats are available, but Parquet is commonly used with +Arrow, so we’ll write back out to that: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Write Format) + :end-before: (Doc section: Write Format) + +Configure FileSystemDatasetWriteOptions +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In order to write to disk, we need some configuration. We’ll do so via +setting values in a :class:`dataset::FileSystemDatasetWriteOptions` struct. We’ll +initialize it with defaults where possible: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Write Options) + :end-before: (Doc section: Write Options) + +One important step in writing to file is having a :class:`fs::FileSystem` to target. +Luckily, we have one from when we set it up for reading. This is a +simple variable assignment: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Options FS) + :end-before: (Doc section: Options FS) + +Arrow can make the directory, but it does need a name for said +directory, so let’s give it one, call it “write_dataset”: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Options Target) + :end-before: (Doc section: Options Target) + +We made a partitioning method previously, declaring that we’d use +Hive-style – this is where we actually pass that to our writing +function: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Options Partitioning) + :end-before: (Doc section: Options Partitioning) + +Part of what’ll happen is Arrow will break up files, thus preventing +them from being too large to handle. This is what makes a dataset +fragmented in the first place. In order to set this up, we need a base +name for each fragment in a directory – in this case, we’ll have +“part{i}.parquet”, which means the third file (within the same +directory) will be called “part3.parquet”, for example: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Options Name Template) + :end-before: (Doc section: Options Name Template) + +Sometimes, data will be written to the same location more than once, and +overwriting will be accepted. Since we may want to run this application +more than once, we will set Arrow to overwrite existing data – if we +didn’t, Arrow would abort due to seeing existing data after the first +run of this application: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Options File Behavior) + :end-before: (Doc section: Options File Behavior) + +Write Dataset to Disk +^^^^^^^^^^^^^^^^^^^^^ + +Once the :class:`dataset::FileSystemDatasetWriteOptions` has been configured, and a +:class:`dataset::Scanner` is prepared to parse the data, we can pass the Options and +:class:`dataset::Scanner` to the :func:`dataset::FileSystemDataset::Write` to write out to +disk: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Write Dataset) + :end-before: (Doc section: Write Dataset) + +You can review your disk to see that you’ve written a folder containing +subfolders for every value of “a”, which each have Parquet files! + +Ending Program +-------------- + +At the end, we just return :func:`Status::OK`, so the ``main()`` knows that +we’re done, and that everything’s okay, just like the preceding +tutorials. + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Ret) + :end-before: (Doc section: Ret) + +With that, you’ve read and written partitioned datasets! This method, +with some configuration, will work for any supported dataset format. For +an example of such a dataset, the NYC Taxi dataset is a well-known +one, which you can find `here `_. +Now you can get larger-than-memory data mapped for use! + +Which means that now we have to be able to process this data without +pulling it all into memory at once. For this, try Acero. + +.. seealso:: :doc:`/cpp/streaming_execution` for more information on Acero. + +Refer to the below for a copy of the complete code: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/dataset_example.cc + :language: cpp + :start-after: (Doc section: Dataset Example) + :end-before: (Doc section: Dataset Example) + :linenos: + :lineno-match: \ No newline at end of file diff --git a/docs/source/cpp/tutorials/io_tutorial.rst b/docs/source/cpp/tutorials/io_tutorial.rst new file mode 100644 index 00000000000..f981c94b83e --- /dev/null +++ b/docs/source/cpp/tutorials/io_tutorial.rst @@ -0,0 +1,404 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +.. default-domain:: cpp +.. highlight:: cpp + +.. cpp:namespace:: arrow + +============== +Arrow File I/O +============== + +Apache Arrow provides file I/O functions to facilitate use of Arrow from +the start to end of an application. In this article, you will: + +1. Read an Arrow file into a :class:`RecordBatch` and write it back out afterwards + +2. Read a CSV file into a :class:`Table` and write it back out afterwards + +3. Read a Parquet file into a :class:`Table` and write it back out afterwards + +Pre-requisites +--------------- + +Before continuing, make sure you have: + +1. An Arrow installation, which you can set up here: :doc:`/cpp/build_system` + +2. An understanding of basic Arrow data structures from :doc:`/cpp/tutorials/basic_arrow` + +3. A directory to run the final application in – this program will generate some files, so be prepared for that. + +Setup +----- + +Before writing out some file I/O, we need to fill in a couple gaps: + +1. We need to include necessary headers. + +2. A ``main()`` is needed to glue things together. + +3. We need files to play with. + +Includes +^^^^^^^^ + +Before writing C++ code, we need some includes. We'll get ``iostream`` for output, then import Arrow's +I/O functionality for each file type we'll work with in this article: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: Includes) + :end-before: (Doc section: Includes) + +Main() +^^^^^^ + +For our glue, we’ll use the ``main()`` pattern from the previous tutorial on +data structures: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: Main) + :end-before: (Doc section: Main) + +Which, like when we used it before, is paired with a ``RunMain()``: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: RunMain) + :end-before: (Doc section: RunMain) + +Generating Files for Reading +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We need some files to actually play with. In practice, you’ll likely +have some input for your own application. Here, however, we want to +explore doing I/O for the sake of it, so let’s generate some files to make +this easy to follow. To create those, we’ll define a helper function +that we’ll run first. Feel free to read through this, but the concepts +used will be explained later in this article. Note that we’re using the +day/month/year data from the previous tutorial. For now, just copy the +function in: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: GenInitialFile) + :end-before: (Doc section: GenInitialFile) + +To get the files for the rest of your code to function, make sure to +call ``GenInitialFile()`` as the very first line in ``RunMain()`` to initialize +the environment: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: Gen Files) + :end-before: (Doc section: Gen Files) + +I/O with Arrow Files +-------------------- + +We’re going to go through this step by step, reading then writing, as +follows: + +1. Reading a file + + a. Open the file + + b. Bind file to :class:`ipc::RecordBatchFileReader` + + c. Read file to :class:`RecordBatch` + +2. Writing a file + + a. Get a :class:`io::FileOutputStream` + + b. Write to file from :class:`RecordBatch` + +Opening a File +^^^^^^^^^^^^^^ + +To actually read a file, we need to get some sort of way to point to it. +In Arrow, that means we’re going to get a :class:`io::ReadableFile` object – much +like an :class:`ArrayBuilder` can clear and make new arrays, we can reassign this +to new files, so we’ll use this instance throughout the examples: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: ReadableFile Definition) + :end-before: (Doc section: ReadableFile Definition) + +A :class:`io::ReadableFile` does little alone – we actually have it bind to a file +with :func:`io::ReadableFile::Open`. For +our purposes here, the default arguments suffice: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: Arrow ReadableFile Open) + :end-before: (Doc section: Arrow ReadableFile Open) + +Opening an Arrow file Reader +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +An :class:`io::ReadableFile` is too generic to offer all functionality to read an Arrow file. +We need to use it to get an :class:`ipc::RecordBatchFileReader` object. This object implements +all the logic needed to read an Arrow file with correct formatting. We get one through +:func:`ipc::RecordBatchFileReader::Open`: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: Arrow Read Open) + :end-before: (Doc section: Arrow Read Open) + +Reading an Open Arrow File to RecordBatch +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +We have to use a :class:`RecordBatch` to read an Arrow file, so we’ll get a +:class:`RecordBatch`. Once we have that, we can actually read the file. Arrow +files can have multiple :class:`RecordBatches `, so we must pass an index. This +file only has one, so pass 0: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: Arrow Read) + :end-before: (Doc section: Arrow Read) + +Prepare a FileOutputStream +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For output, we need a :class:`io::FileOutputStream`. Just like our :class:`io::ReadableFile`, +we’ll be reusing this, so be ready for that. We open files the same way +as when reading: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: Arrow Write Open) + :end-before: (Doc section: Arrow Write Open) + +Write Arrow File from RecordBatch +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Now, we grab our :class:`RecordBatch` we read into previously, and use it, along +with our target file, to create a :class:`ipc::RecordBatchWriter`. The +:class:`ipc::RecordBatchWriter` needs two things: + +1. the target file + +2. the :class:`Schema` for our :class:`RecordBatch` (in case we need to write more :class:`RecordBatches ` of the same format.) + +The :class:`Schema` comes from our existing :class:`RecordBatch` and the target file is +the output stream we just created. + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: Arrow Writer) + :end-before: (Doc section: Arrow Writer) + +We can just call :func:`ipc::RecordBatchWriter::WriteRecordBatch` with our :class:`RecordBatch` to fill up our +file: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: Arrow Write) + :end-before: (Doc section: Arrow Write) + +For IPC in particular, the writer has to be closed since it anticipates more than one batch may be written. To do that: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: Arrow Close) + :end-before: (Doc section: Arrow Close) + +Now we’ve read and written an IPC file! + +I/O with CSV +------------ + +We’re going to go through this step by step, reading then writing, as +follows: + +1. Reading a file + + a. Open the file + + b. Prepare Table + + c. Read File using :class:`csv::TableReader` + +2. Writing a file + + a. Get a :class:`io::FileOutputStream` + + b. Write to file from :class:`Table` + +Opening a CSV File +^^^^^^^^^^^^^^^^^^ + +For a CSV file, we need to open a :class:`io::ReadableFile`, just like an Arrow file, +and reuse our :class:`io::ReadableFile` object from before to do so: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: CSV Read Open) + :end-before: (Doc section: CSV Read Open) + +Preparing a Table +^^^^^^^^^^^^^^^^^ + +CSV can be read into a :class:`Table`, so declare a pointer to a :class:`Table`: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: CSV Table Declare) + :end-before: (Doc section: CSV Table Declare) + +Read a CSV File to Table +^^^^^^^^^^^^^^^^^^^^^^^^ + +The CSV reader has option structs which need to be passed – luckily, +there are defaults for these which we can pass directly. For reference +on the other options, go here: :doc:`/cpp/api/formats`. +without any special delimiters and is small, so we can make our reader +with defaults: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: CSV Reader Make) + :end-before: (Doc section: CSV Reader Make) + +With the CSV reader primed, we can use its :func:`csv::TableReader::Read` method to fill our +:class:`Table`: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: CSV Read) + :end-before: (Doc section: CSV Read) + +Write a CSV File from Table +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +CSV writing to :class:`Table` looks exactly like IPC writing to :class:`RecordBatch`, +except with our :class:`Table`, and using :func:`ipc::RecordBatchWriter::WriteTable` instead of +:func:`ipc::RecordBatchWriter::WriteRecordBatch`. Note that the same writer class is used -- +we're writing with :func:`ipc::RecordBatchWriter::WriteTable` because we have a :class:`Table`. We’ll target +a file, use our :class:`Table’s
` :class:`Schema`, and then write the :class:`Table`: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: CSV Write) + :end-before: (Doc section: CSV Write) + +Now, we’ve read and written a CSV file! + +File I/O with Parquet +--------------------- + +We’re going to go through this step by step, reading then writing, as +follows: + +1. Reading a file + + a. Open the file + + b. Prepare :class:`parquet::arrow::FileReader` + + c. Read file to :class:`Table` + +2. Writing a file + + a. Write :class:`Table` to file + +Opening a Parquet File +^^^^^^^^^^^^^^^^^^^^^^ + +Once more, this file format, Parquet, needs a :class:`io::ReadableFile`, which we +already have, and for the :func:`io::ReadableFile::Open` method to be called on a file: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: Parquet Read Open) + :end-before: (Doc section: Parquet Read Open) + +Setting up a Parquet Reader +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +As always, we need a Reader to actually read the file. We’ve been +getting Readers for each file format from the Arrow namespace. This +time, we enter the Parquet namespace to get the :class:`parquet::arrow::FileReader`: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: Parquet FileReader) + :end-before: (Doc section: Parquet FileReader) + +Now, to set up our reader, we call :func:`parquet::arrow::OpenFile`. Yes, this is necessary +even though we used :func:`io::ReadableFile::Open`. Note that we pass our +:class:`parquet::arrow::FileReader` by reference, instead of assigning to it in output: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: Parquet OpenFile) + :end-before: (Doc section: Parquet OpenFile) + +Reading a Parquet File to Table +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +With a prepared :class:`parquet::arrow::FileReader` in hand, we can read to a +:class:`Table`, except we must pass the :class:`Table` by reference instead of outputting to it: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: Parquet Read) + :end-before: (Doc section: Parquet Read) + +Writing a Parquet File from Table +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +For single-shot writes, writing a Parquet file does not need a writer object. Instead, we give +it our table, point to the memory pool it will use for any necessary +memory consumption, tell it where to write, and the chunk size if it +needs to break up the file at all: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: Parquet Write) + :end-before: (Doc section: Parquet Write) + +Ending Program +-------------- + +At the end, we just return :func:`Status::OK`, so the ``main()`` knows that +we’re done, and that everything’s okay. Just like in the first tutorial. + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: Return) + :end-before: (Doc section: Return) + +With that, you’ve read and written IPC, CSV, and Parquet in Arrow, and +can properly load data and write output! Now, we can move into +processing data with compute functions in the next article. + +Refer to the below for a copy of the complete code: + +.. literalinclude:: ../../../../cpp/examples/tutorial_examples/file_access_example.cc + :language: cpp + :start-after: (Doc section: File I/O) + :end-before: (Doc section: File I/O) + :linenos: + :lineno-match: \ No newline at end of file diff --git a/docs/source/cpp/user_guide.rst b/docs/source/cpp/user_guide.rst new file mode 100644 index 00000000000..375ddb5cfb9 --- /dev/null +++ b/docs/source/cpp/user_guide.rst @@ -0,0 +1,43 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +.. default-domain:: cpp +.. highlight:: cpp + +User Guide +========== + +.. toctree:: + + overview + memory + arrays + datatypes + tables + compute + streaming_execution + io + ipc + orc + parquet + csv + json + dataset + flight + gdb + threading + env_vars From afd3c40a42aa330de61860257a8a09b3c2fe93c8 Mon Sep 17 00:00:00 2001 From: James Bourbeau Date: Wed, 21 Sep 2022 02:52:53 -0500 Subject: [PATCH 111/133] ARROW-16838: [Python] Improve schema inference for pandas indexes with extension dtypes (#14080) Possible fix for https://issues.apache.org/jira/browse/ARROW-16838. `pd.Index` objects don't have a `.head` method, while `pd.DataFrame`, `pd.Series`, and `pd.Index` all support indexing with `[:0]` to return a empty object of the same type. Authored-by: James Bourbeau Signed-off-by: Joris Van den Bossche --- python/pyarrow/pandas_compat.py | 4 +++- python/pyarrow/tests/test_schema.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyarrow/pandas_compat.py b/python/pyarrow/pandas_compat.py index 689cbca6b71..9fa7a699efb 100644 --- a/python/pyarrow/pandas_compat.py +++ b/python/pyarrow/pandas_compat.py @@ -541,7 +541,9 @@ def dataframe_to_types(df, preserve_index, columns=None): if _pandas_api.is_categorical(values): type_ = pa.array(c, from_pandas=True).type elif _pandas_api.is_extension_array_dtype(values): - type_ = pa.array(c.head(0), from_pandas=True).type + empty = c.head(0) if isinstance( + c, _pandas_api.pd.Series) else c[:0] + type_ = pa.array(empty, from_pandas=True).type else: values, type_ = get_datetimetz_type(values, c.dtype, None) type_ = pa.lib._ndarray_to_arrow_type(values, type_) diff --git a/python/pyarrow/tests/test_schema.py b/python/pyarrow/tests/test_schema.py index f26eaaf5fc1..0547d850d36 100644 --- a/python/pyarrow/tests/test_schema.py +++ b/python/pyarrow/tests/test_schema.py @@ -663,7 +663,7 @@ def test_schema_from_pandas(): if Version(pd.__version__) >= Version('1.0.0'): inputs.append(pd.array([1, 2, None], dtype=pd.Int32Dtype())) for data in inputs: - df = pd.DataFrame({'a': data}) + df = pd.DataFrame({'a': data}, index=data) schema = pa.Schema.from_pandas(df) expected = pa.Table.from_pandas(df).schema assert schema == expected From 91ee6dad722ee154d63eea86ce5644e1e658b53b Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Wed, 21 Sep 2022 10:11:50 +0200 Subject: [PATCH 112/133] ARROW-17693: [C++] Remove string_view backport (#14177) Authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- .github/workflows/cpp.yml | 3 +- LICENSE.txt | 28 - c_glib/arrow-glib/compute.cpp | 2 +- c_glib/arrow-glib/input-stream.cpp | 7 +- c_glib/arrow-glib/scalar.cpp | 5 +- ci/docker/debian-11-cpp.dockerfile | 1 + ci/scripts/java_jni_macos_build.sh | 1 + ci/scripts/java_jni_manylinux_build.sh | 1 + cpp/cmake_modules/ThirdpartyToolchain.cmake | 3 +- cpp/examples/arrow/join_example.cc | 2 +- cpp/examples/arrow/rapidjson_row_converter.cc | 2 +- .../stream_reader_writer.cc | 6 +- cpp/gdb_arrow.py | 21 +- cpp/src/arrow/adapters/orc/util.cc | 10 +- cpp/src/arrow/array/array_binary.h | 18 +- cpp/src/arrow/array/array_binary_test.cc | 12 +- cpp/src/arrow/array/array_dict_test.cc | 16 +- cpp/src/arrow/array/array_list_test.cc | 6 +- cpp/src/arrow/array/array_test.cc | 4 +- cpp/src/arrow/array/builder_base.cc | 2 +- cpp/src/arrow/array/builder_binary.cc | 6 +- cpp/src/arrow/array/builder_binary.h | 28 +- cpp/src/arrow/array/builder_decimal.cc | 4 +- cpp/src/arrow/array/builder_decimal.h | 4 +- cpp/src/arrow/array/builder_dict.cc | 4 +- cpp/src/arrow/array/builder_dict.h | 16 +- cpp/src/arrow/array/dict_internal.h | 2 +- cpp/src/arrow/array/diff.cc | 2 +- cpp/src/arrow/array/validate.cc | 4 +- cpp/src/arrow/buffer.h | 13 +- cpp/src/arrow/buffer_test.cc | 4 +- cpp/src/arrow/builder_benchmark.cc | 4 +- cpp/src/arrow/c/bridge.cc | 20 +- cpp/src/arrow/c/bridge_test.cc | 2 +- cpp/src/arrow/compute/exec/asof_join_node.cc | 8 +- .../arrow/compute/exec/asof_join_node_test.cc | 28 +- cpp/src/arrow/compute/exec/expression.cc | 8 +- cpp/src/arrow/compute/exec/hash_join_dict.cc | 2 +- .../arrow/compute/exec/hash_join_node_test.cc | 2 +- cpp/src/arrow/compute/exec/subtree_test.cc | 12 +- cpp/src/arrow/compute/exec/test_util.cc | 11 +- cpp/src/arrow/compute/exec/test_util.h | 12 +- cpp/src/arrow/compute/exec/tpch_node_test.cc | 42 +- .../arrow/compute/kernels/aggregate_basic.cc | 8 +- .../kernels/aggregate_basic_internal.h | 8 +- .../arrow/compute/kernels/aggregate_test.cc | 4 +- .../arrow/compute/kernels/codegen_internal.h | 24 +- cpp/src/arrow/compute/kernels/common.h | 2 +- .../compute/kernels/copy_data_internal.h | 2 +- .../arrow/compute/kernels/hash_aggregate.cc | 6 +- cpp/src/arrow/compute/kernels/row_encoder.cc | 4 +- cpp/src/arrow/compute/kernels/row_encoder.h | 4 +- .../compute/kernels/scalar_arithmetic_test.cc | 2 +- .../compute/kernels/scalar_cast_string.cc | 8 +- .../arrow/compute/kernels/scalar_compare.cc | 2 +- .../compute/kernels/scalar_compare_test.cc | 4 +- .../arrow/compute/kernels/scalar_if_else.cc | 12 +- .../compute/kernels/scalar_set_lookup_test.cc | 2 +- .../compute/kernels/scalar_string_ascii.cc | 89 +- .../compute/kernels/scalar_string_internal.h | 6 +- .../compute/kernels/scalar_string_utf8.cc | 8 +- .../compute/kernels/scalar_temporal_unary.cc | 4 +- cpp/src/arrow/compute/kernels/vector_hash.cc | 2 +- .../compute/kernels/vector_selection_test.cc | 2 +- cpp/src/arrow/csv/chunker.cc | 12 +- cpp/src/arrow/csv/converter.cc | 12 +- cpp/src/arrow/csv/converter_test.cc | 2 +- cpp/src/arrow/csv/invalid_row.h | 5 +- cpp/src/arrow/csv/parser.cc | 15 +- cpp/src/arrow/csv/parser.h | 10 +- cpp/src/arrow/csv/parser_benchmark.cc | 6 +- cpp/src/arrow/csv/parser_test.cc | 46 +- cpp/src/arrow/csv/reader.cc | 14 +- cpp/src/arrow/csv/test_common.cc | 2 +- cpp/src/arrow/csv/writer.cc | 14 +- cpp/src/arrow/dataset/dataset_writer.cc | 10 +- cpp/src/arrow/dataset/dataset_writer_test.cc | 2 +- cpp/src/arrow/dataset/discovery.cc | 11 +- cpp/src/arrow/dataset/file_csv.cc | 10 +- cpp/src/arrow/dataset/partition.cc | 12 +- cpp/src/arrow/dataset/scanner_test.cc | 5 +- cpp/src/arrow/dataset/test_util.h | 8 +- .../engine/simple_extension_type_internal.h | 20 +- .../engine/substrait/expression_internal.cc | 5 +- cpp/src/arrow/engine/substrait/ext_test.cc | 24 +- .../arrow/engine/substrait/extension_set.cc | 55 +- .../arrow/engine/substrait/extension_set.h | 26 +- .../arrow/engine/substrait/extension_types.cc | 13 +- .../arrow/engine/substrait/extension_types.h | 3 - .../arrow/engine/substrait/plan_internal.cc | 14 +- .../engine/substrait/relation_internal.cc | 6 +- cpp/src/arrow/engine/substrait/serde.cc | 13 +- cpp/src/arrow/engine/substrait/serde.h | 10 +- cpp/src/arrow/engine/substrait/serde_test.cc | 6 +- cpp/src/arrow/engine/substrait/util.cc | 2 +- cpp/src/arrow/filesystem/filesystem.cc | 4 +- cpp/src/arrow/filesystem/gcsfs.cc | 4 +- cpp/src/arrow/filesystem/gcsfs_internal.cc | 2 +- cpp/src/arrow/filesystem/gcsfs_internal.h | 2 +- cpp/src/arrow/filesystem/gcsfs_test.cc | 7 +- cpp/src/arrow/filesystem/localfs.cc | 2 +- cpp/src/arrow/filesystem/localfs_benchmark.cc | 2 +- cpp/src/arrow/filesystem/mockfs.cc | 12 +- cpp/src/arrow/filesystem/mockfs.h | 6 +- cpp/src/arrow/filesystem/path_util.cc | 47 +- cpp/src/arrow/filesystem/path_util.h | 30 +- cpp/src/arrow/filesystem/s3_internal.h | 6 +- cpp/src/arrow/filesystem/s3fs.cc | 4 +- cpp/src/arrow/filesystem/util_internal.cc | 8 +- cpp/src/arrow/filesystem/util_internal.h | 10 +- cpp/src/arrow/flight/cookie_internal.cc | 4 +- cpp/src/arrow/flight/cookie_internal.h | 4 +- cpp/src/arrow/flight/flight_internals_test.cc | 8 +- cpp/src/arrow/flight/flight_test.cc | 8 +- .../integration_tests/test_integration.cc | 4 +- cpp/src/arrow/flight/middleware.h | 4 +- cpp/src/arrow/flight/server.cc | 2 +- .../sqlite_tables_schema_batch_reader.cc | 3 +- .../flight/transport/grpc/grpc_client.cc | 27 +- .../flight/transport/grpc/grpc_server.cc | 4 +- .../arrow/flight/transport/ucx/ucx_client.cc | 2 +- .../flight/transport/ucx/ucx_internal.cc | 12 +- .../arrow/flight/transport/ucx/ucx_internal.h | 10 +- .../arrow/flight/transport/ucx/ucx_server.cc | 2 +- cpp/src/arrow/flight/types.cc | 24 +- cpp/src/arrow/flight/types.h | 22 +- cpp/src/arrow/io/buffered.cc | 10 +- cpp/src/arrow/io/buffered.h | 4 +- cpp/src/arrow/io/buffered_test.cc | 8 +- cpp/src/arrow/io/concurrency.h | 12 +- cpp/src/arrow/io/file_test.cc | 4 +- cpp/src/arrow/io/interfaces.cc | 6 +- cpp/src/arrow/io/interfaces.h | 6 +- cpp/src/arrow/io/memory.cc | 8 +- cpp/src/arrow/io/memory.h | 8 +- cpp/src/arrow/io/memory_test.cc | 14 +- cpp/src/arrow/io/slow.cc | 4 +- cpp/src/arrow/io/slow.h | 4 +- cpp/src/arrow/ipc/json_simple.cc | 26 +- cpp/src/arrow/ipc/json_simple.h | 15 +- cpp/src/arrow/ipc/read_write_test.cc | 4 +- cpp/src/arrow/json/chunked_builder_test.cc | 2 +- cpp/src/arrow/json/chunker.cc | 8 +- cpp/src/arrow/json/chunker_test.cc | 12 +- cpp/src/arrow/json/converter.cc | 4 +- cpp/src/arrow/json/object_parser.cc | 4 +- cpp/src/arrow/json/object_parser.h | 4 +- cpp/src/arrow/json/object_writer.cc | 9 +- cpp/src/arrow/json/object_writer.h | 6 +- cpp/src/arrow/json/parser.cc | 6 +- cpp/src/arrow/json/parser_test.cc | 4 +- cpp/src/arrow/json/reader.cc | 4 +- cpp/src/arrow/json/reader_test.cc | 4 +- cpp/src/arrow/json/test_common.h | 8 +- cpp/src/arrow/pretty_print.cc | 12 +- cpp/src/arrow/scalar.cc | 14 +- cpp/src/arrow/scalar.h | 22 +- cpp/src/arrow/scalar_test.cc | 21 +- cpp/src/arrow/stl_iterator_test.cc | 8 +- cpp/src/arrow/testing/gtest_util.cc | 14 +- cpp/src/arrow/testing/gtest_util.h | 28 +- cpp/src/arrow/testing/json_internal.cc | 6 +- cpp/src/arrow/testing/matchers.h | 4 +- cpp/src/arrow/testing/random_test.cc | 2 +- cpp/src/arrow/type.cc | 14 +- cpp/src/arrow/util/base64.h | 6 +- cpp/src/arrow/util/bitmap.h | 5 +- cpp/src/arrow/util/bitmap_reader.h | 1 + cpp/src/arrow/util/bitset_stack.h | 2 +- .../util/{string_view.h => bytes_view.h} | 13 +- cpp/src/arrow/util/decimal.cc | 42 +- cpp/src/arrow/util/decimal.h | 14 +- cpp/src/arrow/util/delimiting.cc | 30 +- cpp/src/arrow/util/delimiting.h | 10 +- cpp/src/arrow/util/formatting.h | 14 +- cpp/src/arrow/util/formatting_util_test.cc | 2 +- cpp/src/arrow/util/hashing.h | 24 +- cpp/src/arrow/util/hashing_test.cc | 2 +- cpp/src/arrow/util/reflection_internal.h | 8 +- cpp/src/arrow/util/reflection_test.cc | 21 +- cpp/src/arrow/util/string.cc | 34 +- cpp/src/arrow/util/string.h | 41 +- cpp/src/arrow/util/string_test.cc | 28 + cpp/src/arrow/util/trie.cc | 6 +- cpp/src/arrow/util/trie.h | 32 +- cpp/src/arrow/util/trie_benchmark.cc | 2 +- cpp/src/arrow/util/trie_test.cc | 12 +- cpp/src/arrow/util/uri.cc | 16 +- cpp/src/arrow/util/uri.h | 8 +- cpp/src/arrow/util/utf8.cc | 2 +- cpp/src/arrow/util/utf8.h | 4 +- cpp/src/arrow/util/utf8_internal.h | 6 +- cpp/src/arrow/util/value_parsing_benchmark.cc | 4 +- cpp/src/arrow/vendored/base64.cpp | 4 +- cpp/src/arrow/vendored/string_view.hpp | 1531 ----------------- cpp/src/arrow/visit_data_inline.h | 19 +- cpp/src/gandiva/gdv_function_stubs.cc | 8 +- cpp/src/gandiva/gdv_string_function_stubs.cc | 6 +- cpp/src/gandiva/in_holder.h | 8 +- cpp/src/parquet/arrow/reader_internal.cc | 2 +- cpp/src/parquet/arrow/schema.cc | 6 +- cpp/src/parquet/encoding.cc | 12 +- cpp/src/parquet/encryption/crypto_factory.cc | 9 +- .../encryption/key_toolkit_internal.cc | 2 +- cpp/src/parquet/metadata.cc | 6 +- cpp/src/parquet/reader_test.cc | 8 +- cpp/src/parquet/stream_writer.cc | 2 +- cpp/src/parquet/stream_writer.h | 6 +- cpp/src/parquet/types.cc | 2 +- cpp/src/parquet/types.h | 8 +- docs/source/cpp/gdb.rst | 1 - python/pyarrow/src/arrow_to_pandas.cc | 6 +- python/pyarrow/src/common.h | 1 + python/pyarrow/src/datetime.cc | 13 +- python/pyarrow/src/gdb.cc | 10 +- python/pyarrow/tests/test_gdb.py | 14 - r/src/altrep.cpp | 2 +- r/src/array_to_vector.cpp | 6 +- 218 files changed, 1077 insertions(+), 2627 deletions(-) rename cpp/src/arrow/util/{string_view.h => bytes_view.h} (72%) delete mode 100644 cpp/src/arrow/vendored/string_view.hpp diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index 8f4702dacb6..2642a6ec1a2 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -146,8 +146,7 @@ jobs: ARROW_WITH_SNAPPY: ON ARROW_WITH_ZLIB: ON ARROW_WITH_ZSTD: ON - # System Abseil installed by Homebrew uses C++ 17 - CMAKE_CXX_STANDARD: 17 + GTest_SOURCE: BUNDLED steps: - name: Checkout Arrow uses: actions/checkout@v3 diff --git a/LICENSE.txt b/LICENSE.txt index 6532b8790c3..86cfaf546ca 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -653,34 +653,6 @@ SOFTWARE. -------------------------------------------------------------------------------- -The file cpp/src/arrow/vendored/string_view.hpp has the following license - -Boost Software License - Version 1.0 - August 17th, 2003 - -Permission is hereby granted, free of charge, to any person or organization -obtaining a copy of the software and accompanying documentation covered by -this license (the "Software") to use, reproduce, display, distribute, -execute, and transmit the Software, and to prepare derivative works of the -Software, and to permit third-parties to whom the Software is furnished to -do so, all subject to the following: - -The copyright notices in the Software and this entire statement, including -the above license grant, this restriction and the following disclaimer, -must be included in all copies of the Software, in whole or in part, and -all derivative works of the Software, unless such copies or derivative -works are solely in the form of machine-executable object code generated by -a source language processor. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT -SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE -FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, -ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER -DEALINGS IN THE SOFTWARE. - --------------------------------------------------------------------------------- - The files in cpp/src/arrow/vendored/xxhash/ have the following license (BSD 2-Clause License) diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index 3404af794de..3554fdf1158 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -5121,7 +5121,7 @@ GArrowFunctionOptions * garrow_function_options_new_raw( const arrow::compute::FunctionOptions *arrow_options) { - arrow::util::string_view arrow_type_name(arrow_options->type_name()); + std::string_view arrow_type_name(arrow_options->type_name()); if (arrow_type_name == "CastOptions") { auto arrow_cast_options = static_cast(arrow_options); diff --git a/c_glib/arrow-glib/input-stream.cpp b/c_glib/arrow-glib/input-stream.cpp index e1e46c7df10..844c83d629b 100644 --- a/c_glib/arrow-glib/input-stream.cpp +++ b/c_glib/arrow-glib/input-stream.cpp @@ -20,7 +20,6 @@ #include #include #include -#include #include #include @@ -34,6 +33,7 @@ #include #include +#include G_BEGIN_DECLS @@ -855,7 +855,7 @@ namespace garrow { } } - arrow::Result Peek(int64_t nbytes) override { + arrow::Result Peek(int64_t nbytes) override { if (!G_IS_BUFFERED_INPUT_STREAM(input_stream_)) { std::string message("[gio-input-stream][peek] " "not peekable input stream: <"); @@ -882,8 +882,7 @@ namespace garrow { if (data_size > static_cast(nbytes)) { data_size = nbytes; } - return arrow::util::string_view(static_cast(data), - data_size); + return std::string_view(static_cast(data), data_size); } arrow::Status Seek(int64_t position) override { diff --git a/c_glib/arrow-glib/scalar.cpp b/c_glib/arrow-glib/scalar.cpp index f8699f34eea..ddd35d9ea60 100644 --- a/c_glib/arrow-glib/scalar.cpp +++ b/c_glib/arrow-glib/scalar.cpp @@ -250,9 +250,8 @@ garrow_scalar_parse(GArrowDataType *data_type, GError **error) { const auto arrow_data_type = garrow_data_type_get_raw(data_type); - auto arrow_data = - arrow::util::string_view(reinterpret_cast(data), - size); + auto arrow_data = std::string_view(reinterpret_cast(data), + size); auto arrow_scalar_result = arrow::Scalar::Parse(arrow_data_type, arrow_data); if (garrow::check(error, arrow_scalar_result, "[scalar][parse]")) { auto arrow_scalar = *arrow_scalar_result; diff --git a/ci/docker/debian-11-cpp.dockerfile b/ci/docker/debian-11-cpp.dockerfile index 5051ae7f003..b205f14f6da 100644 --- a/ci/docker/debian-11-cpp.dockerfile +++ b/ci/docker/debian-11-cpp.dockerfile @@ -106,6 +106,7 @@ ENV absl_SOURCE=BUNDLED \ CC=gcc \ CXX=g++ \ google_cloud_cpp_storage_SOURCE=BUNDLED \ + GTest_SOURCE=BUNDLED \ ORC_SOURCE=BUNDLED \ PATH=/usr/lib/ccache/:$PATH \ Protobuf_SOURCE=BUNDLED \ diff --git a/ci/scripts/java_jni_macos_build.sh b/ci/scripts/java_jni_macos_build.sh index 8923b851042..66100e5d0a0 100755 --- a/ci/scripts/java_jni_macos_build.sh +++ b/ci/scripts/java_jni_macos_build.sh @@ -74,6 +74,7 @@ cmake \ -DCMAKE_INSTALL_LIBDIR=lib \ -DCMAKE_INSTALL_PREFIX=${install_dir} \ -DCMAKE_UNITY_BUILD=${CMAKE_UNITY_BUILD} \ + -DGTest_SOURCE=BUNDLED \ -DPARQUET_BUILD_EXAMPLES=OFF \ -DPARQUET_BUILD_EXECUTABLES=OFF \ -DPARQUET_REQUIRE_ENCRYPTION=OFF \ diff --git a/ci/scripts/java_jni_manylinux_build.sh b/ci/scripts/java_jni_manylinux_build.sh index 7f6e89cb7a3..7cbed90dd2e 100755 --- a/ci/scripts/java_jni_manylinux_build.sh +++ b/ci/scripts/java_jni_manylinux_build.sh @@ -84,6 +84,7 @@ cmake \ -DCMAKE_INSTALL_LIBDIR=lib \ -DCMAKE_INSTALL_PREFIX=${ARROW_HOME} \ -DCMAKE_UNITY_BUILD=${CMAKE_UNITY_BUILD} \ + -DGTest_SOURCE=BUNDLED \ -DORC_SOURCE=BUNDLED \ -DORC_PROTOBUF_EXECUTABLE=${VCPKG_ROOT}/installed/${VCPKG_TARGET_TRIPLET}/tools/protobuf/protoc \ -DPARQUET_BUILD_EXAMPLES=OFF \ diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 0ab66c5bf07..235369caf2f 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -1974,9 +1974,8 @@ macro(build_gtest) set(dummy ">") set(GTEST_CMAKE_ARGS - ${EP_COMMON_TOOLCHAIN} + ${EP_COMMON_CMAKE_ARGS} -DBUILD_SHARED_LIBS=ON - -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DCMAKE_CXX_FLAGS=${GTEST_CMAKE_CXX_FLAGS} -DCMAKE_CXX_FLAGS_${UPPERCASE_BUILD_TYPE}=${GTEST_CMAKE_CXX_FLAGS} -DCMAKE_INSTALL_LIBDIR=lib diff --git a/cpp/examples/arrow/join_example.cc b/cpp/examples/arrow/join_example.cc index 7bea588e3ad..c29f5e5dbbd 100644 --- a/cpp/examples/arrow/join_example.cc +++ b/cpp/examples/arrow/join_example.cc @@ -63,7 +63,7 @@ arrow::Result> CreateDataSetFromCSVData std::shared_ptr input; std::string csv_data = is_left ? kLeftRelationCsvData : kRightRelationCsvData; std::cout << csv_data << std::endl; - arrow::util::string_view sv = csv_data; + std::string_view sv = csv_data; input = std::make_shared(sv); auto read_options = arrow::csv::ReadOptions::Defaults(); auto parse_options = arrow::csv::ParseOptions::Defaults(); diff --git a/cpp/examples/arrow/rapidjson_row_converter.cc b/cpp/examples/arrow/rapidjson_row_converter.cc index defa6de4610..3907e72121c 100644 --- a/cpp/examples/arrow/rapidjson_row_converter.cc +++ b/cpp/examples/arrow/rapidjson_row_converter.cc @@ -97,7 +97,7 @@ class RowBatchBuilder { for (int64_t i = 0; i < array.length(); ++i) { if (!array.IsNull(i)) { rapidjson::Value str_key(field_->name(), rows_[i].GetAllocator()); - arrow::util::string_view value_view = array.Value(i); + std::string_view value_view = array.Value(i); rapidjson::Value value; value.SetString(value_view.data(), static_cast(value_view.size()), diff --git a/cpp/examples/parquet/parquet_stream_api/stream_reader_writer.cc b/cpp/examples/parquet/parquet_stream_api/stream_reader_writer.cc index 64ab7af4962..1f7246b7816 100644 --- a/cpp/examples/parquet/parquet_stream_api/stream_reader_writer.cc +++ b/cpp/examples/parquet/parquet_stream_api/stream_reader_writer.cc @@ -135,10 +135,10 @@ struct TestData { if (i % 2 == 0) return {}; return "Str #" + std::to_string(i); } - static arrow::util::string_view GetStringView(const int i) { + static std::string_view GetStringView(const int i) { static std::string string; string = "StringView #" + std::to_string(i); - return arrow::util::string_view(string); + return std::string_view(string); } static const char* GetCharPtr(const int i) { static std::string string; @@ -190,7 +190,7 @@ void WriteParquetFile() { os.SetMaxRowGroupSize(1000); for (auto i = 0; i < TestData::num_rows; ++i) { - // Output string using 3 different types: std::string, arrow::util::string_view and + // Output string using 3 different types: std::string, std::string_view and // const char *. switch (i % 3) { case 0: diff --git a/cpp/gdb_arrow.py b/cpp/gdb_arrow.py index 5d7e1719afe..6c3af1680bd 100644 --- a/cpp/gdb_arrow.py +++ b/cpp/gdb_arrow.py @@ -456,7 +456,7 @@ def value(self): class StdString: """ - A `std::string` (or possibly `string_view`) value. + A `std::string` (or possibly `std::string_view`) value. """ def __init__(self, val): @@ -2163,23 +2163,6 @@ def to_string(self): return f"arrow::Result<{data_type}>({inner})" -class StringViewPrinter: - """ - Pretty-printer for arrow::util::string_view. - """ - - def __init__(self, name, val): - self.val = val - - def to_string(self): - size = int(self.val['size_']) - if size == 0: - return f"arrow::util::string_view of size 0" - else: - data = bytes_literal(self.val['data_'], size) - return f"arrow::util::string_view of size {size}, {data}" - - class FieldPrinter: """ Pretty-printer for arrow::Field. @@ -2397,8 +2380,6 @@ def to_string(self): "arrow::SimpleTable": TablePrinter, "arrow::Status": StatusPrinter, "arrow::Table": TablePrinter, - "arrow::util::string_view": StringViewPrinter, - "nonstd::sv_lite::basic_string_view": StringViewPrinter, } diff --git a/cpp/src/arrow/adapters/orc/util.cc b/cpp/src/arrow/adapters/orc/util.cc index dbdb110fb46..170aaa18155 100644 --- a/cpp/src/arrow/adapters/orc/util.cc +++ b/cpp/src/arrow/adapters/orc/util.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include "arrow/array/builder_base.h" @@ -30,7 +31,6 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/decimal.h" #include "arrow/util/range.h" -#include "arrow/util/string_view.h" #include "arrow/visit_data_inline.h" #include "orc/Exceptions.hh" @@ -462,7 +462,7 @@ struct Appender { running_arrow_offset++; return Status::OK(); } - Status VisitValue(util::string_view v) { + Status VisitValue(std::string_view v) { batch->notNull[running_orc_offset] = true; COffsetType data_length = 0; batch->data[running_orc_offset] = reinterpret_cast( @@ -486,7 +486,7 @@ struct Appender { running_arrow_offset++; return Status::OK(); } - Status VisitValue(util::string_view v) { + Status VisitValue(std::string_view v) { batch->notNull[running_orc_offset] = true; const Decimal128 dec_value(array.GetValue(running_arrow_offset)); batch->values[running_orc_offset] = static_cast(dec_value.low_bits()); @@ -507,7 +507,7 @@ struct Appender { running_arrow_offset++; return Status::OK(); } - Status VisitValue(util::string_view v) { + Status VisitValue(std::string_view v) { batch->notNull[running_orc_offset] = true; const Decimal128 dec_value(array.GetValue(running_arrow_offset)); batch->values[running_orc_offset] = @@ -557,7 +557,7 @@ struct FixedSizeBinaryAppender { running_arrow_offset++; return Status::OK(); } - Status VisitValue(util::string_view v) { + Status VisitValue(std::string_view v) { batch->notNull[running_orc_offset] = true; batch->data[running_orc_offset] = reinterpret_cast( const_cast(array.GetValue(running_arrow_offset))); diff --git a/cpp/src/arrow/array/array_binary.h b/cpp/src/arrow/array/array_binary.h index cc04d792002..7e58a96ff84 100644 --- a/cpp/src/arrow/array/array_binary.h +++ b/cpp/src/arrow/array/array_binary.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include "arrow/array/array_base.h" @@ -32,7 +33,6 @@ #include "arrow/type.h" #include "arrow/util/checked_cast.h" #include "arrow/util/macros.h" -#include "arrow/util/string_view.h" // IWYU pragma: export #include "arrow/util/visibility.h" namespace arrow { @@ -67,15 +67,15 @@ class BaseBinaryArray : public FlatArray { /// /// \param i the value index /// \return the view over the selected value - util::string_view GetView(int64_t i) const { + std::string_view GetView(int64_t i) const { // Account for base offset i += data_->offset; const offset_type pos = raw_value_offsets_[i]; - return util::string_view(reinterpret_cast(raw_data_ + pos), - raw_value_offsets_[i + 1] - pos); + return std::string_view(reinterpret_cast(raw_data_ + pos), + raw_value_offsets_[i + 1] - pos); } - std::optional operator[](int64_t i) const { + std::optional operator[](int64_t i) const { return *IteratorType(*this, i); } @@ -84,7 +84,7 @@ class BaseBinaryArray : public FlatArray { /// /// \param i the value index /// \return the view over the selected value - util::string_view Value(int64_t i) const { return GetView(i); } + std::string_view Value(int64_t i) const { return GetView(i); } /// \brief Get binary value as a std::string /// @@ -236,11 +236,11 @@ class ARROW_EXPORT FixedSizeBinaryArray : public PrimitiveArray { const uint8_t* GetValue(int64_t i) const; const uint8_t* Value(int64_t i) const { return GetValue(i); } - util::string_view GetView(int64_t i) const { - return util::string_view(reinterpret_cast(GetValue(i)), byte_width()); + std::string_view GetView(int64_t i) const { + return std::string_view(reinterpret_cast(GetValue(i)), byte_width()); } - std::optional operator[](int64_t i) const { + std::optional operator[](int64_t i) const { return *IteratorType(*this, i); } diff --git a/cpp/src/arrow/array/array_binary_test.cc b/cpp/src/arrow/array/array_binary_test.cc index b7225eb8b7d..3bc9bb91a02 100644 --- a/cpp/src/arrow/array/array_binary_test.cc +++ b/cpp/src/arrow/array/array_binary_test.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -37,7 +38,6 @@ #include "arrow/util/bit_util.h" #include "arrow/util/bitmap_builders.h" #include "arrow/util/checked_cast.h" -#include "arrow/util/string_view.h" #include "arrow/visit_data_inline.h" namespace arrow { @@ -63,7 +63,7 @@ void CheckStringArray(const ArrayType& array, const std::vector& st auto view = array.GetView(i); ASSERT_EQ(value_pos, array.value_offset(i)); ASSERT_EQ(strings[j].size(), view.size()); - ASSERT_EQ(util::string_view(strings[j]), view); + ASSERT_EQ(std::string_view(strings[j]), view); value_pos += static_cast(view.size()); } else { ASSERT_TRUE(array.IsNull(i)); @@ -256,7 +256,7 @@ class TestStringArray : public ::testing::Test { } Status ValidateFull(int64_t length, std::vector offsets, - util::string_view data, int64_t offset = 0) { + std::string_view data, int64_t offset = 0) { ArrayType arr(length, Buffer::Wrap(offsets), std::make_shared(data), /*null_bitmap=*/nullptr, /*null_count=*/0, offset); return arr.ValidateFull(); @@ -373,7 +373,7 @@ class TestUTF8Array : public ::testing::Test { using ArrayType = typename TypeTraits::ArrayType; Status ValidateUTF8(int64_t length, std::vector offsets, - util::string_view data, int64_t offset = 0) { + std::string_view data, int64_t offset = 0) { ArrayType arr(length, Buffer::Wrap(offsets), std::make_shared(data), /*null_bitmap=*/nullptr, /*null_count=*/0, offset); return arr.ValidateUTF8(); @@ -867,12 +867,12 @@ struct BinaryAppender { return Status::OK(); } - Status VisitValue(util::string_view v) { + Status VisitValue(std::string_view v) { data.push_back(v); return Status::OK(); } - std::vector data; + std::vector data; }; template diff --git a/cpp/src/arrow/array/array_dict_test.cc b/cpp/src/arrow/array/array_dict_test.cc index 9193e1d21ac..bfa732f165f 100644 --- a/cpp/src/arrow/array/array_dict_test.cc +++ b/cpp/src/arrow/array/array_dict_test.cc @@ -711,7 +711,7 @@ TEST(TestFixedSizeBinaryDictionaryBuilder, ArrayInit) { // Build the dictionary Array auto value_type = fixed_size_binary(4); auto dict_array = ArrayFromJSON(value_type, R"(["abcd", "wxyz"])"); - util::string_view test = "abcd", test2 = "wxyz"; + std::string_view test = "abcd", test2 = "wxyz"; DictionaryBuilder builder(dict_array); ASSERT_OK(builder.Append(test)); ASSERT_OK(builder.Append(test2)); @@ -735,7 +735,7 @@ TEST(TestFixedSizeBinaryDictionaryBuilder, MakeBuilder) { std::unique_ptr boxed_builder; ASSERT_OK(MakeBuilder(default_memory_pool(), dict_type, &boxed_builder)); auto& builder = checked_cast&>(*boxed_builder); - util::string_view test = "abcd", test2 = "wxyz"; + std::string_view test = "abcd", test2 = "wxyz"; ASSERT_OK(builder.Append(test)); ASSERT_OK(builder.Append(test2)); ASSERT_OK(builder.Append(test)); @@ -1317,12 +1317,12 @@ TEST(TestDictionary, ListOfDictionary) { ASSERT_OK(list_builder->Append()); std::vector expected; - for (char a : util::string_view("abc")) { - for (char d : util::string_view("def")) { - for (char g : util::string_view("ghi")) { - for (char j : util::string_view("jkl")) { - for (char m : util::string_view("mno")) { - for (char p : util::string_view("pqr")) { + for (char a : std::string_view("abc")) { + for (char d : std::string_view("def")) { + for (char g : std::string_view("ghi")) { + for (char j : std::string_view("jkl")) { + for (char m : std::string_view("mno")) { + for (char p : std::string_view("pqr")) { if ((static_cast(a) + d + g + j + m + p) % 16 == 0) { ASSERT_OK(list_builder->Append()); } diff --git a/cpp/src/arrow/array/array_list_test.cc b/cpp/src/arrow/array/array_list_test.cc index 373b71b85f3..f8c24b71e06 100644 --- a/cpp/src/arrow/array/array_list_test.cc +++ b/cpp/src/arrow/array/array_list_test.cc @@ -647,11 +647,11 @@ TEST_F(TestMapArray, Equality) { std::shared_ptr array, equal_array, unequal_array; std::vector equal_offsets = {0, 1, 2, 5, 6, 7, 8, 10}; - std::vector equal_keys = {"a", "a", "a", "b", "c", - "a", "a", "a", "a", "b"}; + std::vector equal_keys = {"a", "a", "a", "b", "c", + "a", "a", "a", "a", "b"}; std::vector equal_values = {1, 2, 3, 4, 5, 2, 2, 2, 5, 6}; std::vector unequal_offsets = {0, 1, 4, 7}; - std::vector unequal_keys = {"a", "a", "b", "c", "a", "b", "c"}; + std::vector unequal_keys = {"a", "a", "b", "c", "a", "b", "c"}; std::vector unequal_values = {1, 2, 2, 2, 3, 4, 5}; // setup two equal arrays diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index 9256d4ad0b7..c00e54ecb80 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -2254,12 +2254,12 @@ struct FWBinaryAppender { return Status::OK(); } - Status VisitValue(util::string_view v) { + Status VisitValue(std::string_view v) { data.push_back(v); return Status::OK(); } - std::vector data; + std::vector data; }; TEST_F(TestFWBinaryArray, ArraySpanVisitor) { diff --git a/cpp/src/arrow/array/builder_base.cc b/cpp/src/arrow/array/builder_base.cc index ff37cee5ba1..e9d5fb44ac1 100644 --- a/cpp/src/arrow/array/builder_base.cc +++ b/cpp/src/arrow/array/builder_base.cc @@ -144,7 +144,7 @@ struct AppendScalarImpl { raw++) { auto scalar = checked_cast::ScalarType*>(raw->get()); if (scalar->is_valid) { - builder->UnsafeAppend(util::string_view{*scalar->value}); + builder->UnsafeAppend(std::string_view{*scalar->value}); } else { builder->UnsafeAppendNull(); } diff --git a/cpp/src/arrow/array/builder_binary.cc b/cpp/src/arrow/array/builder_binary.cc index fd1be179816..88c17454034 100644 --- a/cpp/src/arrow/array/builder_binary.cc +++ b/cpp/src/arrow/array/builder_binary.cc @@ -123,10 +123,10 @@ const uint8_t* FixedSizeBinaryBuilder::GetValue(int64_t i) const { return data_ptr + i * byte_width_; } -util::string_view FixedSizeBinaryBuilder::GetView(int64_t i) const { +std::string_view FixedSizeBinaryBuilder::GetView(int64_t i) const { const uint8_t* data_ptr = byte_builder_.data(); - return util::string_view(reinterpret_cast(data_ptr + i * byte_width_), - byte_width_); + return std::string_view(reinterpret_cast(data_ptr + i * byte_width_), + byte_width_); } // ---------------------------------------------------------------------- diff --git a/cpp/src/arrow/array/builder_binary.h b/cpp/src/arrow/array/builder_binary.h index 25cec5c1e25..274baeca748 100644 --- a/cpp/src/arrow/array/builder_binary.h +++ b/cpp/src/arrow/array/builder_binary.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include "arrow/array/array_base.h" @@ -36,7 +37,6 @@ #include "arrow/status.h" #include "arrow/type.h" #include "arrow/util/macros.h" -#include "arrow/util/string_view.h" // IWYU pragma: export #include "arrow/util/visibility.h" namespace arrow { @@ -77,7 +77,7 @@ class BaseBinaryBuilder : public ArrayBuilder { return Append(reinterpret_cast(value), length); } - Status Append(util::string_view value) { + Status Append(std::string_view value) { return Append(value.data(), static_cast(value.size())); } @@ -93,7 +93,7 @@ class BaseBinaryBuilder : public ArrayBuilder { return Status::OK(); } - Status ExtendCurrent(util::string_view value) { + Status ExtendCurrent(std::string_view value) { return ExtendCurrent(reinterpret_cast(value.data()), static_cast(value.size())); } @@ -150,7 +150,7 @@ class BaseBinaryBuilder : public ArrayBuilder { UnsafeAppend(value.c_str(), static_cast(value.size())); } - void UnsafeAppend(util::string_view value) { + void UnsafeAppend(std::string_view value) { UnsafeAppend(value.data(), static_cast(value.size())); } @@ -159,7 +159,7 @@ class BaseBinaryBuilder : public ArrayBuilder { value_data_builder_.UnsafeAppend(value, length); } - void UnsafeExtendCurrent(util::string_view value) { + void UnsafeExtendCurrent(std::string_view value) { UnsafeExtendCurrent(reinterpret_cast(value.data()), static_cast(value.size())); } @@ -370,10 +370,10 @@ class BaseBinaryBuilder : public ArrayBuilder { /// Temporary access to a value. /// /// This view becomes invalid on the next modifying operation. - util::string_view GetView(int64_t i) const { + std::string_view GetView(int64_t i) const { offset_type value_length; const uint8_t* value_data = GetValue(i, &value_length); - return util::string_view(reinterpret_cast(value_data), value_length); + return std::string_view(reinterpret_cast(value_data), value_length); } // Cannot make this a static attribute because of linking issues @@ -476,7 +476,7 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { return Append(reinterpret_cast(value)); } - Status Append(const util::string_view& view) { + Status Append(const std::string_view& view) { ARROW_RETURN_NOT_OK(Reserve(1)); UnsafeAppend(view); return Status::OK(); @@ -490,7 +490,7 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { Status Append(const Buffer& s) { ARROW_RETURN_NOT_OK(Reserve(1)); - UnsafeAppend(util::string_view(s)); + UnsafeAppend(std::string_view(s)); return Status::OK(); } @@ -500,7 +500,7 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { Status Append(const std::array& value) { ARROW_RETURN_NOT_OK(Reserve(1)); UnsafeAppend( - util::string_view(reinterpret_cast(value.data()), value.size())); + std::string_view(reinterpret_cast(value.data()), value.size())); return Status::OK(); } @@ -534,14 +534,14 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { UnsafeAppend(reinterpret_cast(value)); } - void UnsafeAppend(util::string_view value) { + void UnsafeAppend(std::string_view value) { #ifndef NDEBUG CheckValueSize(static_cast(value.size())); #endif UnsafeAppend(reinterpret_cast(value.data())); } - void UnsafeAppend(const Buffer& s) { UnsafeAppend(util::string_view(s)); } + void UnsafeAppend(const Buffer& s) { UnsafeAppend(std::string_view(s)); } void UnsafeAppend(const std::shared_ptr& s) { UnsafeAppend(*s); } @@ -590,7 +590,7 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { /// Temporary access to a value. /// /// This view becomes invalid on the next modifying operation. - util::string_view GetView(int64_t i) const; + std::string_view GetView(int64_t i) const; static constexpr int64_t memory_limit() { return std::numeric_limits::max() - 1; @@ -658,7 +658,7 @@ class ARROW_EXPORT ChunkedBinaryBuilder { return builder_->Append(value, length); } - Status Append(const util::string_view& value) { + Status Append(const std::string_view& value) { return Append(reinterpret_cast(value.data()), static_cast(value.size())); } diff --git a/cpp/src/arrow/array/builder_decimal.cc b/cpp/src/arrow/array/builder_decimal.cc index bd7615a7309..96d6b60932b 100644 --- a/cpp/src/arrow/array/builder_decimal.cc +++ b/cpp/src/arrow/array/builder_decimal.cc @@ -52,7 +52,7 @@ void Decimal128Builder::UnsafeAppend(Decimal128 value) { UnsafeAppendToBitmap(true); } -void Decimal128Builder::UnsafeAppend(util::string_view value) { +void Decimal128Builder::UnsafeAppend(std::string_view value) { FixedSizeBinaryBuilder::UnsafeAppend(value); } @@ -87,7 +87,7 @@ void Decimal256Builder::UnsafeAppend(const Decimal256& value) { UnsafeAppendToBitmap(true); } -void Decimal256Builder::UnsafeAppend(util::string_view value) { +void Decimal256Builder::UnsafeAppend(std::string_view value) { FixedSizeBinaryBuilder::UnsafeAppend(value); } diff --git a/cpp/src/arrow/array/builder_decimal.h b/cpp/src/arrow/array/builder_decimal.h index 3464203dd47..2c8953fdec0 100644 --- a/cpp/src/arrow/array/builder_decimal.h +++ b/cpp/src/arrow/array/builder_decimal.h @@ -47,7 +47,7 @@ class ARROW_EXPORT Decimal128Builder : public FixedSizeBinaryBuilder { Status Append(Decimal128 val); void UnsafeAppend(Decimal128 val); - void UnsafeAppend(util::string_view val); + void UnsafeAppend(std::string_view val); Status FinishInternal(std::shared_ptr* out) override; @@ -77,7 +77,7 @@ class ARROW_EXPORT Decimal256Builder : public FixedSizeBinaryBuilder { Status Append(const Decimal256& val); void UnsafeAppend(const Decimal256& val); - void UnsafeAppend(util::string_view val); + void UnsafeAppend(std::string_view val); Status FinishInternal(std::shared_ptr* out) override; diff --git a/cpp/src/arrow/array/builder_dict.cc b/cpp/src/arrow/array/builder_dict.cc index d51dd4c041a..061fb600412 100644 --- a/cpp/src/arrow/array/builder_dict.cc +++ b/cpp/src/arrow/array/builder_dict.cc @@ -188,12 +188,12 @@ GET_OR_INSERT(MonthIntervalType); #undef GET_OR_INSERT -Status DictionaryMemoTable::GetOrInsert(const BinaryType*, util::string_view value, +Status DictionaryMemoTable::GetOrInsert(const BinaryType*, std::string_view value, int32_t* out) { return impl_->GetOrInsert(value, out); } -Status DictionaryMemoTable::GetOrInsert(const LargeBinaryType*, util::string_view value, +Status DictionaryMemoTable::GetOrInsert(const LargeBinaryType*, std::string_view value, int32_t* out) { return impl_->GetOrInsert(value, out); } diff --git a/cpp/src/arrow/array/builder_dict.h b/cpp/src/arrow/array/builder_dict.h index b720f73d7d2..7bb134ec387 100644 --- a/cpp/src/arrow/array/builder_dict.h +++ b/cpp/src/arrow/array/builder_dict.h @@ -54,7 +54,7 @@ struct DictionaryValue { template struct DictionaryValue> { - using type = util::string_view; + using type = std::string_view; using PhysicalType = typename std::conditional::value, BinaryType, LargeBinaryType>::type; @@ -62,7 +62,7 @@ struct DictionaryValue> { template struct DictionaryValue> { - using type = util::string_view; + using type = std::string_view; using PhysicalType = BinaryType; }; @@ -112,8 +112,8 @@ class ARROW_EXPORT DictionaryMemoTable { Status GetOrInsert(const FloatType*, float value, int32_t* out); Status GetOrInsert(const DoubleType*, double value, int32_t* out); - Status GetOrInsert(const BinaryType*, util::string_view value, int32_t* out); - Status GetOrInsert(const LargeBinaryType*, util::string_view value, int32_t* out); + Status GetOrInsert(const BinaryType*, std::string_view value, int32_t* out); + Status GetOrInsert(const LargeBinaryType*, std::string_view value, int32_t* out); class DictionaryMemoTableImpl; std::unique_ptr impl_; @@ -257,13 +257,13 @@ class DictionaryBuilderBase : public ArrayBuilder { /// \brief Append a fixed-width string (only for FixedSizeBinaryType) template enable_if_fixed_size_binary Append(const uint8_t* value) { - return Append(util::string_view(reinterpret_cast(value), byte_width_)); + return Append(std::string_view(reinterpret_cast(value), byte_width_)); } /// \brief Append a fixed-width string (only for FixedSizeBinaryType) template enable_if_fixed_size_binary Append(const char* value) { - return Append(util::string_view(value, byte_width_)); + return Append(std::string_view(value, byte_width_)); } /// \brief Append a string (only for binary types) @@ -275,13 +275,13 @@ class DictionaryBuilderBase : public ArrayBuilder { /// \brief Append a string (only for binary types) template enable_if_binary_like Append(const char* value, int32_t length) { - return Append(util::string_view(value, length)); + return Append(std::string_view(value, length)); } /// \brief Append a string (only for string types) template enable_if_string_like Append(const char* value, int32_t length) { - return Append(util::string_view(value, length)); + return Append(std::string_view(value, length)); } /// \brief Append a decimal (only for Decimal128Type) diff --git a/cpp/src/arrow/array/dict_internal.h b/cpp/src/arrow/array/dict_internal.h index a8b69133cfe..5245c8d0ff3 100644 --- a/cpp/src/arrow/array/dict_internal.h +++ b/cpp/src/arrow/array/dict_internal.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -34,7 +35,6 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/hashing.h" #include "arrow/util/logging.h" -#include "arrow/util/string_view.h" namespace arrow { namespace internal { diff --git a/cpp/src/arrow/array/diff.cc b/cpp/src/arrow/array/diff.cc index 16f4f9c7638..10802939a73 100644 --- a/cpp/src/arrow/array/diff.cc +++ b/cpp/src/arrow/array/diff.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -43,7 +44,6 @@ #include "arrow/util/logging.h" #include "arrow/util/range.h" #include "arrow/util/string.h" -#include "arrow/util/string_view.h" #include "arrow/vendored/datetime.h" #include "arrow/visit_type_inline.h" diff --git a/cpp/src/arrow/array/validate.cc b/cpp/src/arrow/array/validate.cc index 05155d64b6a..56470ac74b0 100644 --- a/cpp/src/arrow/array/validate.cc +++ b/cpp/src/arrow/array/validate.cc @@ -54,7 +54,7 @@ struct UTF8DataValidator { int64_t i = 0; return VisitArraySpanInline( data, - [&](util::string_view v) { + [&](std::string_view v) { if (ARROW_PREDICT_FALSE(!util::ValidateUTF8(v))) { return Status::Invalid("Invalid UTF8 sequence at string index ", i); } @@ -675,7 +675,7 @@ struct ValidateArrayImpl { const int32_t precision = type.precision(); return VisitArraySpanInline( data, - [&](util::string_view bytes) { + [&](std::string_view bytes) { DCHECK_EQ(bytes.size(), DecimalType::kByteWidth); CType value(reinterpret_cast(bytes.data())); if (!value.FitsInPrecision(precision)) { diff --git a/cpp/src/arrow/buffer.h b/cpp/src/arrow/buffer.h index 8be10d282b0..584a33fbdeb 100644 --- a/cpp/src/arrow/buffer.h +++ b/cpp/src/arrow/buffer.h @@ -21,14 +21,15 @@ #include #include #include +#include #include #include #include "arrow/device.h" #include "arrow/status.h" #include "arrow/type_fwd.h" +#include "arrow/util/bytes_view.h" #include "arrow/util/macros.h" -#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" namespace arrow { @@ -77,7 +78,7 @@ class ARROW_EXPORT Buffer { /// /// \note The memory viewed by data must not be deallocated in the lifetime of the /// Buffer; temporary rvalue strings must be stored in an lvalue somewhere - explicit Buffer(util::string_view data) + explicit Buffer(std::string_view data) : Buffer(reinterpret_cast(data.data()), static_cast(data.size())) {} @@ -159,10 +160,10 @@ class ARROW_EXPORT Buffer { /// \note Can throw std::bad_alloc if buffer is large std::string ToString() const; - /// \brief View buffer contents as a util::string_view - /// \return util::string_view - explicit operator util::string_view() const { - return util::string_view(reinterpret_cast(data_), size_); + /// \brief View buffer contents as a std::string_view + /// \return std::string_view + explicit operator std::string_view() const { + return std::string_view(reinterpret_cast(data_), size_); } /// \brief View buffer contents as a util::bytes_view diff --git a/cpp/src/arrow/buffer_test.cc b/cpp/src/arrow/buffer_test.cc index 724db80eba7..fd159dd9797 100644 --- a/cpp/src/arrow/buffer_test.cc +++ b/cpp/src/arrow/buffer_test.cc @@ -204,8 +204,8 @@ Result> MyMemoryManager::ViewBufferTo( } // Like AssertBufferEqual, but doesn't call Buffer::data() -void AssertMyBufferEqual(const Buffer& buffer, util::string_view expected) { - ASSERT_EQ(util::string_view(buffer), expected); +void AssertMyBufferEqual(const Buffer& buffer, std::string_view expected) { + ASSERT_EQ(std::string_view(buffer), expected); } void AssertIsCPUBuffer(const Buffer& buf) { diff --git a/cpp/src/arrow/builder_benchmark.cc b/cpp/src/arrow/builder_benchmark.cc index 97745d4692e..cf3e7f32d5e 100644 --- a/cpp/src/arrow/builder_benchmark.cc +++ b/cpp/src/arrow/builder_benchmark.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include "benchmark/benchmark.h" @@ -30,7 +31,6 @@ #include "arrow/testing/gtest_util.h" #include "arrow/util/bit_util.h" #include "arrow/util/decimal.h" -#include "arrow/util/string_view.h" namespace arrow { @@ -55,7 +55,7 @@ constexpr int64_t kBytesProcessPerRound = kNumberOfElements * sizeof(ValueType); constexpr int64_t kBytesProcessed = kRounds * kBytesProcessPerRound; static const char* kBinaryString = "12345678"; -static arrow::util::string_view kBinaryView(kBinaryString); +static std::string_view kBinaryView(kBinaryString); static void BuildIntArrayNoNulls(benchmark::State& state) { // NOLINT non-const reference for (auto _ : state) { diff --git a/cpp/src/arrow/c/bridge.cc b/cpp/src/arrow/c/bridge.cc index de531dbc607..2a7374fe6f1 100644 --- a/cpp/src/arrow/c/bridge.cc +++ b/cpp/src/arrow/c/bridge.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -40,7 +41,6 @@ #include "arrow/util/logging.h" #include "arrow/util/macros.h" #include "arrow/util/small_vector.h" -#include "arrow/util/string_view.h" #include "arrow/util/value_parsing.h" #include "arrow/visit_type_inline.h" @@ -666,7 +666,7 @@ namespace { static constexpr int64_t kMaxImportRecursionLevel = 64; -Status InvalidFormatString(util::string_view v) { +Status InvalidFormatString(std::string_view v) { return Status::Invalid("Invalid or unsupported format string: '", v, "'"); } @@ -674,13 +674,13 @@ class FormatStringParser { public: FormatStringParser() {} - explicit FormatStringParser(util::string_view v) : view_(v), index_(0) {} + explicit FormatStringParser(std::string_view v) : view_(v), index_(0) {} bool AtEnd() const { return index_ >= view_.length(); } char Next() { return view_[index_++]; } - util::string_view Rest() { return view_.substr(index_); } + std::string_view Rest() { return view_.substr(index_); } Status CheckNext(char c) { if (AtEnd() || Next() != c) { @@ -704,7 +704,7 @@ class FormatStringParser { } template - Result ParseInt(util::string_view v) { + Result ParseInt(std::string_view v) { using ArrowIntType = typename CTypeTraits::ArrowType; IntType value; if (!internal::ParseValue(v.data(), v.size(), &value)) { @@ -729,13 +729,13 @@ class FormatStringParser { } } - SmallVector Split(util::string_view v, char delim = ',') { - SmallVector parts; + SmallVector Split(std::string_view v, char delim = ',') { + SmallVector parts; size_t start = 0, end; while (true) { end = v.find_first_of(delim, start); parts.push_back(v.substr(start, end - start)); - if (end == util::string_view::npos) { + if (end == std::string_view::npos) { break; } start = end + 1; @@ -744,7 +744,7 @@ class FormatStringParser { } template - Result> ParseInts(util::string_view v) { + Result> ParseInts(std::string_view v) { auto parts = Split(v); std::vector result; result.reserve(parts.size()); @@ -758,7 +758,7 @@ class FormatStringParser { Status Invalid() { return InvalidFormatString(view_); } protected: - util::string_view view_; + std::string_view view_; size_t index_; }; diff --git a/cpp/src/arrow/c/bridge_test.cc b/cpp/src/arrow/c/bridge_test.cc index bb722c52b67..a54da82e10c 100644 --- a/cpp/src/arrow/c/bridge_test.cc +++ b/cpp/src/arrow/c/bridge_test.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -40,7 +41,6 @@ #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" -#include "arrow/util/string_view.h" namespace arrow { diff --git a/cpp/src/arrow/compute/exec/asof_join_node.cc b/cpp/src/arrow/compute/exec/asof_join_node.cc index 35e7b1c6cc6..09ef7d722d9 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -37,7 +38,6 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/future.h" #include "arrow/util/make_unique.h" -#include "arrow/util/string_view.h" namespace arrow { namespace compute { @@ -1015,15 +1015,15 @@ class AsofJoinNode : public ExecNode { static inline Result FindColIndex(const Schema& schema, const FieldRef& field_ref, - util::string_view key_kind) { + std::string_view key_kind) { auto match_res = field_ref.FindOne(schema); if (!match_res.ok()) { return Status::Invalid("Bad join key on table : ", match_res.status().message()); } ARROW_ASSIGN_OR_RAISE(auto match, match_res); if (match.indices().size() != 1) { - return Status::Invalid("AsOfJoinNode does not support a nested ", - to_string(key_kind), "-key ", field_ref.ToString()); + return Status::Invalid("AsOfJoinNode does not support a nested ", key_kind, "-key ", + field_ref.ToString()); } return match.indices()[0]; } diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 2e4bb06176a..c8dbd27d7b6 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "arrow/api.h" @@ -34,7 +35,6 @@ #include "arrow/testing/random.h" #include "arrow/util/checked_cast.h" #include "arrow/util/make_unique.h" -#include "arrow/util/string_view.h" #include "arrow/util/thread_pool.h" #define TRACED_TEST(t_class, t_name, t_body) \ @@ -69,7 +69,7 @@ bool is_temporal_primitive(Type::type type_id) { Result MakeBatchesFromNumString( const std::shared_ptr& schema, - const std::vector& json_strings, int multiplicity = 1) { + const std::vector& json_strings, int multiplicity = 1) { FieldVector num_fields; for (auto field : schema->fields()) { num_fields.push_back( @@ -413,12 +413,12 @@ struct BasicTestTypes { }; struct BasicTest { - BasicTest(const std::vector& l_data, - const std::vector& r0_data, - const std::vector& r1_data, - const std::vector& exp_nokey_data, - const std::vector& exp_emptykey_data, - const std::vector& exp_data, int64_t tolerance) + BasicTest(const std::vector& l_data, + const std::vector& r0_data, + const std::vector& r1_data, + const std::vector& exp_nokey_data, + const std::vector& exp_emptykey_data, + const std::vector& exp_data, int64_t tolerance) : l_data(std::move(l_data)), r0_data(std::move(r0_data)), r1_data(std::move(r1_data)), @@ -622,12 +622,12 @@ struct BasicTest { exp_emptykey_batches, exp_batches); } - std::vector l_data; - std::vector r0_data; - std::vector r1_data; - std::vector exp_nokey_data; - std::vector exp_emptykey_data; - std::vector exp_data; + std::vector l_data; + std::vector r0_data; + std::vector r1_data; + std::vector exp_nokey_data; + std::vector exp_emptykey_data; + std::vector exp_data; int64_t tolerance; }; diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index d23838303f7..ff59977b671 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -40,6 +40,7 @@ namespace arrow { using internal::checked_cast; using internal::checked_pointer_cast; +using internal::EndsWith; namespace compute { @@ -117,8 +118,7 @@ std::string PrintDatum(const Datum& datum) { case Type::STRING: case Type::LARGE_STRING: return '"' + - Escape(util::string_view(*datum.scalar_as().value)) + - '"'; + Escape(std::string_view(*datum.scalar_as().value)) + '"'; case Type::BINARY: case Type::FIXED_SIZE_BINARY: @@ -163,8 +163,8 @@ std::string Expression::ToString() const { return binary(Comparison::GetOp(*cmp)); } - constexpr util::string_view kleene = "_kleene"; - if (util::string_view{call->function_name}.ends_with(kleene)) { + constexpr std::string_view kleene = "_kleene"; + if (EndsWith(call->function_name, kleene)) { auto op = call->function_name.substr(0, call->function_name.size() - kleene.size()); return binary(std::move(op)); } diff --git a/cpp/src/arrow/compute/exec/hash_join_dict.cc b/cpp/src/arrow/compute/exec/hash_join_dict.cc index 560b0ea8d4d..4ce89446d3c 100644 --- a/cpp/src/arrow/compute/exec/hash_join_dict.cc +++ b/cpp/src/arrow/compute/exec/hash_join_dict.cc @@ -127,7 +127,7 @@ static Result> ConvertImp( } else { const auto& scalar = input.scalar_as(); if (scalar.is_valid) { - const util::string_view data = scalar.view(); + const std::string_view data = scalar.view(); DCHECK_EQ(data.size(), sizeof(FROM)); const FROM from = *reinterpret_cast(data.data()); const TO to_value = static_cast(from); diff --git a/cpp/src/arrow/compute/exec/hash_join_node_test.cc b/cpp/src/arrow/compute/exec/hash_join_node_test.cc index b45af654450..de3592ab086 100644 --- a/cpp/src/arrow/compute/exec/hash_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/hash_join_node_test.cc @@ -42,7 +42,7 @@ namespace compute { BatchesWithSchema GenerateBatchesFromString( const std::shared_ptr& schema, - const std::vector& json_strings, int multiplicity = 1) { + const std::vector& json_strings, int multiplicity = 1) { BatchesWithSchema out_batches{{}, schema}; std::vector types; diff --git a/cpp/src/arrow/compute/exec/subtree_test.cc b/cpp/src/arrow/compute/exec/subtree_test.cc index 9e6e86dbd4f..908af3be7ef 100644 --- a/cpp/src/arrow/compute/exec/subtree_test.cc +++ b/cpp/src/arrow/compute/exec/subtree_test.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -26,9 +27,12 @@ #include "arrow/compute/exec/forest_internal.h" #include "arrow/compute/exec/subtree_internal.h" #include "arrow/testing/gtest_util.h" -#include "arrow/util/string_view.h" +#include "arrow/util/string.h" namespace arrow { + +using internal::StartsWith; + namespace compute { using testing::ContainerEq; @@ -94,18 +98,18 @@ struct TestPathTree { using PT = TestPathTree; -util::string_view RemoveTrailingSlash(util::string_view key) { +std::string_view RemoveTrailingSlash(std::string_view key) { while (!key.empty() && key.back() == '/') { key.remove_suffix(1); } return key; } -bool IsAncestorOf(util::string_view ancestor, util::string_view descendant) { +bool IsAncestorOf(std::string_view ancestor, std::string_view descendant) { // See filesystem/path_util.h ancestor = RemoveTrailingSlash(ancestor); if (ancestor == "") return true; descendant = RemoveTrailingSlash(descendant); - if (!descendant.starts_with(ancestor)) return false; + if (!StartsWith(descendant, ancestor)) return false; descendant.remove_prefix(ancestor.size()); if (descendant.empty()) return true; return descendant.front() == '/'; diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index 2abe6e9e029..efb91a708ab 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -142,8 +142,7 @@ ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector& types, - util::string_view json) { +ExecBatch ExecBatchFromJSON(const std::vector& types, std::string_view json) { auto fields = ::arrow::internal::MapVector( [](const TypeHolder& th) { return field("", th.GetSharedPtr()); }, types); @@ -153,7 +152,7 @@ ExecBatch ExecBatchFromJSON(const std::vector& types, } ExecBatch ExecBatchFromJSON(const std::vector& types, - const std::vector& shapes, util::string_view json) { + const std::vector& shapes, std::string_view json) { DCHECK_EQ(types.size(), shapes.size()); ExecBatch batch = ExecBatchFromJSON(types, json); @@ -235,9 +234,9 @@ BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, return out; } -BatchesWithSchema MakeBatchesFromString( - const std::shared_ptr& schema, - const std::vector& json_strings, int multiplicity) { +BatchesWithSchema MakeBatchesFromString(const std::shared_ptr& schema, + const std::vector& json_strings, + int multiplicity) { BatchesWithSchema out_batches{{}, schema}; std::vector types; diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index 5b6e8226b7e..ae7eac61e95 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include "arrow/compute/exec.h" @@ -31,7 +32,6 @@ #include "arrow/testing/visibility.h" #include "arrow/util/async_generator.h" #include "arrow/util/pcg_random.h" -#include "arrow/util/string_view.h" namespace arrow { namespace compute { @@ -45,7 +45,7 @@ ExecNode* MakeDummyNode(ExecPlan* plan, std::string label, std::vector& types, util::string_view json); +ExecBatch ExecBatchFromJSON(const std::vector& types, std::string_view json); /// \brief Shape qualifier for value types. In certain instances /// (e.g. "map_lookup" kernel), an argument may only be a scalar, where in @@ -54,7 +54,7 @@ enum class ArgShape { ANY, ARRAY, SCALAR }; ARROW_TESTING_EXPORT ExecBatch ExecBatchFromJSON(const std::vector& types, - const std::vector& shapes, util::string_view json); + const std::vector& shapes, std::string_view json); struct BatchesWithSchema { std::vector batches; @@ -109,9 +109,9 @@ BatchesWithSchema MakeRandomBatches(const std::shared_ptr& schema, int num_batches = 10, int batch_size = 4); ARROW_TESTING_EXPORT -BatchesWithSchema MakeBatchesFromString( - const std::shared_ptr& schema, - const std::vector& json_strings, int multiplicity = 1); +BatchesWithSchema MakeBatchesFromString(const std::shared_ptr& schema, + const std::vector& json_strings, + int multiplicity = 1); ARROW_TESTING_EXPORT Result> SortTableOnAllFields(const std::shared_ptr
& tab); diff --git a/cpp/src/arrow/compute/exec/tpch_node_test.cc b/cpp/src/arrow/compute/exec/tpch_node_test.cc index 133dbfdf43c..dbc5b341d60 100644 --- a/cpp/src/arrow/compute/exec/tpch_node_test.cc +++ b/cpp/src/arrow/compute/exec/tpch_node_test.cc @@ -17,6 +17,11 @@ #include +#include +#include +#include +#include + #include "arrow/compute/exec/options.h" #include "arrow/compute/exec/test_util.h" #include "arrow/compute/exec/tpch_node.h" @@ -29,14 +34,13 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/make_unique.h" #include "arrow/util/pcg_random.h" +#include "arrow/util/string.h" #include "arrow/util/thread_pool.h" -#include -#include -#include -#include - namespace arrow { + +using internal::StartsWith; + namespace compute { namespace internal { @@ -94,10 +98,10 @@ void VerifyUniqueKey(std::unordered_set* seen, const Datum& d, int32_t } } -void VerifyStringAndNumber_Single(const util::string_view& row, - const util::string_view& prefix, const int64_t i, +void VerifyStringAndNumber_Single(const std::string_view& row, + const std::string_view& prefix, const int64_t i, const int32_t* nums, bool verify_padding) { - ASSERT_TRUE(row.starts_with(prefix)) << row << ", prefix=" << prefix << ", i=" << i; + ASSERT_TRUE(StartsWith(row, prefix)) << row << ", prefix=" << prefix << ", i=" << i; const char* num_str = row.data() + prefix.size(); const char* num_str_end = row.data() + row.size(); int64_t num = 0; @@ -124,7 +128,7 @@ void VerifyStringAndNumber_Single(const util::string_view& row, // corresponding row in numbers. Some TPC-H data is padded to 9 zeros, which this function // can optionally verify as well. This string function verifies fixed width columns. void VerifyStringAndNumber_FixedWidth(const Datum& strings, const Datum& numbers, - int byte_width, const util::string_view& prefix, + int byte_width, const std::string_view& prefix, bool verify_padding = true) { int64_t length = strings.length(); const char* str = reinterpret_cast(strings.array()->buffers[1]->data()); @@ -137,14 +141,14 @@ void VerifyStringAndNumber_FixedWidth(const Datum& strings, const Datum& numbers for (int64_t i = 0; i < length; i++) { const char* row = str + i * byte_width; - util::string_view view(row, byte_width); + std::string_view view(row, byte_width); VerifyStringAndNumber_Single(view, prefix, i, nums, verify_padding); } } // Same as above but for variable length columns void VerifyStringAndNumber_Varlen(const Datum& strings, const Datum& numbers, - const util::string_view& prefix, + const std::string_view& prefix, bool verify_padding = true) { int64_t length = strings.length(); const int32_t* offsets = @@ -160,7 +164,7 @@ void VerifyStringAndNumber_Varlen(const Datum& strings, const Datum& numbers, for (int64_t i = 0; i < length; i++) { int32_t start = offsets[i]; int32_t str_len = offsets[i + 1] - offsets[i]; - util::string_view view(str + start, str_len); + std::string_view view(str + start, str_len); VerifyStringAndNumber_Single(view, prefix, i, nums, verify_padding); } } @@ -253,7 +257,7 @@ void VerifyCorrectNumberOfWords_Varlen(const Datum& d, int num_words) { int32_t start = offsets[i]; int32_t end = offsets[i + 1]; int32_t str_len = end - start; - util::string_view view(str + start, str_len); + std::string_view view(str + start, str_len); bool is_only_alphas_or_spaces = true; for (const char& c : view) { bool is_space = c == ' '; @@ -300,14 +304,14 @@ void VerifyOneOf(const Datum& d, const std::unordered_set& possibilities) // Verifies that each fixed-width row is one of the possibilities void VerifyOneOf(const Datum& d, int32_t byte_width, - const std::unordered_set& possibilities) { + const std::unordered_set& possibilities) { int64_t length = d.length(); const char* col = reinterpret_cast(d.array()->buffers[1]->data()); for (int64_t i = 0; i < length; i++) { const char* row = col + i * byte_width; int32_t row_len = 0; while (row[row_len] && row_len < byte_width) row_len++; - util::string_view view(row, row_len); + std::string_view view(row, row_len); ASSERT_TRUE(possibilities.find(view) != possibilities.end()) << view << " is not a valid string."; } @@ -331,10 +335,10 @@ void CountModifiedComments(const Datum& d, int* good_count, int* bad_count) { for (int64_t i = 0; i < length; i++) { const char* row = str + offsets[i]; int32_t row_length = offsets[i + 1] - offsets[i]; - util::string_view view(row, row_length); - bool customer = view.find("Customer") != util::string_view::npos; - bool recommends = view.find("Recommends") != util::string_view::npos; - bool complaints = view.find("Complaints") != util::string_view::npos; + std::string_view view(row, row_length); + bool customer = view.find("Customer") != std::string_view::npos; + bool recommends = view.find("Recommends") != std::string_view::npos; + bool complaints = view.find("Complaints") != std::string_view::npos; if (customer) { ASSERT_TRUE(recommends ^ complaints); if (recommends) *good_count += 1; diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic.cc b/cpp/src/arrow/compute/kernels/aggregate_basic.cc index 400ccbdf9f6..ce8b7e867ec 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_basic.cc @@ -233,11 +233,11 @@ void AddCountDistinctKernels(ScalarAggregateFunction* func) { AddCountDistinctKernel(day_time_interval(), func); AddCountDistinctKernel(month_day_nano_interval(), func); // Binary & String - AddCountDistinctKernel(match::BinaryLike(), func); - AddCountDistinctKernel(match::LargeBinaryLike(), - func); + AddCountDistinctKernel(match::BinaryLike(), func); + AddCountDistinctKernel(match::LargeBinaryLike(), + func); // Fixed binary & Decimal - AddCountDistinctKernel( + AddCountDistinctKernel( match::FixedSizeBinaryLike(), func); } diff --git a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h index bd2fe534608..aa89f8dc3b4 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h +++ b/cpp/src/arrow/compute/kernels/aggregate_basic_internal.h @@ -360,7 +360,7 @@ struct MinMaxState> { return *this; } - void MergeOne(util::string_view value) { + void MergeOne(std::string_view value) { MergeOne(T(reinterpret_cast(value.data()))); } @@ -398,14 +398,14 @@ struct MinMaxStatemin = std::string(value); this->max = std::string(value); } else { - if (value < util::string_view(this->min)) { + if (value < std::string_view(this->min)) { this->min = std::string(value); - } else if (value > util::string_view(this->max)) { + } else if (value > std::string_view(this->max)) { this->max = std::string(value); } } diff --git a/cpp/src/arrow/compute/kernels/aggregate_test.cc b/cpp/src/arrow/compute/kernels/aggregate_test.cc index 8f400b2d249..c7ae70e2108 100644 --- a/cpp/src/arrow/compute/kernels/aggregate_test.cc +++ b/cpp/src/arrow/compute/kernels/aggregate_test.cc @@ -942,12 +942,12 @@ class TestCountDistinctKernel : public ::testing::Test { CheckScalar("count_distinct", {input}, Expected(expected_all), &all); } - void Check(const std::shared_ptr& type, util::string_view json, + void Check(const std::shared_ptr& type, std::string_view json, int64_t expected_all, bool has_nulls = true) { Check(ArrayFromJSON(type, json), expected_all, has_nulls); } - void Check(const std::shared_ptr& type, util::string_view json) { + void Check(const std::shared_ptr& type, std::string_view json) { auto input = ScalarFromJSON(type, json); auto zero = ResultWith(Expected(0)); auto one = ResultWith(Expected(1)); diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index a20b4ce1476..b0001832174 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include @@ -47,7 +48,6 @@ #include "arrow/util/logging.h" #include "arrow/util/macros.h" #include "arrow/util/make_unique.h" -#include "arrow/util/string_view.h" #include "arrow/visit_data_inline.h" namespace arrow { @@ -136,7 +136,7 @@ struct GetViewType> { template struct GetViewType::value || is_fixed_size_binary_type::value>> { - using T = util::string_view; + using T = std::string_view; using PhysicalType = T; static T LogicalValue(PhysicalType value) { return value; } @@ -145,7 +145,7 @@ struct GetViewType::value || template <> struct GetViewType { using T = Decimal128; - using PhysicalType = util::string_view; + using PhysicalType = std::string_view; static T LogicalValue(PhysicalType value) { return Decimal128(reinterpret_cast(value.data())); @@ -157,7 +157,7 @@ struct GetViewType { template <> struct GetViewType { using T = Decimal256; - using PhysicalType = util::string_view; + using PhysicalType = std::string_view; static T LogicalValue(PhysicalType value) { return Decimal256(reinterpret_cast(value.data())); @@ -271,9 +271,9 @@ struct ArrayIterator> { data(reinterpret_cast(arr.buffers[2].data)), position(0) {} - util::string_view operator()() { + std::string_view operator()() { offset_type next_offset = offsets[++position]; - auto result = util::string_view(data + cur_offset, next_offset - cur_offset); + auto result = std::string_view(data + cur_offset, next_offset - cur_offset); cur_offset = next_offset; return result; } @@ -292,8 +292,8 @@ struct ArrayIterator { width(arr.type->byte_width()), position(arr.offset) {} - util::string_view operator()() { - auto result = util::string_view(data + position * width, width); + std::string_view operator()() { + auto result = std::string_view(data + position * width, width); position++; return result; } @@ -331,7 +331,7 @@ template struct UnboxScalar> { using T = typename Type::c_type; static T Unbox(const Scalar& val) { - util::string_view view = + std::string_view view = checked_cast(val).view(); DCHECK_EQ(view.size(), sizeof(T)); return *reinterpret_cast(view.data()); @@ -340,9 +340,9 @@ struct UnboxScalar> { template struct UnboxScalar> { - using T = util::string_view; + using T = std::string_view; static T Unbox(const Scalar& val) { - if (!val.is_valid) return util::string_view(); + if (!val.is_valid) return std::string_view(); return checked_cast(val).view(); } }; @@ -401,7 +401,7 @@ struct BoxScalar { }; // A VisitArraySpanInline variant that calls its visitor function with logical -// values, such as Decimal128 rather than util::string_view. +// values, such as Decimal128 rather than std::string_view. template static typename ::arrow::internal::call_traits::enable_if_return::type diff --git a/cpp/src/arrow/compute/kernels/common.h b/cpp/src/arrow/compute/kernels/common.h index 21244320f38..bf90d114512 100644 --- a/cpp/src/arrow/compute/kernels/common.h +++ b/cpp/src/arrow/compute/kernels/common.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -42,7 +43,6 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" -#include "arrow/util/string_view.h" // IWYU pragma: end_exports diff --git a/cpp/src/arrow/compute/kernels/copy_data_internal.h b/cpp/src/arrow/compute/kernels/copy_data_internal.h index 2e13563980c..a4083e7e065 100644 --- a/cpp/src/arrow/compute/kernels/copy_data_internal.h +++ b/cpp/src/arrow/compute/kernels/copy_data_internal.h @@ -58,7 +58,7 @@ struct CopyDataUtils { if (!scalar.is_valid) { std::memset(begin, 0x00, width * length); } else { - const util::string_view buffer = scalar.view(); + const std::string_view buffer = scalar.view(); DCHECK_GE(buffer.size(), static_cast(width)); for (int i = 0; i < length; i++) { std::memcpy(begin, buffer.data(), width); diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 068fcab95e4..f947cc732f7 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -1373,7 +1373,7 @@ struct GroupedMinMaxImpl( batch, - [&](uint32_t g, util::string_view val) { + [&](uint32_t g, std::string_view val) { if (!mins_[g] || val < *mins_[g]) { mins_[g].emplace(val.data(), val.size(), allocator_); } @@ -2092,7 +2092,7 @@ struct GroupedOneImpl::value || Status Consume(const ExecSpan& batch) override { return VisitGroupedValues( batch, - [&](uint32_t g, util::string_view val) -> Status { + [&](uint32_t g, std::string_view val) -> Status { if (!bit_util::GetBit(has_one_.data(), g)) { ones_[g].emplace(val.data(), val.size(), allocator_); bit_util::SetBit(has_one_.mutable_data(), g); @@ -2419,7 +2419,7 @@ struct GroupedListImpl::value || num_args_ += num_values; return VisitGroupedValues( batch, - [&](uint32_t group, util::string_view val) -> Status { + [&](uint32_t group, std::string_view val) -> Status { values_.emplace_back(StringType(val.data(), val.size(), allocator_)); return Status::OK(); }, diff --git a/cpp/src/arrow/compute/kernels/row_encoder.cc b/cpp/src/arrow/compute/kernels/row_encoder.cc index 3ab6fc8c337..a38fa1db205 100644 --- a/cpp/src/arrow/compute/kernels/row_encoder.cc +++ b/cpp/src/arrow/compute/kernels/row_encoder.cc @@ -145,7 +145,7 @@ Status FixedWidthKeyEncoder::Encode(const ExecValue& data, int64_t batch_length, viewed.type = view_ty.get(); VisitArraySpanInline( viewed, - [&](util::string_view bytes) { + [&](std::string_view bytes) { auto& encoded_ptr = *encoded_bytes++; *encoded_ptr++ = kValidByte; memcpy(encoded_ptr, bytes.data(), byte_width_); @@ -160,7 +160,7 @@ Status FixedWidthKeyEncoder::Encode(const ExecValue& data, int64_t batch_length, } else { const auto& scalar = data.scalar_as(); if (scalar.is_valid) { - const util::string_view data = scalar.view(); + const std::string_view data = scalar.view(); DCHECK_EQ(data.size(), static_cast(byte_width_)); for (int64_t i = 0; i < batch_length; i++) { auto& encoded_ptr = *encoded_bytes++; diff --git a/cpp/src/arrow/compute/kernels/row_encoder.h b/cpp/src/arrow/compute/kernels/row_encoder.h index 139b1be4197..5fe80e0f506 100644 --- a/cpp/src/arrow/compute/kernels/row_encoder.h +++ b/cpp/src/arrow/compute/kernels/row_encoder.h @@ -121,7 +121,7 @@ struct VarLengthKeyEncoder : KeyEncoder { int64_t i = 0; VisitArraySpanInline( data.array, - [&](util::string_view bytes) { + [&](std::string_view bytes) { lengths[i++] += kExtraByteForNull + sizeof(Offset) + static_cast(bytes.size()); }, @@ -146,7 +146,7 @@ struct VarLengthKeyEncoder : KeyEncoder { if (data.is_array()) { VisitArraySpanInline( data.array, - [&](util::string_view bytes) { + [&](std::string_view bytes) { auto& encoded_ptr = *encoded_bytes++; *encoded_ptr++ = kValidByte; util::SafeStore(encoded_ptr, static_cast(bytes.size())); diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index be8a445c74a..7b74e8e5d60 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -452,7 +452,7 @@ template std::string MakeArray(Elements... elements) { std::vector elements_as_strings = {std::to_string(elements)...}; - std::vector elements_as_views(sizeof...(Elements)); + std::vector elements_as_views(sizeof...(Elements)); std::copy(elements_as_strings.begin(), elements_as_strings.end(), elements_as_views.begin()); diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc index 7a77b63e37a..4e547ef6ccf 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_string.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_string.cc @@ -57,7 +57,7 @@ struct NumericToStringCastFunctor { RETURN_NOT_OK(VisitArraySpanInline( input, [&](value_type v) { - return formatter(v, [&](util::string_view v) { return builder.Append(v); }); + return formatter(v, [&](std::string_view v) { return builder.Append(v); }); }, [&]() { return builder.AppendNull(); })); @@ -84,7 +84,7 @@ struct TemporalToStringCastFunctor { RETURN_NOT_OK(VisitArraySpanInline( input, [&](value_type v) { - return formatter(v, [&](util::string_view v) { return builder.Append(v); }); + return formatter(v, [&](std::string_view v) { return builder.Append(v); }); }, [&]() { return builder.AppendNull(); })); @@ -126,7 +126,7 @@ struct TemporalToStringCastFunctor { RETURN_NOT_OK(VisitArraySpanInline( input, [&](value_type v) { - return formatter(v, [&](util::string_view v) { return builder.Append(v); }); + return formatter(v, [&](std::string_view v) { return builder.Append(v); }); }, [&]() { builder.UnsafeAppendNull(); @@ -196,7 +196,7 @@ struct TemporalToStringCastFunctor { struct Utf8Validator { Status VisitNull() { return Status::OK(); } - Status VisitValue(util::string_view str) { + Status VisitValue(std::string_view str) { if (ARROW_PREDICT_FALSE(!ValidateUTF8Inline(str))) { return Status::Invalid("Invalid UTF8 payload"); } diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index 290a0e5df66..bbd57988477 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -29,7 +29,7 @@ namespace arrow { using internal::checked_cast; using internal::checked_pointer_cast; -using util::string_view; +using std::string_view; namespace compute { namespace internal { diff --git a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc index 2b834ee2eb3..48fa780b031 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare_test.cc @@ -42,7 +42,7 @@ using internal::BitmapReader; namespace compute { -using util::string_view; +using std::string_view; template static void ValidateCompare(CompareOptions options, const Datum& lhs, const Datum& rhs, @@ -136,7 +136,7 @@ Datum SimpleScalarArrayCompare(CompareOptions options, const Datum& const Datum& rhs) { bool swap = lhs.is_array(); auto array = std::static_pointer_cast((swap ? lhs : rhs).make_array()); - auto value = util::string_view( + auto value = std::string_view( *std::static_pointer_cast((swap ? rhs : lhs).scalar())->value); std::vector bitmap(array->length()); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 8c941934a1e..bb3ac6635e0 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -723,7 +723,7 @@ struct IfElseFunctor> { // ASA static Status Call(KernelContext* ctx, const ArraySpan& cond, const Scalar& left, const ArraySpan& right, ExecResult* out) { - util::string_view left_data = internal::UnboxScalar::Unbox(left); + std::string_view left_data = internal::UnboxScalar::Unbox(left); auto left_size = static_cast(left_data.size()); const auto* right_offsets = right.GetValues(1); @@ -754,7 +754,7 @@ struct IfElseFunctor> { const auto* left_offsets = left.GetValues(1); const uint8_t* left_data = left.buffers[2].data; - util::string_view right_data = internal::UnboxScalar::Unbox(right); + std::string_view right_data = internal::UnboxScalar::Unbox(right); auto right_size = static_cast(right_data.size()); // allocate data buffer conservatively @@ -779,10 +779,10 @@ struct IfElseFunctor> { // ASS static Status Call(KernelContext* ctx, const ArraySpan& cond, const Scalar& left, const Scalar& right, ExecResult* out) { - util::string_view left_data = internal::UnboxScalar::Unbox(left); + std::string_view left_data = internal::UnboxScalar::Unbox(left); auto left_size = static_cast(left_data.size()); - util::string_view right_data = internal::UnboxScalar::Unbox(right); + std::string_view right_data = internal::UnboxScalar::Unbox(right); auto right_size = static_cast(right_data.size()); // allocate data buffer conservatively @@ -2314,9 +2314,9 @@ struct CoalesceFunctor> { } RETURN_NOT_OK(builder.ReserveData(static_cast(data_reserve))); - util::string_view fill_value(*scalar.value); + std::string_view fill_value(*scalar.value); VisitArraySpanInline( - left, [&](util::string_view s) { builder.UnsafeAppend(s); }, + left, [&](std::string_view s) { builder.UnsafeAppend(s); }, [&]() { builder.UnsafeAppend(fill_value); }); ARROW_ASSIGN_OR_RAISE(auto temp_output, builder.Finish()); diff --git a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc index 86b7a5597a0..9d8e33b1d04 100644 --- a/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_set_lookup_test.cc @@ -818,7 +818,7 @@ TEST_F(TestIndexInKernel, BinaryResizeTable) { char buf[kBufSize] = "test"; ASSERT_GE(snprintf(buf + 4, sizeof(buf) - 4, "%d", index), 0); - input_builder.UnsafeAppend(util::string_view(buf)); + input_builder.UnsafeAppend(std::string_view(buf)); expected_builder.UnsafeAppend(index); } diff --git a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc index c362cfa8d99..db5eca79d96 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_ascii.cc @@ -26,9 +26,14 @@ #include "arrow/array/builder_nested.h" #include "arrow/compute/kernels/scalar_string_internal.h" +#include "arrow/util/string.h" #include "arrow/util/value_parsing.h" namespace arrow { + +using internal::EndsWith; +using internal::StartsWith; + namespace compute { namespace internal { @@ -38,11 +43,11 @@ namespace { // re2 utilities #ifdef ARROW_WITH_RE2 -util::string_view ToStringView(re2::StringPiece piece) { +std::string_view ToStringView(re2::StringPiece piece) { return {piece.data(), piece.length()}; } -re2::StringPiece ToStringPiece(util::string_view view) { +re2::StringPiece ToStringPiece(std::string_view view) { return {view.data(), view.length()}; } @@ -261,7 +266,7 @@ struct StringBinaryTransformExecBase { // Apply transform RETURN_NOT_OK(VisitArraySpanInline( data1, - [&](util::string_view input_string_view) { + [&](std::string_view input_string_view) { auto input_ncodeunits = static_cast(input_string_view.length()); auto input_string = reinterpret_cast(input_string_view.data()); ARROW_ASSIGN_OR_RAISE( @@ -844,7 +849,7 @@ void AddAsciiStringCaseConversion(FunctionRegistry* registry) { // Binary string length struct BinaryLength { - template + template static OutValue Call(KernelContext*, Arg0Value val, Status*) { return static_cast(val.size()); } @@ -1238,7 +1243,7 @@ struct PlainSubstringMatcher { } } - int64_t Find(util::string_view current) const { + int64_t Find(std::string_view current) const { // Phase 2: Find the prefix in the data const auto pattern_length = options_.pattern.size(); int64_t pattern_pos = 0; @@ -1257,7 +1262,7 @@ struct PlainSubstringMatcher { return -1; } - bool Match(util::string_view current) const { return Find(current) >= 0; } + bool Match(std::string_view current) const { return Find(current) >= 0; } }; struct PlainStartsWithMatcher { @@ -1273,9 +1278,8 @@ struct PlainStartsWithMatcher { return ::arrow::internal::make_unique(options); } - bool Match(util::string_view current) const { - // string_view::starts_with is C++20 - return current.substr(0, options_.pattern.size()) == options_.pattern; + bool Match(std::string_view current) const { + return StartsWith(current, options_.pattern); } }; @@ -1292,11 +1296,8 @@ struct PlainEndsWithMatcher { return ::arrow::internal::make_unique(options); } - bool Match(util::string_view current) const { - // string_view::ends_with is C++20 - return current.size() >= options_.pattern.size() && - current.substr(current.size() - options_.pattern.size(), - options_.pattern.size()) == options_.pattern; + bool Match(std::string_view current) const { + return EndsWith(current, options_.pattern); } }; @@ -1319,7 +1320,7 @@ struct RegexSubstringMatcher { regex_match_(options_.pattern, MakeRE2Options(is_utf8, options.ignore_case, literal)) {} - bool Match(util::string_view current) const { + bool Match(std::string_view current) const { auto piece = re2::StringPiece(current.data(), current.length()); return RE2::PartialMatch(piece, regex_match_); } @@ -1341,7 +1342,7 @@ struct MatchSubstringImpl { for (int64_t i = 0; i < length; ++i) { const char* current_data = reinterpret_cast(data + offsets[i]); int64_t current_length = offsets[i + 1] - offsets[i]; - if (matcher->Match(util::string_view(current_data, current_length))) { + if (matcher->Match(std::string_view(current_data, current_length))) { bitmap_writer.Set(); } bitmap_writer.Next(); @@ -1660,7 +1661,7 @@ struct FindSubstring { explicit FindSubstring(PlainSubstringMatcher matcher) : matcher_(std::move(matcher)) {} template - OutValue Call(KernelContext*, util::string_view val, Status*) const { + OutValue Call(KernelContext*, std::string_view val, Status*) const { return static_cast(matcher_.Find(val)); } }; @@ -1680,7 +1681,7 @@ struct FindSubstringRegex { } template - OutValue Call(KernelContext*, util::string_view val, Status*) const { + OutValue Call(KernelContext*, std::string_view val, Status*) const { re2::StringPiece piece(val.data(), val.length()); re2::StringPiece match; if (RE2::PartialMatch(piece, *regex_match_, &match)) { @@ -1781,7 +1782,7 @@ struct CountSubstring { explicit CountSubstring(PlainSubstringMatcher matcher) : matcher_(std::move(matcher)) {} template - OutValue Call(KernelContext*, util::string_view val, Status*) const { + OutValue Call(KernelContext*, std::string_view val, Status*) const { OutValue count = 0; uint64_t start = 0; const auto pattern_size = std::max(1, matcher_.options_.pattern.size()); @@ -1815,7 +1816,7 @@ struct CountSubstringRegex { } template - OutValue Call(KernelContext*, util::string_view val, Status*) const { + OutValue Call(KernelContext*, std::string_view val, Status*) const { OutValue count = 0; re2::StringPiece input(val.data(), val.size()); auto last_size = input.size(); @@ -1950,7 +1951,7 @@ struct ReplaceSubstring { RETURN_NOT_OK(VisitArraySpanInline( batch[0].array, - [&](util::string_view s) { + [&](std::string_view s) { RETURN_NOT_OK(replacer.ReplaceString(s, &value_data_builder)); offset_builder.UnsafeAppend( static_cast(value_data_builder.length())); @@ -1979,9 +1980,13 @@ struct PlainSubstringReplacer { explicit PlainSubstringReplacer(const ReplaceSubstringOptions& options) : options_(options) {} - Status ReplaceString(util::string_view s, TypedBufferBuilder* builder) const { - const char* i = s.begin(); - const char* end = s.end(); + Status ReplaceString(std::string_view s, TypedBufferBuilder* builder) const { + if (s.empty()) { + // Special-case empty input as s.data() may not be a valid pointer + return Status::OK(); + } + const char* i = s.data(); + const char* end = s.data() + s.length(); int64_t max_replacements = options_.max_replacements; while ((i < end) && (max_replacements != 0)) { const char* pos = @@ -2040,11 +2045,15 @@ struct RegexSubstringReplacer { regex_find_("(" + options_.pattern + ")", MakeRE2Options()), regex_replacement_(options_.pattern, MakeRE2Options()) {} - Status ReplaceString(util::string_view s, TypedBufferBuilder* builder) const { + Status ReplaceString(std::string_view s, TypedBufferBuilder* builder) const { + if (s.empty()) { + // Special-case empty input as s.data() may not be a valid pointer + return Status::OK(); + } re2::StringPiece replacement(options_.replacement); if (options_.max_replacements == -1) { - std::string s_copy(s.to_string()); + std::string s_copy(s); RE2::GlobalReplace(&s_copy, regex_replacement_, replacement); return builder->Append(reinterpret_cast(s_copy.data()), s_copy.length()); @@ -2053,8 +2062,8 @@ struct RegexSubstringReplacer { // Since RE2 does not have the concept of max_replacements, we have to do some work // ourselves. // We might do this faster similar to RE2::GlobalReplace using Match and Rewrite - const char* i = s.begin(); - const char* end = s.end(); + const char* i = s.data(); + const char* end = s.data() + s.length(); re2::StringPiece piece(s.data(), s.length()); int64_t max_replacements = options_.max_replacements; @@ -2228,7 +2237,7 @@ struct ExtractRegexBase { args_pointers_start = (group_count > 0) ? args_pointers.data() : &null_arg; } - bool Match(util::string_view s) { + bool Match(std::string_view s) { return RE2::PartialMatchN(ToStringPiece(s), *data.regex, args_pointers_start, group_count); } @@ -2266,7 +2275,7 @@ struct ExtractRegex : public ExtractRegexBase { } auto visit_null = [&]() { return struct_builder->AppendNull(); }; - auto visit_value = [&](util::string_view s) { + auto visit_value = [&](std::string_view s) { if (Match(s)) { for (int i = 0; i < group_count; i++) { RETURN_NOT_OK(field_builders[i]->Append(ToStringView(found_values[i]))); @@ -2669,17 +2678,17 @@ struct BinaryJoin { }; struct SeparatorScalarLookup { - const util::string_view separator; + const std::string_view separator; bool IsNull(int64_t i) { return false; } - util::string_view GetView(int64_t i) { return separator; } + std::string_view GetView(int64_t i) { return separator; } }; struct SeparatorArrayLookup { const ArrayType& separators; bool IsNull(int64_t i) { return separators.IsNull(i); } - util::string_view GetView(int64_t i) { return separators.GetView(i); } + std::string_view GetView(int64_t i) { return separators.GetView(i); } }; // Scalar, array -> array @@ -2742,7 +2751,7 @@ struct BinaryJoin { return Status::OK(); } - util::string_view separator(*separator_scalar.value); + std::string_view separator(*separator_scalar.value); const auto& strings = checked_cast(*lists.values()); const auto list_offsets = lists.raw_value_offsets(); @@ -2795,7 +2804,7 @@ struct BinaryJoin { const ArrayType& separators; bool IsNull(int64_t i) { return separators.IsNull(i); } - util::string_view GetView(int64_t i) { return separators.GetView(i); } + std::string_view GetView(int64_t i) { return separators.GetView(i); } }; return JoinStrings(lists.length(), strings, ListArrayOffsetLookup{lists}, SeparatorArrayLookup{separators}, &builder, out); @@ -2868,7 +2877,7 @@ struct BinaryJoinElementWise { RETURN_NOT_OK(builder.Reserve(batch.length)); RETURN_NOT_OK(builder.ReserveData(final_size)); - std::vector valid_cols(batch.num_values()); + std::vector valid_cols(batch.num_values()); for (int64_t row = 0; row < batch.length; row++) { int num_valid = 0; // Not counting separator for (int col = 0; col < batch.num_values(); col++) { @@ -2878,7 +2887,7 @@ struct BinaryJoinElementWise { valid_cols[col] = UnboxScalar::Unbox(scalar); if (col < batch.num_values() - 1) num_valid++; } else { - valid_cols[col] = util::string_view(); + valid_cols[col] = std::string_view(); } } else { const ArraySpan& array = batch[col].array; @@ -2887,11 +2896,11 @@ struct BinaryJoinElementWise { const offset_type* offsets = array.GetValues(1); const uint8_t* data = array.GetValues(2, /*absolute_offset=*/0); const int64_t length = offsets[row + 1] - offsets[row]; - valid_cols[col] = util::string_view( + valid_cols[col] = std::string_view( reinterpret_cast(data + offsets[row]), length); if (col < batch.num_values() - 1) num_valid++; } else { - valid_cols[col] = util::string_view(); + valid_cols[col] = std::string_view(); } } } @@ -2914,7 +2923,7 @@ struct BinaryJoinElementWise { const auto separator = valid_cols.back(); bool first = true; for (int col = 0; col < batch.num_values() - 1; col++) { - util::string_view value = valid_cols[col]; + std::string_view value = valid_cols[col]; if (!value.data()) { switch (options.null_handling) { case JoinOptions::EMIT_NULL: diff --git a/cpp/src/arrow/compute/kernels/scalar_string_internal.h b/cpp/src/arrow/compute/kernels/scalar_string_internal.h index 32731414e08..defd7c37157 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_internal.h +++ b/cpp/src/arrow/compute/kernels/scalar_string_internal.h @@ -224,7 +224,7 @@ struct StringPredicateFunctor { ArraySpan* out_arr = out->array_span(); ::arrow::internal::GenerateBitsUnrolled( out_arr->buffers[1].data, out_arr->offset, input.length, [&]() -> bool { - util::string_view val = input_it(); + std::string_view val = input_it(); return Predicate::Call(ctx, reinterpret_cast(val.data()), val.size(), &st); }); @@ -307,7 +307,7 @@ struct StringSplitExec { using State = OptionsWrapper; // Keep the temporary storage accross individual values, to minimize reallocations - std::vector parts; + std::vector parts; Options options; explicit StringSplitExec(const Options& options) : options(options) {} @@ -351,7 +351,7 @@ struct StringSplitExec { return Status::OK(); } - Status SplitString(const util::string_view& s, SplitFinder* finder, + Status SplitString(const std::string_view& s, SplitFinder* finder, BuilderType* builder) { const uint8_t* begin = reinterpret_cast(s.data()); const uint8_t* end = begin + s.length(); diff --git a/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc b/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc index 4b3191c825d..fb197e13a68 100644 --- a/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc +++ b/cpp/src/arrow/compute/kernels/scalar_string_utf8.cc @@ -524,7 +524,7 @@ struct Utf8NormalizeBase { // Try to decompose the given UTF8 string into the codepoints space, // returning the number of codepoints output. - Result DecomposeIntoScratch(util::string_view v) { + Result DecomposeIntoScratch(std::string_view v) { auto decompose = [&]() { return utf8proc_decompose(reinterpret_cast(v.data()), v.size(), @@ -544,7 +544,7 @@ struct Utf8NormalizeBase { return res; } - Result Decompose(util::string_view v, BufferBuilder* data_builder) { + Result Decompose(std::string_view v, BufferBuilder* data_builder) { if (::arrow::util::ValidateAscii(v)) { // Fast path: normalization is a no-op RETURN_NOT_OK(data_builder->Append(v.data(), v.size())); @@ -623,7 +623,7 @@ struct Utf8NormalizeExec : public Utf8NormalizeBase { RETURN_NOT_OK(VisitArraySpanInline( array, - [&](util::string_view v) { + [&](std::string_view v) { ARROW_ASSIGN_OR_RAISE(auto n_bytes, exec.Decompose(v, &data_builder)); offset += n_bytes; *out_offsets++ = static_cast(offset); @@ -656,7 +656,7 @@ void AddUtf8StringNormalize(FunctionRegistry* registry) { // String length struct Utf8Length { - template + template static OutValue Call(KernelContext*, Arg0Value val, Status*) { auto str = reinterpret_cast(val.data()); auto strlen = val.size(); diff --git a/cpp/src/arrow/compute/kernels/scalar_temporal_unary.cc b/cpp/src/arrow/compute/kernels/scalar_temporal_unary.cc index d7c045d84b0..c0dc747e497 100644 --- a/cpp/src/arrow/compute/kernels/scalar_temporal_unary.cc +++ b/cpp/src/arrow/compute/kernels/scalar_temporal_unary.cc @@ -1265,7 +1265,7 @@ struct Strptime { out_writer.Next(); null_count++; }; - auto visit_value = [&](util::string_view s) { + auto visit_value = [&](std::string_view s) { int64_t result; if ((*self.parser)(s.data(), s.size(), self.unit, &result)) { *out_data++ = result; @@ -1292,7 +1292,7 @@ struct Strptime { *out_data++ = 0; return Status::OK(); }; - auto visit_value = [&](util::string_view s) { + auto visit_value = [&](std::string_view s) { int64_t result; if ((*self.parser)(s.data(), s.size(), self.unit, &result)) { *out_data++ = result; diff --git a/cpp/src/arrow/compute/kernels/vector_hash.cc b/cpp/src/arrow/compute/kernels/vector_hash.cc index c8b5173b8d9..c294992d27f 100644 --- a/cpp/src/arrow/compute/kernels/vector_hash.cc +++ b/cpp/src/arrow/compute/kernels/vector_hash.cc @@ -517,7 +517,7 @@ struct HashKernelTraits> { template struct HashKernelTraits> { - using HashKernel = RegularHashKernel; + using HashKernel = RegularHashKernel; }; template diff --git a/cpp/src/arrow/compute/kernels/vector_selection_test.cc b/cpp/src/arrow/compute/kernels/vector_selection_test.cc index f98af93eef3..a58825abdab 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_test.cc @@ -37,7 +37,7 @@ namespace arrow { using internal::checked_cast; using internal::checked_pointer_cast; -using util::string_view; +using std::string_view; namespace compute { diff --git a/cpp/src/arrow/csv/chunker.cc b/cpp/src/arrow/csv/chunker.cc index dc863579db0..bc1b69cb8ae 100644 --- a/cpp/src/arrow/csv/chunker.cc +++ b/cpp/src/arrow/csv/chunker.cc @@ -20,13 +20,13 @@ #include #include #include +#include #include #include "arrow/csv/lexing_internal.h" #include "arrow/status.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" -#include "arrow/util/string_view.h" namespace arrow { namespace csv { @@ -269,7 +269,7 @@ class LexingBoundaryFinder : public BoundaryFinder { explicit LexingBoundaryFinder(ParseOptions options) : options_(std::move(options)), lexer_(options_) {} - Status FindFirst(util::string_view partial, util::string_view block, + Status FindFirst(std::string_view partial, std::string_view block, int64_t* out_pos) override { lexer_.Reset(); if (lexer_.ShouldUseBulkFilter(block.data(), block.data() + block.size())) { @@ -280,7 +280,7 @@ class LexingBoundaryFinder : public BoundaryFinder { } template - Status FindFirstInternal(util::string_view partial, util::string_view block, + Status FindFirstInternal(std::string_view partial, std::string_view block, int64_t* out_pos) { const char* line_end = lexer_.template ReadLine( partial.data(), partial.data() + partial.size()); @@ -298,7 +298,7 @@ class LexingBoundaryFinder : public BoundaryFinder { return Status::OK(); } - Status FindLast(util::string_view block, int64_t* out_pos) override { + Status FindLast(std::string_view block, int64_t* out_pos) override { lexer_.Reset(); if (lexer_.ShouldUseBulkFilter(block.data(), block.data() + block.size())) { return FindLastInternal(block, out_pos); @@ -308,7 +308,7 @@ class LexingBoundaryFinder : public BoundaryFinder { } template - Status FindLastInternal(util::string_view block, int64_t* out_pos) { + Status FindLastInternal(std::string_view block, int64_t* out_pos) { const char* data = block.data(); const char* const data_end = block.data() + block.size(); @@ -331,7 +331,7 @@ class LexingBoundaryFinder : public BoundaryFinder { return Status::OK(); } - Status FindNth(util::string_view partial, util::string_view block, int64_t count, + Status FindNth(std::string_view partial, std::string_view block, int64_t count, int64_t* out_pos, int64_t* num_found) override { lexer_.Reset(); diff --git a/cpp/src/arrow/csv/converter.cc b/cpp/src/arrow/csv/converter.cc index c07eddffd43..a06686954ed 100644 --- a/cpp/src/arrow/csv/converter.cc +++ b/cpp/src/arrow/csv/converter.cc @@ -125,8 +125,8 @@ struct ValueDecoder { if (quoted && !options_.quoted_strings_can_be_null) { return false; } - return null_trie_.Find( - util::string_view(reinterpret_cast(data), size)) >= 0; + return null_trie_.Find(std::string_view(reinterpret_cast(data), size)) >= + 0; } protected: @@ -166,7 +166,7 @@ struct FixedSizeBinaryValueDecoder : public ValueDecoder { template struct BinaryValueDecoder : public ValueDecoder { - using value_type = util::string_view; + using value_type = std::string_view; using ValueDecoder::ValueDecoder; @@ -252,12 +252,12 @@ struct BooleanValueDecoder : public ValueDecoder { Status Decode(const uint8_t* data, uint32_t size, bool quoted, value_type* out) { // XXX should quoted values be allowed at all? - if (false_trie_.Find(util::string_view(reinterpret_cast(data), size)) >= + if (false_trie_.Find(std::string_view(reinterpret_cast(data), size)) >= 0) { *out = false; return Status::OK(); } - if (ARROW_PREDICT_TRUE(true_trie_.Find(util::string_view( + if (ARROW_PREDICT_TRUE(true_trie_.Find(std::string_view( reinterpret_cast(data), size)) >= 0)) { *out = true; return Status::OK(); @@ -288,7 +288,7 @@ struct DecimalValueDecoder : public ValueDecoder { TrimWhiteSpace(&data, &size); Decimal128 decimal; int32_t precision, scale; - util::string_view view(reinterpret_cast(data), size); + std::string_view view(reinterpret_cast(data), size); RETURN_NOT_OK(Decimal128::FromString(view, &decimal, &precision, &scale)); if (precision > type_precision_) { return Status::Invalid("Error converting '", view, "' to ", type_->ToString(), diff --git a/cpp/src/arrow/csv/converter_test.cc b/cpp/src/arrow/csv/converter_test.cc index c32b07d2de4..ea4e171d57e 100644 --- a/cpp/src/arrow/csv/converter_test.cc +++ b/cpp/src/arrow/csv/converter_test.cc @@ -655,7 +655,7 @@ TEST(TimestampConversion, UserDefinedParsersWithZone) { AssertConversionError(type, {"01/02/1970,1970-01-03T00:00:00+0000\n"}, {0}, options); } -Decimal128 Dec128(util::string_view value) { +Decimal128 Dec128(std::string_view value) { Decimal128 dec; int32_t scale = 0; int32_t precision = 0; diff --git a/cpp/src/arrow/csv/invalid_row.h b/cpp/src/arrow/csv/invalid_row.h index 8a07b568a35..4360ceaaea6 100644 --- a/cpp/src/arrow/csv/invalid_row.h +++ b/cpp/src/arrow/csv/invalid_row.h @@ -18,8 +18,7 @@ #pragma once #include - -#include "arrow/util/string_view.h" +#include namespace arrow { namespace csv { @@ -36,7 +35,7 @@ struct InvalidRow { /// CSV header rows). int64_t number; /// \brief View of the entire row. Memory will be freed after callback returns - const util::string_view text; + const std::string_view text; }; /// \brief Result returned by an InvalidRowHandler diff --git a/cpp/src/arrow/csv/parser.cc b/cpp/src/arrow/csv/parser.cc index 8b060df2540..da3472a9d9a 100644 --- a/cpp/src/arrow/csv/parser.cc +++ b/cpp/src/arrow/csv/parser.cc @@ -212,7 +212,7 @@ class BlockParserImpl { batch_.num_rows_ + batch_.num_skipped_rows(); InvalidRow row{batch_.num_cols_, num_cols, first_row_ < 0 ? -1 : first_row_ + batch_row_including_skipped, - util::string_view(start, end - start)}; + std::string_view(start, end - start)}; if (options_.invalid_row_handler && options_.invalid_row_handler(row) == InvalidRowResult::Skip) { @@ -508,7 +508,7 @@ class BlockParserImpl { } template - Status ParseSpecialized(const std::vector& views, bool is_final, + Status ParseSpecialized(const std::vector& views, bool is_final, uint32_t* out_size) { internal::PreferredBulkFilterType bulk_filter(options_); @@ -604,7 +604,7 @@ class BlockParserImpl { return Status::OK(); } - Status Parse(const std::vector& data, bool is_final, + Status Parse(const std::vector& data, bool is_final, uint32_t* out_size) { if (options_.quoting) { if (options_.escaping) { @@ -651,21 +651,20 @@ BlockParser::BlockParser(MemoryPool* pool, ParseOptions options, int32_t num_col BlockParser::~BlockParser() {} -Status BlockParser::Parse(const std::vector& data, - uint32_t* out_size) { +Status BlockParser::Parse(const std::vector& data, uint32_t* out_size) { return impl_->Parse(data, false /* is_final */, out_size); } -Status BlockParser::ParseFinal(const std::vector& data, +Status BlockParser::ParseFinal(const std::vector& data, uint32_t* out_size) { return impl_->Parse(data, true /* is_final */, out_size); } -Status BlockParser::Parse(util::string_view data, uint32_t* out_size) { +Status BlockParser::Parse(std::string_view data, uint32_t* out_size) { return impl_->Parse({data}, false /* is_final */, out_size); } -Status BlockParser::ParseFinal(util::string_view data, uint32_t* out_size) { +Status BlockParser::ParseFinal(std::string_view data, uint32_t* out_size) { return impl_->Parse({data}, true /* is_final */, out_size); } diff --git a/cpp/src/arrow/csv/parser.h b/cpp/src/arrow/csv/parser.h index fb003faaff6..e257d315e30 100644 --- a/cpp/src/arrow/csv/parser.h +++ b/cpp/src/arrow/csv/parser.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include "arrow/buffer.h" @@ -28,7 +29,6 @@ #include "arrow/csv/type_fwd.h" #include "arrow/status.h" #include "arrow/util/macros.h" -#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" namespace arrow { @@ -169,23 +169,23 @@ class ARROW_EXPORT BlockParser { /// /// Parse a block of CSV data, ingesting up to max_num_rows rows. /// The number of bytes actually parsed is returned in out_size. - Status Parse(util::string_view data, uint32_t* out_size); + Status Parse(std::string_view data, uint32_t* out_size); /// \brief Parse sequential blocks of data /// /// Only the last block is allowed to be truncated. - Status Parse(const std::vector& data, uint32_t* out_size); + Status Parse(const std::vector& data, uint32_t* out_size); /// \brief Parse the final block of data /// /// Like Parse(), but called with the final block in a file. /// The last row may lack a trailing line separator. - Status ParseFinal(util::string_view data, uint32_t* out_size); + Status ParseFinal(std::string_view data, uint32_t* out_size); /// \brief Parse the final sequential blocks of data /// /// Only the last block is allowed to be truncated. - Status ParseFinal(const std::vector& data, uint32_t* out_size); + Status ParseFinal(const std::vector& data, uint32_t* out_size); /// \brief Return the number of parsed rows int32_t num_rows() const { return parsed_batch().num_rows(); } diff --git a/cpp/src/arrow/csv/parser_benchmark.cc b/cpp/src/arrow/csv/parser_benchmark.cc index 84495fc542e..203cfa4ea02 100644 --- a/cpp/src/arrow/csv/parser_benchmark.cc +++ b/cpp/src/arrow/csv/parser_benchmark.cc @@ -20,12 +20,12 @@ #include #include #include +#include #include "arrow/csv/chunker.h" #include "arrow/csv/options.h" #include "arrow/csv/parser.h" #include "arrow/testing/gtest_util.h" -#include "arrow/util/string_view.h" namespace arrow { namespace csv { @@ -77,7 +77,7 @@ static std::string BuildCSVData(const Example& example) { static void BenchmarkCSVChunking(benchmark::State& state, // NOLINT non-const reference const std::string& csv, ParseOptions options) { auto chunker = MakeChunker(options); - auto block = std::make_shared(util::string_view(csv)); + auto block = std::make_shared(std::string_view(csv)); while (state.KeepRunning()) { std::shared_ptr whole, partial; @@ -161,7 +161,7 @@ static void BenchmarkCSVParsing(benchmark::State& state, // NOLINT non-const re while (state.KeepRunning()) { uint32_t parsed_size = 0; - ABORT_NOT_OK(parser.Parse(util::string_view(csv), &parsed_size)); + ABORT_NOT_OK(parser.Parse(std::string_view(csv), &parsed_size)); // Include performance of visiting the parsed values, as that might // vary depending on the parser's internal data structures. diff --git a/cpp/src/arrow/csv/parser_test.cc b/cpp/src/arrow/csv/parser_test.cc index 3fb2f11387d..960a69c59db 100644 --- a/cpp/src/arrow/csv/parser_test.cc +++ b/cpp/src/arrow/csv/parser_test.cc @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#include #include #include #include @@ -120,7 +121,7 @@ void GetLastRow(const BlockParser& parser, std::vector* out, } } -size_t TotalViewLength(const std::vector& views) { +size_t TotalViewLength(const std::vector& views) { size_t total_view_length = 0; for (const auto& view : views) { total_view_length += view.length(); @@ -128,12 +129,19 @@ size_t TotalViewLength(const std::vector& views) { return total_view_length; } +std::vector Views(const std::vector& strings) { + std::vector views(strings.size()); + std::transform(strings.begin(), strings.end(), views.begin(), + [](const std::string& s) { return std::string_view(s); }); + return views; +} + Status Parse(BlockParser& parser, const std::string& str, uint32_t* out_size) { - return parser.Parse(util::string_view(str), out_size); + return parser.Parse(std::string_view(str), out_size); } Status ParseFinal(BlockParser& parser, const std::string& str, uint32_t* out_size) { - return parser.ParseFinal(util::string_view(str), out_size); + return parser.ParseFinal(std::string_view(str), out_size); } void AssertParseOk(BlockParser& parser, const std::string& str) { @@ -142,7 +150,7 @@ void AssertParseOk(BlockParser& parser, const std::string& str) { ASSERT_EQ(parsed_size, str.size()); } -void AssertParseOk(BlockParser& parser, const std::vector& data) { +void AssertParseOk(BlockParser& parser, const std::vector& data) { uint32_t parsed_size = static_cast(-1); ASSERT_OK(parser.Parse(data, &parsed_size)); ASSERT_EQ(parsed_size, TotalViewLength(data)); @@ -154,7 +162,7 @@ void AssertParseFinal(BlockParser& parser, const std::string& str) { ASSERT_EQ(parsed_size, str.size()); } -void AssertParseFinal(BlockParser& parser, const std::vector& data) { +void AssertParseFinal(BlockParser& parser, const std::vector& data) { uint32_t parsed_size = static_cast(-1); ASSERT_OK(parser.ParseFinal(data, &parsed_size)); ASSERT_EQ(parsed_size, TotalViewLength(data)); @@ -167,15 +175,16 @@ void AssertParsePartial(BlockParser& parser, const std::string& str, ASSERT_EQ(parsed_size, expected_size); } -void AssertLastRowEq(const BlockParser& parser, const std::vector expected) { +void AssertLastRowEq(const BlockParser& parser, + const std::vector& expected) { std::vector values; GetLastRow(parser, &values); ASSERT_EQ(parser.num_rows(), expected.size()); ASSERT_EQ(values, expected); } -void AssertLastRowEq(const BlockParser& parser, const std::vector expected, - const std::vector expected_quoted) { +void AssertLastRowEq(const BlockParser& parser, const std::vector& expected, + const std::vector& expected_quoted) { std::vector values; std::vector quoted; GetLastRow(parser, &values, "ed); @@ -185,7 +194,7 @@ void AssertLastRowEq(const BlockParser& parser, const std::vector e } void AssertColumnEq(const BlockParser& parser, int32_t col_index, - const std::vector expected) { + const std::vector& expected) { std::vector values; GetColumn(parser, col_index, &values); ASSERT_EQ(parser.num_rows(), expected.size()); @@ -193,8 +202,8 @@ void AssertColumnEq(const BlockParser& parser, int32_t col_index, } void AssertColumnEq(const BlockParser& parser, int32_t col_index, - const std::vector expected, - const std::vector expected_quoted) { + const std::vector& expected, + const std::vector& expected_quoted) { std::vector values; std::vector quoted; GetColumn(parser, col_index, &values, "ed); @@ -204,7 +213,7 @@ void AssertColumnEq(const BlockParser& parser, int32_t col_index, } void AssertColumnsEq(const BlockParser& parser, - const std::vector> expected) { + const std::vector>& expected) { ASSERT_EQ(parser.num_cols(), expected.size()); for (int32_t col_index = 0; col_index < parser.num_cols(); ++col_index) { AssertColumnEq(parser, col_index, expected[col_index]); @@ -212,8 +221,8 @@ void AssertColumnsEq(const BlockParser& parser, } void AssertColumnsEq(const BlockParser& parser, - const std::vector> expected, - const std::vector> quoted) { + const std::vector>& expected, + const std::vector>& quoted) { ASSERT_EQ(parser.num_cols(), expected.size()); for (int32_t col_index = 0; col_index < parser.num_cols(); ++col_index) { AssertColumnEq(parser, col_index, expected[col_index], quoted[col_index]); @@ -238,9 +247,9 @@ TEST(BlockParser, Basics) { { auto csv1 = MakeCSVData({"ab,cd,\n", "ef,,gh\n"}); auto csv2 = MakeCSVData({",ij,kl\n"}); - std::vector csvs = {csv1, csv2}; + std::vector csvs = {csv1, csv2}; BlockParser parser(ParseOptions::Defaults()); - AssertParseOk(parser, {{csv1}, {csv2}}); + AssertParseOk(parser, csvs); AssertColumnsEq(parser, {{"ab", "ef", ""}, {"cd", "", "ij"}, {"", "gh", "kl"}}); AssertLastRowEq(parser, {"", "ij", "kl"}, {false, false, false}); } @@ -392,7 +401,8 @@ TEST(BlockParser, Final) { // Two blocks auto csv1 = MakeCSVData({"ab,cd\n"}); auto csv2 = MakeCSVData({"ef,"}); - AssertParseFinal(parser, {{csv1}, {csv2}}); + std::vector csvs = {csv1, csv2}; + AssertParseFinal(parser, csvs); AssertColumnsEq(parser, {{"ab", "ef"}, {"cd", ""}}); } @@ -596,7 +606,7 @@ TEST(BlockParser, MismatchingNumColumnsHandler) { operator InvalidRowHandler() { return [this](const InvalidRow& row) { // Copy the row to a string since the array behind the string_view can go away - rows.emplace_back(row, row.text.to_string()); + rows.emplace_back(row, row.text); return InvalidRowResult::Skip; }; } diff --git a/cpp/src/arrow/csv/reader.cc b/cpp/src/arrow/csv/reader.cc index d770fa734f5..fdc7fcb1380 100644 --- a/cpp/src/arrow/csv/reader.cc +++ b/cpp/src/arrow/csv/reader.cc @@ -406,7 +406,7 @@ class BlockParsingOperator { io_context_.pool(), parse_options_, num_csv_cols_, num_rows_seen_, max_num_rows); std::shared_ptr straddling; - std::vector views; + std::vector views; if (block.partial->size() != 0 || block.completion->size() != 0) { if (block.partial->size() == 0) { straddling = block.completion; @@ -417,9 +417,9 @@ class BlockParsingOperator { straddling, ConcatenateBuffers({block.partial, block.completion}, io_context_.pool())); } - views = {util::string_view(*straddling), util::string_view(*block.buffer)}; + views = {std::string_view(*straddling), std::string_view(*block.buffer)}; } else { - views = {util::string_view(*block.buffer)}; + views = {std::string_view(*block.buffer)}; } uint32_t parsed_size; if (block.is_final) { @@ -588,7 +588,7 @@ class ReaderMixin { num_rows_seen_, 1); uint32_t parsed_size = 0; RETURN_NOT_OK(parser.Parse( - util::string_view(reinterpret_cast(data), data_end - data), + std::string_view(reinterpret_cast(data), data_end - data), &parsed_size)); if (parser.num_rows() != 1) { return Status::Invalid( @@ -718,7 +718,7 @@ class ReaderMixin { io_context_.pool(), parse_options_, num_csv_cols_, num_rows_seen_, max_num_rows); std::shared_ptr straddling; - std::vector views; + std::vector views; if (partial->size() != 0 || completion->size() != 0) { if (partial->size() == 0) { straddling = completion; @@ -728,9 +728,9 @@ class ReaderMixin { ARROW_ASSIGN_OR_RAISE( straddling, ConcatenateBuffers({partial, completion}, io_context_.pool())); } - views = {util::string_view(*straddling), util::string_view(*block)}; + views = {std::string_view(*straddling), std::string_view(*block)}; } else { - views = {util::string_view(*block)}; + views = {std::string_view(*block)}; } uint32_t parsed_size; if (is_final) { diff --git a/cpp/src/arrow/csv/test_common.cc b/cpp/src/arrow/csv/test_common.cc index 6ba4ff2e3cf..648ad18e3c6 100644 --- a/cpp/src/arrow/csv/test_common.cc +++ b/cpp/src/arrow/csv/test_common.cc @@ -35,7 +35,7 @@ void MakeCSVParser(std::vector lines, ParseOptions options, int32_t auto csv = MakeCSVData(lines); auto parser = std::make_shared(options, num_cols); uint32_t out_size; - ASSERT_OK(parser->Parse(util::string_view(csv), &out_size)); + ASSERT_OK(parser->Parse(std::string_view(csv), &out_size)); ASSERT_EQ(out_size, csv.size()) << "trailing CSV data not parsed"; *out = parser; } diff --git a/cpp/src/arrow/csv/writer.cc b/cpp/src/arrow/csv/writer.cc index 95c2e03a10c..bb8d555a789 100644 --- a/cpp/src/arrow/csv/writer.cc +++ b/cpp/src/arrow/csv/writer.cc @@ -99,7 +99,7 @@ RecordBatchIterator RecordBatchSliceIterator(const RecordBatch& batch, } // Counts the number of quotes in s. -int64_t CountQuotes(arrow::util::string_view s) { +int64_t CountQuotes(std::string_view s) { return static_cast(std::count(s.begin(), s.end(), '"')); } @@ -155,7 +155,7 @@ class ColumnPopulator { // Copies the contents of s to out properly escaping any necessary characters. // Returns the position next to last copied character. -char* Escape(arrow::util::string_view s, char* out) { +char* Escape(std::string_view s, char* out) { for (const char c : s) { *out++ = c; if (c == '"') { @@ -189,7 +189,7 @@ class UnquotedColumnPopulator : public ColumnPopulator { int64_t row_number = 0; VisitArraySpanInline( *casted_array_->data(), - [&](arrow::util::string_view s) { + [&](std::string_view s) { row_lengths[row_number] += static_cast(s.length()); row_number++; }, @@ -202,7 +202,7 @@ class UnquotedColumnPopulator : public ColumnPopulator { Status PopulateRows(char* output, int64_t* offsets) const override { // Function applied to valid values cast to string. - auto valid_function = [&](arrow::util::string_view s) { + auto valid_function = [&](std::string_view s) { memcpy(output + *offsets, s.data(), s.length()); CopyEndChars(output + *offsets + s.length(), end_chars_.c_str(), end_chars_.size()); *offsets += static_cast(s.length() + end_chars_.size()); @@ -290,7 +290,7 @@ class QuotedColumnPopulator : public ColumnPopulator { int row_number = 0; VisitArraySpanInline( *input.data(), - [&](arrow::util::string_view s) { + [&](std::string_view s) { row_lengths[row_number] += static_cast(s.length()) + kQuoteCount; row_number++; }, @@ -302,7 +302,7 @@ class QuotedColumnPopulator : public ColumnPopulator { int row_number = 0; VisitArraySpanInline( *input.data(), - [&](arrow::util::string_view s) { + [&](std::string_view s) { // Each quote in the value string needs to be escaped. int64_t escaped_count = CountQuotes(s); row_needs_escaping_[row_number] = escaped_count > 0; @@ -322,7 +322,7 @@ class QuotedColumnPopulator : public ColumnPopulator { auto needs_escaping = row_needs_escaping_.begin(); VisitArraySpanInline( *(casted_array_->data()), - [&](arrow::util::string_view s) { + [&](std::string_view s) { // still needs string content length to be added char* row = output + *offsets; *row++ = '"'; diff --git a/cpp/src/arrow/dataset/dataset_writer.cc b/cpp/src/arrow/dataset/dataset_writer.cc index bad363b3818..d8e00054e1c 100644 --- a/cpp/src/arrow/dataset/dataset_writer.cc +++ b/cpp/src/arrow/dataset/dataset_writer.cc @@ -38,7 +38,7 @@ namespace internal { namespace { -constexpr util::string_view kIntegerToken = "{i}"; +constexpr std::string_view kIntegerToken = "{i}"; class Throttle { public: @@ -414,16 +414,16 @@ class DatasetWriterDirectoryQueue { uint32_t file_counter_ = 0; }; -Status ValidateBasenameTemplate(util::string_view basename_template) { - if (basename_template.find(fs::internal::kSep) != util::string_view::npos) { +Status ValidateBasenameTemplate(std::string_view basename_template) { + if (basename_template.find(fs::internal::kSep) != std::string_view::npos) { return Status::Invalid("basename_template contained '/'"); } size_t token_start = basename_template.find(kIntegerToken); - if (token_start == util::string_view::npos) { + if (token_start == std::string_view::npos) { return Status::Invalid("basename_template did not contain '", kIntegerToken, "'"); } size_t next_token_start = basename_template.find(kIntegerToken, token_start + 1); - if (next_token_start != util::string_view::npos) { + if (next_token_start != std::string_view::npos) { return Status::Invalid("basename_template contained '", kIntegerToken, "' more than once"); } diff --git a/cpp/src/arrow/dataset/dataset_writer_test.cc b/cpp/src/arrow/dataset/dataset_writer_test.cc index edc9bc8bbc1..6c9c2927393 100644 --- a/cpp/src/arrow/dataset/dataset_writer_test.cc +++ b/cpp/src/arrow/dataset/dataset_writer_test.cc @@ -130,7 +130,7 @@ class DatasetWriterTestFixture : public testing::Test { << "The file " << expected_path << " was not in the list of files visited"; } - std::shared_ptr ReadAsBatch(util::string_view data, int* num_batches) { + std::shared_ptr ReadAsBatch(std::string_view data, int* num_batches) { std::shared_ptr in_stream = std::make_shared(data); EXPECT_OK_AND_ASSIGN(std::shared_ptr reader, diff --git a/cpp/src/arrow/dataset/discovery.cc b/cpp/src/arrow/dataset/discovery.cc index 25fa7ff2b70..a38ec00bb91 100644 --- a/cpp/src/arrow/dataset/discovery.cc +++ b/cpp/src/arrow/dataset/discovery.cc @@ -30,8 +30,12 @@ #include "arrow/dataset/type_fwd.h" #include "arrow/filesystem/path_util.h" #include "arrow/util/logging.h" +#include "arrow/util/string.h" namespace arrow { + +using internal::StartsWith; + namespace dataset { DatasetFactory::DatasetFactory() : root_partition_(compute::literal(true)) {} @@ -158,10 +162,9 @@ bool StartsWithAnyOf(const std::string& path, const std::vector& pr } auto parts = fs::internal::SplitAbstractPath(path); - return std::any_of(parts.cbegin(), parts.cend(), [&](util::string_view part) { - return std::any_of(prefixes.cbegin(), prefixes.cend(), [&](util::string_view prefix) { - return util::string_view(part).starts_with(prefix); - }); + return std::any_of(parts.cbegin(), parts.cend(), [&](std::string_view part) { + return std::any_of(prefixes.cbegin(), prefixes.cend(), + [&](std::string_view prefix) { return StartsWith(part, prefix); }); }); } diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index bfc710105ed..be963338b40 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -56,13 +56,13 @@ using RecordBatchGenerator = std::function>( Result> GetColumnNames( const csv::ReadOptions& read_options, const csv::ParseOptions& parse_options, - util::string_view first_block, MemoryPool* pool) { + std::string_view first_block, MemoryPool* pool) { // Skip BOM when reading column names (ARROW-14644, ARROW-17382) auto size = first_block.length(); const uint8_t* data = reinterpret_cast(first_block.data()); ARROW_ASSIGN_OR_RAISE(auto data_no_bom, util::SkipUTF8BOM(data, size)); size = size - static_cast(data_no_bom - data); - first_block = util::string_view(reinterpret_cast(data_no_bom), size); + first_block = std::string_view(reinterpret_cast(data_no_bom), size); if (!read_options.column_names.empty()) { std::unordered_set column_names; for (const auto& s : read_options.column_names) { @@ -78,7 +78,7 @@ Result> GetColumnNames( csv::BlockParser parser(pool, parse_options, /*num_cols=*/-1, /*first_row=*/1, max_num_rows); - RETURN_NOT_OK(parser.Parse(util::string_view{first_block}, &parsed_size)); + RETURN_NOT_OK(parser.Parse(std::string_view{first_block}, &parsed_size)); if (parser.num_rows() != max_num_rows) { return Status::Invalid("Could not read first ", max_num_rows, @@ -104,7 +104,7 @@ Result> GetColumnNames( RETURN_NOT_OK( parser.VisitLastRow([&](const uint8_t* data, uint32_t size, bool quoted) -> Status { - util::string_view view{reinterpret_cast(data), size}; + std::string_view view{reinterpret_cast(data), size}; if (column_names.emplace(std::string(view)).second) { return Status::OK(); } @@ -116,7 +116,7 @@ Result> GetColumnNames( static inline Result GetConvertOptions( const CsvFileFormat& format, const ScanOptions* scan_options, - const util::string_view first_block) { + const std::string_view first_block) { ARROW_ASSIGN_OR_RAISE( auto csv_scan_options, GetFragmentScanOptions( diff --git a/cpp/src/arrow/dataset/partition.cc b/cpp/src/arrow/dataset/partition.cc index a9744d0aabf..48594336878 100644 --- a/cpp/src/arrow/dataset/partition.cc +++ b/cpp/src/arrow/dataset/partition.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -37,7 +38,6 @@ #include "arrow/util/int_util_overflow.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" -#include "arrow/util/string_view.h" #include "arrow/util/uri.h" #include "arrow/util/utf8.h" @@ -45,7 +45,7 @@ namespace arrow { using internal::checked_cast; using internal::checked_pointer_cast; -using util::string_view; +using std::string_view; using internal::DictionaryMemoTable; @@ -53,7 +53,7 @@ namespace dataset { namespace { /// Apply UriUnescape, then ensure the results are valid UTF-8. -Result SafeUriUnescape(util::string_view encoded) { +Result SafeUriUnescape(std::string_view encoded) { auto decoded = ::arrow::internal::UriUnescape(encoded); if (!util::ValidateUTF8(decoded)) { return Status::Invalid("Partition segment was not valid UTF-8 after URL decoding: ", @@ -482,7 +482,7 @@ class KeyValuePartitioningFactory : public PartitioningFactory { } } - Status InsertRepr(int index, util::string_view repr) { + Status InsertRepr(int index, std::string_view repr) { int dummy; return repr_memos_[index]->GetOrInsert(repr, &dummy); } @@ -738,9 +738,9 @@ Result> HivePartitioning::ParseKey( break; } case SegmentEncoding::Uri: { - auto raw_value = util::string_view(segment).substr(name_end + 1); + auto raw_value = std::string_view(segment).substr(name_end + 1); ARROW_ASSIGN_OR_RAISE(value, SafeUriUnescape(raw_value)); - auto raw_key = util::string_view(segment).substr(0, name_end); + auto raw_key = std::string_view(segment).substr(0, name_end); ARROW_ASSIGN_OR_RAISE(name, SafeUriUnescape(raw_key)); break; } diff --git a/cpp/src/arrow/dataset/scanner_test.cc b/cpp/src/arrow/dataset/scanner_test.cc index 0768014b862..6cafe10f78a 100644 --- a/cpp/src/arrow/dataset/scanner_test.cc +++ b/cpp/src/arrow/dataset/scanner_test.cc @@ -44,7 +44,6 @@ #include "arrow/util/vector.h" using testing::ElementsAre; -using testing::IsEmpty; using testing::UnorderedElementsAreArray; namespace arrow { @@ -1265,11 +1264,11 @@ TEST(ScanOptions, TestMaterializedFields) { // empty dataset, project nothing = nothing materialized opts->dataset_schema = schema({}); set_projection_from_names({}); - EXPECT_THAT(opts->MaterializedFields(), IsEmpty()); + ASSERT_EQ(opts->MaterializedFields().size(), 0); // non-empty dataset, project nothing = nothing materialized opts->dataset_schema = schema({i32, i64}); - EXPECT_THAT(opts->MaterializedFields(), IsEmpty()); + ASSERT_EQ(opts->MaterializedFields().size(), 0); // project nothing, filter on i32 = materialize i32 opts->filter = equal(field_ref("i32"), literal(10)); diff --git a/cpp/src/arrow/dataset/test_util.h b/cpp/src/arrow/dataset/test_util.h index fb54dc3a91a..e17afd25a34 100644 --- a/cpp/src/arrow/dataset/test_util.h +++ b/cpp/src/arrow/dataset/test_util.h @@ -508,11 +508,11 @@ class FileFormatFixtureMixin : public ::testing::Test { bool supported = false; - std::shared_ptr buf = std::make_shared(util::string_view("")); + std::shared_ptr buf = std::make_shared(std::string_view("")); ASSERT_OK_AND_ASSIGN(supported, format_->IsSupported(FileSource(buf))); ASSERT_EQ(supported, false); - buf = std::make_shared(util::string_view("corrupted")); + buf = std::make_shared(std::string_view("corrupted")); ASSERT_OK_AND_ASSIGN(supported, format_->IsSupported(FileSource(buf))); ASSERT_EQ(supported, false); @@ -985,7 +985,7 @@ class JSONRecordBatchFileFormat : public FileFormat { ARROW_ASSIGN_OR_RAISE(auto buffer, file->Read(size)); ARROW_ASSIGN_OR_RAISE(auto schema, Inspect(fragment->source())); - RecordBatchVector batches{RecordBatchFromJSON(schema, util::string_view{*buffer})}; + RecordBatchVector batches{RecordBatchFromJSON(schema, std::string_view{*buffer})}; return MakeVectorGenerator(std::move(batches)); } @@ -1479,7 +1479,7 @@ class WriteFileSystemDatasetMixin : public MakeFileSystemDatasetMixin { } auto expected_struct = ArrayFromJSON(struct_(expected_physical_schema_->fields()), - {file_contents->second}); + file_contents->second); AssertArraysEqual(*expected_struct, *actual_struct, /*verbose=*/true); } diff --git a/cpp/src/arrow/engine/simple_extension_type_internal.h b/cpp/src/arrow/engine/simple_extension_type_internal.h index 66d86088a76..c3f0226283d 100644 --- a/cpp/src/arrow/engine/simple_extension_type_internal.h +++ b/cpp/src/arrow/engine/simple_extension_type_internal.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include "arrow/extension_type.h" @@ -41,7 +42,7 @@ namespace engine { /// Note: The serialization is a very barebones JSON-like format and /// probably shouldn't be hand-edited -template GetStorage(const Params&)> class SimpleExtensionType : public ExtensionType { @@ -67,7 +68,7 @@ class SimpleExtensionType : public ExtensionType { return &::arrow::internal::checked_cast(type).params_; } - std::string extension_name() const override { return kExtensionName.to_string(); } + std::string extension_name() const override { return std::string(kExtensionName); } std::string ToString() const override { return "extension<" + this->Serialize() + ">"; } @@ -101,16 +102,15 @@ class SimpleExtensionType : public ExtensionType { } struct DeserializeImpl { - explicit DeserializeImpl(util::string_view repr) { + explicit DeserializeImpl(std::string_view repr) { Init(kExtensionName, repr, kProperties->size()); kProperties->ForEach(*this); } void Fail() { params_ = std::nullopt; } - void Init(util::string_view class_name, util::string_view repr, - size_t num_properties) { - if (!repr.starts_with(class_name)) return Fail(); + void Init(std::string_view class_name, std::string_view repr, size_t num_properties) { + if (!::arrow::internal::StartsWith(repr, class_name)) return Fail(); repr = repr.substr(class_name.size()); if (repr.empty()) return Fail(); @@ -127,7 +127,7 @@ class SimpleExtensionType : public ExtensionType { if (!params_) return; auto first_colon = members_[i].find_first_of(':'); - if (first_colon == util::string_view::npos) return Fail(); + if (first_colon == std::string_view::npos) return Fail(); auto name = members_[i].substr(0, first_colon); if (name != prop.name()) return Fail(); @@ -135,7 +135,7 @@ class SimpleExtensionType : public ExtensionType { auto value_repr = members_[i].substr(first_colon + 1); typename Property::Type value; try { - std::stringstream ss(value_repr.to_string()); + std::stringstream ss{std::string{value_repr}}; ss >> value; if (!ss.eof()) return Fail(); } catch (...) { @@ -145,7 +145,7 @@ class SimpleExtensionType : public ExtensionType { } std::optional params_; - std::vector members_; + std::vector members_; }; Result> Deserialize( std::shared_ptr storage_type, @@ -179,7 +179,7 @@ class SimpleExtensionType : public ExtensionType { } std::string Finish() { - return kExtensionName.to_string() + "{" + + return std::string(kExtensionName) + "{" + ::arrow::internal::JoinStrings(members_, ",") + "}"; } diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index ec0578828a6..6f181ac0218 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -906,10 +906,9 @@ Result> EncodeSubstraitCa substrait::FunctionArgument* arg = scalar_fn->add_arguments(); if (call.HasEnumArg(i)) { auto enum_val = internal::make_unique(); - ARROW_ASSIGN_OR_RAISE(std::optional enum_arg, - call.GetEnumArg(i)); + ARROW_ASSIGN_OR_RAISE(std::optional enum_arg, call.GetEnumArg(i)); if (enum_arg) { - enum_val->set_specified(enum_arg->to_string()); + enum_val->set_specified(std::string(*enum_arg)); } else { enum_val->set_allocated_unspecified(new google::protobuf::Empty()); } diff --git a/cpp/src/arrow/engine/substrait/ext_test.cc b/cpp/src/arrow/engine/substrait/ext_test.cc index 4b37aa8fcdb..525af7c9464 100644 --- a/cpp/src/arrow/engine/substrait/ext_test.cc +++ b/cpp/src/arrow/engine/substrait/ext_test.cc @@ -68,7 +68,7 @@ bool operator!=(const Id& id1, const Id& id2) { return !(id1 == id2); } struct TypeName { std::shared_ptr type; - util::string_view name; + std::string_view name; }; static const std::vector kTypeNames = { @@ -87,7 +87,7 @@ static const std::vector kFunctionIds = { {kSubstraitArithmeticFunctionsUri, "add"}, }; -static const std::vector kTempFunctionNames = { +static const std::vector kTempFunctionNames = { "temp_func_1", "temp_func_2", }; @@ -156,7 +156,7 @@ TEST_P(ExtensionIdRegistryTest, ReregisterFunctions) { for (Id function_id : kFunctionIds) { ASSERT_RAISES(Invalid, registry->CanAddSubstraitCallToArrow(function_id)); ASSERT_RAISES(Invalid, registry->AddSubstraitCallToArrow( - function_id, function_id.name.to_string())); + function_id, std::string(function_id.name))); } } @@ -206,12 +206,12 @@ TEST(ExtensionIdRegistryTest, RegisterTempFunctions) { for (int i = 0; i < rounds; i++) { auto registry = MakeExtensionIdRegistry(); - for (util::string_view name : kTempFunctionNames) { + for (std::string_view name : kTempFunctionNames) { auto id = Id{kArrowExtTypesUri, name}; ASSERT_OK(registry->CanAddSubstraitCallToArrow(id)); - ASSERT_OK(registry->AddSubstraitCallToArrow(id, name.to_string())); + ASSERT_OK(registry->AddSubstraitCallToArrow(id, std::string(name))); ASSERT_RAISES(Invalid, registry->CanAddSubstraitCallToArrow(id)); - ASSERT_RAISES(Invalid, registry->AddSubstraitCallToArrow(id, name.to_string())); + ASSERT_RAISES(Invalid, registry->AddSubstraitCallToArrow(id, std::string(name))); ASSERT_OK(default_registry->CanAddSubstraitCallToArrow(id)); } } @@ -248,8 +248,8 @@ TEST(ExtensionIdRegistryTest, RegisterNestedTypes) { } TEST(ExtensionIdRegistryTest, RegisterNestedFunctions) { - util::string_view name1 = kTempFunctionNames[0]; - util::string_view name2 = kTempFunctionNames[1]; + std::string_view name1 = kTempFunctionNames[0]; + std::string_view name2 = kTempFunctionNames[1]; auto id1 = Id{kArrowExtTypesUri, name1}; auto id2 = Id{kArrowExtTypesUri, name2}; @@ -259,20 +259,20 @@ TEST(ExtensionIdRegistryTest, RegisterNestedFunctions) { auto registry1 = MakeExtensionIdRegistry(); ASSERT_OK(registry1->CanAddSubstraitCallToArrow(id1)); - ASSERT_OK(registry1->AddSubstraitCallToArrow(id1, name1.to_string())); + ASSERT_OK(registry1->AddSubstraitCallToArrow(id1, std::string(name1))); for (int j = 0; j < rounds; j++) { auto registry2 = MakeExtensionIdRegistry(); ASSERT_OK(registry2->CanAddSubstraitCallToArrow(id2)); - ASSERT_OK(registry2->AddSubstraitCallToArrow(id2, name2.to_string())); + ASSERT_OK(registry2->AddSubstraitCallToArrow(id2, std::string(name2))); ASSERT_RAISES(Invalid, registry2->CanAddSubstraitCallToArrow(id2)); - ASSERT_RAISES(Invalid, registry2->AddSubstraitCallToArrow(id2, name2.to_string())); + ASSERT_RAISES(Invalid, registry2->AddSubstraitCallToArrow(id2, std::string(name2))); ASSERT_OK(default_registry->CanAddSubstraitCallToArrow(id2)); } ASSERT_RAISES(Invalid, registry1->CanAddSubstraitCallToArrow(id1)); - ASSERT_RAISES(Invalid, registry1->AddSubstraitCallToArrow(id1, name1.to_string())); + ASSERT_RAISES(Invalid, registry1->AddSubstraitCallToArrow(id1, std::string(name1))); ASSERT_OK(default_registry->CanAddSubstraitCallToArrow(id1)); } } diff --git a/cpp/src/arrow/engine/substrait/extension_set.cc b/cpp/src/arrow/engine/substrait/extension_set.cc index f7fcd1e1279..2f7c85c9d5c 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.cc +++ b/cpp/src/arrow/engine/substrait/extension_set.cc @@ -25,7 +25,6 @@ #include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" #include "arrow/util/make_unique.h" -#include "arrow/util/string_view.h" namespace arrow { namespace engine { @@ -68,9 +67,9 @@ bool IdHashEq::operator()(Id l, Id r) const { return l.uri == r.uri && l.name == class IdStorageImpl : public IdStorage { public: Id Emplace(Id id) override { - util::string_view owned_uri = EmplaceUri(id.uri); + std::string_view owned_uri = EmplaceUri(id.uri); - util::string_view owned_name; + std::string_view owned_name; auto name_itr = names_.find(id.name); if (name_itr == names_.end()) { owned_names_.emplace_back(id.name); @@ -84,7 +83,7 @@ class IdStorageImpl : public IdStorage { } std::optional Find(Id id) const override { - std::optional maybe_owned_uri = FindUri(id.uri); + std::optional maybe_owned_uri = FindUri(id.uri); if (!maybe_owned_uri) { return std::nullopt; } @@ -97,7 +96,7 @@ class IdStorageImpl : public IdStorage { } } - std::optional FindUri(util::string_view uri) const override { + std::optional FindUri(std::string_view uri) const override { auto uri_itr = uris_.find(uri); if (uri_itr == uris_.end()) { return std::nullopt; @@ -105,11 +104,11 @@ class IdStorageImpl : public IdStorage { return *uri_itr; } - util::string_view EmplaceUri(util::string_view uri) override { + std::string_view EmplaceUri(std::string_view uri) override { auto uri_itr = uris_.find(uri); if (uri_itr == uris_.end()) { owned_uris_.emplace_back(uri); - util::string_view owned_uri = owned_uris_.back(); + std::string_view owned_uri = owned_uris_.back(); uris_.insert(owned_uri); return owned_uri; } @@ -117,8 +116,8 @@ class IdStorageImpl : public IdStorage { } private: - std::unordered_set uris_; - std::unordered_set names_; + std::unordered_set uris_; + std::unordered_set names_; std::list owned_uris_; std::list owned_names_; }; @@ -127,7 +126,7 @@ std::unique_ptr IdStorage::Make() { return ::arrow::internal::make_unique(); } -Result> SubstraitCall::GetEnumArg(uint32_t index) const { +Result> SubstraitCall::GetEnumArg(uint32_t index) const { if (index >= size_) { return Status::Invalid("Expected Substrait call to have an enum argument at index ", index, " but it did not have enough arguments"); @@ -176,10 +175,10 @@ void SubstraitCall::SetValueArg(uint32_t index, compute::Expression value_arg) { // a map of what Ids we have seen. ExtensionSet::ExtensionSet(const ExtensionIdRegistry* registry) : registry_(registry) {} -Status ExtensionSet::CheckHasUri(util::string_view uri) { +Status ExtensionSet::CheckHasUri(std::string_view uri) { auto it = std::find_if(uris_.begin(), uris_.end(), - [&uri](const std::pair& anchor_uri_pair) { + [&uri](const std::pair& anchor_uri_pair) { return anchor_uri_pair.second == uri; }); if (it != uris_.end()) return Status::OK(); @@ -189,10 +188,10 @@ Status ExtensionSet::CheckHasUri(util::string_view uri) { " was referenced by an extension but was not declared in the ExtensionSet."); } -void ExtensionSet::AddUri(std::pair uri) { +void ExtensionSet::AddUri(std::pair uri) { auto it = std::find_if(uris_.begin(), uris_.end(), - [&uri](const std::pair& anchor_uri_pair) { + [&uri](const std::pair& anchor_uri_pair) { return anchor_uri_pair.second == uri.second; }); if (it != uris_.end()) return; @@ -211,14 +210,14 @@ Status ExtensionSet::AddUri(Id id) { // Creates an extension set from the Substrait plan's top-level extensions block Result ExtensionSet::Make( - std::unordered_map uris, + std::unordered_map uris, std::unordered_map type_ids, std::unordered_map function_ids, const ExtensionIdRegistry* registry) { ExtensionSet set(default_extension_id_registry()); set.registry_ = registry; for (auto& uri : uris) { - std::optional maybe_uri_internal = registry->FindUri(uri.second); + std::optional maybe_uri_internal = registry->FindUri(uri.second); if (maybe_uri_internal) { set.uris_[uri.first] = *maybe_uri_internal; } else { @@ -324,9 +323,9 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { virtual ~ExtensionIdRegistryImpl() {} - std::optional FindUri(util::string_view uri) const override { + std::optional FindUri(std::string_view uri) const override { if (parent_) { - std::optional parent_uri = parent_->FindUri(uri); + std::optional parent_uri = parent_->FindUri(uri); if (parent_uri) { return parent_uri; } @@ -620,7 +619,7 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry { }; template -using EnumParser = std::function(std::optional)>; +using EnumParser = std::function(std::optional)>; template EnumParser GetEnumParser(const std::vector& options) { @@ -628,12 +627,12 @@ EnumParser GetEnumParser(const std::vector& options) { for (std::size_t i = 0; i < options.size(); i++) { parse_map[options[i]] = static_cast(i + 1); } - return [parse_map](std::optional enum_val) -> Result { + return [parse_map](std::optional enum_val) -> Result { if (!enum_val) { // Assumes 0 is always kUnspecified in Enum return static_cast(0); } - auto maybe_parsed = parse_map.find(enum_val->to_string()); + auto maybe_parsed = parse_map.find(std::string(*enum_val)); if (maybe_parsed == parse_map.end()) { return Status::Invalid("The value ", *enum_val, " is not an expected enum value"); } @@ -655,7 +654,7 @@ static EnumParser kOverflowParser = template Result ParseEnumArg(const SubstraitCall& call, uint32_t arg_index, const EnumParser& parser) { - ARROW_ASSIGN_OR_RAISE(std::optional enum_arg, + ARROW_ASSIGN_OR_RAISE(std::optional enum_arg, call.GetEnumArg(arg_index)); return parser(enum_arg); } @@ -808,7 +807,7 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { // ----------- Extension Types ---------------------------- struct TypeName { std::shared_ptr type; - util::string_view name; + std::string_view name; }; // The type (variation) mappings listed below need to be kept in sync @@ -847,14 +846,14 @@ struct DefaultExtensionIdRegistry : ExtensionIdRegistryImpl { } // Basic binary mappings for (const auto& function_name : - std::vector>{ + std::vector>{ {kSubstraitBooleanFunctionsUri, "xor"}, {kSubstraitComparisonFunctionsUri, "equal"}, {kSubstraitComparisonFunctionsUri, "not_equal"}}) { - DCHECK_OK( - AddSubstraitCallToArrow({function_name.first, function_name.second}, - DecodeOptionlessBasicMapping( - function_name.second.to_string(), /*max_args=*/2))); + DCHECK_OK(AddSubstraitCallToArrow( + {function_name.first, function_name.second}, + DecodeOptionlessBasicMapping(std::string(function_name.second), + /*max_args=*/2))); } for (const auto& uri : {kSubstraitComparisonFunctionsUri, kSubstraitDatetimeFunctionsUri}) { diff --git a/cpp/src/arrow/engine/substrait/extension_set.h b/cpp/src/arrow/engine/substrait/extension_set.h index 46c83b81d16..4df8952ff9a 100644 --- a/cpp/src/arrow/engine/substrait/extension_set.h +++ b/cpp/src/arrow/engine/substrait/extension_set.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -35,7 +36,6 @@ #include "arrow/result.h" #include "arrow/type_fwd.h" #include "arrow/util/macros.h" -#include "arrow/util/string_view.h" namespace arrow { namespace engine { @@ -60,7 +60,7 @@ constexpr const char* kSubstraitAggregateGenericFunctionsUri = "functions_aggregate_generic.yaml"; struct Id { - util::string_view uri, name; + std::string_view uri, name; bool empty() const { return uri.empty() && name.empty(); } std::string ToString() const; }; @@ -86,7 +86,7 @@ class IdStorage { /// \brief Get an equivalent view pointing into this storage for a URI /// /// If no URI is found then the uri will be copied into storage - virtual util::string_view EmplaceUri(util::string_view uri) = 0; + virtual std::string_view EmplaceUri(std::string_view uri) = 0; /// \brief Get an equivalent id pointing into this storage /// /// If no id is found then nullopt will be returned @@ -94,7 +94,7 @@ class IdStorage { /// \brief Get an equivalent view pointing into this storage for a URI /// /// If no URI is found then nullopt will be returned - virtual std::optional FindUri(util::string_view uri) const = 0; + virtual std::optional FindUri(std::string_view uri) const = 0; static std::unique_ptr Make(); }; @@ -119,7 +119,7 @@ class SubstraitCall { bool is_hash() const { return is_hash_; } bool HasEnumArg(uint32_t index) const; - Result> GetEnumArg(uint32_t index) const; + Result> GetEnumArg(uint32_t index) const; void SetEnumArg(uint32_t index, std::optional enum_arg); Result GetValueArg(uint32_t index) const; bool HasValueArg(uint32_t index) const; @@ -174,7 +174,7 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { /// \brief Return a uri view owned by this registry /// /// If the URI has never been emplaced it will return nullopt - virtual std::optional FindUri(util::string_view uri) const = 0; + virtual std::optional FindUri(std::string_view uri) const = 0; /// \brief Return a id view owned by this registry /// /// If the id has never been emplaced it will return nullopt @@ -255,7 +255,7 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry { Id substrait_function_id) const = 0; }; -constexpr util::string_view kArrowExtTypesUri = +constexpr std::string_view kArrowExtTypesUri = "https://github.com/apache/arrow/blob/master/format/substrait/" "extension_types.yaml"; @@ -309,7 +309,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { public: struct FunctionRecord { Id id; - util::string_view name; + std::string_view name; }; struct TypeRecord { @@ -336,12 +336,12 @@ class ARROW_ENGINE_EXPORT ExtensionSet { /// An extension set should instead be created using /// arrow::engine::GetExtensionSetFromPlan static Result Make( - std::unordered_map uris, + std::unordered_map uris, std::unordered_map type_ids, std::unordered_map function_ids, const ExtensionIdRegistry* = default_extension_id_registry()); - const std::unordered_map& uris() const { return uris_; } + const std::unordered_map& uris() const { return uris_; } /// \brief Returns a data type given an anchor /// @@ -407,7 +407,7 @@ class ARROW_ENGINE_EXPORT ExtensionSet { std::unique_ptr plan_specific_ids_ = IdStorage::Make(); // Map from anchor values to URI values referenced by this extension set - std::unordered_map uris_; + std::unordered_map uris_; // Map from anchor values to type definitions, used during Substrait->Arrow // and populated from the Substrait extension set std::unordered_map types_; @@ -421,8 +421,8 @@ class ARROW_ENGINE_EXPORT ExtensionSet { // and built as the plan is created. std::unordered_map functions_map_; - Status CheckHasUri(util::string_view uri); - void AddUri(std::pair uri); + Status CheckHasUri(std::string_view uri); + void AddUri(std::pair uri); Status AddUri(Id id); }; diff --git a/cpp/src/arrow/engine/substrait/extension_types.cc b/cpp/src/arrow/engine/substrait/extension_types.cc index 2b7211766ee..6a89e3cf98b 100644 --- a/cpp/src/arrow/engine/substrait/extension_types.cc +++ b/cpp/src/arrow/engine/substrait/extension_types.cc @@ -17,9 +17,10 @@ #include "arrow/engine/substrait/extension_types.h" +#include + #include "arrow/engine/simple_extension_type_internal.h" #include "arrow/util/hashing.h" -#include "arrow/util/string_view.h" namespace arrow { @@ -29,7 +30,7 @@ using internal::MakeProperties; namespace engine { namespace { -constexpr util::string_view kUuidExtensionName = "uuid"; +constexpr std::string_view kUuidExtensionName = "uuid"; struct UuidExtensionParams {}; std::shared_ptr UuidGetStorage(const UuidExtensionParams&) { return fixed_size_binary(16); @@ -40,7 +41,7 @@ using UuidType = SimpleExtensionType; -constexpr util::string_view kFixedCharExtensionName = "fixed_char"; +constexpr std::string_view kFixedCharExtensionName = "fixed_char"; struct FixedCharExtensionParams { int32_t length; }; @@ -55,7 +56,7 @@ using FixedCharType = decltype(kFixedCharExtensionParamsProperties), &kFixedCharExtensionParamsProperties, FixedCharGetStorage>; -constexpr util::string_view kVarCharExtensionName = "varchar"; +constexpr std::string_view kVarCharExtensionName = "varchar"; struct VarCharExtensionParams { int32_t length; }; @@ -70,7 +71,7 @@ using VarCharType = decltype(kVarCharExtensionParamsProperties), &kVarCharExtensionParamsProperties, VarCharGetStorage>; -constexpr util::string_view kIntervalYearExtensionName = "interval_year"; +constexpr std::string_view kIntervalYearExtensionName = "interval_year"; struct IntervalYearExtensionParams {}; std::shared_ptr IntervalYearGetStorage(const IntervalYearExtensionParams&) { return fixed_size_list(int32(), 2); @@ -82,7 +83,7 @@ using IntervalYearType = decltype(kIntervalYearExtensionParamsProperties), &kIntervalYearExtensionParamsProperties, IntervalYearGetStorage>; -constexpr util::string_view kIntervalDayExtensionName = "interval_day"; +constexpr std::string_view kIntervalDayExtensionName = "interval_day"; struct IntervalDayExtensionParams {}; std::shared_ptr IntervalDayGetStorage(const IntervalDayExtensionParams&) { return fixed_size_list(int32(), 2); diff --git a/cpp/src/arrow/engine/substrait/extension_types.h b/cpp/src/arrow/engine/substrait/extension_types.h index c623d081b18..3b08084c753 100644 --- a/cpp/src/arrow/engine/substrait/extension_types.h +++ b/cpp/src/arrow/engine/substrait/extension_types.h @@ -20,13 +20,10 @@ #pragma once #include -#include -#include "arrow/buffer.h" #include "arrow/compute/function.h" #include "arrow/engine/substrait/visibility.h" #include "arrow/type_fwd.h" -#include "arrow/util/string_view.h" namespace arrow { namespace engine { diff --git a/cpp/src/arrow/engine/substrait/plan_internal.cc b/cpp/src/arrow/engine/substrait/plan_internal.cc index 1efd4e1a0a9..bd30f043a1b 100644 --- a/cpp/src/arrow/engine/substrait/plan_internal.cc +++ b/cpp/src/arrow/engine/substrait/plan_internal.cc @@ -40,7 +40,7 @@ using ::arrow::internal::make_unique; Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) { plan->clear_extension_uris(); - std::unordered_map map; + std::unordered_map map; auto uris = plan->mutable_extension_uris(); uris->Reserve(static_cast(ext_set.uris().size())); @@ -49,7 +49,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) if (uri.empty()) continue; auto ext_uri = internal::make_unique(); - ext_uri->set_uri(uri.to_string()); + ext_uri->set_uri(std::string(uri)); ext_uri->set_extension_uri_anchor(anchor); uris->AddAllocated(ext_uri.release()); @@ -70,7 +70,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) auto type = internal::make_unique(); type->set_extension_uri_reference(map[type_record.id.uri]); type->set_type_anchor(anchor); - type->set_name(type_record.id.name.to_string()); + type->set_name(std::string(type_record.id.name)); ext_decl->set_allocated_extension_type(type.release()); extensions->AddAllocated(ext_decl.release()); } @@ -81,7 +81,7 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan) auto fn = internal::make_unique(); fn->set_extension_uri_reference(map[function_id.uri]); fn->set_function_anchor(anchor); - fn->set_name(function_id.name.to_string()); + fn->set_name(std::string(function_id.name)); auto ext_decl = internal::make_unique(); ext_decl->set_allocated_extension_function(fn.release()); @@ -96,7 +96,7 @@ Result GetExtensionSetFromPlan(const substrait::Plan& plan, if (registry == NULLPTR) { registry = default_extension_id_registry(); } - std::unordered_map uris; + std::unordered_map uris; uris.reserve(plan.extension_uris_size()); for (const auto& uri : plan.extension_uris()) { uris[uri.extension_uri_anchor()] = uri.uri(); @@ -114,14 +114,14 @@ Result GetExtensionSetFromPlan(const substrait::Plan& plan, case substrait::extensions::SimpleExtensionDeclaration::kExtensionType: { const auto& type = ext.extension_type(); - util::string_view uri = uris[type.extension_uri_reference()]; + std::string_view uri = uris[type.extension_uri_reference()]; type_ids[type.type_anchor()] = Id{uri, type.name()}; break; } case substrait::extensions::SimpleExtensionDeclaration::kExtensionFunction: { const auto& fn = ext.extension_function(); - util::string_view uri = uris[fn.extension_uri_reference()]; + std::string_view uri = uris[fn.extension_uri_reference()]; function_ids[fn.function_anchor()] = Id{uri, fn.name()}; break; } diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index ed07f75f2b9..c920a1a46d0 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -31,13 +31,15 @@ #include "arrow/filesystem/util_internal.h" #include "arrow/util/checked_cast.h" #include "arrow/util/make_unique.h" +#include "arrow/util/string.h" #include "arrow/util/uri.h" namespace arrow { -using ::arrow::internal::UriFromAbsolutePath; using internal::checked_cast; using internal::make_unique; +using internal::StartsWith; +using internal::UriFromAbsolutePath; namespace engine { @@ -189,7 +191,7 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& "unknown substrait::ReadRel::LocalFiles::FileOrFiles::file_format"); } - if (!util::string_view{path}.starts_with("file:///")) { + if (!StartsWith(path, "file:///")) { return Status::NotImplemented("substrait::ReadRel::LocalFiles item (", path, ") with other than local filesystem " "(file:///)"); diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 1e1c61fc322..c7792c7c76e 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -31,7 +31,6 @@ #include "arrow/engine/substrait/type_fwd.h" #include "arrow/engine/substrait/type_internal.h" #include "arrow/type.h" -#include "arrow/util/string_view.h" #include #include @@ -315,7 +314,7 @@ static Status CheckMessagesEquivalent(const Buffer& l_buf, const Buffer& r_buf) return Status::Invalid("Messages were not equivalent: ", out); } -Status CheckMessagesEquivalent(util::string_view message_name, const Buffer& l_buf, +Status CheckMessagesEquivalent(std::string_view message_name, const Buffer& l_buf, const Buffer& r_buf) { if (message_name == "Type") { return CheckMessagesEquivalent(l_buf, r_buf); @@ -357,9 +356,9 @@ inline google::protobuf::util::TypeResolver* GetGeneratedTypeResolver() { return type_resolver.get(); } -Result> SubstraitFromJSON(util::string_view type_name, - util::string_view json) { - std::string type_url = "/substrait." + type_name.to_string(); +Result> SubstraitFromJSON(std::string_view type_name, + std::string_view json) { + std::string type_url = "/substrait." + std::string(type_name); google::protobuf::io::ArrayInputStream json_stream{json.data(), static_cast(json.size())}; @@ -378,8 +377,8 @@ Result> SubstraitFromJSON(util::string_view type_name, return Buffer::FromString(std::move(out)); } -Result SubstraitToJSON(util::string_view type_name, const Buffer& buf) { - std::string type_url = "/substrait." + type_name.to_string(); +Result SubstraitToJSON(std::string_view type_name, const Buffer& buf) { + std::string type_url = "/substrait." + std::string(type_name); google::protobuf::io::ArrayInputStream buf_stream{buf.data(), static_cast(buf.size())}; diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index cc59adb0d25..23683dba0c3 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "arrow/compute/type_fwd.h" @@ -33,7 +34,6 @@ #include "arrow/status.h" #include "arrow/type_fwd.h" #include "arrow/util/macros.h" -#include "arrow/util/string_view.h" namespace arrow { namespace engine { @@ -253,7 +253,7 @@ namespace internal { /// \param[in] r_buf buffer containing the second protobuf serialization to compare /// \return success if equivalent, failure if not ARROW_ENGINE_EXPORT -Status CheckMessagesEquivalent(util::string_view message_name, const Buffer& l_buf, +Status CheckMessagesEquivalent(std::string_view message_name, const Buffer& l_buf, const Buffer& r_buf); /// \brief Utility function to convert a JSON serialization of a Substrait message to @@ -263,8 +263,8 @@ Status CheckMessagesEquivalent(util::string_view message_name, const Buffer& l_b /// \param[in] json the JSON string to convert /// \return a buffer filled with the binary protobuf serialization of message ARROW_ENGINE_EXPORT -Result> SubstraitFromJSON(util::string_view type_name, - util::string_view json); +Result> SubstraitFromJSON(std::string_view type_name, + std::string_view json); /// \brief Utility function to convert a binary protobuf serialization of a Substrait /// message to JSON @@ -273,7 +273,7 @@ Result> SubstraitFromJSON(util::string_view type_name, /// \param[in] buf the buffer containing the binary protobuf serialization of the message /// \return a JSON string representing the message ARROW_ENGINE_EXPORT -Result SubstraitToJSON(util::string_view type_name, const Buffer& buf); +Result SubstraitToJSON(std::string_view type_name, const Buffer& buf); } // namespace internal } // namespace engine diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 7601bcf4370..de40a53cbc1 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -198,7 +198,7 @@ void CheckRoundTripResult(const std::shared_ptr output_schema, } TEST(Substrait, SupportedTypes) { - auto ExpectEq = [](util::string_view json, std::shared_ptr expected_type) { + auto ExpectEq = [](std::string_view json, std::shared_ptr expected_type) { ARROW_SCOPED_TRACE(json); ExtensionSet empty; @@ -396,12 +396,12 @@ TEST(Substrait, NoEquivalentSubstraitType) { } TEST(Substrait, SupportedLiterals) { - auto ExpectEq = [](util::string_view json, Datum expected_value) { + auto ExpectEq = [](std::string_view json, Datum expected_value) { ARROW_SCOPED_TRACE(json); ASSERT_OK_AND_ASSIGN( auto buf, internal::SubstraitFromJSON("Expression", - "{\"literal\":" + json.to_string() + "}")); + "{\"literal\":" + std::string(json) + "}")); ExtensionSet ext_set; ASSERT_OK_AND_ASSIGN(auto expr, DeserializeExpression(*buf, ext_set)); diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index 0df3420c234..867e33a7cd0 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -143,7 +143,7 @@ std::shared_ptr MakeExtensionIdRegistry() { } const std::string& default_extension_types_uri() { - static std::string uri = engine::kArrowExtTypesUri.to_string(); + static std::string uri(engine::kArrowExtTypesUri); return uri; } diff --git a/cpp/src/arrow/filesystem/filesystem.cc b/cpp/src/arrow/filesystem/filesystem.cc index 48b4646bea0..c8fa4d1c377 100644 --- a/cpp/src/arrow/filesystem/filesystem.cc +++ b/cpp/src/arrow/filesystem/filesystem.cc @@ -258,7 +258,7 @@ Result> FileSystem::OpenAppendStream( namespace { -Status ValidateSubPath(util::string_view s) { +Status ValidateSubPath(std::string_view s) { if (internal::IsLikelyUri(s)) { return Status::Invalid("Expected a filesystem path, got a URI: '", s, "'"); } @@ -639,7 +639,7 @@ Status CopyFiles(const std::shared_ptr& source_fs, } auto destination_path = - internal::ConcatAbstractPath(destination_base_dir, relative->to_string()); + internal::ConcatAbstractPath(destination_base_dir, std::string(*relative)); if (source_info.IsDirectory()) { dirs.push_back(destination_path); diff --git a/cpp/src/arrow/filesystem/gcsfs.cc b/cpp/src/arrow/filesystem/gcsfs.cc index da7b856be47..e7d1965d510 100644 --- a/cpp/src/arrow/filesystem/gcsfs.cc +++ b/cpp/src/arrow/filesystem/gcsfs.cc @@ -81,7 +81,7 @@ struct GcsPath { return Status::Invalid("Path cannot start with a separator ('", s, "')"); } if (first_sep == std::string::npos) { - return GcsPath{s, internal::RemoveTrailingSlash(s).to_string(), ""}; + return GcsPath{s, std::string(internal::RemoveTrailingSlash(s)), ""}; } GcsPath path; path.full_path = s; @@ -412,7 +412,7 @@ class GcsFileSystem::Impl { // limitations) using marker objects. That and listing with prefixes creates the // illusion of folders. google::cloud::StatusOr CreateDirMarker(const std::string& bucket, - util::string_view name) { + std::string_view name) { // Make the name canonical. const auto canonical = internal::EnsureTrailingSlash(name); google::cloud::StatusOr object = client_.InsertObject( diff --git a/cpp/src/arrow/filesystem/gcsfs_internal.cc b/cpp/src/arrow/filesystem/gcsfs_internal.cc index b8f0ab80b21..c984fe12f01 100644 --- a/cpp/src/arrow/filesystem/gcsfs_internal.cc +++ b/cpp/src/arrow/filesystem/gcsfs_internal.cc @@ -295,7 +295,7 @@ Result> FromObjectMetadata( return result; } -std::int64_t Depth(arrow::util::string_view path) { +std::int64_t Depth(std::string_view path) { // The last slash is not counted towards depth because it represents a // directory. bool has_trailing_slash = !path.empty() && path.back() == '/'; diff --git a/cpp/src/arrow/filesystem/gcsfs_internal.h b/cpp/src/arrow/filesystem/gcsfs_internal.h index 101f7f62df6..c2a0e2921dc 100644 --- a/cpp/src/arrow/filesystem/gcsfs_internal.h +++ b/cpp/src/arrow/filesystem/gcsfs_internal.h @@ -51,7 +51,7 @@ ARROW_EXPORT Result ToObjectMetadata ARROW_EXPORT Result> FromObjectMetadata( google::cloud::storage::ObjectMetadata const& m); -ARROW_EXPORT std::int64_t Depth(arrow::util::string_view path); +ARROW_EXPORT std::int64_t Depth(std::string_view path); } // namespace internal } // namespace fs diff --git a/cpp/src/arrow/filesystem/gcsfs_test.cc b/cpp/src/arrow/filesystem/gcsfs_test.cc index 50f9a32fa1a..48d56f7b7bb 100644 --- a/cpp/src/arrow/filesystem/gcsfs_test.cc +++ b/cpp/src/arrow/filesystem/gcsfs_test.cc @@ -73,7 +73,6 @@ namespace gcs = google::cloud::storage; using ::testing::Eq; using ::testing::HasSubstr; -using ::testing::IsEmpty; using ::testing::Not; using ::testing::NotNull; using ::testing::Pair; @@ -171,7 +170,7 @@ class GcsIntegrationTest : public ::testing::Test { protected: void SetUp() override { ASSERT_THAT(Testbench(), NotNull()); - ASSERT_THAT(Testbench()->error(), IsEmpty()); + ASSERT_TRUE(Testbench()->error().empty()); ASSERT_TRUE(Testbench()->running()); // Initialize a PRNG with a small amount of entropy. @@ -280,7 +279,7 @@ class GcsIntegrationTest : public ::testing::Test { std::transform(expected.begin(), expected.end(), expected.begin(), [](FileInfo const& info) { if (!info.IsDirectory()) return info; - return Dir(internal::RemoveTrailingSlash(info.path()).to_string()); + return Dir(std::string(internal::RemoveTrailingSlash(info.path()))); }); return expected; } @@ -767,7 +766,7 @@ TEST_F(GcsIntegrationTest, GetFileInfoSelectorNotFoundTrue) { selector.allow_not_found = true; selector.recursive = true; ASSERT_OK_AND_ASSIGN(auto results, fs->GetFileInfo(selector)); - EXPECT_THAT(results, IsEmpty()); + EXPECT_EQ(results.size(), 0); } TEST_F(GcsIntegrationTest, GetFileInfoSelectorNotFoundFalse) { diff --git a/cpp/src/arrow/filesystem/localfs.cc b/cpp/src/arrow/filesystem/localfs.cc index 585131ecc5e..03b4ad3bc72 100644 --- a/cpp/src/arrow/filesystem/localfs.cc +++ b/cpp/src/arrow/filesystem/localfs.cc @@ -85,7 +85,7 @@ bool DetectAbsolutePath(const std::string& s) { namespace { -Status ValidatePath(util::string_view s) { +Status ValidatePath(std::string_view s) { if (internal::IsLikelyUri(s)) { return Status::Invalid("Expected a local filesystem path, got a URI: '", s, "'"); } diff --git a/cpp/src/arrow/filesystem/localfs_benchmark.cc b/cpp/src/arrow/filesystem/localfs_benchmark.cc index 1eb15ccfe23..3c4ded7e537 100644 --- a/cpp/src/arrow/filesystem/localfs_benchmark.cc +++ b/cpp/src/arrow/filesystem/localfs_benchmark.cc @@ -16,6 +16,7 @@ // under the License. #include +#include #include "benchmark/benchmark.h" @@ -29,7 +30,6 @@ #include "arrow/util/async_generator.h" #include "arrow/util/io_util.h" #include "arrow/util/make_unique.h" -#include "arrow/util/string_view.h" namespace arrow { diff --git a/cpp/src/arrow/filesystem/mockfs.cc b/cpp/src/arrow/filesystem/mockfs.cc index d8302bed471..bb211e23df4 100644 --- a/cpp/src/arrow/filesystem/mockfs.cc +++ b/cpp/src/arrow/filesystem/mockfs.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -35,7 +36,6 @@ #include "arrow/util/async_generator.h" #include "arrow/util/future.h" #include "arrow/util/logging.h" -#include "arrow/util/string_view.h" #include "arrow/util/windows_fixup.h" namespace arrow { @@ -44,7 +44,7 @@ namespace internal { namespace { -Status ValidatePath(util::string_view s) { +Status ValidatePath(std::string_view s) { if (internal::IsLikelyUri(s)) { return Status::Invalid("Expected a filesystem path, got a URI: '", s, "'"); } @@ -66,9 +66,9 @@ struct File { int64_t size() const { return data ? data->size() : 0; } - explicit operator util::string_view() const { + explicit operator std::string_view() const { if (data) { - return util::string_view(*data); + return std::string_view(*data); } else { return ""; } @@ -372,7 +372,7 @@ class MockFileSystem::Impl { Entry* child = pair.second.get(); if (child->is_file()) { auto& file = child->as_file(); - out->push_back({path + file.name, file.mtime, util::string_view(file)}); + out->push_back({path + file.name, file.mtime, std::string_view(file)}); } else if (child->is_dir()) { DumpFiles(path, child->as_dir(), out); } @@ -752,7 +752,7 @@ std::vector MockFileSystem::AllFiles() { return result; } -Status MockFileSystem::CreateFile(const std::string& path, util::string_view contents, +Status MockFileSystem::CreateFile(const std::string& path, std::string_view contents, bool recursive) { RETURN_NOT_OK(ValidatePath(path)); auto parent = fs::internal::GetAbstractPathParent(path).first; diff --git a/cpp/src/arrow/filesystem/mockfs.h b/cpp/src/arrow/filesystem/mockfs.h index 2427d4a3bf7..fe86e19be4e 100644 --- a/cpp/src/arrow/filesystem/mockfs.h +++ b/cpp/src/arrow/filesystem/mockfs.h @@ -20,10 +20,10 @@ #include #include #include +#include #include #include "arrow/filesystem/filesystem.h" -#include "arrow/util/string_view.h" #include "arrow/util/windows_fixup.h" namespace arrow { @@ -44,7 +44,7 @@ struct MockDirInfo { struct MockFileInfo { std::string full_path; TimePoint mtime; - util::string_view data; + std::string_view data; bool operator==(const MockFileInfo& other) const { return mtime == other.mtime && full_path == other.full_path && data == other.data; @@ -102,7 +102,7 @@ class ARROW_EXPORT MockFileSystem : public FileSystem { std::vector AllFiles(); // Create a File with a content from a string. - Status CreateFile(const std::string& path, util::string_view content, + Status CreateFile(const std::string& path, std::string_view content, bool recursive = true); // Create a MockFileSystem out of (empty) FileInfo. The content of every diff --git a/cpp/src/arrow/filesystem/path_util.cc b/cpp/src/arrow/filesystem/path_util.cc index 2216a4bb258..53cd6103776 100644 --- a/cpp/src/arrow/filesystem/path_util.cc +++ b/cpp/src/arrow/filesystem/path_util.cc @@ -23,10 +23,13 @@ #include "arrow/result.h" #include "arrow/status.h" #include "arrow/util/logging.h" -#include "arrow/util/string_view.h" +#include "arrow/util/string.h" #include "arrow/util/uri.h" namespace arrow { + +using internal::StartsWith; + namespace fs { namespace internal { @@ -34,7 +37,7 @@ namespace internal { std::vector SplitAbstractPath(const std::string& path, char sep) { std::vector parts; - auto v = util::string_view(path); + auto v = std::string_view(path); // Strip trailing separator if (v.length() > 0 && v.back() == sep) { v = v.substr(0, v.length() - 1); @@ -75,13 +78,13 @@ std::pair GetAbstractPathParent(const std::string& s) } std::string GetAbstractPathExtension(const std::string& s) { - util::string_view basename(s); + std::string_view basename(s); auto offset = basename.find_last_of(kSep); if (offset != std::string::npos) { basename = basename.substr(offset); } auto dot = basename.find_last_of('.'); - if (dot == util::string_view::npos) { + if (dot == std::string_view::npos) { // Empty extension return ""; } @@ -108,7 +111,7 @@ std::string ConcatAbstractPath(const std::string& base, const std::string& stem) return EnsureTrailingSlash(base) + std::string(RemoveLeadingSlash(stem)); } -std::string EnsureTrailingSlash(util::string_view v) { +std::string EnsureTrailingSlash(std::string_view v) { if (v.length() > 0 && v.back() != kSep) { // XXX How about "C:" on Windows? We probably don't want to turn it into "C:/"... // Unless the local filesystem always uses absolute paths @@ -118,7 +121,7 @@ std::string EnsureTrailingSlash(util::string_view v) { } } -std::string EnsureLeadingSlash(util::string_view v) { +std::string EnsureLeadingSlash(std::string_view v) { if (v.length() == 0 || v.front() != kSep) { // XXX How about "C:" on Windows? We probably don't want to turn it into "/C:"... return kSep + std::string(v); @@ -126,21 +129,21 @@ std::string EnsureLeadingSlash(util::string_view v) { return std::string(v); } } -util::string_view RemoveTrailingSlash(util::string_view key) { +std::string_view RemoveTrailingSlash(std::string_view key) { while (!key.empty() && key.back() == kSep) { key.remove_suffix(1); } return key; } -util::string_view RemoveLeadingSlash(util::string_view key) { +std::string_view RemoveLeadingSlash(std::string_view key) { while (!key.empty() && key.front() == kSep) { key.remove_prefix(1); } return key; } -Status AssertNoTrailingSlash(util::string_view key) { +Status AssertNoTrailingSlash(std::string_view key) { if (key.back() == '/') { return NotAFile(key); } @@ -154,8 +157,8 @@ Result MakeAbstractPathRelative(const std::string& base, base, "'"); } auto b = EnsureLeadingSlash(RemoveTrailingSlash(base)); - auto p = util::string_view(path); - if (p.substr(0, b.size()) != util::string_view(b)) { + auto p = std::string_view(path); + if (p.substr(0, b.size()) != std::string_view(b)) { return Status::Invalid("Path '", path, "' is not relative to '", base, "'"); } p = p.substr(b.size()); @@ -165,7 +168,7 @@ Result MakeAbstractPathRelative(const std::string& base, return std::string(RemoveLeadingSlash(p)); } -bool IsAncestorOf(util::string_view ancestor, util::string_view descendant) { +bool IsAncestorOf(std::string_view ancestor, std::string_view descendant) { ancestor = RemoveTrailingSlash(ancestor); if (ancestor == "") { // everything is a descendant of the root directory @@ -173,7 +176,7 @@ bool IsAncestorOf(util::string_view ancestor, util::string_view descendant) { } descendant = RemoveTrailingSlash(descendant); - if (!descendant.starts_with(ancestor)) { + if (!StartsWith(descendant, ancestor)) { // an ancestor path is a prefix of descendant paths return false; } @@ -186,11 +189,11 @@ bool IsAncestorOf(util::string_view ancestor, util::string_view descendant) { } // "/hello/w" is not an ancestor of "/hello/world" - return descendant.starts_with(std::string{kSep}); + return StartsWith(descendant, std::string{kSep}); } -std::optional RemoveAncestor(util::string_view ancestor, - util::string_view descendant) { +std::optional RemoveAncestor(std::string_view ancestor, + std::string_view descendant) { if (!IsAncestorOf(ancestor, descendant)) { return std::nullopt; } @@ -199,8 +202,8 @@ std::optional RemoveAncestor(util::string_view ancestor, return RemoveLeadingSlash(relative_to_ancestor); } -std::vector AncestorsFromBasePath(util::string_view base_path, - util::string_view descendant) { +std::vector AncestorsFromBasePath(std::string_view base_path, + std::string_view descendant) { std::vector ancestry; if (auto relative = RemoveAncestor(base_path, descendant)) { auto relative_segments = fs::internal::SplitAbstractPath(std::string(*relative)); @@ -245,7 +248,7 @@ std::vector MinimalCreateDirSet(std::vector dirs) { return dirs; } -std::string ToBackslashes(util::string_view v) { +std::string ToBackslashes(std::string_view v) { std::string s(v); for (auto& c : s) { if (c == '/') { @@ -255,7 +258,7 @@ std::string ToBackslashes(util::string_view v) { return s; } -std::string ToSlashes(util::string_view v) { +std::string ToSlashes(std::string_view v) { std::string s(v); #ifdef _WIN32 for (auto& c : s) { @@ -267,7 +270,7 @@ std::string ToSlashes(util::string_view v) { return s; } -bool IsEmptyPath(util::string_view v) { +bool IsEmptyPath(std::string_view v) { for (const auto c : v) { if (c != '/') { return false; @@ -276,7 +279,7 @@ bool IsEmptyPath(util::string_view v) { return true; } -bool IsLikelyUri(util::string_view v) { +bool IsLikelyUri(std::string_view v) { if (v.empty() || v[0] == '/') { return false; } diff --git a/cpp/src/arrow/filesystem/path_util.h b/cpp/src/arrow/filesystem/path_util.h index ea8e56df5d4..fc1d2d82443 100644 --- a/cpp/src/arrow/filesystem/path_util.h +++ b/cpp/src/arrow/filesystem/path_util.h @@ -19,11 +19,11 @@ #include #include +#include #include #include #include "arrow/type_fwd.h" -#include "arrow/util/string_view.h" namespace arrow { namespace fs { @@ -61,34 +61,34 @@ Result MakeAbstractPathRelative(const std::string& base, const std::string& path); ARROW_EXPORT -std::string EnsureLeadingSlash(util::string_view s); +std::string EnsureLeadingSlash(std::string_view s); ARROW_EXPORT -util::string_view RemoveLeadingSlash(util::string_view s); +std::string_view RemoveLeadingSlash(std::string_view s); ARROW_EXPORT -std::string EnsureTrailingSlash(util::string_view s); +std::string EnsureTrailingSlash(std::string_view s); ARROW_EXPORT -util::string_view RemoveTrailingSlash(util::string_view s); +std::string_view RemoveTrailingSlash(std::string_view s); ARROW_EXPORT -Status AssertNoTrailingSlash(util::string_view s); +Status AssertNoTrailingSlash(std::string_view s); ARROW_EXPORT -bool IsAncestorOf(util::string_view ancestor, util::string_view descendant); +bool IsAncestorOf(std::string_view ancestor, std::string_view descendant); ARROW_EXPORT -std::optional RemoveAncestor(util::string_view ancestor, - util::string_view descendant); +std::optional RemoveAncestor(std::string_view ancestor, + std::string_view descendant); /// Return a vector of ancestors between a base path and a descendant. /// For example, /// /// AncestorsFromBasePath("a/b", "a/b/c/d/e") -> ["a/b/c", "a/b/c/d"] ARROW_EXPORT -std::vector AncestorsFromBasePath(util::string_view base_path, - util::string_view descendant); +std::vector AncestorsFromBasePath(std::string_view base_path, + std::string_view descendant); /// Given a vector of paths of directories which must be created, produce a the minimal /// subset for passing to CreateDir(recursive=true) by removing redundant parent @@ -118,18 +118,18 @@ std::string JoinAbstractPath(const StringRange& range, char sep = kSep) { /// Convert slashes to backslashes, on all platforms. Mostly useful for testing. ARROW_EXPORT -std::string ToBackslashes(util::string_view s); +std::string ToBackslashes(std::string_view s); /// Ensure a local path is abstract, by converting backslashes to regular slashes /// on Windows. Return the path unchanged on other systems. ARROW_EXPORT -std::string ToSlashes(util::string_view s); +std::string ToSlashes(std::string_view s); ARROW_EXPORT -bool IsEmptyPath(util::string_view s); +bool IsEmptyPath(std::string_view s); ARROW_EXPORT -bool IsLikelyUri(util::string_view s); +bool IsLikelyUri(std::string_view s); class ARROW_EXPORT Globber { public: diff --git a/cpp/src/arrow/filesystem/s3_internal.h b/cpp/src/arrow/filesystem/s3_internal.h index c6e6349ba2c..00efff166f2 100644 --- a/cpp/src/arrow/filesystem/s3_internal.h +++ b/cpp/src/arrow/filesystem/s3_internal.h @@ -19,6 +19,7 @@ #include #include +#include #include #include @@ -33,7 +34,6 @@ #include "arrow/status.h" #include "arrow/util/logging.h" #include "arrow/util/print.h" -#include "arrow/util/string_view.h" namespace arrow { namespace fs { @@ -46,7 +46,7 @@ enum class S3Backend { Amazon, Minio, Other }; inline S3Backend DetectS3Backend(const Aws::Http::HeaderValueCollection& headers) { const auto it = headers.find("server"); if (it != headers.end()) { - const auto& value = util::string_view(it->second); + const auto& value = std::string_view(it->second); if (value.find("AmazonS3") != std::string::npos) { return S3Backend::Amazon; } @@ -218,7 +218,7 @@ inline Aws::String ToAwsString(const std::string& s) { return Aws::String(s.begin(), s.end()); } -inline util::string_view FromAwsString(const Aws::String& s) { +inline std::string_view FromAwsString(const Aws::String& s) { return {s.data(), s.length()}; } diff --git a/cpp/src/arrow/filesystem/s3fs.cc b/cpp/src/arrow/filesystem/s3fs.cc index db79810f5d7..e75f277034a 100644 --- a/cpp/src/arrow/filesystem/s3fs.cc +++ b/cpp/src/arrow/filesystem/s3fs.cc @@ -954,7 +954,7 @@ std::shared_ptr GetObjectMetadata(const ObjectResult& re auto push = [&](std::string k, const Aws::String& v) { if (!v.empty()) { - md->Append(std::move(k), FromAwsString(v).to_string()); + md->Append(std::move(k), std::string(FromAwsString(v))); } }; auto push_datetime = [&](std::string k, const Aws::Utils::DateTime& v) { @@ -1948,7 +1948,7 @@ class S3FileSystem::Impl : public std::enable_shared_from_this& src, return Status::OK(); } -Status PathNotFound(util::string_view path) { +Status PathNotFound(std::string_view path) { return Status::IOError("Path does not exist '", path, "'") .WithDetail(StatusDetailFromErrno(ENOENT)); } -Status NotADir(util::string_view path) { +Status NotADir(std::string_view path) { return Status::IOError("Not a directory: '", path, "'") .WithDetail(StatusDetailFromErrno(ENOTDIR)); } -Status NotAFile(util::string_view path) { +Status NotAFile(std::string_view path) { return Status::IOError("Not a regular file: '", path, "'"); } -Status InvalidDeleteDirContents(util::string_view path) { +Status InvalidDeleteDirContents(std::string_view path) { return Status::Invalid( "DeleteDirContents called on invalid path '", path, "'. ", "If you wish to delete the root directory's contents, call DeleteRootDirContents."); diff --git a/cpp/src/arrow/filesystem/util_internal.h b/cpp/src/arrow/filesystem/util_internal.h index 75a2d3a2ef5..cc16dbba106 100644 --- a/cpp/src/arrow/filesystem/util_internal.h +++ b/cpp/src/arrow/filesystem/util_internal.h @@ -19,11 +19,11 @@ #include #include +#include #include "arrow/filesystem/filesystem.h" #include "arrow/io/interfaces.h" #include "arrow/status.h" -#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" namespace arrow { @@ -39,16 +39,16 @@ Status CopyStream(const std::shared_ptr& src, const io::IOContext& io_context); ARROW_EXPORT -Status PathNotFound(util::string_view path); +Status PathNotFound(std::string_view path); ARROW_EXPORT -Status NotADir(util::string_view path); +Status NotADir(std::string_view path); ARROW_EXPORT -Status NotAFile(util::string_view path); +Status NotAFile(std::string_view path); ARROW_EXPORT -Status InvalidDeleteDirContents(util::string_view path); +Status InvalidDeleteDirContents(std::string_view path); /// \brief Return files matching the glob pattern on the filesystem /// diff --git a/cpp/src/arrow/flight/cookie_internal.cc b/cpp/src/arrow/flight/cookie_internal.cc index 380ea56976d..37672d2ecd8 100644 --- a/cpp/src/arrow/flight/cookie_internal.cc +++ b/cpp/src/arrow/flight/cookie_internal.cc @@ -63,7 +63,7 @@ size_t CaseInsensitiveHash::operator()(const std::string& key) const { return std::hash{}(upper_string); } -Cookie Cookie::Parse(const arrow::util::string_view& cookie_header_value) { +Cookie Cookie::Parse(const std::string_view& cookie_header_value) { // Parse the cookie string. If the cookie has an expiration, record it. // If the cookie has a max-age, calculate the current time + max_age and set that as // the expiration. @@ -252,7 +252,7 @@ void CookieCache::UpdateCachedCookies(const CallHeaders& incoming_headers) { const std::lock_guard guard(mutex_); for (auto it = header_values.first; it != header_values.second; ++it) { - const util::string_view& value = it->second; + const std::string_view& value = it->second; Cookie cookie = Cookie::Parse(value); // Cache cookies regardless of whether or not they are expired. The server may have diff --git a/cpp/src/arrow/flight/cookie_internal.h b/cpp/src/arrow/flight/cookie_internal.h index b87c8052266..f2f469b3824 100644 --- a/cpp/src/arrow/flight/cookie_internal.h +++ b/cpp/src/arrow/flight/cookie_internal.h @@ -23,12 +23,12 @@ #include #include #include +#include #include #include #include "arrow/flight/client_middleware.h" #include "arrow/result.h" -#include "arrow/util/string_view.h" namespace arrow { namespace flight { @@ -54,7 +54,7 @@ class ARROW_FLIGHT_EXPORT Cookie { /// \brief Parse function to parse a cookie header value and return a Cookie object. /// /// \return Cookie object based on cookie header value. - static Cookie Parse(const arrow::util::string_view& cookie_header_value); + static Cookie Parse(const std::string_view& cookie_header_value); /// \brief Parse a cookie header string beginning at the given start_pos and identify /// the name and value of an attribute. diff --git a/cpp/src/arrow/flight/flight_internals_test.cc b/cpp/src/arrow/flight/flight_internals_test.cc index f315e42a6a6..8a809decd70 100644 --- a/cpp/src/arrow/flight/flight_internals_test.cc +++ b/cpp/src/arrow/flight/flight_internals_test.cc @@ -274,8 +274,8 @@ class TestCookieMiddleware : public ::testing::Test { void AddAndValidate(const std::string& incoming_cookie) { // Add cookie CallHeaders call_headers; - call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"), - arrow::util::string_view(incoming_cookie))); + call_headers.insert(std::make_pair(std::string_view("set-cookie"), + std::string_view(incoming_cookie))); middleware_->ReceivedHeaders(call_headers); expected_cookie_cache_.UpdateCachedCookies(call_headers); @@ -423,8 +423,8 @@ class TestCookieParsing : public ::testing::Test { for (auto& cookie : cookies) { // Add cookie CallHeaders call_headers; - call_headers.insert(std::make_pair(arrow::util::string_view("set-cookie"), - arrow::util::string_view(cookie))); + call_headers.insert( + std::make_pair(std::string_view("set-cookie"), std::string_view(cookie))); cookie_cache.UpdateCachedCookies(call_headers); } const std::string actual_cookies = cookie_cache.GetValidCookiesAsString(); diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index a7c79a6dc5b..54597e54203 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -463,7 +463,7 @@ class TracingServerMiddlewareFactory : public ServerMiddlewareFactory { const std::pair& iter_pair = incoming_headers.equal_range("x-tracing-span-id"); if (iter_pair.first != iter_pair.second) { - const util::string_view& value = (*iter_pair.first).second; + const std::string_view& value = (*iter_pair.first).second; *middleware = std::make_shared(std::string(value)); } return Status::OK(); @@ -484,7 +484,7 @@ std::string FindKeyValPrefixInCallHeaders(const CallHeaders& incoming_headers, if (iter == incoming_headers.end()) { return ""; } - const std::string val = iter->second.to_string(); + const std::string val(iter->second); if (val.size() > prefix.length()) { if (std::equal(val.begin(), val.begin() + prefix.length(), prefix.begin(), char_compare)) { @@ -773,8 +773,8 @@ class TestPropagatingMiddleware : public ::testing::Test { void CheckHeader(const std::string& header, const std::string& value, const CallHeaders::const_iterator& it) { // Construct a string_view before comparison to satisfy MSVC - util::string_view header_view(header.data(), header.length()); - util::string_view value_view(value.data(), value.length()); + std::string_view header_view(header.data(), header.length()); + std::string_view value_view(value.data(), value.length()); ASSERT_EQ(header_view, (*it).first); ASSERT_EQ(value_view, (*it).second); } diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index 43c16e0b77a..0b7ddc56ecb 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -150,7 +150,7 @@ class TestServerMiddlewareFactory : public ServerMiddlewareFactory { incoming_headers.equal_range("x-middleware"); std::string received = ""; if (iter_pair.first != iter_pair.second) { - const util::string_view& value = (*iter_pair.first).second; + const std::string_view& value = (*iter_pair.first).second; received = std::string(value); } *middleware = std::make_shared(received); @@ -176,7 +176,7 @@ class TestClientMiddleware : public ClientMiddleware { const std::pair& iter_pair = incoming_headers.equal_range("x-middleware"); if (iter_pair.first != iter_pair.second) { - const util::string_view& value = (*iter_pair.first).second; + const std::string_view& value = (*iter_pair.first).second; *received_header_ = std::string(value); } } diff --git a/cpp/src/arrow/flight/middleware.h b/cpp/src/arrow/flight/middleware.h index d11ba11477c..b050e9cc6ed 100644 --- a/cpp/src/arrow/flight/middleware.h +++ b/cpp/src/arrow/flight/middleware.h @@ -23,11 +23,11 @@ #include #include #include +#include #include #include "arrow/flight/visibility.h" // IWYU pragma: keep #include "arrow/status.h" -#include "arrow/util/string_view.h" namespace arrow { @@ -36,7 +36,7 @@ namespace flight { /// \brief Headers sent from the client or server. /// /// Header values are ordered. -using CallHeaders = std::multimap; +using CallHeaders = std::multimap; /// \brief A write-only wrapper around headers for an RPC call. class ARROW_FLIGHT_EXPORT AddCallHeaders { diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index e9736b0615e..66185cfeba5 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include #include @@ -39,7 +40,6 @@ #include "arrow/status.h" #include "arrow/util/io_util.h" #include "arrow/util/logging.h" -#include "arrow/util/string_view.h" #include "arrow/util/uri.h" namespace arrow { diff --git a/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.cc b/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.cc index 68bde35c718..921dd13182e 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.cc +++ b/cpp/src/arrow/flight/sql/example/sqlite_tables_schema_batch_reader.cc @@ -92,8 +92,7 @@ Status SqliteTablesWithSchemaBatchReader::ReadNext(std::shared_ptr* ARROW_ASSIGN_OR_RAISE(schema_buffer, value); column_fields.clear(); - ARROW_RETURN_NOT_OK( - schema_builder.Append(::arrow::util::string_view(*schema_buffer))); + ARROW_RETURN_NOT_OK(schema_builder.Append(::std::string_view(*schema_buffer))); } std::shared_ptr schema_array; diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc index e6f50169607..34c4ae91627 100644 --- a/cpp/src/arrow/flight/transport/grpc/grpc_client.cc +++ b/cpp/src/arrow/flight/transport/grpc/grpc_client.cc @@ -44,6 +44,7 @@ #include "arrow/util/base64.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" +#include "arrow/util/string.h" #include "arrow/util/uri.h" #include "arrow/flight/client.h" @@ -59,6 +60,8 @@ namespace arrow { +using internal::EndsWith; + namespace flight { namespace transport { namespace grpc { @@ -151,8 +154,8 @@ class GrpcClientInterceptorAdapter : public ::grpc::experimental::Interceptor { received_headers_ = true; CallHeaders headers; for (const auto& entry : metadata) { - headers.insert({util::string_view(entry.first.data(), entry.first.length()), - util::string_view(entry.second.data(), entry.second.length())}); + headers.insert({std::string_view(entry.first.data(), entry.first.length()), + std::string_view(entry.second.data(), entry.second.length())}); } for (const auto& middleware : middleware_) { middleware->ReceivedHeaders(headers); @@ -180,24 +183,24 @@ class GrpcClientInterceptorAdapterFactory std::vector> middleware; FlightMethod flight_method = FlightMethod::Invalid; - util::string_view method(info->method()); - if (method.ends_with("/Handshake")) { + std::string_view method(info->method()); + if (EndsWith(method, "/Handshake")) { flight_method = FlightMethod::Handshake; - } else if (method.ends_with("/ListFlights")) { + } else if (EndsWith(method, "/ListFlights")) { flight_method = FlightMethod::ListFlights; - } else if (method.ends_with("/GetFlightInfo")) { + } else if (EndsWith(method, "/GetFlightInfo")) { flight_method = FlightMethod::GetFlightInfo; - } else if (method.ends_with("/GetSchema")) { + } else if (EndsWith(method, "/GetSchema")) { flight_method = FlightMethod::GetSchema; - } else if (method.ends_with("/DoGet")) { + } else if (EndsWith(method, "/DoGet")) { flight_method = FlightMethod::DoGet; - } else if (method.ends_with("/DoPut")) { + } else if (EndsWith(method, "/DoPut")) { flight_method = FlightMethod::DoPut; - } else if (method.ends_with("/DoExchange")) { + } else if (EndsWith(method, "/DoExchange")) { flight_method = FlightMethod::DoExchange; - } else if (method.ends_with("/DoAction")) { + } else if (EndsWith(method, "/DoAction")) { flight_method = FlightMethod::DoAction; - } else if (method.ends_with("/ListActions")) { + } else if (EndsWith(method, "/ListActions")) { flight_method = FlightMethod::ListActions; } else { ARROW_LOG(WARNING) << "Unknown Flight method: " << info->method(); diff --git a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc index 14daaa58765..a643111e3b2 100644 --- a/cpp/src/arrow/flight/transport/grpc/grpc_server.cc +++ b/cpp/src/arrow/flight/transport/grpc/grpc_server.cc @@ -313,8 +313,8 @@ class GrpcServiceHandler final : public FlightService::Service { CallHeaders incoming_headers; for (const auto& entry : context->client_metadata()) { incoming_headers.insert( - {util::string_view(entry.first.data(), entry.first.length()), - util::string_view(entry.second.data(), entry.second.length())}); + {std::string_view(entry.first.data(), entry.first.length()), + std::string_view(entry.second.data(), entry.second.length())}); } GrpcAddServerHeaders outgoing_headers(context); diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc index 14b5638adab..80124123d4a 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_client.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_client.cc @@ -581,7 +581,7 @@ class UcxClientImpl : public arrow::flight::internal::ClientTransport { ARROW_ASSIGN_OR_RAISE(auto incoming_message, driver->ReadNextFrame()); if (incoming_message->type == FrameType::kBuffer) { ARROW_ASSIGN_OR_RAISE( - *info, FlightInfo::Deserialize(util::string_view(*incoming_message->buffer))); + *info, FlightInfo::Deserialize(std::string_view(*incoming_message->buffer))); ARROW_ASSIGN_OR_RAISE(incoming_message, driver->ReadNextFrame()); } RETURN_NOT_OK(driver->ExpectFrameType(*incoming_message, FrameType::kHeaders)); diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc index 373333663f8..318f6204ac9 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.cc @@ -180,7 +180,7 @@ arrow::Result HeadersFrame::Parse(std::unique_ptr buffer) return Status::Invalid("Buffer underflow, expected key ", i + 1, " to have length ", key_length, ", but only ", (end - payload), " bytes remain"); } - const util::string_view key(reinterpret_cast(payload), key_length); + const std::string_view key(reinterpret_cast(payload), key_length); payload += key_length; if (ARROW_PREDICT_FALSE((end - payload) < value_length)) { @@ -188,7 +188,7 @@ arrow::Result HeadersFrame::Parse(std::unique_ptr buffer) " to have length ", value_length, ", but only ", (end - payload), " bytes remain"); } - const util::string_view value(reinterpret_cast(payload), value_length); + const std::string_view value(reinterpret_cast(payload), value_length); payload += value_length; result.headers_.emplace_back(key, value); } @@ -243,7 +243,7 @@ arrow::Result HeadersFrame::Make( return Make(all_headers); } -arrow::Result HeadersFrame::Get(const std::string& key) { +arrow::Result HeadersFrame::Get(const std::string& key) { for (const auto& pair : headers_) { if (pair.first == key) return pair.second; } @@ -252,7 +252,7 @@ arrow::Result HeadersFrame::Get(const std::string& key) { Status HeadersFrame::GetStatus(Status* out) { static const std::string kUnknownMessage = "Server did not send status message header"; - util::string_view code_str, message_str; + std::string_view code_str, message_str; auto status = Get(kHeaderStatus).Value(&code_str); if (!status.ok()) { return Status::KeyError("Server did not send status code header ", kHeaderStatusCode); @@ -273,7 +273,7 @@ Status HeadersFrame::GetStatus(Status* out) { } *out = transport_status.ToStatus(); - util::string_view detail_str, bin_str; + std::string_view detail_str, bin_str; std::optional message, detail_message, detail_bin; if (!Get(kHeaderStatusCode).Value(&code_str).ok()) { // No Arrow status sent, go with the transport status @@ -363,7 +363,7 @@ Status PayloadHeaderFrame::ToFlightData(internal::FlightData* data) { return Status::Invalid("Buffer is too small: expected ", offset + size, " bytes but have ", buffer->size()); } - util::string_view desc(reinterpret_cast(buffer->data() + offset), size); + std::string_view desc(reinterpret_cast(buffer->data() + offset), size); data->descriptor.reset(new FlightDescriptor()); ARROW_ASSIGN_OR_RAISE(*data->descriptor, FlightDescriptor::Deserialize(desc)); offset += size; diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_internal.h b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h index f5b81ab4147..d14296db097 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_internal.h +++ b/cpp/src/arrow/flight/transport/ucx/ucx_internal.h @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -35,7 +36,6 @@ #include "arrow/util/future.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" -#include "arrow/util/string_view.h" namespace arrow { namespace flight { @@ -191,8 +191,8 @@ struct Frame { std::unique_ptr buffer_) : type(type_), size(size_), counter(counter_), buffer(std::move(buffer_)) {} - util::string_view view() const { - return util::string_view(reinterpret_cast(buffer->data()), size); + std::string_view view() const { + return std::string_view(reinterpret_cast(buffer->data()), size); } /// \brief Parse a UCX active message header. This will not @@ -222,7 +222,7 @@ static constexpr uint32_t kUcpAmHandlerId = 0x1024; class HeadersFrame { public: /// \brief Get a header value (or an error if it was not found) - arrow::Result Get(const std::string& key); + arrow::Result Get(const std::string& key); /// \brief Extract the server-sent status. Status GetStatus(Status* out); /// \brief Parse the headers from the buffer. @@ -240,7 +240,7 @@ class HeadersFrame { private: std::unique_ptr buffer_; - std::vector> headers_; + std::vector> headers_; }; /// \brief A representation of a kPayloadHeader frame (i.e. all of the diff --git a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc index d7ddbfab06e..398bc438146 100644 --- a/cpp/src/arrow/flight/transport/ucx/ucx_server.cc +++ b/cpp/src/arrow/flight/transport/ucx/ucx_server.cc @@ -362,7 +362,7 @@ class UcxServerImpl : public arrow::flight::internal::ServerTransport { SERVER_RETURN_NOT_OK(driver, driver->ExpectFrameType(*frame, FrameType::kBuffer)); FlightDescriptor descriptor; SERVER_RETURN_NOT_OK(driver, - FlightDescriptor::Deserialize(util::string_view(*frame->buffer)) + FlightDescriptor::Deserialize(std::string_view(*frame->buffer)) .Value(&descriptor)); std::unique_ptr info; diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index a505e6d6e1e..2122e57ccc1 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include "arrow/buffer.h" @@ -29,7 +30,6 @@ #include "arrow/status.h" #include "arrow/table.h" #include "arrow/util/make_unique.h" -#include "arrow/util/string_view.h" #include "arrow/util/uri.h" namespace arrow { @@ -177,8 +177,7 @@ arrow::Result SchemaResult::SerializeToString() const { return out; } -arrow::Result SchemaResult::Deserialize( - arrow::util::string_view serialized) { +arrow::Result SchemaResult::Deserialize(std::string_view serialized) { pb::SchemaResult pb_schema_result; if (serialized.size() > static_cast(std::numeric_limits::max())) { return Status::Invalid("Serialized SchemaResult size should not exceed 2 GiB"); @@ -207,7 +206,7 @@ Status FlightDescriptor::SerializeToString(std::string* out) const { } arrow::Result FlightDescriptor::Deserialize( - arrow::util::string_view serialized) { + std::string_view serialized) { pb::FlightDescriptor pb_descriptor; if (serialized.size() > static_cast(std::numeric_limits::max())) { return Status::Invalid("Serialized FlightDescriptor size should not exceed 2 GiB"); @@ -244,7 +243,7 @@ Status Ticket::SerializeToString(std::string* out) const { return SerializeToString().Value(out); } -arrow::Result Ticket::Deserialize(arrow::util::string_view serialized) { +arrow::Result Ticket::Deserialize(std::string_view serialized) { pb::Ticket pb_ticket; if (serialized.size() > static_cast(std::numeric_limits::max())) { return Status::Invalid("Serialized Ticket size should not exceed 2 GiB"); @@ -308,7 +307,7 @@ Status FlightInfo::SerializeToString(std::string* out) const { } arrow::Result> FlightInfo::Deserialize( - arrow::util::string_view serialized) { + std::string_view serialized) { pb::FlightInfo pb_info; if (serialized.size() > static_cast(std::numeric_limits::max())) { return Status::Invalid("Serialized FlightInfo size should not exceed 2 GiB"); @@ -410,8 +409,7 @@ arrow::Result FlightEndpoint::SerializeToString() const { return out; } -arrow::Result FlightEndpoint::Deserialize( - arrow::util::string_view serialized) { +arrow::Result FlightEndpoint::Deserialize(std::string_view serialized) { pb::FlightEndpoint pb_flight_endpoint; if (serialized.size() > static_cast(std::numeric_limits::max())) { return Status::Invalid("Serialized FlightEndpoint size should not exceed 2 GiB"); @@ -441,7 +439,7 @@ arrow::Result ActionType::SerializeToString() const { return out; } -arrow::Result ActionType::Deserialize(arrow::util::string_view serialized) { +arrow::Result ActionType::Deserialize(std::string_view serialized) { pb::ActionType pb_action_type; if (serialized.size() > static_cast(std::numeric_limits::max())) { return Status::Invalid("Serialized ActionType size should not exceed 2 GiB"); @@ -471,7 +469,7 @@ arrow::Result Criteria::SerializeToString() const { return out; } -arrow::Result Criteria::Deserialize(arrow::util::string_view serialized) { +arrow::Result Criteria::Deserialize(std::string_view serialized) { pb::Criteria pb_criteria; if (serialized.size() > static_cast(std::numeric_limits::max())) { return Status::Invalid("Serialized Criteria size should not exceed 2 GiB"); @@ -502,7 +500,7 @@ arrow::Result Action::SerializeToString() const { return out; } -arrow::Result Action::Deserialize(arrow::util::string_view serialized) { +arrow::Result Action::Deserialize(std::string_view serialized) { pb::Action pb_action; if (serialized.size() > static_cast(std::numeric_limits::max())) { return Status::Invalid("Serialized Action size should not exceed 2 GiB"); @@ -532,7 +530,7 @@ arrow::Result Result::SerializeToString() const { return out; } -arrow::Result Result::Deserialize(arrow::util::string_view serialized) { +arrow::Result Result::Deserialize(std::string_view serialized) { pb::Result pb_result; if (serialized.size() > static_cast(std::numeric_limits::max())) { return Status::Invalid("Serialized Result size should not exceed 2 GiB"); @@ -645,7 +643,7 @@ bool BasicAuth::Equals(const BasicAuth& other) const { return (username == other.username) && (password == other.password); } -arrow::Result BasicAuth::Deserialize(arrow::util::string_view serialized) { +arrow::Result BasicAuth::Deserialize(std::string_view serialized) { pb::BasicAuth pb_result; if (serialized.size() > static_cast(std::numeric_limits::max())) { return Status::Invalid("Serialized BasicAuth size should not exceed 2 GiB"); diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index ae9867e44a1..6957c5992a3 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -30,7 +31,6 @@ #include "arrow/ipc/options.h" #include "arrow/ipc/writer.h" #include "arrow/result.h" -#include "arrow/util/string_view.h" namespace arrow { @@ -153,7 +153,7 @@ struct ARROW_FLIGHT_EXPORT ActionType { arrow::Result SerializeToString() const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(arrow::util::string_view serialized); + static arrow::Result Deserialize(std::string_view serialized); }; /// \brief Opaque selection criteria for ListFlights RPC @@ -174,7 +174,7 @@ struct ARROW_FLIGHT_EXPORT Criteria { arrow::Result SerializeToString() const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(arrow::util::string_view serialized); + static arrow::Result Deserialize(std::string_view serialized); }; /// \brief An action to perform with the DoAction RPC @@ -198,7 +198,7 @@ struct ARROW_FLIGHT_EXPORT Action { arrow::Result SerializeToString() const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(arrow::util::string_view serialized); + static arrow::Result Deserialize(std::string_view serialized); }; /// \brief Opaque result returned after executing an action @@ -218,7 +218,7 @@ struct ARROW_FLIGHT_EXPORT Result { arrow::Result SerializeToString() const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(arrow::util::string_view serialized); + static arrow::Result Deserialize(std::string_view serialized); }; /// \brief message for simple auth @@ -236,7 +236,7 @@ struct ARROW_FLIGHT_EXPORT BasicAuth { } /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(arrow::util::string_view serialized); + static arrow::Result Deserialize(std::string_view serialized); /// \brief Serialize this message to its wire-format representation. arrow::Result SerializeToString() const; @@ -284,7 +284,7 @@ struct ARROW_FLIGHT_EXPORT FlightDescriptor { /// /// Useful when interoperating with non-Flight systems (e.g. REST /// services) that may want to return Flight types. - static arrow::Result Deserialize(arrow::util::string_view serialized); + static arrow::Result Deserialize(std::string_view serialized); ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.") static Status Deserialize(const std::string& serialized, FlightDescriptor* out); @@ -334,7 +334,7 @@ struct ARROW_FLIGHT_EXPORT Ticket { /// /// Useful when interoperating with non-Flight systems (e.g. REST /// services) that may want to return Flight types. - static arrow::Result Deserialize(arrow::util::string_view serialized); + static arrow::Result Deserialize(std::string_view serialized); ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.") static Status Deserialize(const std::string& serialized, Ticket* out); @@ -442,7 +442,7 @@ struct ARROW_FLIGHT_EXPORT FlightEndpoint { arrow::Result SerializeToString() const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(arrow::util::string_view serialized); + static arrow::Result Deserialize(std::string_view serialized); }; /// \brief Staging data structure for messages about to be put on the wire @@ -492,7 +492,7 @@ struct ARROW_FLIGHT_EXPORT SchemaResult { arrow::Result SerializeToString() const; /// \brief Deserialize this message from its wire-format representation. - static arrow::Result Deserialize(arrow::util::string_view serialized); + static arrow::Result Deserialize(std::string_view serialized); private: std::string raw_schema_; @@ -562,7 +562,7 @@ class ARROW_FLIGHT_EXPORT FlightInfo { /// Useful when interoperating with non-Flight systems (e.g. REST /// services) that may want to return Flight types. static arrow::Result> Deserialize( - arrow::util::string_view serialized); + std::string_view serialized); ARROW_DEPRECATED("Deprecated in 8.0.0. Use Result-returning overload instead.") static Status Deserialize(const std::string& serialized, diff --git a/cpp/src/arrow/io/buffered.cc b/cpp/src/arrow/io/buffered.cc index ccfe9a360ab..e0e37c58026 100644 --- a/cpp/src/arrow/io/buffered.cc +++ b/cpp/src/arrow/io/buffered.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include "arrow/buffer.h" @@ -28,7 +29,6 @@ #include "arrow/memory_pool.h" #include "arrow/status.h" #include "arrow/util/logging.h" -#include "arrow/util/string_view.h" namespace arrow { namespace io { @@ -292,7 +292,7 @@ class BufferedInputStream::Impl : public BufferedBase { return ResizeBuffer(new_buffer_size); } - Result Peek(int64_t nbytes) { + Result Peek(int64_t nbytes) { if (raw_read_bound_ >= 0) { // Do not try to peek more than the total remaining number of bytes. nbytes = std::min(nbytes, bytes_buffered_ + (raw_read_bound_ - raw_read_total_)); @@ -324,8 +324,8 @@ class BufferedInputStream::Impl : public BufferedBase { nbytes = bytes_buffered_; } DCHECK(nbytes <= bytes_buffered_); // Enough bytes available - return util::string_view(reinterpret_cast(buffer_data_ + buffer_pos_), - static_cast(nbytes)); + return std::string_view(reinterpret_cast(buffer_data_ + buffer_pos_), + static_cast(nbytes)); } int64_t bytes_buffered() const { return bytes_buffered_; } @@ -458,7 +458,7 @@ std::shared_ptr BufferedInputStream::raw() const { return impl_->ra Result BufferedInputStream::DoTell() const { return impl_->Tell(); } -Result BufferedInputStream::DoPeek(int64_t nbytes) { +Result BufferedInputStream::DoPeek(int64_t nbytes) { return impl_->Peek(nbytes); } diff --git a/cpp/src/arrow/io/buffered.h b/cpp/src/arrow/io/buffered.h index 8116613fa4e..01c0a016dab 100644 --- a/cpp/src/arrow/io/buffered.h +++ b/cpp/src/arrow/io/buffered.h @@ -21,10 +21,10 @@ #include #include +#include #include "arrow/io/concurrency.h" #include "arrow/io/interfaces.h" -#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" namespace arrow { @@ -157,7 +157,7 @@ class ARROW_EXPORT BufferedInputStream /// \brief Return a zero-copy string view referencing buffered data, /// but do not advance the position of the stream. Buffers data and /// expands the buffer size if necessary - Result DoPeek(int64_t nbytes) override; + Result DoPeek(int64_t nbytes) override; class ARROW_NO_EXPORT Impl; std::unique_ptr impl_; diff --git a/cpp/src/arrow/io/buffered_test.cc b/cpp/src/arrow/io/buffered_test.cc index f6f6d61f849..520eaaa9356 100644 --- a/cpp/src/arrow/io/buffered_test.cc +++ b/cpp/src/arrow/io/buffered_test.cc @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -44,7 +45,6 @@ #include "arrow/status.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/io_util.h" -#include "arrow/util/string_view.h" namespace arrow { namespace io { @@ -503,7 +503,7 @@ class TestBufferedInputStreamBound : public ::testing::Test { TEST_F(TestBufferedInputStreamBound, Basics) { std::shared_ptr buffer; - util::string_view view; + std::string_view view; // source is at offset 10 ASSERT_OK_AND_ASSIGN(view, stream_->Peek(10)); @@ -559,7 +559,7 @@ TEST_F(TestBufferedInputStreamBound, Basics) { TEST_F(TestBufferedInputStreamBound, LargeFirstPeek) { // Test a first peek larger than chunk size std::shared_ptr buffer; - util::string_view view; + std::string_view view; int64_t n = 70; ASSERT_GT(n, chunk_size_); @@ -592,7 +592,7 @@ TEST_F(TestBufferedInputStreamBound, LargeFirstPeek) { TEST_F(TestBufferedInputStreamBound, UnboundedPeek) { CreateExample(/*bounded=*/false); - util::string_view view; + std::string_view view; ASSERT_OK_AND_ASSIGN(view, stream_->Peek(10)); ASSERT_EQ(10, view.size()); ASSERT_EQ(50, stream_->bytes_buffered()); diff --git a/cpp/src/arrow/io/concurrency.h b/cpp/src/arrow/io/concurrency.h index b41ad2c1350..43ceb8debce 100644 --- a/cpp/src/arrow/io/concurrency.h +++ b/cpp/src/arrow/io/concurrency.h @@ -116,7 +116,7 @@ class ARROW_EXPORT InputStreamConcurrencyWrapper : public InputStream { return derived()->DoRead(nbytes); } - Result Peek(int64_t nbytes) final { + Result Peek(int64_t nbytes) final { auto guard = lock_.exclusive_guard(); return derived()->DoPeek(nbytes); } @@ -132,7 +132,7 @@ class ARROW_EXPORT InputStreamConcurrencyWrapper : public InputStream { And optionally: Status DoAbort() override; - Result DoPeek(int64_t nbytes) override; + Result DoPeek(int64_t nbytes) override; These methods should be protected in the derived class and InputStreamConcurrencyWrapper declared as a friend with @@ -145,7 +145,7 @@ class ARROW_EXPORT InputStreamConcurrencyWrapper : public InputStream { // have derived classes itself. virtual Status DoAbort() { return derived()->DoClose(); } - virtual Result DoPeek(int64_t ARROW_ARG_UNUSED(nbytes)) { + virtual Result DoPeek(int64_t ARROW_ARG_UNUSED(nbytes)) { return Status::NotImplemented("Peek not implemented"); } @@ -186,7 +186,7 @@ class ARROW_EXPORT RandomAccessFileConcurrencyWrapper : public RandomAccessFile return derived()->DoRead(nbytes); } - Result Peek(int64_t nbytes) final { + Result Peek(int64_t nbytes) final { auto guard = lock_.exclusive_guard(); return derived()->DoPeek(nbytes); } @@ -232,7 +232,7 @@ class ARROW_EXPORT RandomAccessFileConcurrencyWrapper : public RandomAccessFile And optionally: Status DoAbort() override; - Result DoPeek(int64_t nbytes) override; + Result DoPeek(int64_t nbytes) override; These methods should be protected in the derived class and RandomAccessFileConcurrencyWrapper declared as a friend with @@ -245,7 +245,7 @@ class ARROW_EXPORT RandomAccessFileConcurrencyWrapper : public RandomAccessFile // have derived classes itself. virtual Status DoAbort() { return derived()->DoClose(); } - virtual Result DoPeek(int64_t ARROW_ARG_UNUSED(nbytes)) { + virtual Result DoPeek(int64_t ARROW_ARG_UNUSED(nbytes)) { return Status::NotImplemented("Peek not implemented"); } diff --git a/cpp/src/arrow/io/file_test.cc b/cpp/src/arrow/io/file_test.cc index 8165c9c0b49..b5c8797b0b0 100644 --- a/cpp/src/arrow/io/file_test.cc +++ b/cpp/src/arrow/io/file_test.cc @@ -67,7 +67,7 @@ class FileTestFixture : public ::testing::Test { EnsureFileDeleted(); } - std::string TempFile(arrow::util::string_view path) { + std::string TempFile(std::string_view path) { return temp_dir_->path().Join(std::string(path)).ValueOrDie().ToString(); } @@ -563,7 +563,7 @@ class TestMemoryMappedFile : public ::testing::Test, public MemoryMapFixture { void TearDown() override { MemoryMapFixture::TearDown(); } - std::string TempFile(arrow::util::string_view path) { + std::string TempFile(std::string_view path) { return temp_dir_->path().Join(std::string(path)).ValueOrDie().ToString(); } diff --git a/cpp/src/arrow/io/interfaces.cc b/cpp/src/arrow/io/interfaces.cc index 1dfb0bdf8ad..238e297a7f4 100644 --- a/cpp/src/arrow/io/interfaces.cc +++ b/cpp/src/arrow/io/interfaces.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -38,7 +39,6 @@ #include "arrow/util/io_util.h" #include "arrow/util/iterator.h" #include "arrow/util/logging.h" -#include "arrow/util/string_view.h" #include "arrow/util/thread_pool.h" namespace arrow { @@ -107,7 +107,7 @@ const IOContext& Readable::io_context() const { return g_default_io_context; } Status InputStream::Advance(int64_t nbytes) { return Read(nbytes).status(); } -Result InputStream::Peek(int64_t ARROW_ARG_UNUSED(nbytes)) { +Result InputStream::Peek(int64_t ARROW_ARG_UNUSED(nbytes)) { return Status::NotImplemented("Peek not implemented"); } @@ -178,7 +178,7 @@ Status RandomAccessFile::WillNeed(const std::vector& ranges) { return Status::OK(); } -Status Writable::Write(util::string_view data) { +Status Writable::Write(std::string_view data) { return Write(data.data(), static_cast(data.size())); } diff --git a/cpp/src/arrow/io/interfaces.h b/cpp/src/arrow/io/interfaces.h index 70c0dd8520f..86e9ad2d524 100644 --- a/cpp/src/arrow/io/interfaces.h +++ b/cpp/src/arrow/io/interfaces.h @@ -20,13 +20,13 @@ #include #include #include +#include #include #include "arrow/io/type_fwd.h" #include "arrow/type_fwd.h" #include "arrow/util/cancel.h" #include "arrow/util/macros.h" -#include "arrow/util/string_view.h" #include "arrow/util/type_fwd.h" #include "arrow/util/visibility.h" @@ -175,7 +175,7 @@ class ARROW_EXPORT Writable { /// \brief Flush buffered bytes, if any virtual Status Flush(); - Status Write(util::string_view data); + Status Write(std::string_view data); }; class ARROW_EXPORT Readable { @@ -227,7 +227,7 @@ class ARROW_EXPORT InputStream : virtual public FileInterface, /// May return NotImplemented on streams that don't support it. /// /// \param[in] nbytes the maximum number of bytes to see - virtual Result Peek(int64_t nbytes); + virtual Result Peek(int64_t nbytes); /// \brief Return true if InputStream is capable of zero copy Buffer reads /// diff --git a/cpp/src/arrow/io/memory.cc b/cpp/src/arrow/io/memory.cc index 6495242e63b..9b2b0313323 100644 --- a/cpp/src/arrow/io/memory.cc +++ b/cpp/src/arrow/io/memory.cc @@ -274,7 +274,7 @@ BufferReader::BufferReader(const uint8_t* data, int64_t size) BufferReader::BufferReader(const Buffer& buffer) : BufferReader(buffer.data(), buffer.size()) {} -BufferReader::BufferReader(const util::string_view& data) +BufferReader::BufferReader(const std::string_view& data) : BufferReader(reinterpret_cast(data.data()), static_cast(data.size())) {} @@ -290,12 +290,12 @@ Result BufferReader::DoTell() const { return position_; } -Result BufferReader::DoPeek(int64_t nbytes) { +Result BufferReader::DoPeek(int64_t nbytes) { RETURN_NOT_OK(CheckClosed()); const int64_t bytes_available = std::min(nbytes, size_ - position_); - return util::string_view(reinterpret_cast(data_) + position_, - static_cast(bytes_available)); + return std::string_view(reinterpret_cast(data_) + position_, + static_cast(bytes_available)); } bool BufferReader::supports_zero_copy() const { return true; } diff --git a/cpp/src/arrow/io/memory.h b/cpp/src/arrow/io/memory.h index 8213439ef74..5c35a6015be 100644 --- a/cpp/src/arrow/io/memory.h +++ b/cpp/src/arrow/io/memory.h @@ -21,12 +21,12 @@ #include #include +#include #include #include "arrow/io/concurrency.h" #include "arrow/io/interfaces.h" #include "arrow/type_fwd.h" -#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" namespace arrow { @@ -149,9 +149,9 @@ class ARROW_EXPORT BufferReader explicit BufferReader(const Buffer& buffer); BufferReader(const uint8_t* data, int64_t size); - /// \brief Instantiate from std::string or arrow::util::string_view. Does not + /// \brief Instantiate from std::string or std::string_view. Does not /// own data - explicit BufferReader(const util::string_view& data); + explicit BufferReader(const std::string_view& data); bool closed() const override; @@ -173,7 +173,7 @@ class ARROW_EXPORT BufferReader Result> DoRead(int64_t nbytes); Result DoReadAt(int64_t position, int64_t nbytes, void* out); Result> DoReadAt(int64_t position, int64_t nbytes); - Result DoPeek(int64_t nbytes) override; + Result DoPeek(int64_t nbytes) override; Result DoTell() const; Status DoSeek(int64_t position); diff --git a/cpp/src/arrow/io/memory_test.cc b/cpp/src/arrow/io/memory_test.cc index d361243ad6f..cdcbe240f85 100644 --- a/cpp/src/arrow/io/memory_test.cc +++ b/cpp/src/arrow/io/memory_test.cc @@ -162,10 +162,10 @@ TEST(TestFixedSizeBufferWriter, InvalidWrites) { TEST(TestBufferReader, FromStrings) { // ARROW-3291: construct BufferReader from std::string or - // arrow::util::string_view + // std::string_view std::string data = "data123456"; - auto view = util::string_view(data); + auto view = std::string_view(data); BufferReader reader1(data); BufferReader reader2(view); @@ -208,7 +208,7 @@ TEST(TestBufferReader, Peek) { BufferReader reader(std::make_shared(data)); - util::string_view view; + std::string_view view; ASSERT_OK_AND_ASSIGN(view, reader.Peek(4)); @@ -378,7 +378,7 @@ template void TestSlowInputStream() { using clock = std::chrono::high_resolution_clock; - auto stream = std::make_shared(util::string_view("abcdefghijkl")); + auto stream = std::make_shared(std::string_view("abcdefghijkl")); const double latency = 0.6; auto slow = std::make_shared(stream, latency); @@ -395,8 +395,8 @@ void TestSlowInputStream() { ARROW_UNUSED(dt); #endif - ASSERT_OK_AND_ASSIGN(util::string_view view, slow->Peek(4)); - ASSERT_EQ(view, util::string_view("ghij")); + ASSERT_OK_AND_ASSIGN(std::string_view view, slow->Peek(4)); + ASSERT_EQ(view, std::string_view("ghij")); ASSERT_OK(slow->Close()); ASSERT_TRUE(slow->closed()); @@ -493,7 +493,7 @@ class TestTransformInputStream : public ::testing::Test { TransformInputStream::TransformFunc transform() const { return T(); } void TestEmptyStream() { - auto wrapped = std::make_shared(util::string_view()); + auto wrapped = std::make_shared(std::string_view()); auto stream = std::make_shared(wrapped, transform()); ASSERT_OK_AND_EQ(0, stream->Tell()); diff --git a/cpp/src/arrow/io/slow.cc b/cpp/src/arrow/io/slow.cc index 1042691fa59..7c11a484fc1 100644 --- a/cpp/src/arrow/io/slow.cc +++ b/cpp/src/arrow/io/slow.cc @@ -97,7 +97,7 @@ Result> SlowInputStream::Read(int64_t nbytes) { return stream_->Read(nbytes); } -Result SlowInputStream::Peek(int64_t nbytes) { +Result SlowInputStream::Peek(int64_t nbytes) { return stream_->Peek(nbytes); } @@ -140,7 +140,7 @@ Result> SlowRandomAccessFile::ReadAt(int64_t position, return stream_->ReadAt(position, nbytes); } -Result SlowRandomAccessFile::Peek(int64_t nbytes) { +Result SlowRandomAccessFile::Peek(int64_t nbytes) { return stream_->Peek(nbytes); } diff --git a/cpp/src/arrow/io/slow.h b/cpp/src/arrow/io/slow.h index 1ed90f0c2e9..fdcc56dfa6a 100644 --- a/cpp/src/arrow/io/slow.h +++ b/cpp/src/arrow/io/slow.h @@ -85,7 +85,7 @@ class ARROW_EXPORT SlowInputStream : public SlowInputStreamBase { Result Read(int64_t nbytes, void* out) override; Result> Read(int64_t nbytes) override; - Result Peek(int64_t nbytes) override; + Result Peek(int64_t nbytes) override; Result Tell() const override; }; @@ -107,7 +107,7 @@ class ARROW_EXPORT SlowRandomAccessFile : public SlowInputStreamBase> Read(int64_t nbytes) override; Result ReadAt(int64_t position, int64_t nbytes, void* out) override; Result> ReadAt(int64_t position, int64_t nbytes) override; - Result Peek(int64_t nbytes) override; + Result Peek(int64_t nbytes) override; Result GetSize() override; Status Seek(int64_t position) override; diff --git a/cpp/src/arrow/ipc/json_simple.cc b/cpp/src/arrow/ipc/json_simple.cc index 667fd00ae21..1b93aeb2f28 100644 --- a/cpp/src/arrow/ipc/json_simple.cc +++ b/cpp/src/arrow/ipc/json_simple.cc @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -36,7 +37,6 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/decimal.h" #include "arrow/util/logging.h" -#include "arrow/util/string_view.h" #include "arrow/util/value_parsing.h" #include "arrow/json/rapidjson_defs.h" @@ -317,7 +317,7 @@ class DecimalConverter final if (json_obj.IsString()) { int32_t precision, scale; DecimalValue d; - auto view = util::string_view(json_obj.GetString(), json_obj.GetStringLength()); + auto view = std::string_view(json_obj.GetString(), json_obj.GetStringLength()); RETURN_NOT_OK(DecimalValue::FromString(view, &d, &precision, &scale)); if (scale != decimal_type_->scale()) { return Status::Invalid("Invalid scale for decimal: expected ", @@ -359,7 +359,7 @@ class TimestampConverter final : public ConcreteConverter { if (json_obj.IsNumber()) { RETURN_NOT_OK(ConvertNumber(json_obj, *this->type_, &value)); } else if (json_obj.IsString()) { - util::string_view view(json_obj.GetString(), json_obj.GetStringLength()); + std::string_view view(json_obj.GetString(), json_obj.GetStringLength()); if (!ParseValue(*timestamp_type_, view.data(), view.size(), &value)) { return Status::Invalid("couldn't parse timestamp from ", view); } @@ -461,7 +461,7 @@ class StringConverter final return this->AppendNull(); } if (json_obj.IsString()) { - auto view = util::string_view(json_obj.GetString(), json_obj.GetStringLength()); + auto view = std::string_view(json_obj.GetString(), json_obj.GetStringLength()); return builder_->Append(view); } else { return JSONTypeError("string", json_obj.GetType()); @@ -492,7 +492,7 @@ class FixedSizeBinaryConverter final return this->AppendNull(); } if (json_obj.IsString()) { - auto view = util::string_view(json_obj.GetString(), json_obj.GetStringLength()); + auto view = std::string_view(json_obj.GetString(), json_obj.GetStringLength()); if (view.length() != static_cast(builder_->byte_width())) { std::stringstream ss; ss << "Invalid string length " << view.length() << " in JSON input for " @@ -906,7 +906,7 @@ Status GetConverter(const std::shared_ptr& type, } // namespace Result> ArrayFromJSON(const std::shared_ptr& type, - util::string_view json_string) { + std::string_view json_string) { std::shared_ptr converter; RETURN_NOT_OK(GetConverter(type, &converter)); @@ -926,12 +926,12 @@ Result> ArrayFromJSON(const std::shared_ptr& ty Result> ArrayFromJSON(const std::shared_ptr& type, const std::string& json_string) { - return ArrayFromJSON(type, util::string_view(json_string)); + return ArrayFromJSON(type, std::string_view(json_string)); } Result> ArrayFromJSON(const std::shared_ptr& type, const char* json_string) { - return ArrayFromJSON(type, util::string_view(json_string)); + return ArrayFromJSON(type, std::string_view(json_string)); } Status ChunkedArrayFromJSON(const std::shared_ptr& type, @@ -948,8 +948,8 @@ Status ChunkedArrayFromJSON(const std::shared_ptr& type, } Status DictArrayFromJSON(const std::shared_ptr& type, - util::string_view indices_json, - util::string_view dictionary_json, std::shared_ptr* out) { + std::string_view indices_json, std::string_view dictionary_json, + std::shared_ptr* out) { if (type->id() != Type::DICTIONARY) { return Status::TypeError("DictArrayFromJSON requires dictionary type, got ", *type); } @@ -965,8 +965,8 @@ Status DictArrayFromJSON(const std::shared_ptr& type, .Value(out); } -Status ScalarFromJSON(const std::shared_ptr& type, - util::string_view json_string, std::shared_ptr* out) { +Status ScalarFromJSON(const std::shared_ptr& type, std::string_view json_string, + std::shared_ptr* out) { std::shared_ptr converter; RETURN_NOT_OK(GetConverter(type, &converter)); @@ -985,7 +985,7 @@ Status ScalarFromJSON(const std::shared_ptr& type, } Status DictScalarFromJSON(const std::shared_ptr& type, - util::string_view index_json, util::string_view dictionary_json, + std::string_view index_json, std::string_view dictionary_json, std::shared_ptr* out) { if (type->id() != Type::DICTIONARY) { return Status::TypeError("DictScalarFromJSON requires dictionary type, got ", *type); diff --git a/cpp/src/arrow/ipc/json_simple.h b/cpp/src/arrow/ipc/json_simple.h index 2fb2e838375..3a730ee6a3f 100644 --- a/cpp/src/arrow/ipc/json_simple.h +++ b/cpp/src/arrow/ipc/json_simple.h @@ -21,10 +21,10 @@ #include #include +#include #include "arrow/status.h" #include "arrow/type_fwd.h" -#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" namespace arrow { @@ -42,7 +42,7 @@ Result> ArrayFromJSON(const std::shared_ptr&, ARROW_EXPORT Result> ArrayFromJSON(const std::shared_ptr&, - util::string_view json); + std::string_view json); ARROW_EXPORT Result> ArrayFromJSON(const std::shared_ptr&, @@ -54,17 +54,16 @@ Status ChunkedArrayFromJSON(const std::shared_ptr& type, std::shared_ptr* out); ARROW_EXPORT -Status DictArrayFromJSON(const std::shared_ptr&, util::string_view indices_json, - util::string_view dictionary_json, std::shared_ptr* out); +Status DictArrayFromJSON(const std::shared_ptr&, std::string_view indices_json, + std::string_view dictionary_json, std::shared_ptr* out); ARROW_EXPORT -Status ScalarFromJSON(const std::shared_ptr&, util::string_view json, +Status ScalarFromJSON(const std::shared_ptr&, std::string_view json, std::shared_ptr* out); ARROW_EXPORT -Status DictScalarFromJSON(const std::shared_ptr&, util::string_view index_json, - util::string_view dictionary_json, - std::shared_ptr* out); +Status DictScalarFromJSON(const std::shared_ptr&, std::string_view index_json, + std::string_view dictionary_json, std::shared_ptr* out); } // namespace json } // namespace internal diff --git a/cpp/src/arrow/ipc/read_write_test.cc b/cpp/src/arrow/ipc/read_write_test.cc index be6fd513e5e..b556c8ed34b 100644 --- a/cpp/src/arrow/ipc/read_write_test.cc +++ b/cpp/src/arrow/ipc/read_write_test.cc @@ -381,7 +381,7 @@ class IpcTestFixture : public io::MemoryMapFixture, public ExtensionTypesMixin { ASSERT_OK_AND_ASSIGN(temp_dir_, TemporaryDir::Make("ipc-test-")); } - std::string TempFile(util::string_view file) { + std::string TempFile(std::string_view file) { return temp_dir_->path().Join(std::string(file)).ValueOrDie().ToString(); } @@ -891,7 +891,7 @@ class RecursionLimits : public ::testing::Test, public io::MemoryMapFixture { ASSERT_OK_AND_ASSIGN(temp_dir_, TemporaryDir::Make("ipc-recursion-limits-test-")); } - std::string TempFile(util::string_view file) { + std::string TempFile(std::string_view file) { return temp_dir_->path().Join(std::string(file)).ValueOrDie().ToString(); } diff --git a/cpp/src/arrow/json/chunked_builder_test.cc b/cpp/src/arrow/json/chunked_builder_test.cc index 2d89ab9b026..d1d6e5e5fc3 100644 --- a/cpp/src/arrow/json/chunked_builder_test.cc +++ b/cpp/src/arrow/json/chunked_builder_test.cc @@ -35,7 +35,7 @@ namespace arrow { namespace json { -using util::string_view; +using std::string_view; using internal::checked_cast; using internal::GetCpuThreadPool; diff --git a/cpp/src/arrow/json/chunker.cc b/cpp/src/arrow/json/chunker.cc index b4b4d31eb94..362d8e13f5f 100644 --- a/cpp/src/arrow/json/chunker.cc +++ b/cpp/src/arrow/json/chunker.cc @@ -18,6 +18,7 @@ #include "arrow/json/chunker.h" #include +#include #include #include @@ -28,12 +29,11 @@ #include "arrow/json/options.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" -#include "arrow/util/string_view.h" namespace arrow { using internal::make_unique; -using util::string_view; +using std::string_view; namespace json { @@ -140,7 +140,7 @@ class ParsingBoundaryFinder : public BoundaryFinder { return Status::OK(); } - Status FindLast(util::string_view block, int64_t* out_pos) override { + Status FindLast(std::string_view block, int64_t* out_pos) override { const size_t block_length = block.size(); size_t consumed_length = 0; while (consumed_length < block_length) { @@ -164,7 +164,7 @@ class ParsingBoundaryFinder : public BoundaryFinder { return Status::OK(); } - Status FindNth(util::string_view partial, util::string_view block, int64_t count, + Status FindNth(std::string_view partial, std::string_view block, int64_t count, int64_t* out_pos, int64_t* num_found) override { return Status::NotImplemented("ParsingBoundaryFinder::FindNth"); } diff --git a/cpp/src/arrow/json/chunker_test.cc b/cpp/src/arrow/json/chunker_test.cc index 1b4ea4d0824..ed1328fa601 100644 --- a/cpp/src/arrow/json/chunker_test.cc +++ b/cpp/src/arrow/json/chunker_test.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include @@ -28,16 +29,19 @@ #include "arrow/json/test_common.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/logging.h" -#include "arrow/util/string_view.h" +#include "arrow/util/string.h" namespace arrow { + +using internal::StartsWith; + namespace json { // Use no nested objects and no string literals containing braces in this test. // This way the positions of '{' and '}' can be used as simple proxies // for object begin/end. -using util::string_view; +using std::string_view; template static std::shared_ptr join(Lines&& lines, std::string delimiter, @@ -154,10 +158,10 @@ void AssertStraddledChunking(Chunker& chunker, const std::shared_ptr& bu AssertChunking(chunker, first_half, 1); std::shared_ptr first_whole, partial; ASSERT_OK(chunker.Process(first_half, &first_whole, &partial)); - ASSERT_TRUE(string_view(*first_half).starts_with(string_view(*first_whole))); + ASSERT_TRUE(StartsWith(std::string_view(*first_half), std::string_view(*first_whole))); std::shared_ptr completion, rest; ASSERT_OK(chunker.ProcessWithPartial(partial, second_half, &completion, &rest)); - ASSERT_TRUE(string_view(*second_half).starts_with(string_view(*completion))); + ASSERT_TRUE(StartsWith(std::string_view(*second_half), std::string_view(*completion))); std::shared_ptr straddling; ASSERT_OK_AND_ASSIGN(straddling, ConcatenateBuffers({partial, completion})); auto length = ConsumeWholeObject(&straddling); diff --git a/cpp/src/arrow/json/converter.cc b/cpp/src/arrow/json/converter.cc index a2f584c0b7f..d677be25ae4 100644 --- a/cpp/src/arrow/json/converter.cc +++ b/cpp/src/arrow/json/converter.cc @@ -18,6 +18,7 @@ #include "arrow/json/converter.h" #include +#include #include #include "arrow/array.h" @@ -30,13 +31,12 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/decimal.h" #include "arrow/util/logging.h" -#include "arrow/util/string_view.h" #include "arrow/util/value_parsing.h" namespace arrow { using internal::checked_cast; -using util::string_view; +using std::string_view; namespace json { diff --git a/cpp/src/arrow/json/object_parser.cc b/cpp/src/arrow/json/object_parser.cc index c857cd537e7..ba4a42aec4c 100644 --- a/cpp/src/arrow/json/object_parser.cc +++ b/cpp/src/arrow/json/object_parser.cc @@ -28,7 +28,7 @@ namespace rj = arrow::rapidjson; class ObjectParser::Impl { public: - Status Parse(arrow::util::string_view json) { + Status Parse(std::string_view json) { document_.Parse(reinterpret_cast(json.data()), static_cast(json.size())); @@ -70,7 +70,7 @@ ObjectParser::ObjectParser() : impl_(new ObjectParser::Impl()) {} ObjectParser::~ObjectParser() = default; -Status ObjectParser::Parse(arrow::util::string_view json) { return impl_->Parse(json); } +Status ObjectParser::Parse(std::string_view json) { return impl_->Parse(json); } Result ObjectParser::GetString(const char* key) const { return impl_->GetString(key); diff --git a/cpp/src/arrow/json/object_parser.h b/cpp/src/arrow/json/object_parser.h index ef93201651a..8f23923d1ce 100644 --- a/cpp/src/arrow/json/object_parser.h +++ b/cpp/src/arrow/json/object_parser.h @@ -18,9 +18,9 @@ #pragma once #include +#include #include "arrow/result.h" -#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" namespace arrow { @@ -34,7 +34,7 @@ class ARROW_EXPORT ObjectParser { ObjectParser(); ~ObjectParser(); - Status Parse(arrow::util::string_view json); + Status Parse(std::string_view json); Result GetString(const char* key) const; Result GetBool(const char* key) const; diff --git a/cpp/src/arrow/json/object_writer.cc b/cpp/src/arrow/json/object_writer.cc index 06d09f81e94..3277807880c 100644 --- a/cpp/src/arrow/json/object_writer.cc +++ b/cpp/src/arrow/json/object_writer.cc @@ -32,7 +32,7 @@ class ObjectWriter::Impl { public: Impl() : root_(rj::kObjectType) {} - void SetString(arrow::util::string_view key, arrow::util::string_view value) { + void SetString(std::string_view key, std::string_view value) { rj::Document::AllocatorType& allocator = document_.GetAllocator(); rj::Value str_key(key.data(), allocator); @@ -41,7 +41,7 @@ class ObjectWriter::Impl { root_.AddMember(str_key, str_value, allocator); } - void SetBool(arrow::util::string_view key, bool value) { + void SetBool(std::string_view key, bool value) { rj::Document::AllocatorType& allocator = document_.GetAllocator(); rj::Value str_key(key.data(), allocator); @@ -66,12 +66,11 @@ ObjectWriter::ObjectWriter() : impl_(new ObjectWriter::Impl()) {} ObjectWriter::~ObjectWriter() = default; -void ObjectWriter::SetString(arrow::util::string_view key, - arrow::util::string_view value) { +void ObjectWriter::SetString(std::string_view key, std::string_view value) { impl_->SetString(key, value); } -void ObjectWriter::SetBool(arrow::util::string_view key, bool value) { +void ObjectWriter::SetBool(std::string_view key, bool value) { impl_->SetBool(key, value); } diff --git a/cpp/src/arrow/json/object_writer.h b/cpp/src/arrow/json/object_writer.h index 55ff0ce52bc..b15b09dbdac 100644 --- a/cpp/src/arrow/json/object_writer.h +++ b/cpp/src/arrow/json/object_writer.h @@ -18,8 +18,8 @@ #pragma once #include +#include -#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" namespace arrow { @@ -33,8 +33,8 @@ class ARROW_EXPORT ObjectWriter { ObjectWriter(); ~ObjectWriter(); - void SetString(arrow::util::string_view key, arrow::util::string_view value); - void SetBool(arrow::util::string_view key, bool value); + void SetString(std::string_view key, std::string_view value); + void SetBool(std::string_view key, bool value); std::string Serialize(); diff --git a/cpp/src/arrow/json/parser.cc b/cpp/src/arrow/json/parser.cc index 815fa7dc7b7..3774b578a83 100644 --- a/cpp/src/arrow/json/parser.cc +++ b/cpp/src/arrow/json/parser.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -36,7 +37,6 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/logging.h" #include "arrow/util/make_unique.h" -#include "arrow/util/string_view.h" #include "arrow/util/trie.h" #include "arrow/visit_type_inline.h" @@ -45,7 +45,7 @@ namespace arrow { using internal::BitsetStack; using internal::checked_cast; using internal::make_unique; -using util::string_view; +using std::string_view; namespace json { @@ -89,7 +89,7 @@ static arrow::internal::Trie MakeFromTagTrie() { Kind::type Kind::FromTag(const std::shared_ptr& tag) { static arrow::internal::Trie name_to_kind = MakeFromTagTrie(); DCHECK_NE(tag->FindKey("json_kind"), -1); - util::string_view name = tag->value(tag->FindKey("json_kind")); + std::string_view name = tag->value(tag->FindKey("json_kind")); DCHECK_NE(name_to_kind.Find(name), -1); return static_cast(name_to_kind.Find(name)); } diff --git a/cpp/src/arrow/json/parser_test.cc b/cpp/src/arrow/json/parser_test.cc index 2a44ed8375e..e1f346bda3b 100644 --- a/cpp/src/arrow/json/parser_test.cc +++ b/cpp/src/arrow/json/parser_test.cc @@ -21,6 +21,7 @@ #include #include +#include #include #include @@ -29,7 +30,6 @@ #include "arrow/status.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/checked_cast.h" -#include "arrow/util/string_view.h" namespace arrow { @@ -37,7 +37,7 @@ using internal::checked_cast; namespace json { -using util::string_view; +using std::string_view; void AssertUnconvertedStructArraysEqual(const StructArray& expected, const StructArray& actual); diff --git a/cpp/src/arrow/json/reader.cc b/cpp/src/arrow/json/reader.cc index 18aed0235ff..85e527c8bda 100644 --- a/cpp/src/arrow/json/reader.cc +++ b/cpp/src/arrow/json/reader.cc @@ -17,6 +17,7 @@ #include "arrow/json/reader.h" +#include #include #include @@ -33,13 +34,12 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/iterator.h" #include "arrow/util/logging.h" -#include "arrow/util/string_view.h" #include "arrow/util/task_group.h" #include "arrow/util/thread_pool.h" namespace arrow { -using util::string_view; +using std::string_view; using internal::checked_cast; using internal::GetCpuThreadPool; diff --git a/cpp/src/arrow/json/reader_test.cc b/cpp/src/arrow/json/reader_test.cc index 976343b5211..4037bf0be66 100644 --- a/cpp/src/arrow/json/reader_test.cc +++ b/cpp/src/arrow/json/reader_test.cc @@ -31,7 +31,7 @@ namespace arrow { namespace json { -using util::string_view; +using std::string_view; using internal::checked_cast; @@ -43,7 +43,7 @@ class ReaderTest : public ::testing::TestWithParam { read_options_, parse_options_)); } - void SetUpReader(util::string_view input) { + void SetUpReader(std::string_view input) { ASSERT_OK(MakeStream(input, &input_)); SetUpReader(); } diff --git a/cpp/src/arrow/json/test_common.h b/cpp/src/arrow/json/test_common.h index 508be0c9102..18007a49638 100644 --- a/cpp/src/arrow/json/test_common.h +++ b/cpp/src/arrow/json/test_common.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -34,7 +35,6 @@ #include "arrow/testing/gtest_util.h" #include "arrow/type.h" #include "arrow/util/checked_cast.h" -#include "arrow/util/string_view.h" #include "arrow/visit_type_inline.h" #include "rapidjson/document.h" @@ -51,7 +51,7 @@ namespace json { namespace rj = arrow::rapidjson; using rj::StringBuffer; -using util::string_view; +using std::string_view; using Writer = rj::Writer; inline static Status OK(bool ok) { return ok ? Status::OK() : Status::Invalid(""); } @@ -216,7 +216,7 @@ static inline std::string PrettyPrint(string_view one_line) { } template -std::string RowsOfOneColumn(util::string_view name, std::initializer_list values, +std::string RowsOfOneColumn(std::string_view name, std::initializer_list values, decltype(std::to_string(*values.begin()))* = nullptr) { std::stringstream ss; for (auto value : values) { @@ -225,7 +225,7 @@ std::string RowsOfOneColumn(util::string_view name, std::initializer_list val return ss.str(); } -inline std::string RowsOfOneColumn(util::string_view name, +inline std::string RowsOfOneColumn(std::string_view name, std::initializer_list values) { std::stringstream ss; for (auto value : values) { diff --git a/cpp/src/arrow/pretty_print.cc b/cpp/src/arrow/pretty_print.cc index ac92287f1bc..61d308a145b 100644 --- a/cpp/src/arrow/pretty_print.cc +++ b/cpp/src/arrow/pretty_print.cc @@ -26,6 +26,7 @@ #include #include // IWYU pragma: keep #include +#include #include #include @@ -41,7 +42,6 @@ #include "arrow/util/int_util_overflow.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/string.h" -#include "arrow/util/string_view.h" #include "arrow/vendored/datetime.h" #include "arrow/visit_array_inline.h" @@ -57,8 +57,8 @@ class PrettyPrinter { PrettyPrinter(const PrettyPrintOptions& options, std::ostream* sink) : options_(options), indent_(options.indent), sink_(sink) {} - inline void Write(util::string_view data); - inline void WriteIndented(util::string_view data); + inline void Write(std::string_view data); + inline void WriteIndented(std::string_view data); inline void Newline(); inline void Indent(); inline void IndentAfterNewline(); @@ -103,9 +103,9 @@ void PrettyPrinter::CloseArray(const Array& array) { (*sink_) << "]"; } -void PrettyPrinter::Write(util::string_view data) { (*sink_) << data; } +void PrettyPrinter::Write(std::string_view data) { (*sink_) << data; } -void PrettyPrinter::WriteIndented(util::string_view data) { +void PrettyPrinter::WriteIndented(std::string_view data) { Indent(); Write(data); } @@ -173,7 +173,7 @@ class ArrayPrinter : public PrettyPrinter { template Status WritePrimitiveValues(const ArrayType& array, Formatter* formatter) { - auto appender = [&](util::string_view v) { (*sink_) << v; }; + auto appender = [&](std::string_view v) { (*sink_) << v; }; auto format_func = [&](int64_t i) { (*formatter)(array.GetView(i), appender); return Status::OK(); diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 5ed92f09476..fcf44fe82f4 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -839,16 +839,16 @@ struct ScalarParseImpl { return std::move(out_); } - ScalarParseImpl(std::shared_ptr type, util::string_view s) + ScalarParseImpl(std::shared_ptr type, std::string_view s) : type_(std::move(type)), s_(s) {} std::shared_ptr type_; - util::string_view s_; + std::string_view s_; std::shared_ptr out_; }; Result> Scalar::Parse(const std::shared_ptr& type, - util::string_view s) { + std::string_view s) { return ScalarParseImpl{type, s}.Finish(); } @@ -871,9 +871,8 @@ std::shared_ptr FormatToBuffer(Formatter&& formatter, const ScalarType& if (!from.is_valid) { return Buffer::FromString("null"); } - return formatter(from.value, [&](util::string_view v) { - return Buffer::FromString(std::string(v)); - }); + return formatter( + from.value, [&](std::string_view v) { return Buffer::FromString(std::string(v)); }); } // error fallback @@ -993,8 +992,7 @@ Status CastImpl(const DateScalar& from, TimestampScalar* to) { // string to any template Status CastImpl(const StringScalar& from, ScalarType* to) { - ARROW_ASSIGN_OR_RAISE(auto out, - Scalar::Parse(to->type, util::string_view(*from.value))); + ARROW_ASSIGN_OR_RAISE(auto out, Scalar::Parse(to->type, std::string_view(*from.value))); to->value = std::move(checked_cast(*out).value); return Status::OK(); } diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 22532041eca..66e18631334 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -35,7 +36,6 @@ #include "arrow/type_traits.h" #include "arrow/util/compare.h" #include "arrow/util/decimal.h" -#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" #include "arrow/visit_type_inline.h" @@ -95,7 +95,7 @@ struct ARROW_EXPORT Scalar : public std::enable_shared_from_this, Status ValidateFull() const; static Result> Parse(const std::shared_ptr& type, - util::string_view repr); + std::string_view repr); // TODO(bkietz) add compute::CastOptions Result> CastTo(std::shared_ptr to) const; @@ -140,7 +140,7 @@ struct ARROW_EXPORT PrimitiveScalarBase : public Scalar { /// \brief Get a mutable pointer to the value of this scalar. May be null. virtual void* mutable_data() = 0; /// \brief Get an immutable view of the value of this scalar as bytes. - virtual util::string_view view() const = 0; + virtual std::string_view view() const = 0; }; template @@ -159,8 +159,8 @@ struct ARROW_EXPORT PrimitiveScalar : public PrimitiveScalarBase { ValueType value{}; void* mutable_data() override { return &value; } - util::string_view view() const override { - return util::string_view(reinterpret_cast(&value), sizeof(ValueType)); + std::string_view view() const override { + return std::string_view(reinterpret_cast(&value), sizeof(ValueType)); }; }; @@ -245,8 +245,8 @@ struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase { void* mutable_data() override { return value ? reinterpret_cast(value->mutable_data()) : NULLPTR; } - util::string_view view() const override { - return value ? util::string_view(*value) : util::string_view(); + std::string_view view() const override { + return value ? std::string_view(*value) : std::string_view(); } protected: @@ -415,9 +415,9 @@ struct ARROW_EXPORT DecimalScalar : public internal::PrimitiveScalarBase { return reinterpret_cast(value.mutable_native_endian_bytes()); } - util::string_view view() const override { - return util::string_view(reinterpret_cast(value.native_endian_bytes()), - ValueType::kByteWidth); + std::string_view view() const override { + return std::string_view(reinterpret_cast(value.native_endian_bytes()), + ValueType::kByteWidth); } ValueType value; @@ -561,7 +561,7 @@ struct ARROW_EXPORT DictionaryScalar : public internal::PrimitiveScalarBase { return internal::checked_cast(*value.index) .mutable_data(); } - util::string_view view() const override { + std::string_view view() const override { return internal::checked_cast(*value.index) .view(); } diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index bf001fc6fd9..42315ca1b62 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -59,7 +59,7 @@ void AssertMakeScalar(const Scalar& expected, MakeScalarArgs&&... args) { AssertScalarsEqual(expected, *scalar, /*verbose=*/true); } -void AssertParseScalar(const std::shared_ptr& type, const util::string_view& s, +void AssertParseScalar(const std::shared_ptr& type, const std::string_view& s, const Scalar& expected) { ASSERT_OK_AND_ASSIGN(auto scalar, Scalar::Parse(type, s)); ASSERT_OK(scalar->Validate()); @@ -643,11 +643,11 @@ TEST(TestFixedSizeBinaryScalar, MakeScalar) { AssertMakeScalar(FixedSizeBinaryScalar(buf, type), type, buf); - AssertParseScalar(type, util::string_view(data), FixedSizeBinaryScalar(buf, type)); + AssertParseScalar(type, std::string_view(data), FixedSizeBinaryScalar(buf, type)); // Wrong length ASSERT_RAISES(Invalid, MakeScalar(type, Buffer::FromString(data.substr(3))).status()); - ASSERT_RAISES(Invalid, Scalar::Parse(type, util::string_view(data).substr(3)).status()); + ASSERT_RAISES(Invalid, Scalar::Parse(type, std::string_view(data).substr(3)).status()); } TEST(TestFixedSizeBinaryScalar, ValidateErrors) { @@ -831,7 +831,7 @@ TEST(TestTimestampScalars, MakeScalar) { auto type3 = timestamp(TimeUnit::MICRO); auto type4 = timestamp(TimeUnit::NANO); - util::string_view epoch_plus_1s = "1970-01-01 00:00:01"; + std::string_view epoch_plus_1s = "1970-01-01 00:00:01"; AssertMakeScalar(TimestampScalar(1, type1), type1, int64_t(1)); AssertParseScalar(type1, epoch_plus_1s, TimestampScalar(1000, type1)); @@ -992,7 +992,7 @@ TEST(TestDayTimeIntervalScalars, Basics) { TYPED_TEST(TestNumericScalar, Cast) { auto type = TypeTraits::type_singleton(); - for (util::string_view repr : {"0", "1", "3"}) { + for (std::string_view repr : {"0", "1", "3"}) { std::shared_ptr scalar; ASSERT_OK_AND_ASSIGN(scalar, Scalar::Parse(type, repr)); @@ -1015,7 +1015,7 @@ TYPED_TEST(TestNumericScalar, Cast) { if (is_integer_type::value) { ASSERT_OK_AND_ASSIGN(auto cast_to_string, scalar->CastTo(utf8())); ASSERT_EQ( - util::string_view(*checked_cast(*cast_to_string).value), + std::string_view(*checked_cast(*cast_to_string).value), repr); } } @@ -1609,7 +1609,7 @@ class TestExtensionScalar : public ::testing::Test { } protected: - ExtensionScalar MakeUuidScalar(util::string_view value) { + ExtensionScalar MakeUuidScalar(std::string_view value) { return ExtensionScalar(std::make_shared( std::make_shared(value), storage_type_), type_); @@ -1618,10 +1618,9 @@ class TestExtensionScalar : public ::testing::Test { std::shared_ptr type_, storage_type_; const UuidType* uuid_type_{nullptr}; - const util::string_view uuid_string1_{UUID_STRING1}; - const util::string_view uuid_string2_{UUID_STRING2}; - const util::string_view uuid_json_{"[\"" UUID_STRING1 "\", \"" UUID_STRING2 - "\", null]"}; + const std::string_view uuid_string1_{UUID_STRING1}; + const std::string_view uuid_string2_{UUID_STRING2}; + const std::string_view uuid_json_{"[\"" UUID_STRING1 "\", \"" UUID_STRING2 "\", null]"}; }; #undef UUID_STRING1 diff --git a/cpp/src/arrow/stl_iterator_test.cc b/cpp/src/arrow/stl_iterator_test.cc index 652a66cb516..3fe57ebc0d4 100644 --- a/cpp/src/arrow/stl_iterator_test.cc +++ b/cpp/src/arrow/stl_iterator_test.cc @@ -128,11 +128,11 @@ TEST(ArrayIterator, RangeFor) { TEST(ArrayIterator, String) { auto array = checked_pointer_cast( ArrayFromJSON(utf8(), R"(["foo", "bar", null, "quux"])")); - std::vector> values; + std::vector> values; for (const auto v : *array) { values.push_back(v); } - std::vector> expected{"foo", "bar", {}, "quux"}; + std::vector> expected{"foo", "bar", {}, "quux"}; ASSERT_EQ(values, expected); } @@ -150,11 +150,11 @@ TEST(ArrayIterator, Boolean) { TEST(ArrayIterator, FixedSizeBinary) { auto array = checked_pointer_cast( ArrayFromJSON(fixed_size_binary(3), R"(["foo", "bar", null, "quu"])")); - std::vector> values; + std::vector> values; for (const auto v : *array) { values.push_back(v); } - std::vector> expected{"foo", "bar", {}, "quu"}; + std::vector> expected{"foo", "bar", {}, "quu"}; ASSERT_EQ(values, expected); } diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 2ba944e41f1..84879321ff1 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -410,14 +410,14 @@ void AssertDatumsApproxEqual(const Datum& expected, const Datum& actual, bool ve } std::shared_ptr ArrayFromJSON(const std::shared_ptr& type, - util::string_view json) { + std::string_view json) { EXPECT_OK_AND_ASSIGN(auto out, ipc::internal::json::ArrayFromJSON(type, json)); return out; } std::shared_ptr DictArrayFromJSON(const std::shared_ptr& type, - util::string_view indices_json, - util::string_view dictionary_json) { + std::string_view indices_json, + std::string_view dictionary_json) { std::shared_ptr out; ABORT_NOT_OK( ipc::internal::json::DictArrayFromJSON(type, indices_json, dictionary_json, &out)); @@ -432,7 +432,7 @@ std::shared_ptr ChunkedArrayFromJSON(const std::shared_ptr RecordBatchFromJSON(const std::shared_ptr& schema, - util::string_view json) { + std::string_view json) { // Parse as a StructArray auto struct_type = struct_(schema->fields()); std::shared_ptr struct_array = ArrayFromJSON(struct_type, json); @@ -442,15 +442,15 @@ std::shared_ptr RecordBatchFromJSON(const std::shared_ptr& } std::shared_ptr ScalarFromJSON(const std::shared_ptr& type, - util::string_view json) { + std::string_view json) { std::shared_ptr out; ABORT_NOT_OK(ipc::internal::json::ScalarFromJSON(type, json, &out)); return out; } std::shared_ptr DictScalarFromJSON(const std::shared_ptr& type, - util::string_view index_json, - util::string_view dictionary_json) { + std::string_view index_json, + std::string_view dictionary_json) { std::shared_ptr out; ABORT_NOT_OK( ipc::internal::json::DictScalarFromJSON(type, index_json, dictionary_json, &out)); diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index b6bfcb8e2d3..e21a2888e85 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -40,7 +41,6 @@ #include "arrow/type_traits.h" #include "arrow/util/macros.h" #include "arrow/util/string_builder.h" -#include "arrow/util/string_view.h" #include "arrow/util/type_fwd.h" // NOTE: failing must be inline in the macros below, to get correct file / line number @@ -316,16 +316,16 @@ ARROW_TESTING_EXPORT void TestInitialized(const Array& array); ARROW_TESTING_EXPORT std::shared_ptr ArrayFromJSON(const std::shared_ptr&, - util::string_view json); + std::string_view json); ARROW_TESTING_EXPORT std::shared_ptr DictArrayFromJSON(const std::shared_ptr& type, - util::string_view indices_json, - util::string_view dictionary_json); + std::string_view indices_json, + std::string_view dictionary_json); ARROW_TESTING_EXPORT std::shared_ptr RecordBatchFromJSON(const std::shared_ptr&, - util::string_view); + std::string_view); ARROW_TESTING_EXPORT std::shared_ptr ChunkedArrayFromJSON(const std::shared_ptr&, @@ -333,12 +333,12 @@ std::shared_ptr ChunkedArrayFromJSON(const std::shared_ptr ScalarFromJSON(const std::shared_ptr&, - util::string_view json); + std::string_view json); ARROW_TESTING_EXPORT std::shared_ptr DictScalarFromJSON(const std::shared_ptr&, - util::string_view index_json, - util::string_view dictionary_json); + std::string_view index_json, + std::string_view dictionary_json); ARROW_TESTING_EXPORT std::shared_ptr
TableFromJSON(const std::shared_ptr&, @@ -530,15 +530,3 @@ class ARROW_TESTING_EXPORT GatingTask { }; } // namespace arrow - -namespace nonstd { -namespace sv_lite { - -// Without this hint, GTest will print string_views as a container of char -template > -void PrintTo(const basic_string_view& view, std::ostream* os) { - *os << view; -} - -} // namespace sv_lite -} // namespace nonstd diff --git a/cpp/src/arrow/testing/json_internal.cc b/cpp/src/arrow/testing/json_internal.cc index c88e95df016..c1d45aa2e08 100644 --- a/cpp/src/arrow/testing/json_internal.cc +++ b/cpp/src/arrow/testing/json_internal.cc @@ -472,7 +472,7 @@ class ArrayWriter { return Status::OK(); } - void WriteRawNumber(util::string_view v) { + void WriteRawNumber(std::string_view v) { // Avoid RawNumber() as it misleadingly adds quotes // (see https://github.com/Tencent/rapidjson/pull/1155) writer_->RawValue(v.data(), v.size(), rj::kNumberType); @@ -503,7 +503,7 @@ class ArrayWriter { static const std::string null_string = "0"; for (int64_t i = 0; i < arr.length(); ++i) { if (arr.IsValid(i)) { - fmt(arr.Value(i), [&](util::string_view repr) { + fmt(arr.Value(i), [&](std::string_view repr) { writer_->String(repr.data(), static_cast(repr.size())); }); } else { @@ -630,7 +630,7 @@ class ArrayWriter { // Represent 64-bit integers as strings, as JSON numbers cannot represent // them exactly. ::arrow::internal::StringFormatter::ArrowType> formatter; - auto append = [this](util::string_view v) { + auto append = [this](std::string_view v) { writer_->String(v.data(), static_cast(v.size())); return Status::OK(); }; diff --git a/cpp/src/arrow/testing/matchers.h b/cpp/src/arrow/testing/matchers.h index 4d5bb695757..fa2222ee1ab 100644 --- a/cpp/src/arrow/testing/matchers.h +++ b/cpp/src/arrow/testing/matchers.h @@ -412,7 +412,7 @@ DataEqMatcher DataEq(Data&& dat) { /// Constructs an array with ArrayFromJSON against which arguments are matched inline DataEqMatcher DataEqArray(const std::shared_ptr& type, - util::string_view json) { + std::string_view json) { return DataEq(ArrayFromJSON(type, json)); } @@ -446,7 +446,7 @@ DataEqMatcher DataEqArray(T type, const std::vector>& v /// Constructs a scalar with ScalarFromJSON against which arguments are matched inline DataEqMatcher DataEqScalar(const std::shared_ptr& type, - util::string_view json) { + std::string_view json) { return DataEq(ScalarFromJSON(type, json)); } diff --git a/cpp/src/arrow/testing/random_test.cc b/cpp/src/arrow/testing/random_test.cc index 588c4f22687..30988ac0d3c 100644 --- a/cpp/src/arrow/testing/random_test.cc +++ b/cpp/src/arrow/testing/random_test.cc @@ -360,7 +360,7 @@ TEST(TypeSpecificTests, RepeatedStrings) { AssertTypeEqual(field->type(), base_array->type()); auto array = internal::checked_pointer_cast(base_array); ASSERT_OK(array->ValidateFull()); - util::string_view singular_value = array->GetView(0); + std::string_view singular_value = array->GetView(0); for (auto slot : *array) { if (!slot.has_value()) continue; ASSERT_EQ(slot, singular_value); diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index c91fa234e0a..25e32373196 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -1152,35 +1152,35 @@ Result FieldRef::FromDotPath(const std::string& dot_path_arg) { std::vector children; - util::string_view dot_path = dot_path_arg; + std::string_view dot_path = dot_path_arg; auto parse_name = [&] { std::string name; for (;;) { auto segment_end = dot_path.find_first_of("\\[."); - if (segment_end == util::string_view::npos) { + if (segment_end == std::string_view::npos) { // dot_path doesn't contain any other special characters; consume all - name.append(dot_path.begin(), dot_path.end()); + name.append(dot_path.data(), dot_path.length()); dot_path = ""; break; } if (dot_path[segment_end] != '\\') { // segment_end points to a subscript for a new FieldRef - name.append(dot_path.begin(), segment_end); + name.append(dot_path.data(), segment_end); dot_path = dot_path.substr(segment_end); break; } if (dot_path.size() == segment_end + 1) { // dot_path ends with backslash; consume it all - name.append(dot_path.begin(), dot_path.end()); + name.append(dot_path.data(), dot_path.length()); dot_path = ""; break; } // append all characters before backslash, then the character which follows it - name.append(dot_path.begin(), segment_end); + name.append(dot_path.data(), segment_end); name.push_back(dot_path[segment_end + 1]); dot_path = dot_path.substr(segment_end + 2); } @@ -1198,7 +1198,7 @@ Result FieldRef::FromDotPath(const std::string& dot_path_arg) { } case '[': { auto subscript_end = dot_path.find_first_not_of("0123456789"); - if (subscript_end == util::string_view::npos || dot_path[subscript_end] != ']') { + if (subscript_end == std::string_view::npos || dot_path[subscript_end] != ']') { return Status::Invalid("Dot path '", dot_path_arg, "' contained an unterminated index"); } diff --git a/cpp/src/arrow/util/base64.h b/cpp/src/arrow/util/base64.h index a46884d17e6..5b80e19d896 100644 --- a/cpp/src/arrow/util/base64.h +++ b/cpp/src/arrow/util/base64.h @@ -18,18 +18,18 @@ #pragma once #include +#include -#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" namespace arrow { namespace util { ARROW_EXPORT -std::string base64_encode(string_view s); +std::string base64_encode(std::string_view s); ARROW_EXPORT -std::string base64_decode(string_view s); +std::string base64_decode(std::string_view s); } // namespace util } // namespace arrow diff --git a/cpp/src/arrow/util/bitmap.h b/cpp/src/arrow/util/bitmap.h index 51a5fac97fb..a6df1e561ee 100644 --- a/cpp/src/arrow/util/bitmap.h +++ b/cpp/src/arrow/util/bitmap.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include "arrow/buffer.h" @@ -32,11 +33,11 @@ #include "arrow/util/bitmap_ops.h" #include "arrow/util/bitmap_reader.h" #include "arrow/util/bitmap_writer.h" +#include "arrow/util/bytes_view.h" #include "arrow/util/compare.h" #include "arrow/util/endian.h" #include "arrow/util/functional.h" #include "arrow/util/string_builder.h" -#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" namespace arrow { @@ -49,7 +50,7 @@ class ARROW_EXPORT Bitmap : public util::ToStringOstreamable, public util::EqualityComparable { public: template - using View = util::basic_string_view; + using View = std::basic_string_view; Bitmap() = default; diff --git a/cpp/src/arrow/util/bitmap_reader.h b/cpp/src/arrow/util/bitmap_reader.h index 110fb6958da..89006ba887b 100644 --- a/cpp/src/arrow/util/bitmap_reader.h +++ b/cpp/src/arrow/util/bitmap_reader.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include diff --git a/cpp/src/arrow/util/bitset_stack.h b/cpp/src/arrow/util/bitset_stack.h index addded94943..9b334b3605e 100644 --- a/cpp/src/arrow/util/bitset_stack.h +++ b/cpp/src/arrow/util/bitset_stack.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -38,7 +39,6 @@ #include "arrow/util/functional.h" #include "arrow/util/macros.h" #include "arrow/util/string_builder.h" -#include "arrow/util/string_view.h" #include "arrow/util/type_traits.h" #include "arrow/util/visibility.h" diff --git a/cpp/src/arrow/util/string_view.h b/cpp/src/arrow/util/bytes_view.h similarity index 72% rename from cpp/src/arrow/util/string_view.h rename to cpp/src/arrow/util/bytes_view.h index 4a51c2ebd9e..b1aacc96ed8 100644 --- a/cpp/src/arrow/util/string_view.h +++ b/cpp/src/arrow/util/bytes_view.h @@ -17,22 +17,13 @@ #pragma once -#define nssv_CONFIG_SELECT_STRING_VIEW nssv_STRING_VIEW_NONSTD - #include -#include - -#include "arrow/vendored/string_view.hpp" // IWYU pragma: export +#include namespace arrow { namespace util { -using nonstd::string_view; - -template > -using basic_string_view = nonstd::basic_string_view; - -using bytes_view = basic_string_view; +using bytes_view = std::basic_string_view; } // namespace util } // namespace arrow diff --git a/cpp/src/arrow/util/decimal.cc b/cpp/src/arrow/util/decimal.cc index 7bda91cf100..b5e5e69aa7e 100644 --- a/cpp/src/arrow/util/decimal.cc +++ b/cpp/src/arrow/util/decimal.cc @@ -287,7 +287,7 @@ static void AppendLittleEndianArrayToString(const std::array& array const uint32_t* segment = &segments[num_segments - 1]; internal::StringFormatter format; // First segment is formatted as-is. - format(*segment, [&output](util::string_view formatted) { + format(*segment, [&output](std::string_view formatted) { memcpy(output, formatted.data(), formatted.size()); output += formatted.size(); }); @@ -295,7 +295,7 @@ static void AppendLittleEndianArrayToString(const std::array& array --segment; // Right-pad formatted segment such that e.g. 123 is formatted as "000000123". output += 9; - format(*segment, [output](util::string_view formatted) { + format(*segment, [output](std::string_view formatted) { memcpy(output - formatted.size(), formatted.data(), formatted.size()); }); } @@ -355,7 +355,7 @@ static void AdjustIntegerStringWithScale(int32_t scale, std::string* str) { str->push_back('+'); } internal::StringFormatter format; - format(adjusted_exponent, [str](util::string_view formatted) { + format(adjusted_exponent, [str](std::string_view formatted) { str->append(formatted.data(), formatted.size()); }); return; @@ -397,7 +397,7 @@ std::string Decimal128::ToString(int32_t scale) const { // Iterates over input and for each group of kInt64DecimalDigits multiple out by // the appropriate power of 10 necessary to add source parsed as uint64 and // then adds the parsed value of source. -static inline void ShiftAndAdd(const util::string_view& input, uint64_t out[], +static inline void ShiftAndAdd(const std::string_view& input, uint64_t out[], size_t out_size) { for (size_t posn = 0; posn < input.size();) { const size_t group_size = std::min(kInt64DecimalDigits, input.size() - posn); @@ -420,8 +420,8 @@ static inline void ShiftAndAdd(const util::string_view& input, uint64_t out[], namespace { struct DecimalComponents { - util::string_view whole_digits; - util::string_view fractional_digits; + std::string_view whole_digits; + std::string_view fractional_digits; int32_t exponent = 0; char sign = 0; bool has_exponent = false; @@ -436,14 +436,14 @@ inline bool IsDigit(char c) { return c >= '0' && c <= '9'; } inline bool StartsExponent(char c) { return c == 'e' || c == 'E'; } inline size_t ParseDigitsRun(const char* s, size_t start, size_t size, - util::string_view* out) { + std::string_view* out) { size_t pos; for (pos = start; pos < size; ++pos) { if (!IsDigit(s[pos])) { break; } } - *out = util::string_view(s + start, pos - start); + *out = std::string_view(s + start, pos - start); return pos; } @@ -508,7 +508,7 @@ inline Status ToArrowStatus(DecimalStatus dstatus, int num_bits) { } template -Status DecimalFromString(const char* type_name, const util::string_view& s, Decimal* out, +Status DecimalFromString(const char* type_name, const std::string_view& s, Decimal* out, int32_t* precision, int32_t* scale) { if (s.empty()) { return Status::Invalid("Empty string cannot be converted to ", type_name); @@ -573,33 +573,33 @@ Status DecimalFromString(const char* type_name, const util::string_view& s, Deci } // namespace -Status Decimal128::FromString(const util::string_view& s, Decimal128* out, +Status Decimal128::FromString(const std::string_view& s, Decimal128* out, int32_t* precision, int32_t* scale) { return DecimalFromString("decimal128", s, out, precision, scale); } Status Decimal128::FromString(const std::string& s, Decimal128* out, int32_t* precision, int32_t* scale) { - return FromString(util::string_view(s), out, precision, scale); + return FromString(std::string_view(s), out, precision, scale); } Status Decimal128::FromString(const char* s, Decimal128* out, int32_t* precision, int32_t* scale) { - return FromString(util::string_view(s), out, precision, scale); + return FromString(std::string_view(s), out, precision, scale); } -Result Decimal128::FromString(const util::string_view& s) { +Result Decimal128::FromString(const std::string_view& s) { Decimal128 out; RETURN_NOT_OK(FromString(s, &out, nullptr, nullptr)); return std::move(out); } Result Decimal128::FromString(const std::string& s) { - return FromString(util::string_view(s)); + return FromString(std::string_view(s)); } Result Decimal128::FromString(const char* s) { - return FromString(util::string_view(s)); + return FromString(std::string_view(s)); } // Helper function used by Decimal128::FromBigEndian @@ -706,33 +706,33 @@ std::string Decimal256::ToString(int32_t scale) const { return str; } -Status Decimal256::FromString(const util::string_view& s, Decimal256* out, +Status Decimal256::FromString(const std::string_view& s, Decimal256* out, int32_t* precision, int32_t* scale) { return DecimalFromString("decimal256", s, out, precision, scale); } Status Decimal256::FromString(const std::string& s, Decimal256* out, int32_t* precision, int32_t* scale) { - return FromString(util::string_view(s), out, precision, scale); + return FromString(std::string_view(s), out, precision, scale); } Status Decimal256::FromString(const char* s, Decimal256* out, int32_t* precision, int32_t* scale) { - return FromString(util::string_view(s), out, precision, scale); + return FromString(std::string_view(s), out, precision, scale); } -Result Decimal256::FromString(const util::string_view& s) { +Result Decimal256::FromString(const std::string_view& s) { Decimal256 out; RETURN_NOT_OK(FromString(s, &out, nullptr, nullptr)); return std::move(out); } Result Decimal256::FromString(const std::string& s) { - return FromString(util::string_view(s)); + return FromString(std::string_view(s)); } Result Decimal256::FromString(const char* s) { - return FromString(util::string_view(s)); + return FromString(std::string_view(s)); } Result Decimal256::FromBigEndian(const uint8_t* bytes, int32_t length) { diff --git a/cpp/src/arrow/util/decimal.h b/cpp/src/arrow/util/decimal.h index 5b26f1f5431..9a863c51bf6 100644 --- a/cpp/src/arrow/util/decimal.h +++ b/cpp/src/arrow/util/decimal.h @@ -21,13 +21,13 @@ #include #include #include +#include #include #include "arrow/result.h" #include "arrow/status.h" #include "arrow/type_fwd.h" #include "arrow/util/basic_decimal.h" -#include "arrow/util/string_view.h" namespace arrow { @@ -95,13 +95,13 @@ class ARROW_EXPORT Decimal128 : public BasicDecimal128 { /// \brief Convert a decimal string to a Decimal128 value, optionally including /// precision and scale if they're passed in and not null. - static Status FromString(const util::string_view& s, Decimal128* out, - int32_t* precision, int32_t* scale = NULLPTR); + static Status FromString(const std::string_view& s, Decimal128* out, int32_t* precision, + int32_t* scale = NULLPTR); static Status FromString(const std::string& s, Decimal128* out, int32_t* precision, int32_t* scale = NULLPTR); static Status FromString(const char* s, Decimal128* out, int32_t* precision, int32_t* scale = NULLPTR); - static Result FromString(const util::string_view& s); + static Result FromString(const std::string_view& s); static Result FromString(const std::string& s); static Result FromString(const char* s); @@ -211,13 +211,13 @@ class ARROW_EXPORT Decimal256 : public BasicDecimal256 { /// \brief Convert a decimal string to a Decimal256 value, optionally including /// precision and scale if they're passed in and not null. - static Status FromString(const util::string_view& s, Decimal256* out, - int32_t* precision, int32_t* scale = NULLPTR); + static Status FromString(const std::string_view& s, Decimal256* out, int32_t* precision, + int32_t* scale = NULLPTR); static Status FromString(const std::string& s, Decimal256* out, int32_t* precision, int32_t* scale = NULLPTR); static Status FromString(const char* s, Decimal256* out, int32_t* precision, int32_t* scale = NULLPTR); - static Result FromString(const util::string_view& s); + static Result FromString(const std::string_view& s); static Result FromString(const std::string& s); static Result FromString(const char* s); diff --git a/cpp/src/arrow/util/delimiting.cc b/cpp/src/arrow/util/delimiting.cc index fe1b6ea3126..4ae3646e321 100644 --- a/cpp/src/arrow/util/delimiting.cc +++ b/cpp/src/arrow/util/delimiting.cc @@ -32,14 +32,14 @@ Status StraddlingTooLarge() { class NewlineBoundaryFinder : public BoundaryFinder { public: - Status FindFirst(util::string_view partial, util::string_view block, + Status FindFirst(std::string_view partial, std::string_view block, int64_t* out_pos) override { auto pos = block.find_first_of(newline_delimiters); - if (pos == util::string_view::npos) { + if (pos == std::string_view::npos) { *out_pos = kNoDelimiterFound; } else { auto end = block.find_first_not_of(newline_delimiters, pos); - if (end == util::string_view::npos) { + if (end == std::string_view::npos) { end = block.length(); } *out_pos = static_cast(end); @@ -47,13 +47,13 @@ class NewlineBoundaryFinder : public BoundaryFinder { return Status::OK(); } - Status FindLast(util::string_view block, int64_t* out_pos) override { + Status FindLast(std::string_view block, int64_t* out_pos) override { auto pos = block.find_last_of(newline_delimiters); - if (pos == util::string_view::npos) { + if (pos == std::string_view::npos) { *out_pos = kNoDelimiterFound; } else { auto end = block.find_first_not_of(newline_delimiters, pos); - if (end == util::string_view::npos) { + if (end == std::string_view::npos) { end = block.length(); } *out_pos = static_cast(end); @@ -61,15 +61,15 @@ class NewlineBoundaryFinder : public BoundaryFinder { return Status::OK(); } - Status FindNth(util::string_view partial, util::string_view block, int64_t count, + Status FindNth(std::string_view partial, std::string_view block, int64_t count, int64_t* out_pos, int64_t* num_found) override { - DCHECK(partial.find_first_of(newline_delimiters) == util::string_view::npos); + DCHECK(partial.find_first_of(newline_delimiters) == std::string_view::npos); int64_t found = 0; int64_t pos = kNoDelimiterFound; auto cur_pos = block.find_first_of(newline_delimiters); - while (cur_pos != util::string_view::npos) { + while (cur_pos != std::string_view::npos) { if (block[cur_pos] == '\r' && cur_pos + 1 < block.length() && block[cur_pos + 1] == '\n') { cur_pos += 2; @@ -108,7 +108,7 @@ Chunker::Chunker(std::shared_ptr delimiter) Status Chunker::Process(std::shared_ptr block, std::shared_ptr* whole, std::shared_ptr* partial) { int64_t last_pos = -1; - RETURN_NOT_OK(boundary_finder_->FindLast(util::string_view(*block), &last_pos)); + RETURN_NOT_OK(boundary_finder_->FindLast(std::string_view(*block), &last_pos)); if (last_pos == BoundaryFinder::kNoDelimiterFound) { // No delimiter found *whole = SliceBuffer(block, 0, 0); @@ -132,8 +132,8 @@ Status Chunker::ProcessWithPartial(std::shared_ptr partial, return Status::OK(); } int64_t first_pos = -1; - RETURN_NOT_OK(boundary_finder_->FindFirst(util::string_view(*partial), - util::string_view(*block), &first_pos)); + RETURN_NOT_OK(boundary_finder_->FindFirst(std::string_view(*partial), + std::string_view(*block), &first_pos)); if (first_pos == BoundaryFinder::kNoDelimiterFound) { // No delimiter in block => the current object is too large for block size return StraddlingTooLarge(); @@ -155,8 +155,8 @@ Status Chunker::ProcessFinal(std::shared_ptr partial, return Status::OK(); } int64_t first_pos = -1; - RETURN_NOT_OK(boundary_finder_->FindFirst(util::string_view(*partial), - util::string_view(*block), &first_pos)); + RETURN_NOT_OK(boundary_finder_->FindFirst(std::string_view(*partial), + std::string_view(*block), &first_pos)); if (first_pos == BoundaryFinder::kNoDelimiterFound) { // No delimiter in block => it's entirely a completion of partial *completion = block; @@ -175,7 +175,7 @@ Status Chunker::ProcessSkip(std::shared_ptr partial, int64_t pos; int64_t num_found; ARROW_RETURN_NOT_OK(boundary_finder_->FindNth( - util::string_view(*partial), util::string_view(*block), *count, &pos, &num_found)); + std::string_view(*partial), std::string_view(*block), *count, &pos, &num_found)); if (pos == BoundaryFinder::kNoDelimiterFound) { return StraddlingTooLarge(); } diff --git a/cpp/src/arrow/util/delimiting.h b/cpp/src/arrow/util/delimiting.h index b4b868340db..161ad0bfddf 100644 --- a/cpp/src/arrow/util/delimiting.h +++ b/cpp/src/arrow/util/delimiting.h @@ -19,10 +19,10 @@ #include #include +#include #include "arrow/status.h" #include "arrow/util/macros.h" -#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" namespace arrow { @@ -43,7 +43,7 @@ class ARROW_EXPORT BoundaryFinder { /// The returned `out_pos` is relative to `block`'s start and should point /// to the first character after the first delimiter. /// `out_pos` will be -1 if no delimiter is found. - virtual Status FindFirst(util::string_view partial, util::string_view block, + virtual Status FindFirst(std::string_view partial, std::string_view block, int64_t* out_pos) = 0; /// \brief Find the position of the last delimiter inside block @@ -51,7 +51,7 @@ class ARROW_EXPORT BoundaryFinder { /// The returned `out_pos` is relative to `block`'s start and should point /// to the first character after the last delimiter. /// `out_pos` will be -1 if no delimiter is found. - virtual Status FindLast(util::string_view block, int64_t* out_pos) = 0; + virtual Status FindLast(std::string_view block, int64_t* out_pos) = 0; /// \brief Find the position of the Nth delimiter inside the block /// @@ -63,8 +63,8 @@ class ARROW_EXPORT BoundaryFinder { /// `out_pos` will be -1 if no delimiter is found. /// /// The returned `num_found` is the number of delimiters actually found - virtual Status FindNth(util::string_view partial, util::string_view block, - int64_t count, int64_t* out_pos, int64_t* num_found) = 0; + virtual Status FindNth(std::string_view partial, std::string_view block, int64_t count, + int64_t* out_pos, int64_t* num_found) = 0; static constexpr int64_t kNoDelimiterFound = -1; diff --git a/cpp/src/arrow/util/formatting.h b/cpp/src/arrow/util/formatting.h index 335aba8c5e3..a69c7131c37 100644 --- a/cpp/src/arrow/util/formatting.h +++ b/cpp/src/arrow/util/formatting.h @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -33,7 +34,6 @@ #include "arrow/type_traits.h" #include "arrow/util/double_conversion.h" #include "arrow/util/macros.h" -#include "arrow/util/string_view.h" #include "arrow/util/time.h" #include "arrow/util/visibility.h" #include "arrow/vendored/datetime.h" @@ -60,7 +60,7 @@ template using enable_if_formattable = enable_if_t::value, R>; template -using Return = decltype(std::declval()(util::string_view{})); +using Return = decltype(std::declval()(std::string_view{})); ///////////////////////////////////////////////////////////////////////// // Boolean formatting @@ -76,10 +76,10 @@ class StringFormatter { Return operator()(bool value, Appender&& append) { if (value) { const char string[] = "true"; - return append(util::string_view(string)); + return append(std::string_view(string)); } else { const char string[] = "false"; - return append(util::string_view(string)); + return append(std::string_view(string)); } } }; @@ -135,8 +135,8 @@ void FormatAllDigitsLeftPadded(Int value, size_t pad, char pad_char, char** curs } template -util::string_view ViewDigitBuffer(const std::array& buffer, - char* cursor) { +std::string_view ViewDigitBuffer(const std::array& buffer, + char* cursor) { auto buffer_end = buffer.data() + BUFFER_SIZE; return {cursor, static_cast(buffer_end - cursor)}; } @@ -260,7 +260,7 @@ class FloatToStringFormatterMixin : public FloatToStringFormatter { Return operator()(value_type value, Appender&& append) { char buffer[buffer_size]; int size = FormatFloat(value, buffer, buffer_size); - return append(util::string_view(buffer, size)); + return append(std::string_view(buffer, size)); } }; diff --git a/cpp/src/arrow/util/formatting_util_test.cc b/cpp/src/arrow/util/formatting_util_test.cc index a5760859990..eddf76fe845 100644 --- a/cpp/src/arrow/util/formatting_util_test.cc +++ b/cpp/src/arrow/util/formatting_util_test.cc @@ -33,7 +33,7 @@ using internal::StringFormatter; class StringAppender { public: - Status operator()(util::string_view v) { + Status operator()(std::string_view v) { string_.append(v.data(), v.size()); return Status::OK(); } diff --git a/cpp/src/arrow/util/hashing.h b/cpp/src/arrow/util/hashing.h index ca5a6c766bd..bb04a364cda 100644 --- a/cpp/src/arrow/util/hashing.h +++ b/cpp/src/arrow/util/hashing.h @@ -103,11 +103,11 @@ struct ScalarHelper::value> template struct ScalarHelper::value>> + enable_if_t::value>> : public ScalarHelperBase { - // ScalarHelper specialization for util::string_view + // ScalarHelper specialization for std::string_view - static hash_t ComputeHash(const util::string_view& value) { + static hash_t ComputeHash(const std::string_view& value) { return ComputeStringHash(value.data(), static_cast(value.size())); } }; @@ -641,7 +641,7 @@ class BinaryMemoTable : public MemoTable { } } - int32_t Get(const util::string_view& value) const { + int32_t Get(const std::string_view& value) const { return Get(value.data(), static_cast(value.length())); } @@ -669,7 +669,7 @@ class BinaryMemoTable : public MemoTable { } template - Status GetOrInsert(const util::string_view& value, Func1&& on_found, + Status GetOrInsert(const std::string_view& value, Func1&& on_found, Func2&& on_not_found, int32_t* out_memo_index) { return GetOrInsert(value.data(), static_cast(value.length()), std::forward(on_found), std::forward(on_not_found), @@ -682,7 +682,7 @@ class BinaryMemoTable : public MemoTable { data, length, [](int32_t i) {}, [](int32_t i) {}, out_memo_index); } - Status GetOrInsert(const util::string_view& value, int32_t* out_memo_index) { + Status GetOrInsert(const std::string_view& value, int32_t* out_memo_index) { return GetOrInsert(value.data(), static_cast(value.length()), out_memo_index); } @@ -817,8 +817,8 @@ class BinaryMemoTable : public MemoTable { } // Visit the stored values in insertion order. - // The visitor function should have the signature `void(util::string_view)` - // or `void(const util::string_view&)`. + // The visitor function should have the signature `void(std::string_view)` + // or `void(const std::string_view&)`. template void VisitValues(int32_t start, VisitFunc&& visit) const { for (int32_t i = start; i < size(); ++i) { @@ -841,8 +841,8 @@ class BinaryMemoTable : public MemoTable { std::pair Lookup(hash_t h, const void* data, builder_offset_type length) const { auto cmp_func = [=](const Payload* payload) { - util::string_view lhs = binary_builder_.GetView(payload->memo_index); - util::string_view rhs(static_cast(data), length); + std::string_view lhs = binary_builder_.GetView(payload->memo_index); + std::string_view rhs(static_cast(data), length); return lhs == rhs; }; return hash_table_.Lookup(h, cmp_func); @@ -850,7 +850,7 @@ class BinaryMemoTable : public MemoTable { public: Status MergeTable(const BinaryMemoTable& other_table) { - other_table.VisitValues(0, [this](const util::string_view& other_value) { + other_table.VisitValues(0, [this](const std::string_view& other_value) { int32_t unused; DCHECK_OK(this->GetOrInsert(other_value, &unused)); }); @@ -918,7 +918,7 @@ struct StringViewHash { // std::hash compatible hasher for use with std::unordered_* // (the std::hash specialization provided by nonstd constructs std::string // temporaries then invokes std::hash against those) - hash_t operator()(const util::string_view& value) const { + hash_t operator()(const std::string_view& value) const { return ComputeStringHash<0>(value.data(), static_cast(value.size())); } }; diff --git a/cpp/src/arrow/util/hashing_test.cc b/cpp/src/arrow/util/hashing_test.cc index 116e305e59e..6589f098afd 100644 --- a/cpp/src/arrow/util/hashing_test.cc +++ b/cpp/src/arrow/util/hashing_test.cc @@ -440,7 +440,7 @@ TEST(BinaryMemoTable, Basics) { { const int32_t start_offset = 1; std::vector actual; - table.VisitValues(start_offset, [&](const util::string_view& v) { + table.VisitValues(start_offset, [&](const std::string_view& v) { actual.emplace_back(v.data(), v.length()); }); EXPECT_THAT(actual, testing::ElementsAre(B, C, D, E, F, "")); diff --git a/cpp/src/arrow/util/reflection_internal.h b/cpp/src/arrow/util/reflection_internal.h index 0440a2eb563..2e994aa4b70 100644 --- a/cpp/src/arrow/util/reflection_internal.h +++ b/cpp/src/arrow/util/reflection_internal.h @@ -18,11 +18,11 @@ #pragma once #include +#include #include #include #include "arrow/type_traits.h" -#include "arrow/util/string_view.h" namespace arrow { namespace internal { @@ -81,14 +81,14 @@ struct DataMemberProperty { void set(Class* obj, Type value) const { (*obj).*ptr_ = std::move(value); } - constexpr util::string_view name() const { return name_; } + constexpr std::string_view name() const { return name_; } - util::string_view name_; + std::string_view name_; Type Class::*ptr_; }; template -constexpr DataMemberProperty DataMember(util::string_view name, +constexpr DataMemberProperty DataMember(std::string_view name, Type Class::*ptr) { return {name, ptr}; } diff --git a/cpp/src/arrow/util/reflection_test.cc b/cpp/src/arrow/util/reflection_test.cc index 8ca9077ddc6..d2d6379bece 100644 --- a/cpp/src/arrow/util/reflection_test.cc +++ b/cpp/src/arrow/util/reflection_test.cc @@ -16,6 +16,7 @@ // under the License. #include +#include #include @@ -48,7 +49,7 @@ struct EqualsImpl { template struct ToStringImpl { template - ToStringImpl(util::string_view class_name, const Class& obj, const Properties& props) + ToStringImpl(std::string_view class_name, const Class& obj, const Properties& props) : class_name_(class_name), obj_(obj), members_(props.size()) { props.ForEach(*this); } @@ -61,10 +62,10 @@ struct ToStringImpl { } std::string Finish() { - return class_name_.to_string() + "{" + JoinStrings(members_, ",") + "}"; + return std::string(class_name_) + "{" + JoinStrings(members_, ",") + "}"; } - util::string_view class_name_; + std::string_view class_name_; const Class& obj_; std::vector members_; }; @@ -73,7 +74,7 @@ struct ToStringImpl { template struct FromStringImpl { template - FromStringImpl(util::string_view class_name, util::string_view repr, + FromStringImpl(std::string_view class_name, std::string_view repr, const Properties& props) { Init(class_name, repr, props.size()); props.ForEach(*this); @@ -81,8 +82,8 @@ struct FromStringImpl { void Fail() { obj_ = std::nullopt; } - void Init(util::string_view class_name, util::string_view repr, size_t num_properties) { - if (!repr.starts_with(class_name)) return Fail(); + void Init(std::string_view class_name, std::string_view repr, size_t num_properties) { + if (!StartsWith(repr, class_name)) return Fail(); repr = repr.substr(class_name.size()); if (repr.empty()) return Fail(); @@ -99,7 +100,7 @@ struct FromStringImpl { if (!obj_) return; auto first_colon = members_[i].find_first_of(':'); - if (first_colon == util::string_view::npos) return Fail(); + if (first_colon == std::string_view::npos) return Fail(); auto name = members_[i].substr(0, first_colon); if (name != prop.name()) return Fail(); @@ -107,7 +108,7 @@ struct FromStringImpl { auto value_repr = members_[i].substr(first_colon + 1); typename Property::Type value; try { - std::stringstream ss(value_repr.to_string()); + std::stringstream ss{std::string{value_repr}}; ss >> value; if (!ss.eof()) return Fail(); } catch (...) { @@ -117,7 +118,7 @@ struct FromStringImpl { } std::optional obj_ = Class{}; - std::vector members_; + std::vector members_; }; // unmodified structure which we wish to reflect on: @@ -146,7 +147,7 @@ std::string ToString(const Person& obj) { void PrintTo(const Person& obj, std::ostream* os) { *os << ToString(obj); } -std::optional PersonFromString(util::string_view repr) { +std::optional PersonFromString(std::string_view repr) { return FromStringImpl("Person", repr, kPersonProperties).obj_; } diff --git a/cpp/src/arrow/util/string.cc b/cpp/src/arrow/util/string.cc index 09df881a9b0..2055b4f47ea 100644 --- a/cpp/src/arrow/util/string.cc +++ b/cpp/src/arrow/util/string.cc @@ -69,9 +69,9 @@ std::string HexEncode(const char* data, size_t length) { return HexEncode(reinterpret_cast(data), length); } -std::string HexEncode(util::string_view str) { return HexEncode(str.data(), str.size()); } +std::string HexEncode(std::string_view str) { return HexEncode(str.data(), str.size()); } -std::string Escape(util::string_view str) { return Escape(str.data(), str.size()); } +std::string Escape(std::string_view str) { return Escape(str.data(), str.size()); } Status ParseHexValue(const char* data, uint8_t* out) { char c1 = data[0]; @@ -92,9 +92,9 @@ Status ParseHexValue(const char* data, uint8_t* out) { namespace internal { -std::vector SplitString(util::string_view v, char delimiter, - int64_t limit) { - std::vector parts; +std::vector SplitString(std::string_view v, char delimiter, + int64_t limit) { + std::vector parts; size_t start = 0, end; while (true) { if (limit > 0 && static_cast(limit - 1) <= parts.size()) { @@ -113,7 +113,7 @@ std::vector SplitString(util::string_view v, char delimiter, template static std::string JoinStringLikes(const std::vector& strings, - util::string_view delimiter) { + std::string_view delimiter) { if (strings.size() == 0) { return ""; } @@ -125,13 +125,13 @@ static std::string JoinStringLikes(const std::vector& strings, return out; } -std::string JoinStrings(const std::vector& strings, - util::string_view delimiter) { +std::string JoinStrings(const std::vector& strings, + std::string_view delimiter) { return JoinStringLikes(strings, delimiter); } std::string JoinStrings(const std::vector& strings, - util::string_view delimiter) { + std::string_view delimiter) { return JoinStringLikes(strings, delimiter); } @@ -152,7 +152,7 @@ std::string TrimString(std::string value) { return value; } -bool AsciiEqualsCaseInsensitive(util::string_view left, util::string_view right) { +bool AsciiEqualsCaseInsensitive(std::string_view left, std::string_view right) { // TODO: ASCII validation if (left.size() != right.size()) { return false; @@ -166,7 +166,7 @@ bool AsciiEqualsCaseInsensitive(util::string_view left, util::string_view right) return true; } -std::string AsciiToLower(util::string_view value) { +std::string AsciiToLower(std::string_view value) { // TODO: ASCII validation std::string result = std::string(value); std::transform(result.begin(), result.end(), result.begin(), @@ -174,7 +174,7 @@ std::string AsciiToLower(util::string_view value) { return result; } -std::string AsciiToUpper(util::string_view value) { +std::string AsciiToUpper(std::string_view value) { // TODO: ASCII validation std::string result = std::string(value); std::transform(result.begin(), result.end(), result.begin(), @@ -182,17 +182,17 @@ std::string AsciiToUpper(util::string_view value) { return result; } -std::optional Replace(util::string_view s, util::string_view token, - util::string_view replacement) { +std::optional Replace(std::string_view s, std::string_view token, + std::string_view replacement) { size_t token_start = s.find(token); if (token_start == std::string::npos) { return std::nullopt; } - return s.substr(0, token_start).to_string() + replacement.to_string() + - s.substr(token_start + token.size()).to_string(); + return std::string(s.substr(0, token_start)) + std::string(replacement) + + std::string(s.substr(token_start + token.size())); } -Result ParseBoolean(util::string_view value) { +Result ParseBoolean(std::string_view value) { if (AsciiEqualsCaseInsensitive(value, "true") || value == "1") { return true; } else if (AsciiEqualsCaseInsensitive(value, "false") || value == "0") { diff --git a/cpp/src/arrow/util/string.h b/cpp/src/arrow/util/string.h index fd9a3d1e063..ec2ccd11ef5 100644 --- a/cpp/src/arrow/util/string.h +++ b/cpp/src/arrow/util/string.h @@ -19,10 +19,10 @@ #include #include +#include #include #include "arrow/result.h" -#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" namespace arrow { @@ -35,47 +35,59 @@ ARROW_EXPORT std::string Escape(const char* data, size_t length); ARROW_EXPORT std::string HexEncode(const char* data, size_t length); -ARROW_EXPORT std::string HexEncode(util::string_view str); +ARROW_EXPORT std::string HexEncode(std::string_view str); -ARROW_EXPORT std::string Escape(util::string_view str); +ARROW_EXPORT std::string Escape(std::string_view str); ARROW_EXPORT Status ParseHexValue(const char* data, uint8_t* out); namespace internal { +/// Like std::string_view::starts_with in C++20 +inline bool StartsWith(std::string_view s, std::string_view prefix) { + return s.length() >= prefix.length() && + (s.empty() || s.substr(0, prefix.length()) == prefix); +} + +/// Like std::string_view::ends_with in C++20 +inline bool EndsWith(std::string_view s, std::string_view suffix) { + return s.length() >= suffix.length() && + (s.empty() || s.substr(s.length() - suffix.length()) == suffix); +} + /// \brief Split a string with a delimiter ARROW_EXPORT -std::vector SplitString(util::string_view v, char delim, - int64_t limit = 0); +std::vector SplitString(std::string_view v, char delim, + int64_t limit = 0); /// \brief Join strings with a delimiter ARROW_EXPORT -std::string JoinStrings(const std::vector& strings, - util::string_view delimiter); +std::string JoinStrings(const std::vector& strings, + std::string_view delimiter); /// \brief Join strings with a delimiter ARROW_EXPORT std::string JoinStrings(const std::vector& strings, - util::string_view delimiter); + std::string_view delimiter); /// \brief Trim whitespace from left and right sides of string ARROW_EXPORT std::string TrimString(std::string value); ARROW_EXPORT -bool AsciiEqualsCaseInsensitive(util::string_view left, util::string_view right); +bool AsciiEqualsCaseInsensitive(std::string_view left, std::string_view right); ARROW_EXPORT -std::string AsciiToLower(util::string_view value); +std::string AsciiToLower(std::string_view value); ARROW_EXPORT -std::string AsciiToUpper(util::string_view value); +std::string AsciiToUpper(std::string_view value); /// \brief Search for the first instance of a token and replace it or return nullopt if /// the token is not found. ARROW_EXPORT -std::optional Replace(util::string_view s, util::string_view token, - util::string_view replacement); +std::optional Replace(std::string_view s, std::string_view token, + std::string_view replacement); /// \brief Get boolean value from string /// @@ -83,6 +95,7 @@ std::optional Replace(util::string_view s, util::string_view token, /// If "0", "false" (case-insensitive), returns false /// Otherwise, returns Status::Invalid ARROW_EXPORT -arrow::Result ParseBoolean(util::string_view value); +arrow::Result ParseBoolean(std::string_view value); + } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/string_test.cc b/cpp/src/arrow/util/string_test.cc index 2aa6fccbd9a..a1aac17ab50 100644 --- a/cpp/src/arrow/util/string_test.cc +++ b/cpp/src/arrow/util/string_test.cc @@ -166,5 +166,33 @@ TEST(SplitString, LimitZero) { EXPECT_EQ(parts[2], "c"); } +TEST(StartsWith, Basics) { + std::string empty{}; + std::string abc{"abc"}; + std::string abcdef{"abcdef"}; + std::string def{"def"}; + ASSERT_TRUE(StartsWith(empty, empty)); + ASSERT_TRUE(StartsWith(abc, empty)); + ASSERT_TRUE(StartsWith(abc, abc)); + ASSERT_TRUE(StartsWith(abcdef, abc)); + ASSERT_FALSE(StartsWith(abc, abcdef)); + ASSERT_FALSE(StartsWith(def, abcdef)); + ASSERT_FALSE(StartsWith(abcdef, def)); +} + +TEST(EndsWith, Basics) { + std::string empty{}; + std::string abc{"abc"}; + std::string abcdef{"abcdef"}; + std::string def{"def"}; + ASSERT_TRUE(EndsWith(empty, empty)); + ASSERT_TRUE(EndsWith(abc, empty)); + ASSERT_TRUE(EndsWith(abc, abc)); + ASSERT_TRUE(EndsWith(abcdef, def)); + ASSERT_FALSE(EndsWith(abcdef, abc)); + ASSERT_FALSE(EndsWith(def, abcdef)); + ASSERT_FALSE(EndsWith(abcdef, abc)); +} + } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/trie.cc b/cpp/src/arrow/util/trie.cc index 7fa7f852eb4..ec2aed302f5 100644 --- a/cpp/src/arrow/util/trie.cc +++ b/cpp/src/arrow/util/trie.cc @@ -91,7 +91,7 @@ Status TrieBuilder::AppendChildNode(Trie::Node* parent, uint8_t ch, Trie::Node&& } Status TrieBuilder::CreateChildNode(Trie::Node* parent, uint8_t ch, - util::string_view substring) { + std::string_view substring) { const auto kMaxSubstringLength = Trie::kMaxSubstringLength; while (substring.length() > kMaxSubstringLength) { @@ -112,7 +112,7 @@ Status TrieBuilder::CreateChildNode(Trie::Node* parent, uint8_t ch, } Status TrieBuilder::CreateChildNode(Trie::Node* parent, char ch, - util::string_view substring) { + std::string_view substring) { return CreateChildNode(parent, static_cast(ch), substring); } @@ -147,7 +147,7 @@ Status TrieBuilder::SplitNode(fast_index_type node_index, fast_index_type split_ return Status::OK(); } -Status TrieBuilder::Append(util::string_view s, bool allow_duplicate) { +Status TrieBuilder::Append(std::string_view s, bool allow_duplicate) { // Find or create node for string fast_index_type node_index = 0; fast_index_type pos = 0; diff --git a/cpp/src/arrow/util/trie.h b/cpp/src/arrow/util/trie.h index b250cca647d..7815d4d1ecc 100644 --- a/cpp/src/arrow/util/trie.h +++ b/cpp/src/arrow/util/trie.h @@ -23,12 +23,12 @@ #include #include #include +#include #include #include #include "arrow/status.h" #include "arrow/util/macros.h" -#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" namespace arrow { @@ -45,10 +45,10 @@ class SmallString { template SmallString(const T& v) { // NOLINT implicit constructor - *this = util::string_view(v); + *this = std::string_view(v); } - SmallString& operator=(const util::string_view s) { + SmallString& operator=(const std::string_view s) { #ifndef NDEBUG CheckSize(s.size()); #endif @@ -58,18 +58,16 @@ class SmallString { } SmallString& operator=(const std::string& s) { - *this = util::string_view(s); + *this = std::string_view(s); return *this; } SmallString& operator=(const char* s) { - *this = util::string_view(s); + *this = std::string_view(s); return *this; } - explicit operator util::string_view() const { - return util::string_view(data_, length_); - } + explicit operator std::string_view() const { return std::string_view(data_, length_); } const char* data() const { return data_; } size_t length() const { return length_; } @@ -82,21 +80,21 @@ class SmallString { } SmallString substr(size_t pos) const { - return SmallString(util::string_view(*this).substr(pos)); + return SmallString(std::string_view(*this).substr(pos)); } SmallString substr(size_t pos, size_t count) const { - return SmallString(util::string_view(*this).substr(pos, count)); + return SmallString(std::string_view(*this).substr(pos, count)); } template bool operator==(T&& other) const { - return util::string_view(*this) == util::string_view(std::forward(other)); + return std::string_view(*this) == std::string_view(std::forward(other)); } template bool operator!=(T&& other) const { - return util::string_view(*this) != util::string_view(std::forward(other)); + return std::string_view(*this) != std::string_view(std::forward(other)); } protected: @@ -108,7 +106,7 @@ class SmallString { template std::ostream& operator<<(std::ostream& os, const SmallString& str) { - return os << util::string_view(str); + return os << std::string_view(str); } // A trie class for byte strings, optimized for small sets of short strings. @@ -123,7 +121,7 @@ class ARROW_EXPORT Trie { Trie(Trie&&) = default; Trie& operator=(Trie&&) = default; - int32_t Find(util::string_view s) const { + int32_t Find(std::string_view s) const { const Node* node = &nodes_[0]; fast_index_type pos = 0; if (s.length() > static_cast(kMaxIndex)) { @@ -222,7 +220,7 @@ class ARROW_EXPORT TrieBuilder { public: TrieBuilder(); - Status Append(util::string_view s, bool allow_duplicate = false); + Status Append(std::string_view s, bool allow_duplicate = false); Trie Finish(); protected: @@ -233,8 +231,8 @@ class ARROW_EXPORT TrieBuilder { // Append an already constructed child node to the parent Status AppendChildNode(Trie::Node* parent, uint8_t ch, Trie::Node&& node); // Create a matching child node from this parent - Status CreateChildNode(Trie::Node* parent, uint8_t ch, util::string_view substring); - Status CreateChildNode(Trie::Node* parent, char ch, util::string_view substring); + Status CreateChildNode(Trie::Node* parent, uint8_t ch, std::string_view substring); + Status CreateChildNode(Trie::Node* parent, char ch, std::string_view substring); Trie trie_; diff --git a/cpp/src/arrow/util/trie_benchmark.cc b/cpp/src/arrow/util/trie_benchmark.cc index 868accc3744..b938f87d8d1 100644 --- a/cpp/src/arrow/util/trie_benchmark.cc +++ b/cpp/src/arrow/util/trie_benchmark.cc @@ -86,7 +86,7 @@ BENCHMARK(TrieLookupNotFound); #ifdef ARROW_WITH_BENCHMARKS_REFERENCE -static inline bool InlinedNullLookup(util::string_view s) { +static inline bool InlinedNullLookup(std::string_view s) { // An inlined version of trie lookup for a specific set of strings // (see AllNulls()) auto size = s.length(); diff --git a/cpp/src/arrow/util/trie_test.cc b/cpp/src/arrow/util/trie_test.cc index cfe66689da5..9c6b7678a46 100644 --- a/cpp/src/arrow/util/trie_test.cc +++ b/cpp/src/arrow/util/trie_test.cc @@ -36,7 +36,7 @@ TEST(SmallString, Basics) { { SS s; ASSERT_EQ(s.length(), 0); - ASSERT_EQ(util::string_view(s), util::string_view("")); + ASSERT_EQ(std::string_view(s), std::string_view("")); ASSERT_EQ(s, ""); ASSERT_NE(s, "x"); ASSERT_EQ(sizeof(s), 6); @@ -44,7 +44,7 @@ TEST(SmallString, Basics) { { SS s("abc"); ASSERT_EQ(s.length(), 3); - ASSERT_EQ(util::string_view(s), util::string_view("abc")); + ASSERT_EQ(std::string_view(s), std::string_view("abc")); ASSERT_EQ(std::memcmp(s.data(), "abc", 3), 0); ASSERT_EQ(s, "abc"); ASSERT_NE(s, "ab"); @@ -55,23 +55,23 @@ TEST(SmallString, Assign) { using SS = SmallString<5>; auto s = SS(); - s = util::string_view("abc"); + s = std::string_view("abc"); ASSERT_EQ(s.length(), 3); - ASSERT_EQ(util::string_view(s), util::string_view("abc")); + ASSERT_EQ(std::string_view(s), std::string_view("abc")); ASSERT_EQ(std::memcmp(s.data(), "abc", 3), 0); ASSERT_EQ(s, "abc"); ASSERT_NE(s, "ab"); s = std::string("ghijk"); ASSERT_EQ(s.length(), 5); - ASSERT_EQ(util::string_view(s), util::string_view("ghijk")); + ASSERT_EQ(std::string_view(s), std::string_view("ghijk")); ASSERT_EQ(std::memcmp(s.data(), "ghijk", 5), 0); ASSERT_EQ(s, "ghijk"); ASSERT_NE(s, ""); s = SS("xy"); ASSERT_EQ(s.length(), 2); - ASSERT_EQ(util::string_view(s), util::string_view("xy")); + ASSERT_EQ(std::string_view(s), std::string_view("xy")); ASSERT_EQ(std::memcmp(s.data(), "xy", 2), 0); ASSERT_EQ(s, "xy"); ASSERT_NE(s, "xyz"); diff --git a/cpp/src/arrow/util/uri.cc b/cpp/src/arrow/util/uri.cc index abfc9de8b49..ced1b18404c 100644 --- a/cpp/src/arrow/util/uri.cc +++ b/cpp/src/arrow/util/uri.cc @@ -20,9 +20,9 @@ #include #include #include +#include #include -#include "arrow/util/string_view.h" #include "arrow/util/value_parsing.h" #include "arrow/vendored/uriparser/Uri.h" @@ -31,7 +31,7 @@ namespace internal { namespace { -util::string_view TextRangeToView(const UriTextRangeStructA& range) { +std::string_view TextRangeToView(const UriTextRangeStructA& range) { if (range.first == nullptr) { return ""; } else { @@ -50,7 +50,7 @@ std::string TextRangeToString(const UriTextRangeStructA& range) { bool IsTextRangeSet(const UriTextRangeStructA& range) { return range.first != nullptr; } #ifdef _WIN32 -bool IsDriveSpec(const util::string_view s) { +bool IsDriveSpec(const std::string_view s) { return (s.length() >= 2 && s[1] == ':' && ((s[0] >= 'A' && s[0] <= 'Z') || (s[0] >= 'a' && s[0] <= 'z'))); } @@ -72,7 +72,7 @@ std::string UriEscape(const std::string& s) { return escaped; } -std::string UriUnescape(const util::string_view s) { +std::string UriUnescape(const std::string_view s) { std::string result(s); if (!result.empty()) { auto end = uriUnescapeInPlaceA(&result[0]); @@ -94,7 +94,7 @@ std::string UriEncodeHost(const std::string& host) { } } -bool IsValidUriScheme(const arrow::util::string_view s) { +bool IsValidUriScheme(const std::string_view s) { auto is_alpha = [](char c) { return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z'); }; auto is_scheme_char = [&](char c) { return is_alpha(c) || (c >= '0' && c <= '9') || c == '+' || c == '-' || c == '.'; @@ -133,7 +133,7 @@ struct Uri::Impl { std::vector data_; std::string string_rep_; int32_t port_; - std::vector path_segments_; + std::vector path_segments_; bool is_file_uri_; bool is_absolute_path_; }; @@ -162,7 +162,7 @@ int32_t Uri::port() const { return impl_->port_; } std::string Uri::username() const { auto userpass = TextRangeToView(impl_->uri_.userInfo); auto sep_pos = userpass.find_first_of(':'); - if (sep_pos == util::string_view::npos) { + if (sep_pos == std::string_view::npos) { return UriUnescape(userpass); } else { return UriUnescape(userpass.substr(0, sep_pos)); @@ -172,7 +172,7 @@ std::string Uri::username() const { std::string Uri::password() const { auto userpass = TextRangeToView(impl_->uri_.userInfo); auto sep_pos = userpass.find_first_of(':'); - if (sep_pos == util::string_view::npos) { + if (sep_pos == std::string_view::npos) { return std::string(); } else { return UriUnescape(userpass.substr(sep_pos + 1)); diff --git a/cpp/src/arrow/util/uri.h b/cpp/src/arrow/util/uri.h index 50d9eccf82f..10853b4b777 100644 --- a/cpp/src/arrow/util/uri.h +++ b/cpp/src/arrow/util/uri.h @@ -20,11 +20,11 @@ #include #include #include +#include #include #include #include "arrow/type_fwd.h" -#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" namespace arrow { @@ -38,7 +38,7 @@ class ARROW_EXPORT Uri { Uri(Uri&&); Uri& operator=(Uri&&); - // XXX Should we use util::string_view instead? These functions are + // XXX Should we use std::string_view instead? These functions are // not performance-critical. /// The URI scheme, such as "http", or the empty string if the URI has no @@ -93,7 +93,7 @@ ARROW_EXPORT std::string UriEscape(const std::string& s); ARROW_EXPORT -std::string UriUnescape(const arrow::util::string_view s); +std::string UriUnescape(const std::string_view s); /// Encode a host for use within a URI, such as "localhost", /// "127.0.0.1", or "[::1]". @@ -102,7 +102,7 @@ std::string UriEncodeHost(const std::string& host); /// Whether the string is a syntactically valid URI scheme according to RFC 3986. ARROW_EXPORT -bool IsValidUriScheme(const arrow::util::string_view s); +bool IsValidUriScheme(const std::string_view s); /// Create a file uri from a given absolute path ARROW_EXPORT diff --git a/cpp/src/arrow/util/utf8.cc b/cpp/src/arrow/util/utf8.cc index e589e1763e6..67f04709621 100644 --- a/cpp/src/arrow/util/utf8.cc +++ b/cpp/src/arrow/util/utf8.cc @@ -96,7 +96,7 @@ bool ValidateUTF8(const uint8_t* data, int64_t size) { return ValidateUTF8Inline(data, size); } -bool ValidateUTF8(const util::string_view& str) { return ValidateUTF8Inline(str); } +bool ValidateUTF8(const std::string_view& str) { return ValidateUTF8Inline(str); } static const uint8_t kBOM[] = {0xEF, 0xBB, 0xBF}; diff --git a/cpp/src/arrow/util/utf8.h b/cpp/src/arrow/util/utf8.h index eab207d2a02..909113055d1 100644 --- a/cpp/src/arrow/util/utf8.h +++ b/cpp/src/arrow/util/utf8.h @@ -20,10 +20,10 @@ #include #include #include +#include #include "arrow/type_fwd.h" #include "arrow/util/macros.h" -#include "arrow/util/string_view.h" #include "arrow/util/visibility.h" namespace arrow { @@ -41,7 +41,7 @@ ARROW_EXPORT void InitializeUTF8(); ARROW_EXPORT bool ValidateUTF8(const uint8_t* data, int64_t size); -ARROW_EXPORT bool ValidateUTF8(const util::string_view& str); +ARROW_EXPORT bool ValidateUTF8(const std::string_view& str); // Skip UTF8 byte order mark, if any. ARROW_EXPORT diff --git a/cpp/src/arrow/util/utf8_internal.h b/cpp/src/arrow/util/utf8_internal.h index 9d2954e9d1c..0ce7dd76200 100644 --- a/cpp/src/arrow/util/utf8_internal.h +++ b/cpp/src/arrow/util/utf8_internal.h @@ -22,6 +22,7 @@ #include #include #include +#include #if defined(ARROW_HAVE_NEON) || defined(ARROW_HAVE_SSE4_2) #include @@ -30,7 +31,6 @@ #include "arrow/type_fwd.h" #include "arrow/util/macros.h" #include "arrow/util/simd.h" -#include "arrow/util/string_view.h" #include "arrow/util/ubsan.h" #include "arrow/util/utf8.h" #include "arrow/util/visibility.h" @@ -201,7 +201,7 @@ static inline bool ValidateUTF8Inline(const uint8_t* data, int64_t size) { return ARROW_PREDICT_TRUE(state == internal::kUTF8ValidateAccept); } -static inline bool ValidateUTF8Inline(const util::string_view& str) { +static inline bool ValidateUTF8Inline(const std::string_view& str) { const uint8_t* data = reinterpret_cast(str.data()); const size_t length = str.size(); @@ -266,7 +266,7 @@ static inline bool ValidateAscii(const uint8_t* data, int64_t len) { #endif } -static inline bool ValidateAscii(const util::string_view& str) { +static inline bool ValidateAscii(const std::string_view& str) { const uint8_t* data = reinterpret_cast(str.data()); const size_t length = str.size(); diff --git a/cpp/src/arrow/util/value_parsing_benchmark.cc b/cpp/src/arrow/util/value_parsing_benchmark.cc index 40d139316e5..2c4a32b7a1b 100644 --- a/cpp/src/arrow/util/value_parsing_benchmark.cc +++ b/cpp/src/arrow/util/value_parsing_benchmark.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -32,7 +33,6 @@ #include "arrow/testing/random.h" #include "arrow/type.h" #include "arrow/util/formatting.h" -#include "arrow/util/string_view.h" #include "arrow/util/value_parsing.h" namespace arrow { @@ -218,7 +218,7 @@ static void TimestampParsingStrptime( } struct DummyAppender { - Status operator()(util::string_view v) { + Status operator()(std::string_view v) { if (pos_ >= static_cast(v.size())) { pos_ = 0; } diff --git a/cpp/src/arrow/vendored/base64.cpp b/cpp/src/arrow/vendored/base64.cpp index 0de11955b7d..6f53c0524e7 100644 --- a/cpp/src/arrow/vendored/base64.cpp +++ b/cpp/src/arrow/vendored/base64.cpp @@ -87,13 +87,13 @@ static std::string base64_encode(unsigned char const* bytes_to_encode, unsigned } -std::string base64_encode(string_view string_to_encode) { +std::string base64_encode(std::string_view string_to_encode) { auto bytes_to_encode = reinterpret_cast(string_to_encode.data()); auto in_len = static_cast(string_to_encode.size()); return base64_encode(bytes_to_encode, in_len); } -std::string base64_decode(string_view encoded_string) { +std::string base64_decode(std::string_view encoded_string) { size_t in_len = encoded_string.size(); int i = 0; int j = 0; diff --git a/cpp/src/arrow/vendored/string_view.hpp b/cpp/src/arrow/vendored/string_view.hpp deleted file mode 100644 index a2d5567854f..00000000000 --- a/cpp/src/arrow/vendored/string_view.hpp +++ /dev/null @@ -1,1531 +0,0 @@ -// Vendored from git changeset v1.4.0 - -// Copyright 2017-2020 by Martin Moene -// -// string-view lite, a C++17-like string_view for C++98 and later. -// For more information see https://github.com/martinmoene/string-view-lite -// -// Distributed under the Boost Software License, Version 1.0. -// (See accompanying file LICENSE.txt or copy at http://www.boost.org/LICENSE_1_0.txt) - -#pragma once - -#ifndef NONSTD_SV_LITE_H_INCLUDED -#define NONSTD_SV_LITE_H_INCLUDED - -#define string_view_lite_MAJOR 1 -#define string_view_lite_MINOR 4 -#define string_view_lite_PATCH 0 - -#define string_view_lite_VERSION nssv_STRINGIFY(string_view_lite_MAJOR) "." nssv_STRINGIFY(string_view_lite_MINOR) "." nssv_STRINGIFY(string_view_lite_PATCH) - -#define nssv_STRINGIFY( x ) nssv_STRINGIFY_( x ) -#define nssv_STRINGIFY_( x ) #x - -// string-view lite configuration: - -#define nssv_STRING_VIEW_DEFAULT 0 -#define nssv_STRING_VIEW_NONSTD 1 -#define nssv_STRING_VIEW_STD 2 - -#if !defined( nssv_CONFIG_SELECT_STRING_VIEW ) -# define nssv_CONFIG_SELECT_STRING_VIEW ( nssv_HAVE_STD_STRING_VIEW ? nssv_STRING_VIEW_STD : nssv_STRING_VIEW_NONSTD ) -#endif - -#if defined( nssv_CONFIG_SELECT_STD_STRING_VIEW ) || defined( nssv_CONFIG_SELECT_NONSTD_STRING_VIEW ) -# error nssv_CONFIG_SELECT_STD_STRING_VIEW and nssv_CONFIG_SELECT_NONSTD_STRING_VIEW are deprecated and removed, please use nssv_CONFIG_SELECT_STRING_VIEW=nssv_STRING_VIEW_... -#endif - -#ifndef nssv_CONFIG_STD_SV_OPERATOR -# define nssv_CONFIG_STD_SV_OPERATOR 0 -#endif - -#ifndef nssv_CONFIG_USR_SV_OPERATOR -# define nssv_CONFIG_USR_SV_OPERATOR 1 -#endif - -#ifdef nssv_CONFIG_CONVERSION_STD_STRING -# define nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS nssv_CONFIG_CONVERSION_STD_STRING -# define nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS nssv_CONFIG_CONVERSION_STD_STRING -#endif - -#ifndef nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS -# define nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS 1 -#endif - -#ifndef nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS -# define nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS 1 -#endif - -// Control presence of exception handling (try and auto discover): - -#ifndef nssv_CONFIG_NO_EXCEPTIONS -# if defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND) -# define nssv_CONFIG_NO_EXCEPTIONS 0 -# else -# define nssv_CONFIG_NO_EXCEPTIONS 1 -# endif -#endif - -// C++ language version detection (C++20 is speculative): -// Note: VC14.0/1900 (VS2015) lacks too much from C++14. - -#ifndef nssv_CPLUSPLUS -# if defined(_MSVC_LANG ) && !defined(__clang__) -# define nssv_CPLUSPLUS (_MSC_VER == 1900 ? 201103L : _MSVC_LANG ) -# else -# define nssv_CPLUSPLUS __cplusplus -# endif -#endif - -#define nssv_CPP98_OR_GREATER ( nssv_CPLUSPLUS >= 199711L ) -#define nssv_CPP11_OR_GREATER ( nssv_CPLUSPLUS >= 201103L ) -#define nssv_CPP11_OR_GREATER_ ( nssv_CPLUSPLUS >= 201103L ) -#define nssv_CPP14_OR_GREATER ( nssv_CPLUSPLUS >= 201402L ) -#define nssv_CPP17_OR_GREATER ( nssv_CPLUSPLUS >= 201703L ) -#define nssv_CPP20_OR_GREATER ( nssv_CPLUSPLUS >= 202000L ) - -// use C++17 std::string_view if available and requested: - -#if nssv_CPP17_OR_GREATER && defined(__has_include ) -# if __has_include( ) -# define nssv_HAVE_STD_STRING_VIEW 1 -# else -# define nssv_HAVE_STD_STRING_VIEW 0 -# endif -#else -# define nssv_HAVE_STD_STRING_VIEW 0 -#endif - -#define nssv_USES_STD_STRING_VIEW ( (nssv_CONFIG_SELECT_STRING_VIEW == nssv_STRING_VIEW_STD) || ((nssv_CONFIG_SELECT_STRING_VIEW == nssv_STRING_VIEW_DEFAULT) && nssv_HAVE_STD_STRING_VIEW) ) - -#define nssv_HAVE_STARTS_WITH ( nssv_CPP20_OR_GREATER || !nssv_USES_STD_STRING_VIEW ) -#define nssv_HAVE_ENDS_WITH nssv_HAVE_STARTS_WITH - -// -// Use C++17 std::string_view: -// - -#if nssv_USES_STD_STRING_VIEW - -#include - -// Extensions for std::string: - -#if nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - -namespace nonstd { - -template< class CharT, class Traits, class Allocator = std::allocator > -std::basic_string -to_string( std::basic_string_view v, Allocator const & a = Allocator() ) -{ - return std::basic_string( v.begin(), v.end(), a ); -} - -template< class CharT, class Traits, class Allocator > -std::basic_string_view -to_string_view( std::basic_string const & s ) -{ - return std::basic_string_view( s.data(), s.size() ); -} - -// Literal operators sv and _sv: - -#if nssv_CONFIG_STD_SV_OPERATOR - -using namespace std::literals::string_view_literals; - -#endif - -#if nssv_CONFIG_USR_SV_OPERATOR - -inline namespace literals { -inline namespace string_view_literals { - - -constexpr std::string_view operator "" _sv( const char* str, size_t len ) noexcept // (1) -{ - return std::string_view{ str, len }; -} - -constexpr std::u16string_view operator "" _sv( const char16_t* str, size_t len ) noexcept // (2) -{ - return std::u16string_view{ str, len }; -} - -constexpr std::u32string_view operator "" _sv( const char32_t* str, size_t len ) noexcept // (3) -{ - return std::u32string_view{ str, len }; -} - -constexpr std::wstring_view operator "" _sv( const wchar_t* str, size_t len ) noexcept // (4) -{ - return std::wstring_view{ str, len }; -} - -}} // namespace literals::string_view_literals - -#endif // nssv_CONFIG_USR_SV_OPERATOR - -} // namespace nonstd - -#endif // nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - -namespace nonstd { - -using std::string_view; -using std::wstring_view; -using std::u16string_view; -using std::u32string_view; -using std::basic_string_view; - -// literal "sv" and "_sv", see above - -using std::operator==; -using std::operator!=; -using std::operator<; -using std::operator<=; -using std::operator>; -using std::operator>=; - -using std::operator<<; - -} // namespace nonstd - -#else // nssv_HAVE_STD_STRING_VIEW - -// -// Before C++17: use string_view lite: -// - -// Compiler versions: -// -// MSVC++ 6.0 _MSC_VER == 1200 nssv_COMPILER_MSVC_VERSION == 60 (Visual Studio 6.0) -// MSVC++ 7.0 _MSC_VER == 1300 nssv_COMPILER_MSVC_VERSION == 70 (Visual Studio .NET 2002) -// MSVC++ 7.1 _MSC_VER == 1310 nssv_COMPILER_MSVC_VERSION == 71 (Visual Studio .NET 2003) -// MSVC++ 8.0 _MSC_VER == 1400 nssv_COMPILER_MSVC_VERSION == 80 (Visual Studio 2005) -// MSVC++ 9.0 _MSC_VER == 1500 nssv_COMPILER_MSVC_VERSION == 90 (Visual Studio 2008) -// MSVC++ 10.0 _MSC_VER == 1600 nssv_COMPILER_MSVC_VERSION == 100 (Visual Studio 2010) -// MSVC++ 11.0 _MSC_VER == 1700 nssv_COMPILER_MSVC_VERSION == 110 (Visual Studio 2012) -// MSVC++ 12.0 _MSC_VER == 1800 nssv_COMPILER_MSVC_VERSION == 120 (Visual Studio 2013) -// MSVC++ 14.0 _MSC_VER == 1900 nssv_COMPILER_MSVC_VERSION == 140 (Visual Studio 2015) -// MSVC++ 14.1 _MSC_VER >= 1910 nssv_COMPILER_MSVC_VERSION == 141 (Visual Studio 2017) -// MSVC++ 14.2 _MSC_VER >= 1920 nssv_COMPILER_MSVC_VERSION == 142 (Visual Studio 2019) - -#if defined(_MSC_VER ) && !defined(__clang__) -# define nssv_COMPILER_MSVC_VER (_MSC_VER ) -# define nssv_COMPILER_MSVC_VERSION (_MSC_VER / 10 - 10 * ( 5 + (_MSC_VER < 1900 ) ) ) -#else -# define nssv_COMPILER_MSVC_VER 0 -# define nssv_COMPILER_MSVC_VERSION 0 -#endif - -#define nssv_COMPILER_VERSION( major, minor, patch ) ( 10 * ( 10 * (major) + (minor) ) + (patch) ) - -#if defined(__clang__) -# define nssv_COMPILER_CLANG_VERSION nssv_COMPILER_VERSION(__clang_major__, __clang_minor__, __clang_patchlevel__) -#else -# define nssv_COMPILER_CLANG_VERSION 0 -#endif - -#if defined(__GNUC__) && !defined(__clang__) -# define nssv_COMPILER_GNUC_VERSION nssv_COMPILER_VERSION(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) -#else -# define nssv_COMPILER_GNUC_VERSION 0 -#endif - -// half-open range [lo..hi): -#define nssv_BETWEEN( v, lo, hi ) ( (lo) <= (v) && (v) < (hi) ) - -// Presence of language and library features: - -#ifdef _HAS_CPP0X -# define nssv_HAS_CPP0X _HAS_CPP0X -#else -# define nssv_HAS_CPP0X 0 -#endif - -// Unless defined otherwise below, consider VC14 as C++11 for variant-lite: - -#if nssv_COMPILER_MSVC_VER >= 1900 -# undef nssv_CPP11_OR_GREATER -# define nssv_CPP11_OR_GREATER 1 -#endif - -#define nssv_CPP11_90 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1500) -#define nssv_CPP11_100 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1600) -#define nssv_CPP11_110 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1700) -#define nssv_CPP11_120 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1800) -#define nssv_CPP11_140 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1900) -#define nssv_CPP11_141 (nssv_CPP11_OR_GREATER_ || nssv_COMPILER_MSVC_VER >= 1910) - -#define nssv_CPP14_000 (nssv_CPP14_OR_GREATER) -#define nssv_CPP17_000 (nssv_CPP17_OR_GREATER) - -// Presence of C++11 language features: - -#define nssv_HAVE_CONSTEXPR_11 nssv_CPP11_140 -#define nssv_HAVE_EXPLICIT_CONVERSION nssv_CPP11_140 -#define nssv_HAVE_INLINE_NAMESPACE nssv_CPP11_140 -#define nssv_HAVE_NOEXCEPT nssv_CPP11_140 -#define nssv_HAVE_NULLPTR nssv_CPP11_100 -#define nssv_HAVE_REF_QUALIFIER nssv_CPP11_140 -#define nssv_HAVE_UNICODE_LITERALS nssv_CPP11_140 -#define nssv_HAVE_USER_DEFINED_LITERALS nssv_CPP11_140 -#define nssv_HAVE_WCHAR16_T nssv_CPP11_100 -#define nssv_HAVE_WCHAR32_T nssv_CPP11_100 - -#if ! ( ( nssv_CPP11_OR_GREATER && nssv_COMPILER_CLANG_VERSION ) || nssv_BETWEEN( nssv_COMPILER_CLANG_VERSION, 300, 400 ) ) -# define nssv_HAVE_STD_DEFINED_LITERALS nssv_CPP11_140 -#else -# define nssv_HAVE_STD_DEFINED_LITERALS 0 -#endif - -// Presence of C++14 language features: - -#define nssv_HAVE_CONSTEXPR_14 nssv_CPP14_000 - -// Presence of C++17 language features: - -#define nssv_HAVE_NODISCARD nssv_CPP17_000 - -// Presence of C++ library features: - -#define nssv_HAVE_STD_HASH nssv_CPP11_120 - -// C++ feature usage: - -#if nssv_HAVE_CONSTEXPR_11 -# define nssv_constexpr constexpr -#else -# define nssv_constexpr /*constexpr*/ -#endif - -#if nssv_HAVE_CONSTEXPR_14 -# define nssv_constexpr14 constexpr -#else -# define nssv_constexpr14 /*constexpr*/ -#endif - -#if nssv_HAVE_EXPLICIT_CONVERSION -# define nssv_explicit explicit -#else -# define nssv_explicit /*explicit*/ -#endif - -#if nssv_HAVE_INLINE_NAMESPACE -# define nssv_inline_ns inline -#else -# define nssv_inline_ns /*inline*/ -#endif - -#if nssv_HAVE_NOEXCEPT -# define nssv_noexcept noexcept -#else -# define nssv_noexcept /*noexcept*/ -#endif - -//#if nssv_HAVE_REF_QUALIFIER -//# define nssv_ref_qual & -//# define nssv_refref_qual && -//#else -//# define nssv_ref_qual /*&*/ -//# define nssv_refref_qual /*&&*/ -//#endif - -#if nssv_HAVE_NULLPTR -# define nssv_nullptr nullptr -#else -# define nssv_nullptr NULL -#endif - -#if nssv_HAVE_NODISCARD -# define nssv_nodiscard [[nodiscard]] -#else -# define nssv_nodiscard /*[[nodiscard]]*/ -#endif - -// Additional includes: - -#include -#include -#include -#include -#include -#include // std::char_traits<> - -#if ! nssv_CONFIG_NO_EXCEPTIONS -# include -#endif - -#if nssv_CPP11_OR_GREATER -# include -#endif - -// Clang, GNUC, MSVC warning suppression macros: - -#if defined(__clang__) -# pragma clang diagnostic ignored "-Wreserved-user-defined-literal" -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wuser-defined-literals" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wliteral-suffix" -#endif // __clang__ - -#if nssv_COMPILER_MSVC_VERSION >= 140 -# define nssv_SUPPRESS_MSGSL_WARNING(expr) [[gsl::suppress(expr)]] -# define nssv_SUPPRESS_MSVC_WARNING(code, descr) __pragma(warning(suppress: code) ) -# define nssv_DISABLE_MSVC_WARNINGS(codes) __pragma(warning(push)) __pragma(warning(disable: codes)) -#else -# define nssv_SUPPRESS_MSGSL_WARNING(expr) -# define nssv_SUPPRESS_MSVC_WARNING(code, descr) -# define nssv_DISABLE_MSVC_WARNINGS(codes) -#endif - -#if defined(__clang__) -# define nssv_RESTORE_WARNINGS() _Pragma("clang diagnostic pop") -#elif defined(__GNUC__) -# define nssv_RESTORE_WARNINGS() _Pragma("GCC diagnostic pop") -#elif nssv_COMPILER_MSVC_VERSION >= 140 -# define nssv_RESTORE_WARNINGS() __pragma(warning(pop )) -#else -# define nssv_RESTORE_WARNINGS() -#endif - -// Suppress the following MSVC (GSL) warnings: -// - C4455, non-gsl : 'operator ""sv': literal suffix identifiers that do not -// start with an underscore are reserved -// - C26472, gsl::t.1 : don't use a static_cast for arithmetic conversions; -// use brace initialization, gsl::narrow_cast or gsl::narow -// - C26481: gsl::b.1 : don't use pointer arithmetic. Use span instead - -nssv_DISABLE_MSVC_WARNINGS( 4455 26481 26472 ) -//nssv_DISABLE_CLANG_WARNINGS( "-Wuser-defined-literals" ) -//nssv_DISABLE_GNUC_WARNINGS( -Wliteral-suffix ) - -namespace nonstd { namespace sv_lite { - -#if nssv_CPP11_OR_GREATER - -namespace detail { - -#if nssv_CPP14_OR_GREATER - -template< typename CharT > -inline constexpr std::size_t length( CharT * s, std::size_t result = 0 ) -{ - CharT * v = s; - std::size_t r = result; - while ( *v != '\0' ) { - ++v; - ++r; - } - return r; -} - -#else // nssv_CPP14_OR_GREATER - -// Expect tail call optimization to make length() non-recursive: - -template< typename CharT > -inline constexpr std::size_t length( CharT * s, std::size_t result = 0 ) -{ - return *s == '\0' ? result : length( s + 1, result + 1 ); -} - -#endif // nssv_CPP14_OR_GREATER - -} // namespace detail - -#endif // nssv_CPP11_OR_GREATER - -template -< - class CharT, - class Traits = std::char_traits -> -class basic_string_view; - -// -// basic_string_view: -// - -template -< - class CharT, - class Traits /* = std::char_traits */ -> -class basic_string_view -{ -public: - // Member types: - - typedef Traits traits_type; - typedef CharT value_type; - - typedef CharT * pointer; - typedef CharT const * const_pointer; - typedef CharT & reference; - typedef CharT const & const_reference; - - typedef const_pointer iterator; - typedef const_pointer const_iterator; - typedef std::reverse_iterator< const_iterator > reverse_iterator; - typedef std::reverse_iterator< const_iterator > const_reverse_iterator; - - typedef std::size_t size_type; - typedef std::ptrdiff_t difference_type; - - // 24.4.2.1 Construction and assignment: - - nssv_constexpr basic_string_view() nssv_noexcept - : data_( nssv_nullptr ) - , size_( 0 ) - {} - -#if nssv_CPP11_OR_GREATER - nssv_constexpr basic_string_view( basic_string_view const & other ) nssv_noexcept = default; -#else - nssv_constexpr basic_string_view( basic_string_view const & other ) nssv_noexcept - : data_( other.data_) - , size_( other.size_) - {} -#endif - - nssv_constexpr basic_string_view( CharT const * s, size_type count ) nssv_noexcept // non-standard noexcept - : data_( s ) - , size_( count ) - {} - - nssv_constexpr basic_string_view( CharT const * s) nssv_noexcept // non-standard noexcept - : data_( s ) -#if nssv_CPP17_OR_GREATER - , size_( Traits::length(s) ) -#elif nssv_CPP11_OR_GREATER - , size_( detail::length(s) ) -#else - , size_( Traits::length(s) ) -#endif - {} - - // Assignment: - -#if nssv_CPP11_OR_GREATER - nssv_constexpr14 basic_string_view & operator=( basic_string_view const & other ) nssv_noexcept = default; -#else - nssv_constexpr14 basic_string_view & operator=( basic_string_view const & other ) nssv_noexcept - { - data_ = other.data_; - size_ = other.size_; - return *this; - } -#endif - - // 24.4.2.2 Iterator support: - - nssv_constexpr const_iterator begin() const nssv_noexcept { return data_; } - nssv_constexpr const_iterator end() const nssv_noexcept { return data_ + size_; } - - nssv_constexpr const_iterator cbegin() const nssv_noexcept { return begin(); } - nssv_constexpr const_iterator cend() const nssv_noexcept { return end(); } - - nssv_constexpr const_reverse_iterator rbegin() const nssv_noexcept { return const_reverse_iterator( end() ); } - nssv_constexpr const_reverse_iterator rend() const nssv_noexcept { return const_reverse_iterator( begin() ); } - - nssv_constexpr const_reverse_iterator crbegin() const nssv_noexcept { return rbegin(); } - nssv_constexpr const_reverse_iterator crend() const nssv_noexcept { return rend(); } - - // 24.4.2.3 Capacity: - - nssv_constexpr size_type size() const nssv_noexcept { return size_; } - nssv_constexpr size_type length() const nssv_noexcept { return size_; } - nssv_constexpr size_type max_size() const nssv_noexcept { return (std::numeric_limits< size_type >::max)(); } - - // since C++20 - nssv_nodiscard nssv_constexpr bool empty() const nssv_noexcept - { - return 0 == size_; - } - - // 24.4.2.4 Element access: - - nssv_constexpr const_reference operator[]( size_type pos ) const - { - return data_at( pos ); - } - - nssv_constexpr14 const_reference at( size_type pos ) const - { -#if nssv_CONFIG_NO_EXCEPTIONS - assert( pos < size() ); -#else - if ( pos >= size() ) - { - throw std::out_of_range("nonstd::string_view::at()"); - } -#endif - return data_at( pos ); - } - - nssv_constexpr const_reference front() const { return data_at( 0 ); } - nssv_constexpr const_reference back() const { return data_at( size() - 1 ); } - - nssv_constexpr const_pointer data() const nssv_noexcept { return data_; } - - // 24.4.2.5 Modifiers: - - nssv_constexpr14 void remove_prefix( size_type n ) - { - assert( n <= size() ); - data_ += n; - size_ -= n; - } - - nssv_constexpr14 void remove_suffix( size_type n ) - { - assert( n <= size() ); - size_ -= n; - } - - nssv_constexpr14 void swap( basic_string_view & other ) nssv_noexcept - { - using std::swap; - swap( data_, other.data_ ); - swap( size_, other.size_ ); - } - - // 24.4.2.6 String operations: - - size_type copy( CharT * dest, size_type n, size_type pos = 0 ) const - { -#if nssv_CONFIG_NO_EXCEPTIONS - assert( pos <= size() ); -#else - if ( pos > size() ) - { - throw std::out_of_range("nonstd::string_view::copy()"); - } -#endif - const size_type rlen = (std::min)( n, size() - pos ); - - (void) Traits::copy( dest, data() + pos, rlen ); - - return rlen; - } - - nssv_constexpr14 basic_string_view substr( size_type pos = 0, size_type n = npos ) const - { -#if nssv_CONFIG_NO_EXCEPTIONS - assert( pos <= size() ); -#else - if ( pos > size() ) - { - throw std::out_of_range("nonstd::string_view::substr()"); - } -#endif - return basic_string_view( data() + pos, (std::min)( n, size() - pos ) ); - } - - // compare(), 6x: - - nssv_constexpr14 int compare( basic_string_view other ) const nssv_noexcept // (1) - { - if ( const int result = Traits::compare( data(), other.data(), (std::min)( size(), other.size() ) ) ) - { - return result; - } - - return size() == other.size() ? 0 : size() < other.size() ? -1 : 1; - } - - nssv_constexpr int compare( size_type pos1, size_type n1, basic_string_view other ) const // (2) - { - return substr( pos1, n1 ).compare( other ); - } - - nssv_constexpr int compare( size_type pos1, size_type n1, basic_string_view other, size_type pos2, size_type n2 ) const // (3) - { - return substr( pos1, n1 ).compare( other.substr( pos2, n2 ) ); - } - - nssv_constexpr int compare( CharT const * s ) const // (4) - { - return compare( basic_string_view( s ) ); - } - - nssv_constexpr int compare( size_type pos1, size_type n1, CharT const * s ) const // (5) - { - return substr( pos1, n1 ).compare( basic_string_view( s ) ); - } - - nssv_constexpr int compare( size_type pos1, size_type n1, CharT const * s, size_type n2 ) const // (6) - { - return substr( pos1, n1 ).compare( basic_string_view( s, n2 ) ); - } - - // 24.4.2.7 Searching: - - // starts_with(), 3x, since C++20: - - nssv_constexpr bool starts_with( basic_string_view v ) const nssv_noexcept // (1) - { - return size() >= v.size() && compare( 0, v.size(), v ) == 0; - } - - nssv_constexpr bool starts_with( CharT c ) const nssv_noexcept // (2) - { - return starts_with( basic_string_view( &c, 1 ) ); - } - - nssv_constexpr bool starts_with( CharT const * s ) const // (3) - { - return starts_with( basic_string_view( s ) ); - } - - // ends_with(), 3x, since C++20: - - nssv_constexpr bool ends_with( basic_string_view v ) const nssv_noexcept // (1) - { - return size() >= v.size() && compare( size() - v.size(), npos, v ) == 0; - } - - nssv_constexpr bool ends_with( CharT c ) const nssv_noexcept // (2) - { - return ends_with( basic_string_view( &c, 1 ) ); - } - - nssv_constexpr bool ends_with( CharT const * s ) const // (3) - { - return ends_with( basic_string_view( s ) ); - } - - // find(), 4x: - - nssv_constexpr14 size_type find( basic_string_view v, size_type pos = 0 ) const nssv_noexcept // (1) - { - return assert( v.size() == 0 || v.data() != nssv_nullptr ) - , pos >= size() - ? npos - : to_pos( std::search( cbegin() + pos, cend(), v.cbegin(), v.cend(), Traits::eq ) ); - } - - nssv_constexpr14 size_type find( CharT c, size_type pos = 0 ) const nssv_noexcept // (2) - { - return find( basic_string_view( &c, 1 ), pos ); - } - - nssv_constexpr14 size_type find( CharT const * s, size_type pos, size_type n ) const // (3) - { - return find( basic_string_view( s, n ), pos ); - } - - nssv_constexpr14 size_type find( CharT const * s, size_type pos = 0 ) const // (4) - { - return find( basic_string_view( s ), pos ); - } - - // rfind(), 4x: - - nssv_constexpr14 size_type rfind( basic_string_view v, size_type pos = npos ) const nssv_noexcept // (1) - { - if ( size() < v.size() ) - { - return npos; - } - - if ( v.empty() ) - { - return (std::min)( size(), pos ); - } - - const_iterator last = cbegin() + (std::min)( size() - v.size(), pos ) + v.size(); - const_iterator result = std::find_end( cbegin(), last, v.cbegin(), v.cend(), Traits::eq ); - - return result != last ? size_type( result - cbegin() ) : npos; - } - - nssv_constexpr14 size_type rfind( CharT c, size_type pos = npos ) const nssv_noexcept // (2) - { - return rfind( basic_string_view( &c, 1 ), pos ); - } - - nssv_constexpr14 size_type rfind( CharT const * s, size_type pos, size_type n ) const // (3) - { - return rfind( basic_string_view( s, n ), pos ); - } - - nssv_constexpr14 size_type rfind( CharT const * s, size_type pos = npos ) const // (4) - { - return rfind( basic_string_view( s ), pos ); - } - - // find_first_of(), 4x: - - nssv_constexpr size_type find_first_of( basic_string_view v, size_type pos = 0 ) const nssv_noexcept // (1) - { - return pos >= size() - ? npos - : to_pos( std::find_first_of( cbegin() + pos, cend(), v.cbegin(), v.cend(), Traits::eq ) ); - } - - nssv_constexpr size_type find_first_of( CharT c, size_type pos = 0 ) const nssv_noexcept // (2) - { - return find_first_of( basic_string_view( &c, 1 ), pos ); - } - - nssv_constexpr size_type find_first_of( CharT const * s, size_type pos, size_type n ) const // (3) - { - return find_first_of( basic_string_view( s, n ), pos ); - } - - nssv_constexpr size_type find_first_of( CharT const * s, size_type pos = 0 ) const // (4) - { - return find_first_of( basic_string_view( s ), pos ); - } - - // find_last_of(), 4x: - - nssv_constexpr size_type find_last_of( basic_string_view v, size_type pos = npos ) const nssv_noexcept // (1) - { - return empty() - ? npos - : pos >= size() - ? find_last_of( v, size() - 1 ) - : to_pos( std::find_first_of( const_reverse_iterator( cbegin() + pos + 1 ), crend(), v.cbegin(), v.cend(), Traits::eq ) ); - } - - nssv_constexpr size_type find_last_of( CharT c, size_type pos = npos ) const nssv_noexcept // (2) - { - return find_last_of( basic_string_view( &c, 1 ), pos ); - } - - nssv_constexpr size_type find_last_of( CharT const * s, size_type pos, size_type count ) const // (3) - { - return find_last_of( basic_string_view( s, count ), pos ); - } - - nssv_constexpr size_type find_last_of( CharT const * s, size_type pos = npos ) const // (4) - { - return find_last_of( basic_string_view( s ), pos ); - } - - // find_first_not_of(), 4x: - - nssv_constexpr size_type find_first_not_of( basic_string_view v, size_type pos = 0 ) const nssv_noexcept // (1) - { - return pos >= size() - ? npos - : to_pos( std::find_if( cbegin() + pos, cend(), not_in_view( v ) ) ); - } - - nssv_constexpr size_type find_first_not_of( CharT c, size_type pos = 0 ) const nssv_noexcept // (2) - { - return find_first_not_of( basic_string_view( &c, 1 ), pos ); - } - - nssv_constexpr size_type find_first_not_of( CharT const * s, size_type pos, size_type count ) const // (3) - { - return find_first_not_of( basic_string_view( s, count ), pos ); - } - - nssv_constexpr size_type find_first_not_of( CharT const * s, size_type pos = 0 ) const // (4) - { - return find_first_not_of( basic_string_view( s ), pos ); - } - - // find_last_not_of(), 4x: - - nssv_constexpr size_type find_last_not_of( basic_string_view v, size_type pos = npos ) const nssv_noexcept // (1) - { - return empty() - ? npos - : pos >= size() - ? find_last_not_of( v, size() - 1 ) - : to_pos( std::find_if( const_reverse_iterator( cbegin() + pos + 1 ), crend(), not_in_view( v ) ) ); - } - - nssv_constexpr size_type find_last_not_of( CharT c, size_type pos = npos ) const nssv_noexcept // (2) - { - return find_last_not_of( basic_string_view( &c, 1 ), pos ); - } - - nssv_constexpr size_type find_last_not_of( CharT const * s, size_type pos, size_type count ) const // (3) - { - return find_last_not_of( basic_string_view( s, count ), pos ); - } - - nssv_constexpr size_type find_last_not_of( CharT const * s, size_type pos = npos ) const // (4) - { - return find_last_not_of( basic_string_view( s ), pos ); - } - - // Constants: - -#if nssv_CPP17_OR_GREATER - static nssv_constexpr size_type npos = size_type(-1); -#elif nssv_CPP11_OR_GREATER - enum : size_type { npos = size_type(-1) }; -#else - enum { npos = size_type(-1) }; -#endif - -private: - struct not_in_view - { - const basic_string_view v; - - nssv_constexpr explicit not_in_view( basic_string_view v_ ) : v( v_ ) {} - - nssv_constexpr bool operator()( CharT c ) const - { - return npos == v.find_first_of( c ); - } - }; - - nssv_constexpr size_type to_pos( const_iterator it ) const - { - return it == cend() ? npos : size_type( it - cbegin() ); - } - - nssv_constexpr size_type to_pos( const_reverse_iterator it ) const - { - return it == crend() ? npos : size_type( crend() - it - 1 ); - } - - nssv_constexpr const_reference data_at( size_type pos ) const - { -#if nssv_BETWEEN( nssv_COMPILER_GNUC_VERSION, 1, 500 ) - return data_[pos]; -#else - return assert( pos < size() ), data_[pos]; -#endif - } - -private: - const_pointer data_; - size_type size_; - -public: -#if nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS - - template< class Allocator > - basic_string_view( std::basic_string const & s ) nssv_noexcept - : data_( s.data() ) - , size_( s.size() ) - {} - -#if nssv_HAVE_EXPLICIT_CONVERSION - - template< class Allocator > - explicit operator std::basic_string() const - { - return to_string( Allocator() ); - } - -#endif // nssv_HAVE_EXPLICIT_CONVERSION - -#if nssv_CPP11_OR_GREATER - - template< class Allocator = std::allocator > - std::basic_string - to_string( Allocator const & a = Allocator() ) const - { - return std::basic_string( begin(), end(), a ); - } - -#else - - std::basic_string - to_string() const - { - return std::basic_string( begin(), end() ); - } - - template< class Allocator > - std::basic_string - to_string( Allocator const & a ) const - { - return std::basic_string( begin(), end(), a ); - } - -#endif // nssv_CPP11_OR_GREATER - -#endif // nssv_CONFIG_CONVERSION_STD_STRING_CLASS_METHODS -}; - -// -// Non-member functions: -// - -// 24.4.3 Non-member comparison functions: -// lexicographically compare two string views (function template): - -template< class CharT, class Traits > -nssv_constexpr bool operator== ( - basic_string_view lhs, - basic_string_view rhs ) nssv_noexcept -{ return lhs.compare( rhs ) == 0 ; } - -template< class CharT, class Traits > -nssv_constexpr bool operator!= ( - basic_string_view lhs, - basic_string_view rhs ) nssv_noexcept -{ return lhs.compare( rhs ) != 0 ; } - -template< class CharT, class Traits > -nssv_constexpr bool operator< ( - basic_string_view lhs, - basic_string_view rhs ) nssv_noexcept -{ return lhs.compare( rhs ) < 0 ; } - -template< class CharT, class Traits > -nssv_constexpr bool operator<= ( - basic_string_view lhs, - basic_string_view rhs ) nssv_noexcept -{ return lhs.compare( rhs ) <= 0 ; } - -template< class CharT, class Traits > -nssv_constexpr bool operator> ( - basic_string_view lhs, - basic_string_view rhs ) nssv_noexcept -{ return lhs.compare( rhs ) > 0 ; } - -template< class CharT, class Traits > -nssv_constexpr bool operator>= ( - basic_string_view lhs, - basic_string_view rhs ) nssv_noexcept -{ return lhs.compare( rhs ) >= 0 ; } - -// Let S be basic_string_view, and sv be an instance of S. -// Implementations shall provide sufficient additional overloads marked -// constexpr and noexcept so that an object t with an implicit conversion -// to S can be compared according to Table 67. - -#if ! nssv_CPP11_OR_GREATER || nssv_BETWEEN( nssv_COMPILER_MSVC_VERSION, 100, 141 ) - -// accomodate for older compilers: - -// == - -template< class CharT, class Traits> -nssv_constexpr bool operator==( - basic_string_view lhs, - CharT const * rhs ) nssv_noexcept -{ return lhs.compare( rhs ) == 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator==( - CharT const * lhs, - basic_string_view rhs ) nssv_noexcept -{ return rhs.compare( lhs ) == 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator==( - basic_string_view lhs, - std::basic_string rhs ) nssv_noexcept -{ return lhs.size() == rhs.size() && lhs.compare( rhs ) == 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator==( - std::basic_string rhs, - basic_string_view lhs ) nssv_noexcept -{ return lhs.size() == rhs.size() && lhs.compare( rhs ) == 0; } - -// != - -template< class CharT, class Traits> -nssv_constexpr bool operator!=( - basic_string_view lhs, - char const * rhs ) nssv_noexcept -{ return lhs.compare( rhs ) != 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator!=( - char const * lhs, - basic_string_view rhs ) nssv_noexcept -{ return rhs.compare( lhs ) != 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator!=( - basic_string_view lhs, - std::basic_string rhs ) nssv_noexcept -{ return lhs.size() != rhs.size() && lhs.compare( rhs ) != 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator!=( - std::basic_string rhs, - basic_string_view lhs ) nssv_noexcept -{ return lhs.size() != rhs.size() || rhs.compare( lhs ) != 0; } - -// < - -template< class CharT, class Traits> -nssv_constexpr bool operator<( - basic_string_view lhs, - char const * rhs ) nssv_noexcept -{ return lhs.compare( rhs ) < 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator<( - char const * lhs, - basic_string_view rhs ) nssv_noexcept -{ return rhs.compare( lhs ) > 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator<( - basic_string_view lhs, - std::basic_string rhs ) nssv_noexcept -{ return lhs.compare( rhs ) < 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator<( - std::basic_string rhs, - basic_string_view lhs ) nssv_noexcept -{ return rhs.compare( lhs ) > 0; } - -// <= - -template< class CharT, class Traits> -nssv_constexpr bool operator<=( - basic_string_view lhs, - char const * rhs ) nssv_noexcept -{ return lhs.compare( rhs ) <= 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator<=( - char const * lhs, - basic_string_view rhs ) nssv_noexcept -{ return rhs.compare( lhs ) >= 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator<=( - basic_string_view lhs, - std::basic_string rhs ) nssv_noexcept -{ return lhs.compare( rhs ) <= 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator<=( - std::basic_string rhs, - basic_string_view lhs ) nssv_noexcept -{ return rhs.compare( lhs ) >= 0; } - -// > - -template< class CharT, class Traits> -nssv_constexpr bool operator>( - basic_string_view lhs, - char const * rhs ) nssv_noexcept -{ return lhs.compare( rhs ) > 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator>( - char const * lhs, - basic_string_view rhs ) nssv_noexcept -{ return rhs.compare( lhs ) < 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator>( - basic_string_view lhs, - std::basic_string rhs ) nssv_noexcept -{ return lhs.compare( rhs ) > 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator>( - std::basic_string rhs, - basic_string_view lhs ) nssv_noexcept -{ return rhs.compare( lhs ) < 0; } - -// >= - -template< class CharT, class Traits> -nssv_constexpr bool operator>=( - basic_string_view lhs, - char const * rhs ) nssv_noexcept -{ return lhs.compare( rhs ) >= 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator>=( - char const * lhs, - basic_string_view rhs ) nssv_noexcept -{ return rhs.compare( lhs ) <= 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator>=( - basic_string_view lhs, - std::basic_string rhs ) nssv_noexcept -{ return lhs.compare( rhs ) >= 0; } - -template< class CharT, class Traits> -nssv_constexpr bool operator>=( - std::basic_string rhs, - basic_string_view lhs ) nssv_noexcept -{ return rhs.compare( lhs ) <= 0; } - -#else // newer compilers: - -#define nssv_BASIC_STRING_VIEW_I(T,U) typename std::decay< basic_string_view >::type - -#if nssv_BETWEEN( nssv_COMPILER_MSVC_VERSION, 140, 150 ) -# define nssv_MSVC_ORDER(x) , int=x -#else -# define nssv_MSVC_ORDER(x) /*, int=x*/ -#endif - -// == - -template< class CharT, class Traits nssv_MSVC_ORDER(1) > -nssv_constexpr bool operator==( - basic_string_view lhs, - nssv_BASIC_STRING_VIEW_I(CharT, Traits) rhs ) nssv_noexcept -{ return lhs.compare( rhs ) == 0; } - -template< class CharT, class Traits nssv_MSVC_ORDER(2) > -nssv_constexpr bool operator==( - nssv_BASIC_STRING_VIEW_I(CharT, Traits) lhs, - basic_string_view rhs ) nssv_noexcept -{ return lhs.size() == rhs.size() && lhs.compare( rhs ) == 0; } - -// != - -template< class CharT, class Traits nssv_MSVC_ORDER(1) > -nssv_constexpr bool operator!= ( - basic_string_view < CharT, Traits > lhs, - nssv_BASIC_STRING_VIEW_I( CharT, Traits ) rhs ) nssv_noexcept -{ return lhs.size() != rhs.size() || lhs.compare( rhs ) != 0 ; } - -template< class CharT, class Traits nssv_MSVC_ORDER(2) > -nssv_constexpr bool operator!= ( - nssv_BASIC_STRING_VIEW_I( CharT, Traits ) lhs, - basic_string_view < CharT, Traits > rhs ) nssv_noexcept -{ return lhs.compare( rhs ) != 0 ; } - -// < - -template< class CharT, class Traits nssv_MSVC_ORDER(1) > -nssv_constexpr bool operator< ( - basic_string_view < CharT, Traits > lhs, - nssv_BASIC_STRING_VIEW_I( CharT, Traits ) rhs ) nssv_noexcept -{ return lhs.compare( rhs ) < 0 ; } - -template< class CharT, class Traits nssv_MSVC_ORDER(2) > -nssv_constexpr bool operator< ( - nssv_BASIC_STRING_VIEW_I( CharT, Traits ) lhs, - basic_string_view < CharT, Traits > rhs ) nssv_noexcept -{ return lhs.compare( rhs ) < 0 ; } - -// <= - -template< class CharT, class Traits nssv_MSVC_ORDER(1) > -nssv_constexpr bool operator<= ( - basic_string_view < CharT, Traits > lhs, - nssv_BASIC_STRING_VIEW_I( CharT, Traits ) rhs ) nssv_noexcept -{ return lhs.compare( rhs ) <= 0 ; } - -template< class CharT, class Traits nssv_MSVC_ORDER(2) > -nssv_constexpr bool operator<= ( - nssv_BASIC_STRING_VIEW_I( CharT, Traits ) lhs, - basic_string_view < CharT, Traits > rhs ) nssv_noexcept -{ return lhs.compare( rhs ) <= 0 ; } - -// > - -template< class CharT, class Traits nssv_MSVC_ORDER(1) > -nssv_constexpr bool operator> ( - basic_string_view < CharT, Traits > lhs, - nssv_BASIC_STRING_VIEW_I( CharT, Traits ) rhs ) nssv_noexcept -{ return lhs.compare( rhs ) > 0 ; } - -template< class CharT, class Traits nssv_MSVC_ORDER(2) > -nssv_constexpr bool operator> ( - nssv_BASIC_STRING_VIEW_I( CharT, Traits ) lhs, - basic_string_view < CharT, Traits > rhs ) nssv_noexcept -{ return lhs.compare( rhs ) > 0 ; } - -// >= - -template< class CharT, class Traits nssv_MSVC_ORDER(1) > -nssv_constexpr bool operator>= ( - basic_string_view < CharT, Traits > lhs, - nssv_BASIC_STRING_VIEW_I( CharT, Traits ) rhs ) nssv_noexcept -{ return lhs.compare( rhs ) >= 0 ; } - -template< class CharT, class Traits nssv_MSVC_ORDER(2) > -nssv_constexpr bool operator>= ( - nssv_BASIC_STRING_VIEW_I( CharT, Traits ) lhs, - basic_string_view < CharT, Traits > rhs ) nssv_noexcept -{ return lhs.compare( rhs ) >= 0 ; } - -#undef nssv_MSVC_ORDER -#undef nssv_BASIC_STRING_VIEW_I - -#endif // compiler-dependent approach to comparisons - -// 24.4.4 Inserters and extractors: - -namespace detail { - -template< class Stream > -void write_padding( Stream & os, std::streamsize n ) -{ - for ( std::streamsize i = 0; i < n; ++i ) - os.rdbuf()->sputc( os.fill() ); -} - -template< class Stream, class View > -Stream & write_to_stream( Stream & os, View const & sv ) -{ - typename Stream::sentry sentry( os ); - - if ( !os ) - return os; - - const std::streamsize length = static_cast( sv.length() ); - - // Whether, and how, to pad: - const bool pad = ( length < os.width() ); - const bool left_pad = pad && ( os.flags() & std::ios_base::adjustfield ) == std::ios_base::right; - - if ( left_pad ) - write_padding( os, os.width() - length ); - - // Write span characters: - os.rdbuf()->sputn( sv.begin(), length ); - - if ( pad && !left_pad ) - write_padding( os, os.width() - length ); - - // Reset output stream width: - os.width( 0 ); - - return os; -} - -} // namespace detail - -template< class CharT, class Traits > -std::basic_ostream & -operator<<( - std::basic_ostream& os, - basic_string_view sv ) -{ - return detail::write_to_stream( os, sv ); -} - -// Several typedefs for common character types are provided: - -typedef basic_string_view string_view; -typedef basic_string_view wstring_view; -#if nssv_HAVE_WCHAR16_T -typedef basic_string_view u16string_view; -typedef basic_string_view u32string_view; -#endif - -}} // namespace nonstd::sv_lite - -// -// 24.4.6 Suffix for basic_string_view literals: -// - -#if nssv_HAVE_USER_DEFINED_LITERALS - -namespace nonstd { -nssv_inline_ns namespace literals { -nssv_inline_ns namespace string_view_literals { - -#if nssv_CONFIG_STD_SV_OPERATOR && nssv_HAVE_STD_DEFINED_LITERALS - -nssv_constexpr nonstd::sv_lite::string_view operator "" sv( const char* str, size_t len ) nssv_noexcept // (1) -{ - return nonstd::sv_lite::string_view{ str, len }; -} - -nssv_constexpr nonstd::sv_lite::u16string_view operator "" sv( const char16_t* str, size_t len ) nssv_noexcept // (2) -{ - return nonstd::sv_lite::u16string_view{ str, len }; -} - -nssv_constexpr nonstd::sv_lite::u32string_view operator "" sv( const char32_t* str, size_t len ) nssv_noexcept // (3) -{ - return nonstd::sv_lite::u32string_view{ str, len }; -} - -nssv_constexpr nonstd::sv_lite::wstring_view operator "" sv( const wchar_t* str, size_t len ) nssv_noexcept // (4) -{ - return nonstd::sv_lite::wstring_view{ str, len }; -} - -#endif // nssv_CONFIG_STD_SV_OPERATOR && nssv_HAVE_STD_DEFINED_LITERALS - -#if nssv_CONFIG_USR_SV_OPERATOR - -nssv_constexpr nonstd::sv_lite::string_view operator "" _sv( const char* str, size_t len ) nssv_noexcept // (1) -{ - return nonstd::sv_lite::string_view{ str, len }; -} - -nssv_constexpr nonstd::sv_lite::u16string_view operator "" _sv( const char16_t* str, size_t len ) nssv_noexcept // (2) -{ - return nonstd::sv_lite::u16string_view{ str, len }; -} - -nssv_constexpr nonstd::sv_lite::u32string_view operator "" _sv( const char32_t* str, size_t len ) nssv_noexcept // (3) -{ - return nonstd::sv_lite::u32string_view{ str, len }; -} - -nssv_constexpr nonstd::sv_lite::wstring_view operator "" _sv( const wchar_t* str, size_t len ) nssv_noexcept // (4) -{ - return nonstd::sv_lite::wstring_view{ str, len }; -} - -#endif // nssv_CONFIG_USR_SV_OPERATOR - -}}} // namespace nonstd::literals::string_view_literals - -#endif - -// -// Extensions for std::string: -// - -#if nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - -namespace nonstd { -namespace sv_lite { - -// Exclude MSVC 14 (19.00): it yields ambiguous to_string(): - -#if nssv_CPP11_OR_GREATER && nssv_COMPILER_MSVC_VERSION != 140 - -template< class CharT, class Traits, class Allocator = std::allocator > -std::basic_string -to_string( basic_string_view v, Allocator const & a = Allocator() ) -{ - return std::basic_string( v.begin(), v.end(), a ); -} - -#else - -template< class CharT, class Traits > -std::basic_string -to_string( basic_string_view v ) -{ - return std::basic_string( v.begin(), v.end() ); -} - -template< class CharT, class Traits, class Allocator > -std::basic_string -to_string( basic_string_view v, Allocator const & a ) -{ - return std::basic_string( v.begin(), v.end(), a ); -} - -#endif // nssv_CPP11_OR_GREATER - -template< class CharT, class Traits, class Allocator > -basic_string_view -to_string_view( std::basic_string const & s ) -{ - return basic_string_view( s.data(), s.size() ); -} - -}} // namespace nonstd::sv_lite - -#endif // nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS - -// -// make types and algorithms available in namespace nonstd: -// - -namespace nonstd { - -using sv_lite::basic_string_view; -using sv_lite::string_view; -using sv_lite::wstring_view; - -#if nssv_HAVE_WCHAR16_T -using sv_lite::u16string_view; -#endif -#if nssv_HAVE_WCHAR32_T -using sv_lite::u32string_view; -#endif - -// literal "sv" - -using sv_lite::operator==; -using sv_lite::operator!=; -using sv_lite::operator<; -using sv_lite::operator<=; -using sv_lite::operator>; -using sv_lite::operator>=; - -using sv_lite::operator<<; - -#if nssv_CONFIG_CONVERSION_STD_STRING_FREE_FUNCTIONS -using sv_lite::to_string; -using sv_lite::to_string_view; -#endif - -} // namespace nonstd - -// 24.4.5 Hash support (C++11): - -// Note: The hash value of a string view object is equal to the hash value of -// the corresponding string object. - -#if nssv_HAVE_STD_HASH - -#include - -namespace std { - -template<> -struct hash< nonstd::string_view > -{ -public: - std::size_t operator()( nonstd::string_view v ) const nssv_noexcept - { - return std::hash()( std::string( v.data(), v.size() ) ); - } -}; - -template<> -struct hash< nonstd::wstring_view > -{ -public: - std::size_t operator()( nonstd::wstring_view v ) const nssv_noexcept - { - return std::hash()( std::wstring( v.data(), v.size() ) ); - } -}; - -template<> -struct hash< nonstd::u16string_view > -{ -public: - std::size_t operator()( nonstd::u16string_view v ) const nssv_noexcept - { - return std::hash()( std::u16string( v.data(), v.size() ) ); - } -}; - -template<> -struct hash< nonstd::u32string_view > -{ -public: - std::size_t operator()( nonstd::u32string_view v ) const nssv_noexcept - { - return std::hash()( std::u32string( v.data(), v.size() ) ); - } -}; - -} // namespace std - -#endif // nssv_HAVE_STD_HASH - -nssv_RESTORE_WARNINGS() - -#endif // nssv_HAVE_STD_STRING_VIEW -#endif // NONSTD_SV_LITE_H_INCLUDED diff --git a/cpp/src/arrow/visit_data_inline.h b/cpp/src/arrow/visit_data_inline.h index 2919f3d96fe..7d37698f14d 100644 --- a/cpp/src/arrow/visit_data_inline.h +++ b/cpp/src/arrow/visit_data_inline.h @@ -17,6 +17,8 @@ #pragma once +#include + #include "arrow/array.h" #include "arrow/status.h" #include "arrow/type.h" @@ -25,7 +27,6 @@ #include "arrow/util/bit_util.h" #include "arrow/util/checked_cast.h" #include "arrow/util/functional.h" -#include "arrow/util/string_view.h" namespace arrow { namespace internal { @@ -89,7 +90,7 @@ struct ArraySpanInlineVisitor { // Binary, String... template struct ArraySpanInlineVisitor> { - using c_type = util::string_view; + using c_type = std::string_view; template static Status VisitStatus(const ArraySpan& arr, ValidFunc&& valid_func, @@ -114,7 +115,7 @@ struct ArraySpanInlineVisitor> { arr.buffers[0].data, arr.offset, arr.length, [&](int64_t i) { ARROW_UNUSED(i); - auto value = util::string_view(data + cur_offset, *offsets - cur_offset); + auto value = std::string_view(data + cur_offset, *offsets - cur_offset); cur_offset = *offsets++; return valid_func(value); }, @@ -146,8 +147,8 @@ struct ArraySpanInlineVisitor> { VisitBitBlocksVoid( arr.buffers[0].data, arr.offset, arr.length, [&](int64_t i) { - auto value = util::string_view(reinterpret_cast(data + offsets[i]), - offsets[i + 1] - offsets[i]); + auto value = std::string_view(reinterpret_cast(data + offsets[i]), + offsets[i + 1] - offsets[i]); valid_func(value); }, std::forward(null_func)); @@ -157,7 +158,7 @@ struct ArraySpanInlineVisitor> { // FixedSizeBinary, Decimal128 template struct ArraySpanInlineVisitor> { - using c_type = util::string_view; + using c_type = std::string_view; template static Status VisitStatus(const ArraySpan& arr, ValidFunc&& valid_func, @@ -168,7 +169,7 @@ struct ArraySpanInlineVisitor> { return VisitBitBlocks( arr.buffers[0].data, arr.offset, arr.length, [&](int64_t i) { - auto value = util::string_view(data, byte_width); + auto value = std::string_view(data, byte_width); data += byte_width; return valid_func(value); }, @@ -187,7 +188,7 @@ struct ArraySpanInlineVisitor> { VisitBitBlocksVoid( arr.buffers[0].data, arr.offset, arr.length, [&](int64_t i) { - valid_func(util::string_view(data, byte_width)); + valid_func(std::string_view(data, byte_width)); data += byte_width; }, [&]() { @@ -222,7 +223,7 @@ VisitArraySpanInline(const ArraySpan& arr, ValidFunc&& valid_func, NullFunc&& nu // The scalar value's type depends on the array data type: // - the type's `c_type`, if any // - for boolean arrays, a `bool` -// - for binary, string and fixed-size binary arrays, a `util::string_view` +// - for binary, string and fixed-size binary arrays, a `std::string_view` template struct ArraySpanVisitor { diff --git a/cpp/src/gandiva/gdv_function_stubs.cc b/cpp/src/gandiva/gdv_function_stubs.cc index cf0f4f9b917..a2a0216d314 100644 --- a/cpp/src/gandiva/gdv_function_stubs.cc +++ b/cpp/src/gandiva/gdv_function_stubs.cc @@ -120,7 +120,7 @@ bool gdv_fn_in_expr_lookup_utf8(int64_t ptr, const char* data, int data_len, } gandiva::InHolder* holder = reinterpret_cast*>(ptr); - return holder->HasValue(arrow::util::string_view(data, data_len)); + return holder->HasValue(std::string_view(data, data_len)); } int32_t gdv_fn_populate_varlen_vector(int64_t context_ptr, int8_t* data_ptr, @@ -205,8 +205,7 @@ const char* gdv_fn_base64_encode_binary(int64_t context, const char* in, int32_t return ""; } // use arrow method to encode base64 string - std::string encoded_str = - arrow::util::base64_encode(arrow::util::string_view(in, in_len)); + std::string encoded_str = arrow::util::base64_encode(std::string_view(in, in_len)); *out_len = static_cast(encoded_str.length()); // allocate memory for response char* ret = reinterpret_cast( @@ -233,8 +232,7 @@ const char* gdv_fn_base64_decode_utf8(int64_t context, const char* in, int32_t i return ""; } // use arrow method to decode base64 string - std::string decoded_str = - arrow::util::base64_decode(arrow::util::string_view(in, in_len)); + std::string decoded_str = arrow::util::base64_decode(std::string_view(in, in_len)); *out_len = static_cast(decoded_str.length()); // allocate memory for response char* ret = reinterpret_cast( diff --git a/cpp/src/gandiva/gdv_string_function_stubs.cc b/cpp/src/gandiva/gdv_string_function_stubs.cc index 0c963f4417f..cf04de3a8e1 100644 --- a/cpp/src/gandiva/gdv_string_function_stubs.cc +++ b/cpp/src/gandiva/gdv_string_function_stubs.cc @@ -21,11 +21,11 @@ #include #include +#include #include #include #include "arrow/util/double_conversion.h" -#include "arrow/util/string_view.h" #include "arrow/util/utf8_internal.h" #include "arrow/util/value_parsing.h" @@ -102,7 +102,7 @@ const char* gdv_fn_regexp_extract_utf8_utf8_int32(int64_t ptr, int64_t holder_pt *out_len = 0; \ return ""; \ } \ - arrow::Status status = formatter(value, [&](arrow::util::string_view v) { \ + arrow::Status status = formatter(value, [&](std::string_view v) { \ int64_t size = static_cast(v.size()); \ *out_len = static_cast(len < size ? len : size); \ memcpy(ret, v.data(), *out_len); \ @@ -138,7 +138,7 @@ const char* gdv_fn_regexp_extract_utf8_utf8_int32(int64_t ptr, int64_t holder_pt *out_len = 0; \ return ""; \ } \ - arrow::Status status = formatter(value, [&](arrow::util::string_view v) { \ + arrow::Status status = formatter(value, [&](std::string_view v) { \ int64_t size = static_cast(v.size()); \ *out_len = static_cast(len < size ? len : size); \ memcpy(ret, v.data(), *out_len); \ diff --git a/cpp/src/gandiva/in_holder.h b/cpp/src/gandiva/in_holder.h index d55ab5ec55f..65262969c5d 100644 --- a/cpp/src/gandiva/in_holder.h +++ b/cpp/src/gandiva/in_holder.h @@ -72,19 +72,17 @@ class InHolder { } } - bool HasValue(arrow::util::string_view value) const { - return values_lookup_.count(value) == 1; - } + bool HasValue(std::string_view value) const { return values_lookup_.count(value) == 1; } private: struct string_view_hash { public: - std::size_t operator()(arrow::util::string_view v) const { + std::size_t operator()(std::string_view v) const { return arrow::internal::ComputeStringHash<0>(v.data(), v.length()); } }; - std::unordered_set values_lookup_; + std::unordered_set values_lookup_; const std::unordered_set values_; }; diff --git a/cpp/src/parquet/arrow/reader_internal.cc b/cpp/src/parquet/arrow/reader_internal.cc index 64fcc451808..e428c206bfc 100644 --- a/cpp/src/parquet/arrow/reader_internal.cc +++ b/cpp/src/parquet/arrow/reader_internal.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -43,7 +44,6 @@ #include "arrow/util/endian.h" #include "arrow/util/int_util_overflow.h" #include "arrow/util/logging.h" -#include "arrow/util/string_view.h" #include "arrow/util/ubsan.h" #include "parquet/arrow/reader.h" diff --git a/cpp/src/parquet/arrow/schema.cc b/cpp/src/parquet/arrow/schema.cc index 716083f8a58..79c18c9b410 100644 --- a/cpp/src/parquet/arrow/schema.cc +++ b/cpp/src/parquet/arrow/schema.cc @@ -30,6 +30,7 @@ #include "arrow/util/checked_cast.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" +#include "arrow/util/string.h" #include "arrow/util/value_parsing.h" #include "parquet/arrow/schema_internal.h" @@ -44,6 +45,7 @@ using arrow::FieldVector; using arrow::KeyValueMetadata; using arrow::Status; using arrow::internal::checked_cast; +using arrow::internal::EndsWith; using ArrowType = arrow::DataType; using ArrowTypeId = arrow::Type; @@ -496,8 +498,8 @@ Status PopulateLeaf(int column_index, const std::shared_ptr& field, // If the name is array or ends in _tuple, this should be a list of struct // even for single child elements. bool HasStructListName(const GroupNode& node) { - ::arrow::util::string_view name{node.name()}; - return name == "array" || name.ends_with("_tuple"); + ::std::string_view name{node.name()}; + return name == "array" || EndsWith(name, "_tuple"); } Status GroupToStruct(const GroupNode& node, LevelInfo current_levels, diff --git a/cpp/src/parquet/encoding.cc b/cpp/src/parquet/encoding.cc index 5a0184b1860..bcefc68fa03 100644 --- a/cpp/src/parquet/encoding.cc +++ b/cpp/src/parquet/encoding.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -42,7 +43,6 @@ #include "arrow/util/int_util_overflow.h" #include "arrow/util/logging.h" #include "arrow/util/rle_encoding.h" -#include "arrow/util/string_view.h" #include "arrow/util/ubsan.h" #include "arrow/visit_data_inline.h" #include "parquet/exception.h" @@ -56,7 +56,7 @@ using arrow::Status; using arrow::VisitNullBitmapInline; using arrow::internal::AddWithOverflow; using arrow::internal::checked_cast; -using arrow::util::string_view; +using std::string_view; template using ArrowPoolVector = std::vector>; @@ -154,7 +154,7 @@ class PlainEncoder : public EncoderImpl, virtual public TypedEncoder { PARQUET_THROW_NOT_OK(::arrow::VisitArraySpanInline( *array.data(), - [&](::arrow::util::string_view view) { + [&](::std::string_view view) { if (ARROW_PREDICT_FALSE(view.size() > kMaxByteArraySize)) { return Status::Invalid("Parquet cannot store strings with size 2GB or more"); } @@ -617,7 +617,7 @@ class DictEncoderImpl : public EncoderImpl, virtual public DictEncoder { void PutBinaryArray(const ArrayType& array) { PARQUET_THROW_NOT_OK(::arrow::VisitArraySpanInline( *array.data(), - [&](::arrow::util::string_view view) { + [&](::std::string_view view) { if (ARROW_PREDICT_FALSE(view.size() > kMaxByteArraySize)) { return Status::Invalid("Parquet cannot store strings with size 2GB or more"); } @@ -658,7 +658,7 @@ void DictEncoderImpl::WriteDict(uint8_t* buffer) { // ByteArray and FLBA already have the dictionary encoded in their data heaps template <> void DictEncoderImpl::WriteDict(uint8_t* buffer) { - memo_table_.VisitValues(0, [&buffer](const ::arrow::util::string_view& v) { + memo_table_.VisitValues(0, [&buffer](const ::std::string_view& v) { uint32_t len = static_cast(v.length()); memcpy(buffer, &len, sizeof(len)); buffer += sizeof(len); @@ -669,7 +669,7 @@ void DictEncoderImpl::WriteDict(uint8_t* buffer) { template <> void DictEncoderImpl::WriteDict(uint8_t* buffer) { - memo_table_.VisitValues(0, [&](const ::arrow::util::string_view& v) { + memo_table_.VisitValues(0, [&](const ::std::string_view& v) { DCHECK_EQ(v.length(), static_cast(type_length_)); memcpy(buffer, v.data(), type_length_); buffer += type_length_; diff --git a/cpp/src/parquet/encryption/crypto_factory.cc b/cpp/src/parquet/encryption/crypto_factory.cc index 384516bff47..316793c73db 100644 --- a/cpp/src/parquet/encryption/crypto_factory.cc +++ b/cpp/src/parquet/encryption/crypto_factory.cc @@ -15,10 +15,11 @@ // specific language governing permissions and limitations // under the License. +#include + #include "arrow/result.h" #include "arrow/util/logging.h" #include "arrow/util/string.h" -#include "arrow/util/string_view.h" #include "parquet/encryption/crypto_factory.h" #include "parquet/encryption/encryption_internal.h" @@ -94,7 +95,7 @@ ColumnPathToEncryptionPropertiesMap CryptoFactory::GetColumnEncryptionProperties int dek_length, const std::string& column_keys, FileKeyWrapper* key_wrapper) { ColumnPathToEncryptionPropertiesMap encrypted_columns; - std::vector<::arrow::util::string_view> key_to_columns = + std::vector<::std::string_view> key_to_columns = ::arrow::internal::SplitString(column_keys, ';'); for (size_t i = 0; i < key_to_columns.size(); ++i) { std::string cur_key_to_columns = @@ -103,7 +104,7 @@ ColumnPathToEncryptionPropertiesMap CryptoFactory::GetColumnEncryptionProperties continue; } - std::vector<::arrow::util::string_view> parts = + std::vector<::std::string_view> parts = ::arrow::internal::SplitString(cur_key_to_columns, ':'); if (parts.size() != 2) { std::ostringstream message; @@ -118,7 +119,7 @@ ColumnPathToEncryptionPropertiesMap CryptoFactory::GetColumnEncryptionProperties } std::string column_names_str = ::arrow::internal::TrimString(std::string(parts[1])); - std::vector<::arrow::util::string_view> column_names = + std::vector<::std::string_view> column_names = ::arrow::internal::SplitString(column_names_str, ','); if (0 == column_names.size()) { throw ParquetException("No columns to encrypt defined for key: " + column_key_id); diff --git a/cpp/src/parquet/encryption/key_toolkit_internal.cc b/cpp/src/parquet/encryption/key_toolkit_internal.cc index dc9c070e7a3..6e0e4e6c65e 100644 --- a/cpp/src/parquet/encryption/key_toolkit_internal.cc +++ b/cpp/src/parquet/encryption/key_toolkit_internal.cc @@ -45,7 +45,7 @@ std::string EncryptKeyLocally(const std::string& key_bytes, const std::string& m static_cast(aad.size()), reinterpret_cast(&encrypted_key[0])); return ::arrow::util::base64_encode( - ::arrow::util::string_view(encrypted_key.data(), encrypted_key_len)); + ::std::string_view(encrypted_key.data(), encrypted_key_len)); } std::string DecryptKeyLocally(const std::string& encoded_encrypted_key, diff --git a/cpp/src/parquet/metadata.cc b/cpp/src/parquet/metadata.cc index 1b2a3df9c43..fed45fa2d82 100644 --- a/cpp/src/parquet/metadata.cc +++ b/cpp/src/parquet/metadata.cc @@ -21,13 +21,13 @@ #include #include #include +#include #include #include #include "arrow/io/memory.h" #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" -#include "arrow/util/string_view.h" #include "parquet/encryption/encryption_internal.h" #include "parquet/encryption/internal_file_decryptor.h" #include "parquet/exception.h" @@ -1050,8 +1050,8 @@ class ApplicationVersionParser { private: bool IsSpace(const std::string& string, const size_t& offset) { - auto target = ::arrow::util::string_view(string).substr(offset, 1); - return target.find_first_of(spaces_) != ::arrow::util::string_view::npos; + auto target = ::std::string_view(string).substr(offset, 1); + return target.find_first_of(spaces_) != ::std::string_view::npos; } void RemovePrecedingSpaces(const std::string& string, size_t& start, diff --git a/cpp/src/parquet/reader_test.cc b/cpp/src/parquet/reader_test.cc index 7776d995c02..a43238d1369 100644 --- a/cpp/src/parquet/reader_test.cc +++ b/cpp/src/parquet/reader_test.cc @@ -153,7 +153,7 @@ TEST_F(TestTextDeltaLengthByteArray, TestTextScanner) { ASSERT_FALSE(is_null); std::string expected = expected_prefix + std::to_string(i * i); ASSERT_TRUE(val.len == expected.length()); - ASSERT_EQ(::arrow::util::string_view(reinterpret_cast(val.ptr), val.len), + ASSERT_EQ(::std::string_view(reinterpret_cast(val.ptr), val.len), expected); } ASSERT_FALSE(scanner->HasNext()); @@ -200,9 +200,9 @@ TEST_F(TestTextDeltaLengthByteArray, TestBatchRead) { auto expected = expected_prefix + std::to_string((i + values_read) * (i + values_read)); ASSERT_TRUE(values[i].len == expected.length()); - ASSERT_EQ(::arrow::util::string_view(reinterpret_cast(values[i].ptr), - values[i].len), - expected); + ASSERT_EQ( + ::std::string_view(reinterpret_cast(values[i].ptr), values[i].len), + expected); } values_read += curr_batch_read; } diff --git a/cpp/src/parquet/stream_writer.cc b/cpp/src/parquet/stream_writer.cc index 253ebf1bc91..dc76c2935d4 100644 --- a/cpp/src/parquet/stream_writer.cc +++ b/cpp/src/parquet/stream_writer.cc @@ -136,7 +136,7 @@ StreamWriter& StreamWriter::operator<<(const std::string& v) { return WriteVariableLength(v.data(), v.size()); } -StreamWriter& StreamWriter::operator<<(::arrow::util::string_view v) { +StreamWriter& StreamWriter::operator<<(::std::string_view v) { return WriteVariableLength(v.data(), v.size()); } diff --git a/cpp/src/parquet/stream_writer.h b/cpp/src/parquet/stream_writer.h index 5801011e166..f95d39fd1d5 100644 --- a/cpp/src/parquet/stream_writer.h +++ b/cpp/src/parquet/stream_writer.h @@ -23,9 +23,9 @@ #include #include #include +#include #include -#include "arrow/util/string_view.h" #include "parquet/column_writer.h" #include "parquet/file_writer.h" @@ -123,7 +123,7 @@ class PARQUET_EXPORT StreamWriter { /// \brief Helper class to write fixed length strings. /// This is useful as the standard string view (such as - /// arrow::util::string_view) is for variable length data. + /// std::string_view) is for variable length data. struct PARQUET_EXPORT FixedStringView { FixedStringView() = default; @@ -149,7 +149,7 @@ class PARQUET_EXPORT StreamWriter { /// \brief Output operators for variable length strings. StreamWriter& operator<<(const char* v); StreamWriter& operator<<(const std::string& v); - StreamWriter& operator<<(::arrow::util::string_view v); + StreamWriter& operator<<(::std::string_view v); /// \brief Output operator for optional fields. template diff --git a/cpp/src/parquet/types.cc b/cpp/src/parquet/types.cc index 349fc682aad..532fd4c3d7b 100644 --- a/cpp/src/parquet/types.cc +++ b/cpp/src/parquet/types.cc @@ -73,7 +73,7 @@ std::unique_ptr GetCodec(Compression::type codec, int compression_level) return result; } -std::string FormatStatValue(Type::type parquet_type, ::arrow::util::string_view val) { +std::string FormatStatValue(Type::type parquet_type, ::std::string_view val) { std::stringstream result; const char* bytes = val.data(); diff --git a/cpp/src/parquet/types.h b/cpp/src/parquet/types.h index b419bf5dcf9..183a3705291 100644 --- a/cpp/src/parquet/types.h +++ b/cpp/src/parquet/types.h @@ -20,11 +20,11 @@ #include #include #include +#include #include #include #include - -#include "arrow/util/string_view.h" +#include #include "parquet/platform.h" #include "parquet/type_fwd.h" @@ -538,7 +538,7 @@ struct ByteArray { ByteArray() : len(0), ptr(NULLPTR) {} ByteArray(uint32_t len, const uint8_t* ptr) : len(len), ptr(ptr) {} - ByteArray(::arrow::util::string_view view) // NOLINT implicit conversion + ByteArray(::std::string_view view) // NOLINT implicit conversion : ByteArray(static_cast(view.size()), reinterpret_cast(view.data())) {} uint32_t len; @@ -743,7 +743,7 @@ PARQUET_EXPORT std::string ConvertedTypeToString(ConvertedType::type t); PARQUET_EXPORT std::string TypeToString(Type::type t); PARQUET_EXPORT std::string FormatStatValue(Type::type parquet_type, - ::arrow::util::string_view val); + ::std::string_view val); PARQUET_EXPORT int GetTypeByteSize(Type::type t); diff --git a/docs/source/cpp/gdb.rst b/docs/source/cpp/gdb.rst index 609f11a993a..ed1810a6720 100644 --- a/docs/source/cpp/gdb.rst +++ b/docs/source/cpp/gdb.rst @@ -165,4 +165,3 @@ Important utility classes are also covered: * :class:`arrow::Status` and :class:`arrow::Result` * :class:`arrow::Buffer` and subclasses * :class:`arrow::Decimal128`, :class:`arrow::Decimal256` -* :class:`arrow::util::string_view`, :class:`arrow::util::Variant` diff --git a/python/pyarrow/src/arrow_to_pandas.cc b/python/pyarrow/src/arrow_to_pandas.cc index 437f0f11925..ba67eb10553 100644 --- a/python/pyarrow/src/arrow_to_pandas.cc +++ b/python/pyarrow/src/arrow_to_pandas.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -43,7 +44,6 @@ #include "arrow/util/logging.h" #include "arrow/util/macros.h" #include "arrow/util/parallel.h" -#include "arrow/util/string_view.h" #include "arrow/visit_type_inline.h" #include "arrow/compute/api.h" @@ -586,7 +586,7 @@ template struct MemoizationTraits> { // For binary, we memoize string_view as a scalar value to avoid having to // unnecessarily copy the memory into the memo table data structure - using Scalar = util::string_view; + using Scalar = std::string_view; }; // Generic Array -> PyObject** converter that handles object deduplication, if @@ -1018,7 +1018,7 @@ struct ObjectWriterVisitor { enable_if_t::value || is_fixed_size_binary_type::value, Status> Visit(const Type& type) { - auto WrapValue = [](const util::string_view& view, PyObject** out) { + auto WrapValue = [](const std::string_view& view, PyObject** out) { *out = WrapBytes::Wrap(view.data(), view.length()); if (*out == nullptr) { PyErr_Clear(); diff --git a/python/pyarrow/src/common.h b/python/pyarrow/src/common.h index 768ff8dce44..59f15c8a135 100644 --- a/python/pyarrow/src/common.h +++ b/python/pyarrow/src/common.h @@ -18,6 +18,7 @@ #pragma once #include +#include #include #include "arrow/buffer.h" diff --git a/python/pyarrow/src/datetime.cc b/python/pyarrow/src/datetime.cc index 9604b529753..c4591ab50e0 100644 --- a/python/pyarrow/src/datetime.cc +++ b/python/pyarrow/src/datetime.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include "arrow/array.h" #include "arrow/scalar.h" @@ -40,14 +41,14 @@ namespace { // Same as Regex '([+-])(0[0-9]|1[0-9]|2[0-3]):([0-5][0-9])$'. // GCC 4.9 doesn't support regex, so handcode until support for it // is dropped. -bool MatchFixedOffset(const std::string& tz, util::string_view* sign, - util::string_view* hour, util::string_view* minute) { +bool MatchFixedOffset(const std::string& tz, std::string_view* sign, + std::string_view* hour, std::string_view* minute) { if (tz.size() < 5) { return false; } const char* iter = tz.data(); if (*iter == '+' || *iter == '-') { - *sign = util::string_view(iter, 1); + *sign = std::string_view(iter, 1); iter++; if (tz.size() < 6) { return false; @@ -55,7 +56,7 @@ bool MatchFixedOffset(const std::string& tz, util::string_view* sign, } if ((((*iter == '0' || *iter == '1') && *(iter + 1) >= '0' && *(iter + 1) <= '9') || (*iter == '2' && *(iter + 1) >= '0' && *(iter + 1) <= '3'))) { - *hour = util::string_view(iter, 2); + *hour = std::string_view(iter, 2); iter += 2; } else { return false; @@ -66,7 +67,7 @@ bool MatchFixedOffset(const std::string& tz, util::string_view* sign, iter++; if (*iter >= '0' && *iter <= '5' && *(iter + 1) >= '0' && *(iter + 1) <= '9') { - *minute = util::string_view(iter, 2); + *minute = std::string_view(iter, 2); iter += 2; } else { return false; @@ -389,7 +390,7 @@ Result PyTZInfo_utcoffset_hhmm(PyObject* pytzinfo) { // Converted from python. See https://github.com/apache/arrow/pull/7604 // for details. Result StringToTzinfo(const std::string& tz) { - util::string_view sign_str, hour_str, minute_str; + std::string_view sign_str, hour_str, minute_str; OwnedRef pytz; OwnedRef zoneinfo; OwnedRef datetime; diff --git a/python/pyarrow/src/gdb.cc b/python/pyarrow/src/gdb.cc index c681dfe9caa..16530a032d7 100644 --- a/python/pyarrow/src/gdb.cc +++ b/python/pyarrow/src/gdb.cc @@ -34,7 +34,6 @@ #include "arrow/util/key_value_metadata.h" #include "arrow/util/logging.h" #include "arrow/util/macros.h" -#include "arrow/util/string_view.h" namespace arrow { @@ -81,7 +80,7 @@ class UuidType : public ExtensionType { }; std::shared_ptr SliceArrayFromJSON(const std::shared_ptr& ty, - util::string_view json, int64_t offset = 0, + std::string_view json, int64_t offset = 0, int64_t length = -1) { auto array = *ArrayFromJSON(ty, json); if (length != -1) { @@ -121,12 +120,9 @@ void TestSession() { auto error_detail_result = Result(error_detail_status); // String views - util::string_view string_view_empty{}; - util::string_view string_view_abc{"abc"}; + std::string_view string_view_abc{"abc"}; std::string special_chars = std::string("foo\"bar") + '\x00' + "\r\n\t\x1f"; - util::string_view string_view_special_chars(special_chars); - std::string very_long = "abc" + std::string(5000, 'K') + "xyz"; - util::string_view string_view_very_long(very_long); + std::string_view string_view_special_chars(special_chars); // Buffers Buffer buffer_null{nullptr, 0}; diff --git a/python/pyarrow/tests/test_gdb.py b/python/pyarrow/tests/test_gdb.py index 3056cb4326d..d0d241cc564 100644 --- a/python/pyarrow/tests/test_gdb.py +++ b/python/pyarrow/tests/test_gdb.py @@ -264,20 +264,6 @@ def test_status(gdb_arrow): 'detail=[custom-detail-id] "This is a detail"))') -def test_string_view(gdb_arrow): - check_stack_repr(gdb_arrow, "string_view_empty", - "arrow::util::string_view of size 0") - check_stack_repr(gdb_arrow, "string_view_abc", - 'arrow::util::string_view of size 3, "abc"') - check_stack_repr( - gdb_arrow, "string_view_special_chars", - r'arrow::util::string_view of size 12, "foo\"bar\000\r\n\t\037"') - check_stack_repr( - gdb_arrow, "string_view_very_long", - 'arrow::util::string_view of size 5006, ' - '"abc", \'K\' ...') - - def test_buffer_stack(gdb_arrow): check_stack_repr(gdb_arrow, "buffer_null", "arrow::Buffer of size 0, read-only") diff --git a/r/src/altrep.cpp b/r/src/altrep.cpp index 97bb72b3df7..8e19d13a0ef 100644 --- a/r/src/altrep.cpp +++ b/r/src/altrep.cpp @@ -779,7 +779,7 @@ struct AltrepVectorString : public AltrepVectorBase> { std::string stripped_string_; const bool strip_out_nuls_; bool nul_was_stripped_; - util::string_view view_; + std::string_view view_; }; // Get a single string, as a CHARSXP SEXP from data2. diff --git a/r/src/array_to_vector.cpp b/r/src/array_to_vector.cpp index dccc29537ed..d7c51e79359 100644 --- a/r/src/array_to_vector.cpp +++ b/r/src/array_to_vector.cpp @@ -374,11 +374,11 @@ struct Converter_String : public Converter { bool Parallel() const { return false; } private: - static SEXP r_string_from_view(arrow::util::string_view view) { + static SEXP r_string_from_view(std::string_view view) { return Rf_mkCharLenCE(view.data(), view.size(), CE_UTF8); } - static SEXP r_string_from_view_strip_nul(arrow::util::string_view view, + static SEXP r_string_from_view_strip_nul(std::string_view view, bool* nul_was_stripped) { const char* old_string = view.data(); @@ -391,7 +391,7 @@ struct Converter_String : public Converter { if (nul_count == 1) { // first nul spotted: allocate stripped string storage - stripped_string = view.to_string(); + stripped_string = std::string(view); stripped_len = i; } From d7258aa1ee6e38b3a638b6a87c69659b73bb2ba4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Percy=20Camilo=20Trive=C3=B1o=20Aucahuasi?= Date: Wed, 21 Sep 2022 03:18:46 -0500 Subject: [PATCH 113/133] ARROW-17686: [C++] Add custom ToPrint to AsofJoinBasicTest (#14172) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add a simple printer compatible with [gtest printers](https://github.com/google/googletest/blob/main/googletest/include/gtest/gtest-printers.h), hopefully this will solve some issues with valgrind (3.13.0) on ubuntu 18.04 running the unit test `arrow-compute-asof-join-node-test`. Jira ticket: https://issues.apache.org/jira/browse/ARROW-17686 Authored-by: Percy Camilo Triveño Aucahuasi Signed-off-by: Antoine Pitrou --- cpp/src/arrow/compute/exec/asof_join_node_test.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index c8dbd27d7b6..2b6613021ff 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -633,6 +633,10 @@ struct BasicTest { using AsofJoinBasicParams = std::tuple, std::string>; +void PrintTo(const AsofJoinBasicParams& x, ::std::ostream* os) { + *os << "AsofJoinBasicParams: " << std::get<1>(x); +} + struct AsofJoinBasicTest : public testing::TestWithParam {}; class AsofJoinTest : public testing::Test {}; From fc98c95bed9324971478d600ea7f82bd76d3e2c9 Mon Sep 17 00:00:00 2001 From: david dali susanibar arce Date: Wed, 21 Sep 2022 06:54:57 -0500 Subject: [PATCH 114/133] ARROW-15745: [Java] Deprecate redundant iterable of ScanTask (#14168) Deprecate redundant iterable ScanTask since there are no more ScanTasks on the C++ side Authored-by: david dali susanibar arce Signed-off-by: David Li --- .../arrow/dataset/jni/NativeScanTask.java | 1 + .../arrow/dataset/jni/NativeScanner.java | 13 ++++ .../arrow/dataset/scanner/ScanTask.java | 1 + .../apache/arrow/dataset/scanner/Scanner.java | 10 +++ .../org/apache/arrow/dataset/TestDataset.java | 9 +-- .../dataset/file/TestFileSystemDataset.java | 69 ++++++------------- .../arrow/dataset/jni/TestNativeDataset.java | 4 +- 7 files changed, 51 insertions(+), 56 deletions(-) diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanTask.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanTask.java index e4764236dad..7747dd60340 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanTask.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanTask.java @@ -25,6 +25,7 @@ * id via {@link JniWrapper}, thus we allow only one-time execution of method {@link #execute()}. If a re-scan * operation is expected, call {@link NativeDataset#newScan} to create a new scanner instance. */ +@Deprecated public class NativeScanTask implements ScanTask { private final NativeScanner scanner; diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java index de18f9e5e0b..8ca8e5cf50e 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeScanner.java @@ -68,6 +68,19 @@ ArrowReader execute() { } @Override + public ArrowReader scanBatches() { + if (closed) { + throw new NativeInstanceReleasedException(); + } + if (!executed.compareAndSet(false, true)) { + throw new UnsupportedOperationException("NativeScanner can only be executed once. Create a " + + "new scanner instead"); + } + return new NativeReader(context.getAllocator()); + } + + @Override + @Deprecated public Iterable scan() { if (closed) { throw new NativeInstanceReleasedException(); diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanTask.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanTask.java index 434f5c9a6fa..16b8aeefb61 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanTask.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanTask.java @@ -26,6 +26,7 @@ * ScanTask is meant to be a unit of work to be dispatched. The implementation * must be thread and concurrent safe. */ +@Deprecated public interface ScanTask extends AutoCloseable { /** diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/Scanner.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/Scanner.java index 93a1b08f366..43749b7db8e 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/Scanner.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/Scanner.java @@ -17,6 +17,7 @@ package org.apache.arrow.dataset.scanner; +import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.Schema; /** @@ -24,12 +25,21 @@ */ public interface Scanner extends AutoCloseable { + /** + * Read the dataset as a stream of record batches. + * + * @return a {@link ArrowReader}. + */ + ArrowReader scanBatches(); + /** * Perform the scan operation. * * @return a iterable set of {@link ScanTask}s. Each task is considered independent and it is allowed * to execute the tasks concurrently to gain better performance. + * @deprecated use {@link #scanBatches()} instead. */ + @Deprecated Iterable scan(); /** diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/TestDataset.java b/java/dataset/src/test/java/org/apache/arrow/dataset/TestDataset.java index 15224534d28..2516c409593 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/TestDataset.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/TestDataset.java @@ -28,7 +28,6 @@ import java.util.stream.StreamSupport; import org.apache.arrow.dataset.scanner.ScanOptions; -import org.apache.arrow.dataset.scanner.ScanTask; import org.apache.arrow.dataset.scanner.Scanner; import org.apache.arrow.dataset.source.Dataset; import org.apache.arrow.dataset.source.DatasetFactory; @@ -63,9 +62,7 @@ protected List collectResultFromFactory(DatasetFactory factory final Dataset dataset = factory.finish(); final Scanner scanner = dataset.newScan(options); try { - final List ret = stream(scanner.scan()) - .flatMap(t -> stream(collectTaskData(t))) - .collect(Collectors.toList()); + final List ret = collectTaskData(scanner); AutoCloseables.close(scanner, dataset); return ret; } catch (RuntimeException e) { @@ -75,8 +72,8 @@ protected List collectResultFromFactory(DatasetFactory factory } } - protected List collectTaskData(ScanTask scanTask) { - try (ArrowReader reader = scanTask.execute()) { + protected List collectTaskData(Scanner scan) { + try (ArrowReader reader = scan.scanBatches()) { List batches = new ArrayList<>(); while (reader.loadNextBatch()) { VectorSchemaRoot root = reader.getVectorSchemaRoot(); diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java index b8d51a3edb1..9dc5f2b655a 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestFileSystemDataset.java @@ -42,7 +42,6 @@ import org.apache.arrow.dataset.jni.NativeDataset; import org.apache.arrow.dataset.jni.NativeInstanceReleasedException; import org.apache.arrow.dataset.jni.NativeMemoryPool; -import org.apache.arrow.dataset.jni.NativeScanTask; import org.apache.arrow.dataset.jni.NativeScanner; import org.apache.arrow.dataset.jni.TestNativeDataset; import org.apache.arrow.dataset.scanner.ScanOptions; @@ -88,7 +87,7 @@ public void testBaseParquetRead() throws Exception { Schema schema = inferResultSchemaFromFactory(factory, options); List datum = collectResultFromFactory(factory, options); - assertSingleTaskProduced(factory, options); + assertScanBatchesProduced(factory, options); assertEquals(1, datum.size()); assertEquals(2, schema.getFields().size()); assertEquals("id", schema.getFields().get(0).getName()); @@ -112,7 +111,7 @@ public void testParquetProjectSingleColumn() throws Exception { List datum = collectResultFromFactory(factory, options); org.apache.avro.Schema expectedSchema = truncateAvroSchema(writeSupport.getAvroSchema(), 0, 1); - assertSingleTaskProduced(factory, options); + assertScanBatchesProduced(factory, options); assertEquals(1, schema.getFields().size()); assertEquals("id", schema.getFields().get(0).getName()); assertEquals(Types.MinorType.INT.getType(), schema.getFields().get(0).getType()); @@ -139,7 +138,7 @@ public void testParquetBatchSize() throws Exception { Schema schema = inferResultSchemaFromFactory(factory, options); List datum = collectResultFromFactory(factory, options); - assertSingleTaskProduced(factory, options); + assertScanBatchesProduced(factory, options); assertEquals(3, datum.size()); datum.forEach(batch -> assertEquals(1, batch.getLength())); checkParquetReadResult(schema, writeSupport.getWrittenRecords(), datum); @@ -163,7 +162,7 @@ public void testParquetDirectoryRead() throws Exception { Schema schema = inferResultSchemaFromFactory(factory, options); List datum = collectResultFromFactory(factory, options); - assertSingleTaskProduced(factory, options); + assertScanBatchesProduced(factory, options); assertEquals(7, datum.size()); datum.forEach(batch -> assertEquals(1, batch.getLength())); checkParquetReadResult(schema, expectedJsonUnordered, datum); @@ -182,7 +181,7 @@ public void testEmptyProjectSelectsZeroColumns() throws Exception { List datum = collectResultFromFactory(factory, options); org.apache.avro.Schema expectedSchema = org.apache.avro.Schema.createRecord(Collections.emptyList()); - assertSingleTaskProduced(factory, options); + assertScanBatchesProduced(factory, options); assertEquals(0, schema.getFields().size()); assertEquals(1, datum.size()); checkParquetReadResult(schema, @@ -204,7 +203,7 @@ public void testNullProjectSelectsAllColumns() throws Exception { Schema schema = inferResultSchemaFromFactory(factory, options); List datum = collectResultFromFactory(factory, options); - assertSingleTaskProduced(factory, options); + assertScanBatchesProduced(factory, options); assertEquals(1, datum.size()); assertEquals(2, schema.getFields().size()); assertEquals("id", schema.getFields().get(0).getName()); @@ -233,7 +232,7 @@ public void testNoErrorWhenCloseAgain() throws Exception { } @Test - public void testErrorThrownWhenScanAgain() throws Exception { + public void testErrorThrownWhenScanBatchesAgain() throws Exception { ParquetWriteSupport writeSupport = ParquetWriteSupport.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 1, "a"); FileSystemDatasetFactory factory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), @@ -241,25 +240,18 @@ public void testErrorThrownWhenScanAgain() throws Exception { NativeDataset dataset = factory.finish(); ScanOptions options = new ScanOptions(100); NativeScanner scanner = dataset.newScan(options); - List taskList1 = collect(scanner.scan()); - List taskList2 = collect(scanner.scan()); - NativeScanTask task1 = taskList1.get(0); - NativeScanTask task2 = taskList2.get(0); - List datum = collectTaskData(task1); - + List datum = collectTaskData(scanner); AutoCloseables.close(datum); - - UnsupportedOperationException uoe = assertThrows(UnsupportedOperationException.class, task2::execute); - Assertions.assertEquals("NativeScanner cannot be executed more than once. Consider creating new scanner instead", + UnsupportedOperationException uoe = assertThrows(UnsupportedOperationException.class, + scanner::scanBatches); + Assertions.assertEquals("NativeScanner can only be executed once. Create a new scanner instead", uoe.getMessage()); - AutoCloseables.close(taskList1); - AutoCloseables.close(taskList2); AutoCloseables.close(scanner, dataset, factory); } @Test - public void testScanInOtherThread() throws Exception { + public void testScanBatchesInOtherThread() throws Exception { ExecutorService executor = Executors.newSingleThreadExecutor(); ParquetWriteSupport writeSupport = ParquetWriteSupport.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 1, "a"); @@ -268,17 +260,14 @@ public void testScanInOtherThread() throws Exception { NativeDataset dataset = factory.finish(); ScanOptions options = new ScanOptions(100); NativeScanner scanner = dataset.newScan(options); - List taskList = collect(scanner.scan()); - NativeScanTask task = taskList.get(0); - List datum = executor.submit(() -> collectTaskData(task)).get(); + List datum = executor.submit(() -> collectTaskData(scanner)).get(); AutoCloseables.close(datum); - AutoCloseables.close(taskList); AutoCloseables.close(scanner, dataset, factory); } @Test - public void testErrorThrownWhenScanAfterScannerClose() throws Exception { + public void testErrorThrownWhenScanBatchesAfterScannerClose() throws Exception { ParquetWriteSupport writeSupport = ParquetWriteSupport.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 1, "a"); FileSystemDatasetFactory factory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), @@ -287,28 +276,13 @@ public void testErrorThrownWhenScanAfterScannerClose() throws Exception { ScanOptions options = new ScanOptions(100); NativeScanner scanner = dataset.newScan(options); scanner.close(); - assertThrows(NativeInstanceReleasedException.class, scanner::scan); - AutoCloseables.close(factory); - } - - @Test - public void testErrorThrownWhenExecuteTaskAfterTaskClose() throws Exception { - ParquetWriteSupport writeSupport = ParquetWriteSupport.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 1, "a"); + assertThrows(NativeInstanceReleasedException.class, scanner::scanBatches); - FileSystemDatasetFactory factory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), - FileFormat.PARQUET, writeSupport.getOutputURI()); - NativeDataset dataset = factory.finish(); - ScanOptions options = new ScanOptions(100); - NativeScanner scanner = dataset.newScan(options); - List tasks = collect(scanner.scan()); - NativeScanTask task = tasks.get(0); - task.close(); - assertThrows(NativeInstanceReleasedException.class, task::execute); AutoCloseables.close(factory); } @Test - public void testErrorThrownWhenIterateOnIteratorAfterTaskClose() throws Exception { + public void testErrorThrownWhenReadAfterNativeReaderClose() throws Exception { ParquetWriteSupport writeSupport = ParquetWriteSupport.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 1, "a"); FileSystemDatasetFactory factory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), @@ -316,11 +290,10 @@ public void testErrorThrownWhenIterateOnIteratorAfterTaskClose() throws Exceptio NativeDataset dataset = factory.finish(); ScanOptions options = new ScanOptions(100); NativeScanner scanner = dataset.newScan(options); - List tasks = collect(scanner.scan()); - NativeScanTask task = tasks.get(0); - ArrowReader reader = task.execute(); - task.close(); + ArrowReader reader = scanner.scanBatches(); + scanner.close(); assertThrows(NativeInstanceReleasedException.class, reader::loadNextBatch); + AutoCloseables.close(factory); } @@ -348,7 +321,7 @@ public void testBaseArrowIpcRead() throws Exception { Schema schema = inferResultSchemaFromFactory(factory, options); List datum = collectResultFromFactory(factory, options); - assertSingleTaskProduced(factory, options); + assertScanBatchesProduced(factory, options); assertEquals(1, datum.size()); assertEquals(1, schema.getFields().size()); assertEquals("ints", schema.getFields().get(0).getName()); @@ -376,7 +349,7 @@ public void testBaseOrcRead() throws Exception { Schema schema = inferResultSchemaFromFactory(factory, options); List datum = collectResultFromFactory(factory, options); - assertSingleTaskProduced(factory, options); + assertScanBatchesProduced(factory, options); assertEquals(1, datum.size()); assertEquals(1, schema.getFields().size()); assertEquals("ints", schema.getFields().get(0).getName()); diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/jni/TestNativeDataset.java b/java/dataset/src/test/java/org/apache/arrow/dataset/jni/TestNativeDataset.java index 2a86a256883..d0f91769096 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/jni/TestNativeDataset.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/jni/TestNativeDataset.java @@ -25,9 +25,9 @@ import org.junit.Assert; public abstract class TestNativeDataset extends TestDataset { - protected void assertSingleTaskProduced(DatasetFactory factory, ScanOptions options) { + protected void assertScanBatchesProduced(DatasetFactory factory, ScanOptions options) { final Dataset dataset = factory.finish(); final Scanner scanner = dataset.newScan(options); - Assert.assertEquals(1L, stream(scanner.scan()).count()); + Assert.assertNotNull(scanner.scanBatches()); } } From bd433c014990d708c084e1f3d166c93116fdbd2f Mon Sep 17 00:00:00 2001 From: Igor Suhorukov Date: Wed, 21 Sep 2022 15:06:46 +0300 Subject: [PATCH 115/133] ARROW-17629: [Java] Bind DB column to Arrow Map type in JdbcToArrowUtils (#14134) This pull request allows support of mapping for map type (hstore translated by jdbc driver to java.util.Map or json text/varchar). Mapping for MapConsumer should be manually expressed in new JdbcToArrowConfigBuilder(...).setJdbcToArrowTypeConverter(USER_CUSTOM_jdbcToArrowTypeConverter). Now it possible as part of [ARROW-17630](https://issues.apache.org/jira/browse/ARROW-17630) that allow user to distinguish columns by number and related external metadata. Authored-by: igor.suhorukov Signed-off-by: David Li --- java/adapter/jdbc/pom.xml | 2 - .../arrow/adapter/jdbc/JdbcFieldInfo.java | 12 ++ .../arrow/adapter/jdbc/JdbcToArrowConfig.java | 2 +- .../arrow/adapter/jdbc/JdbcToArrowUtils.java | 13 +++ .../adapter/jdbc/consumer/MapConsumer.java | 104 ++++++++++++++++++ .../adapter/jdbc/AbstractJdbcToArrowTest.java | 38 ++++++- .../adapter/jdbc/JdbcToArrowTestHelper.java | 50 +++++++++ .../jdbc/h2/JdbcToArrowCharSetTest.java | 20 ++-- .../jdbc/h2/JdbcToArrowDataTypesTest.java | 20 ++-- .../jdbc/h2/JdbcToArrowMapDataTypeTest.java | 75 +++++++++++++ .../adapter/jdbc/h2/JdbcToArrowNullTest.java | 45 +++++--- .../h2/JdbcToArrowOptionalColumnsTest.java | 6 +- .../adapter/jdbc/h2/JdbcToArrowTest.java | 65 +++++------ .../jdbc/h2/JdbcToArrowTimeZoneTest.java | 14 ++- .../h2/JdbcToArrowVectorIteratorTest.java | 2 +- .../resources/h2/test1_all_datatypes_h2.yml | 27 ++--- .../h2/test1_all_datatypes_null_h2.yml | 13 ++- ...t1_all_datatypes_selected_null_rows_h2.yml | 16 +-- .../src/test/resources/h2/test1_map_h2.yml | 33 ++++++ .../h2/test1_selected_datatypes_null_h2.yml | 5 +- 20 files changed, 452 insertions(+), 110 deletions(-) create mode 100644 java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/consumer/MapConsumer.java create mode 100644 java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowMapDataTypeTest.java create mode 100644 java/adapter/jdbc/src/test/resources/h2/test1_map_h2.yml diff --git a/java/adapter/jdbc/pom.xml b/java/adapter/jdbc/pom.xml index aaadda0375f..4355b3c5014 100644 --- a/java/adapter/jdbc/pom.xml +++ b/java/adapter/jdbc/pom.xml @@ -67,13 +67,11 @@ com.fasterxml.jackson.core jackson-databind - test com.fasterxml.jackson.core jackson-core - test diff --git a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcFieldInfo.java b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcFieldInfo.java index 3443a1e44c1..3237c9bf97b 100644 --- a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcFieldInfo.java +++ b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcFieldInfo.java @@ -35,6 +35,7 @@ * */ public class JdbcFieldInfo { + private final int column; private final int jdbcType; private final int nullability; private final int precision; @@ -53,6 +54,7 @@ public JdbcFieldInfo(int jdbcType) { (jdbcType != Types.DECIMAL && jdbcType != Types.NUMERIC), "DECIMAL and NUMERIC types require a precision and scale; please use another constructor."); + this.column = 0; this.jdbcType = jdbcType; this.nullability = ResultSetMetaData.columnNullableUnknown; this.precision = 0; @@ -68,6 +70,7 @@ public JdbcFieldInfo(int jdbcType) { * @param scale The field's numeric scale. */ public JdbcFieldInfo(int jdbcType, int precision, int scale) { + this.column = 0; this.jdbcType = jdbcType; this.nullability = ResultSetMetaData.columnNullableUnknown; this.precision = precision; @@ -84,6 +87,7 @@ public JdbcFieldInfo(int jdbcType, int precision, int scale) { * @param scale The field's numeric scale. */ public JdbcFieldInfo(int jdbcType, int nullability, int precision, int scale) { + this.column = 0; this.jdbcType = jdbcType; this.nullability = nullability; this.precision = precision; @@ -106,6 +110,7 @@ public JdbcFieldInfo(ResultSetMetaData rsmd, int column) throws SQLException { column <= rsmd.getColumnCount(), "The index must be within the number of columns (1 to %s, inclusive)", rsmd.getColumnCount()); + this.column = column; this.jdbcType = rsmd.getColumnType(column); this.nullability = rsmd.isNullable(column); this.precision = rsmd.getPrecision(column); @@ -139,4 +144,11 @@ public int getPrecision() { public int getScale() { return scale; } + + /** + * The column index for query column. + */ + public int getColumn() { + return column; + } } diff --git a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfig.java b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfig.java index d11682c7d4a..012cd95c0b2 100644 --- a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfig.java +++ b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowConfig.java @@ -211,7 +211,7 @@ public final class JdbcToArrowConfig { // set up type converter this.jdbcToArrowTypeConverter = jdbcToArrowTypeConverter != null ? jdbcToArrowTypeConverter : - jdbcFieldInfo -> JdbcToArrowUtils.getArrowTypeFromJdbcType(jdbcFieldInfo, calendar); + (jdbcFieldInfo) -> JdbcToArrowUtils.getArrowTypeFromJdbcType(jdbcFieldInfo, calendar); } /** diff --git a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java index e782efde3d4..93c6a80c107 100644 --- a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java +++ b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java @@ -31,6 +31,7 @@ import java.sql.Timestamp; import java.sql.Types; import java.util.ArrayList; +import java.util.Arrays; import java.util.Calendar; import java.util.HashMap; import java.util.List; @@ -49,6 +50,7 @@ import org.apache.arrow.adapter.jdbc.consumer.FloatConsumer; import org.apache.arrow.adapter.jdbc.consumer.IntConsumer; import org.apache.arrow.adapter.jdbc.consumer.JdbcConsumer; +import org.apache.arrow.adapter.jdbc.consumer.MapConsumer; import org.apache.arrow.adapter.jdbc.consumer.NullConsumer; import org.apache.arrow.adapter.jdbc.consumer.SmallIntConsumer; import org.apache.arrow.adapter.jdbc.consumer.TimeConsumer; @@ -76,6 +78,7 @@ import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -279,6 +282,14 @@ public static Schema jdbcToArrowSchema(ResultSetMetaData rsmd, JdbcToArrowConfig children = new ArrayList(); final ArrowType childType = config.getJdbcToArrowTypeConverter().apply(arrayFieldInfo); children.add(new Field("child", FieldType.nullable(childType), null)); + } else if (arrowType.getTypeID() == ArrowType.ArrowTypeID.Map) { + FieldType mapType = new FieldType(false, ArrowType.Struct.INSTANCE, null, null); + FieldType keyType = new FieldType(false, new ArrowType.Utf8(), null, null); + FieldType valueType = new FieldType(false, new ArrowType.Utf8(), null, null); + children = new ArrayList<>(); + children.add(new Field("child", mapType, + Arrays.asList(new Field(MapVector.KEY_NAME, keyType, null), + new Field(MapVector.VALUE_NAME, valueType, null)))); } fields.add(new Field(columnName, fieldType, children)); @@ -471,6 +482,8 @@ static JdbcConsumer getConsumer(ArrowType arrowType, int columnIndex, boolean nu JdbcConsumer delegate = getConsumer(childVector.getField().getType(), JDBC_ARRAY_VALUE_COLUMN, childVector.getField().isNullable(), childVector, config); return ArrayConsumer.createConsumer((ListVector) vector, delegate, columnIndex, nullable); + case Map: + return MapConsumer.createConsumer((MapVector) vector, columnIndex, nullable); case Null: return new NullConsumer((NullVector) vector); default: diff --git a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/consumer/MapConsumer.java b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/consumer/MapConsumer.java new file mode 100644 index 00000000000..07a071bfc09 --- /dev/null +++ b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/consumer/MapConsumer.java @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adapter.jdbc.consumer; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.util.Map; + +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.impl.UnionMapWriter; +import org.apache.arrow.vector.util.ObjectMapperFactory; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; + +/** + * Consumer which consume map type values from {@link ResultSet}. + * Write the data into {@link org.apache.arrow.vector.complex.MapVector}. + */ +public class MapConsumer extends BaseConsumer { + + + private final UnionMapWriter writer; + private final ObjectMapper objectMapper = ObjectMapperFactory.newObjectMapper(); + private final TypeReference> typeReference = new TypeReference>() {}; + private int currentRow; + + /** + * Creates a consumer for {@link MapVector}. + */ + public static MapConsumer createConsumer(MapVector mapVector, int index, boolean nullable) { + return new MapConsumer(mapVector, index); + } + + /** + * Instantiate a MapConsumer. + */ + public MapConsumer(MapVector vector, int index) { + super(vector, index); + writer = vector.getWriter(); + } + + @Override + public void consume(ResultSet resultSet) throws SQLException, IOException { + Object map = resultSet.getObject(columnIndexInResultSet); + writer.setPosition(currentRow++); + if (map != null) { + if (map instanceof String) { + writeJavaMapIntoVector(objectMapper.readValue((String) map, typeReference)); + } else if (map instanceof Map) { + writeJavaMapIntoVector((Map) map); + } else { + throw new IllegalArgumentException("Unknown type of map type column from JDBC " + map.getClass().getName()); + } + } else { + writer.writeNull(); + } + } + + private void writeJavaMapIntoVector(Map map) { + BufferAllocator allocator = vector.getAllocator(); + writer.startMap(); + map.forEach((key, value) -> { + byte[] keyBytes = key.getBytes(StandardCharsets.UTF_8); + byte[] valueBytes = value != null ? value.getBytes(StandardCharsets.UTF_8) : null; + try ( + ArrowBuf keyBuf = allocator.buffer(keyBytes.length); + ArrowBuf valueBuf = valueBytes != null ? allocator.buffer(valueBytes.length) : null; + ) { + writer.startEntry(); + keyBuf.writeBytes(keyBytes); + writer.key().varChar().writeVarChar(0, keyBytes.length, keyBuf); + if (valueBytes != null) { + valueBuf.writeBytes(valueBytes); + writer.value().varChar().writeVarChar(0, valueBytes.length, valueBuf); + } else { + writer.value().varChar().writeNull(); + } + writer.endEntry(); + } + }); + writer.endMap(); + } +} + diff --git a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/AbstractJdbcToArrowTest.java b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/AbstractJdbcToArrowTest.java index 645e343ffd0..dc36ef9f827 100644 --- a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/AbstractJdbcToArrowTest.java +++ b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/AbstractJdbcToArrowTest.java @@ -21,6 +21,7 @@ import java.sql.Connection; import java.sql.DriverManager; import java.sql.ResultSet; +import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; import java.sql.Types; @@ -28,11 +29,13 @@ import java.util.HashMap; import java.util.Map; import java.util.TimeZone; +import java.util.function.Function; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.util.ValueVectorUtility; import org.junit.After; import org.junit.Before; @@ -58,6 +61,7 @@ public abstract class AbstractJdbcToArrowTest { protected static final String DOUBLE = "DOUBLE_FIELD7"; protected static final String INT = "INT_FIELD1"; protected static final String LIST = "LIST_FIELD19"; + protected static final String MAP = "MAP_FIELD20"; protected static final String REAL = "REAL_FIELD8"; protected static final String SMALLINT = "SMALLINT_FIELD4"; protected static final String TIME = "TIME_FIELD9"; @@ -155,8 +159,10 @@ public static Object[][] prepareTestData(String[] testFiles, @SuppressWarnings(" * Abstract method to implement logic to assert test various datatype values. * * @param root VectorSchemaRoot for test + * @param isIncludeMapVector is this dataset checks includes map column. + * Jdbc type to 'map' mapping declared in configuration only manually */ - public abstract void testDataSets(VectorSchemaRoot root); + public abstract void testDataSets(VectorSchemaRoot root, boolean isIncludeMapVector); /** * For the given SQL query, execute and fetch the data from Relational DB and convert it to Arrow objects. @@ -342,4 +348,34 @@ public static VectorSchemaRoot sqlToArrow(ResultSet resultSet, JdbcToArrowConfig return root; } + /** + * Register MAP_FIELD20 as ArrowType.Map + * @param calendar Calendar instance to use for Date, Time and Timestamp datasets, or null if none. + * @param rsmd ResultSetMetaData to lookup column name from result set metadata + * @return typeConverter instance with mapping column to Map type + */ + protected Function jdbcToArrowTypeConverter( + Calendar calendar, ResultSetMetaData rsmd) { + return (jdbcFieldInfo) -> { + String columnLabel = null; + try { + int columnIndex = jdbcFieldInfo.getColumn(); + if (columnIndex != 0) { + columnLabel = rsmd.getColumnLabel(columnIndex); + } + } catch (SQLException e) { + throw new RuntimeException(e); + } + if (MAP.equals(columnLabel)) { + return new ArrowType.Map(false); + } else { + return JdbcToArrowUtils.getArrowTypeFromJdbcType(jdbcFieldInfo, calendar); + } + }; + } + + protected ResultSetMetaData getQueryMetaData(String query) throws SQLException { + return conn.createStatement().executeQuery(query).getMetaData(); + } + } diff --git a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcToArrowTestHelper.java b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcToArrowTestHelper.java index e7b7fe0455b..5d1fb2276cc 100644 --- a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcToArrowTestHelper.java +++ b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcToArrowTestHelper.java @@ -26,7 +26,9 @@ import java.nio.charset.Charset; import java.sql.ResultSetMetaData; import java.sql.SQLException; +import java.util.AbstractMap; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -47,8 +49,17 @@ import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.JsonStringArrayList; +import org.apache.arrow.vector.util.JsonStringHashMap; +import org.apache.arrow.vector.util.ObjectMapperFactory; +import org.apache.arrow.vector.util.Text; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; /** * This is a Helper class which has functionalities to read and assert the values from the given FieldVector object. @@ -240,6 +251,45 @@ public static void assertListVectorValues(ListVector listVector, int rowCount, I } } + public static void assertMapVectorValues(MapVector mapVector, int rowCount, Map[] values) { + assertEquals(rowCount, mapVector.getValueCount()); + + for (int j = 0; j < mapVector.getValueCount(); j++) { + if (values[j] == null) { + assertTrue(mapVector.isNull(j)); + } else { + JsonStringArrayList> actualSource = + (JsonStringArrayList>) mapVector.getObject(j); + Map actualMap = null; + if (actualSource != null && !actualSource.isEmpty()) { + actualMap = actualSource.stream().map(entry -> + new AbstractMap.SimpleEntry<>(entry.get("key").toString(), + entry.get("value") != null ? entry.get("value").toString() : null)) + .collect(HashMap::new, (collector, val) -> collector.put(val.getKey(), val.getValue()), HashMap::putAll); + } + assertEquals(values[j], actualMap); + } + } + } + + public static Map[] getMapValues(String[] values, String dataType) { + String[] dataArr = getValues(values, dataType); + Map[] maps = new Map[dataArr.length]; + ObjectMapper objectMapper = ObjectMapperFactory.newObjectMapper(); + TypeReference> typeReference = new TypeReference>() {}; + for (int idx = 0; idx < dataArr.length; idx++) { + String jsonString = dataArr[idx].replace("|", ","); + if (!jsonString.isEmpty()) { + try { + maps[idx] = objectMapper.readValue(jsonString, typeReference); + } catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + } + } + return maps; + } + public static void assertNullValues(BaseValueVector vector, int rowCount) { assertEquals(rowCount, vector.getValueCount()); diff --git a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowCharSetTest.java b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowCharSetTest.java index b548c9169af..422b55070aa 100644 --- a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowCharSetTest.java +++ b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowCharSetTest.java @@ -109,22 +109,22 @@ public static Collection getTestData() throws SQLException, ClassNotFo @Test public void testJdbcToArrowValues() throws SQLException, IOException { testDataSets(sqlToArrow(conn, table.getQuery(), new RootAllocator(Integer.MAX_VALUE), - Calendar.getInstance())); - testDataSets(sqlToArrow(conn, table.getQuery(), new RootAllocator(Integer.MAX_VALUE))); + Calendar.getInstance()), false); + testDataSets(sqlToArrow(conn, table.getQuery(), new RootAllocator(Integer.MAX_VALUE)), false); testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()), - new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance())); - testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()))); + new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance()), false); + testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery())), false); testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()), - new RootAllocator(Integer.MAX_VALUE))); + new RootAllocator(Integer.MAX_VALUE)), false); testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()), - Calendar.getInstance())); + Calendar.getInstance()), false); testDataSets(sqlToArrow( conn.createStatement().executeQuery(table.getQuery()), - new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance()).build())); + new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance()).build()), false); testDataSets(sqlToArrow( conn, table.getQuery(), - new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance()).build())); + new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance()).build()), false); } @Test @@ -139,8 +139,10 @@ public void testJdbcSchemaMetadata() throws SQLException { * This method calls the assert methods for various DataSets. * * @param root VectorSchemaRoot for test + * @param isIncludeMapVector is this dataset checks includes map column. + * Jdbc type to 'map' mapping declared in configuration only manually */ - public void testDataSets(VectorSchemaRoot root) { + public void testDataSets(VectorSchemaRoot root, boolean isIncludeMapVector) { JdbcToArrowTestHelper.assertFieldMetadataIsEmpty(root); assertVarcharVectorValues((VarCharVector) root.getVector(CLOB), table.getRowCount(), diff --git a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowDataTypesTest.java b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowDataTypesTest.java index 9810cd2e796..ae4fffd0f94 100644 --- a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowDataTypesTest.java +++ b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowDataTypesTest.java @@ -147,25 +147,25 @@ public static Collection getTestData() throws SQLException, ClassNotFo @Test public void testJdbcToArrowValues() throws SQLException, IOException { testDataSets(sqlToArrow(conn, table.getQuery(), new RootAllocator(Integer.MAX_VALUE), - Calendar.getInstance())); - testDataSets(sqlToArrow(conn, table.getQuery(), new RootAllocator(Integer.MAX_VALUE))); + Calendar.getInstance()), false); + testDataSets(sqlToArrow(conn, table.getQuery(), new RootAllocator(Integer.MAX_VALUE)), false); testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()), - new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance())); - testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()))); + new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance()), false); + testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery())), false); testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()), - new RootAllocator(Integer.MAX_VALUE))); - testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()), Calendar.getInstance())); + new RootAllocator(Integer.MAX_VALUE)), false); + testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()), Calendar.getInstance()), false); testDataSets(sqlToArrow( conn.createStatement().executeQuery(table.getQuery()), new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance()) .setArraySubTypeByColumnNameMap(ARRAY_SUB_TYPE_BY_COLUMN_NAME_MAP) - .build())); + .build()), false); testDataSets(sqlToArrow( conn, table.getQuery(), new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance()) .setArraySubTypeByColumnNameMap(ARRAY_SUB_TYPE_BY_COLUMN_NAME_MAP) - .build())); + .build()), false); } @Test @@ -182,8 +182,10 @@ public void testJdbcSchemaMetadata() throws SQLException { * This method calls the assert methods for various DataSets. * * @param root VectorSchemaRoot for test + * @param isIncludeMapVector is this dataset checks includes map column. + * Jdbc type to 'map' mapping declared in configuration only manually */ - public void testDataSets(VectorSchemaRoot root) { + public void testDataSets(VectorSchemaRoot root, boolean isIncludeMapVector) { JdbcToArrowTestHelper.assertFieldMetadataIsEmpty(root); switch (table.getType()) { diff --git a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowMapDataTypeTest.java b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowMapDataTypeTest.java new file mode 100644 index 00000000000..43862a93c39 --- /dev/null +++ b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowMapDataTypeTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adapter.jdbc.h2; + +import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertMapVectorValues; +import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getMapValues; + +import java.io.IOException; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.Calendar; + +import org.apache.arrow.adapter.jdbc.AbstractJdbcToArrowTest; +import org.apache.arrow.adapter.jdbc.JdbcToArrowConfigBuilder; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.MapVector; +import org.junit.Test; + +/** + * Test MapConsumer with OTHER jdbc type. + */ +public class JdbcToArrowMapDataTypeTest extends AbstractJdbcToArrowTest { + + public JdbcToArrowMapDataTypeTest() throws IOException { + this.table = getTable("h2/test1_map_h2.yml", JdbcToArrowMapDataTypeTest.class); + } + + /** + * Test Method to test JdbcToArrow Functionality for Map form Types.OTHER column + */ + @Test + public void testJdbcToArrowValues() throws SQLException, IOException { + Calendar calendar = Calendar.getInstance(); + ResultSetMetaData rsmd = getQueryMetaData(table.getQuery()); + testDataSets(sqlToArrow( + conn.createStatement().executeQuery(table.getQuery()), + new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance()) + .setJdbcToArrowTypeConverter(jdbcToArrowTypeConverter(calendar, rsmd)) + .build()), true); + testDataSets(sqlToArrow( + conn, + table.getQuery(), + new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance()) + .setJdbcToArrowTypeConverter(jdbcToArrowTypeConverter(calendar, rsmd)) + .build()), true); + } + + /** + * This method calls the assert methods for various DataSets. + * + * @param root VectorSchemaRoot for test + * @param isIncludeMapVector is this dataset checks includes map column. + * Jdbc type to 'map' mapping declared in configuration only manually + */ + public void testDataSets(VectorSchemaRoot root, boolean isIncludeMapVector) { + assertMapVectorValues((MapVector) root.getVector(MAP), table.getRowCount(), + getMapValues(table.getValues(), MAP)); + } +} diff --git a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowNullTest.java b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowNullTest.java index e021b276fbe..5731f27c5b3 100644 --- a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowNullTest.java +++ b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowNullTest.java @@ -26,6 +26,7 @@ import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertFloat8VectorValues; import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertIntVectorValues; import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertListVectorValues; +import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertMapVectorValues; import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertNullValues; import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertSmallIntVectorValues; import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertTimeStampVectorValues; @@ -42,6 +43,7 @@ import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getIntValues; import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getListValues; import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getLongValues; +import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getMapValues; import java.io.IOException; import java.sql.ResultSetMetaData; @@ -72,6 +74,7 @@ import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Test; import org.junit.runner.RunWith; @@ -123,25 +126,29 @@ public static Collection getTestData() throws SQLException, ClassNotFo @Test public void testJdbcToArrowValues() throws SQLException, IOException { testDataSets(sqlToArrow(conn, table.getQuery(), new RootAllocator(Integer.MAX_VALUE), - Calendar.getInstance())); - testDataSets(sqlToArrow(conn, table.getQuery(), new RootAllocator(Integer.MAX_VALUE))); + Calendar.getInstance()), false); + testDataSets(sqlToArrow(conn, table.getQuery(), new RootAllocator(Integer.MAX_VALUE)), false); testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()), - new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance())); - testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()))); + new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance()), false); + testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery())), false); testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()), - new RootAllocator(Integer.MAX_VALUE))); - testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()), Calendar.getInstance())); + new RootAllocator(Integer.MAX_VALUE)), false); + testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()), Calendar.getInstance()), false); + Calendar calendar = Calendar.getInstance(); + ResultSetMetaData rsmd = getQueryMetaData(table.getQuery()); testDataSets(sqlToArrow( conn.createStatement().executeQuery(table.getQuery()), new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance()) .setArraySubTypeByColumnNameMap(ARRAY_SUB_TYPE_BY_COLUMN_NAME_MAP) - .build())); + .setJdbcToArrowTypeConverter(jdbcToArrowTypeConverter(calendar, rsmd)) + .build()), true); testDataSets(sqlToArrow( conn, table.getQuery(), new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance()) .setArraySubTypeByColumnNameMap(ARRAY_SUB_TYPE_BY_COLUMN_NAME_MAP) - .build())); + .setJdbcToArrowTypeConverter(jdbcToArrowTypeConverter(calendar, rsmd)) + .build()), true); } @Test @@ -158,8 +165,10 @@ public void testJdbcSchemaMetadata() throws SQLException { * This method calls the assert methods for various DataSets. * * @param root VectorSchemaRoot for test + * @param isIncludeMapVector is this dataset checks includes map column. + * Jdbc type to 'map' mapping declared in configuration only manually */ - public void testDataSets(VectorSchemaRoot root) { + public void testDataSets(VectorSchemaRoot root, boolean isIncludeMapVector) { JdbcToArrowTestHelper.assertFieldMetadataIsEmpty(root); switch (table.getType()) { @@ -167,10 +176,10 @@ public void testDataSets(VectorSchemaRoot root) { sqlToArrowTestNullValues(table.getVectors(), root, table.getRowCount()); break; case SELECTED_NULL_COLUMN: - sqlToArrowTestSelectedNullColumnsValues(table.getVectors(), root, table.getRowCount()); + sqlToArrowTestSelectedNullColumnsValues(table.getVectors(), root, table.getRowCount(), isIncludeMapVector); break; case SELECTED_NULL_ROW: - testAllVectorValues(root); + testAllVectorValues(root, isIncludeMapVector); break; default: // do nothing @@ -178,7 +187,7 @@ public void testDataSets(VectorSchemaRoot root) { } } - private void testAllVectorValues(VectorSchemaRoot root) { + private void testAllVectorValues(VectorSchemaRoot root, boolean isIncludeMapVector) { JdbcToArrowTestHelper.assertFieldMetadataIsEmpty(root); assertBigIntVectorValues((BigIntVector) root.getVector(BIGINT), table.getRowCount(), @@ -234,6 +243,10 @@ private void testAllVectorValues(VectorSchemaRoot root) { assertListVectorValues((ListVector) root.getVector(LIST), table.getRowCount(), getListValues(table.getValues(), LIST)); + if (isIncludeMapVector) { + assertMapVectorValues((MapVector) root.getVector(MAP), table.getRowCount(), + getMapValues(table.getValues(), MAP)); + } } /** @@ -270,8 +283,11 @@ public void sqlToArrowTestNullValues(String[] vectors, VectorSchemaRoot root, in * @param vectors Vectors to test * @param root VectorSchemaRoot for test * @param rowCount number of rows + * @param isIncludeMapVector is this dataset checks includes map column. + * Jdbc type to 'map' mapping declared in configuration only manually */ - public void sqlToArrowTestSelectedNullColumnsValues(String[] vectors, VectorSchemaRoot root, int rowCount) { + public void sqlToArrowTestSelectedNullColumnsValues(String[] vectors, VectorSchemaRoot root, int rowCount, + boolean isIncludeMapVector) { assertNullValues((BigIntVector) root.getVector(vectors[0]), rowCount); assertNullValues((DecimalVector) root.getVector(vectors[1]), rowCount); assertNullValues((Float8Vector) root.getVector(vectors[2]), rowCount); @@ -286,6 +302,9 @@ public void sqlToArrowTestSelectedNullColumnsValues(String[] vectors, VectorSche assertNullValues((VarCharVector) root.getVector(vectors[11]), rowCount); assertNullValues((BitVector) root.getVector(vectors[12]), rowCount); assertNullValues((ListVector) root.getVector(vectors[13]), rowCount); + if (isIncludeMapVector) { + assertNullValues((MapVector) root.getVector(vectors[14]), rowCount); + } } } diff --git a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowOptionalColumnsTest.java b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowOptionalColumnsTest.java index 84960dc8880..eebcbe64c0e 100644 --- a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowOptionalColumnsTest.java +++ b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowOptionalColumnsTest.java @@ -71,7 +71,7 @@ public static Collection getTestData() throws SQLException, ClassNotFo */ @Test public void testJdbcToArrowValues() throws SQLException, IOException { - testDataSets(sqlToArrow(conn, table.getQuery(), new RootAllocator(Integer.MAX_VALUE))); + testDataSets(sqlToArrow(conn, table.getQuery(), new RootAllocator(Integer.MAX_VALUE)), false); } /** @@ -79,8 +79,10 @@ public void testJdbcToArrowValues() throws SQLException, IOException { * nullable in the VectorSchemaRoot, and that a SQL `NOT NULL` column becomes non-nullable. * * @param root VectorSchemaRoot for test + * @param isIncludeMapVector is this dataset checks includes map column. + * Jdbc type to 'map' mapping declared in configuration only manually */ - public void testDataSets(VectorSchemaRoot root) { + public void testDataSets(VectorSchemaRoot root, boolean isIncludeMapVector) { JdbcToArrowTestHelper.assertFieldMetadataIsEmpty(root); assertTrue(root.getSchema().getFields().get(0).isNullable()); diff --git a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowTest.java b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowTest.java index f9cf72d5dd1..7641fa7f165 100644 --- a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowTest.java +++ b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowTest.java @@ -17,31 +17,7 @@ package org.apache.arrow.adapter.jdbc.h2; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertBigIntVectorValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertBitVectorValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertBooleanVectorValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertDateVectorValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertDecimalVectorValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertFloat4VectorValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertFloat8VectorValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertIntVectorValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertListVectorValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertNullVectorValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertSmallIntVectorValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertTimeStampVectorValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertTimeVectorValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertTinyIntVectorValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertVarBinaryVectorValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.assertVarcharVectorValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getBinaryValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getBooleanValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getCharArray; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getDecimalValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getDoubleValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getFloatValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getIntValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getListValues; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getLongValues; +import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.*; import static org.junit.Assert.assertEquals; import java.io.IOException; @@ -81,6 +57,7 @@ import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.types.pojo.Schema; import org.junit.Test; import org.junit.runner.RunWith; @@ -126,35 +103,41 @@ public static Collection getTestData() throws SQLException, ClassNotFo @Test public void testJdbcToArrowValues() throws SQLException, IOException { testDataSets(sqlToArrow(conn, table.getQuery(), new RootAllocator(Integer.MAX_VALUE), - Calendar.getInstance())); - testDataSets(sqlToArrow(conn, table.getQuery(), new RootAllocator(Integer.MAX_VALUE))); + Calendar.getInstance()), false); + testDataSets(sqlToArrow(conn, table.getQuery(), new RootAllocator(Integer.MAX_VALUE)), false); testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()), - new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance())); - testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()))); + new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance()), false); + testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery())), false); testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()), - new RootAllocator(Integer.MAX_VALUE))); + new RootAllocator(Integer.MAX_VALUE)), false); testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()), - Calendar.getInstance())); + Calendar.getInstance()), false); + Calendar calendar = Calendar.getInstance(); + ResultSetMetaData rsmd = getQueryMetaData(table.getQuery()); testDataSets(sqlToArrow( conn.createStatement().executeQuery(table.getQuery()), new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance()) .setArraySubTypeByColumnNameMap(ARRAY_SUB_TYPE_BY_COLUMN_NAME_MAP) - .build())); + .setJdbcToArrowTypeConverter(jdbcToArrowTypeConverter(calendar, rsmd)) + .build()), true); testDataSets(sqlToArrow( conn, table.getQuery(), - new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance()) + new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), calendar) .setArraySubTypeByColumnNameMap(ARRAY_SUB_TYPE_BY_COLUMN_NAME_MAP) - .build())); + .setJdbcToArrowTypeConverter(jdbcToArrowTypeConverter(calendar, rsmd)) + .build()), true); } @Test public void testJdbcSchemaMetadata() throws SQLException { - JdbcToArrowConfig config = new JdbcToArrowConfigBuilder(new RootAllocator(0), Calendar.getInstance(), true) + Calendar calendar = Calendar.getInstance(); + ResultSetMetaData rsmd = getQueryMetaData(table.getQuery()); + JdbcToArrowConfig config = new JdbcToArrowConfigBuilder(new RootAllocator(0), calendar, true) .setReuseVectorSchemaRoot(reuseVectorSchemaRoot) + .setJdbcToArrowTypeConverter(jdbcToArrowTypeConverter(calendar, rsmd)) .setArraySubTypeByColumnNameMap(ARRAY_SUB_TYPE_BY_COLUMN_NAME_MAP) .build(); - ResultSetMetaData rsmd = conn.createStatement().executeQuery(table.getQuery()).getMetaData(); Schema schema = JdbcToArrowUtils.jdbcToArrowSchema(rsmd, config); JdbcToArrowTestHelper.assertFieldMetadataMatchesResultSetMetadata(rsmd, schema); } @@ -163,10 +146,11 @@ public void testJdbcSchemaMetadata() throws SQLException { * This method calls the assert methods for various DataSets. * * @param root VectorSchemaRoot for test + * @param isIncludeMapVector is this dataset checks includes map column. + * Jdbc type to 'map' mapping declared in configuration only manually */ - public void testDataSets(VectorSchemaRoot root) { + public void testDataSets(VectorSchemaRoot root, boolean isIncludeMapVector) { JdbcToArrowTestHelper.assertFieldMetadataIsEmpty(root); - assertBigIntVectorValues((BigIntVector) root.getVector(BIGINT), table.getRowCount(), getLongValues(table.getValues(), BIGINT)); @@ -222,6 +206,11 @@ public void testDataSets(VectorSchemaRoot root) { assertListVectorValues((ListVector) root.getVector(LIST), table.getRowCount(), getListValues(table.getValues(), LIST)); + + if (isIncludeMapVector) { + assertMapVectorValues((MapVector) root.getVector(MAP), table.getRowCount(), + getMapValues(table.getValues(), MAP)); + } } @Test diff --git a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowTimeZoneTest.java b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowTimeZoneTest.java index f5ddbdb9bf0..462a75da514 100644 --- a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowTimeZoneTest.java +++ b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowTimeZoneTest.java @@ -105,22 +105,22 @@ public static Collection getTestData() throws SQLException, ClassNotFo @Test public void testJdbcToArrowValues() throws SQLException, IOException { testDataSets(sqlToArrow(conn, table.getQuery(), new RootAllocator(Integer.MAX_VALUE), - Calendar.getInstance(TimeZone.getTimeZone(table.getTimezone())))); + Calendar.getInstance(TimeZone.getTimeZone(table.getTimezone()))), false); testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()), - new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance(TimeZone.getTimeZone(table.getTimezone())))); + new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance(TimeZone.getTimeZone(table.getTimezone()))), false); testDataSets(sqlToArrow(conn.createStatement().executeQuery(table.getQuery()), - Calendar.getInstance(TimeZone.getTimeZone(table.getTimezone())))); + Calendar.getInstance(TimeZone.getTimeZone(table.getTimezone()))), false); testDataSets(sqlToArrow( conn.createStatement().executeQuery(table.getQuery()), new JdbcToArrowConfigBuilder( new RootAllocator(Integer.MAX_VALUE), - Calendar.getInstance(TimeZone.getTimeZone(table.getTimezone()))).build())); + Calendar.getInstance(TimeZone.getTimeZone(table.getTimezone()))).build()), false); testDataSets(sqlToArrow( conn, table.getQuery(), new JdbcToArrowConfigBuilder( new RootAllocator(Integer.MAX_VALUE), - Calendar.getInstance(TimeZone.getTimeZone(table.getTimezone()))).build())); + Calendar.getInstance(TimeZone.getTimeZone(table.getTimezone()))).build()), false); } @Test @@ -136,8 +136,10 @@ public void testJdbcSchemaMetadata() throws SQLException { * This method calls the assert methods for various DataSets. * * @param root VectorSchemaRoot for test + * @param isIncludeMapVector is this dataset checks includes map column. + * Jdbc type to 'map' mapping declared in configuration only manually */ - public void testDataSets(VectorSchemaRoot root) { + public void testDataSets(VectorSchemaRoot root, boolean isIncludeMapVector) { JdbcToArrowTestHelper.assertFieldMetadataIsEmpty(root); switch (table.getType()) { diff --git a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowVectorIteratorTest.java b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowVectorIteratorTest.java index 84ec3a45620..762f6e764b4 100644 --- a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowVectorIteratorTest.java +++ b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/h2/JdbcToArrowVectorIteratorTest.java @@ -516,7 +516,7 @@ public void testJdbcToArrowCustomTypeConversion() throws SQLException, IOExcepti } // second experiment, using customized type converter - builder.setJdbcToArrowTypeConverter(fieldInfo -> { + builder.setJdbcToArrowTypeConverter((fieldInfo) -> { switch (fieldInfo.getJdbcType()) { case Types.REAL: // this is different from the default type converter diff --git a/java/adapter/jdbc/src/test/resources/h2/test1_all_datatypes_h2.yml b/java/adapter/jdbc/src/test/resources/h2/test1_all_datatypes_h2.yml index 45b8c9c7713..ff76acf8d7c 100644 --- a/java/adapter/jdbc/src/test/resources/h2/test1_all_datatypes_h2.yml +++ b/java/adapter/jdbc/src/test/resources/h2/test1_all_datatypes_h2.yml @@ -14,61 +14,61 @@ name: 'test1_all_datatypes_h2' create: 'CREATE TABLE table1 (int_field1 INT, bool_field2 BOOLEAN, tinyint_field3 TINYINT, smallint_field4 SMALLINT, bigint_field5 BIGINT, decimal_field6 DECIMAL(20,2), double_field7 DOUBLE, real_field8 REAL, time_field9 TIME, date_field10 DATE, timestamp_field11 TIMESTAMP, binary_field12 BINARY(100), varchar_field13 VARCHAR(256), blob_field14 BLOB, clob_field15 CLOB, char_field16 CHAR(16), bit_field17 BIT, - null_field18 NULL, list_field19 ARRAY);' + null_field18 NULL, list_field19 ARRAY, map_field20 VARCHAR(256));' data: - 'INSERT INTO table1 VALUES (101, 1, 45, 12000, 92233720, 17345667789.23, 56478356785.345, 56478356785.345, PARSEDATETIME(''12:45:35 GMT'', ''HH:mm:ss z''), PARSEDATETIME(''2018-02-12 GMT'', ''yyyy-MM-dd z''), PARSEDATETIME(''2018-02-12 12:45:35 GMT'', ''yyyy-MM-dd HH:mm:ss z''), ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to varchar'', - ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (1, 2, 3));' + ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (1, 2, 3), ''{"a":"b","key":"12345"}'');' - 'INSERT INTO table1 VALUES (102, 1, 45, 12000, 92233720, 17345667789.23, 56478356785.345, 56478356785.345, PARSEDATETIME(''12:45:35 GMT'', ''HH:mm:ss z''), PARSEDATETIME(''2018-02-12 GMT'', ''yyyy-MM-dd z''), PARSEDATETIME(''2018-02-12 12:45:35 GMT'', ''yyyy-MM-dd HH:mm:ss z''), ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to varchar'', - ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (1, 2));' + ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (1, 2),''{"c":"d"}'');' - 'INSERT INTO table1 VALUES (103, 1, 45, 12000, 92233720, 17345667789.23, 56478356785.345, 56478356785.345, PARSEDATETIME(''12:45:35 GMT'', ''HH:mm:ss z''), PARSEDATETIME(''2018-02-12 GMT'', ''yyyy-MM-dd z''), PARSEDATETIME(''2018-02-12 12:45:35 GMT'', ''yyyy-MM-dd HH:mm:ss z''), ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to varchar'', - ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (1));' + ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (1),''{"e":"f"}'');' - 'INSERT INTO table1 VALUES (104, 1, 45, 12000, 92233720, 17345667789.23, 56478356785.345, 56478356785.345, PARSEDATETIME(''12:45:35 GMT'', ''HH:mm:ss z''), PARSEDATETIME(''2018-02-12 GMT'', ''yyyy-MM-dd z''), PARSEDATETIME(''2018-02-12 12:45:35 GMT'', ''yyyy-MM-dd HH:mm:ss z''), ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to varchar'', - ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (2, 3, 4));' + ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (2, 3, 4),''{"g":"h"}'');' - 'INSERT INTO table1 VALUES (null, 1, 45, 12000, 92233720, 17345667789.23, 56478356785.345, 56478356785.345, PARSEDATETIME(''12:45:35 GMT'', ''HH:mm:ss z''), PARSEDATETIME(''2018-02-12 GMT'', ''yyyy-MM-dd z''), PARSEDATETIME(''2018-02-12 12:45:35 GMT'', ''yyyy-MM-dd HH:mm:ss z''), ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to varchar'', - ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (2, 3));' + ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (2, 3),''{"i":"j"}'');' - 'INSERT INTO table1 VALUES (null, 1, 45, 12000, 92233720, 17345667789.23, 56478356785.345, 56478356785.345, PARSEDATETIME(''12:45:35 GMT'', ''HH:mm:ss z''), PARSEDATETIME(''2018-02-12 GMT'', ''yyyy-MM-dd z''), PARSEDATETIME(''2018-02-12 12:45:35 GMT'', ''yyyy-MM-dd HH:mm:ss z''), ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to varchar'', - ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (2));' + ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (2),''{"k":"l"}'');' - 'INSERT INTO table1 VALUES (107, 1, 45, 12000, 92233720, 17345667789.23, 56478356785.345, 56478356785.345, PARSEDATETIME(''12:45:35 GMT'', ''HH:mm:ss z''), PARSEDATETIME(''2018-02-12 GMT'', ''yyyy-MM-dd z''), PARSEDATETIME(''2018-02-12 12:45:35 GMT'', ''yyyy-MM-dd HH:mm:ss z''), ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to varchar'', - ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (3, 4, 5));' + ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (3, 4, 5),''{"m":"n"}'');' - 'INSERT INTO table1 VALUES (108, 1, 45, 12000, 92233720, 17345667789.23, 56478356785.345, 56478356785.345, PARSEDATETIME(''12:45:35 GMT'', ''HH:mm:ss z''), PARSEDATETIME(''2018-02-12 GMT'', ''yyyy-MM-dd z''), PARSEDATETIME(''2018-02-12 12:45:35 GMT'', ''yyyy-MM-dd HH:mm:ss z''), ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to varchar'', - ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (3, 4));' + ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (3, 4),''{"o":"p"}'');' - 'INSERT INTO table1 VALUES (109, 1, 45, 12000, 92233720, 17345667789.23, 56478356785.345, 56478356785.345, PARSEDATETIME(''12:45:35 GMT'', ''HH:mm:ss z''), PARSEDATETIME(''2018-02-12 GMT'', ''yyyy-MM-dd z''), PARSEDATETIME(''2018-02-12 12:45:35 GMT'', ''yyyy-MM-dd HH:mm:ss z''), ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to varchar'', - ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (3));' + ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (3),''{"q":"r"}'');' - 'INSERT INTO table1 VALUES (110, 1, 45, 12000, 92233720, 17345667789.23, 56478356785.345, 56478356785.345, PARSEDATETIME(''12:45:35 GMT'', ''HH:mm:ss z''), PARSEDATETIME(''2018-02-12 GMT'', ''yyyy-MM-dd z''), PARSEDATETIME(''2018-02-12 12:45:35 GMT'', ''yyyy-MM-dd HH:mm:ss z''), ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to varchar'', - ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, ());' + ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', 1, null, (),''{"s":"t"}'');' query: 'select int_field1, bool_field2, tinyint_field3, smallint_field4, bigint_field5, decimal_field6, double_field7, real_field8, - time_field9, date_field10, timestamp_field11, binary_field12, varchar_field13, blob_field14, clob_field15, char_field16, bit_field17, null_field18, list_field19 from table1' + time_field9, date_field10, timestamp_field11, binary_field12, varchar_field13, blob_field14, clob_field15, char_field16, bit_field17, null_field18, list_field19, map_field20 from table1' drop: 'DROP table table1;' @@ -119,4 +119,5 @@ values: some text that needs to be converted to clob,some text that needs to be converted to clob, some text that needs to be converted to clob,some text that needs to be converted to clob, some text that needs to be converted to clob,some text that needs to be converted to clob' - - 'LIST_FIELD19=(1;2;3),(1;2),(1),(2;3;4),(2;3),(2),(3;4;5),(3;4),(3),()' \ No newline at end of file + - 'LIST_FIELD19=(1;2;3),(1;2),(1),(2;3;4),(2;3),(2),(3;4;5),(3;4),(3),()' + - 'MAP_FIELD20={"a":"b"|"key":"12345"},{"c":"d"},{"e":"f"},{"g":"h"},{"i":"j"},{"k":"l"},{"m":"n"},{"o":"p"},{"q":"r"},{"s":"t"}' \ No newline at end of file diff --git a/java/adapter/jdbc/src/test/resources/h2/test1_all_datatypes_null_h2.yml b/java/adapter/jdbc/src/test/resources/h2/test1_all_datatypes_null_h2.yml index 1edcc556334..e1b1a1adcbb 100644 --- a/java/adapter/jdbc/src/test/resources/h2/test1_all_datatypes_null_h2.yml +++ b/java/adapter/jdbc/src/test/resources/h2/test1_all_datatypes_null_h2.yml @@ -32,20 +32,21 @@ vectors: - 'CHAR_FIELD16' - 'BIT_FIELD17' - 'LIST_FIELD19' + - 'MAP_FIELD20' rowCount: '5' create: 'CREATE TABLE table1 (int_field1 INT, bool_field2 BOOLEAN, tinyint_field3 TINYINT, smallint_field4 SMALLINT, bigint_field5 BIGINT, decimal_field6 DECIMAL(20,2), double_field7 DOUBLE, real_field8 REAL, time_field9 TIME, date_field10 DATE, timestamp_field11 TIMESTAMP, binary_field12 BINARY(100), varchar_field13 VARCHAR(256), blob_field14 BLOB, clob_field15 CLOB, char_field16 CHAR(16), bit_field17 BIT, - list_field19 ARRAY);' + list_field19 ARRAY,map_field20 VARCHAR(256));' data: - - 'INSERT INTO table1 VALUES (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);' - - 'INSERT INTO table1 VALUES (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);' - - 'INSERT INTO table1 VALUES (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);' - - 'INSERT INTO table1 VALUES (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);' - - 'INSERT INTO table1 VALUES (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);' + - 'INSERT INTO table1 VALUES (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);' + - 'INSERT INTO table1 VALUES (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);' + - 'INSERT INTO table1 VALUES (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);' + - 'INSERT INTO table1 VALUES (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);' + - 'INSERT INTO table1 VALUES (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);' query: 'select int_field1, bool_field2, tinyint_field3, smallint_field4, bigint_field5, decimal_field6, double_field7, real_field8, time_field9, date_field10, timestamp_field11, binary_field12, varchar_field13, blob_field14, clob_field15, char_field16, bit_field17, diff --git a/java/adapter/jdbc/src/test/resources/h2/test1_all_datatypes_selected_null_rows_h2.yml b/java/adapter/jdbc/src/test/resources/h2/test1_all_datatypes_selected_null_rows_h2.yml index c07ab7d4c0f..0521ce2f9c3 100644 --- a/java/adapter/jdbc/src/test/resources/h2/test1_all_datatypes_selected_null_rows_h2.yml +++ b/java/adapter/jdbc/src/test/resources/h2/test1_all_datatypes_selected_null_rows_h2.yml @@ -32,34 +32,35 @@ vectors: - 'CHAR_FIELD16' - 'BIT_FIELD17' - 'LIST_FIELD19' + - 'MAP_FIELD20' create: 'CREATE TABLE table1 (int_field1 INT, bool_field2 BOOLEAN, tinyint_field3 TINYINT, smallint_field4 SMALLINT, bigint_field5 BIGINT, decimal_field6 DECIMAL(20,2), double_field7 DOUBLE, real_field8 REAL, time_field9 TIME, date_field10 DATE, timestamp_field11 TIMESTAMP, binary_field12 BINARY(100), varchar_field13 VARCHAR(256), blob_field14 BLOB, clob_field15 CLOB, char_field16 CHAR(16), bit_field17 BIT, - list_field19 ARRAY);' + list_field19 ARRAY, map_field20 VARCHAR(256));' data: - - 'INSERT INTO table1 VALUES (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);' + - 'INSERT INTO table1 VALUES (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);' - 'INSERT INTO table1 VALUES (101, 1, 45, 12000, 92233720, 17345667789.23, 56478356785.345, 56478356785.345, PARSEDATETIME(''12:45:35 GMT'', ''HH:mm:ss z''), PARSEDATETIME(''2018-02-12 GMT'', ''yyyy-MM-dd z''), PARSEDATETIME(''2018-02-12 12:45:35 GMT'', ''yyyy-MM-dd HH:mm:ss z''), ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to varchar'', ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', - 1, (1, 2, 3));' + 1, (1, 2, 3),''{"a":"b"}'');' - - 'INSERT INTO table1 VALUES (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);' + - 'INSERT INTO table1 VALUES (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);' - 'INSERT INTO table1 VALUES (101, 1, 45, 12000, 92233720, 17345667789.23, 56478356785.345, 56478356785.345, PARSEDATETIME(''12:45:35 GMT'', ''HH:mm:ss z''), PARSEDATETIME(''2018-02-12 GMT'', ''yyyy-MM-dd z''), PARSEDATETIME(''2018-02-12 12:45:35 GMT'', ''yyyy-MM-dd HH:mm:ss z''), ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to varchar'', ''736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279'', ''some text that needs to be converted to clob'', ''some char text'', - 1, (1, 2, 3));' + 1, (1, 2, 3),''{"c":"d"}'');' - - 'INSERT INTO table1 VALUES (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);' + - 'INSERT INTO table1 VALUES (null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null);' query: 'select int_field1, bool_field2, tinyint_field3, smallint_field4, bigint_field5, decimal_field6, double_field7, real_field8, time_field9, date_field10, timestamp_field11, binary_field12, varchar_field13, blob_field14, clob_field15, char_field16, bit_field17, - list_field19 from table1' + list_field19, map_field20 from table1' drop: 'DROP table table1;' @@ -87,3 +88,4 @@ values: null,736f6d6520746578742074686174206e6565647320746f20626520636f6e76657274656420746f2062696e617279,null' - 'CLOB_FIELD15=null,some text that needs to be converted to clob,null,some text that needs to be converted to clob,null' - 'LIST_FIELD19=null,(1;2;3),null,(1;2;3),null' + - 'MAP_FIELD20=null,{"a":"b"},null,{"c":"d"},null' diff --git a/java/adapter/jdbc/src/test/resources/h2/test1_map_h2.yml b/java/adapter/jdbc/src/test/resources/h2/test1_map_h2.yml new file mode 100644 index 00000000000..a1800d20af6 --- /dev/null +++ b/java/adapter/jdbc/src/test/resources/h2/test1_map_h2.yml @@ -0,0 +1,33 @@ +#Licensed to the Apache Software Foundation (ASF) under one or more contributor +#license agreements. See the NOTICE file distributed with this work for additional +#information regarding copyright ownership. The ASF licenses this file to +#You under the Apache License, Version 2.0 (the "License"); you may not use +#this file except in compliance with the License. You may obtain a copy of +#the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required +#by applicable law or agreed to in writing, software distributed under the +#License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS +#OF ANY KIND, either express or implied. See the License for the specific +#language governing permissions and limitations under the License. + +name: 'test1_map_h2' + +type: 'map' + +vector: 'MAP_FIELD20' + +create: 'CREATE TABLE table1 (map_field20 OTHER);' + +rowCount: '4' + +data: + - 'INSERT INTO table1 VALUES (X''aced00057372002e6f72672e6170616368652e6172726f772e766563746f722e7574696c2e4a736f6e537472696e67486173684d61709819d7169e7a2ecf020000787200176a6176612e7574696c2e4c696e6b6564486173684d617034c04e5c106cc0fb0200015a000b6163636573734f72646572787200116a6176612e7574696c2e486173684d61700507dac1c31660d103000246000a6c6f6164466163746f724900097468726573686f6c6478703f4000000000000c7708000000100000000374000161740001627400033132337400067177657274797400057a78637662740001217800'');' + - 'INSERT INTO table1 VALUES (X''aced00057372002e6f72672e6170616368652e6172726f772e766563746f722e7574696c2e4a736f6e537472696e67486173684d61709819d7169e7a2ecf020000787200176a6176612e7574696c2e4c696e6b6564486173684d617034c04e5c106cc0fb0200015a000b6163636573734f72646572787200116a6176612e7574696c2e486173684d61700507dac1c31660d103000246000a6c6f6164466163746f724900097468726573686f6c6478703f4000000000000c77080000001000000003740001617400016274000163740001647400033132337400067177657274797800'');' + - 'INSERT INTO table1 VALUES (X''aced00057372002e6f72672e6170616368652e6172726f772e766563746f722e7574696c2e4a736f6e537472696e67486173684d61709819d7169e7a2ecf020000787200176a6176612e7574696c2e4c696e6b6564486173684d617034c04e5c106cc0fb0200015a000b6163636573734f72646572787200116a6176612e7574696c2e486173684d61700507dac1c31660d103000246000a6c6f6164466163746f724900097468726573686f6c6478703f4000000000000c7708000000100000000174000074000576616c75657800'');' + - 'INSERT INTO table1 VALUES (X''aced00057372002e6f72672e6170616368652e6172726f772e766563746f722e7574696c2e4a736f6e537472696e67486173684d61709819d7169e7a2ecf020000787200176a6176612e7574696c2e4c696e6b6564486173684d617034c04e5c106cc0fb0200015a000b6163636573734f72646572787200116a6176612e7574696c2e486173684d61700507dac1c31660d103000246000a6c6f6164466163746f724900097468726573686f6c6478703f4000000000000c7708000000100000000274000b6e6f6e456d7074794b65797074000c736f6d654f746865724b65797400007800'');' + +query: 'select map_field20 from table1;' + +drop: 'DROP table table1;' + +values: + - 'MAP_FIELD20={"a":"b"|"123":"qwerty"|"zxcvb":"!"},{"a":"b"|"123":"qwerty"|"c":"d"},{"":"value"},{"nonEmptyKey":null|"someOtherKey":""}' \ No newline at end of file diff --git a/java/adapter/jdbc/src/test/resources/h2/test1_selected_datatypes_null_h2.yml b/java/adapter/jdbc/src/test/resources/h2/test1_selected_datatypes_null_h2.yml index 16324de12a0..e8d1d5de02c 100644 --- a/java/adapter/jdbc/src/test/resources/h2/test1_selected_datatypes_null_h2.yml +++ b/java/adapter/jdbc/src/test/resources/h2/test1_selected_datatypes_null_h2.yml @@ -28,13 +28,14 @@ vectors: - 'CHAR_FIELD16' - 'BIT_FIELD17' - 'LIST_FIELD19' + - 'MAP_FIELD20' rowCount: '5' create: 'CREATE TABLE table1 (int_field1 INT, bool_field2 BOOLEAN, tinyint_field3 TINYINT, smallint_field4 SMALLINT, bigint_field5 BIGINT, decimal_field6 DECIMAL(20,2), double_field7 DOUBLE, real_field8 REAL, time_field9 TIME, date_field10 DATE, timestamp_field11 TIMESTAMP, binary_field12 BINARY(100), varchar_field13 VARCHAR(256), blob_field14 BLOB, clob_field15 CLOB, char_field16 CHAR(16), bit_field17 BIT, - list_field19 ARRAY);' + list_field19 ARRAY, map_field20 VARCHAR(256));' data: - 'INSERT INTO table1 (int_field1, bool_field2, tinyint_field3, smallint_field4) VALUES (102, 0, 46, 12001);' @@ -43,6 +44,6 @@ data: - 'INSERT INTO table1 (int_field1, bool_field2, tinyint_field3, smallint_field4) VALUES (102, 0, 46, 12001);' - 'INSERT INTO table1 (int_field1, bool_field2, tinyint_field3, smallint_field4) VALUES (102, 0, 46, 12001);' -query: 'select bigint_field5, decimal_field6, double_field7, real_field8, time_field9, date_field10, timestamp_field11, binary_field12, varchar_field13, blob_field14, clob_field15, char_field16, bit_field17, list_field19 from table1' +query: 'select bigint_field5, decimal_field6, double_field7, real_field8, time_field9, date_field10, timestamp_field11, binary_field12, varchar_field13, blob_field14, clob_field15, char_field16, bit_field17, list_field19, map_field20 from table1' drop: 'DROP table table1;' \ No newline at end of file From 8d11e3ddb69db4499707c0aff2e6e35f49f7dfb2 Mon Sep 17 00:00:00 2001 From: Rok Mihevc Date: Wed, 21 Sep 2022 15:08:36 +0200 Subject: [PATCH 116/133] ARROW-16981: [C++] Expose jemalloc statistics for logging (#13516) This is to resolve [ARROW-16981](https://issues.apache.org/jira/browse/ARROW-16981). Lead-authored-by: Rok Co-authored-by: Rok Mihevc Signed-off-by: Antoine Pitrou --- cpp/src/arrow/memory_pool.cc | 23 ++++++- cpp/src/arrow/memory_pool.h | 33 +++++++++ cpp/src/arrow/memory_pool_jemalloc.cc | 52 ++++++++++++++ cpp/src/arrow/memory_pool_test.cc | 98 ++++++++++++++++++++++++++- 4 files changed, 204 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/memory_pool.cc b/cpp/src/arrow/memory_pool.cc index 99cb0682462..638bbb3ab7f 100644 --- a/cpp/src/arrow/memory_pool.cc +++ b/cpp/src/arrow/memory_pool.cc @@ -667,8 +667,29 @@ MemoryPool* default_memory_pool() { #ifndef ARROW_JEMALLOC Status jemalloc_set_decay_ms(int ms) { - return Status::Invalid("jemalloc support is not built"); + return Status::NotImplemented("jemalloc support is not built"); } + +Result jemalloc_get_stat(const char* name) { + return Status::NotImplemented("jemalloc support is not built"); +} + +Status jemalloc_peak_reset() { + return Status::NotImplemented("jemalloc support is not built"); +} + +Status jemalloc_stats_print(const char* opts) { + return Status::NotImplemented("jemalloc support is not built"); +} + +Status jemalloc_stats_print(std::function write_cb, const char* opts) { + return Status::NotImplemented("jemalloc support is not built"); +} + +Result jemalloc_stats_string(const char* opts) { + return Status::NotImplemented("jemalloc support is not built"); +} + #endif /////////////////////////////////////////////////////////////////////// diff --git a/cpp/src/arrow/memory_pool.h b/cpp/src/arrow/memory_pool.h index 58b375af3a9..dba55268d69 100644 --- a/cpp/src/arrow/memory_pool.h +++ b/cpp/src/arrow/memory_pool.h @@ -19,9 +19,11 @@ #include #include +#include #include #include +#include "arrow/result.h" #include "arrow/status.h" #include "arrow/type_fwd.h" #include "arrow/util/visibility.h" @@ -175,6 +177,37 @@ ARROW_EXPORT Status jemalloc_memory_pool(MemoryPool** out); ARROW_EXPORT Status jemalloc_set_decay_ms(int ms); +/// \brief Get basic statistics from jemalloc's mallctl. +/// See the MALLCTL NAMESPACE section in jemalloc project documentation for +/// available stats. +ARROW_EXPORT +Result jemalloc_get_stat(const char* name); + +/// \brief Reset the counter for peak bytes allocated in the calling thread to zero. +/// This affects subsequent calls to thread.peak.read, but not the values returned by +/// thread.allocated or thread.deallocated. +ARROW_EXPORT +Status jemalloc_peak_reset(); + +/// \brief Print summary statistics in human-readable form to stderr. +/// See malloc_stats_print documentation in jemalloc project documentation for +/// available opt flags. +ARROW_EXPORT +Status jemalloc_stats_print(const char* opts = ""); + +/// \brief Print summary statistics in human-readable form using a callback +/// See malloc_stats_print documentation in jemalloc project documentation for +/// available opt flags. +ARROW_EXPORT +Status jemalloc_stats_print(std::function write_cb, + const char* opts = ""); + +/// \brief Get summary statistics in human-readable form. +/// See malloc_stats_print documentation in jemalloc project documentation for +/// available opt flags. +ARROW_EXPORT +Result jemalloc_stats_string(const char* opts = ""); + /// \brief Return a process-wide memory pool based on mimalloc. /// /// May return NotImplemented if mimalloc is not available. diff --git a/cpp/src/arrow/memory_pool_jemalloc.cc b/cpp/src/arrow/memory_pool_jemalloc.cc index 48a5bac137b..03d2b28ee3e 100644 --- a/cpp/src/arrow/memory_pool_jemalloc.cc +++ b/cpp/src/arrow/memory_pool_jemalloc.cc @@ -16,6 +16,7 @@ // under the License. #include "arrow/memory_pool_internal.h" +#include "arrow/util/io_util.h" #include "arrow/util/logging.h" // IWYU pragma: keep // We can't put the jemalloc memory pool implementation into @@ -153,4 +154,55 @@ Status jemalloc_set_decay_ms(int ms) { #undef RETURN_IF_JEMALLOC_ERROR +Result jemalloc_get_stat(const char* name) { + size_t sz = sizeof(uint64_t); + int err; + uint64_t value; + + // Update the statistics cached by mallctl. + if (std::strcmp(name, "stats.allocated") == 0 || + std::strcmp(name, "stats.active") == 0 || + std::strcmp(name, "stats.metadata") == 0 || + std::strcmp(name, "stats.resident") == 0 || + std::strcmp(name, "stats.mapped") == 0 || + std::strcmp(name, "stats.retained") == 0) { + uint64_t epoch; + mallctl("epoch", &epoch, &sz, &epoch, sz); + } + + err = mallctl(name, &value, &sz, nullptr, 0); + + if (err) { + return arrow::internal::IOErrorFromErrno(err, "Failed retrieving ", &name); + } + + return value; +} + +Status jemalloc_peak_reset() { + int err = mallctl("thread.peak.reset", nullptr, nullptr, nullptr, 0); + return err ? arrow::internal::IOErrorFromErrno(err, "Failed resetting thread.peak.") + : Status::OK(); +} + +Result jemalloc_stats_string(const char* opts) { + std::string stats; + auto write_cb = [&stats](const char* str) { stats.append(str); }; + ARROW_UNUSED(jemalloc_stats_print(write_cb, opts)); + return stats; +} + +Status jemalloc_stats_print(const char* opts) { + malloc_stats_print(nullptr, nullptr, opts); + return Status::OK(); +} + +Status jemalloc_stats_print(std::function write_cb, const char* opts) { + auto cb_wrapper = [](void* opaque, const char* str) { + (*static_cast*>(opaque))(str); + }; + malloc_stats_print(cb_wrapper, &write_cb, opts); + return Status::OK(); +} + } // namespace arrow diff --git a/cpp/src/arrow/memory_pool_test.cc b/cpp/src/arrow/memory_pool_test.cc index 591d86a23f5..5ac14a44b9a 100644 --- a/cpp/src/arrow/memory_pool_test.cc +++ b/cpp/src/arrow/memory_pool_test.cc @@ -25,6 +25,7 @@ #include "arrow/status.h" #include "arrow/testing/gtest_util.h" #include "arrow/util/config.h" +#include "arrow/util/logging.h" namespace arrow { @@ -168,7 +169,102 @@ TEST(Jemalloc, SetDirtyPageDecayMillis) { #ifdef ARROW_JEMALLOC ASSERT_OK(jemalloc_set_decay_ms(0)); #else - ASSERT_RAISES(Invalid, jemalloc_set_decay_ms(0)); + ASSERT_RAISES(NotImplemented, jemalloc_set_decay_ms(0)); +#endif +} + +TEST(Jemalloc, GetAllocationStats) { +#ifdef ARROW_JEMALLOC + uint8_t* data; + int64_t allocated, active, metadata, resident, mapped, retained, allocated0, active0, + metadata0, resident0, mapped0, retained0; + int64_t thread_allocated, thread_deallocated, thread_peak_read, thread_allocated0, + thread_deallocated0, thread_peak_read0; + MemoryPool* pool = nullptr; + ABORT_NOT_OK(jemalloc_memory_pool(&pool)); + ASSERT_EQ("jemalloc", pool->backend_name()); + + // Record stats before allocating + ASSERT_OK_AND_ASSIGN(allocated0, jemalloc_get_stat("stats.allocated")); + ASSERT_OK_AND_ASSIGN(active0, jemalloc_get_stat("stats.active")); + ASSERT_OK_AND_ASSIGN(metadata0, jemalloc_get_stat("stats.metadata")); + ASSERT_OK_AND_ASSIGN(resident0, jemalloc_get_stat("stats.resident")); + ASSERT_OK_AND_ASSIGN(mapped0, jemalloc_get_stat("stats.mapped")); + ASSERT_OK_AND_ASSIGN(retained0, jemalloc_get_stat("stats.retained")); + ASSERT_OK_AND_ASSIGN(thread_allocated0, jemalloc_get_stat("thread.allocated")); + ASSERT_OK_AND_ASSIGN(thread_deallocated0, jemalloc_get_stat("thread.deallocated")); + ASSERT_OK_AND_ASSIGN(thread_peak_read0, jemalloc_get_stat("thread.peak.read")); + + // Allocate memory + ASSERT_OK(pool->Allocate(1025, &data)); + ASSERT_EQ(pool->bytes_allocated(), 1025); + ASSERT_OK(pool->Reallocate(1025, 1023, &data)); + ASSERT_EQ(pool->bytes_allocated(), 1023); + + // Record stats after allocating + ASSERT_OK_AND_ASSIGN(allocated, jemalloc_get_stat("stats.allocated")); + ASSERT_OK_AND_ASSIGN(active, jemalloc_get_stat("stats.active")); + ASSERT_OK_AND_ASSIGN(metadata, jemalloc_get_stat("stats.metadata")); + ASSERT_OK_AND_ASSIGN(resident, jemalloc_get_stat("stats.resident")); + ASSERT_OK_AND_ASSIGN(mapped, jemalloc_get_stat("stats.mapped")); + ASSERT_OK_AND_ASSIGN(retained, jemalloc_get_stat("stats.retained")); + ASSERT_OK_AND_ASSIGN(thread_allocated, jemalloc_get_stat("thread.allocated")); + ASSERT_OK_AND_ASSIGN(thread_deallocated, jemalloc_get_stat("thread.deallocated")); + ASSERT_OK_AND_ASSIGN(thread_peak_read, jemalloc_get_stat("thread.peak.read")); + + // Check allocated stats pre-allocation + ASSERT_NEAR(allocated0, 120000, 100000); + ASSERT_NEAR(active0, 75000, 70000); + ASSERT_NEAR(metadata0, 3000000, 1000000); + ASSERT_NEAR(resident0, 3000000, 1000000); + ASSERT_NEAR(mapped0, 6500000, 1000000); + ASSERT_NEAR(retained0, 1500000, 2000000); + + // Check allocated stats change due to allocation + ASSERT_NEAR(allocated - allocated0, 70000, 50000); + ASSERT_NEAR(active - active0, 100000, 90000); + ASSERT_NEAR(metadata - metadata0, 500, 460); + ASSERT_NEAR(resident - resident0, 120000, 110000); + ASSERT_NEAR(mapped - mapped0, 100000, 90000); + ASSERT_NEAR(retained - retained0, 0, 40000); + + ASSERT_NEAR(thread_peak_read - thread_peak_read0, 1024, 700); + ASSERT_NEAR(thread_allocated - thread_allocated0, 2500, 500); + ASSERT_EQ(thread_deallocated - thread_deallocated0, 1280); + + // Resetting thread peak read metric + ASSERT_OK(pool->Allocate(12560, &data)); + ASSERT_OK_AND_ASSIGN(thread_peak_read, jemalloc_get_stat("thread.peak.read")); + ASSERT_NEAR(thread_peak_read, 15616, 1000); + ASSERT_OK(jemalloc_peak_reset()); + ASSERT_OK(pool->Allocate(1256, &data)); + ASSERT_OK_AND_ASSIGN(thread_peak_read, jemalloc_get_stat("thread.peak.read")); + ASSERT_NEAR(thread_peak_read, 1280, 100); + + // Print statistics to stderr + ASSERT_OK(jemalloc_stats_print("J")); + + // Read statistics into std::string + ASSERT_OK_AND_ASSIGN(std::string stats, jemalloc_stats_string("Jax")); + + // Read statistics into std::string with a lambda + std::string stats2; + auto write_cb = [&stats2](const char* str) { stats2.append(str); }; + ASSERT_OK(jemalloc_stats_print(write_cb, "Jax")); + + ASSERT_EQ(stats.rfind("{\"jemalloc\":{\"version\"", 0), 0); + ASSERT_EQ(stats2.rfind("{\"jemalloc\":{\"version\"", 0), 0); + ASSERT_EQ(stats.substr(0, 100), stats2.substr(0, 100)); +#else + std::string stats; + auto write_cb = [&stats](const char* str) { stats.append(str); }; + ASSERT_RAISES(NotImplemented, jemalloc_get_stat("thread.peak.read")); + ASSERT_RAISES(NotImplemented, jemalloc_get_stat("stats.allocated")); + ASSERT_RAISES(NotImplemented, jemalloc_get_stat("stats.allocated")); + ASSERT_RAISES(NotImplemented, jemalloc_get_stat("stats.allocatedp")); + ASSERT_RAISES(NotImplemented, jemalloc_peak_reset()); + ASSERT_RAISES(NotImplemented, jemalloc_stats_print(write_cb, "Jax")); + ASSERT_RAISES(NotImplemented, jemalloc_stats_print("ax")); #endif } From 82f8826deb8dec86f5b5fde6b56cc212798f0c41 Mon Sep 17 00:00:00 2001 From: Miles Granger Date: Wed, 21 Sep 2022 16:37:49 +0200 Subject: [PATCH 117/133] ARROW-16174: [Python] Fix FixedSizeListArray.flatten() on sliced input (#14000) [ARROW-16174](https://issues.apache.org/jira/browse/ARROW-16174) Current behavior ```python import pyarrow as pa array = pa.array([[1], [2], [3]], type=pa.list_(pa.int64(), list_size=1)) array[2:].flatten().to_pylist() [1, 2, 3] ``` After this patch ```python import pyarrow as pa array = pa.array([[1], [2], [3]], type=pa.list_(pa.int64(), list_size=1)) array[2:].flatten().to_pylist() [3] ``` Authored-by: Miles Granger Signed-off-by: Antoine Pitrou --- python/pyarrow/array.pxi | 12 +----- python/pyarrow/lib.pxd | 2 +- python/pyarrow/tests/test_array.py | 64 ++++++++++++++++++++---------- 3 files changed, 45 insertions(+), 33 deletions(-) diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index 4b35697874c..40b7bc22c84 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -2125,7 +2125,7 @@ cdef class MapArray(ListArray): return pyarrow_wrap_array(( self.ap).items()) -cdef class FixedSizeListArray(Array): +cdef class FixedSizeListArray(BaseListArray): """ Concrete class for Arrow arrays of a fixed size list data type. """ @@ -2212,16 +2212,6 @@ cdef class FixedSizeListArray(Array): @property def values(self): - return self.flatten() - - def flatten(self): - """ - Unnest this FixedSizeListArray by one level. - - Returns - ------- - result : Array - """ cdef CFixedSizeListArray* arr = self.ap return pyarrow_wrap_array(arr.values()) diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 67db3d2ffb8..25725a49570 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -383,7 +383,7 @@ cdef class MapArray(ListArray): pass -cdef class FixedSizeListArray(Array): +cdef class FixedSizeListArray(BaseListArray): pass diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index 39049a7859c..154ca348d2b 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -2685,30 +2685,49 @@ def test_list_array_flatten(offset_type, list_type_factory): assert arr2.values.values.equals(arr0) -@pytest.mark.parametrize('list_type_factory', [pa.list_, pa.large_list]) -def test_list_value_parent_indices(list_type_factory): +@pytest.mark.parametrize('list_type', [ + pa.list_(pa.int32()), + pa.list_(pa.int32(), list_size=2), + pa.large_list(pa.int32())]) +def test_list_value_parent_indices(list_type): arr = pa.array( [ - [0, 1, 2], + [0, 1], None, - [], + [None, None], [3, 4] - ], type=list_type_factory(pa.int32())) - expected = pa.array([0, 0, 0, 3, 3], type=pa.int64()) + ], type=list_type) + expected = pa.array([0, 0, 2, 2, 3, 3], type=pa.int64()) assert arr.value_parent_indices().equals(expected) -@pytest.mark.parametrize(('offset_type', 'list_type_factory'), - [(pa.int32(), pa.list_), (pa.int64(), pa.large_list)]) -def test_list_value_lengths(offset_type, list_type_factory): - arr = pa.array( - [ - [0, 1, 2], - None, - [], - [3, 4] - ], type=list_type_factory(pa.int32())) - expected = pa.array([3, None, 0, 2], type=offset_type) +@pytest.mark.parametrize(('offset_type', 'list_type'), + [(pa.int32(), pa.list_(pa.int32())), + (pa.int32(), pa.list_(pa.int32(), list_size=2)), + (pa.int64(), pa.large_list(pa.int32()))]) +def test_list_value_lengths(offset_type, list_type): + + # FixedSizeListArray needs fixed list sizes + if getattr(list_type, "list_size", None): + arr = pa.array( + [ + [0, 1], + None, + [None, None], + [3, 4] + ], type=list_type) + expected = pa.array([2, None, 2, 2], type=offset_type) + + # Otherwise create variable list sizes + else: + arr = pa.array( + [ + [0, 1, 2], + None, + [], + [3, 4] + ], type=list_type) + expected = pa.array([3, None, 0, 2], type=offset_type) assert arr.value_lengths().equals(expected) @@ -2771,7 +2790,6 @@ def test_fixed_size_list_array_flatten(): typ1 = pa.list_(pa.int64(), 2) arr1 = pa.array([ [1, 2], [3, 4], [5, 6], - None, None, None, [7, None], None, [8, 9] ], type=typ1) assert arr1.type.equals(typ1) @@ -2779,15 +2797,19 @@ def test_fixed_size_list_array_flatten(): typ0 = pa.int64() arr0 = pa.array([ - 1, 2, 3, 4, 5, 6, - None, None, None, None, None, None, - 7, None, None, None, 8, 9, + 1, 2, 3, 4, 5, 6, 7, None, 8, 9, ], type=typ0) assert arr0.type.equals(typ0) assert arr1.flatten().equals(arr0) assert arr2.flatten().flatten().equals(arr0) +def test_fixed_size_list_array_flatten_with_slice(): + array = pa.array([[1], [2], [3]], + type=pa.list_(pa.float64(), list_size=1)) + assert array[2:].flatten() == pa.array([3], type=pa.float64()) + + def test_map_array_values_offsets(): ty = pa.map_(pa.utf8(), pa.int32()) ty_values = pa.struct([pa.field("key", pa.utf8(), nullable=False), From f0303652b4934a9f767dca88268016c69375687d Mon Sep 17 00:00:00 2001 From: Nic Crane Date: Wed, 21 Sep 2022 15:46:58 +0100 Subject: [PATCH 118/133] ARROW-17699: [R] Add better error message for if a non-schema passed into open_dataset() (#14108) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Example below shows what happens if the `schema()` function is passed in as the `schema` argument via the `...` argument to `open_dataset()`. Before (a later check for something else catches it, this giving a misleading message): ``` library(dplyr) library(arrow) tf <- tempfile() dir.create(tf) write_dataset(mtcars, tf, format = "csv") open_dataset(tf, format = "csv", schema = schema) %>% collect() #> Error in `CsvFileFormat$create()`: #> ! Values in `column_names` must match `schema` field names #> ✖ `column_names` and `schema` field names match but are not in the same order ``` After (more accurate error message): ``` library(dplyr) library(arrow) tf <- tempfile() dir.create(tf) write_dataset(mtcars, tf, format = "csv") open_dataset(tf, format = "csv", schema = schema) %>% collect() #> Error in `CsvFileFormat$create()`: #> ! `schema` must be an object of class 'Schema' not 'function'. ``` Authored-by: Nic Crane Signed-off-by: Nic Crane --- r/R/dataset-format.R | 7 +++++++ r/tests/testthat/test-dataset-csv.R | 10 ++++++++++ 2 files changed, 17 insertions(+) diff --git a/r/R/dataset-format.R b/r/R/dataset-format.R index 948abf2829e..ad2a38a4f40 100644 --- a/r/R/dataset-format.R +++ b/r/R/dataset-format.R @@ -128,6 +128,13 @@ CsvFileFormat$create <- function(..., options <- list(...) schema <- options[["schema"]] + if (!is.null(schema) && !inherits(schema, "Schema")) { + abort(paste0( + "`schema` must be an object of class 'Schema' not '", + class(schema)[1], + "'." + )) + } column_names <- read_options$column_names schema_names <- names(schema) diff --git a/r/tests/testthat/test-dataset-csv.R b/r/tests/testthat/test-dataset-csv.R index 0718746624e..969bb82e40a 100644 --- a/r/tests/testthat/test-dataset-csv.R +++ b/r/tests/testthat/test-dataset-csv.R @@ -380,3 +380,13 @@ test_that("skip argument in open_dataset", { ) expect_equal(collect(ds), tbl) }) + +test_that("error message if non-schema passed in as schema to open_dataset", { + + # passing in the schema function, not an actual schema + expect_error( + open_dataset(csv_dir, format = "csv", schema = schema), + regexp = "`schema` must be an object of class 'Schema' not 'function'.", + fixed = TRUE + ) +}) From 58be6a317ff09eefb53c3f0122e4d4eedd166977 Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 21 Sep 2022 11:05:57 -0400 Subject: [PATCH 119/133] ARROW-17169: [Go][Parquet] Panic in bitmap writer with Nullable List of Struct (#14183) When building the Nullable List of Struct column for record reading we didn't account for the worst-case scenario for building the final array. We need to handle the upper bound case of `offsetData[validityIO.Read]`+`validityIO.NullCount` Authored-by: Matt Topol Signed-off-by: Matt Topol --- go/parquet/pqarrow/column_readers.go | 7 ++++- go/parquet/pqarrow/encode_arrow_test.go | 36 +++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/go/parquet/pqarrow/column_readers.go b/go/parquet/pqarrow/column_readers.go index 73577b616ee..c3bf8ecd4ff 100644 --- a/go/parquet/pqarrow/column_readers.go +++ b/go/parquet/pqarrow/column_readers.go @@ -388,7 +388,12 @@ func (lr *listReader) BuildArray(lenBound int64) (*arrow.Chunked, error) { return nil, err } - arr, err := lr.itemRdr.BuildArray(int64(offsetData[int(validityIO.Read)])) + // if the parent (itemRdr) has nulls and is a nested type like list + // then we need BuildArray to account for that with the number of + // definition levels when building out the bitmap. So the upper bound + // to make sure we have the space for is the worst case scenario, + // the upper bound is the value of the last offset + the nullcount + arr, err := lr.itemRdr.BuildArray(int64(offsetData[int(validityIO.Read)]) + validityIO.NullCount) if err != nil { return nil, err } diff --git a/go/parquet/pqarrow/encode_arrow_test.go b/go/parquet/pqarrow/encode_arrow_test.go index c9aeb19c4a2..a7185bfae4f 100644 --- a/go/parquet/pqarrow/encode_arrow_test.go +++ b/go/parquet/pqarrow/encode_arrow_test.go @@ -1288,6 +1288,42 @@ func (ps *ParquetIOTestSuite) TestNull() { ps.roundTripTable(expected, true) } +// ARROW-17169 +func (ps *ParquetIOTestSuite) TestNullableListOfStruct() { + bldr := array.NewListBuilder(memory.DefaultAllocator, arrow.StructOf( + arrow.Field{Name: "a", Type: arrow.PrimitiveTypes.Int32}, + arrow.Field{Name: "b", Type: arrow.BinaryTypes.String}, + )) + defer bldr.Release() + + stBldr := bldr.ValueBuilder().(*array.StructBuilder) + aBldr := stBldr.FieldBuilder(0).(*array.Int32Builder) + bBldr := stBldr.FieldBuilder(1).(*array.StringBuilder) + + for i := 0; i < 320; i++ { + if i%5 == 0 { + bldr.AppendNull() + continue + } + bldr.Append(true) + for j := 0; j < 4; j++ { + stBldr.Append(true) + aBldr.Append(int32(i + j)) + bBldr.Append(strconv.Itoa(i + j)) + } + } + + arr := bldr.NewArray() + defer arr.Release() + + field := arrow.Field{Name: "x", Type: arr.DataType(), Nullable: true} + expected := array.NewTable(arrow.NewSchema([]arrow.Field{field}, nil), + []arrow.Column{*arrow.NewColumn(field, arrow.NewChunked(field.Type, []arrow.Array{arr}))}, -1) + defer expected.Release() + + ps.roundTripTable(expected, false) +} + func TestParquetArrowIO(t *testing.T) { suite.Run(t, new(ParquetIOTestSuite)) } From 87d102eb96b7a4245171f3b72522ba5017bf23a0 Mon Sep 17 00:00:00 2001 From: rtpsw Date: Wed, 21 Sep 2022 18:25:36 +0300 Subject: [PATCH 120/133] ARROW-17696: [C++] arrow-compute-asof-join-node-test inordinately slow (#14190) See https://issues.apache.org/jira/browse/ARROW-17696 Authored-by: Yaron Gvili Signed-off-by: Antoine Pitrou --- cpp/src/arrow/compute/exec/asof_join_node_test.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/exec/asof_join_node_test.cc b/cpp/src/arrow/compute/exec/asof_join_node_test.cc index 2b6613021ff..919bfdbde49 100644 --- a/cpp/src/arrow/compute/exec/asof_join_node_test.cc +++ b/cpp/src/arrow/compute/exec/asof_join_node_test.cc @@ -568,7 +568,7 @@ struct BasicTest { std::uniform_int_distribution r0_distribution(0, r0_types.size() - 1); std::uniform_int_distribution r1_distribution(0, r1_types.size() - 1); - for (int i = 0; i < 1000; i++) { + for (int i = 0; i < 100; i++) { auto time_type = time_types[time_distribution(engine)]; ARROW_SCOPED_TRACE("Time type: ", *time_type); auto key_type = key_types[key_distribution(engine)]; @@ -584,8 +584,7 @@ struct BasicTest { auto end_time = std::chrono::system_clock::now(); std::chrono::duration diff = end_time - start_time; - if (diff.count() > 2) { - // this normally happens on slow CI systems, but is fine + if (diff.count() > 0.2) { break; } } From cf66aa03208bd87956f24c7e748186c70eff4b0c Mon Sep 17 00:00:00 2001 From: Matt Topol Date: Wed, 21 Sep 2022 15:20:56 -0400 Subject: [PATCH 121/133] ARROW-17804: [Go][CSV] Add Date32 and Time32 parsers (#14192) Given the recent addition of inferring schemas for CSVs, it should be able to parse all the types that can be inferred Authored-by: Matt Topol Signed-off-by: Matt Topol --- go/arrow/csv/common.go | 2 +- go/arrow/csv/reader.go | 39 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/go/arrow/csv/common.go b/go/arrow/csv/common.go index f27ba4c0a98..e4de1d86876 100644 --- a/go/arrow/csv/common.go +++ b/go/arrow/csv/common.go @@ -159,7 +159,7 @@ func WithNullWriter(null string) Option { } } -// WithBoolWriter override the default bool formatter with a fucntion that returns +// WithBoolWriter override the default bool formatter with a function that returns // a string representaton of bool states. i.e. True, False, 1, 0 func WithBoolWriter(fmtr func(bool) string) Option { return func(cfg config) { diff --git a/go/arrow/csv/reader.go b/go/arrow/csv/reader.go index cd3902affca..6e92a40e415 100644 --- a/go/arrow/csv/reader.go +++ b/go/arrow/csv/reader.go @@ -446,7 +446,14 @@ func (r *Reader) initFieldConverter(bldr array.Builder) func(string) { return func(str string) { r.parseTimestamp(bldr, str, dt.Unit) } - + case *arrow.Date32Type: + return func(str string) { + r.parseDate32(bldr, str) + } + case *arrow.Time32Type: + return func(str string) { + r.parseTime32(bldr, str, dt.Unit) + } default: panic(fmt.Errorf("arrow/csv: unhandled field type %T", bldr.Type())) } @@ -644,6 +651,36 @@ func (r *Reader) parseTimestamp(field array.Builder, str string, unit arrow.Time field.(*array.TimestampBuilder).Append(v) } +func (r *Reader) parseDate32(field array.Builder, str string) { + if r.isNull(str) { + field.AppendNull() + return + } + + tm, err := time.Parse("2006-01-02", str) + if err != nil && r.err == nil { + r.err = err + field.AppendNull() + return + } + field.(*array.Date32Builder).Append(arrow.Date32FromTime(tm)) +} + +func (r *Reader) parseTime32(field array.Builder, str string, unit arrow.TimeUnit) { + if r.isNull(str) { + field.AppendNull() + return + } + + val, err := arrow.Time32FromString(str, unit) + if err != nil && r.err == nil { + r.err = err + field.AppendNull() + return + } + field.(*array.Time32Builder).Append(val) +} + // Retain increases the reference count by 1. // Retain may be called simultaneously from multiple goroutines. func (r *Reader) Retain() { From 29143259ecc63a9436482f65a6e5e0757329be73 Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Wed, 21 Sep 2022 22:07:09 +0200 Subject: [PATCH 122/133] ARROW-17800: [C++] Fix failures in jemalloc stats tests (#14194) - Provide compatibility for 32-bit platforms - Avoid memory leak in tests - Make checks less strict Authored-by: Antoine Pitrou Signed-off-by: Antoine Pitrou --- cpp/src/arrow/memory_pool_jemalloc.cc | 28 ++++++++++++++++++++------- cpp/src/arrow/memory_pool_test.cc | 23 +++++++++++++--------- 2 files changed, 35 insertions(+), 16 deletions(-) diff --git a/cpp/src/arrow/memory_pool_jemalloc.cc b/cpp/src/arrow/memory_pool_jemalloc.cc index 03d2b28ee3e..c7d73c84910 100644 --- a/cpp/src/arrow/memory_pool_jemalloc.cc +++ b/cpp/src/arrow/memory_pool_jemalloc.cc @@ -155,9 +155,8 @@ Status jemalloc_set_decay_ms(int ms) { #undef RETURN_IF_JEMALLOC_ERROR Result jemalloc_get_stat(const char* name) { - size_t sz = sizeof(uint64_t); + size_t sz; int err; - uint64_t value; // Update the statistics cached by mallctl. if (std::strcmp(name, "stats.allocated") == 0 || @@ -167,16 +166,31 @@ Result jemalloc_get_stat(const char* name) { std::strcmp(name, "stats.mapped") == 0 || std::strcmp(name, "stats.retained") == 0) { uint64_t epoch; + sz = sizeof(epoch); mallctl("epoch", &epoch, &sz, &epoch, sz); } - err = mallctl(name, &value, &sz, nullptr, 0); - - if (err) { - return arrow::internal::IOErrorFromErrno(err, "Failed retrieving ", &name); + // Depending on the stat being queried and on the platform, we could need + // to pass a uint32_t or uint64_t pointer. Try both. + { + uint64_t value = 0; + sz = sizeof(value); + err = mallctl(name, &value, &sz, nullptr, 0); + if (!err) { + return value; + } + } + // EINVAL means the given value length (`sz`) was incorrect. + if (err == EINVAL) { + uint32_t value = 0; + sz = sizeof(value); + err = mallctl(name, &value, &sz, nullptr, 0); + if (!err) { + return value; + } } - return value; + return arrow::internal::IOErrorFromErrno(err, "Failed retrieving ", &name); } Status jemalloc_peak_reset() { diff --git a/cpp/src/arrow/memory_pool_test.cc b/cpp/src/arrow/memory_pool_test.cc index 5ac14a44b9a..61e0abf3e66 100644 --- a/cpp/src/arrow/memory_pool_test.cc +++ b/cpp/src/arrow/memory_pool_test.cc @@ -180,6 +180,7 @@ TEST(Jemalloc, GetAllocationStats) { metadata0, resident0, mapped0, retained0; int64_t thread_allocated, thread_deallocated, thread_peak_read, thread_allocated0, thread_deallocated0, thread_peak_read0; + MemoryPool* pool = nullptr; ABORT_NOT_OK(jemalloc_memory_pool(&pool)); ASSERT_EQ("jemalloc", pool->backend_name()); @@ -211,14 +212,15 @@ TEST(Jemalloc, GetAllocationStats) { ASSERT_OK_AND_ASSIGN(thread_allocated, jemalloc_get_stat("thread.allocated")); ASSERT_OK_AND_ASSIGN(thread_deallocated, jemalloc_get_stat("thread.deallocated")); ASSERT_OK_AND_ASSIGN(thread_peak_read, jemalloc_get_stat("thread.peak.read")); + pool->Free(data, 1023); // Check allocated stats pre-allocation - ASSERT_NEAR(allocated0, 120000, 100000); - ASSERT_NEAR(active0, 75000, 70000); - ASSERT_NEAR(metadata0, 3000000, 1000000); - ASSERT_NEAR(resident0, 3000000, 1000000); - ASSERT_NEAR(mapped0, 6500000, 1000000); - ASSERT_NEAR(retained0, 1500000, 2000000); + ASSERT_GT(allocated0, 0); + ASSERT_GT(active0, 0); + ASSERT_GT(metadata0, 0); + ASSERT_GT(resident0, 0); + ASSERT_GT(mapped0, 0); + ASSERT_GE(retained0, 0); // Check allocated stats change due to allocation ASSERT_NEAR(allocated - allocated0, 70000, 50000); @@ -233,13 +235,16 @@ TEST(Jemalloc, GetAllocationStats) { ASSERT_EQ(thread_deallocated - thread_deallocated0, 1280); // Resetting thread peak read metric - ASSERT_OK(pool->Allocate(12560, &data)); + ASSERT_OK(pool->Allocate(100000, &data)); ASSERT_OK_AND_ASSIGN(thread_peak_read, jemalloc_get_stat("thread.peak.read")); - ASSERT_NEAR(thread_peak_read, 15616, 1000); + ASSERT_NEAR(thread_peak_read, 100000, 50000); + pool->Free(data, 100000); ASSERT_OK(jemalloc_peak_reset()); + ASSERT_OK(pool->Allocate(1256, &data)); ASSERT_OK_AND_ASSIGN(thread_peak_read, jemalloc_get_stat("thread.peak.read")); - ASSERT_NEAR(thread_peak_read, 1280, 100); + ASSERT_NEAR(thread_peak_read, 1256, 100); + pool->Free(data, 1256); // Print statistics to stderr ASSERT_OK(jemalloc_stats_print("J")); From 17a2137612970fb9347c9a336a729cf172c0d67e Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 21 Sep 2022 16:34:45 -0400 Subject: [PATCH 123/133] requested fixes --- cpp/src/arrow/compute/exec/plan_test.cc | 169 +++++++--------------- cpp/src/arrow/compute/exec/source_node.cc | 22 +-- 2 files changed, 61 insertions(+), 130 deletions(-) diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 2fa4cde9d20..ba3a7e46dc8 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -296,146 +296,83 @@ TEST(ExecPlanExecution, TableSourceSinkError) { Raises(StatusCode::Invalid, HasSubstr("batch_size > 0"))); } -TEST(ExecPlanExecution, ArrayVectorSourceSink) { - for (int num_threads : {1, 4}) { - ASSERT_OK_AND_ASSIGN(auto io_executor, - arrow::internal::ThreadPool::Make(num_threads)); - ExecContext exec_context(default_memory_pool(), io_executor.get()); - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_context)); - AsyncGenerator> sink_gen; - - auto exp_batches = MakeBasicBatches(); - ASSERT_OK_AND_ASSIGN(auto arrayvecs, ToArrayVectors(exp_batches)); - auto arrayvec_it_maker = [&arrayvecs]() { - return MakeVectorIterator>(arrayvecs); - }; - - ASSERT_OK(Declaration::Sequence( - { - {"array_source", ArrayVectorSourceNodeOptions{exp_batches.schema, - arrayvec_it_maker}}, - {"sink", SinkNodeOptions{&sink_gen}}, - }) - .AddToPlan(plan.get())); - - ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); - } -} - -TEST(ExecPlanExecution, ArrayVectorSourceSinkError) { +template +void test_source_sink_error( + std::string source_factory_name, + std::function>(const BatchesWithSchema&)> + to_elements) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); std::shared_ptr no_schema; auto exp_batches = MakeBasicBatches(); - ASSERT_OK_AND_ASSIGN(auto arrayvecs, ToArrayVectors(exp_batches)); - auto arrayvec_it_maker = [&arrayvecs]() { - return MakeVectorIterator>(arrayvecs); + ASSERT_OK_AND_ASSIGN(auto elements, to_elements(exp_batches)); + auto element_it_maker = [&elements]() { + return MakeVectorIterator(elements); }; - auto null_executor_options = - ArrayVectorSourceNodeOptions{exp_batches.schema, arrayvec_it_maker}; - ASSERT_THAT(MakeExecNode("array_source", plan.get(), {}, null_executor_options), - Raises(StatusCode::Invalid, HasSubstr("not null"))); + auto null_executor_options = OptionsType{exp_batches.schema, element_it_maker}; + ASSERT_OK(MakeExecNode(source_factory_name, plan.get(), {}, null_executor_options)); - auto null_schema_options = ArrayVectorSourceNodeOptions{no_schema, arrayvec_it_maker}; - ASSERT_THAT(MakeExecNode("array_source", plan.get(), {}, null_schema_options), + auto null_schema_options = OptionsType{no_schema, element_it_maker}; + ASSERT_THAT(MakeExecNode(source_factory_name, plan.get(), {}, null_schema_options), Raises(StatusCode::Invalid, HasSubstr("not null"))); } -TEST(ExecPlanExecution, ExecBatchSourceSink) { - for (int num_threads : {1, 4}) { - ASSERT_OK_AND_ASSIGN(auto io_executor, - arrow::internal::ThreadPool::Make(num_threads)); - ExecContext exec_context(default_memory_pool(), io_executor.get()); - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_context)); - AsyncGenerator> sink_gen; +template +void test_source_sink( + std::string source_factory_name, + std::function>(const BatchesWithSchema&)> + to_elements) { + ASSERT_OK_AND_ASSIGN(auto io_executor, arrow::internal::ThreadPool::Make(1)); + ExecContext exec_context(default_memory_pool(), io_executor.get()); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_context)); + AsyncGenerator> sink_gen; - auto exp_batches = MakeBasicBatches(); - ASSERT_OK_AND_ASSIGN(auto exec_batches, ToExecBatches(exp_batches)); - auto exec_batch_it_maker = [&exec_batches]() { - return MakeVectorIterator>(exec_batches); - }; + auto exp_batches = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN(auto elements, to_elements(exp_batches)); + auto element_it_maker = [&elements]() { + return MakeVectorIterator(elements); + }; - ASSERT_OK(Declaration::Sequence( - { - {"exec_source", ExecBatchSourceNodeOptions{exp_batches.schema, - exec_batch_it_maker}}, - {"sink", SinkNodeOptions{&sink_gen}}, - }) - .AddToPlan(plan.get())); + ASSERT_OK(Declaration::Sequence({ + {source_factory_name, + OptionsType{exp_batches.schema, element_it_maker}}, + {"sink", SinkNodeOptions{&sink_gen}}, + }) + .AddToPlan(plan.get())); - ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); - } + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); } -TEST(ExecPlanExecution, ExecBatchSourceSinkError) { - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - std::shared_ptr no_schema; +TEST(ExecPlanExecution, ArrayVectorSourceSink) { + test_source_sink, ArrayVectorSourceNodeOptions>( + "array_vector_source", ToArrayVectors); +} - auto exp_batches = MakeBasicBatches(); - ASSERT_OK_AND_ASSIGN(auto exec_batches, ToExecBatches(exp_batches)); - auto exec_batch_it_maker = [&exec_batches]() { - return MakeVectorIterator>(exec_batches); - }; +TEST(ExecPlanExecution, ArrayVectorSourceSinkError) { + test_source_sink_error, ArrayVectorSourceNodeOptions>( + "array_vector_source", ToArrayVectors); +} - auto null_executor_options = - ExecBatchSourceNodeOptions{exp_batches.schema, exec_batch_it_maker}; - ASSERT_THAT(MakeExecNode("exec_source", plan.get(), {}, null_executor_options), - Raises(StatusCode::Invalid, HasSubstr("not null"))); +TEST(ExecPlanExecution, ExecBatchSourceSink) { + test_source_sink, ExecBatchSourceNodeOptions>( + "exec_batch_source", ToExecBatches); +} - auto null_schema_options = ExecBatchSourceNodeOptions{no_schema, exec_batch_it_maker}; - ASSERT_THAT(MakeExecNode("exec_source", plan.get(), {}, null_schema_options), - Raises(StatusCode::Invalid, HasSubstr("not null"))); +TEST(ExecPlanExecution, ExecBatchSourceSinkError) { + test_source_sink_error, ExecBatchSourceNodeOptions>( + "exec_batch_source", ToExecBatches); } TEST(ExecPlanExecution, RecordBatchSourceSink) { - for (int num_threads : {1, 4}) { - ASSERT_OK_AND_ASSIGN(auto io_executor, - arrow::internal::ThreadPool::Make(num_threads)); - ExecContext exec_context(default_memory_pool(), io_executor.get()); - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_context)); - AsyncGenerator> sink_gen; - - auto exp_batches = MakeBasicBatches(); - ASSERT_OK_AND_ASSIGN(auto record_batches, ToRecordBatches(exp_batches)); - auto record_batch_it_maker = [&record_batches]() { - return MakeVectorIterator>(record_batches); - }; - - ASSERT_OK(Declaration::Sequence({ - {"record_source", - RecordBatchSourceNodeOptions{ - exp_batches.schema, record_batch_it_maker}}, - {"sink", SinkNodeOptions{&sink_gen}}, - }) - .AddToPlan(plan.get())); - - ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); - } + test_source_sink, RecordBatchSourceNodeOptions>( + "record_batch_source", ToRecordBatches); } TEST(ExecPlanExecution, RecordBatchSourceSinkError) { - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - std::shared_ptr no_schema; - - auto exp_batches = MakeBasicBatches(); - ASSERT_OK_AND_ASSIGN(auto record_batches, ToRecordBatches(exp_batches)); - auto record_batch_it_maker = [&record_batches]() { - return MakeVectorIterator>(record_batches); - }; - - auto null_executor_options = - RecordBatchSourceNodeOptions{exp_batches.schema, record_batch_it_maker}; - ASSERT_THAT(MakeExecNode("record_source", plan.get(), {}, null_executor_options), - Raises(StatusCode::Invalid, HasSubstr("not null"))); - - auto null_schema_options = - RecordBatchSourceNodeOptions{no_schema, record_batch_it_maker}; - ASSERT_THAT(MakeExecNode("record_source", plan.get(), {}, null_schema_options), - Raises(StatusCode::Invalid, HasSubstr("not null"))); + test_source_sink_error, RecordBatchSourceNodeOptions>( + "record_batch_source", ToRecordBatches); } TEST(ExecPlanExecution, SinkNodeBackpressure) { diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 73fc7e5ebda..be712af6353 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -24,6 +24,7 @@ #include "arrow/compute/exec/util.h" #include "arrow/compute/exec_internal.h" #include "arrow/datum.h" +#include "arrow/io/util_internal.h" #include "arrow/result.h" #include "arrow/table.h" #include "arrow/util/async_generator.h" @@ -307,22 +308,15 @@ struct SchemaSourceNode : public SourceNode { auto io_executor = plan->exec_context()->executor(); auto it = it_maker(); - RETURN_NOT_OK(ValidateSchemaSourceNodeInput(io_executor, schema, This::kKindName)); - ARROW_ASSIGN_OR_RAISE(auto generator, This::MakeGenerator(it, io_executor, schema)); - return plan->EmplaceNode(plan, schema, generator); - } - - static arrow::Status ValidateSchemaSourceNodeInput( - arrow::internal::Executor* io_executor, const std::shared_ptr& schema, - const char* kKindName) { if (schema == NULLPTR) { - return Status::Invalid(kKindName, " requires schema which is not null"); + return Status::Invalid(This::kKindName, " requires schema which is not null"); } if (io_executor == NULLPTR) { - return Status::Invalid(kKindName, " requires IO-executor which is not null"); + io_executor = io::internal::GetIOThreadPool(); } - return Status::OK(); + ARROW_ASSIGN_OR_RAISE(auto generator, This::MakeGenerator(it, io_executor, schema)); + return plan->EmplaceNode(plan, schema, generator); } }; @@ -434,9 +428,9 @@ namespace internal { void RegisterSourceNode(ExecFactoryRegistry* registry) { DCHECK_OK(registry->AddFactory("source", SourceNode::Make)); DCHECK_OK(registry->AddFactory("table_source", TableSourceNode::Make)); - DCHECK_OK(registry->AddFactory("record_source", RecordBatchSourceNode::Make)); - DCHECK_OK(registry->AddFactory("exec_source", ExecBatchSourceNode::Make)); - DCHECK_OK(registry->AddFactory("array_source", ArrayVectorSourceNode::Make)); + DCHECK_OK(registry->AddFactory("record_batch_source", RecordBatchSourceNode::Make)); + DCHECK_OK(registry->AddFactory("exec_batch_source", ExecBatchSourceNode::Make)); + DCHECK_OK(registry->AddFactory("array_vector_source", ArrayVectorSourceNode::Make)); } } // namespace internal From 43e66a928e29a811b0e14f2a2d7ffa9f8290ccbe Mon Sep 17 00:00:00 2001 From: Igor Suhorukov Date: Wed, 21 Sep 2022 23:44:41 +0300 Subject: [PATCH 124/133] ARROW-17659: [Java] Populate JDBC schema name metadata when config.shouldIncludeMetadata provided (#14196) Current implementation include [catalog,table,column,type](https://github.com/apache/arrow/blob/master/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java#L248) metadata, but schema metadata field is missing. In terms of PostgreSQL catalog - is database, schema - namespace inside database, so catalog name is insufficient for table addressing without schema. Proposed changes is + metadata.put(Constants.SQL_SCHEMA_KEY, rsmd.getSchemaName(i)); Authored-by: igor.suhorukov Signed-off-by: David Li --- .../apache/arrow/adapter/jdbc/Constants.java | 1 + .../arrow/adapter/jdbc/JdbcToArrowUtils.java | 1 + .../adapter/jdbc/JdbcToArrowTestHelper.java | 3 +- ...expectedSchemaWithCommentsAndJdbcMeta.json | 42 ++++++++++++------- 4 files changed, 31 insertions(+), 16 deletions(-) diff --git a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/Constants.java b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/Constants.java index aaadacb5436..5b01077b179 100644 --- a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/Constants.java +++ b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/Constants.java @@ -24,6 +24,7 @@ public class Constants { private Constants() {} public static final String SQL_CATALOG_NAME_KEY = "SQL_CATALOG_NAME"; + public static final String SQL_SCHEMA_NAME_KEY = "SQL_SCHEMA_NAME"; public static final String SQL_TABLE_NAME_KEY = "SQL_TABLE_NAME"; public static final String SQL_COLUMN_NAME_KEY = "SQL_COLUMN_NAME"; public static final String SQL_TYPE_KEY = "SQL_TYPE"; diff --git a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java index 93c6a80c107..dc79f6efff3 100644 --- a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java +++ b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/JdbcToArrowUtils.java @@ -253,6 +253,7 @@ public static Schema jdbcToArrowSchema(ResultSetMetaData rsmd, JdbcToArrowConfig if (config.shouldIncludeMetadata()) { metadata = new HashMap<>(); metadata.put(Constants.SQL_CATALOG_NAME_KEY, rsmd.getCatalogName(i)); + metadata.put(Constants.SQL_SCHEMA_NAME_KEY, rsmd.getSchemaName(i)); metadata.put(Constants.SQL_TABLE_NAME_KEY, rsmd.getTableName(i)); metadata.put(Constants.SQL_COLUMN_NAME_KEY, columnName); metadata.put(Constants.SQL_TYPE_KEY, rsmd.getColumnTypeName(i)); diff --git a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcToArrowTestHelper.java b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcToArrowTestHelper.java index 5d1fb2276cc..d5f896ba7df 100644 --- a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcToArrowTestHelper.java +++ b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcToArrowTestHelper.java @@ -324,9 +324,10 @@ public static void assertFieldMetadataMatchesResultSetMetadata(ResultSetMetaData Map metadata = fields.get(i - 1).getMetadata(); assertNotNull(metadata); - assertEquals(4, metadata.size()); + assertEquals(5, metadata.size()); assertEquals(rsmd.getCatalogName(i), metadata.get(Constants.SQL_CATALOG_NAME_KEY)); + assertEquals(rsmd.getSchemaName(i), metadata.get(Constants.SQL_SCHEMA_NAME_KEY)); assertEquals(rsmd.getTableName(i), metadata.get(Constants.SQL_TABLE_NAME_KEY)); assertEquals(rsmd.getColumnLabel(i), metadata.get(Constants.SQL_COLUMN_NAME_KEY)); assertEquals(rsmd.getColumnTypeName(i), metadata.get(Constants.SQL_TYPE_KEY)); diff --git a/java/adapter/jdbc/src/test/resources/h2/expectedSchemaWithCommentsAndJdbcMeta.json b/java/adapter/jdbc/src/test/resources/h2/expectedSchemaWithCommentsAndJdbcMeta.json index 967ce0a08e1..9b25d635d4b 100644 --- a/java/adapter/jdbc/src/test/resources/h2/expectedSchemaWithCommentsAndJdbcMeta.json +++ b/java/adapter/jdbc/src/test/resources/h2/expectedSchemaWithCommentsAndJdbcMeta.json @@ -9,11 +9,8 @@ }, "children" : [ ], "metadata" : [ { - "value" : "Record identifier", - "key" : "comment" - }, { - "value" : "TABLE1", - "key" : "SQL_TABLE_NAME" + "value" : "PUBLIC", + "key" : "SQL_SCHEMA_NAME" }, { "value" : "JDBCTOARROWTEST?CHARACTERENCODING=UTF-8", "key" : "SQL_CATALOG_NAME" @@ -23,6 +20,12 @@ }, { "value" : "BIGINT", "key" : "SQL_TYPE" + }, { + "value" : "Record identifier", + "key" : "comment" + }, { + "value" : "TABLE1", + "key" : "SQL_TABLE_NAME" } ] }, { "name" : "NAME", @@ -32,11 +35,8 @@ }, "children" : [ ], "metadata" : [ { - "value" : "Name of record", - "key" : "comment" - }, { - "value" : "TABLE1", - "key" : "SQL_TABLE_NAME" + "value" : "PUBLIC", + "key" : "SQL_SCHEMA_NAME" }, { "value" : "JDBCTOARROWTEST?CHARACTERENCODING=UTF-8", "key" : "SQL_CATALOG_NAME" @@ -46,6 +46,12 @@ }, { "value" : "VARCHAR", "key" : "SQL_TYPE" + }, { + "value" : "Name of record", + "key" : "comment" + }, { + "value" : "TABLE1", + "key" : "SQL_TABLE_NAME" } ] }, { "name" : "COLUMN1", @@ -55,6 +61,9 @@ }, "children" : [ ], "metadata" : [ { + "value" : "PUBLIC", + "key" : "SQL_SCHEMA_NAME" + }, { "value" : "TABLE1", "key" : "SQL_TABLE_NAME" }, { @@ -77,11 +86,8 @@ }, "children" : [ ], "metadata" : [ { - "value" : "Informative description of columnN", - "key" : "comment" - }, { - "value" : "TABLE1", - "key" : "SQL_TABLE_NAME" + "value" : "PUBLIC", + "key" : "SQL_SCHEMA_NAME" }, { "value" : "JDBCTOARROWTEST?CHARACTERENCODING=UTF-8", "key" : "SQL_CATALOG_NAME" @@ -91,6 +97,12 @@ }, { "value" : "INTEGER", "key" : "SQL_TYPE" + }, { + "value" : "Informative description of columnN", + "key" : "comment" + }, { + "value" : "TABLE1", + "key" : "SQL_TABLE_NAME" } ] } ], "metadata" : [ { From 311fe3e875f5273cd988b052740d85216e3216b3 Mon Sep 17 00:00:00 2001 From: Jin Shang Date: Thu, 22 Sep 2022 16:40:42 +0800 Subject: [PATCH 125/133] ARROW-17790: [C++][Gandiva] Adapt to LLVM opaque pointer (#14187) Starting from LLVM 13, LLVM IR has been shifting towards a unified opaque pointer type, i.e. pointers without pointee types. It has provided workarounds until LLVM 15. The temporary workarounds need to be replaced in order to support LLVM 15 and onwards. We need to supply the pointee type to the CreateGEP and CreateLoad methods. For more background info, see https://llvm.org/docs/OpaquePointers.html and https://lists.llvm.org/pipermail/llvm-dev/2015-February/081822.html Related issues: https://issues.apache.org/jira/browse/ARROW-14363 https://issues.apache.org/jira/browse/ARROW-17728 https://issues.apache.org/jira/browse/ARROW-17775 Lead-authored-by: Jin Shang Co-authored-by: jinshang Signed-off-by: Antoine Pitrou --- .github/workflows/cpp.yml | 10 +++- .github/workflows/python.yml | 3 + .github/workflows/ruby.yml | 3 + cpp/src/arrow/array/array_test.cc | 3 - cpp/src/arrow/filesystem/gcsfs_test.cc | 1 + cpp/src/arrow/filesystem/s3_test_util.cc | 1 + cpp/src/arrow/flight/test_util.cc | 3 +- cpp/src/gandiva/decimal_ir.cc | 13 +++-- cpp/src/gandiva/engine.cc | 8 +-- cpp/src/gandiva/engine_llvm_test.cc | 5 +- cpp/src/gandiva/llvm_generator.cc | 72 +++++++++++++++--------- cpp/src/gandiva/llvm_generator.h | 2 +- cpp/src/gandiva/llvm_includes.h | 13 ----- 13 files changed, 76 insertions(+), 61 deletions(-) diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index 2642a6ec1a2..9c7e09382ab 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -56,7 +56,7 @@ jobs: name: ${{ matrix.title }} runs-on: ubuntu-latest if: ${{ !contains(github.event.pull_request.title, 'WIP') }} - timeout-minutes: 60 + timeout-minutes: 75 strategy: fail-fast: false matrix: @@ -122,7 +122,7 @@ jobs: name: AMD64 macOS 11 C++ runs-on: macos-latest if: ${{ !contains(github.event.pull_request.title, 'WIP') }} - timeout-minutes: 60 + timeout-minutes: 75 strategy: fail-fast: false env: @@ -179,7 +179,11 @@ jobs: key: cpp-ccache-macos-${{ hashFiles('cpp/**') }} restore-keys: cpp-ccache-macos- - name: Build - run: ci/scripts/cpp_build.sh $(pwd) $(pwd)/build + # use brew version of clang, to be consistent with LLVM lib, see ARROW-17790. + run: | + export CC=$(brew --prefix llvm)/bin/clang + export CXX=$(brew --prefix llvm)/bin/clang++ + ci/scripts/cpp_build.sh $(pwd) $(pwd)/build - name: Test shell: bash run: | diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 5ccbceabea7..1c72f1b706c 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -160,7 +160,10 @@ jobs: -r python/requirements-test.txt - name: Build shell: bash + # use brew version of clang, to be consistent with LLVM lib, see ARROW-17790. run: | + export CC=$(brew --prefix llvm)/bin/clang + export CXX=$(brew --prefix llvm)/bin/clang++ export PYTHON=python3 ci/scripts/cpp_build.sh $(pwd) $(pwd)/build ci/scripts/python_build.sh $(pwd) $(pwd)/build diff --git a/.github/workflows/ruby.yml b/.github/workflows/ruby.yml index 4dd61befab6..35cc1b1c29d 100644 --- a/.github/workflows/ruby.yml +++ b/.github/workflows/ruby.yml @@ -164,7 +164,10 @@ jobs: key: ruby-ccache-macos-${{ hashFiles('cpp/**') }} restore-keys: ruby-ccache-macos- - name: Build C++ + # use brew version of clang, to be consistent with LLVM lib, see ARROW-17790. run: | + export CC=$(brew --prefix llvm)/bin/clang + export CXX=$(brew --prefix llvm)/bin/clang++ ci/scripts/cpp_build.sh $(pwd) $(pwd)/build - name: Build GLib run: | diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index c00e54ecb80..d4ad1578b77 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -2831,8 +2831,6 @@ class DecimalTest : public ::testing::TestWithParam { auto type = std::make_shared(precision, 4); auto builder = std::make_shared(type); - size_t null_count = 0; - const size_t size = draw.size(); ARROW_EXPECT_OK(builder->Reserve(size)); @@ -2842,7 +2840,6 @@ class DecimalTest : public ::testing::TestWithParam { ARROW_EXPECT_OK(builder->Append(draw[i])); } else { ARROW_EXPECT_OK(builder->AppendNull()); - ++null_count; } } diff --git a/cpp/src/arrow/filesystem/gcsfs_test.cc b/cpp/src/arrow/filesystem/gcsfs_test.cc index 48d56f7b7bb..f64834e591b 100644 --- a/cpp/src/arrow/filesystem/gcsfs_test.cc +++ b/cpp/src/arrow/filesystem/gcsfs_test.cc @@ -17,6 +17,7 @@ #include // Missing include in boost/process +#define BOOST_NO_CXX98_FUNCTION_BASE // ARROW-17805 // This boost/asio/io_context.hpp include is needless for no MinGW // build. // diff --git a/cpp/src/arrow/filesystem/s3_test_util.cc b/cpp/src/arrow/filesystem/s3_test_util.cc index f5a054a8efa..d7e0cbc92d7 100644 --- a/cpp/src/arrow/filesystem/s3_test_util.cc +++ b/cpp/src/arrow/filesystem/s3_test_util.cc @@ -34,6 +34,7 @@ #ifdef __MINGW32__ #include #endif +#define BOOST_NO_CXX98_FUNCTION_BASE // ARROW-17805 // We need BOOST_USE_WINDOWS_H definition with MinGW when we use // boost/process.hpp. See BOOST_USE_WINDOWS_H=1 in // cpp/cmake_modules/ThirdpartyToolchain.cmake for details. diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc index d858c15db6d..41e4dcaeddb 100644 --- a/cpp/src/arrow/flight/test_util.cc +++ b/cpp/src/arrow/flight/test_util.cc @@ -30,11 +30,12 @@ // We need Windows fixes before including Boost #include "arrow/util/windows_compatibility.h" +#include #include +#define BOOST_NO_CXX98_FUNCTION_BASE // ARROW-17805 // We need BOOST_USE_WINDOWS_H definition with MinGW when we use // boost/process.hpp. See BOOST_USE_WINDOWS_H=1 in // cpp/cmake_modules/ThirdpartyToolchain.cmake for details. -#include #include #include "arrow/array.h" diff --git a/cpp/src/gandiva/decimal_ir.cc b/cpp/src/gandiva/decimal_ir.cc index 5d5d30b4a75..b22e7ad5b5e 100644 --- a/cpp/src/gandiva/decimal_ir.cc +++ b/cpp/src/gandiva/decimal_ir.cc @@ -96,8 +96,9 @@ void DecimalIR::InitializeIntrinsics() { // CPP: return kScaleMultipliers[scale] llvm::Value* DecimalIR::GetScaleMultiplier(llvm::Value* scale) { auto const_array = module()->getGlobalVariable(kScaleMultipliersName); - auto ptr = CreateGEP(ir_builder(), const_array, {types()->i32_constant(0), scale}); - return CreateLoad(ir_builder(), ptr); + auto ptr = ir_builder()->CreateGEP(const_array->getValueType(), const_array, + {types()->i32_constant(0), scale}); + return ir_builder()->CreateLoad(types()->i128_type(), ptr); } // CPP: x <= y ? y : x @@ -248,8 +249,8 @@ llvm::Value* DecimalIR::AddLarge(const ValueFull& x, const ValueFull& y, ir_builder()->CreateCall(module()->getFunction("add_large_decimal128_decimal128"), args); - auto out_high = CreateLoad(ir_builder(), out_high_ptr); - auto out_low = CreateLoad(ir_builder(), out_low_ptr); + auto out_high = ir_builder()->CreateLoad(types()->i64_type(), out_high_ptr); + auto out_low = ir_builder()->CreateLoad(types()->i64_type(), out_low_ptr); auto sum = ValueSplit(out_high, out_low).AsInt128(this); ADD_TRACE_128("AddLarge : sum", sum); return sum; @@ -445,8 +446,8 @@ llvm::Value* DecimalIR::CallDecimalFunction(const std::string& function_name, // Make call to pre-compiled IR function. ir_builder()->CreateCall(module()->getFunction(function_name), dis_assembled_args); - auto out_high = CreateLoad(ir_builder(), out_high_ptr); - auto out_low = CreateLoad(ir_builder(), out_low_ptr); + auto out_high = ir_builder()->CreateLoad(i64, out_high_ptr); + auto out_low = ir_builder()->CreateLoad(i64, out_low_ptr); result = ValueSplit(out_high, out_low).AsInt128(this); } else { DCHECK_NE(return_type, types()->void_type()); diff --git a/cpp/src/gandiva/engine.cc b/cpp/src/gandiva/engine.cc index 3bd52917756..8d1057e9c57 100644 --- a/cpp/src/gandiva/engine.cc +++ b/cpp/src/gandiva/engine.cc @@ -234,11 +234,11 @@ Status Engine::LoadPreCompiledIR() { Status::CodeGenError("Could not load module from IR: ", buffer_or_error.getError().message())); - std::unique_ptr buffer = move(buffer_or_error.get()); + std::unique_ptr buffer = std::move(buffer_or_error.get()); /// Parse the IR module. llvm::Expected> module_or_error = - llvm::getOwningLazyBitcodeModule(move(buffer), *context()); + llvm::getOwningLazyBitcodeModule(std::move(buffer), *context()); if (!module_or_error) { // NOTE: llvm::handleAllErrors() fails linking with RTTI-disabled LLVM builds // (ARROW-5148) @@ -247,14 +247,14 @@ Status Engine::LoadPreCompiledIR() { stream << module_or_error.takeError(); return Status::CodeGenError(stream.str()); } - std::unique_ptr ir_module = move(module_or_error.get()); + std::unique_ptr ir_module = std::move(module_or_error.get()); // set dataLayout SetDataLayout(ir_module.get()); ARROW_RETURN_IF(llvm::verifyModule(*ir_module, &llvm::errs()), Status::CodeGenError("verify of IR Module failed")); - ARROW_RETURN_IF(llvm::Linker::linkModules(*module_, move(ir_module)), + ARROW_RETURN_IF(llvm::Linker::linkModules(*module_, std::move(ir_module)), Status::CodeGenError("failed to link IR Modules")); return Status::OK(); diff --git a/cpp/src/gandiva/engine_llvm_test.cc b/cpp/src/gandiva/engine_llvm_test.cc index 0bf6413cf65..9baaa82d2e0 100644 --- a/cpp/src/gandiva/engine_llvm_test.cc +++ b/cpp/src/gandiva/engine_llvm_test.cc @@ -80,8 +80,9 @@ class TestEngine : public ::testing::Test { loop_var->addIncoming(loop_update, loop_body); // get the current value - llvm::Value* offset = CreateGEP(builder, arg_elements, loop_var, "offset"); - llvm::Value* current_value = CreateLoad(builder, offset, "value"); + llvm::Value* offset = + builder->CreateGEP(types->i64_type(), arg_elements, loop_var, "offset"); + llvm::Value* current_value = builder->CreateLoad(types->i64_type(), offset, "value"); // setup sum PHI llvm::Value* sum_update = builder->CreateAdd(sum, current_value, "sum+ith"); diff --git a/cpp/src/gandiva/llvm_generator.cc b/cpp/src/gandiva/llvm_generator.cc index 58efef9676f..06159099745 100644 --- a/cpp/src/gandiva/llvm_generator.cc +++ b/cpp/src/gandiva/llvm_generator.cc @@ -26,6 +26,7 @@ #include "gandiva/dex.h" #include "gandiva/expr_decomposer.h" #include "gandiva/expression.h" +#include "gandiva/llvm_types.h" #include "gandiva/lvalue.h" namespace gandiva { @@ -162,18 +163,18 @@ Status LLVMGenerator::Execute(const arrow::RecordBatch& record_batch, return Status::OK(); } -llvm::Value* LLVMGenerator::LoadVectorAtIndex(llvm::Value* arg_addrs, int idx, - const std::string& name) { +llvm::Value* LLVMGenerator::LoadVectorAtIndex(llvm::Value* arg_addrs, llvm::Type* type, + int idx, const std::string& name) { auto* idx_val = types()->i32_constant(idx); - auto* offset = CreateGEP(ir_builder(), arg_addrs, idx_val, name + "_mem_addr"); - return CreateLoad(ir_builder(), offset, name + "_mem"); + auto* offset = ir_builder()->CreateGEP(type, arg_addrs, idx_val, name + "_mem_addr"); + return ir_builder()->CreateLoad(type, offset, name + "_mem"); } /// Get reference to validity array at specified index in the args list. llvm::Value* LLVMGenerator::GetValidityReference(llvm::Value* arg_addrs, int idx, FieldPtr field) { const std::string& name = field->name(); - llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name); + llvm::Value* load = LoadVectorAtIndex(arg_addrs, types()->i64_type(), idx, name); return ir_builder()->CreateIntToPtr(load, types()->i64_ptr_type(), name + "_varray"); } @@ -181,7 +182,7 @@ llvm::Value* LLVMGenerator::GetValidityReference(llvm::Value* arg_addrs, int idx llvm::Value* LLVMGenerator::GetDataBufferPtrReference(llvm::Value* arg_addrs, int idx, FieldPtr field) { const std::string& name = field->name(); - llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name); + llvm::Value* load = LoadVectorAtIndex(arg_addrs, types()->i64_type(), idx, name); return ir_builder()->CreateIntToPtr(load, types()->i8_ptr_type(), name + "_buf_ptr"); } @@ -189,7 +190,7 @@ llvm::Value* LLVMGenerator::GetDataBufferPtrReference(llvm::Value* arg_addrs, in llvm::Value* LLVMGenerator::GetDataReference(llvm::Value* arg_addrs, int idx, FieldPtr field) { const std::string& name = field->name(); - llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name); + llvm::Value* load = LoadVectorAtIndex(arg_addrs, types()->i64_type(), idx, name); llvm::Type* base_type = types()->DataVecType(field->type()); llvm::Value* ret; if (base_type->isPointerTy()) { @@ -205,13 +206,13 @@ llvm::Value* LLVMGenerator::GetDataReference(llvm::Value* arg_addrs, int idx, llvm::Value* LLVMGenerator::GetOffsetsReference(llvm::Value* arg_addrs, int idx, FieldPtr field) { const std::string& name = field->name(); - llvm::Value* load = LoadVectorAtIndex(arg_addrs, idx, name); + llvm::Value* load = LoadVectorAtIndex(arg_addrs, types()->i64_type(), idx, name); return ir_builder()->CreateIntToPtr(load, types()->i32_ptr_type(), name + "_oarray"); } /// Get reference to local bitmap array at specified index in the args list. llvm::Value* LLVMGenerator::GetLocalBitMapReference(llvm::Value* arg_bitmaps, int idx) { - llvm::Value* load = LoadVectorAtIndex(arg_bitmaps, idx, ""); + llvm::Value* load = LoadVectorAtIndex(arg_bitmaps, types()->i64_type(), idx, ""); return ir_builder()->CreateIntToPtr(load, types()->i64_ptr_type(), std::to_string(idx) + "_lbmap"); } @@ -278,16 +279,21 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count, arguments.push_back(types()->i64_ptr_type()); // offsets arguments.push_back(types()->i64_ptr_type()); // bitmaps arguments.push_back(types()->i64_ptr_type()); // holders + llvm::Type* selection_vector_type; switch (selection_vector_mode) { case SelectionVector::MODE_NONE: case SelectionVector::MODE_UINT16: arguments.push_back(types()->ptr_type(types()->i16_type())); + selection_vector_type = types()->i16_type(); break; case SelectionVector::MODE_UINT32: arguments.push_back(types()->i32_ptr_type()); + selection_vector_type = types()->i32_type(); break; case SelectionVector::MODE_UINT64: arguments.push_back(types()->i64_ptr_type()); + selection_vector_type = types()->i64_type(); + break; } arguments.push_back(types()->i64_type()); // ctx_ptr arguments.push_back(types()->i64_type()); // nrec @@ -338,8 +344,9 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count, std::vector slice_offsets; for (int idx = 0; idx < buffer_count; idx++) { - auto offsetAddr = CreateGEP(builder, arg_addr_offsets, types()->i32_constant(idx)); - auto offset = CreateLoad(builder, offsetAddr); + auto offsetAddr = builder->CreateGEP(types()->i64_type(), arg_addr_offsets, + types()->i32_constant(idx)); + auto offset = builder->CreateLoad(types()->i64_type(), offsetAddr); slice_offsets.push_back(offset); } @@ -351,9 +358,11 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count, llvm::Value* position_var = loop_var; if (selection_vector_mode != SelectionVector::MODE_NONE) { + auto selection_vector_addr = + builder->CreateGEP(selection_vector_type, arg_selection_vector, loop_var); position_var = builder->CreateIntCast( - CreateLoad(builder, CreateGEP(builder, arg_selection_vector, loop_var), - "uncasted_position_var"), + builder->CreateLoad(selection_vector_type, selection_vector_addr, + "uncasted_position_var"), types()->i64_type(), true, "position_var"); } @@ -378,7 +387,8 @@ Status LLVMGenerator::CodeGenExprValue(DexPtr value_expr, int buffer_count, SetPackedBitValue(output_ref, loop_var, output_value->data()); } else if (arrow::is_primitive(output_type_id) || output_type_id == arrow::Type::DECIMAL) { - llvm::Value* slot_offset = CreateGEP(builder, output_ref, loop_var); + auto slot_offset = + builder->CreateGEP(types()->IRType(output_type_id), output_ref, loop_var); builder->CreateStore(output_value->data(), slot_offset); } else if (arrow::is_binary_like(output_type_id)) { // Var-len output. Make a function call to populate the data. @@ -564,6 +574,7 @@ LLVMGenerator::Visitor::Visitor(LLVMGenerator* generator, llvm::Function* functi void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) { llvm::IRBuilder<>* builder = ir_builder(); + auto types = generator_->types(); llvm::Value* slot_ref = GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field()); llvm::Value* slot_index = builder->CreateAdd(loop_var_, GetSliceOffset(dex.DataIdx())); llvm::Value* slot_value; @@ -576,15 +587,16 @@ void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) { break; case arrow::Type::DECIMAL: { - auto slot_offset = CreateGEP(builder, slot_ref, slot_index); - slot_value = CreateLoad(builder, slot_offset, dex.FieldName()); + auto slot_offset = builder->CreateGEP(types->i128_type(), slot_ref, slot_index); + slot_value = builder->CreateLoad(types->i128_type(), slot_offset, dex.FieldName()); lvalue = generator_->BuildDecimalLValue(slot_value, dex.FieldType()); break; } default: { - auto slot_offset = CreateGEP(builder, slot_ref, slot_index); - slot_value = CreateLoad(builder, slot_offset, dex.FieldName()); + auto type = types->IRType(dex.FieldType()->id()); + auto slot_offset = builder->CreateGEP(type, slot_ref, slot_index); + slot_value = builder->CreateLoad(type, slot_offset, dex.FieldName()); lvalue = std::make_shared(slot_value); break; } @@ -597,6 +609,7 @@ void LLVMGenerator::Visitor::Visit(const VectorReadFixedLenValueDex& dex) { void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) { llvm::IRBuilder<>* builder = ir_builder(); llvm::Value* slot; + auto types = generator_->types(); // compute len from the offsets array. llvm::Value* offsets_slot_ref = @@ -605,14 +618,15 @@ void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) { builder->CreateAdd(loop_var_, GetSliceOffset(dex.OffsetsIdx())); // => offset_start = offsets[loop_var] - slot = CreateGEP(builder, offsets_slot_ref, offsets_slot_index); - llvm::Value* offset_start = CreateLoad(builder, slot, "offset_start"); + slot = builder->CreateGEP(types->i32_type(), offsets_slot_ref, offsets_slot_index); + llvm::Value* offset_start = + builder->CreateLoad(types->i32_type(), slot, "offset_start"); // => offset_end = offsets[loop_var + 1] llvm::Value* offsets_slot_index_next = builder->CreateAdd( offsets_slot_index, generator_->types()->i64_constant(1), "loop_var+1"); - slot = CreateGEP(builder, offsets_slot_ref, offsets_slot_index_next); - llvm::Value* offset_end = CreateLoad(builder, slot, "offset_end"); + slot = builder->CreateGEP(types->i32_type(), offsets_slot_ref, offsets_slot_index_next); + auto offset_end = builder->CreateLoad(types->i32_type(), slot, "offset_end"); // => len_value = offset_end - offset_start llvm::Value* len_value = @@ -621,7 +635,7 @@ void LLVMGenerator::Visitor::Visit(const VectorReadVarLenValueDex& dex) { // get the data from the data array, at offset 'offset_start'. llvm::Value* data_slot_ref = GetBufferReference(dex.DataIdx(), kBufferTypeData, dex.Field()); - llvm::Value* data_value = CreateGEP(builder, data_slot_ref, offset_start); + auto data_value = builder->CreateGEP(types->i8_type(), data_slot_ref, offset_start); ADD_VISITOR_TRACE("visit var-len data vector " + dex.FieldName() + " len %T", len_value); result_.reset(new LValue(data_value, len_value)); @@ -831,7 +845,7 @@ void LLVMGenerator::Visitor::Visit(const NullableInternalFuncDex& dex) { result_ = BuildFunctionCall(native_function, arrow_return_type, ¶ms); // load the result validity and truncate to i1. - llvm::Value* result_valid_i8 = CreateLoad(builder, result_valid_ptr); + auto result_valid_i8 = builder->CreateLoad(types->i8_type(), result_valid_ptr); llvm::Value* result_valid = builder->CreateTrunc(result_valid_i8, types->i1_type()); // set validity bit in the local bitmap. @@ -1038,7 +1052,7 @@ void LLVMGenerator::Visitor::VisitInExpression(const InExprDexBase& dex) { builder->SetInsertPoint(entry_block_); llvm::Value* in_holder = generator_->LoadVectorAtIndex( - arg_holder_ptrs_, dex_instance.get_holder_idx(), "in_holder"); + arg_holder_ptrs_, types->i64_type(), dex_instance.get_holder_idx(), "in_holder"); builder->SetInsertPoint(saved_block); params.push_back(in_holder); @@ -1255,7 +1269,9 @@ LValuePtr LLVMGenerator::Visitor::BuildFunctionCall(const NativeFunction* func, ? decimalIR.CallDecimalFunction(func->pc_name(), llvm_return_type, *params) : generator_->AddFunctionCall(func->pc_name(), llvm_return_type, *params); auto value_len = - (result_len_ptr == nullptr) ? nullptr : CreateLoad(builder, result_len_ptr); + (result_len_ptr == nullptr) + ? nullptr + : builder->CreateLoad(result_len_ptr->getAllocatedType(), result_len_ptr); return std::make_shared(value, value_len); } } @@ -1278,8 +1294,8 @@ std::vector LLVMGenerator::Visitor::BuildParams( llvm::BasicBlock* saved_block = builder->GetInsertBlock(); builder->SetInsertPoint(entry_block_); - llvm::Value* holder = - generator_->LoadVectorAtIndex(arg_holder_ptrs_, holder_idx, "holder"); + auto holder = generator_->LoadVectorAtIndex( + arg_holder_ptrs_, generator_->types()->i64_type(), holder_idx, "holder"); builder->SetInsertPoint(saved_block); params.push_back(holder); diff --git a/cpp/src/gandiva/llvm_generator.h b/cpp/src/gandiva/llvm_generator.h index 693119128ea..fa13af74350 100644 --- a/cpp/src/gandiva/llvm_generator.h +++ b/cpp/src/gandiva/llvm_generator.h @@ -183,7 +183,7 @@ class GANDIVA_EXPORT LLVMGenerator { Status Add(const ExpressionPtr expr, const FieldDescriptorPtr output); /// Generate code to load the vector at specified index in the 'arg_addrs' array. - llvm::Value* LoadVectorAtIndex(llvm::Value* arg_addrs, int idx, + llvm::Value* LoadVectorAtIndex(llvm::Value* arg_addrs, llvm::Type* type, int idx, const std::string& name); /// Generate code to load the vector at specified index and cast it as bitmap. diff --git a/cpp/src/gandiva/llvm_includes.h b/cpp/src/gandiva/llvm_includes.h index 37f915eb571..3d455591895 100644 --- a/cpp/src/gandiva/llvm_includes.h +++ b/cpp/src/gandiva/llvm_includes.h @@ -41,16 +41,3 @@ #if defined(_MSC_VER) #pragma warning(pop) #endif - -// Workaround for deprecated builder methods as of LLVM 13: ARROW-14363 -inline llvm::Value* CreateGEP(llvm::IRBuilder<>* builder, llvm::Value* Ptr, - llvm::ArrayRef IdxList, - const llvm::Twine& Name = "") { - return builder->CreateGEP(Ptr->getType()->getScalarType()->getPointerElementType(), Ptr, - IdxList, Name); -} - -inline llvm::LoadInst* CreateLoad(llvm::IRBuilder<>* builder, llvm::Value* Ptr, - const llvm::Twine& Name = "") { - return builder->CreateLoad(Ptr->getType()->getPointerElementType(), Ptr, Name); -} From 4d3001b79daed6141edeb3b9691f553c468debf7 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Mon, 5 Sep 2022 02:52:37 -0400 Subject: [PATCH 126/133] ARROW-17610: [C++] Support additional source types in SourceNode --- cpp/src/arrow/compute/exec/options.h | 25 ++++ cpp/src/arrow/compute/exec/plan_test.cc | 142 ++++++++++++++++++++++ cpp/src/arrow/compute/exec/source_node.cc | 141 +++++++++++++++++++++ cpp/src/arrow/compute/exec/test_util.cc | 32 +++++ cpp/src/arrow/compute/exec/test_util.h | 12 ++ 5 files changed, 352 insertions(+) diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index c5edc0610c5..97516d3f86e 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -77,6 +77,31 @@ class ARROW_EXPORT TableSourceNodeOptions : public ExecNodeOptions { int64_t max_batch_size; }; +/// \brief An extended Source node which accepts a schema +/// +/// ItMaker is a maker of an iterator of tabular data. +template +class ARROW_EXPORT SchemaSourceNodeOptions : public ExecNodeOptions { + public: + SchemaSourceNodeOptions(std::shared_ptr schema, ItMaker it_maker) + : schema(schema), it_maker(it_maker) {} + + // the schema of the record batches from the iterator + std::shared_ptr schema; + + // maker of an iterator which acts as the data source + ItMaker it_maker; +}; + +using ExecBatchIteratorMaker = std::function>()>; +using ExecBatchSourceNodeOptions = SchemaSourceNodeOptions; + +using RecordBatchIteratorMaker = std::function>()>; +using RecordBatchSourceNodeOptions = SchemaSourceNodeOptions; + +using ArrayVectorIteratorMaker = std::function>()>; +using ArrayVectorSourceNodeOptions = SchemaSourceNodeOptions; + /// \brief Make a node which excludes some rows from batches passed through it /// /// filter_expression will be evaluated against each batch which is pushed to diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 1dd071975ee..5834e813dd7 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -296,6 +296,148 @@ TEST(ExecPlanExecution, TableSourceSinkError) { Raises(StatusCode::Invalid, HasSubstr("batch_size > 0"))); } +TEST(ExecPlanExecution, ArrayVectorSourceSink) { + for (int num_threads : {1, 4}) { + ASSERT_OK_AND_ASSIGN(auto io_executor, + arrow::internal::ThreadPool::Make(num_threads)); + ExecContext exec_context(default_memory_pool(), io_executor.get()); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_context)); + AsyncGenerator> sink_gen; + + auto exp_batches = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN(auto arrayvecs, ToArrayVectors(exp_batches)); + auto arrayvec_it_maker = [&arrayvecs]() { + return MakeVectorIterator>(arrayvecs); + }; + + ASSERT_OK(Declaration::Sequence( + { + {"array_source", ArrayVectorSourceNodeOptions{exp_batches.schema, + arrayvec_it_maker}}, + {"sink", SinkNodeOptions{&sink_gen}}, + }) + .AddToPlan(plan.get())); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + ASSERT_EQ(res, exp_batches.batches); + } +} + +TEST(ExecPlanExecution, ArrayVectorSourceSinkError) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + std::shared_ptr no_schema; + + auto exp_batches = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN(auto arrayvecs, ToArrayVectors(exp_batches)); + auto arrayvec_it_maker = [&arrayvecs]() { + return MakeVectorIterator>(arrayvecs); + }; + + auto null_executor_options = + ArrayVectorSourceNodeOptions{exp_batches.schema, arrayvec_it_maker}; + ASSERT_THAT(MakeExecNode("array_source", plan.get(), {}, null_executor_options), + Raises(StatusCode::Invalid, HasSubstr("not null"))); + + auto null_schema_options = ArrayVectorSourceNodeOptions{no_schema, arrayvec_it_maker}; + ASSERT_THAT(MakeExecNode("array_source", plan.get(), {}, null_schema_options), + Raises(StatusCode::Invalid, HasSubstr("not null"))); +} + +TEST(ExecPlanExecution, ExecBatchSourceSink) { + for (int num_threads : {1, 4}) { + ASSERT_OK_AND_ASSIGN(auto io_executor, + arrow::internal::ThreadPool::Make(num_threads)); + ExecContext exec_context(default_memory_pool(), io_executor.get()); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_context)); + AsyncGenerator> sink_gen; + + auto exp_batches = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN(auto exec_batches, ToExecBatches(exp_batches)); + auto exec_batch_it_maker = [&exec_batches]() { + return MakeVectorIterator>(exec_batches); + }; + + ASSERT_OK(Declaration::Sequence( + { + {"exec_source", ExecBatchSourceNodeOptions{exp_batches.schema, + exec_batch_it_maker}}, + {"sink", SinkNodeOptions{&sink_gen}}, + }) + .AddToPlan(plan.get())); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + ASSERT_EQ(res, exp_batches.batches); + } +} + +TEST(ExecPlanExecution, ExecBatchSourceSinkError) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + std::shared_ptr no_schema; + + auto exp_batches = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN(auto exec_batches, ToExecBatches(exp_batches)); + auto exec_batch_it_maker = [&exec_batches]() { + return MakeVectorIterator>(exec_batches); + }; + + auto null_executor_options = + ExecBatchSourceNodeOptions{exp_batches.schema, exec_batch_it_maker}; + ASSERT_THAT(MakeExecNode("exec_source", plan.get(), {}, null_executor_options), + Raises(StatusCode::Invalid, HasSubstr("not null"))); + + auto null_schema_options = ExecBatchSourceNodeOptions{no_schema, exec_batch_it_maker}; + ASSERT_THAT(MakeExecNode("exec_source", plan.get(), {}, null_schema_options), + Raises(StatusCode::Invalid, HasSubstr("not null"))); +} + +TEST(ExecPlanExecution, RecordBatchSourceSink) { + for (int num_threads : {1, 4}) { + ASSERT_OK_AND_ASSIGN(auto io_executor, + arrow::internal::ThreadPool::Make(num_threads)); + ExecContext exec_context(default_memory_pool(), io_executor.get()); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_context)); + AsyncGenerator> sink_gen; + + auto exp_batches = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN(auto record_batches, ToRecordBatches(exp_batches)); + auto record_batch_it_maker = [&record_batches]() { + return MakeVectorIterator>(record_batches); + }; + + ASSERT_OK(Declaration::Sequence({ + {"record_source", + RecordBatchSourceNodeOptions{ + exp_batches.schema, record_batch_it_maker}}, + {"sink", SinkNodeOptions{&sink_gen}}, + }) + .AddToPlan(plan.get())); + + ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); + ASSERT_EQ(res, exp_batches.batches); + } +} + +TEST(ExecPlanExecution, RecordBatchSourceSinkError) { + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); + std::shared_ptr no_schema; + + auto exp_batches = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN(auto record_batches, ToRecordBatches(exp_batches)); + auto record_batch_it_maker = [&record_batches]() { + return MakeVectorIterator>(record_batches); + }; + + auto null_executor_options = + RecordBatchSourceNodeOptions{exp_batches.schema, record_batch_it_maker}; + ASSERT_THAT(MakeExecNode("record_source", plan.get(), {}, null_executor_options), + Raises(StatusCode::Invalid, HasSubstr("not null"))); + + auto null_schema_options = + RecordBatchSourceNodeOptions{no_schema, record_batch_it_maker}; + ASSERT_THAT(MakeExecNode("record_source", plan.get(), {}, null_schema_options), + Raises(StatusCode::Invalid, HasSubstr("not null"))); +} + TEST(ExecPlanExecution, SinkNodeBackpressure) { std::optional batch = ExecBatchFromJSON({int32(), boolean()}, diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 1d51a5c1d28..e9244077e84 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -291,6 +291,144 @@ struct TableSourceNode : public SourceNode { } }; +template +struct SchemaSourceNode : public SourceNode { + SchemaSourceNode(ExecPlan* plan, std::shared_ptr schema, + arrow::AsyncGenerator> generator) + : SourceNode(plan, schema, generator) {} + + static Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 0, This::kKindName)); + const auto& cast_options = checked_cast(options); + auto& it_maker = cast_options.it_maker; + auto& schema = cast_options.schema; + + auto io_executor = plan->exec_context()->executor(); + auto it = it_maker(); + + RETURN_NOT_OK(ValidateSchemaSourceNodeInput(io_executor, schema, This::kKindName)); + ARROW_ASSIGN_OR_RAISE(auto generator, This::MakeGenerator(it, io_executor, schema)); + return plan->EmplaceNode(plan, schema, MakeOrderedGenerator(generator)); + } + + static arrow::Status ValidateSchemaSourceNodeInput( + arrow::internal::Executor* io_executor, const std::shared_ptr& schema, + const char* kKindName) { + if (schema == NULLPTR) { + return Status::Invalid(kKindName, " requires schema which is not null"); + } + if (io_executor == NULLPTR) { + return Status::Invalid(kKindName, " requires IO-executor which is not null"); + } + + return Status::OK(); + } + + template + static arrow::AsyncGenerator> MakeOrderedGenerator( + arrow::AsyncGenerator>& unordered_gen) { + using Enum = Enumerated>; + auto enum_gen = MakeEnumeratedGenerator(unordered_gen); + auto seq_gen = MakeSequencingGenerator( + enum_gen, + /*compare=*/[](const Enum& a, const Enum& b) { return a.index > b.index; }, + /*is_next=*/[](const Enum& a, const Enum& b) { return a.index + 1 == b.index; }, + /*initial_value=*/Enum{{}, 0, false}); + return MakeMappedGenerator(enum_gen, [](const Enum& e) { return e.value; }); + } +}; + +struct RecordBatchSourceNode + : public SchemaSourceNode { + using RecordBatchSchemaSourceNode = + SchemaSourceNode; + + using RecordBatchSchemaSourceNode::Make; + using RecordBatchSchemaSourceNode::RecordBatchSchemaSourceNode; + + const char* kind_name() const override { return kKindName; } + + static Result>> MakeGenerator( + Iterator>& batch_it, + arrow::internal::Executor* io_executor, const std::shared_ptr& schema) { + auto to_exec_batch = + [schema](const std::shared_ptr& batch) -> util::optional { + if (batch == NULLPTR || *batch->schema() != *schema) { + return util::nullopt; + } + return util::optional(ExecBatch(*batch)); + }; + auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(batch_it)); + return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); + } + + static const char kKindName[]; +}; + +const char RecordBatchSourceNode::kKindName[] = "RecordBatchSourceNode"; + +struct ExecBatchSourceNode + : public SchemaSourceNode { + using ExecBatchSchemaSourceNode = + SchemaSourceNode; + + using ExecBatchSchemaSourceNode::ExecBatchSchemaSourceNode; + using ExecBatchSchemaSourceNode::Make; + + const char* kind_name() const override { return kKindName; } + + static Result>> MakeGenerator( + Iterator>& batch_it, + arrow::internal::Executor* io_executor, const std::shared_ptr& schema) { + auto to_exec_batch = + [&schema](const std::shared_ptr& batch) -> util::optional { + return batch == NULLPTR ? util::nullopt : util::optional(*batch); + }; + auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(batch_it)); + return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); + } + + static const char kKindName[]; +}; + +const char ExecBatchSourceNode::kKindName[] = "ExecBatchSourceNode"; + +struct ArrayVectorSourceNode + : public SchemaSourceNode { + using ArrayVectorSchemaSourceNode = + SchemaSourceNode; + + using ArrayVectorSchemaSourceNode::ArrayVectorSchemaSourceNode; + using ArrayVectorSchemaSourceNode::Make; + + const char* kind_name() const override { return kKindName; } + + static Result>> MakeGenerator( + Iterator>& arrayvec_it, + arrow::internal::Executor* io_executor, const std::shared_ptr& schema) { + auto to_exec_batch = + [&schema]( + const std::shared_ptr& arrayvec) -> util::optional { + if (arrayvec == NULLPTR || arrayvec->size() == 0) { + return util::nullopt; + } + std::vector datumvec; + for (const auto& array : *arrayvec) { + datumvec.push_back(Datum(array)); + } + return util::optional( + ExecBatch(std::move(datumvec), (*arrayvec)[0]->length())); + }; + auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(arrayvec_it)); + return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); + } + + static const char kKindName[]; +}; + +const char ArrayVectorSourceNode::kKindName[] = "ArrayVectorSourceNode"; + } // namespace namespace internal { @@ -298,6 +436,9 @@ namespace internal { void RegisterSourceNode(ExecFactoryRegistry* registry) { DCHECK_OK(registry->AddFactory("source", SourceNode::Make)); DCHECK_OK(registry->AddFactory("table_source", TableSourceNode::Make)); + DCHECK_OK(registry->AddFactory("record_source", RecordBatchSourceNode::Make)); + DCHECK_OK(registry->AddFactory("exec_source", ExecBatchSourceNode::Make)); + DCHECK_OK(registry->AddFactory("array_source", ArrayVectorSourceNode::Make)); } } // namespace internal diff --git a/cpp/src/arrow/compute/exec/test_util.cc b/cpp/src/arrow/compute/exec/test_util.cc index efb91a708ab..7eede9b6bf8 100644 --- a/cpp/src/arrow/compute/exec/test_util.cc +++ b/cpp/src/arrow/compute/exec/test_util.cc @@ -258,6 +258,38 @@ BatchesWithSchema MakeBatchesFromString(const std::shared_ptr& schema, return out_batches; } +Result>> ToArrayVectors( + const BatchesWithSchema& batches_with_schema) { + std::vector> arrayvecs; + for (auto batch : batches_with_schema.batches) { + ARROW_ASSIGN_OR_RAISE(auto record_batch, + batch.ToRecordBatch(batches_with_schema.schema)); + arrayvecs.push_back(std::make_shared(record_batch->columns())); + } + return arrayvecs; +} + +Result>> ToExecBatches( + const BatchesWithSchema& batches_with_schema) { + std::vector> exec_batches; + for (auto batch : batches_with_schema.batches) { + auto exec_batch = std::make_shared(batch); + exec_batches.push_back(exec_batch); + } + return exec_batches; +} + +Result>> ToRecordBatches( + const BatchesWithSchema& batches_with_schema) { + std::vector> record_batches; + for (auto batch : batches_with_schema.batches) { + ARROW_ASSIGN_OR_RAISE(auto record_batch, + batch.ToRecordBatch(batches_with_schema.schema)); + record_batches.push_back(record_batch); + } + return record_batches; +} + Result> SortTableOnAllFields(const std::shared_ptr
& tab) { std::vector sort_keys; for (auto&& f : tab->schema()->fields()) { diff --git a/cpp/src/arrow/compute/exec/test_util.h b/cpp/src/arrow/compute/exec/test_util.h index ae7eac61e95..ccfd7d2431a 100644 --- a/cpp/src/arrow/compute/exec/test_util.h +++ b/cpp/src/arrow/compute/exec/test_util.h @@ -113,6 +113,18 @@ BatchesWithSchema MakeBatchesFromString(const std::shared_ptr& schema, const std::vector& json_strings, int multiplicity = 1); +ARROW_TESTING_EXPORT +Result>> ToArrayVectors( + const BatchesWithSchema& batches_with_schema); + +ARROW_TESTING_EXPORT +Result>> ToExecBatches( + const BatchesWithSchema& batches); + +ARROW_TESTING_EXPORT +Result>> ToRecordBatches( + const BatchesWithSchema& batches); + ARROW_TESTING_EXPORT Result> SortTableOnAllFields(const std::shared_ptr
& tab); From 66a4a295ecb40d57f4fe15a2bfa4d654d368c841 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Mon, 5 Sep 2022 10:19:16 -0400 Subject: [PATCH 127/133] fix source ordering --- cpp/src/arrow/compute/exec/options.h | 2 +- cpp/src/arrow/compute/exec/source_node.cc | 73 +++++++++++++++++------ cpp/src/arrow/util/async_generator.h | 2 +- 3 files changed, 57 insertions(+), 20 deletions(-) diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 97516d3f86e..7eabd3ef561 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -84,7 +84,7 @@ template class ARROW_EXPORT SchemaSourceNodeOptions : public ExecNodeOptions { public: SchemaSourceNodeOptions(std::shared_ptr schema, ItMaker it_maker) - : schema(schema), it_maker(it_maker) {} + : schema(schema), it_maker(std::move(it_maker)) {} // the schema of the record batches from the iterator std::shared_ptr schema; diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index e9244077e84..5aa947a31eb 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -309,7 +309,7 @@ struct SchemaSourceNode : public SourceNode { RETURN_NOT_OK(ValidateSchemaSourceNodeInput(io_executor, schema, This::kKindName)); ARROW_ASSIGN_OR_RAISE(auto generator, This::MakeGenerator(it, io_executor, schema)); - return plan->EmplaceNode(plan, schema, MakeOrderedGenerator(generator)); + return plan->EmplaceNode(plan, schema, generator); } static arrow::Status ValidateSchemaSourceNodeInput( @@ -326,16 +326,33 @@ struct SchemaSourceNode : public SourceNode { } template - static arrow::AsyncGenerator> MakeOrderedGenerator( - arrow::AsyncGenerator>& unordered_gen) { - using Enum = Enumerated>; - auto enum_gen = MakeEnumeratedGenerator(unordered_gen); - auto seq_gen = MakeSequencingGenerator( - enum_gen, + static Iterator> MakeEnumeratedIterator(Iterator it) { + struct { + int64_t index = 0; + Enumerated operator()(const Item& item) { + return Enumerated{item, index++, false}; + } + } enumerator; + return MakeMapIterator(std::move(enumerator), std::move(it)); + } + + template + static arrow::AsyncGenerator MakeUnenumeratedGenerator( + const arrow::AsyncGenerator>& enum_gen) { + using Enum = Enumerated; + return MakeMappedGenerator(enum_gen, [](const Enum& e) { return e.value; }); + } + + template + static arrow::AsyncGenerator MakeOrderedGenerator( + const arrow::AsyncGenerator>& unordered_gen) { + using Enum = Enumerated; + auto enum_gen = MakeSequencingGenerator( + unordered_gen, /*compare=*/[](const Enum& a, const Enum& b) { return a.index > b.index; }, /*is_next=*/[](const Enum& a, const Enum& b) { return a.index + 1 == b.index; }, - /*initial_value=*/Enum{{}, 0, false}); - return MakeMappedGenerator(enum_gen, [](const Enum& e) { return e.value; }); + /*initial_value=*/Enum{{}, 0}); + return MakeUnenumeratedGenerator(enum_gen); } }; @@ -344,9 +361,13 @@ struct RecordBatchSourceNode using RecordBatchSchemaSourceNode = SchemaSourceNode; - using RecordBatchSchemaSourceNode::Make; using RecordBatchSchemaSourceNode::RecordBatchSchemaSourceNode; + static Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + return RecordBatchSchemaSourceNode::Make(plan, inputs, options); + } + const char* kind_name() const override { return kKindName; } static Result>> MakeGenerator( @@ -360,7 +381,10 @@ struct RecordBatchSourceNode return util::optional(ExecBatch(*batch)); }; auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(batch_it)); - return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); + auto enum_it = MakeEnumeratedIterator(std::move(exec_batch_it)); + ARROW_ASSIGN_OR_RAISE(auto enum_gen, + MakeBackgroundGenerator(std::move(enum_it), io_executor)); + return MakeUnenumeratedGenerator(std::move(enum_gen)); } static const char kKindName[]; @@ -374,7 +398,11 @@ struct ExecBatchSourceNode SchemaSourceNode; using ExecBatchSchemaSourceNode::ExecBatchSchemaSourceNode; - using ExecBatchSchemaSourceNode::Make; + + static Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + return ExecBatchSchemaSourceNode::Make(plan, inputs, options); + } const char* kind_name() const override { return kKindName; } @@ -382,11 +410,14 @@ struct ExecBatchSourceNode Iterator>& batch_it, arrow::internal::Executor* io_executor, const std::shared_ptr& schema) { auto to_exec_batch = - [&schema](const std::shared_ptr& batch) -> util::optional { + [](const std::shared_ptr& batch) -> util::optional { return batch == NULLPTR ? util::nullopt : util::optional(*batch); }; auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(batch_it)); - return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); + auto enum_it = MakeEnumeratedIterator(std::move(exec_batch_it)); + ARROW_ASSIGN_OR_RAISE(auto enum_gen, + MakeBackgroundGenerator(std::move(enum_it), io_executor)); + return MakeUnenumeratedGenerator(std::move(enum_gen)); } static const char kKindName[]; @@ -400,7 +431,11 @@ struct ArrayVectorSourceNode SchemaSourceNode; using ArrayVectorSchemaSourceNode::ArrayVectorSchemaSourceNode; - using ArrayVectorSchemaSourceNode::Make; + + static Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + return ArrayVectorSchemaSourceNode::Make(plan, inputs, options); + } const char* kind_name() const override { return kKindName; } @@ -408,8 +443,7 @@ struct ArrayVectorSourceNode Iterator>& arrayvec_it, arrow::internal::Executor* io_executor, const std::shared_ptr& schema) { auto to_exec_batch = - [&schema]( - const std::shared_ptr& arrayvec) -> util::optional { + [](const std::shared_ptr& arrayvec) -> util::optional { if (arrayvec == NULLPTR || arrayvec->size() == 0) { return util::nullopt; } @@ -421,7 +455,10 @@ struct ArrayVectorSourceNode ExecBatch(std::move(datumvec), (*arrayvec)[0]->length())); }; auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(arrayvec_it)); - return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); + auto enum_it = MakeEnumeratedIterator(std::move(exec_batch_it)); + ARROW_ASSIGN_OR_RAISE(auto enum_gen, + MakeBackgroundGenerator(std::move(enum_it), io_executor)); + return MakeUnenumeratedGenerator(std::move(enum_gen)); } static const char kKindName[]; diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index 0d51208ac72..d64532ff53a 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -1499,7 +1499,7 @@ AsyncGenerator MakeConcatenatedGenerator(AsyncGenerator> so template struct Enumerated { T value; - int index; + int64_t index; bool last; }; From e11b9bdb0f3adc6573fe0e1192facf82b197aa8d Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Mon, 5 Sep 2022 15:47:09 -0400 Subject: [PATCH 128/133] add doc strings --- cpp/src/arrow/compute/exec/options.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 7eabd3ef561..08cd91eb926 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -94,11 +94,14 @@ class ARROW_EXPORT SchemaSourceNodeOptions : public ExecNodeOptions { }; using ExecBatchIteratorMaker = std::function>()>; +/// \brief An extended Source node which accepts a schema and exec-batches using ExecBatchSourceNodeOptions = SchemaSourceNodeOptions; using RecordBatchIteratorMaker = std::function>()>; +/// \brief An extended Source node which accepts a schema and record-batches using RecordBatchSourceNodeOptions = SchemaSourceNodeOptions; +/// \brief An extended Source node which accepts a schema and array-vectors using ArrayVectorIteratorMaker = std::function>()>; using ArrayVectorSourceNodeOptions = SchemaSourceNodeOptions; From 81d34458df648feb7f69ee654c0cc6317d9678e6 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Tue, 6 Sep 2022 03:29:40 -0400 Subject: [PATCH 129/133] fix test failure --- cpp/src/arrow/compute/exec/source_node.cc | 5 ++++- cpp/src/arrow/util/async_generator.h | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 5aa947a31eb..6fc6e92b96d 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -327,8 +327,11 @@ struct SchemaSourceNode : public SourceNode { template static Iterator> MakeEnumeratedIterator(Iterator it) { + // TODO: Should Enumerated<>.index be changed to int64_t? Currently, this change + // causes dataset unit-test failures + using index_t = decltype(Enumerated{}.index); struct { - int64_t index = 0; + index_t index = 0; Enumerated operator()(const Item& item) { return Enumerated{item, index++, false}; } diff --git a/cpp/src/arrow/util/async_generator.h b/cpp/src/arrow/util/async_generator.h index d64532ff53a..0d51208ac72 100644 --- a/cpp/src/arrow/util/async_generator.h +++ b/cpp/src/arrow/util/async_generator.h @@ -1499,7 +1499,7 @@ AsyncGenerator MakeConcatenatedGenerator(AsyncGenerator> so template struct Enumerated { T value; - int64_t index; + int index; bool last; }; From 8e3c91b0a6b416d47e892c96ca2c5b6c26d84618 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sat, 10 Sep 2022 17:09:30 -0400 Subject: [PATCH 130/133] remove ordering --- cpp/src/arrow/compute/exec/plan_test.cc | 12 +++--- cpp/src/arrow/compute/exec/source_node.cc | 48 ++--------------------- 2 files changed, 9 insertions(+), 51 deletions(-) diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 5834e813dd7..1e036a0fe2a 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -318,8 +318,8 @@ TEST(ExecPlanExecution, ArrayVectorSourceSink) { }) .AddToPlan(plan.get())); - ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); - ASSERT_EQ(res, exp_batches.batches); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); } } @@ -365,8 +365,8 @@ TEST(ExecPlanExecution, ExecBatchSourceSink) { }) .AddToPlan(plan.get())); - ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); - ASSERT_EQ(res, exp_batches.batches); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); } } @@ -412,8 +412,8 @@ TEST(ExecPlanExecution, RecordBatchSourceSink) { }) .AddToPlan(plan.get())); - ASSERT_FINISHES_OK_AND_ASSIGN(auto res, StartAndCollect(plan.get(), sink_gen)); - ASSERT_EQ(res, exp_batches.batches); + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); } } diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 6fc6e92b96d..94f8bb2b5a2 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -324,39 +324,6 @@ struct SchemaSourceNode : public SourceNode { return Status::OK(); } - - template - static Iterator> MakeEnumeratedIterator(Iterator it) { - // TODO: Should Enumerated<>.index be changed to int64_t? Currently, this change - // causes dataset unit-test failures - using index_t = decltype(Enumerated{}.index); - struct { - index_t index = 0; - Enumerated operator()(const Item& item) { - return Enumerated{item, index++, false}; - } - } enumerator; - return MakeMapIterator(std::move(enumerator), std::move(it)); - } - - template - static arrow::AsyncGenerator MakeUnenumeratedGenerator( - const arrow::AsyncGenerator>& enum_gen) { - using Enum = Enumerated; - return MakeMappedGenerator(enum_gen, [](const Enum& e) { return e.value; }); - } - - template - static arrow::AsyncGenerator MakeOrderedGenerator( - const arrow::AsyncGenerator>& unordered_gen) { - using Enum = Enumerated; - auto enum_gen = MakeSequencingGenerator( - unordered_gen, - /*compare=*/[](const Enum& a, const Enum& b) { return a.index > b.index; }, - /*is_next=*/[](const Enum& a, const Enum& b) { return a.index + 1 == b.index; }, - /*initial_value=*/Enum{{}, 0}); - return MakeUnenumeratedGenerator(enum_gen); - } }; struct RecordBatchSourceNode @@ -384,10 +351,7 @@ struct RecordBatchSourceNode return util::optional(ExecBatch(*batch)); }; auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(batch_it)); - auto enum_it = MakeEnumeratedIterator(std::move(exec_batch_it)); - ARROW_ASSIGN_OR_RAISE(auto enum_gen, - MakeBackgroundGenerator(std::move(enum_it), io_executor)); - return MakeUnenumeratedGenerator(std::move(enum_gen)); + return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); } static const char kKindName[]; @@ -417,10 +381,7 @@ struct ExecBatchSourceNode return batch == NULLPTR ? util::nullopt : util::optional(*batch); }; auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(batch_it)); - auto enum_it = MakeEnumeratedIterator(std::move(exec_batch_it)); - ARROW_ASSIGN_OR_RAISE(auto enum_gen, - MakeBackgroundGenerator(std::move(enum_it), io_executor)); - return MakeUnenumeratedGenerator(std::move(enum_gen)); + return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); } static const char kKindName[]; @@ -458,10 +419,7 @@ struct ArrayVectorSourceNode ExecBatch(std::move(datumvec), (*arrayvec)[0]->length())); }; auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(arrayvec_it)); - auto enum_it = MakeEnumeratedIterator(std::move(exec_batch_it)); - ARROW_ASSIGN_OR_RAISE(auto enum_gen, - MakeBackgroundGenerator(std::move(enum_it), io_executor)); - return MakeUnenumeratedGenerator(std::move(enum_gen)); + return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); } static const char kKindName[]; From 3f98c8670a5ec6912188aac088f9216593b832f4 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Sun, 11 Sep 2022 02:50:10 -0400 Subject: [PATCH 131/133] support RecordBatchReader maker --- cpp/src/arrow/compute/exec/options.h | 9 +++++---- cpp/src/arrow/record_batch.h | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/exec/options.h b/cpp/src/arrow/compute/exec/options.h index 08cd91eb926..657356de8e5 100644 --- a/cpp/src/arrow/compute/exec/options.h +++ b/cpp/src/arrow/compute/exec/options.h @@ -27,6 +27,7 @@ #include "arrow/compute/api_vector.h" #include "arrow/compute/exec.h" #include "arrow/compute/exec/expression.h" +#include "arrow/record_batch.h" #include "arrow/result.h" #include "arrow/util/async_generator.h" #include "arrow/util/async_util.h" @@ -93,6 +94,10 @@ class ARROW_EXPORT SchemaSourceNodeOptions : public ExecNodeOptions { ItMaker it_maker; }; +/// \brief An extended Source node which accepts a schema and array-vectors +using ArrayVectorIteratorMaker = std::function>()>; +using ArrayVectorSourceNodeOptions = SchemaSourceNodeOptions; + using ExecBatchIteratorMaker = std::function>()>; /// \brief An extended Source node which accepts a schema and exec-batches using ExecBatchSourceNodeOptions = SchemaSourceNodeOptions; @@ -101,10 +106,6 @@ using RecordBatchIteratorMaker = std::function; -/// \brief An extended Source node which accepts a schema and array-vectors -using ArrayVectorIteratorMaker = std::function>()>; -using ArrayVectorSourceNodeOptions = SchemaSourceNodeOptions; - /// \brief Make a node which excludes some rows from batches passed through it /// /// filter_expression will be evaluated against each batch which is pushed to diff --git a/cpp/src/arrow/record_batch.h b/cpp/src/arrow/record_batch.h index 8bc70322560..32c8e5fa795 100644 --- a/cpp/src/arrow/record_batch.h +++ b/cpp/src/arrow/record_batch.h @@ -217,7 +217,7 @@ struct ARROW_EXPORT RecordBatchWithMetadata { }; /// \brief Abstract interface for reading stream of record batches -class ARROW_EXPORT RecordBatchReader { +class ARROW_EXPORT RecordBatchReader : public Iterator> { public: using ValueType = std::shared_ptr; From 6e57054507b230adfd510b276d150dcf44d7cd98 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Wed, 21 Sep 2022 16:34:45 -0400 Subject: [PATCH 132/133] requested fixes --- cpp/src/arrow/compute/exec/plan_test.cc | 169 +++++++--------------- cpp/src/arrow/compute/exec/source_node.cc | 22 +-- 2 files changed, 61 insertions(+), 130 deletions(-) diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 1e036a0fe2a..8f0cae4a2ee 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -296,146 +296,83 @@ TEST(ExecPlanExecution, TableSourceSinkError) { Raises(StatusCode::Invalid, HasSubstr("batch_size > 0"))); } -TEST(ExecPlanExecution, ArrayVectorSourceSink) { - for (int num_threads : {1, 4}) { - ASSERT_OK_AND_ASSIGN(auto io_executor, - arrow::internal::ThreadPool::Make(num_threads)); - ExecContext exec_context(default_memory_pool(), io_executor.get()); - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_context)); - AsyncGenerator> sink_gen; - - auto exp_batches = MakeBasicBatches(); - ASSERT_OK_AND_ASSIGN(auto arrayvecs, ToArrayVectors(exp_batches)); - auto arrayvec_it_maker = [&arrayvecs]() { - return MakeVectorIterator>(arrayvecs); - }; - - ASSERT_OK(Declaration::Sequence( - { - {"array_source", ArrayVectorSourceNodeOptions{exp_batches.schema, - arrayvec_it_maker}}, - {"sink", SinkNodeOptions{&sink_gen}}, - }) - .AddToPlan(plan.get())); - - ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); - } -} - -TEST(ExecPlanExecution, ArrayVectorSourceSinkError) { +template +void test_source_sink_error( + std::string source_factory_name, + std::function>(const BatchesWithSchema&)> + to_elements) { ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); std::shared_ptr no_schema; auto exp_batches = MakeBasicBatches(); - ASSERT_OK_AND_ASSIGN(auto arrayvecs, ToArrayVectors(exp_batches)); - auto arrayvec_it_maker = [&arrayvecs]() { - return MakeVectorIterator>(arrayvecs); + ASSERT_OK_AND_ASSIGN(auto elements, to_elements(exp_batches)); + auto element_it_maker = [&elements]() { + return MakeVectorIterator(elements); }; - auto null_executor_options = - ArrayVectorSourceNodeOptions{exp_batches.schema, arrayvec_it_maker}; - ASSERT_THAT(MakeExecNode("array_source", plan.get(), {}, null_executor_options), - Raises(StatusCode::Invalid, HasSubstr("not null"))); + auto null_executor_options = OptionsType{exp_batches.schema, element_it_maker}; + ASSERT_OK(MakeExecNode(source_factory_name, plan.get(), {}, null_executor_options)); - auto null_schema_options = ArrayVectorSourceNodeOptions{no_schema, arrayvec_it_maker}; - ASSERT_THAT(MakeExecNode("array_source", plan.get(), {}, null_schema_options), + auto null_schema_options = OptionsType{no_schema, element_it_maker}; + ASSERT_THAT(MakeExecNode(source_factory_name, plan.get(), {}, null_schema_options), Raises(StatusCode::Invalid, HasSubstr("not null"))); } -TEST(ExecPlanExecution, ExecBatchSourceSink) { - for (int num_threads : {1, 4}) { - ASSERT_OK_AND_ASSIGN(auto io_executor, - arrow::internal::ThreadPool::Make(num_threads)); - ExecContext exec_context(default_memory_pool(), io_executor.get()); - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_context)); - AsyncGenerator> sink_gen; +template +void test_source_sink( + std::string source_factory_name, + std::function>(const BatchesWithSchema&)> + to_elements) { + ASSERT_OK_AND_ASSIGN(auto io_executor, arrow::internal::ThreadPool::Make(1)); + ExecContext exec_context(default_memory_pool(), io_executor.get()); + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_context)); + AsyncGenerator> sink_gen; - auto exp_batches = MakeBasicBatches(); - ASSERT_OK_AND_ASSIGN(auto exec_batches, ToExecBatches(exp_batches)); - auto exec_batch_it_maker = [&exec_batches]() { - return MakeVectorIterator>(exec_batches); - }; + auto exp_batches = MakeBasicBatches(); + ASSERT_OK_AND_ASSIGN(auto elements, to_elements(exp_batches)); + auto element_it_maker = [&elements]() { + return MakeVectorIterator(elements); + }; - ASSERT_OK(Declaration::Sequence( - { - {"exec_source", ExecBatchSourceNodeOptions{exp_batches.schema, - exec_batch_it_maker}}, - {"sink", SinkNodeOptions{&sink_gen}}, - }) - .AddToPlan(plan.get())); + ASSERT_OK(Declaration::Sequence({ + {source_factory_name, + OptionsType{exp_batches.schema, element_it_maker}}, + {"sink", SinkNodeOptions{&sink_gen}}, + }) + .AddToPlan(plan.get())); - ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); - } + ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), + Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); } -TEST(ExecPlanExecution, ExecBatchSourceSinkError) { - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - std::shared_ptr no_schema; +TEST(ExecPlanExecution, ArrayVectorSourceSink) { + test_source_sink, ArrayVectorSourceNodeOptions>( + "array_vector_source", ToArrayVectors); +} - auto exp_batches = MakeBasicBatches(); - ASSERT_OK_AND_ASSIGN(auto exec_batches, ToExecBatches(exp_batches)); - auto exec_batch_it_maker = [&exec_batches]() { - return MakeVectorIterator>(exec_batches); - }; +TEST(ExecPlanExecution, ArrayVectorSourceSinkError) { + test_source_sink_error, ArrayVectorSourceNodeOptions>( + "array_vector_source", ToArrayVectors); +} - auto null_executor_options = - ExecBatchSourceNodeOptions{exp_batches.schema, exec_batch_it_maker}; - ASSERT_THAT(MakeExecNode("exec_source", plan.get(), {}, null_executor_options), - Raises(StatusCode::Invalid, HasSubstr("not null"))); +TEST(ExecPlanExecution, ExecBatchSourceSink) { + test_source_sink, ExecBatchSourceNodeOptions>( + "exec_batch_source", ToExecBatches); +} - auto null_schema_options = ExecBatchSourceNodeOptions{no_schema, exec_batch_it_maker}; - ASSERT_THAT(MakeExecNode("exec_source", plan.get(), {}, null_schema_options), - Raises(StatusCode::Invalid, HasSubstr("not null"))); +TEST(ExecPlanExecution, ExecBatchSourceSinkError) { + test_source_sink_error, ExecBatchSourceNodeOptions>( + "exec_batch_source", ToExecBatches); } TEST(ExecPlanExecution, RecordBatchSourceSink) { - for (int num_threads : {1, 4}) { - ASSERT_OK_AND_ASSIGN(auto io_executor, - arrow::internal::ThreadPool::Make(num_threads)); - ExecContext exec_context(default_memory_pool(), io_executor.get()); - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_context)); - AsyncGenerator> sink_gen; - - auto exp_batches = MakeBasicBatches(); - ASSERT_OK_AND_ASSIGN(auto record_batches, ToRecordBatches(exp_batches)); - auto record_batch_it_maker = [&record_batches]() { - return MakeVectorIterator>(record_batches); - }; - - ASSERT_OK(Declaration::Sequence({ - {"record_source", - RecordBatchSourceNodeOptions{ - exp_batches.schema, record_batch_it_maker}}, - {"sink", SinkNodeOptions{&sink_gen}}, - }) - .AddToPlan(plan.get())); - - ASSERT_THAT(StartAndCollect(plan.get(), sink_gen), - Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches)))); - } + test_source_sink, RecordBatchSourceNodeOptions>( + "record_batch_source", ToRecordBatches); } TEST(ExecPlanExecution, RecordBatchSourceSinkError) { - ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make()); - std::shared_ptr no_schema; - - auto exp_batches = MakeBasicBatches(); - ASSERT_OK_AND_ASSIGN(auto record_batches, ToRecordBatches(exp_batches)); - auto record_batch_it_maker = [&record_batches]() { - return MakeVectorIterator>(record_batches); - }; - - auto null_executor_options = - RecordBatchSourceNodeOptions{exp_batches.schema, record_batch_it_maker}; - ASSERT_THAT(MakeExecNode("record_source", plan.get(), {}, null_executor_options), - Raises(StatusCode::Invalid, HasSubstr("not null"))); - - auto null_schema_options = - RecordBatchSourceNodeOptions{no_schema, record_batch_it_maker}; - ASSERT_THAT(MakeExecNode("record_source", plan.get(), {}, null_schema_options), - Raises(StatusCode::Invalid, HasSubstr("not null"))); + test_source_sink_error, RecordBatchSourceNodeOptions>( + "record_batch_source", ToRecordBatches); } TEST(ExecPlanExecution, SinkNodeBackpressure) { diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 94f8bb2b5a2..097b7d8f1d2 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -25,6 +25,7 @@ #include "arrow/compute/exec/util.h" #include "arrow/compute/exec_internal.h" #include "arrow/datum.h" +#include "arrow/io/util_internal.h" #include "arrow/result.h" #include "arrow/table.h" #include "arrow/util/async_generator.h" @@ -307,22 +308,15 @@ struct SchemaSourceNode : public SourceNode { auto io_executor = plan->exec_context()->executor(); auto it = it_maker(); - RETURN_NOT_OK(ValidateSchemaSourceNodeInput(io_executor, schema, This::kKindName)); - ARROW_ASSIGN_OR_RAISE(auto generator, This::MakeGenerator(it, io_executor, schema)); - return plan->EmplaceNode(plan, schema, generator); - } - - static arrow::Status ValidateSchemaSourceNodeInput( - arrow::internal::Executor* io_executor, const std::shared_ptr& schema, - const char* kKindName) { if (schema == NULLPTR) { - return Status::Invalid(kKindName, " requires schema which is not null"); + return Status::Invalid(This::kKindName, " requires schema which is not null"); } if (io_executor == NULLPTR) { - return Status::Invalid(kKindName, " requires IO-executor which is not null"); + io_executor = io::internal::GetIOThreadPool(); } - return Status::OK(); + ARROW_ASSIGN_OR_RAISE(auto generator, This::MakeGenerator(it, io_executor, schema)); + return plan->EmplaceNode(plan, schema, generator); } }; @@ -434,9 +428,9 @@ namespace internal { void RegisterSourceNode(ExecFactoryRegistry* registry) { DCHECK_OK(registry->AddFactory("source", SourceNode::Make)); DCHECK_OK(registry->AddFactory("table_source", TableSourceNode::Make)); - DCHECK_OK(registry->AddFactory("record_source", RecordBatchSourceNode::Make)); - DCHECK_OK(registry->AddFactory("exec_source", ExecBatchSourceNode::Make)); - DCHECK_OK(registry->AddFactory("array_source", ArrayVectorSourceNode::Make)); + DCHECK_OK(registry->AddFactory("record_batch_source", RecordBatchSourceNode::Make)); + DCHECK_OK(registry->AddFactory("exec_batch_source", ExecBatchSourceNode::Make)); + DCHECK_OK(registry->AddFactory("array_vector_source", ArrayVectorSourceNode::Make)); } } // namespace internal From 541e53d653347e603ba6f5e99881f8e3faa08dc8 Mon Sep 17 00:00:00 2001 From: Yaron Gvili Date: Thu, 22 Sep 2022 05:56:46 -0400 Subject: [PATCH 133/133] rebase --- cpp/src/arrow/compute/exec/plan_test.cc | 2 +- cpp/src/arrow/compute/exec/source_node.cc | 24 +++++++++++------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/cpp/src/arrow/compute/exec/plan_test.cc b/cpp/src/arrow/compute/exec/plan_test.cc index 8f0cae4a2ee..87afbca7bf1 100644 --- a/cpp/src/arrow/compute/exec/plan_test.cc +++ b/cpp/src/arrow/compute/exec/plan_test.cc @@ -326,7 +326,7 @@ void test_source_sink( ASSERT_OK_AND_ASSIGN(auto io_executor, arrow::internal::ThreadPool::Make(1)); ExecContext exec_context(default_memory_pool(), io_executor.get()); ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_context)); - AsyncGenerator> sink_gen; + AsyncGenerator> sink_gen; auto exp_batches = MakeBasicBatches(); ASSERT_OK_AND_ASSIGN(auto elements, to_elements(exp_batches)); diff --git a/cpp/src/arrow/compute/exec/source_node.cc b/cpp/src/arrow/compute/exec/source_node.cc index 097b7d8f1d2..c1183f66060 100644 --- a/cpp/src/arrow/compute/exec/source_node.cc +++ b/cpp/src/arrow/compute/exec/source_node.cc @@ -295,7 +295,7 @@ struct TableSourceNode : public SourceNode { template struct SchemaSourceNode : public SourceNode { SchemaSourceNode(ExecPlan* plan, std::shared_ptr schema, - arrow::AsyncGenerator> generator) + arrow::AsyncGenerator> generator) : SourceNode(plan, schema, generator) {} static Result Make(ExecPlan* plan, std::vector inputs, @@ -334,15 +334,15 @@ struct RecordBatchSourceNode const char* kind_name() const override { return kKindName; } - static Result>> MakeGenerator( + static Result>> MakeGenerator( Iterator>& batch_it, arrow::internal::Executor* io_executor, const std::shared_ptr& schema) { auto to_exec_batch = - [schema](const std::shared_ptr& batch) -> util::optional { + [schema](const std::shared_ptr& batch) -> std::optional { if (batch == NULLPTR || *batch->schema() != *schema) { - return util::nullopt; + return std::nullopt; } - return util::optional(ExecBatch(*batch)); + return std::optional(ExecBatch(*batch)); }; auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(batch_it)); return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); @@ -367,12 +367,12 @@ struct ExecBatchSourceNode const char* kind_name() const override { return kKindName; } - static Result>> MakeGenerator( + static Result>> MakeGenerator( Iterator>& batch_it, arrow::internal::Executor* io_executor, const std::shared_ptr& schema) { auto to_exec_batch = - [](const std::shared_ptr& batch) -> util::optional { - return batch == NULLPTR ? util::nullopt : util::optional(*batch); + [](const std::shared_ptr& batch) -> std::optional { + return batch == NULLPTR ? std::nullopt : std::optional(*batch); }; auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(batch_it)); return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor); @@ -397,19 +397,19 @@ struct ArrayVectorSourceNode const char* kind_name() const override { return kKindName; } - static Result>> MakeGenerator( + static Result>> MakeGenerator( Iterator>& arrayvec_it, arrow::internal::Executor* io_executor, const std::shared_ptr& schema) { auto to_exec_batch = - [](const std::shared_ptr& arrayvec) -> util::optional { + [](const std::shared_ptr& arrayvec) -> std::optional { if (arrayvec == NULLPTR || arrayvec->size() == 0) { - return util::nullopt; + return std::nullopt; } std::vector datumvec; for (const auto& array : *arrayvec) { datumvec.push_back(Datum(array)); } - return util::optional( + return std::optional( ExecBatch(std::move(datumvec), (*arrayvec)[0]->length())); }; auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(arrayvec_it));