diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index d7e433f4844..402bf1aad98 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -384,6 +384,7 @@ if(ARROW_COMPUTE) compute/function_internal.cc compute/kernel.cc compute/registry.cc + compute/memory_resources.cc compute/kernels/aggregate_basic.cc compute/kernels/aggregate_mode.cc compute/kernels/aggregate_quantile.cc @@ -424,7 +425,9 @@ if(ARROW_COMPUTE) compute/exec/util.cc compute/exec/hash_join.cc compute/exec/hash_join_node.cc - compute/exec/task_util.cc) + compute/exec/task_util.cc + compute/exec/data_holder_node.cc + ) append_avx2_src(compute/kernels/aggregate_basic_avx2.cc) append_avx512_src(compute/kernels/aggregate_basic_avx512.cc) diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc index 50f1ad4fd0b..4a518551db6 100644 --- a/cpp/src/arrow/compute/exec.cc +++ b/cpp/src/arrow/compute/exec.cc @@ -34,6 +34,7 @@ #include "arrow/compute/exec_internal.h" #include "arrow/compute/function.h" #include "arrow/compute/kernel.h" +#include "arrow/compute/memory_resources.h" #include "arrow/compute/registry.h" #include "arrow/datum.h" #include "arrow/pretty_print.h" @@ -1015,9 +1016,13 @@ std::unique_ptr KernelExecutor::MakeScalarAggregate() { } // namespace detail ExecContext::ExecContext(MemoryPool* pool, ::arrow::internal::Executor* executor, - FunctionRegistry* func_registry) + FunctionRegistry* func_registry, + MemoryResources* memory_resources) : pool_(pool), executor_(executor) { this->func_registry_ = func_registry == nullptr ? GetFunctionRegistry() : func_registry; + + this->memory_resources_ = + memory_resources == nullptr ? GetMemoryResources(pool) : memory_resources; } CpuInfo* ExecContext::cpu_info() const { return CpuInfo::GetInstance(); } diff --git a/cpp/src/arrow/compute/exec.h b/cpp/src/arrow/compute/exec.h index 7707622bc53..4117c21b370 100644 --- a/cpp/src/arrow/compute/exec.h +++ b/cpp/src/arrow/compute/exec.h @@ -29,6 +29,7 @@ #include "arrow/array/data.h" #include "arrow/compute/exec/expression.h" +#include "arrow/compute/memory_resources.h" #include "arrow/datum.h" #include "arrow/memory_pool.h" #include "arrow/result.h" @@ -62,7 +63,8 @@ class ARROW_EXPORT ExecContext { // If no function registry passed, the default is used. explicit ExecContext(MemoryPool* pool = default_memory_pool(), ::arrow::internal::Executor* executor = NULLPTR, - FunctionRegistry* func_registry = NULLPTR); + FunctionRegistry* func_registry = NULLPTR, + MemoryResources* memory_resources = NULLPTR); /// \brief The MemoryPool used for allocations, default is /// default_memory_pool(). @@ -78,6 +80,11 @@ class ARROW_EXPORT ExecContext { /// registry provided by GetFunctionRegistry. FunctionRegistry* func_registry() const { return func_registry_; } + /// \brief The MemoryResources for looking up memory resources by memory level + /// and getting data holders to enable out of core processing. Defaults to the + /// instance provided by GetMemoryResources. + MemoryResources* memory_resources() const { return memory_resources_; } + // \brief Set maximum length unit of work for kernel execution. Larger // contiguous array inputs will be split into smaller chunks, and, if // possible and enabled, processed in parallel. The default chunksize is @@ -124,6 +131,7 @@ class ARROW_EXPORT ExecContext { int64_t exec_chunksize_ = std::numeric_limits::max(); bool preallocate_contiguous_ = true; bool use_threads_ = true; + MemoryResources* memory_resources_; }; ARROW_EXPORT ExecContext* default_exec_context(); diff --git a/cpp/src/arrow/compute/exec/CMakeLists.txt b/cpp/src/arrow/compute/exec/CMakeLists.txt index ccc36c093e8..e72281f903d 100644 --- a/cpp/src/arrow/compute/exec/CMakeLists.txt +++ b/cpp/src/arrow/compute/exec/CMakeLists.txt @@ -27,6 +27,7 @@ add_arrow_compute_test(expression_test add_arrow_compute_test(plan_test PREFIX "arrow-compute") add_arrow_compute_test(hash_join_node_test PREFIX "arrow-compute") add_arrow_compute_test(union_node_test PREFIX "arrow-compute") +add_arrow_compute_test(data_holder_node_test PREFIX "arrow-compute") add_arrow_compute_test(util_test PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/exec/data_holder_node.cc b/cpp/src/arrow/compute/exec/data_holder_node.cc new file mode 100644 index 00000000000..ea396ff8dfe --- /dev/null +++ b/cpp/src/arrow/compute/exec/data_holder_node.cc @@ -0,0 +1,220 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include "arrow/api.h" +#include "arrow/compute/api.h" + +#include "arrow/compute/memory_resources.h" +#include "arrow/util/async_generator.h" +#include "arrow/util/logging.h" + +#include "arrow/compute/exec.h" +#include "arrow/compute/exec/exec_plan.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/future.h" +#include "arrow/util/make_unique.h" +#include "arrow/util/thread_pool.h" + +namespace arrow { + +using internal::checked_cast; + +namespace compute { + +class DataHolderManager { + public: + explicit DataHolderManager(ExecContext* context) + : context_(context), gen_(), producer_(gen_.producer()) {} + + Status Push(const std::shared_ptr& batch) { + bool pushed = false; + auto resources = context_->memory_resources(); + for (auto memory_resource : resources->memory_resources()) { + auto memory_used = memory_resource->memory_used(); + if (memory_used < memory_resource->memory_limit()) { + ARROW_ASSIGN_OR_RAISE(auto data_holder, memory_resource->GetDataHolder(batch)); + this->producer_.Push(std::move(data_holder)); + pushed = true; + break; + } + } + if (!pushed) { + return Status::Invalid("No memory resource registered at all in the exec_context"); + } + return Status::OK(); + } + AsyncGenerator> generator() { return gen_; } + + public: + ExecContext* context_; + PushGenerator> gen_; + PushGenerator>::Producer producer_; +}; + +class DataHolderNode : public ExecNode { + public: + DataHolderNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, + std::shared_ptr output_schema, int num_outputs) + : ExecNode(plan, std::move(inputs), input_labels, std::move(output_schema), + /*num_outputs=*/num_outputs) { + executor_ = plan->exec_context()->executor(); + + data_holder_manager_ = + ::arrow::internal::make_unique(plan->exec_context()); + + auto status = task_group_.AddTask([this]() -> Result> { + ARROW_DCHECK(executor_ != nullptr); + return executor_->Submit(this->stop_source_.token(), [this] { + auto generator = this->data_holder_manager_->generator(); + auto iterator = MakeGeneratorIterator(std::move(generator)); + while (true) { + ARROW_ASSIGN_OR_RAISE(auto result, iterator.Next()); + if (IsIterationEnd(result)) { + break; + } + ARROW_ASSIGN_OR_RAISE(ExecBatch batch, result->Get()); + this->outputs_[0]->InputReceived(this, batch); + } + return Status::OK(); + }); + }); + if (!status.ok()) { + if (input_counter_.Cancel()) { + this->Finish(status); + } + inputs_[0]->StopProducing(this); + } + } + + void ErrorReceived(ExecNode* input, Status error) override { + DCHECK_EQ(input, inputs_[0]); + outputs_[0]->ErrorReceived(this, std::move(error)); + } + + void InputFinished(ExecNode* input, int total_batches) override { + DCHECK_EQ(input, inputs_[0]); + outputs_[0]->InputFinished(this, total_batches); + if (input_counter_.SetTotal(total_batches)) { + this->Finish(); + } + } + + static Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + auto schema = inputs[0]->output_schema(); + return plan->EmplaceNode(plan, std::move(inputs), + std::vector{"target"}, + std::move(schema), /*num_outputs=*/1); + } + + const char* kind_name() const override { return "DataHolderNode"; } + + void InputReceived(ExecNode* input, ExecBatch batch) override { + if (finished_.is_finished()) { + return; + } + auto status = task_group_.AddTask([this, batch]() -> Result> { + return this->executor_->Submit(this->stop_source_.token(), [this, batch]() { + auto pool = this->plan()->exec_context()->memory_pool(); + ARROW_ASSIGN_OR_RAISE(auto record_batch, + batch.ToRecordBatch(this->output_schema(), pool)); + Status status = data_holder_manager_->Push(record_batch); + if (ErrorIfNotOk(status)) { + return status; + } + if (this->input_counter_.Increment()) { + this->Finish(status); + } + return Status::OK(); + }); + }); + if (!status.ok()) { + if (input_counter_.Cancel()) { + this->Finish(status); + } + inputs_[0]->StopProducing(this); + return; + } + } + + Status StartProducing() override { return Status::OK(); } + + void PauseProducing(ExecNode* output) override {} + + void ResumeProducing(ExecNode* output) override {} + + void StopProducing(ExecNode* output) override { + DCHECK_EQ(output, outputs_[0]); + StopProducing(); + } + + void StopProducing() override { + if (executor_) { + this->stop_source_.RequestStop(); + } + if (input_counter_.Cancel()) { + this->Finish(); + } + inputs_[0]->StopProducing(this); + } + + Future<> finished() override { return finished_; } + + std::string ToStringExtra() const override { return ""; } + + protected: + void Finish(Status finish_st = Status::OK()) { + this->data_holder_manager_->producer_.Close(); + + task_group_.End().AddCallback([this, finish_st](const Status& st) { + Status final_status = finish_st & st; + this->finished_.MarkFinished(final_status); + }); + } + + protected: + // Counter for the number of batches received + AtomicCounter input_counter_; + + // Future to sync finished + Future<> finished_ = Future<>::Make(); + + // The task group for the corresponding batches + util::AsyncTaskGroup task_group_; + + ::arrow::internal::Executor* executor_; + + // Variable used to cancel remaining tasks in the executor + StopSource stop_source_; + + std::unique_ptr data_holder_manager_; +}; + +namespace internal { + +void RegisterDataHolderNode(ExecFactoryRegistry* registry) { + DCHECK_OK(registry->AddFactory("data_holder", DataHolderNode::Make)); +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/data_holder_node_test.cc b/cpp/src/arrow/compute/exec/data_holder_node_test.cc new file mode 100644 index 00000000000..db51cc494b1 --- /dev/null +++ b/cpp/src/arrow/compute/exec/data_holder_node_test.cc @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/api.h" +#include "arrow/compute/exec/options.h" +#include "arrow/compute/exec/test_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/matchers.h" +#include "arrow/testing/random.h" + +using testing::UnorderedElementsAreArray; + +namespace arrow { +namespace compute { + +struct TestDataHolderNode : public ::testing::Test { + static constexpr int kNumBatches = 10; + + TestDataHolderNode() : rng_(0) {} + + std::shared_ptr GenerateRandomSchema(size_t num_inputs) { + static std::vector> some_arrow_types = { + arrow::null(), arrow::boolean(), arrow::int8(), arrow::int16(), + arrow::int32(), arrow::int64(), arrow::float16(), arrow::float32(), + arrow::float64(), arrow::utf8(), arrow::binary(), arrow::date32()}; + + std::vector> fields(num_inputs); + std::default_random_engine gen(42); + std::uniform_int_distribution types_dist( + 0, static_cast(some_arrow_types.size()) - 1); + for (size_t i = 0; i < num_inputs; i++) { + int random_index = types_dist(gen); + auto col_type = some_arrow_types.at(random_index); + fields[i] = + field("column_" + std::to_string(i) + "_" + col_type->ToString(), col_type); + } + return schema(fields); + } + + void GenerateBatchesFromSchema(const std::shared_ptr& schema, + size_t num_batches, BatchesWithSchema* out_batches, + int multiplicity = 1, int64_t batch_size = 4) { + if (num_batches == 0) { + auto empty_record_batch = ExecBatch(*rng_.BatchOf(schema->fields(), 0)); + out_batches->batches.push_back(empty_record_batch); + } else { + for (size_t j = 0; j < num_batches; j++) { + out_batches->batches.push_back( + ExecBatch(*rng_.BatchOf(schema->fields(), batch_size))); + } + } + + size_t batch_count = out_batches->batches.size(); + for (int repeat = 1; repeat < multiplicity; ++repeat) { + for (size_t i = 0; i < batch_count; ++i) { + out_batches->batches.push_back(out_batches->batches[i]); + } + } + out_batches->schema = schema; + } + + void CheckRunOutput(const std::vector& batches, + const BatchesWithSchema& exp_batches) { + ExecContext exec_context(default_memory_pool(), + ::arrow::internal::GetCpuThreadPool()); + + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_context)); + + Declaration union_decl{"union", ExecNodeOptions{}}; + + for (const auto& batch : batches) { + union_decl.inputs.emplace_back(Declaration{ + "source", SourceNodeOptions{batch.schema, batch.gen(/*parallel=*/true, + /*slow=*/false)}}); + } + AsyncGenerator> sink_gen; + + if (batches.size() == 0) { + ASSERT_RAISES(Invalid, Declaration::Sequence({union_decl, + {"data_holder", ExecNodeOptions{}}, + {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); + return; + } else { + ASSERT_OK(Declaration::Sequence({union_decl, + {"data_holder", ExecNodeOptions{}}, + {"sink", SinkNodeOptions{&sink_gen}}}) + .AddToPlan(plan.get())); + } + Future> actual = StartAndCollect(plan.get(), sink_gen); + + auto expected_matcher = + Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches))); + ASSERT_THAT(actual, expected_matcher); + } + + void CheckDataHolderExecNode(size_t num_input_nodes, size_t num_batches) { + auto random_schema = GenerateRandomSchema(num_input_nodes); + + std::vector> all_record_batches; + std::vector input_batches(num_input_nodes); + BatchesWithSchema exp_batches; + exp_batches.schema = random_schema; + for (size_t i = 0; i < num_input_nodes; i++) { + GenerateBatchesFromSchema(random_schema, num_batches, &input_batches[i]); + for (const auto& batch : input_batches[i].batches) { + exp_batches.batches.push_back(batch); + } + } + CheckRunOutput(input_batches, exp_batches); + } + + ::arrow::random::RandomArrayGenerator rng_; +}; + +TEST_F(TestDataHolderNode, TestNonEmpty) { + for (int64_t num_input_nodes : {1, 2, 4, 8}) { + this->CheckDataHolderExecNode(num_input_nodes, kNumBatches); + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index 7e7824d8524..8ea098ceafa 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -462,6 +462,7 @@ void RegisterUnionNode(ExecFactoryRegistry*); void RegisterAggregateNode(ExecFactoryRegistry*); void RegisterSinkNode(ExecFactoryRegistry*); void RegisterHashJoinNode(ExecFactoryRegistry*); +void RegisterDataHolderNode(ExecFactoryRegistry*); } // namespace internal @@ -476,6 +477,7 @@ ExecFactoryRegistry* default_exec_factory_registry() { internal::RegisterAggregateNode(this); internal::RegisterSinkNode(this); internal::RegisterHashJoinNode(this); + internal::RegisterDataHolderNode(this); } Result GetFactory(const std::string& factory_name) override { diff --git a/cpp/src/arrow/compute/memory_resources.cc b/cpp/src/arrow/compute/memory_resources.cc new file mode 100644 index 00000000000..20e195fafbc --- /dev/null +++ b/cpp/src/arrow/compute/memory_resources.cc @@ -0,0 +1,270 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/memory_resources.h" +#include "arrow/compute/exec.h" +#include "arrow/record_batch.h" +#include "arrow/table.h" +#include "arrow/util/logging.h" + +#include +#include +#include +#include + +#include +#include +#include +#include +#include "arrow/io/file.h" + +#ifdef __APPLE__ +#include +#include +#endif + +#ifdef __linux__ +#include +#include +#endif + +// Windows APIs +#include "arrow/util/windows_compatibility.h" + +namespace arrow { + +namespace compute { + +std::string MemoryLevelName(MemoryLevel memory_level) { + static const char* MemoryLevelNames[] = {ARROW_STRINGIFY(MemoryLevel::kDiskLevel), + ARROW_STRINGIFY(MemoryLevel::kCpuLevel), + ARROW_STRINGIFY(MemoryLevel::kGpuLevel)}; + + return MemoryLevelNames[static_cast(memory_level)]; +} + +std::string MemoryResource::ToString() const { return MemoryLevelName(memory_level_); } + +class CPUDataHolder : public DataHolder { + public: + explicit CPUDataHolder(const std::shared_ptr& record_batch) + : DataHolder(MemoryLevel::kCpuLevel), record_batch_(std::move(record_batch)) {} + + Result Get() override { return ExecBatch(*record_batch_); } + + private: + std::shared_ptr record_batch_; +}; + +namespace { + +std::string RandomString(std::size_t length) { + const std::string characters = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + std::random_device random_device; + std::mt19937 generator(random_device()); + std::uniform_int_distribution<> distribution(0, characters.size() - 1); + std::string random_string; + for (std::size_t i = 0; i < length; ++i) { + random_string += characters[distribution(generator)]; + } + return random_string; +} + +} // namespace + +Status StoreRecordBatch(const std::shared_ptr& record_batch, + const std::shared_ptr& filesystem, + const std::string& file_path) { + auto output = filesystem->OpenOutputStream(file_path).ValueOrDie(); + auto writer = + arrow::ipc::MakeFileWriter(output.get(), record_batch->schema()).ValueOrDie(); + ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*record_batch)); + return writer->Close(); +} +Result> RecoverRecordBatch( + const std::shared_ptr& filesystem, const std::string& file_path) { + ARROW_ASSIGN_OR_RAISE(auto input, filesystem->OpenInputFile(file_path)); + ARROW_ASSIGN_OR_RAISE(auto reader, arrow::ipc::feather::Reader::Open(input)); + std::shared_ptr table; + ARROW_RETURN_NOT_OK(reader->Read(&table)); + TableBatchReader batch_iter(*table); + ARROW_ASSIGN_OR_RAISE(auto batch, batch_iter.Next()); + return batch; +} + +class DiskDataHolder : public DataHolder { + public: + DiskDataHolder(const std::shared_ptr& record_batch, + MemoryPool* memory_pool) + : DataHolder(MemoryLevel::kDiskLevel), memory_pool_(memory_pool) { + std::string root_path; + std::string file_name = "data-holder-temp-" + RandomString(64) + ".feather"; + + filesystem_ = + arrow::fs::FileSystemFromUri(cache_storage_root_path, &root_path).ValueOrDie(); + + file_path_ = root_path + file_name; + status_ = StoreRecordBatch(record_batch, filesystem_, file_path_); + } + + Result Get() override { + ARROW_RETURN_NOT_OK(status_); + ARROW_ASSIGN_OR_RAISE(auto record_batch, RecoverRecordBatch(filesystem_, file_path_)); + return ExecBatch(*record_batch); + } + + private: + std::string file_path_; + Status status_; + MemoryPool* memory_pool_; + std::shared_ptr filesystem_; + const std::string cache_storage_root_path = "file:///tmp/"; +}; + +MemoryResources::~MemoryResources() {} + +std::unique_ptr MemoryResources::Make() { + return std::unique_ptr(new MemoryResources()); +} + +Status MemoryResources::AddMemoryResource(std::shared_ptr resource) { + auto level = static_cast(resource->memory_level()); + if (stats_[level] != nullptr) { + return Status::KeyError("Already have a resource type registered with name: ", + resource->ToString()); + } + stats_[level] = std::move(resource); + return Status::OK(); +} + +size_t MemoryResources::size() const { return stats_.size(); } + +Result MemoryResources::memory_resource(MemoryLevel memory_level) const { + auto level = static_cast(memory_level); + if (stats_[level] == nullptr) { + return Status::KeyError("No memory resource registered with level: ", + MemoryLevelName(memory_level)); + } + return stats_[level].get(); +} + +std::vector MemoryResources::memory_resources() const { + std::vector arr; + for (auto&& resource : stats_) { + if (resource != nullptr) { + arr.push_back(resource.get()); + } + } + return arr; +} + +namespace { + +size_t GetTotalMemorySize() { +#ifdef __APPLE__ + int mib[2]; + size_t physical_memory; + size_t length; + // Get the Physical memory size + mib[0] = CTL_HW; + mib[1] = HW_MEMSIZE; + length = sizeof(size_t); + sysctl(mib, 2, &physical_memory, &length, NULL, 0); + return physical_memory; +#elif defined(_MSC_VER) + MEMORYSTATUSEX status; + status.dwLength = sizeof(status); + GlobalMemoryStatusEx(&status); + return status.ullTotalPhys; +#else // Linux + struct sysinfo si; + sysinfo(&si); + return (size_t)si.freeram; +#endif +} + +struct CPUMemoryResource : public MemoryResource { + CPUMemoryResource(arrow::MemoryPool* pool, float memory_limit_threshold = 0.75) + : MemoryResource(MemoryLevel::kCpuLevel), pool_(pool) { + total_memory_size_ = GetTotalMemorySize(); + memory_limit_ = memory_limit_threshold * total_memory_size_; + } + + int64_t memory_used() override { return pool_->bytes_allocated(); } + + int64_t memory_limit() override { return memory_limit_; } + + Result> GetDataHolder( + const std::shared_ptr& batch) override { + auto data_holder = std::make_shared(batch); + return data_holder; + } + + private: + arrow::MemoryPool* pool_; + int64_t memory_limit_; + int64_t total_memory_size_; +}; + +class DiskMemoryResource : public MemoryResource { + public: + DiskMemoryResource(arrow::MemoryPool* pool) + : MemoryResource(MemoryLevel::kDiskLevel), pool_(pool) { + memory_used_ = 0; + memory_limit_ = std::numeric_limits::max(); + } + + int64_t memory_limit() override { return memory_limit_; } + + int64_t memory_used() override { return memory_used_; } + + Result> GetDataHolder( + const std::shared_ptr& batch) override { + auto data_holder = std::make_shared(batch, pool_); + return data_holder; + } + + private: + int64_t memory_used_; + int64_t memory_limit_; + arrow::MemoryPool* pool_; +}; + +static std::unique_ptr CreateBuiltInMemoryResources(MemoryPool* pool) { + auto resources = MemoryResources::Make(); + + // CPU MemoryLevel + auto cpu_level = std::make_shared(pool); + DCHECK_OK(resources->AddMemoryResource(std::move(cpu_level))); + + // Disk MemoryLevel + auto disk_level = std::make_shared(pool); + DCHECK_OK(resources->AddMemoryResource(std::move(disk_level))); + + return resources; +} + +} // namespace + +MemoryResources* GetMemoryResources(MemoryPool* pool) { + static auto resources = CreateBuiltInMemoryResources(pool); + return resources.get(); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/memory_resources.h b/cpp/src/arrow/compute/memory_resources.h new file mode 100644 index 00000000000..88e9902e977 --- /dev/null +++ b/cpp/src/arrow/compute/memory_resources.h @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "arrow/memory_pool.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/util/macros.h" + +#include +#include +#include +#include +#include + +namespace arrow { + +namespace compute { + +struct ExecBatch; + +enum class MemoryLevel : int { kGpuLevel, kCpuLevel, kDiskLevel, kNumLevels }; + +class ARROW_EXPORT DataHolder { + public: + explicit DataHolder(MemoryLevel memory_level) : memory_level_(memory_level) {} + + MemoryLevel memory_level() const { return memory_level_; }; + + virtual Result Get() = 0; + + private: + MemoryLevel memory_level_; +}; + +class ARROW_EXPORT MemoryResource { + public: + explicit MemoryResource(MemoryLevel memory_level) : memory_level_(memory_level) {} + + virtual ~MemoryResource() = default; + + MemoryLevel memory_level() const { return memory_level_; } + + std::string ToString() const; + + virtual int64_t memory_limit() = 0; + + virtual int64_t memory_used() = 0; + + virtual Result> GetDataHolder( + const std::shared_ptr& batch) = 0; + + private: + MemoryLevel memory_level_; +}; + +class ARROW_EXPORT MemoryResources { + public: + ~MemoryResources(); + + static std::unique_ptr Make(); + + Status AddMemoryResource(std::shared_ptr resource); + + size_t size() const; + + Result memory_resource(MemoryLevel level) const; + + std::vector memory_resources() const; + + private: + MemoryResources() {} + + private: + std::array, + static_cast(MemoryLevel::kNumLevels)> + stats_ = {}; +}; + +ARROW_EXPORT MemoryResources* GetMemoryResources(MemoryPool* pool); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/type_fwd.h b/cpp/src/arrow/compute/type_fwd.h index 127929ced58..659e2220823 100644 --- a/cpp/src/arrow/compute/type_fwd.h +++ b/cpp/src/arrow/compute/type_fwd.h @@ -42,6 +42,7 @@ struct KernelState; class Expression; class ExecNode; +class DataHolder; class ExecPlan; class ExecNodeOptions; class ExecFactoryRegistry;