diff --git a/c_glib/arrow-glib/scalar.cpp b/c_glib/arrow-glib/scalar.cpp index def6b151483..f965b497030 100644 --- a/c_glib/arrow-glib/scalar.cpp +++ b/c_glib/arrow-glib/scalar.cpp @@ -1063,7 +1063,8 @@ garrow_base_binary_scalar_get_value(GArrowBaseBinaryScalar *scalar) if (!priv->value) { const auto arrow_scalar = std::static_pointer_cast( garrow_scalar_get_raw(GARROW_SCALAR(scalar))); - priv->value = garrow_buffer_new_raw(&(arrow_scalar->value)); + priv->value = garrow_buffer_new_raw( + const_cast *>(&(arrow_scalar->value))); } return priv->value; } @@ -1983,7 +1984,8 @@ garrow_base_list_scalar_get_value(GArrowBaseListScalar *scalar) if (!priv->value) { const auto arrow_scalar = std::static_pointer_cast( garrow_scalar_get_raw(GARROW_SCALAR(scalar))); - priv->value = garrow_array_new_raw(&(arrow_scalar->value)); + priv->value = garrow_array_new_raw( + const_cast *>(&(arrow_scalar->value))); } return priv->value; } diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index 21ac1a09f56..1656454aa4d 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -823,6 +824,41 @@ TEST_F(TestArray, TestFillFromScalar) { } } +// GH-40069: Data-race when concurrent calling ArraySpan::FillFromScalar of the same +// scalar instance. +TEST_F(TestArray, TestConcurrentFillFromScalar) { + for (auto type : TestArrayUtilitiesAgainstTheseTypes()) { + ARROW_SCOPED_TRACE("type = ", type->ToString()); + for (auto seed : {0u, 0xdeadbeef, 42u}) { + ARROW_SCOPED_TRACE("seed = ", seed); + + Field field("", type, /*nullable=*/true, + key_value_metadata({{"extension_allow_random_storage", "true"}})); + auto array = random::GenerateArray(field, 1, seed); + + ASSERT_OK_AND_ASSIGN(auto scalar, array->GetScalar(0)); + + // Lambda to create fill an ArraySpan with the scalar and use the ArraySpan a bit. + auto array_span_from_scalar = [&]() { + ArraySpan span(*scalar); + auto roundtripped_array = span.ToArray(); + ASSERT_OK(roundtripped_array->ValidateFull()); + + AssertArraysEqual(*array, *roundtripped_array); + ASSERT_OK_AND_ASSIGN(auto roundtripped_scalar, roundtripped_array->GetScalar(0)); + AssertScalarsEqual(*scalar, *roundtripped_scalar); + }; + + // Two concurrent calls to the lambda are just enough for TSAN to detect a race + // condition. + auto fut1 = std::async(std::launch::async, array_span_from_scalar); + auto fut2 = std::async(std::launch::async, array_span_from_scalar); + fut1.get(); + fut2.get(); + } + } +} + TEST_F(TestArray, ExtensionSpanRoundTrip) { // Other types are checked in MakeEmptyArray but MakeEmptyArray doesn't // work for extension types so we check that here diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc index 80c411dfa6a..ff3112ec1fc 100644 --- a/cpp/src/arrow/array/data.cc +++ b/cpp/src/arrow/array/data.cc @@ -283,25 +283,15 @@ void ArraySpan::SetMembers(const ArrayData& data) { namespace { -template -BufferSpan OffsetsForScalar(uint8_t* scratch_space, offset_type value_size) { - auto* offsets = reinterpret_cast(scratch_space); - offsets[0] = 0; - offsets[1] = static_cast(value_size); - static_assert(2 * sizeof(offset_type) <= 16); - return {scratch_space, sizeof(offset_type) * 2}; +BufferSpan OffsetsForScalar(uint8_t* scratch_space, int64_t offset_width) { + return {scratch_space, offset_width * 2}; } -template std::pair OffsetsAndSizesForScalar(uint8_t* scratch_space, - offset_type value_size) { + int64_t offset_width) { auto* offsets = scratch_space; - auto* sizes = scratch_space + sizeof(offset_type); - reinterpret_cast(offsets)[0] = 0; - reinterpret_cast(sizes)[0] = value_size; - static_assert(2 * sizeof(offset_type) <= 16); - return {BufferSpan{offsets, sizeof(offset_type)}, - BufferSpan{sizes, sizeof(offset_type)}}; + auto* sizes = scratch_space + offset_width; + return {BufferSpan{offsets, offset_width}, BufferSpan{sizes, offset_width}}; } int GetNumBuffers(const DataType& type) { @@ -415,26 +405,23 @@ void ArraySpan::FillFromScalar(const Scalar& value) { data_size = scalar.value->size(); } if (is_binary_like(type_id)) { - this->buffers[1] = - OffsetsForScalar(scalar.scratch_space_, static_cast(data_size)); + const auto& binary_scalar = checked_cast(value); + this->buffers[1] = OffsetsForScalar(binary_scalar.scratch_space_, sizeof(int32_t)); } else { // is_large_binary_like - this->buffers[1] = OffsetsForScalar(scalar.scratch_space_, data_size); + const auto& large_binary_scalar = checked_cast(value); + this->buffers[1] = + OffsetsForScalar(large_binary_scalar.scratch_space_, sizeof(int64_t)); } this->buffers[2].data = const_cast(data_buffer); this->buffers[2].size = data_size; } else if (type_id == Type::BINARY_VIEW || type_id == Type::STRING_VIEW) { - const auto& scalar = checked_cast(value); + const auto& scalar = checked_cast(value); this->buffers[1].size = BinaryViewType::kSize; this->buffers[1].data = scalar.scratch_space_; - static_assert(sizeof(BinaryViewType::c_type) <= sizeof(scalar.scratch_space_)); - auto* view = new (&scalar.scratch_space_) BinaryViewType::c_type; if (scalar.is_valid) { - *view = util::ToBinaryView(std::string_view{*scalar.value}, 0, 0); this->buffers[2] = internal::PackVariadicBuffers({&scalar.value, 1}); - } else { - *view = {}; } } else if (type_id == Type::FIXED_SIZE_BINARY) { const auto& scalar = checked_cast(value); @@ -443,12 +430,10 @@ void ArraySpan::FillFromScalar(const Scalar& value) { } else if (is_var_length_list_like(type_id) || type_id == Type::FIXED_SIZE_LIST) { const auto& scalar = checked_cast(value); - int64_t value_length = 0; this->child_data.resize(1); if (scalar.value != nullptr) { // When the scalar is null, scalar.value can also be null this->child_data[0].SetMembers(*scalar.value->data()); - value_length = scalar.value->length(); } else { // Even when the value is null, we still must populate the // child_data to yield a valid array. Tedious @@ -456,17 +441,25 @@ void ArraySpan::FillFromScalar(const Scalar& value) { &this->child_data[0]); } - if (type_id == Type::LIST || type_id == Type::MAP) { - this->buffers[1] = - OffsetsForScalar(scalar.scratch_space_, static_cast(value_length)); + if (type_id == Type::LIST) { + const auto& list_scalar = checked_cast(value); + this->buffers[1] = OffsetsForScalar(list_scalar.scratch_space_, sizeof(int32_t)); + } else if (type_id == Type::MAP) { + const auto& map_scalar = checked_cast(value); + this->buffers[1] = OffsetsForScalar(map_scalar.scratch_space_, sizeof(int32_t)); } else if (type_id == Type::LARGE_LIST) { - this->buffers[1] = OffsetsForScalar(scalar.scratch_space_, value_length); + const auto& large_list_scalar = checked_cast(value); + this->buffers[1] = + OffsetsForScalar(large_list_scalar.scratch_space_, sizeof(int64_t)); } else if (type_id == Type::LIST_VIEW) { - std::tie(this->buffers[1], this->buffers[2]) = OffsetsAndSizesForScalar( - scalar.scratch_space_, static_cast(value_length)); - } else if (type_id == Type::LARGE_LIST_VIEW) { + const auto& list_view_scalar = checked_cast(value); std::tie(this->buffers[1], this->buffers[2]) = - OffsetsAndSizesForScalar(scalar.scratch_space_, value_length); + OffsetsAndSizesForScalar(list_view_scalar.scratch_space_, sizeof(int32_t)); + } else if (type_id == Type::LARGE_LIST_VIEW) { + const auto& large_list_view_scalar = + checked_cast(value); + std::tie(this->buffers[1], this->buffers[2]) = OffsetsAndSizesForScalar( + large_list_view_scalar.scratch_space_, sizeof(int64_t)); } else { DCHECK_EQ(type_id, Type::FIXED_SIZE_LIST); // FIXED_SIZE_LIST: does not have a second buffer @@ -480,27 +473,19 @@ void ArraySpan::FillFromScalar(const Scalar& value) { this->child_data[i].FillFromScalar(*scalar.value[i]); } } else if (is_union(type_id)) { - // Dense union needs scratch space to store both offsets and a type code - struct UnionScratchSpace { - alignas(int64_t) int8_t type_code; - alignas(int64_t) uint8_t offsets[sizeof(int32_t) * 2]; - }; - static_assert(sizeof(UnionScratchSpace) <= sizeof(UnionScalar::scratch_space_)); - auto* union_scratch_space = reinterpret_cast( - &checked_cast(value).scratch_space_); - // First buffer is kept null since unions have no validity vector this->buffers[0] = {}; - union_scratch_space->type_code = checked_cast(value).type_code; - this->buffers[1].data = reinterpret_cast(&union_scratch_space->type_code); - this->buffers[1].size = 1; - this->child_data.resize(this->type->num_fields()); if (type_id == Type::DENSE_UNION) { const auto& scalar = checked_cast(value); - this->buffers[2] = - OffsetsForScalar(union_scratch_space->offsets, static_cast(1)); + auto* union_scratch_space = + reinterpret_cast(&scalar.scratch_space_); + + this->buffers[1].data = reinterpret_cast(&union_scratch_space->type_code); + this->buffers[1].size = 1; + + this->buffers[2] = OffsetsForScalar(union_scratch_space->offsets, sizeof(int32_t)); // We can't "see" the other arrays in the union, but we put the "active" // union array in the right place and fill zero-length arrays for the // others @@ -517,6 +502,12 @@ void ArraySpan::FillFromScalar(const Scalar& value) { } } else { const auto& scalar = checked_cast(value); + auto* union_scratch_space = + reinterpret_cast(&scalar.scratch_space_); + + this->buffers[1].data = reinterpret_cast(&union_scratch_space->type_code); + this->buffers[1].size = 1; + // Sparse union scalars have a full complement of child values even // though only one of them is relevant, so we just fill them in here for (int i = 0; i < static_cast(this->child_data.size()); ++i) { @@ -541,7 +532,6 @@ void ArraySpan::FillFromScalar(const Scalar& value) { e.null_count = 0; e.buffers[1].data = scalar.scratch_space_; e.buffers[1].size = sizeof(run_end); - reinterpret_cast(scalar.scratch_space_)[0] = run_end; }; switch (scalar.run_end_type()->id()) { diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 72b29057b82..097ee1de45b 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -369,43 +369,6 @@ struct UnboxScalar { } }; -template -struct BoxScalar; - -template -struct BoxScalar> { - using T = typename GetOutputType::T; - static void Box(T val, Scalar* out) { - // Enables BoxScalar to work on a (for example) Time64Scalar - T* mutable_data = reinterpret_cast( - checked_cast<::arrow::internal::PrimitiveScalarBase*>(out)->mutable_data()); - *mutable_data = val; - } -}; - -template -struct BoxScalar> { - using T = typename GetOutputType::T; - using ScalarType = typename TypeTraits::ScalarType; - static void Box(T val, Scalar* out) { - checked_cast(out)->value = std::make_shared(val); - } -}; - -template <> -struct BoxScalar { - using T = Decimal128; - using ScalarType = Decimal128Scalar; - static void Box(T val, Scalar* out) { checked_cast(out)->value = val; } -}; - -template <> -struct BoxScalar { - using T = Decimal256; - using ScalarType = Decimal256Scalar; - static void Box(T val, Scalar* out) { checked_cast(out)->value = val; } -}; - // A VisitArraySpanInline variant that calls its visitor function with logical // values, such as Decimal128 rather than std::string_view. diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index daf8ed76d62..9b2fd987d81 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -491,8 +491,9 @@ template struct ScalarMinMax { using OutValue = typename GetOutputType::T; - static void ExecScalar(const ExecSpan& batch, - const ElementWiseAggregateOptions& options, Scalar* out) { + static Result> ExecScalar( + const ExecSpan& batch, const ElementWiseAggregateOptions& options, + std::shared_ptr type) { // All arguments are scalar OutValue value{}; bool valid = false; @@ -502,8 +503,8 @@ struct ScalarMinMax { const Scalar& scalar = *arg.scalar; if (!scalar.is_valid) { if (options.skip_nulls) continue; - out->is_valid = false; - return; + valid = false; + break; } if (!valid) { value = UnboxScalar::Unbox(scalar); @@ -513,9 +514,10 @@ struct ScalarMinMax { value, UnboxScalar::Unbox(scalar)); } } - out->is_valid = valid; if (valid) { - BoxScalar::Box(value, out); + return MakeScalar(std::move(type), std::move(value)); + } else { + return MakeNullScalar(std::move(type)); } } @@ -537,8 +539,7 @@ struct ScalarMinMax { bool initialize_output = true; if (scalar_count > 0) { ARROW_ASSIGN_OR_RAISE(std::shared_ptr temp_scalar, - MakeScalar(out->type()->GetSharedPtr(), 0)); - ExecScalar(batch, options, temp_scalar.get()); + ExecScalar(batch, options, out->type()->GetSharedPtr())); if (temp_scalar->is_valid) { const auto value = UnboxScalar::Unbox(*temp_scalar); initialize_output = false; diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 6996b46c8b6..8e8d3903663 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -542,6 +542,12 @@ struct ScalarValidateImpl { } }; +template +void FillScalarScratchSpace(void* scratch_space, T const (&arr)[N]) { + static_assert(sizeof(arr) <= internal::kScalarScratchSpaceSize); + std::memcpy(scratch_space, arr, sizeof(arr)); +} + } // namespace size_t Scalar::hash() const { return ScalarHashImpl(*this).hash_; } @@ -557,6 +563,28 @@ Status Scalar::ValidateFull() const { BaseBinaryScalar::BaseBinaryScalar(std::string s, std::shared_ptr type) : BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {} +void BinaryScalar::FillScratchSpace() { + FillScalarScratchSpace( + scratch_space_, + {int32_t(0), value ? static_cast(value->size()) : int32_t(0)}); +} + +void BinaryViewScalar::FillScratchSpace() { + static_assert(sizeof(BinaryViewType::c_type) <= internal::kScalarScratchSpaceSize); + auto* view = new (&scratch_space_) BinaryViewType::c_type; + if (value) { + *view = util::ToBinaryView(std::string_view{*value}, 0, 0); + } else { + *view = {}; + } +} + +void LargeBinaryScalar::FillScratchSpace() { + FillScalarScratchSpace( + scratch_space_, + {int64_t(0), value ? static_cast(value->size()) : int64_t(0)}); +} + FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) @@ -578,21 +606,45 @@ FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::string s, bool is_valid) BaseListScalar::BaseListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) : Scalar{std::move(type), is_valid}, value(std::move(value)) { - ARROW_CHECK(this->type->field(0)->type()->Equals(this->value->type())); + if (this->value) { + ARROW_CHECK(this->type->field(0)->type()->Equals(this->value->type())); + } } ListScalar::ListScalar(std::shared_ptr value, bool is_valid) : BaseListScalar(value, list(value->type()), is_valid) {} +void ListScalar::FillScratchSpace() { + FillScalarScratchSpace( + scratch_space_, + {int32_t(0), value ? static_cast(value->length()) : int32_t(0)}); +} + LargeListScalar::LargeListScalar(std::shared_ptr value, bool is_valid) : BaseListScalar(value, large_list(value->type()), is_valid) {} +void LargeListScalar::FillScratchSpace() { + FillScalarScratchSpace(scratch_space_, + {int64_t(0), value ? value->length() : int64_t(0)}); +} + ListViewScalar::ListViewScalar(std::shared_ptr value, bool is_valid) : BaseListScalar(value, list_view(value->type()), is_valid) {} +void ListViewScalar::FillScratchSpace() { + FillScalarScratchSpace( + scratch_space_, + {int32_t(0), value ? static_cast(value->length()) : int32_t(0)}); +} + LargeListViewScalar::LargeListViewScalar(std::shared_ptr value, bool is_valid) : BaseListScalar(value, large_list_view(value->type()), is_valid) {} +void LargeListViewScalar::FillScratchSpace() { + FillScalarScratchSpace(scratch_space_, + {int64_t(0), value ? value->length() : int64_t(0)}); +} + inline std::shared_ptr MakeMapType(const std::shared_ptr& pair_type) { ARROW_CHECK_EQ(pair_type->id(), Type::STRUCT); ARROW_CHECK_EQ(pair_type->num_fields(), 2); @@ -602,11 +654,19 @@ inline std::shared_ptr MakeMapType(const std::shared_ptr& pa MapScalar::MapScalar(std::shared_ptr value, bool is_valid) : BaseListScalar(value, MakeMapType(value->type()), is_valid) {} +void MapScalar::FillScratchSpace() { + FillScalarScratchSpace( + scratch_space_, + {int32_t(0), value ? static_cast(value->length()) : int32_t(0)}); +} + FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) - : BaseListScalar(value, std::move(type), is_valid) { - ARROW_CHECK_EQ(this->value->length(), - checked_cast(*this->type).list_size()); + : BaseListScalar(std::move(value), std::move(type), is_valid) { + if (this->value) { + ARROW_CHECK_EQ(this->value->length(), + checked_cast(*this->type).list_size()); + } } FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr value, bool is_valid) @@ -656,6 +716,21 @@ RunEndEncodedScalar::RunEndEncodedScalar(const std::shared_ptr& type) RunEndEncodedScalar::~RunEndEncodedScalar() = default; +void RunEndEncodedScalar::FillScratchSpace() { + auto run_end = run_end_type()->id(); + switch (run_end) { + case Type::INT16: + FillScalarScratchSpace(scratch_space_, {int16_t(1)}); + break; + case Type::INT32: + FillScalarScratchSpace(scratch_space_, {int32_t(1)}); + break; + default: + DCHECK_EQ(run_end, Type::INT64); + FillScalarScratchSpace(scratch_space_, {int64_t(1)}); + } +} + DictionaryScalar::DictionaryScalar(std::shared_ptr type) : internal::PrimitiveScalarBase(std::move(type)), value{MakeNullScalar(checked_cast(*this->type).index_type()), @@ -732,11 +807,14 @@ SparseUnionScalar::SparseUnionScalar(ValueType value, int8_t type_code, std::shared_ptr type) : UnionScalar(std::move(type), type_code, /*is_valid=*/true), value(std::move(value)) { - this->child_id = - checked_cast(*this->type).child_ids()[type_code]; + const auto child_ids = checked_cast(*this->type).child_ids(); + if (type_code >= 0 && static_cast(type_code) < child_ids.size() && + child_ids[type_code] != UnionType::kInvalidChildId) { + this->child_id = child_ids[type_code]; - // Fix nullness based on whether the selected child is null - this->is_valid = this->value[this->child_id]->is_valid; + // Fix nullness based on whether the selected child is null + this->is_valid = this->value[this->child_id]->is_valid; + } } std::shared_ptr SparseUnionScalar::FromValue(std::shared_ptr value, @@ -755,6 +833,17 @@ std::shared_ptr SparseUnionScalar::FromValue(std::shared_ptr val return std::make_shared(field_values, type_code, std::move(type)); } +void SparseUnionScalar::FillScratchSpace() { + auto* union_scratch_space = reinterpret_cast(&scratch_space_); + union_scratch_space->type_code = type_code; +} + +void DenseUnionScalar::FillScratchSpace() { + auto* union_scratch_space = reinterpret_cast(&scratch_space_); + union_scratch_space->type_code = type_code; + FillScalarScratchSpace(union_scratch_space->offsets, {int32_t(0), int32_t(1)}); +} + namespace { template @@ -969,58 +1058,72 @@ std::shared_ptr FormatToBuffer(Formatter&& formatter, const ScalarType& } // error fallback -Status CastImpl(const Scalar& from, Scalar* to) { +template +Result> CastImpl(const Scalar& from, + std::shared_ptr to_type) { return Status::NotImplemented("casting scalars of type ", *from.type, " to type ", - *to->type); + *to_type); } // numeric to numeric -template -Status CastImpl(const NumericScalar& from, NumericScalar* to) { - to->value = static_cast(from.value); - return Status::OK(); +template +enable_if_number>> CastImpl( + const NumericScalar& from, std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; + return std::make_shared(static_cast(from.value), + std::move(to_type)); } // numeric to boolean -template -Status CastImpl(const NumericScalar& from, BooleanScalar* to) { - constexpr auto zero = static_cast(0); - to->value = from.value != zero; - return Status::OK(); +template +enable_if_boolean>> CastImpl( + const NumericScalar& from, std::shared_ptr to_type) { + constexpr auto zero = static_cast(0); + return std::make_shared(from.value != zero, std::move(to_type)); } // boolean to numeric -template -Status CastImpl(const BooleanScalar& from, NumericScalar* to) { - to->value = static_cast(from.value); - return Status::OK(); +template +enable_if_number>> CastImpl( + const BooleanScalar& from, std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; + return std::make_shared(static_cast(from.value), + std::move(to_type)); } // numeric to temporal -template +template typename std::enable_if::value && !std::is_same::value && !std::is_same::value, - Status>::type -CastImpl(const NumericScalar& from, TemporalScalar* to) { - to->value = static_cast(from.value); - return Status::OK(); + Result>>::type +CastImpl(const NumericScalar& from, std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; + return std::make_shared(static_cast(from.value), + std::move(to_type)); } // temporal to numeric -template -typename std::enable_if::value && +template +typename std::enable_if::value && + std::is_base_of::value && !std::is_same::value && !std::is_same::value, - Status>::type -CastImpl(const TemporalScalar& from, NumericScalar* to) { - to->value = static_cast(from.value); - return Status::OK(); + Result>>::type +CastImpl(const TemporalScalar& from, std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; + return std::make_shared(static_cast(from.value), + std::move(to_type)); } // timestamp to timestamp -Status CastImpl(const TimestampScalar& from, TimestampScalar* to) { - return util::ConvertTimestampValue(from.type, to->type, from.value).Value(&to->value); +template +enable_if_timestamp>> CastImpl( + const TimestampScalar& from, std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; + ARROW_ASSIGN_OR_RAISE(auto value, + util::ConvertTimestampValue(from.type, to_type, from.value)); + return std::make_shared(value, std::move(to_type)); } template @@ -1029,101 +1132,117 @@ std::shared_ptr AsTimestampType(const std::shared_ptr& type) } // duration to duration -Status CastImpl(const DurationScalar& from, DurationScalar* to) { - return util::ConvertTimestampValue(AsTimestampType(from.type), - AsTimestampType(to->type), from.value) - .Value(&to->value); +template +enable_if_duration>> CastImpl( + const DurationScalar& from, std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; + ARROW_ASSIGN_OR_RAISE( + auto value, + util::ConvertTimestampValue(AsTimestampType(from.type), + AsTimestampType(to_type), from.value)); + return std::make_shared(value, std::move(to_type)); } // time to time -template -enable_if_time CastImpl(const TimeScalar& from, ToScalar* to) { - return util::ConvertTimestampValue(AsTimestampType(from.type), - AsTimestampType(to->type), from.value) - .Value(&to->value); +template +enable_if_time>> CastImpl( + const TimeScalar& from, std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; + ARROW_ASSIGN_OR_RAISE( + auto value, util::ConvertTimestampValue(AsTimestampType(from.type), + AsTimestampType(to_type), from.value)); + return std::make_shared(value, std::move(to_type)); } constexpr int64_t kMillisecondsInDay = 86400000; // date to date -Status CastImpl(const Date32Scalar& from, Date64Scalar* to) { - to->value = from.value * kMillisecondsInDay; - return Status::OK(); +template +enable_if_t::value, Result>> +CastImpl(const Date32Scalar& from, std::shared_ptr to_type) { + return std::make_shared(from.value * kMillisecondsInDay, + std::move(to_type)); } -Status CastImpl(const Date64Scalar& from, Date32Scalar* to) { - to->value = static_cast(from.value / kMillisecondsInDay); - return Status::OK(); +template +enable_if_t::value, Result>> +CastImpl(const Date64Scalar& from, std::shared_ptr to_type) { + return std::make_shared( + static_cast(from.value / kMillisecondsInDay), std::move(to_type)); } // timestamp to date -Status CastImpl(const TimestampScalar& from, Date64Scalar* to) { +template +enable_if_t::value, Result>> +CastImpl(const TimestampScalar& from, std::shared_ptr to_type) { ARROW_ASSIGN_OR_RAISE( auto millis, util::ConvertTimestampValue(from.type, timestamp(TimeUnit::MILLI), from.value)); - to->value = millis - millis % kMillisecondsInDay; - return Status::OK(); + return std::make_shared(millis - millis % kMillisecondsInDay, + std::move(to_type)); } -Status CastImpl(const TimestampScalar& from, Date32Scalar* to) { +template +enable_if_t::value, Result>> +CastImpl(const TimestampScalar& from, std::shared_ptr to_type) { ARROW_ASSIGN_OR_RAISE( auto millis, util::ConvertTimestampValue(from.type, timestamp(TimeUnit::MILLI), from.value)); - to->value = static_cast(millis / kMillisecondsInDay); - return Status::OK(); + return std::make_shared(static_cast(millis / kMillisecondsInDay), + std::move(to_type)); } // date to timestamp -template -Status CastImpl(const DateScalar& from, TimestampScalar* to) { +template +enable_if_timestamp>> CastImpl( + const DateScalar& from, std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; int64_t millis = from.value; - if (std::is_same::value) { + if (std::is_same::value) { millis *= kMillisecondsInDay; } - return util::ConvertTimestampValue(timestamp(TimeUnit::MILLI), to->type, millis) - .Value(&to->value); + ARROW_ASSIGN_OR_RAISE(auto value, util::ConvertTimestampValue( + timestamp(TimeUnit::MILLI), to_type, millis)); + return std::make_shared(value, std::move(to_type)); } // string to any -template -Status CastImpl(const StringScalar& from, ScalarType* to) { - ARROW_ASSIGN_OR_RAISE(auto out, Scalar::Parse(to->type, std::string_view(*from.value))); - to->value = std::move(checked_cast(*out).value); - return Status::OK(); +template +Result> CastImpl(const StringScalar& from, + std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; + ARROW_ASSIGN_OR_RAISE(auto out, + Scalar::Parse(std::move(to_type), std::string_view(*from.value))); + DCHECK(checked_pointer_cast(out) != nullptr); + return std::move(out); } // binary/large binary/large string to string -template -enable_if_t && - !std::is_same::value, - Status> -CastImpl(const ScalarType& from, StringScalar* to) { - to->value = from.value; - return Status::OK(); +template +enable_if_t::value && + std::is_base_of_v && + !std::is_same::value, + Result>> +CastImpl(const From& from, std::shared_ptr to_type) { + return std::make_shared(from.value, std::move(to_type)); } // formattable to string -template , // note: Value unused but necessary to trigger SFINAE if Formatter is // undefined typename Value = typename Formatter::value_type> -Status CastImpl(const ScalarType& from, StringScalar* to) { - to->value = FormatToBuffer(Formatter{from.type.get()}, from); - return Status::OK(); -} - -Status CastImpl(const Decimal128Scalar& from, StringScalar* to) { - auto from_type = checked_cast(from.type.get()); - to->value = Buffer::FromString(from.value.ToString(from_type->scale())); - return Status::OK(); -} - -Status CastImpl(const Decimal256Scalar& from, StringScalar* to) { - auto from_type = checked_cast(from.type.get()); - to->value = Buffer::FromString(from.value.ToString(from_type->scale())); - return Status::OK(); +typename std::enable_if_t::value, + Result>> +CastImpl(const From& from, std::shared_ptr to_type) { + return std::make_shared(FormatToBuffer(Formatter{from.type.get()}, from), + std::move(to_type)); } -Status CastImpl(const StructScalar& from, StringScalar* to) { +// struct to string +template +typename std::enable_if_t::value, + Result>> +CastImpl(const StructScalar& from, std::shared_ptr to_type) { std::stringstream ss; ss << '{'; for (int i = 0; static_cast(i) < from.value.size(); i++) { @@ -1132,24 +1251,23 @@ Status CastImpl(const StructScalar& from, StringScalar* to) { << " = " << from.value[i]->ToString(); } ss << '}'; - to->value = Buffer::FromString(ss.str()); - return Status::OK(); + return std::make_shared(Buffer::FromString(ss.str()), std::move(to_type)); } // casts between variable-length and fixed-length list types -template -enable_if_list_type CastImpl( - const BaseListScalar& from, ToScalar* to) { - if constexpr (sizeof(typename ToScalar::TypeClass::offset_type) < sizeof(int64_t)) { - if (from.value->length() > - std::numeric_limits::max()) { +template +std::enable_if_t::value && is_list_type::value, + Result>> +CastImpl(const From& from, std::shared_ptr to_type) { + if constexpr (sizeof(typename To::offset_type) < sizeof(int64_t)) { + if (from.value->length() > std::numeric_limits::max()) { return Status::Invalid(from.type->ToString(), " too large to cast to ", - to->type->ToString()); + to_type->ToString()); } } - if constexpr (is_fixed_size_list_type::value) { - const auto& fixed_size_list_type = checked_cast(*to->type); + if constexpr (is_fixed_size_list_type::value) { + const auto& fixed_size_list_type = checked_cast(*to_type); if (from.value->length() != fixed_size_list_type.list_size()) { return Status::Invalid("Cannot cast ", from.type->ToString(), " of length ", from.value->length(), " to fixed size list of length ", @@ -1157,13 +1275,15 @@ enable_if_list_type CastImpl( } } - DCHECK_EQ(from.is_valid, to->is_valid); - to->value = from.value; - return Status::OK(); + using ToScalar = typename TypeTraits::ScalarType; + return std::make_shared(from.value, std::move(to_type), from.is_valid); } // list based types (list, large list and map (fixed sized list too)) to string -Status CastImpl(const BaseListScalar& from, StringScalar* to) { +template +typename std::enable_if_t::value, + Result>> +CastImpl(const BaseListScalar& from, std::shared_ptr to_type) { std::stringstream ss; ss << from.type->ToString() << "["; for (int64_t i = 0; i < from.value->length(); i++) { @@ -1172,11 +1292,14 @@ Status CastImpl(const BaseListScalar& from, StringScalar* to) { ss << value->ToString(); } ss << ']'; - to->value = Buffer::FromString(ss.str()); - return Status::OK(); + return std::make_shared(Buffer::FromString(ss.str()), std::move(to_type)); } -Status CastImpl(const UnionScalar& from, StringScalar* to) { +// union types to string +template +typename std::enable_if_t::value, + Result>> +CastImpl(const UnionScalar& from, std::shared_ptr to_type) { const auto& union_ty = checked_cast(*from.type); std::stringstream ss; const Scalar* selected_value; @@ -1188,8 +1311,7 @@ Status CastImpl(const UnionScalar& from, StringScalar* to) { } ss << "union{" << union_ty.field(union_ty.child_ids()[from.type_code])->ToString() << " = " << selected_value->ToString() << '}'; - to->value = Buffer::FromString(ss.str()); - return Status::OK(); + return std::make_shared(Buffer::FromString(ss.str()), std::move(to_type)); } struct CastImplVisitor { @@ -1199,59 +1321,49 @@ struct CastImplVisitor { const Scalar& from_; const std::shared_ptr& to_type_; - Scalar* out_; + std::shared_ptr out_ = nullptr; }; template struct FromTypeVisitor : CastImplVisitor { using ToScalar = typename TypeTraits::ScalarType; - FromTypeVisitor(const Scalar& from, const std::shared_ptr& to_type, - Scalar* out) - : CastImplVisitor{from, to_type, out} {} + FromTypeVisitor(const Scalar& from, const std::shared_ptr& to_type) + : CastImplVisitor{from, to_type} {} template Status Visit(const FromType&) { - return CastImpl(checked_cast::ScalarType&>(from_), - checked_cast(out_)); + ARROW_ASSIGN_OR_RAISE( + out_, CastImpl( + checked_cast::ScalarType&>(from_), + std::move(to_type_))); + return Status::OK(); } // identity cast only for parameter free types template typename std::enable_if_t::is_parameter_free, Status> Visit( const ToType&) { - checked_cast(out_)->value = checked_cast(from_).value; + ARROW_ASSIGN_OR_RAISE(out_, MakeScalar(std::move(to_type_), + checked_cast(from_).value)); return Status::OK(); } - Status CastFromListLike(const BaseListType& base_list_type) { - return CastImpl(checked_cast(from_), - checked_cast(out_)); - } - - Status Visit(const ListType& list_type) { return CastFromListLike(list_type); } - - Status Visit(const LargeListType& large_list_type) { - return CastFromListLike(large_list_type); - } - - Status Visit(const FixedSizeListType& fixed_size_list_type) { - return CastFromListLike(fixed_size_list_type); - } - Status Visit(const NullType&) { return NotImplemented(); } Status Visit(const DictionaryType&) { return NotImplemented(); } Status Visit(const ExtensionType&) { return NotImplemented(); } }; struct ToTypeVisitor : CastImplVisitor { - ToTypeVisitor(const Scalar& from, const std::shared_ptr& to_type, Scalar* out) - : CastImplVisitor{from, to_type, out} {} + ToTypeVisitor(const Scalar& from, const std::shared_ptr& to_type) + : CastImplVisitor{from, to_type} {} template Status Visit(const ToType&) { - FromTypeVisitor unpack_from_type{from_, to_type_, out_}; - return VisitTypeInline(*from_.type, &unpack_from_type); + FromTypeVisitor unpack_from_type{from_, to_type_}; + ARROW_RETURN_NOT_OK(VisitTypeInline(*from_.type, &unpack_from_type)); + out_ = std::move(unpack_from_type.out_); + return Status::OK(); } Status Visit(const NullType&) { @@ -1262,25 +1374,28 @@ struct ToTypeVisitor : CastImplVisitor { } Status Visit(const DictionaryType& dict_type) { - auto& out = checked_cast(out_)->value; ARROW_ASSIGN_OR_RAISE(auto cast_value, from_.CastTo(dict_type.value_type())); - ARROW_ASSIGN_OR_RAISE(out.dictionary, MakeArrayFromScalar(*cast_value, 1)); - return Int32Scalar(0).CastTo(dict_type.index_type()).Value(&out.index); + ARROW_ASSIGN_OR_RAISE(auto dictionary, MakeArrayFromScalar(*cast_value, 1)); + ARROW_ASSIGN_OR_RAISE(auto index, Int32Scalar(0).CastTo(dict_type.index_type())); + out_ = DictionaryScalar::Make(std::move(index), std::move(dictionary)); + return Status::OK(); } Status Visit(const ExtensionType&) { return NotImplemented(); } + + Result> Finish() && { + ARROW_RETURN_NOT_OK(VisitTypeInline(*to_type_, this)); + return std::move(out_); + } }; } // namespace Result> Scalar::CastTo(std::shared_ptr to) const { - std::shared_ptr out = MakeNullScalar(to); if (is_valid) { - out->is_valid = true; - ToTypeVisitor unpack_to_type{*this, to, out.get()}; - RETURN_NOT_OK(VisitTypeInline(*to, &unpack_to_type)); + return ToTypeVisitor{*this, std::move(to)}.Finish(); } - return out; + return MakeNullScalar(std::move(to)); } void PrintTo(const Scalar& scalar, std::ostream* os) { *os << scalar.ToString(); } diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 65c5ee4df0a..a7ee6a417d9 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -131,11 +131,19 @@ struct ARROW_EXPORT NullScalar : public Scalar { namespace internal { +constexpr auto kScalarScratchSpaceSize = sizeof(int64_t) * 2; + +template struct ARROW_EXPORT ArraySpanFillFromScalarScratchSpace { // 16 bytes of scratch space to enable ArraySpan to be a view onto any // Scalar- including binary scalars where we need to create a buffer // that looks like two 32-bit or 64-bit offsets. - alignas(int64_t) mutable uint8_t scratch_space_[sizeof(int64_t) * 2]; + alignas(int64_t) mutable uint8_t scratch_space_[kScalarScratchSpaceSize]; + + private: + ArraySpanFillFromScalarScratchSpace() { static_cast(this)->FillScratchSpace(); } + + friend Impl; }; struct ARROW_EXPORT PrimitiveScalarBase : public Scalar { @@ -145,8 +153,6 @@ struct ARROW_EXPORT PrimitiveScalarBase : public Scalar { using Scalar::Scalar; /// \brief Get a const pointer to the value of this scalar. May be null. virtual const void* data() const = 0; - /// \brief Get a mutable pointer to the value of this scalar. May be null. - virtual void* mutable_data() = 0; /// \brief Get an immutable view of the value of this scalar as bytes. virtual std::string_view view() const = 0; }; @@ -167,7 +173,6 @@ struct ARROW_EXPORT PrimitiveScalar : public PrimitiveScalarBase { ValueType value{}; const void* data() const override { return &value; } - void* mutable_data() override { return &value; } std::string_view view() const override { return std::string_view(reinterpret_cast(&value), sizeof(ValueType)); }; @@ -245,34 +250,38 @@ struct ARROW_EXPORT DoubleScalar : public NumericScalar { using NumericScalar::NumericScalar; }; -struct ARROW_EXPORT BaseBinaryScalar - : public internal::PrimitiveScalarBase, - private internal::ArraySpanFillFromScalarScratchSpace { - using internal::PrimitiveScalarBase::PrimitiveScalarBase; +struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase { using ValueType = std::shared_ptr; - std::shared_ptr value; + // The value is not supposed to be modified after construction, because subclasses have + // a scratch space whose content need to be kept consistent with the value. It is also + // the user of this class's responsibility to ensure that the buffer is not written to + // accidentally. + const std::shared_ptr value = NULLPTR; const void* data() const override { return value ? reinterpret_cast(value->data()) : NULLPTR; } - void* mutable_data() override { - return value ? reinterpret_cast(value->mutable_data()) : NULLPTR; - } std::string_view view() const override { return value ? std::string_view(*value) : std::string_view(); } + explicit BaseBinaryScalar(std::shared_ptr type) + : internal::PrimitiveScalarBase(std::move(type)) {} + BaseBinaryScalar(std::shared_ptr value, std::shared_ptr type) : internal::PrimitiveScalarBase{std::move(type), true}, value(std::move(value)) {} - friend ArraySpan; BaseBinaryScalar(std::string s, std::shared_ptr type); }; -struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { +struct ARROW_EXPORT BinaryScalar + : public BaseBinaryScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using BaseBinaryScalar::BaseBinaryScalar; using TypeClass = BinaryType; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; explicit BinaryScalar(std::shared_ptr value) : BinaryScalar(std::move(value), binary()) {} @@ -280,6 +289,12 @@ struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { explicit BinaryScalar(std::string s) : BaseBinaryScalar(std::move(s), binary()) {} BinaryScalar() : BinaryScalar(binary()) {} + + private: + void FillScratchSpace(); + + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; struct ARROW_EXPORT StringScalar : public BinaryScalar { @@ -294,9 +309,13 @@ struct ARROW_EXPORT StringScalar : public BinaryScalar { StringScalar() : StringScalar(utf8()) {} }; -struct ARROW_EXPORT BinaryViewScalar : public BaseBinaryScalar { +struct ARROW_EXPORT BinaryViewScalar + : public BaseBinaryScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using BaseBinaryScalar::BaseBinaryScalar; using TypeClass = BinaryViewType; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; explicit BinaryViewScalar(std::shared_ptr value) : BinaryViewScalar(std::move(value), binary_view()) {} @@ -307,6 +326,12 @@ struct ARROW_EXPORT BinaryViewScalar : public BaseBinaryScalar { BinaryViewScalar() : BinaryViewScalar(binary_view()) {} std::string_view view() const override { return std::string_view(*this->value); } + + private: + void FillScratchSpace(); + + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; struct ARROW_EXPORT StringViewScalar : public BinaryViewScalar { @@ -322,9 +347,13 @@ struct ARROW_EXPORT StringViewScalar : public BinaryViewScalar { StringViewScalar() : StringViewScalar(utf8_view()) {} }; -struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar { +struct ARROW_EXPORT LargeBinaryScalar + : public BaseBinaryScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using BaseBinaryScalar::BaseBinaryScalar; using TypeClass = LargeBinaryType; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; LargeBinaryScalar(std::shared_ptr value, std::shared_ptr type) : BaseBinaryScalar(std::move(value), std::move(type)) {} @@ -336,6 +365,12 @@ struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar { : BaseBinaryScalar(std::move(s), large_binary()) {} LargeBinaryScalar() : LargeBinaryScalar(large_binary()) {} + + private: + void FillScratchSpace(); + + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; struct ARROW_EXPORT LargeStringScalar : public LargeBinaryScalar { @@ -482,10 +517,6 @@ struct ARROW_EXPORT DecimalScalar : public internal::PrimitiveScalarBase { return reinterpret_cast(value.native_endian_bytes()); } - void* mutable_data() override { - return reinterpret_cast(value.mutable_native_endian_bytes()); - } - std::string_view view() const override { return std::string_view(reinterpret_cast(value.native_endian_bytes()), ValueType::kByteWidth); @@ -502,54 +533,102 @@ struct ARROW_EXPORT Decimal256Scalar : public DecimalScalar; BaseListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid = true); - std::shared_ptr value; - - private: - friend struct ArraySpan; + // The value is not supposed to be modified after construction, because subclasses have + // a scratch space whose content need to be kept consistent with the value. It is also + // the user of this class's responsibility to ensure that the array is not modified + // accidentally. + const std::shared_ptr value; }; -struct ARROW_EXPORT ListScalar : public BaseListScalar { +struct ARROW_EXPORT ListScalar + : public BaseListScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = ListType; using BaseListScalar::BaseListScalar; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; explicit ListScalar(std::shared_ptr value, bool is_valid = true); + + private: + void FillScratchSpace(); + + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; -struct ARROW_EXPORT LargeListScalar : public BaseListScalar { +struct ARROW_EXPORT LargeListScalar + : public BaseListScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = LargeListType; using BaseListScalar::BaseListScalar; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; explicit LargeListScalar(std::shared_ptr value, bool is_valid = true); + + private: + void FillScratchSpace(); + + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; -struct ARROW_EXPORT ListViewScalar : public BaseListScalar { +struct ARROW_EXPORT ListViewScalar + : public BaseListScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = ListViewType; using BaseListScalar::BaseListScalar; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; explicit ListViewScalar(std::shared_ptr value, bool is_valid = true); + + private: + void FillScratchSpace(); + + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; -struct ARROW_EXPORT LargeListViewScalar : public BaseListScalar { +struct ARROW_EXPORT LargeListViewScalar + : public BaseListScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = LargeListViewType; using BaseListScalar::BaseListScalar; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; explicit LargeListViewScalar(std::shared_ptr value, bool is_valid = true); + + private: + void FillScratchSpace(); + + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; -struct ARROW_EXPORT MapScalar : public BaseListScalar { +struct ARROW_EXPORT MapScalar + : public BaseListScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = MapType; using BaseListScalar::BaseListScalar; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; explicit MapScalar(std::shared_ptr value, bool is_valid = true); + + private: + void FillScratchSpace(); + + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; struct ARROW_EXPORT FixedSizeListScalar : public BaseListScalar { @@ -576,9 +655,10 @@ struct ARROW_EXPORT StructScalar : public Scalar { std::vector field_names); }; -struct ARROW_EXPORT UnionScalar : public Scalar, - private internal::ArraySpanFillFromScalarScratchSpace { - int8_t type_code; +struct ARROW_EXPORT UnionScalar : public Scalar { + // The type code is not supposed to be modified after construction, because the scratch + // space's content need to be kept consistent with it. + const int8_t type_code; virtual const std::shared_ptr& child_value() const = 0; @@ -586,17 +666,31 @@ struct ARROW_EXPORT UnionScalar : public Scalar, UnionScalar(std::shared_ptr type, int8_t type_code, bool is_valid) : Scalar(std::move(type), is_valid), type_code(type_code) {} - friend struct ArraySpan; + struct UnionScratchSpace { + alignas(int64_t) int8_t type_code; + alignas(int64_t) uint8_t offsets[sizeof(int32_t) * 2]; + }; + static_assert(sizeof(UnionScratchSpace) <= internal::kScalarScratchSpaceSize); + + friend ArraySpan; }; -struct ARROW_EXPORT SparseUnionScalar : public UnionScalar { +struct ARROW_EXPORT SparseUnionScalar + : public UnionScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = SparseUnionType; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; // Even though only one of the union values is relevant for this scalar, we // nonetheless construct a vector of scalars, one per union value, to have // enough data to reconstruct a valid ArraySpan of length 1 from this scalar using ValueType = std::vector>; - ValueType value; + // The value is not supposed to be modified after construction, because the scratch + // space's content need to be kept consistent with the value. It is also the user of + // this class's responsibility to ensure that the scalars of the vector is not modified + // to accidentally. + const ValueType value; // The value index corresponding to the active type code int child_id; @@ -611,30 +705,56 @@ struct ARROW_EXPORT SparseUnionScalar : public UnionScalar { /// to construct a vector of scalars static std::shared_ptr FromValue(std::shared_ptr value, int field_index, std::shared_ptr type); + + private: + void FillScratchSpace(); + + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; -struct ARROW_EXPORT DenseUnionScalar : public UnionScalar { +struct ARROW_EXPORT DenseUnionScalar + : public UnionScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = DenseUnionType; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; // For DenseUnionScalar, we can make a valid ArraySpan of length 1 from this // scalar using ValueType = std::shared_ptr; - ValueType value; + // The value is not supposed to be modified after construction, because the scratch + // space's content need to be kept consistent with the value. It is also the user of + // this class's responsibility to ensure that the elements of the vector is not modified + // accidentally. + const ValueType value; const std::shared_ptr& child_value() const override { return this->value; } DenseUnionScalar(ValueType value, int8_t type_code, std::shared_ptr type) : UnionScalar(std::move(type), type_code, value->is_valid), value(std::move(value)) {} + + private: + void FillScratchSpace(); + + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; struct ARROW_EXPORT RunEndEncodedScalar : public Scalar, - private internal::ArraySpanFillFromScalarScratchSpace { + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = RunEndEncodedType; using ValueType = std::shared_ptr; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; - ValueType value; + // The value is not supposed to be modified after construction, because the scratch + // space's content need to be kept consistent with the value. It is also the user of + // this class's responsibility to ensure that the wrapped scalar is not modified + // accidentally. + const ValueType value; RunEndEncodedScalar(std::shared_ptr value, std::shared_ptr type); @@ -652,7 +772,10 @@ struct ARROW_EXPORT RunEndEncodedScalar private: const TypeClass& ree_type() const { return internal::checked_cast(*type); } + void FillScratchSpace(); + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; /// \brief A Scalar value for DictionaryType @@ -680,10 +803,6 @@ struct ARROW_EXPORT DictionaryScalar : public internal::PrimitiveScalarBase { const void* data() const override { return internal::checked_cast(*value.index).data(); } - void* mutable_data() override { - return internal::checked_cast(*value.index) - .mutable_data(); - } std::string_view view() const override { return internal::checked_cast(*value.index) .view(); diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index 09dfde32271..104a5697b57 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -95,6 +95,68 @@ TEST(TestNullScalar, ValidateErrors) { AssertValidationFails(scalar); } +TEST(TestNullScalar, Cast) { + NullScalar scalar; + for (auto to_type : { + int8(), + float64(), + date32(), + time32(TimeUnit::SECOND), + timestamp(TimeUnit::SECOND), + duration(TimeUnit::SECOND), + utf8(), + large_binary(), + list(int32()), + struct_({field("f", int32())}), + map(utf8(), int32()), + decimal(12, 2), + list_view(int32()), + large_list(int32()), + dense_union({field("string", utf8()), field("number", uint64())}), + sparse_union({field("string", utf8()), field("number", uint64())}), + }) { + // Cast() function doesn't support casting null scalar, use Scalar::CastTo() instead. + ASSERT_OK_AND_ASSIGN(auto casted, scalar.CastTo(to_type)); + ASSERT_EQ(casted->type->id(), to_type->id()); + ASSERT_FALSE(casted->is_valid); + } +} + +TEST(TestBooleanScalar, Cast) { + for (auto b : {true, false}) { + BooleanScalar scalar(b); + ARROW_SCOPED_TRACE("boolean value: ", scalar.ToString()); + + // Boolean type (identity cast). + { + ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, boolean())); + ASSERT_TRUE(casted.scalar()->Equals(scalar)) << casted.scalar()->ToString(); + } + + // Numeric types. + for (auto to_type : { + int8(), + uint16(), + int32(), + uint64(), + float32(), + float64(), + }) { + ARROW_SCOPED_TRACE("to type: ", to_type->ToString()); + ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, to_type)); + ASSERT_EQ(casted.scalar()->type->id(), to_type->id()); + ASSERT_EQ(casted.scalar()->ToString(), std::to_string(b)); + } + + // String type. + { + ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, utf8())); + ASSERT_EQ(casted.scalar()->type->id(), utf8()->id()); + ASSERT_EQ(casted.scalar()->ToString(), scalar.ToString()); + } + } +} + template class TestNumericScalar : public ::testing::Test { public: @@ -464,12 +526,23 @@ class TestDecimalScalar : public ::testing::Test { ::testing::HasSubstr("does not fit in precision of"), invalid.ValidateFull()); } + + void TestCast() { + const auto ty = std::make_shared(3, 2); + const auto pi = ScalarType(ValueType(314), ty); + + ASSERT_OK_AND_ASSIGN(auto casted, Cast(pi, utf8())); + ASSERT_TRUE(casted.scalar()->Equals(StringScalar("3.14"))) + << casted.scalar()->ToString(); + } }; TYPED_TEST_SUITE(TestDecimalScalar, DecimalArrowTypes); TYPED_TEST(TestDecimalScalar, Basics) { this->TestBasics(); } +TYPED_TEST(TestDecimalScalar, Cast) { this->TestCast(); } + TEST(TestBinaryScalar, Basics) { std::string data = "test data"; auto buf = std::make_shared(data); @@ -551,6 +624,14 @@ TEST(TestBinaryScalar, ValidateErrors) { AssertValidationFails(*null_scalar); } +TEST(TestBinaryScalar, Cast) { + BinaryScalar scalar(Buffer::FromString("test data")); + ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, utf8())); + ASSERT_EQ(casted.scalar()->type->id(), utf8()->id()); + AssertBufferEqual(*checked_cast(*casted.scalar()).value, + *scalar.value); +} + template class TestStringScalar : public ::testing::Test { public: @@ -580,19 +661,25 @@ class TestStringScalar : public ::testing::Test { } void TestValidateErrors() { - // Inconsistent is_valid / value - ScalarType scalar(Buffer::FromString("xxx")); - scalar.is_valid = false; - AssertValidationFails(scalar); + { + // Inconsistent is_valid / value + ScalarType scalar(Buffer::FromString("xxx")); + scalar.is_valid = false; + AssertValidationFails(scalar); + } - auto null_scalar = MakeNullScalar(type_); - null_scalar->is_valid = true; - AssertValidationFails(*null_scalar); + { + auto null_scalar = MakeNullScalar(type_); + null_scalar->is_valid = true; + AssertValidationFails(*null_scalar); + } - // Invalid UTF8 - scalar = ScalarType(Buffer::FromString("\xff")); - ASSERT_OK(scalar.Validate()); - ASSERT_RAISES(Invalid, scalar.ValidateFull()); + { + // Invalid UTF8 + ScalarType scalar(Buffer::FromString("\xff")); + ASSERT_OK(scalar.Validate()); + ASSERT_RAISES(Invalid, scalar.ValidateFull()); + } } protected: @@ -676,8 +763,16 @@ TEST(TestFixedSizeBinaryScalar, ValidateErrors) { FixedSizeBinaryScalar scalar(buf, type); ASSERT_OK(scalar.ValidateFull()); - scalar.value = SliceBuffer(buf, 1); - AssertValidationFails(scalar); + ASSERT_RAISES(Invalid, MakeScalar(type, SliceBuffer(buf, 1))); +} + +TEST(TestFixedSizeBinaryScalar, Cast) { + std::string data = "test data"; + FixedSizeBinaryScalar scalar(data); + ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, utf8())); + ASSERT_EQ(casted.scalar()->type->id(), utf8()->id()); + AssertBufferEqual(*checked_cast(*casted.scalar()).value, + *scalar.value); } TEST(TestDateScalars, Basics) { @@ -1136,24 +1231,25 @@ class TestListLikeScalar : public ::testing::Test { } void TestValidateErrors() { - ScalarType scalar(value_); - scalar.is_valid = false; - ASSERT_OK(scalar.ValidateFull()); - - // Value must be defined - scalar = ScalarType(value_); - scalar.value = nullptr; - AssertValidationFails(scalar); + { + ScalarType scalar(value_); + scalar.is_valid = false; + ASSERT_OK(scalar.ValidateFull()); + } - // Inconsistent child type - scalar = ScalarType(value_); - scalar.value = ArrayFromJSON(int32(), "[1, 2, null]"); - AssertValidationFails(scalar); + { + // Value must be defined + ScalarType scalar(nullptr, type_); + scalar.is_valid = true; + AssertValidationFails(scalar); + } - // Invalid UTF8 in child data - scalar = ScalarType(ArrayFromJSON(utf8(), "[null, null, \"\xff\"]")); - ASSERT_OK(scalar.Validate()); - ASSERT_RAISES(Invalid, scalar.ValidateFull()); + { + // Invalid UTF8 in child data + ScalarType scalar(ArrayFromJSON(utf8(), "[null, null, \"\xff\"]")); + ASSERT_OK(scalar.Validate()); + ASSERT_RAISES(Invalid, scalar.ValidateFull()); + } } void TestHashing() { @@ -1195,6 +1291,12 @@ class TestListLikeScalar : public ::testing::Test { auto invalid_cast_type = fixed_size_list(value_->type(), 5); CheckListCastError(scalar, invalid_cast_type); + + // Cast() function doesn't support casting list-like to string, use Scalar::CastTo() + // instead. + ASSERT_OK_AND_ASSIGN(auto casted_str, scalar.CastTo(utf8())); + ASSERT_EQ(casted_str->type->id(), utf8()->id()); + ASSERT_EQ(casted_str->ToString(), scalar.ToString()); } protected: @@ -1224,6 +1326,24 @@ TEST(TestFixedSizeListScalar, ValidateErrors) { AssertValidationFails(scalar); } +TEST(TestFixedSizeListScalar, Cast) { + const auto ty = fixed_size_list(int16(), 3); + FixedSizeListScalar scalar(ArrayFromJSON(int16(), "[1, 2, 5]"), ty); + + CheckListCast(scalar, list(int16())); + CheckListCast(scalar, large_list(int16())); + CheckListCast(scalar, fixed_size_list(int16(), 3)); + + auto invalid_cast_type = fixed_size_list(int16(), 4); + CheckListCastError(scalar, invalid_cast_type); + + // Cast() function doesn't support casting list-like to string, use Scalar::CastTo() + // instead. + ASSERT_OK_AND_ASSIGN(auto casted_str, scalar.CastTo(utf8())); + ASSERT_EQ(casted_str->type->id(), utf8()->id()); + ASSERT_EQ(casted_str->ToString(), scalar.ToString()); +} + TEST(TestMapScalar, Basics) { auto value = ArrayFromJSON(struct_({field("key", utf8(), false), field("value", int8())}), @@ -1253,6 +1373,12 @@ TEST(TestMapScalar, Cast) { auto invalid_cast_type = fixed_size_list(key_value_type, 5); CheckListCastError(scalar, invalid_cast_type); + + // Cast() function doesn't support casting map to string, use Scalar::CastTo() instead. + ASSERT_OK_AND_ASSIGN(auto casted_str, scalar.CastTo(utf8())); + ASSERT_TRUE(casted_str->Equals(StringScalar( + R"(map[{key:string = a, value:int8 = 1}, {key:string = b, value:int8 = 2}])"))) + << casted_str->ToString(); } TEST(TestStructScalar, FieldAccess) { @@ -1345,6 +1471,16 @@ TEST(TestStructScalar, ValidateErrors) { ASSERT_RAISES(Invalid, scalar.ValidateFull()); } +TEST(TestStructScalar, Cast) { + auto ty = struct_({field("i", int32()), field("s", utf8())}); + StructScalar scalar({MakeScalar(42), MakeScalar("xxx")}, ty); + + // Cast() function doesn't support casting map to string, use Scalar::CastTo() instead. + ASSERT_OK_AND_ASSIGN(auto casted_str, scalar.CastTo(utf8())); + ASSERT_TRUE(casted_str->Equals(StringScalar(R"({i:int32 = 42, s:string = xxx})"))) + << casted_str->ToString(); +} + TEST(TestDictionaryScalar, Basics) { for (auto index_ty : all_dictionary_index_types()) { auto ty = dictionary(index_ty, utf8()); @@ -1534,17 +1670,41 @@ void CheckGetNullUnionScalar(const Array& arr, int64_t index) { ASSERT_FALSE(checked_cast(*scalar).child_value()->is_valid); } +std::shared_ptr MakeUnionScalar(const SparseUnionType& type, int8_t type_code, + std::shared_ptr field_value, + int field_index) { + ScalarVector field_values; + for (int i = 0; i < type.num_fields(); ++i) { + if (i == field_index) { + field_values.emplace_back(std::move(field_value)); + } else { + field_values.emplace_back(MakeNullScalar(type.field(i)->type())); + } + } + return std::make_shared(std::move(field_values), type_code, + type.GetSharedPtr()); +} + std::shared_ptr MakeUnionScalar(const SparseUnionType& type, std::shared_ptr field_value, int field_index) { - return SparseUnionScalar::FromValue(field_value, field_index, type.GetSharedPtr()); + return SparseUnionScalar::FromValue(std::move(field_value), field_index, + type.GetSharedPtr()); +} + +std::shared_ptr MakeUnionScalar(const DenseUnionType& type, int8_t type_code, + std::shared_ptr field_value, + int field_index) { + return std::make_shared(std::move(field_value), type_code, + type.GetSharedPtr()); } std::shared_ptr MakeUnionScalar(const DenseUnionType& type, std::shared_ptr field_value, int field_index) { int8_t type_code = type.type_codes()[field_index]; - return std::make_shared(field_value, type_code, type.GetSharedPtr()); + return std::make_shared(std::move(field_value), type_code, + type.GetSharedPtr()); } std::shared_ptr MakeSpecificNullScalar(const DenseUnionType& type, @@ -1592,7 +1752,13 @@ class TestUnionScalar : public ::testing::Test { std::shared_ptr ScalarFromValue(int field_index, std::shared_ptr field_value) { - return MakeUnionScalar(*union_type_, field_value, field_index); + return MakeUnionScalar(*union_type_, std::move(field_value), field_index); + } + + std::shared_ptr ScalarFromTypeCodeAndValue(int8_t type_code, + std::shared_ptr field_value, + int field_index) { + return MakeUnionScalar(*union_type_, type_code, std::move(field_value), field_index); } std::shared_ptr SpecificNull(int field_index) { @@ -1610,40 +1776,48 @@ class TestUnionScalar : public ::testing::Test { } void TestValidateErrors() { - // Type code doesn't exist - auto scalar = ScalarFromValue(0, alpha_); - UnionScalar* union_scalar = static_cast(scalar.get()); - - // Invalid type code - union_scalar->type_code = 0; - AssertValidationFails(*union_scalar); + { + // Invalid type code + auto scalar = ScalarFromTypeCodeAndValue(0, alpha_, 0); + AssertValidationFails(*scalar); + } - union_scalar->is_valid = false; - AssertValidationFails(*union_scalar); + { + auto scalar = ScalarFromTypeCodeAndValue(0, alpha_, 0); + scalar->is_valid = false; + AssertValidationFails(*scalar); + } - union_scalar->type_code = -42; - union_scalar->is_valid = true; - AssertValidationFails(*union_scalar); + { + auto scalar = ScalarFromTypeCodeAndValue(-42, alpha_, 0); + AssertValidationFails(*scalar); + } - union_scalar->is_valid = false; - AssertValidationFails(*union_scalar); + { + auto scalar = ScalarFromTypeCodeAndValue(-42, alpha_, 0); + scalar->is_valid = false; + AssertValidationFails(*scalar); + } // Type code doesn't correspond to child type if (type_->id() == ::arrow::Type::DENSE_UNION) { - union_scalar->type_code = 42; - union_scalar->is_valid = true; - AssertValidationFails(*union_scalar); - - scalar = ScalarFromValue(2, two_); - union_scalar = static_cast(scalar.get()); - union_scalar->type_code = 3; - AssertValidationFails(*union_scalar); + { + auto scalar = ScalarFromTypeCodeAndValue(42, alpha_, 0); + AssertValidationFails(*scalar); + } + + { + auto scalar = ScalarFromTypeCodeAndValue(3, two_, 2); + AssertValidationFails(*scalar); + } } - // underlying value has invalid UTF8 - scalar = ScalarFromValue(0, std::make_shared("\xff")); - ASSERT_OK(scalar->Validate()); - ASSERT_RAISES(Invalid, scalar->ValidateFull()); + { + // underlying value has invalid UTF8 + auto scalar = ScalarFromValue(0, std::make_shared("\xff")); + ASSERT_OK(scalar->Validate()); + ASSERT_RAISES(Invalid, scalar->ValidateFull()); + } } void TestEquals() { @@ -1680,6 +1854,14 @@ class TestUnionScalar : public ::testing::Test { } } + void TestCast() { + // Cast() function doesn't support casting union to string, use Scalar::CastTo() + // instead. + ASSERT_OK_AND_ASSIGN(auto casted, union_alpha_->CastTo(utf8())); + ASSERT_TRUE(casted->Equals(StringScalar(R"(union{string: string = alpha})"))) + << casted->ToString(); + } + protected: std::shared_ptr type_; const UnionType* union_type_; @@ -1698,6 +1880,8 @@ TYPED_TEST(TestUnionScalar, Equals) { this->TestEquals(); } TYPED_TEST(TestUnionScalar, MakeNullScalar) { this->TestMakeNullScalar(); } +TYPED_TEST(TestUnionScalar, Cast) { this->TestCast(); } + class TestSparseUnionScalar : public TestUnionScalar {}; TEST_F(TestSparseUnionScalar, GetScalar) { @@ -1974,14 +2158,14 @@ TEST_F(TestExtensionScalar, ValidateErrors) { scalar.is_valid = false; ASSERT_OK(scalar.ValidateFull()); - // Invalid storage scalar (wrong length) - std::shared_ptr invalid_storage = MakeNullScalar(storage_type_); - invalid_storage->is_valid = true; - static_cast(invalid_storage.get())->value = - std::make_shared("123"); - AssertValidationFails(*invalid_storage); + // Invalid storage scalar (invalid UTF8) + ASSERT_OK_AND_ASSIGN(std::shared_ptr invalid_storage, + MakeScalar(utf8(), std::make_shared("\xff"))); + ASSERT_OK(invalid_storage->Validate()); + ASSERT_RAISES(Invalid, invalid_storage->ValidateFull()); scalar = ExtensionScalar(invalid_storage, type_); - AssertValidationFails(scalar); + ASSERT_OK(scalar.Validate()); + ASSERT_RAISES(Invalid, scalar.ValidateFull()); } } // namespace arrow