diff --git a/cpp/src/arrow/chunked_array.h b/cpp/src/arrow/chunked_array.h index 2ace045c2bf..86d9b2b51fe 100644 --- a/cpp/src/arrow/chunked_array.h +++ b/cpp/src/arrow/chunked_array.h @@ -128,7 +128,7 @@ class ARROW_EXPORT ChunkedArray { /// there are zero chunks Result> View(const std::shared_ptr& type) const; - std::shared_ptr type() const { return type_; } + const std::shared_ptr& type() const { return type_; } /// \brief Determine if two chunked arrays are equal. /// diff --git a/cpp/src/arrow/compute/exec.cc b/cpp/src/arrow/compute/exec.cc index 7d6db9f58db..70093f30be6 100644 --- a/cpp/src/arrow/compute/exec.cc +++ b/cpp/src/arrow/compute/exec.cc @@ -594,6 +594,16 @@ class KernelExecutorImpl : public KernelExecutor { return out; } + Status CheckResultType(const Datum& out, const char* function_name) override { + const auto& type = out.type(); + if (type != nullptr && !type->Equals(output_descr_.type)) { + return Status::TypeError( + "kernel type result mismatch for function '", function_name, "': declared as ", + output_descr_.type->ToString(), ", actual is ", type->ToString()); + } + return Status::OK(); + } + ExecContext* exec_context() { return kernel_ctx_->exec_context(); } KernelState* state() { return kernel_ctx_->state(); } diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index 4aab64a46a4..67a9f3c40ff 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -540,9 +540,13 @@ Result ExecuteScalarExpression(const Expression& expr, const ExecBatch& i auto options = call->options.get(); RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, descrs, options})); - auto listener = std::make_shared(); - RETURN_NOT_OK(executor->Execute(arguments, listener.get())); - return executor->WrapResults(arguments, listener->values()); + compute::detail::DatumAccumulator listener; + RETURN_NOT_OK(executor->Execute(arguments, &listener)); + const auto out = executor->WrapResults(arguments, listener.values()); +#ifndef NDEBUG + DCHECK_OK(executor->CheckResultType(out, call->function_name.c_str())); +#endif + return out; } namespace { diff --git a/cpp/src/arrow/compute/exec_internal.h b/cpp/src/arrow/compute/exec_internal.h index 55daa243cd3..74124f02267 100644 --- a/cpp/src/arrow/compute/exec_internal.h +++ b/cpp/src/arrow/compute/exec_internal.h @@ -120,6 +120,9 @@ class ARROW_EXPORT KernelExecutor { virtual Datum WrapResults(const std::vector& args, const std::vector& outputs) = 0; + /// \brief Check the actual result type against the resolved output type + virtual Status CheckResultType(const Datum& out, const char* function_name) = 0; + static std::unique_ptr MakeScalar(); static std::unique_ptr MakeVector(); static std::unique_ptr MakeScalarAggregate(); diff --git a/cpp/src/arrow/compute/function.cc b/cpp/src/arrow/compute/function.cc index 05d14d03b16..555b518786b 100644 --- a/cpp/src/arrow/compute/function.cc +++ b/cpp/src/arrow/compute/function.cc @@ -30,6 +30,7 @@ #include "arrow/compute/registry.h" #include "arrow/datum.h" #include "arrow/util/cpu_info.h" +#include "arrow/util/logging.h" namespace arrow { @@ -230,9 +231,13 @@ Result Function::Execute(const std::vector& args, } RETURN_NOT_OK(executor->Init(&kernel_ctx, {kernel, inputs, options})); - auto listener = std::make_shared(); - RETURN_NOT_OK(executor->Execute(implicitly_cast_args, listener.get())); - return executor->WrapResults(implicitly_cast_args, listener->values()); + detail::DatumAccumulator listener; + RETURN_NOT_OK(executor->Execute(implicitly_cast_args, &listener)); + const auto out = executor->WrapResults(implicitly_cast_args, listener.values()); +#ifndef NDEBUG + DCHECK_OK(executor->CheckResultType(out, name_.c_str())); +#endif + return out; } Status Function::Validate() const { diff --git a/cpp/src/arrow/datum.cc b/cpp/src/arrow/datum.cc index dd10fce3e4d..d3ff6aba0af 100644 --- a/cpp/src/arrow/datum.cc +++ b/cpp/src/arrow/datum.cc @@ -86,7 +86,7 @@ std::shared_ptr Datum::make_array() const { return MakeArray(util::get>(this->value)); } -std::shared_ptr Datum::type() const { +const std::shared_ptr& Datum::type() const { if (this->kind() == Datum::ARRAY) { return util::get>(this->value)->type; } @@ -96,17 +96,19 @@ std::shared_ptr Datum::type() const { if (this->kind() == Datum::SCALAR) { return util::get>(this->value)->type; } - return nullptr; + static std::shared_ptr no_type; + return no_type; } -std::shared_ptr Datum::schema() const { +const std::shared_ptr& Datum::schema() const { if (this->kind() == Datum::RECORD_BATCH) { return util::get>(this->value)->schema(); } if (this->kind() == Datum::TABLE) { return util::get>(this->value)->schema(); } - return nullptr; + static std::shared_ptr no_schema; + return no_schema; } int64_t Datum::length() const { diff --git a/cpp/src/arrow/datum.h b/cpp/src/arrow/datum.h index 6ba6af7f79e..da851d917d8 100644 --- a/cpp/src/arrow/datum.h +++ b/cpp/src/arrow/datum.h @@ -251,12 +251,12 @@ struct ARROW_EXPORT Datum { /// \brief The value type of the variant, if any /// /// \return nullptr if no type - std::shared_ptr type() const; + const std::shared_ptr& type() const; /// \brief The schema of the variant, if any /// /// \return nullptr if no schema - std::shared_ptr schema() const; + const std::shared_ptr& schema() const; /// \brief The value length of the variant, if any /// diff --git a/cpp/src/arrow/table.h b/cpp/src/arrow/table.h index f1e5f23eed8..b313e926257 100644 --- a/cpp/src/arrow/table.h +++ b/cpp/src/arrow/table.h @@ -92,7 +92,7 @@ class ARROW_EXPORT Table { const std::shared_ptr& array); /// \brief Return the table schema - std::shared_ptr schema() const { return schema_; } + const std::shared_ptr& schema() const { return schema_; } /// \brief Return a column by index virtual std::shared_ptr column(int i) const = 0;