From 49a3dd167ced33b94abdc355c5e8e7f2a0f15944 Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Thu, 28 Mar 2024 01:39:35 +0800 Subject: [PATCH 01/31] WIP --- cpp/src/arrow/scalar.cc | 60 ++++++++++++++++- cpp/src/arrow/scalar.h | 139 ++++++++++++++++++++++++++++++++++------ 2 files changed, 178 insertions(+), 21 deletions(-) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 6996b46c8b6..e88cec9cf8f 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -542,6 +542,30 @@ struct ScalarValidateImpl { } }; +// // Helper function to fill scratch space, base case for recursion +// void FillScratchSpaceX() {} + +// // Recursive variadic function to fill scratch space +// template +// void FillScratchSpaceX(T first, Args... args) { +// // Ensure that the argument does not exceed the bounds of the scratch space +// static_assert(offset + sizeof(T) <= +// sizeof(internal::ArraySpanFillFromScalarScratchSpace::scratch_space_), +// "Total size of arguments exceeds scratch space size."); + +// // Cast the scratch space at the given offset to the type of the current argument and +// // assign it +// *reinterpret_cast(scratch_space_ + offset) = first; + +// // Calculate the next offset based on the size of the current type T +// constexpr size_t next_offset = offset + sizeof(T); + +// // Recursively fill the scratch space with the remaining arguments +// if constexpr (sizeof...(args) > 0) { // Use if constexpr to stop recursion when +// // there are no more arguments +// FillScratchSpaceX(args...); +// } +// } } // namespace size_t Scalar::hash() const { return ScalarHashImpl(*this).hash_; } @@ -554,8 +578,38 @@ Status Scalar::ValidateFull() const { return ScalarValidateImpl(/*full_validation=*/true).Validate(*this); } -BaseBinaryScalar::BaseBinaryScalar(std::string s, std::shared_ptr type) - : BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {} +// template +// BaseBinaryScalar::BaseBinaryScalar(std::string s, std::shared_ptr type, +// Args... args) +// : BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type), +// std::forward(args)...) {} + +BaseBinaryScalar::BaseBinaryScalar(std::string s, std::shared_ptr type, + FillScratchSpaceByValueFn fn) + : BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type), + std::move(fn)) {} + +BinaryScalar::BinaryScalar(std::string s, std::shared_ptr type) + : BinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {} + +BinaryViewScalar::BinaryViewScalar(std::string s, std::shared_ptr type) + : BinaryViewScalar(Buffer::FromString(std::move(s)), std::move(type)) {} + +LargeBinaryScalar::LargeBinaryScalar(std::string s, std::shared_ptr type) + : LargeBinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {} + +// BinaryViewType::c_type* BinaryViewScalar::FillScratchSpace(uint8_t* scratch_space_, +// bool is_valid, +// const Buffer* value) { +// static_assert(sizeof(BinaryViewType::c_type) <= +// sizeof(internal::ArraySpanFillFromScalarScratchSpace::scratch_space_)); +// auto* view = new (&scratch_space_) BinaryViewType::c_type; +// if (is_valid) { +// *view = util::ToBinaryView(std::string_view{*value}, 0, 0); +// } else { +// *view = {}; +// } +// } FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::shared_ptr value, std::shared_ptr type, @@ -730,7 +784,7 @@ Result TimestampScalar::FromISO8601(std::string_view iso8601, SparseUnionScalar::SparseUnionScalar(ValueType value, int8_t type_code, std::shared_ptr type) - : UnionScalar(std::move(type), type_code, /*is_valid=*/true), + : UnionScalar(std::move(type), type_code, /*is_valid=*/true, FillScratchSpace), value(std::move(value)) { this->child_id = checked_cast(*this->type).child_ids()[type_code]; diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 65c5ee4df0a..2306b135a9f 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -136,6 +136,33 @@ struct ARROW_EXPORT ArraySpanFillFromScalarScratchSpace { // Scalar- including binary scalars where we need to create a buffer // that looks like two 32-bit or 64-bit offsets. alignas(int64_t) mutable uint8_t scratch_space_[sizeof(int64_t) * 2]; + + using FillScratchFn = std::function; + + explicit ArraySpanFillFromScalarScratchSpace(FillScratchFn fn) { + fn(scratch_space_); + } + + // template + // explicit ArraySpanFillFromScalarScratchSpace(Args... args) { + // FillScratchSpace<0>(std::forward(args)...); + // } + + private: + // Helper function to fill scratch space, base case for recursion + void FillScratchSpace() {} + + // Recursive variadic function to fill scratch space + template + void FillScratchSpace(T first, Args... args) { + static_assert(offset + sizeof(T) <= sizeof(scratch_space_), + "Total size of arguments exceeds scratch space size."); + *reinterpret_cast(scratch_space_ + offset) = first; + constexpr size_t next_offset = offset + sizeof(T); + if constexpr (sizeof...(args) > 0) { + FillScratchSpace(args...); + } + } }; struct ARROW_EXPORT PrimitiveScalarBase : public Scalar { @@ -248,7 +275,21 @@ struct ARROW_EXPORT DoubleScalar : public NumericScalar { struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase, private internal::ArraySpanFillFromScalarScratchSpace { - using internal::PrimitiveScalarBase::PrimitiveScalarBase; + // using internal::PrimitiveScalarBase::PrimitiveScalarBase; + using FillScratchSpaceByValueFn = std::function; + + // template + // explicit BaseBinaryScalar(std::shared_ptr type, Args... args) + // : PrimitiveScalarBase(std::move(type)), + // ArraySpanFillFromScalarScratchSpace(std::forward(args)...) {} + + BaseBinaryScalar(std::shared_ptr type, + FillScratchSpaceByValueFn fn) + : PrimitiveScalarBase(std::move(type)), + ArraySpanFillFromScalarScratchSpace([&](uint8_t* scratch_space) { + fn(scratch_space, false, nullptr); + }) {} + using ValueType = std::shared_ptr; std::shared_ptr value; @@ -263,23 +304,51 @@ struct ARROW_EXPORT BaseBinaryScalar return value ? std::string_view(*value) : std::string_view(); } - BaseBinaryScalar(std::shared_ptr value, std::shared_ptr type) - : internal::PrimitiveScalarBase{std::move(type), true}, value(std::move(value)) {} + // template + // BaseBinaryScalar(std::shared_ptr value, std::shared_ptr type, + // Args... args) + // : internal::PrimitiveScalarBase{std::move(type), true}, + // ArraySpanFillFromScalarScratchSpace(std::forward(args)...), + // value(std::move(value)) {} + + BaseBinaryScalar(std::shared_ptr value, std::shared_ptr type, + FillScratchSpaceByValueFn fn) + : internal::PrimitiveScalarBase{std::move(type), true}, + ArraySpanFillFromScalarScratchSpace([&](uint8_t* scratch_space) { + fn(scratch_space, true, value.get()); + }), + value(std::move(value)) {} friend ArraySpan; - BaseBinaryScalar(std::string s, std::shared_ptr type); + // template + // BaseBinaryScalar(std::string s, std::shared_ptr type, Args... args); + + BaseBinaryScalar(std::string s, std::shared_ptr type, + FillScratchSpaceByValueFn fn); }; struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { - using BaseBinaryScalar::BaseBinaryScalar; + // using BaseBinaryScalar::BaseBinaryScalar; using TypeClass = BinaryType; + explicit BinaryScalar(std::shared_ptr type) + : BaseBinaryScalar(std::move(type), FillScratchSpace) {} + + BinaryScalar(std::shared_ptr value, std::shared_ptr type) + : BaseBinaryScalar(std::move(value), std::move(type), FillScratchSpace) {} + + BinaryScalar(std::string s, std::shared_ptr type); + explicit BinaryScalar(std::shared_ptr value) : BinaryScalar(std::move(value), binary()) {} - explicit BinaryScalar(std::string s) : BaseBinaryScalar(std::move(s), binary()) {} + explicit BinaryScalar(std::string s) : BinaryScalar(std::move(s), binary()) {} BinaryScalar() : BinaryScalar(binary()) {} + + private: + static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, + const Buffer* value) {} }; struct ARROW_EXPORT StringScalar : public BinaryScalar { @@ -287,26 +356,38 @@ struct ARROW_EXPORT StringScalar : public BinaryScalar { using TypeClass = StringType; explicit StringScalar(std::shared_ptr value) - : StringScalar(std::move(value), utf8()) {} + : BinaryScalar(std::move(value), utf8()) {} explicit StringScalar(std::string s) : BinaryScalar(std::move(s), utf8()) {} - StringScalar() : StringScalar(utf8()) {} + StringScalar() : BinaryScalar(utf8()) {} }; struct ARROW_EXPORT BinaryViewScalar : public BaseBinaryScalar { - using BaseBinaryScalar::BaseBinaryScalar; + // using BaseBinaryScalar::BaseBinaryScalar; using TypeClass = BinaryViewType; + explicit BinaryViewScalar(std::shared_ptr type) + : BaseBinaryScalar(std::move(type), FillScratchSpace) {} + + BinaryViewScalar(std::shared_ptr value, std::shared_ptr type) + : BaseBinaryScalar(std::move(value), std::move(type), FillScratchSpace) {} + + BinaryViewScalar(std::string s, std::shared_ptr type); + explicit BinaryViewScalar(std::shared_ptr value) : BinaryViewScalar(std::move(value), binary_view()) {} explicit BinaryViewScalar(std::string s) - : BaseBinaryScalar(std::move(s), binary_view()) {} + : BinaryViewScalar(std::move(s), binary_view()) {} BinaryViewScalar() : BinaryViewScalar(binary_view()) {} std::string_view view() const override { return std::string_view(*this->value); } + + private: + static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, + const Buffer* value) {} }; struct ARROW_EXPORT StringViewScalar : public BinaryViewScalar { @@ -314,28 +395,37 @@ struct ARROW_EXPORT StringViewScalar : public BinaryViewScalar { using TypeClass = StringViewType; explicit StringViewScalar(std::shared_ptr value) - : StringViewScalar(std::move(value), utf8_view()) {} + : BinaryViewScalar(std::move(value), utf8_view()) {} explicit StringViewScalar(std::string s) : BinaryViewScalar(std::move(s), utf8_view()) {} - StringViewScalar() : StringViewScalar(utf8_view()) {} + StringViewScalar() : BinaryViewScalar(utf8_view()) {} }; struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar { using BaseBinaryScalar::BaseBinaryScalar; using TypeClass = LargeBinaryType; + explicit LargeBinaryScalar(std::shared_ptr type) + : BaseBinaryScalar(std::move(type), FillScratchSpace) {} + LargeBinaryScalar(std::shared_ptr value, std::shared_ptr type) - : BaseBinaryScalar(std::move(value), std::move(type)) {} + : BaseBinaryScalar(std::move(value), std::move(type), FillScratchSpace) {} + + LargeBinaryScalar(std::string s, std::shared_ptr type); explicit LargeBinaryScalar(std::shared_ptr value) : LargeBinaryScalar(std::move(value), large_binary()) {} explicit LargeBinaryScalar(std::string s) - : BaseBinaryScalar(std::move(s), large_binary()) {} + : LargeBinaryScalar(std::move(s), large_binary()) {} LargeBinaryScalar() : LargeBinaryScalar(large_binary()) {} + + private: + static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, + const Buffer* value) {} }; struct ARROW_EXPORT LargeStringScalar : public LargeBinaryScalar { @@ -343,7 +433,7 @@ struct ARROW_EXPORT LargeStringScalar : public LargeBinaryScalar { using TypeClass = LargeStringType; explicit LargeStringScalar(std::shared_ptr value) - : LargeStringScalar(std::move(value), large_utf8()) {} + : LargeBinaryScalar(std::move(value), large_utf8()) {} explicit LargeStringScalar(std::string s) : LargeBinaryScalar(std::move(s), large_utf8()) {} @@ -583,8 +673,13 @@ struct ARROW_EXPORT UnionScalar : public Scalar, virtual const std::shared_ptr& child_value() const = 0; protected: - UnionScalar(std::shared_ptr type, int8_t type_code, bool is_valid) - : Scalar(std::move(type), is_valid), type_code(type_code) {} + using FillScratchSpaceByValueFn = std::function; + UnionScalar(std::shared_ptr type, int8_t type_code, bool is_valid, + FillScratchSpaceByValueFn fn) + : Scalar(std::move(type), is_valid), + internal::ArraySpanFillFromScalarScratchSpace( + [&](uint8_t* scratch_space) { fn(scratch_space, type_code); }), + type_code(type_code) {} friend struct ArraySpan; }; @@ -611,6 +706,10 @@ struct ARROW_EXPORT SparseUnionScalar : public UnionScalar { /// to construct a vector of scalars static std::shared_ptr FromValue(std::shared_ptr value, int field_index, std::shared_ptr type); + + private: + static void FillScratchSpace(uint8_t* scratch_space, int8_t type_code) { + } }; struct ARROW_EXPORT DenseUnionScalar : public UnionScalar { @@ -624,8 +723,12 @@ struct ARROW_EXPORT DenseUnionScalar : public UnionScalar { const std::shared_ptr& child_value() const override { return this->value; } DenseUnionScalar(ValueType value, int8_t type_code, std::shared_ptr type) - : UnionScalar(std::move(type), type_code, value->is_valid), + : UnionScalar(std::move(type), type_code, value->is_valid, FillScratchSpace), value(std::move(value)) {} + + private: + static void FillScratchSpace(uint8_t* scratch_space, int8_t type_code) { + } }; struct ARROW_EXPORT RunEndEncodedScalar From 3f21d40ef4e95d6d1267ab1c584c724329ef0e0c Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Fri, 29 Mar 2024 01:56:10 +0800 Subject: [PATCH 02/31] WIP --- cpp/src/arrow/scalar.cc | 95 ++++++++++++++++++++++++---- cpp/src/arrow/scalar.h | 134 +++++++++++++++++++--------------------- 2 files changed, 144 insertions(+), 85 deletions(-) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index e88cec9cf8f..3d2722018d5 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -586,18 +586,27 @@ Status Scalar::ValidateFull() const { BaseBinaryScalar::BaseBinaryScalar(std::string s, std::shared_ptr type, FillScratchSpaceByValueFn fn) - : BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type), - std::move(fn)) {} + : BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type), std::move(fn)) { +} BinaryScalar::BinaryScalar(std::string s, std::shared_ptr type) : BinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {} +void BinaryScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, + const Buffer* value) {} + BinaryViewScalar::BinaryViewScalar(std::string s, std::shared_ptr type) : BinaryViewScalar(Buffer::FromString(std::move(s)), std::move(type)) {} +void BinaryViewScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, + const Buffer* value) {} + LargeBinaryScalar::LargeBinaryScalar(std::string s, std::shared_ptr type) : LargeBinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {} +void LargeBinaryScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, + const Buffer* value) {} + // BinaryViewType::c_type* BinaryViewScalar::FillScratchSpace(uint8_t* scratch_space_, // bool is_valid, // const Buffer* value) { @@ -630,22 +639,58 @@ FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::string s, bool is_valid) : FixedSizeBinaryScalar(Buffer::FromString(std::move(s)), is_valid) {} BaseListScalar::BaseListScalar(std::shared_ptr value, - std::shared_ptr type, bool is_valid) - : Scalar{std::move(type), is_valid}, value(std::move(value)) { + std::shared_ptr type, bool is_valid, + FillScratchSpaceByValueFn fn) + : Scalar{std::move(type), is_valid}, + ArraySpanFillFromScalarScratchSpace( + [&](uint8_t* scratch_space) { fn(scratch_space, is_valid, value.get()); }), + value(std::move(value)) { ARROW_CHECK(this->type->field(0)->type()->Equals(this->value->type())); } +ListScalar::ListScalar(std::shared_ptr value, std::shared_ptr type, + bool is_valid) + : BaseListScalar(std::move(value), std::move(type), is_valid, + ListScalar::FillScratchSpace) {} + ListScalar::ListScalar(std::shared_ptr value, bool is_valid) - : BaseListScalar(value, list(value->type()), is_valid) {} + : ListScalar(std::move(value), list(value->type()), is_valid) {} + +void ListScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, + const Array* value) {} + +LargeListScalar::LargeListScalar(std::shared_ptr value, + std::shared_ptr type, bool is_valid) + : BaseListScalar(std::move(value), std::move(type), is_valid, + LargeListScalar::FillScratchSpace) {} LargeListScalar::LargeListScalar(std::shared_ptr value, bool is_valid) - : BaseListScalar(value, large_list(value->type()), is_valid) {} + : LargeListScalar(std::move(value), large_list(value->type()), is_valid) {} + +void LargeListScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, + const Array* value) {} + +ListViewScalar::ListViewScalar(std::shared_ptr value, + std::shared_ptr type, bool is_valid) + : BaseListScalar(std::move(value), std::move(type), is_valid, + ListViewScalar::FillScratchSpace) {} ListViewScalar::ListViewScalar(std::shared_ptr value, bool is_valid) - : BaseListScalar(value, list_view(value->type()), is_valid) {} + : ListViewScalar(std::move(value), list_view(value->type()), is_valid) {} + +void ListViewScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, + const Array* value) {} + +LargeListViewScalar::LargeListViewScalar(std::shared_ptr value, + std::shared_ptr type, bool is_valid) + : BaseListScalar(std::move(value), std::move(type), is_valid, + LargeListViewScalar::FillScratchSpace) {} LargeListViewScalar::LargeListViewScalar(std::shared_ptr value, bool is_valid) - : BaseListScalar(value, large_list_view(value->type()), is_valid) {} + : LargeListViewScalar(std::move(value), large_list_view(value->type()), is_valid) {} + +void LargeListViewScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, + const Array* value) {} inline std::shared_ptr MakeMapType(const std::shared_ptr& pair_type) { ARROW_CHECK_EQ(pair_type->id(), Type::STRUCT); @@ -653,21 +698,34 @@ inline std::shared_ptr MakeMapType(const std::shared_ptr& pa return map(pair_type->field(0)->type(), pair_type->field(1)->type()); } +MapScalar::MapScalar(std::shared_ptr value, std::shared_ptr type, + bool is_valid) + : BaseListScalar(std::move(value), std::move(type), is_valid, + MapScalar::FillScratchSpace) {} + MapScalar::MapScalar(std::shared_ptr value, bool is_valid) - : BaseListScalar(value, MakeMapType(value->type()), is_valid) {} + : MapScalar(std::move(value), MakeMapType(value->type()), is_valid) {} + +void MapScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, + const Array* value) {} FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) - : BaseListScalar(value, std::move(type), is_valid) { + : BaseListScalar(std::move(value), std::move(type), is_valid, + FixedSizeListScalar::FillScratchSpace) { ARROW_CHECK_EQ(this->value->length(), checked_cast(*this->type).list_size()); } FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr value, bool is_valid) - : BaseListScalar( - value, fixed_size_list(value->type(), static_cast(value->length())), + : FixedSizeListScalar( + std::move(value), + fixed_size_list(value->type(), static_cast(value->length())), is_valid) {} +void FixedSizeListScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, + const Array* value) {} + Result> StructScalar::Make( ScalarVector values, std::vector field_names) { if (values.size() != field_names.size()) { @@ -697,9 +755,17 @@ Result> StructScalar::field(FieldRef ref) const { } } +void SparseUnionScalar::FillScratchSpace(uint8_t* scratch_space, int8_t type_code) {} + +void DenseUnionScalar::FillScratchSpace(uint8_t* scratch_space, int8_t type_code) {} + RunEndEncodedScalar::RunEndEncodedScalar(std::shared_ptr value, std::shared_ptr type) - : Scalar{std::move(type), value->is_valid}, value{std::move(value)} { + : Scalar{std::move(type), value->is_valid}, + ArraySpanFillFromScalarScratchSpace([&](uint8_t* scratch_space) { + FillScratchSpace(scratch_space, *(this->type)); + }), + value{std::move(value)} { ARROW_CHECK_EQ(this->type->id(), Type::RUN_END_ENCODED); } @@ -710,6 +776,9 @@ RunEndEncodedScalar::RunEndEncodedScalar(const std::shared_ptr& type) RunEndEncodedScalar::~RunEndEncodedScalar() = default; +void RunEndEncodedScalar::FillScratchSpace(uint8_t* scratch_space, const DataType& type) { +} + DictionaryScalar::DictionaryScalar(std::shared_ptr type) : internal::PrimitiveScalarBase(std::move(type)), value{MakeNullScalar(checked_cast(*this->type).index_type()), diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 2306b135a9f..876c9bab120 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -137,32 +137,12 @@ struct ARROW_EXPORT ArraySpanFillFromScalarScratchSpace { // that looks like two 32-bit or 64-bit offsets. alignas(int64_t) mutable uint8_t scratch_space_[sizeof(int64_t) * 2]; - using FillScratchFn = std::function; + protected: + using FillScratchSpaceFn = std::function; - explicit ArraySpanFillFromScalarScratchSpace(FillScratchFn fn) { + explicit ArraySpanFillFromScalarScratchSpace(FillScratchSpaceFn fn) { fn(scratch_space_); } - - // template - // explicit ArraySpanFillFromScalarScratchSpace(Args... args) { - // FillScratchSpace<0>(std::forward(args)...); - // } - - private: - // Helper function to fill scratch space, base case for recursion - void FillScratchSpace() {} - - // Recursive variadic function to fill scratch space - template - void FillScratchSpace(T first, Args... args) { - static_assert(offset + sizeof(T) <= sizeof(scratch_space_), - "Total size of arguments exceeds scratch space size."); - *reinterpret_cast(scratch_space_ + offset) = first; - constexpr size_t next_offset = offset + sizeof(T); - if constexpr (sizeof...(args) > 0) { - FillScratchSpace(args...); - } - } }; struct ARROW_EXPORT PrimitiveScalarBase : public Scalar { @@ -275,21 +255,6 @@ struct ARROW_EXPORT DoubleScalar : public NumericScalar { struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase, private internal::ArraySpanFillFromScalarScratchSpace { - // using internal::PrimitiveScalarBase::PrimitiveScalarBase; - using FillScratchSpaceByValueFn = std::function; - - // template - // explicit BaseBinaryScalar(std::shared_ptr type, Args... args) - // : PrimitiveScalarBase(std::move(type)), - // ArraySpanFillFromScalarScratchSpace(std::forward(args)...) {} - - BaseBinaryScalar(std::shared_ptr type, - FillScratchSpaceByValueFn fn) - : PrimitiveScalarBase(std::move(type)), - ArraySpanFillFromScalarScratchSpace([&](uint8_t* scratch_space) { - fn(scratch_space, false, nullptr); - }) {} - using ValueType = std::shared_ptr; std::shared_ptr value; @@ -304,31 +269,27 @@ struct ARROW_EXPORT BaseBinaryScalar return value ? std::string_view(*value) : std::string_view(); } - // template - // BaseBinaryScalar(std::shared_ptr value, std::shared_ptr type, - // Args... args) - // : internal::PrimitiveScalarBase{std::move(type), true}, - // ArraySpanFillFromScalarScratchSpace(std::forward(args)...), - // value(std::move(value)) {} + using FillScratchSpaceByValueFn = std::function; + + BaseBinaryScalar(std::shared_ptr type, FillScratchSpaceByValueFn fn) + : PrimitiveScalarBase(std::move(type)), + ArraySpanFillFromScalarScratchSpace( + [&](uint8_t* scratch_space) { fn(scratch_space, false, nullptr); }) {} BaseBinaryScalar(std::shared_ptr value, std::shared_ptr type, FillScratchSpaceByValueFn fn) : internal::PrimitiveScalarBase{std::move(type), true}, - ArraySpanFillFromScalarScratchSpace([&](uint8_t* scratch_space) { - fn(scratch_space, true, value.get()); - }), + ArraySpanFillFromScalarScratchSpace( + [&](uint8_t* scratch_space) { fn(scratch_space, true, value.get()); }), value(std::move(value)) {} - friend ArraySpan; - // template - // BaseBinaryScalar(std::string s, std::shared_ptr type, Args... args); - BaseBinaryScalar(std::string s, std::shared_ptr type, FillScratchSpaceByValueFn fn); + + friend ArraySpan; }; struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { - // using BaseBinaryScalar::BaseBinaryScalar; using TypeClass = BinaryType; explicit BinaryScalar(std::shared_ptr type) @@ -348,7 +309,7 @@ struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { private: static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, - const Buffer* value) {} + const Buffer* value); }; struct ARROW_EXPORT StringScalar : public BinaryScalar { @@ -364,7 +325,6 @@ struct ARROW_EXPORT StringScalar : public BinaryScalar { }; struct ARROW_EXPORT BinaryViewScalar : public BaseBinaryScalar { - // using BaseBinaryScalar::BaseBinaryScalar; using TypeClass = BinaryViewType; explicit BinaryViewScalar(std::shared_ptr type) @@ -387,7 +347,7 @@ struct ARROW_EXPORT BinaryViewScalar : public BaseBinaryScalar { private: static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, - const Buffer* value) {} + const Buffer* value); }; struct ARROW_EXPORT StringViewScalar : public BinaryViewScalar { @@ -425,7 +385,7 @@ struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar { private: static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, - const Buffer* value) {} + const Buffer* value); }; struct ARROW_EXPORT LargeStringScalar : public LargeBinaryScalar { @@ -595,11 +555,12 @@ struct ARROW_EXPORT Decimal256Scalar : public DecimalScalar; + using FillScratchSpaceByValueFn = std::function; + BaseListScalar(std::shared_ptr value, std::shared_ptr type, - bool is_valid = true); + bool is_valid, FillScratchSpaceByValueFn fn); std::shared_ptr value; @@ -609,37 +570,59 @@ struct ARROW_EXPORT BaseListScalar struct ARROW_EXPORT ListScalar : public BaseListScalar { using TypeClass = ListType; - using BaseListScalar::BaseListScalar; + + ListScalar(std::shared_ptr value, std::shared_ptr type, + bool is_valid = true); explicit ListScalar(std::shared_ptr value, bool is_valid = true); + + private: + static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, const Array* value); }; struct ARROW_EXPORT LargeListScalar : public BaseListScalar { using TypeClass = LargeListType; - using BaseListScalar::BaseListScalar; + LargeListScalar(std::shared_ptr value, std::shared_ptr type, + bool is_valid = true); explicit LargeListScalar(std::shared_ptr value, bool is_valid = true); + + private: + static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, const Array* value); }; struct ARROW_EXPORT ListViewScalar : public BaseListScalar { using TypeClass = ListViewType; - using BaseListScalar::BaseListScalar; + ListViewScalar(std::shared_ptr value, std::shared_ptr type, + bool is_valid = true); explicit ListViewScalar(std::shared_ptr value, bool is_valid = true); + + private: + static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, const Array* value); }; struct ARROW_EXPORT LargeListViewScalar : public BaseListScalar { using TypeClass = LargeListViewType; - using BaseListScalar::BaseListScalar; + LargeListViewScalar(std::shared_ptr value, std::shared_ptr type, + bool is_valid = true); explicit LargeListViewScalar(std::shared_ptr value, bool is_valid = true); + + private: + static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, const Array* value); }; struct ARROW_EXPORT MapScalar : public BaseListScalar { using TypeClass = MapType; - using BaseListScalar::BaseListScalar; + // using BaseListScalar::BaseListScalar; + MapScalar(std::shared_ptr value, std::shared_ptr type, + bool is_valid = true); explicit MapScalar(std::shared_ptr value, bool is_valid = true); + + private: + static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, const Array* value); }; struct ARROW_EXPORT FixedSizeListScalar : public BaseListScalar { @@ -649,6 +632,9 @@ struct ARROW_EXPORT FixedSizeListScalar : public BaseListScalar { bool is_valid = true); explicit FixedSizeListScalar(std::shared_ptr value, bool is_valid = true); + + private: + static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, const Array* value); }; struct ARROW_EXPORT StructScalar : public Scalar { @@ -707,9 +693,8 @@ struct ARROW_EXPORT SparseUnionScalar : public UnionScalar { static std::shared_ptr FromValue(std::shared_ptr value, int field_index, std::shared_ptr type); - private: - static void FillScratchSpace(uint8_t* scratch_space, int8_t type_code) { - } + private: + static void FillScratchSpace(uint8_t* scratch_space, int8_t type_code); }; struct ARROW_EXPORT DenseUnionScalar : public UnionScalar { @@ -726,9 +711,8 @@ struct ARROW_EXPORT DenseUnionScalar : public UnionScalar { : UnionScalar(std::move(type), type_code, value->is_valid, FillScratchSpace), value(std::move(value)) {} - private: - static void FillScratchSpace(uint8_t* scratch_space, int8_t type_code) { - } + private: + static void FillScratchSpace(uint8_t* scratch_space, int8_t type_code); }; struct ARROW_EXPORT RunEndEncodedScalar @@ -747,13 +731,19 @@ struct ARROW_EXPORT RunEndEncodedScalar ~RunEndEncodedScalar() override; const std::shared_ptr& run_end_type() const { - return ree_type().run_end_type(); + return ree_type(*type).run_end_type(); } - const std::shared_ptr& value_type() const { return ree_type().value_type(); } + const std::shared_ptr& value_type() const { + return ree_type(*type).value_type(); + } private: - const TypeClass& ree_type() const { return internal::checked_cast(*type); } + static const TypeClass& ree_type(const DataType& type) { + return internal::checked_cast(type); + } + + static void FillScratchSpace(uint8_t* scratch_space, const DataType& type); friend ArraySpan; }; From d2ba93aef91aecd54baec68825fb7f6daa6b3e03 Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Fri, 29 Mar 2024 17:03:16 +0800 Subject: [PATCH 03/31] Scalar initialize scratch space in ctor --- cpp/src/arrow/scalar.cc | 214 +++++++++++++++++++++++++--------------- cpp/src/arrow/scalar.h | 182 +++++++++++++++++++++++++--------- 2 files changed, 267 insertions(+), 129 deletions(-) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 3d2722018d5..079b400b86e 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -543,29 +543,35 @@ struct ScalarValidateImpl { }; // // Helper function to fill scratch space, base case for recursion -// void FillScratchSpaceX() {} - -// // Recursive variadic function to fill scratch space -// template -// void FillScratchSpaceX(T first, Args... args) { -// // Ensure that the argument does not exceed the bounds of the scratch space -// static_assert(offset + sizeof(T) <= -// sizeof(internal::ArraySpanFillFromScalarScratchSpace::scratch_space_), -// "Total size of arguments exceeds scratch space size."); - -// // Cast the scratch space at the given offset to the type of the current argument and -// // assign it -// *reinterpret_cast(scratch_space_ + offset) = first; - -// // Calculate the next offset based on the size of the current type T -// constexpr size_t next_offset = offset + sizeof(T); - -// // Recursively fill the scratch space with the remaining arguments -// if constexpr (sizeof...(args) > 0) { // Use if constexpr to stop recursion when -// // there are no more arguments -// FillScratchSpaceX(args...); -// } -// } +// void FillScalarScratchSpaceHelperInternal(uint8_t* scratch_space) {} + +// Recursive variadic function to fill scratch space +template +void FillScalarScratchSpaceHelperInternal(uint8_t* scratch_space, T first, Args... args) { + // Ensure that the argument does not exceed the bounds of the scratch space + static_assert(offset + sizeof(T) <= + sizeof(internal::ArraySpanFillFromScalarScratchSpace::scratch_space_), + "Total size of arguments exceeds scratch space size."); + + // Cast the scratch space at the given offset to the type of the current argument and + // assign it + *reinterpret_cast(scratch_space + offset) = first; + + // Calculate the next offset based on the size of the current type T + constexpr size_t next_offset = offset + sizeof(T); + + // Recursively fill the scratch space with the remaining arguments + if constexpr (sizeof...(args) > 0) { // Use if constexpr to stop recursion when + // there are no more arguments + FillScalarScratchSpaceHelperInternal(scratch_space, + std::forward(args)...); + } +} + +template +void FillScalarScratchSpaceHelper(Args... args) { + FillScalarScratchSpaceHelperInternal<0>(std::forward(args)...); +} } // namespace size_t Scalar::hash() const { return ScalarHashImpl(*this).hash_; } @@ -578,47 +584,44 @@ Status Scalar::ValidateFull() const { return ScalarValidateImpl(/*full_validation=*/true).Validate(*this); } -// template -// BaseBinaryScalar::BaseBinaryScalar(std::string s, std::shared_ptr type, -// Args... args) -// : BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type), -// std::forward(args)...) {} - BaseBinaryScalar::BaseBinaryScalar(std::string s, std::shared_ptr type, - FillScratchSpaceByValueFn fn) - : BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type), std::move(fn)) { -} + const FillScalarScratchSpace& fill) + : BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type), fill) {} BinaryScalar::BinaryScalar(std::string s, std::shared_ptr type) : BinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {} -void BinaryScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, - const Buffer* value) {} +void BinaryScalar::FillScalarScratchSpace::Fill( + const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { + FillScalarScratchSpaceHelper( + scratch_space.scratch_space_, int32_t(0), + is_valid_ ? static_cast(value_->size()) : int32_t(0)); +} BinaryViewScalar::BinaryViewScalar(std::string s, std::shared_ptr type) : BinaryViewScalar(Buffer::FromString(std::move(s)), std::move(type)) {} -void BinaryViewScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, - const Buffer* value) {} +void BinaryViewScalar::FillScalarScratchSpace::Fill( + const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { + static_assert(sizeof(BinaryViewType::c_type) <= + sizeof(internal::ArraySpanFillFromScalarScratchSpace::scratch_space_)); + auto* view = new (&scratch_space.scratch_space_) BinaryViewType::c_type; + if (is_valid_) { + *view = util::ToBinaryView(std::string_view{*value_}, 0, 0); + } else { + *view = {}; + } +} LargeBinaryScalar::LargeBinaryScalar(std::string s, std::shared_ptr type) : LargeBinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {} -void LargeBinaryScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, - const Buffer* value) {} - -// BinaryViewType::c_type* BinaryViewScalar::FillScratchSpace(uint8_t* scratch_space_, -// bool is_valid, -// const Buffer* value) { -// static_assert(sizeof(BinaryViewType::c_type) <= -// sizeof(internal::ArraySpanFillFromScalarScratchSpace::scratch_space_)); -// auto* view = new (&scratch_space_) BinaryViewType::c_type; -// if (is_valid) { -// *view = util::ToBinaryView(std::string_view{*value}, 0, 0); -// } else { -// *view = {}; -// } -// } +void LargeBinaryScalar::FillScalarScratchSpace::Fill( + const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { + FillScalarScratchSpaceHelper( + scratch_space.scratch_space_, int64_t(0), + is_valid_ ? static_cast(value_->size()) : int64_t(0)); +} FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::shared_ptr value, std::shared_ptr type, @@ -640,10 +643,9 @@ FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::string s, bool is_valid) BaseListScalar::BaseListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid, - FillScratchSpaceByValueFn fn) + const FillScalarScratchSpace& fill) : Scalar{std::move(type), is_valid}, - ArraySpanFillFromScalarScratchSpace( - [&](uint8_t* scratch_space) { fn(scratch_space, is_valid, value.get()); }), + ArraySpanFillFromScalarScratchSpace(fill), value(std::move(value)) { ARROW_CHECK(this->type->field(0)->type()->Equals(this->value->type())); } @@ -651,46 +653,60 @@ BaseListScalar::BaseListScalar(std::shared_ptr value, ListScalar::ListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) : BaseListScalar(std::move(value), std::move(type), is_valid, - ListScalar::FillScratchSpace) {} + ListScalar::FillScalarScratchSpace(value.get())) {} ListScalar::ListScalar(std::shared_ptr value, bool is_valid) : ListScalar(std::move(value), list(value->type()), is_valid) {} -void ListScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, - const Array* value) {} +void ListScalar::FillScalarScratchSpace::Fill( + const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { + FillScalarScratchSpaceHelper( + scratch_space.scratch_space_, int32_t(0), + value_ ? static_cast(value_->length()) : int32_t(0)); +} LargeListScalar::LargeListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) : BaseListScalar(std::move(value), std::move(type), is_valid, - LargeListScalar::FillScratchSpace) {} + LargeListScalar::FillScalarScratchSpace(value.get())) {} LargeListScalar::LargeListScalar(std::shared_ptr value, bool is_valid) : LargeListScalar(std::move(value), large_list(value->type()), is_valid) {} -void LargeListScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, - const Array* value) {} +void LargeListScalar::FillScalarScratchSpace::Fill( + const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { + FillScalarScratchSpaceHelper(scratch_space.scratch_space_, int64_t(0), + value_ ? value_->length() : int64_t(0)); +} ListViewScalar::ListViewScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) : BaseListScalar(std::move(value), std::move(type), is_valid, - ListViewScalar::FillScratchSpace) {} + ListViewScalar::FillScalarScratchSpace(value.get())) {} ListViewScalar::ListViewScalar(std::shared_ptr value, bool is_valid) : ListViewScalar(std::move(value), list_view(value->type()), is_valid) {} -void ListViewScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, - const Array* value) {} +void ListViewScalar::FillScalarScratchSpace::Fill( + const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { + FillScalarScratchSpaceHelper( + scratch_space.scratch_space_, int32_t(0), + value_ ? static_cast(value_->length()) : int32_t(0)); +} LargeListViewScalar::LargeListViewScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) : BaseListScalar(std::move(value), std::move(type), is_valid, - LargeListViewScalar::FillScratchSpace) {} + LargeListViewScalar::FillScalarScratchSpace(value.get())) {} LargeListViewScalar::LargeListViewScalar(std::shared_ptr value, bool is_valid) : LargeListViewScalar(std::move(value), large_list_view(value->type()), is_valid) {} -void LargeListViewScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, - const Array* value) {} +void LargeListViewScalar::FillScalarScratchSpace::Fill( + const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { + FillScalarScratchSpaceHelper(scratch_space.scratch_space_, int64_t(0), + value_ ? value_->length() : int64_t(0)); +} inline std::shared_ptr MakeMapType(const std::shared_ptr& pair_type) { ARROW_CHECK_EQ(pair_type->id(), Type::STRUCT); @@ -701,18 +717,22 @@ inline std::shared_ptr MakeMapType(const std::shared_ptr& pa MapScalar::MapScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) : BaseListScalar(std::move(value), std::move(type), is_valid, - MapScalar::FillScratchSpace) {} + MapScalar::FillScalarScratchSpace(value.get())) {} MapScalar::MapScalar(std::shared_ptr value, bool is_valid) : MapScalar(std::move(value), MakeMapType(value->type()), is_valid) {} -void MapScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, - const Array* value) {} +void MapScalar::FillScalarScratchSpace::Fill( + const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { + FillScalarScratchSpaceHelper( + scratch_space.scratch_space_, int32_t(0), + value_ ? static_cast(value_->length()) : int32_t(0)); +} FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) : BaseListScalar(std::move(value), std::move(type), is_valid, - FixedSizeListScalar::FillScratchSpace) { + FixedSizeListScalar::FillScalarScratchSpace(value.get())) { ARROW_CHECK_EQ(this->value->length(), checked_cast(*this->type).list_size()); } @@ -723,9 +743,6 @@ FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr value, bool is_v fixed_size_list(value->type(), static_cast(value->length())), is_valid) {} -void FixedSizeListScalar::FillScratchSpace(uint8_t* scratch_space, bool is_valid, - const Array* value) {} - Result> StructScalar::Make( ScalarVector values, std::vector field_names) { if (values.size() != field_names.size()) { @@ -755,16 +772,10 @@ Result> StructScalar::field(FieldRef ref) const { } } -void SparseUnionScalar::FillScratchSpace(uint8_t* scratch_space, int8_t type_code) {} - -void DenseUnionScalar::FillScratchSpace(uint8_t* scratch_space, int8_t type_code) {} - RunEndEncodedScalar::RunEndEncodedScalar(std::shared_ptr value, std::shared_ptr type) : Scalar{std::move(type), value->is_valid}, - ArraySpanFillFromScalarScratchSpace([&](uint8_t* scratch_space) { - FillScratchSpace(scratch_space, *(this->type)); - }), + ArraySpanFillFromScalarScratchSpace(FillScalarScratchSpace(ree_type(this->type))), value{std::move(value)} { ARROW_CHECK_EQ(this->type->id(), Type::RUN_END_ENCODED); } @@ -776,7 +787,19 @@ RunEndEncodedScalar::RunEndEncodedScalar(const std::shared_ptr& type) RunEndEncodedScalar::~RunEndEncodedScalar() = default; -void RunEndEncodedScalar::FillScratchSpace(uint8_t* scratch_space, const DataType& type) { +void RunEndEncodedScalar::FillScalarScratchSpace::Fill( + const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { + switch (ree_type_->id()) { + case Type::INT16: + FillScalarScratchSpaceHelper(scratch_space.scratch_space_, int16_t(1)); + break; + case Type::INT32: + FillScalarScratchSpaceHelper(scratch_space.scratch_space_, int32_t(1)); + break; + default: + DCHECK_EQ(ree_type_->id(), Type::INT64); + FillScalarScratchSpaceHelper(scratch_space.scratch_space_, int64_t(1)); + } } DictionaryScalar::DictionaryScalar(std::shared_ptr type) @@ -853,7 +876,8 @@ Result TimestampScalar::FromISO8601(std::string_view iso8601, SparseUnionScalar::SparseUnionScalar(ValueType value, int8_t type_code, std::shared_ptr type) - : UnionScalar(std::move(type), type_code, /*is_valid=*/true, FillScratchSpace), + : UnionScalar(std::move(type), type_code, /*is_valid=*/true, + FillScalarScratchSpace(type_code)), value(std::move(value)) { this->child_id = checked_cast(*this->type).child_ids()[type_code]; @@ -880,6 +904,32 @@ std::shared_ptr SparseUnionScalar::FromValue(std::shared_ptr val namespace { +struct UnionScratchSpace { + alignas(int64_t) int8_t type_code; + alignas(int64_t) uint8_t offsets[sizeof(int32_t) * 2]; +}; +static_assert(sizeof(UnionScratchSpace) <= + sizeof(internal::ArraySpanFillFromScalarScratchSpace::scratch_space_)); + +} // namespace + +void SparseUnionScalar::FillScalarScratchSpace::Fill( + const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { + auto* union_scratch_space = + reinterpret_cast(&scratch_space.scratch_space_); + union_scratch_space->type_code = type_code_; +} + +void DenseUnionScalar::FillScalarScratchSpace::Fill( + const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { + auto* union_scratch_space = + reinterpret_cast(&scratch_space.scratch_space_); + union_scratch_space->type_code = type_code_; + FillScalarScratchSpaceHelper(union_scratch_space->offsets, int32_t(0), int32_t(1)); +} + +namespace { + template using scalar_constructor_has_arrow_type = std::is_constructible::ScalarType, std::shared_ptr>; diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 876c9bab120..d3f2252931c 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -131,6 +131,14 @@ struct ARROW_EXPORT NullScalar : public Scalar { namespace internal { +struct ArraySpanFillFromScalarScratchSpace; + +// Helper class to fill scratch space in a polymorphic way. +struct FillScalarScratchSpace { + virtual ~FillScalarScratchSpace() = default; + virtual void Fill(const ArraySpanFillFromScalarScratchSpace& scratch_space) const = 0; +}; + struct ARROW_EXPORT ArraySpanFillFromScalarScratchSpace { // 16 bytes of scratch space to enable ArraySpan to be a view onto any // Scalar- including binary scalars where we need to create a buffer @@ -138,10 +146,8 @@ struct ARROW_EXPORT ArraySpanFillFromScalarScratchSpace { alignas(int64_t) mutable uint8_t scratch_space_[sizeof(int64_t) * 2]; protected: - using FillScratchSpaceFn = std::function; - - explicit ArraySpanFillFromScalarScratchSpace(FillScratchSpaceFn fn) { - fn(scratch_space_); + explicit ArraySpanFillFromScalarScratchSpace(const FillScalarScratchSpace& fill) { + fill.Fill(*this); } }; @@ -269,22 +275,27 @@ struct ARROW_EXPORT BaseBinaryScalar return value ? std::string_view(*value) : std::string_view(); } - using FillScratchSpaceByValueFn = std::function; + protected: + struct FillScalarScratchSpace : public internal::FillScalarScratchSpace { + FillScalarScratchSpace(bool is_valid, const Buffer* value) + : is_valid_(is_valid), value_(value) {} - BaseBinaryScalar(std::shared_ptr type, FillScratchSpaceByValueFn fn) - : PrimitiveScalarBase(std::move(type)), - ArraySpanFillFromScalarScratchSpace( - [&](uint8_t* scratch_space) { fn(scratch_space, false, nullptr); }) {} + protected: + bool is_valid_; + const Buffer* value_; + }; + + BaseBinaryScalar(std::shared_ptr type, const FillScalarScratchSpace& fill) + : PrimitiveScalarBase(std::move(type)), ArraySpanFillFromScalarScratchSpace(fill) {} BaseBinaryScalar(std::shared_ptr value, std::shared_ptr type, - FillScratchSpaceByValueFn fn) + const FillScalarScratchSpace& fill) : internal::PrimitiveScalarBase{std::move(type), true}, - ArraySpanFillFromScalarScratchSpace( - [&](uint8_t* scratch_space) { fn(scratch_space, true, value.get()); }), + ArraySpanFillFromScalarScratchSpace(fill), value(std::move(value)) {} BaseBinaryScalar(std::string s, std::shared_ptr type, - FillScratchSpaceByValueFn fn); + const FillScalarScratchSpace& fill); friend ArraySpan; }; @@ -293,10 +304,11 @@ struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { using TypeClass = BinaryType; explicit BinaryScalar(std::shared_ptr type) - : BaseBinaryScalar(std::move(type), FillScratchSpace) {} + : BaseBinaryScalar(std::move(type), FillScalarScratchSpace(false, nullptr)) {} BinaryScalar(std::shared_ptr value, std::shared_ptr type) - : BaseBinaryScalar(std::move(value), std::move(type), FillScratchSpace) {} + : BaseBinaryScalar(std::move(value), std::move(type), + FillScalarScratchSpace(true, value.get())) {} BinaryScalar(std::string s, std::shared_ptr type); @@ -308,8 +320,12 @@ struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { BinaryScalar() : BinaryScalar(binary()) {} private: - static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, - const Buffer* value); + struct FillScalarScratchSpace : public BaseBinaryScalar::FillScalarScratchSpace { + using BaseBinaryScalar::FillScalarScratchSpace::FillScalarScratchSpace; + + void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) + const override; + }; }; struct ARROW_EXPORT StringScalar : public BinaryScalar { @@ -328,10 +344,11 @@ struct ARROW_EXPORT BinaryViewScalar : public BaseBinaryScalar { using TypeClass = BinaryViewType; explicit BinaryViewScalar(std::shared_ptr type) - : BaseBinaryScalar(std::move(type), FillScratchSpace) {} + : BaseBinaryScalar(std::move(type), FillScalarScratchSpace(false, nullptr)) {} BinaryViewScalar(std::shared_ptr value, std::shared_ptr type) - : BaseBinaryScalar(std::move(value), std::move(type), FillScratchSpace) {} + : BaseBinaryScalar(std::move(value), std::move(type), + FillScalarScratchSpace(true, value.get())) {} BinaryViewScalar(std::string s, std::shared_ptr type); @@ -346,8 +363,12 @@ struct ARROW_EXPORT BinaryViewScalar : public BaseBinaryScalar { std::string_view view() const override { return std::string_view(*this->value); } private: - static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, - const Buffer* value); + struct FillScalarScratchSpace : public BaseBinaryScalar::FillScalarScratchSpace { + using BaseBinaryScalar::FillScalarScratchSpace::FillScalarScratchSpace; + + void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) + const override; + }; }; struct ARROW_EXPORT StringViewScalar : public BinaryViewScalar { @@ -368,10 +389,11 @@ struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar { using TypeClass = LargeBinaryType; explicit LargeBinaryScalar(std::shared_ptr type) - : BaseBinaryScalar(std::move(type), FillScratchSpace) {} + : BaseBinaryScalar(std::move(type), FillScalarScratchSpace(false, nullptr)) {} LargeBinaryScalar(std::shared_ptr value, std::shared_ptr type) - : BaseBinaryScalar(std::move(value), std::move(type), FillScratchSpace) {} + : BaseBinaryScalar(std::move(value), std::move(type), + FillScalarScratchSpace(true, value.get())) {} LargeBinaryScalar(std::string s, std::shared_ptr type); @@ -384,8 +406,12 @@ struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar { LargeBinaryScalar() : LargeBinaryScalar(large_binary()) {} private: - static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, - const Buffer* value); + struct FillScalarScratchSpace : public BaseBinaryScalar::FillScalarScratchSpace { + using BaseBinaryScalar::FillScalarScratchSpace::FillScalarScratchSpace; + + void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) + const override; + }; }; struct ARROW_EXPORT LargeStringScalar : public LargeBinaryScalar { @@ -557,12 +583,18 @@ struct ARROW_EXPORT BaseListScalar private internal::ArraySpanFillFromScalarScratchSpace { using ValueType = std::shared_ptr; - using FillScratchSpaceByValueFn = std::function; + std::shared_ptr value; - BaseListScalar(std::shared_ptr value, std::shared_ptr type, - bool is_valid, FillScratchSpaceByValueFn fn); + protected: + struct FillScalarScratchSpace : public internal::FillScalarScratchSpace { + FillScalarScratchSpace(const Array* value) : value_(value) {} - std::shared_ptr value; + protected: + const Array* value_; + }; + + BaseListScalar(std::shared_ptr value, std::shared_ptr type, + bool is_valid, const FillScalarScratchSpace& fill); private: friend struct ArraySpan; @@ -577,7 +609,12 @@ struct ARROW_EXPORT ListScalar : public BaseListScalar { explicit ListScalar(std::shared_ptr value, bool is_valid = true); private: - static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, const Array* value); + struct FillScalarScratchSpace : public BaseListScalar::FillScalarScratchSpace { + using BaseListScalar::FillScalarScratchSpace::FillScalarScratchSpace; + + void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) + const override; + }; }; struct ARROW_EXPORT LargeListScalar : public BaseListScalar { @@ -588,7 +625,12 @@ struct ARROW_EXPORT LargeListScalar : public BaseListScalar { explicit LargeListScalar(std::shared_ptr value, bool is_valid = true); private: - static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, const Array* value); + struct FillScalarScratchSpace : public BaseListScalar::FillScalarScratchSpace { + using BaseListScalar::FillScalarScratchSpace::FillScalarScratchSpace; + + void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) + const override; + }; }; struct ARROW_EXPORT ListViewScalar : public BaseListScalar { @@ -599,7 +641,12 @@ struct ARROW_EXPORT ListViewScalar : public BaseListScalar { explicit ListViewScalar(std::shared_ptr value, bool is_valid = true); private: - static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, const Array* value); + struct FillScalarScratchSpace : public BaseListScalar::FillScalarScratchSpace { + using BaseListScalar::FillScalarScratchSpace::FillScalarScratchSpace; + + void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) + const override; + }; }; struct ARROW_EXPORT LargeListViewScalar : public BaseListScalar { @@ -610,7 +657,12 @@ struct ARROW_EXPORT LargeListViewScalar : public BaseListScalar { explicit LargeListViewScalar(std::shared_ptr value, bool is_valid = true); private: - static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, const Array* value); + struct FillScalarScratchSpace : public BaseListScalar::FillScalarScratchSpace { + using BaseListScalar::FillScalarScratchSpace::FillScalarScratchSpace; + + void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) + const override; + }; }; struct ARROW_EXPORT MapScalar : public BaseListScalar { @@ -622,7 +674,12 @@ struct ARROW_EXPORT MapScalar : public BaseListScalar { explicit MapScalar(std::shared_ptr value, bool is_valid = true); private: - static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, const Array* value); + struct FillScalarScratchSpace : public BaseListScalar::FillScalarScratchSpace { + using BaseListScalar::FillScalarScratchSpace::FillScalarScratchSpace; + + void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) + const override; + }; }; struct ARROW_EXPORT FixedSizeListScalar : public BaseListScalar { @@ -634,7 +691,12 @@ struct ARROW_EXPORT FixedSizeListScalar : public BaseListScalar { explicit FixedSizeListScalar(std::shared_ptr value, bool is_valid = true); private: - static void FillScratchSpace(uint8_t* scratch_space, bool is_valid, const Array* value); + struct FillScalarScratchSpace : public BaseListScalar::FillScalarScratchSpace { + using BaseListScalar::FillScalarScratchSpace::FillScalarScratchSpace; + + void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) + const override {} + }; }; struct ARROW_EXPORT StructScalar : public Scalar { @@ -659,12 +721,17 @@ struct ARROW_EXPORT UnionScalar : public Scalar, virtual const std::shared_ptr& child_value() const = 0; protected: - using FillScratchSpaceByValueFn = std::function; + struct FillScalarScratchSpace : public internal::FillScalarScratchSpace { + explicit FillScalarScratchSpace(int8_t type_code) : type_code_(type_code) {} + + protected: + int8_t type_code_; + }; + UnionScalar(std::shared_ptr type, int8_t type_code, bool is_valid, - FillScratchSpaceByValueFn fn) + const FillScalarScratchSpace& fill) : Scalar(std::move(type), is_valid), - internal::ArraySpanFillFromScalarScratchSpace( - [&](uint8_t* scratch_space) { fn(scratch_space, type_code); }), + internal::ArraySpanFillFromScalarScratchSpace(fill), type_code(type_code) {} friend struct ArraySpan; @@ -694,7 +761,12 @@ struct ARROW_EXPORT SparseUnionScalar : public UnionScalar { std::shared_ptr type); private: - static void FillScratchSpace(uint8_t* scratch_space, int8_t type_code); + struct FillScalarScratchSpace : public UnionScalar::FillScalarScratchSpace { + using UnionScalar::FillScalarScratchSpace::FillScalarScratchSpace; + + void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) + const override; + }; }; struct ARROW_EXPORT DenseUnionScalar : public UnionScalar { @@ -708,11 +780,17 @@ struct ARROW_EXPORT DenseUnionScalar : public UnionScalar { const std::shared_ptr& child_value() const override { return this->value; } DenseUnionScalar(ValueType value, int8_t type_code, std::shared_ptr type) - : UnionScalar(std::move(type), type_code, value->is_valid, FillScratchSpace), + : UnionScalar(std::move(type), type_code, value->is_valid, + FillScalarScratchSpace(type_code)), value(std::move(value)) {} private: - static void FillScratchSpace(uint8_t* scratch_space, int8_t type_code); + struct FillScalarScratchSpace : public UnionScalar::FillScalarScratchSpace { + using UnionScalar::FillScalarScratchSpace::FillScalarScratchSpace; + + void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) + const override; + }; }; struct ARROW_EXPORT RunEndEncodedScalar @@ -731,19 +809,29 @@ struct ARROW_EXPORT RunEndEncodedScalar ~RunEndEncodedScalar() override; const std::shared_ptr& run_end_type() const { - return ree_type(*type).run_end_type(); + return ree_type(type)->run_end_type(); } const std::shared_ptr& value_type() const { - return ree_type(*type).value_type(); + return ree_type(type)->value_type(); } private: - static const TypeClass& ree_type(const DataType& type) { - return internal::checked_cast(type); + static std::shared_ptr ree_type( + const std::shared_ptr& type) { + return internal::checked_pointer_cast(type); } - static void FillScratchSpace(uint8_t* scratch_space, const DataType& type); + struct FillScalarScratchSpace : public internal::FillScalarScratchSpace { + explicit FillScalarScratchSpace(std::shared_ptr ree_type) + : ree_type_(std::move(ree_type)) {} + + void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) + const override; + + private: + const std::shared_ptr ree_type_; + }; friend ArraySpan; }; From 6647f05a07979a47bebaa7d1e3322873510bb6bf Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Fri, 29 Mar 2024 17:54:41 +0800 Subject: [PATCH 04/31] Remove scratch writing in ArraySpan --- cpp/src/arrow/array/data.cc | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc index 80c411dfa6a..51d328518cf 100644 --- a/cpp/src/arrow/array/data.cc +++ b/cpp/src/arrow/array/data.cc @@ -285,10 +285,10 @@ namespace { template BufferSpan OffsetsForScalar(uint8_t* scratch_space, offset_type value_size) { - auto* offsets = reinterpret_cast(scratch_space); - offsets[0] = 0; - offsets[1] = static_cast(value_size); - static_assert(2 * sizeof(offset_type) <= 16); + // auto* offsets = reinterpret_cast(scratch_space); + // offsets[0] = 0; + // offsets[1] = static_cast(value_size); + // static_assert(2 * sizeof(offset_type) <= 16); return {scratch_space, sizeof(offset_type) * 2}; } @@ -297,9 +297,9 @@ std::pair OffsetsAndSizesForScalar(uint8_t* scratch_spac offset_type value_size) { auto* offsets = scratch_space; auto* sizes = scratch_space + sizeof(offset_type); - reinterpret_cast(offsets)[0] = 0; - reinterpret_cast(sizes)[0] = value_size; - static_assert(2 * sizeof(offset_type) <= 16); + // reinterpret_cast(offsets)[0] = 0; + // reinterpret_cast(sizes)[0] = value_size; + // static_assert(2 * sizeof(offset_type) <= 16); return {BufferSpan{offsets, sizeof(offset_type)}, BufferSpan{sizes, sizeof(offset_type)}}; } @@ -428,14 +428,14 @@ void ArraySpan::FillFromScalar(const Scalar& value) { this->buffers[1].size = BinaryViewType::kSize; this->buffers[1].data = scalar.scratch_space_; - static_assert(sizeof(BinaryViewType::c_type) <= sizeof(scalar.scratch_space_)); - auto* view = new (&scalar.scratch_space_) BinaryViewType::c_type; - if (scalar.is_valid) { - *view = util::ToBinaryView(std::string_view{*scalar.value}, 0, 0); - this->buffers[2] = internal::PackVariadicBuffers({&scalar.value, 1}); - } else { - *view = {}; - } + // static_assert(sizeof(BinaryViewType::c_type) <= sizeof(scalar.scratch_space_)); + // auto* view = new (&scalar.scratch_space_) BinaryViewType::c_type; + // if (scalar.is_valid) { + // *view = util::ToBinaryView(std::string_view{*scalar.value}, 0, 0); + // this->buffers[2] = internal::PackVariadicBuffers({&scalar.value, 1}); + // } else { + // *view = {}; + // } } else if (type_id == Type::FIXED_SIZE_BINARY) { const auto& scalar = checked_cast(value); this->buffers[1].data = const_cast(scalar.value->data()); @@ -492,7 +492,7 @@ void ArraySpan::FillFromScalar(const Scalar& value) { // First buffer is kept null since unions have no validity vector this->buffers[0] = {}; - union_scratch_space->type_code = checked_cast(value).type_code; + // union_scratch_space->type_code = checked_cast(value).type_code; this->buffers[1].data = reinterpret_cast(&union_scratch_space->type_code); this->buffers[1].size = 1; @@ -541,7 +541,7 @@ void ArraySpan::FillFromScalar(const Scalar& value) { e.null_count = 0; e.buffers[1].data = scalar.scratch_space_; e.buffers[1].size = sizeof(run_end); - reinterpret_cast(scalar.scratch_space_)[0] = run_end; + // reinterpret_cast(scalar.scratch_space_)[0] = run_end; }; switch (scalar.run_end_type()->id()) { From aba9dcb5fa5508d85799ffc7030ced471d485974 Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Sat, 30 Mar 2024 02:35:46 +0800 Subject: [PATCH 05/31] Fix --- cpp/src/arrow/array/data.cc | 12 +++--- cpp/src/arrow/scalar.cc | 43 +++++++++++--------- cpp/src/arrow/scalar.h | 81 +++++++++++++++++++++++-------------- 3 files changed, 79 insertions(+), 57 deletions(-) diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc index 51d328518cf..7be2658835b 100644 --- a/cpp/src/arrow/array/data.cc +++ b/cpp/src/arrow/array/data.cc @@ -430,12 +430,12 @@ void ArraySpan::FillFromScalar(const Scalar& value) { this->buffers[1].data = scalar.scratch_space_; // static_assert(sizeof(BinaryViewType::c_type) <= sizeof(scalar.scratch_space_)); // auto* view = new (&scalar.scratch_space_) BinaryViewType::c_type; - // if (scalar.is_valid) { - // *view = util::ToBinaryView(std::string_view{*scalar.value}, 0, 0); - // this->buffers[2] = internal::PackVariadicBuffers({&scalar.value, 1}); - // } else { - // *view = {}; - // } + if (scalar.is_valid) { + // *view = util::ToBinaryView(std::string_view{*scalar.value}, 0, 0); + this->buffers[2] = internal::PackVariadicBuffers({&scalar.value, 1}); + } else { + // *view = {}; + } } else if (type_id == Type::FIXED_SIZE_BINARY) { const auto& scalar = checked_cast(value); this->buffers[1].data = const_cast(scalar.value->data()); diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 079b400b86e..c0ff64197bb 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -584,9 +584,11 @@ Status Scalar::ValidateFull() const { return ScalarValidateImpl(/*full_validation=*/true).Validate(*this); } +template BaseBinaryScalar::BaseBinaryScalar(std::string s, std::shared_ptr type, - const FillScalarScratchSpace& fill) - : BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type), fill) {} + FillScalarScratchSpaceFactoryT factory) + : BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type), + std::move(factory)) {} BinaryScalar::BinaryScalar(std::string s, std::shared_ptr type) : BinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {} @@ -641,11 +643,12 @@ FixedSizeBinaryScalar::FixedSizeBinaryScalar(const std::shared_ptr& valu FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::string s, bool is_valid) : FixedSizeBinaryScalar(Buffer::FromString(std::move(s)), is_valid) {} +template BaseListScalar::BaseListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid, - const FillScalarScratchSpace& fill) + FillScalarScratchSpaceFactoryT factory) : Scalar{std::move(type), is_valid}, - ArraySpanFillFromScalarScratchSpace(fill), + ArraySpanFillFromScalarScratchSpace(factory(value.get())), value(std::move(value)) { ARROW_CHECK(this->type->field(0)->type()->Equals(this->value->type())); } @@ -653,10 +656,10 @@ BaseListScalar::BaseListScalar(std::shared_ptr value, ListScalar::ListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) : BaseListScalar(std::move(value), std::move(type), is_valid, - ListScalar::FillScalarScratchSpace(value.get())) {} + FillScalarScratchSpaceFactory) {} ListScalar::ListScalar(std::shared_ptr value, bool is_valid) - : ListScalar(std::move(value), list(value->type()), is_valid) {} + : ListScalar(value, list(value->type()), is_valid) {} void ListScalar::FillScalarScratchSpace::Fill( const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { @@ -668,10 +671,10 @@ void ListScalar::FillScalarScratchSpace::Fill( LargeListScalar::LargeListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) : BaseListScalar(std::move(value), std::move(type), is_valid, - LargeListScalar::FillScalarScratchSpace(value.get())) {} + FillScalarScratchSpaceFactory) {} LargeListScalar::LargeListScalar(std::shared_ptr value, bool is_valid) - : LargeListScalar(std::move(value), large_list(value->type()), is_valid) {} + : LargeListScalar(value, large_list(value->type()), is_valid) {} void LargeListScalar::FillScalarScratchSpace::Fill( const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { @@ -682,10 +685,10 @@ void LargeListScalar::FillScalarScratchSpace::Fill( ListViewScalar::ListViewScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) : BaseListScalar(std::move(value), std::move(type), is_valid, - ListViewScalar::FillScalarScratchSpace(value.get())) {} + FillScalarScratchSpaceFactory) {} ListViewScalar::ListViewScalar(std::shared_ptr value, bool is_valid) - : ListViewScalar(std::move(value), list_view(value->type()), is_valid) {} + : ListViewScalar(value, list_view(value->type()), is_valid) {} void ListViewScalar::FillScalarScratchSpace::Fill( const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { @@ -697,10 +700,10 @@ void ListViewScalar::FillScalarScratchSpace::Fill( LargeListViewScalar::LargeListViewScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) : BaseListScalar(std::move(value), std::move(type), is_valid, - LargeListViewScalar::FillScalarScratchSpace(value.get())) {} + FillScalarScratchSpaceFactory) {} LargeListViewScalar::LargeListViewScalar(std::shared_ptr value, bool is_valid) - : LargeListViewScalar(std::move(value), large_list_view(value->type()), is_valid) {} + : LargeListViewScalar(value, large_list_view(value->type()), is_valid) {} void LargeListViewScalar::FillScalarScratchSpace::Fill( const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { @@ -717,10 +720,10 @@ inline std::shared_ptr MakeMapType(const std::shared_ptr& pa MapScalar::MapScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) : BaseListScalar(std::move(value), std::move(type), is_valid, - MapScalar::FillScalarScratchSpace(value.get())) {} + FillScalarScratchSpaceFactory) {} MapScalar::MapScalar(std::shared_ptr value, bool is_valid) - : MapScalar(std::move(value), MakeMapType(value->type()), is_valid) {} + : MapScalar(value, MakeMapType(value->type()), is_valid) {} void MapScalar::FillScalarScratchSpace::Fill( const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { @@ -732,15 +735,14 @@ void MapScalar::FillScalarScratchSpace::Fill( FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) : BaseListScalar(std::move(value), std::move(type), is_valid, - FixedSizeListScalar::FillScalarScratchSpace(value.get())) { + FillScalarScratchSpaceFactory) { ARROW_CHECK_EQ(this->value->length(), checked_cast(*this->type).list_size()); } FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr value, bool is_valid) : FixedSizeListScalar( - std::move(value), - fixed_size_list(value->type(), static_cast(value->length())), + value, fixed_size_list(value->type(), static_cast(value->length())), is_valid) {} Result> StructScalar::Make( @@ -775,7 +777,8 @@ Result> StructScalar::field(FieldRef ref) const { RunEndEncodedScalar::RunEndEncodedScalar(std::shared_ptr value, std::shared_ptr type) : Scalar{std::move(type), value->is_valid}, - ArraySpanFillFromScalarScratchSpace(FillScalarScratchSpace(ree_type(this->type))), + ArraySpanFillFromScalarScratchSpace( + FillScalarScratchSpace(*ree_type(this->type)->run_end_type())), value{std::move(value)} { ARROW_CHECK_EQ(this->type->id(), Type::RUN_END_ENCODED); } @@ -789,7 +792,7 @@ RunEndEncodedScalar::~RunEndEncodedScalar() = default; void RunEndEncodedScalar::FillScalarScratchSpace::Fill( const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { - switch (ree_type_->id()) { + switch (run_end_type_.id()) { case Type::INT16: FillScalarScratchSpaceHelper(scratch_space.scratch_space_, int16_t(1)); break; @@ -797,7 +800,7 @@ void RunEndEncodedScalar::FillScalarScratchSpace::Fill( FillScalarScratchSpaceHelper(scratch_space.scratch_space_, int32_t(1)); break; default: - DCHECK_EQ(ree_type_->id(), Type::INT64); + DCHECK_EQ(run_end_type_.id(), Type::INT64); FillScalarScratchSpaceHelper(scratch_space.scratch_space_, int64_t(1)); } } diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index d3f2252931c..6fae46a1cff 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -275,6 +275,25 @@ struct ARROW_EXPORT BaseBinaryScalar return value ? std::string_view(*value) : std::string_view(); } + protected: + template + BaseBinaryScalar(std::shared_ptr type, FillScalarScratchSpaceFactoryT factory) + : PrimitiveScalarBase(std::move(type)), + ArraySpanFillFromScalarScratchSpace(factory(false, nullptr)) {} + + template + BaseBinaryScalar(std::shared_ptr value, std::shared_ptr type, + FillScalarScratchSpaceFactoryT factory) + : internal::PrimitiveScalarBase{std::move(type), true}, + ArraySpanFillFromScalarScratchSpace(factory(true, value.get())), + value(std::move(value)) {} + + template + BaseBinaryScalar(std::string s, std::shared_ptr type, + FillScalarScratchSpaceFactoryT factory); + + friend ArraySpan; + protected: struct FillScalarScratchSpace : public internal::FillScalarScratchSpace { FillScalarScratchSpace(bool is_valid, const Buffer* value) @@ -285,30 +304,23 @@ struct ARROW_EXPORT BaseBinaryScalar const Buffer* value_; }; - BaseBinaryScalar(std::shared_ptr type, const FillScalarScratchSpace& fill) - : PrimitiveScalarBase(std::move(type)), ArraySpanFillFromScalarScratchSpace(fill) {} - - BaseBinaryScalar(std::shared_ptr value, std::shared_ptr type, - const FillScalarScratchSpace& fill) - : internal::PrimitiveScalarBase{std::move(type), true}, - ArraySpanFillFromScalarScratchSpace(fill), - value(std::move(value)) {} - - BaseBinaryScalar(std::string s, std::shared_ptr type, - const FillScalarScratchSpace& fill); - - friend ArraySpan; + template + static FillScalarScratchSpaceT FillScalarScratchSpaceFactory(bool is_valid, + const Buffer* value) { + return {is_valid, value}; + } }; struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { using TypeClass = BinaryType; explicit BinaryScalar(std::shared_ptr type) - : BaseBinaryScalar(std::move(type), FillScalarScratchSpace(false, nullptr)) {} + : BaseBinaryScalar(std::move(type), + FillScalarScratchSpaceFactory) {} BinaryScalar(std::shared_ptr value, std::shared_ptr type) : BaseBinaryScalar(std::move(value), std::move(type), - FillScalarScratchSpace(true, value.get())) {} + FillScalarScratchSpaceFactory) {} BinaryScalar(std::string s, std::shared_ptr type); @@ -344,11 +356,12 @@ struct ARROW_EXPORT BinaryViewScalar : public BaseBinaryScalar { using TypeClass = BinaryViewType; explicit BinaryViewScalar(std::shared_ptr type) - : BaseBinaryScalar(std::move(type), FillScalarScratchSpace(false, nullptr)) {} + : BaseBinaryScalar(std::move(type), + FillScalarScratchSpaceFactory) {} BinaryViewScalar(std::shared_ptr value, std::shared_ptr type) : BaseBinaryScalar(std::move(value), std::move(type), - FillScalarScratchSpace(true, value.get())) {} + FillScalarScratchSpaceFactory) {} BinaryViewScalar(std::string s, std::shared_ptr type); @@ -389,11 +402,12 @@ struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar { using TypeClass = LargeBinaryType; explicit LargeBinaryScalar(std::shared_ptr type) - : BaseBinaryScalar(std::move(type), FillScalarScratchSpace(false, nullptr)) {} + : BaseBinaryScalar(std::move(type), + FillScalarScratchSpaceFactory) {} LargeBinaryScalar(std::shared_ptr value, std::shared_ptr type) : BaseBinaryScalar(std::move(value), std::move(type), - FillScalarScratchSpace(true, value.get())) {} + FillScalarScratchSpaceFactory) {} LargeBinaryScalar(std::string s, std::shared_ptr type); @@ -585,19 +599,25 @@ struct ARROW_EXPORT BaseListScalar std::shared_ptr value; + template + BaseListScalar(std::shared_ptr value, std::shared_ptr type, + bool is_valid, FillScalarScratchSpaceFactoryT factory); + + private: + friend struct ArraySpan; + protected: struct FillScalarScratchSpace : public internal::FillScalarScratchSpace { - FillScalarScratchSpace(const Array* value) : value_(value) {} + explicit FillScalarScratchSpace(const Array* value) : value_(value) {} protected: const Array* value_; }; - BaseListScalar(std::shared_ptr value, std::shared_ptr type, - bool is_valid, const FillScalarScratchSpace& fill); - - private: - friend struct ArraySpan; + template + static FillScalarScratchSpaceT FillScalarScratchSpaceFactory(const Array* value) { + return FillScalarScratchSpaceT{value}; + } }; struct ARROW_EXPORT ListScalar : public BaseListScalar { @@ -817,20 +837,19 @@ struct ARROW_EXPORT RunEndEncodedScalar } private: - static std::shared_ptr ree_type( - const std::shared_ptr& type) { - return internal::checked_pointer_cast(type); + static std::shared_ptr ree_type(const std::shared_ptr& type) { + return internal::checked_pointer_cast(type); } struct FillScalarScratchSpace : public internal::FillScalarScratchSpace { - explicit FillScalarScratchSpace(std::shared_ptr ree_type) - : ree_type_(std::move(ree_type)) {} + explicit FillScalarScratchSpace(const DataType& run_end_type) + : run_end_type_(run_end_type) {} void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const override; private: - const std::shared_ptr ree_type_; + const DataType& run_end_type_; }; friend ArraySpan; From 39407f49eeb0b6048bbed00f1880f23570b9d579 Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Sat, 30 Mar 2024 02:50:34 +0800 Subject: [PATCH 06/31] Fix lint --- cpp/src/arrow/scalar.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 6fae46a1cff..f93ff2984f5 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -279,7 +279,7 @@ struct ARROW_EXPORT BaseBinaryScalar template BaseBinaryScalar(std::shared_ptr type, FillScalarScratchSpaceFactoryT factory) : PrimitiveScalarBase(std::move(type)), - ArraySpanFillFromScalarScratchSpace(factory(false, nullptr)) {} + ArraySpanFillFromScalarScratchSpace(factory(false, NULLPTR)) {} template BaseBinaryScalar(std::shared_ptr value, std::shared_ptr type, From d957a090d0dbabe9816f8be8cb0435e5c8cf3e1d Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Sat, 30 Mar 2024 03:01:16 +0800 Subject: [PATCH 07/31] Add concurrent fill from scalar test case --- cpp/src/arrow/array/array_test.cc | 34 +++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index 21ac1a09f56..1411e08c158 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -823,6 +824,39 @@ TEST_F(TestArray, TestFillFromScalar) { } } +TEST_F(TestArray, TestConcurrentFillFromScalar) { + for (auto type : TestArrayUtilitiesAgainstTheseTypes()) { + ARROW_SCOPED_TRACE("type = ", type->ToString()); + for (auto seed : {0u, 0xdeadbeef, 42u}) { + ARROW_SCOPED_TRACE("seed = ", seed); + + Field field("", type, /*nullable=*/true, + key_value_metadata({{"extension_allow_random_storage", "true"}})); + auto array = random::GenerateArray(field, 1, seed); + + ASSERT_OK_AND_ASSIGN(auto scalar, array->GetScalar(0)); + + // Lambda to create fill an ArraySpan with the scalar and use the ArraySpan a bit. + auto array_span_from_scalar = [&]() { + ArraySpan span(*scalar); + auto roundtripped_array = span.ToArray(); + ASSERT_OK(roundtripped_array->ValidateFull()); + + AssertArraysEqual(*array, *roundtripped_array); + ASSERT_OK_AND_ASSIGN(auto roundtripped_scalar, roundtripped_array->GetScalar(0)); + AssertScalarsEqual(*scalar, *roundtripped_scalar); + }; + + // Two concurrent calls to the lambda are just enough for TSAN to detect a race + // condition. + auto fut1 = std::async(std::launch::async, array_span_from_scalar); + auto fut2 = std::async(std::launch::async, array_span_from_scalar); + fut1.get(); + fut2.get(); + } + } +} + TEST_F(TestArray, ExtensionSpanRoundTrip) { // Other types are checked in MakeEmptyArray but MakeEmptyArray doesn't // work for extension types so we check that here From aed65575292f410f492c74e92de2ca797ba7a75b Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Mon, 1 Apr 2024 02:34:46 +0800 Subject: [PATCH 08/31] WIP --- .../arrow/compute/kernels/codegen_internal.h | 66 +++++++++---------- .../arrow/compute/kernels/scalar_compare.cc | 20 +++--- cpp/src/arrow/scalar.cc | 12 +--- cpp/src/arrow/scalar.h | 26 ++++---- 4 files changed, 60 insertions(+), 64 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 72b29057b82..8e21393743e 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -369,42 +369,42 @@ struct UnboxScalar { } }; -template -struct BoxScalar; - -template -struct BoxScalar> { - using T = typename GetOutputType::T; - static void Box(T val, Scalar* out) { - // Enables BoxScalar to work on a (for example) Time64Scalar - T* mutable_data = reinterpret_cast( - checked_cast<::arrow::internal::PrimitiveScalarBase*>(out)->mutable_data()); - *mutable_data = val; - } -}; +// template +// struct BoxScalar; + +// template +// struct BoxScalar> { +// using T = typename GetOutputType::T; +// static void Box(T val, Scalar* out) { +// // Enables BoxScalar to work on a (for example) Time64Scalar +// T* mutable_data = reinterpret_cast( +// checked_cast<::arrow::internal::PrimitiveScalarBase*>(out)->mutable_data()); +// *mutable_data = val; +// } +// }; -template -struct BoxScalar> { - using T = typename GetOutputType::T; - using ScalarType = typename TypeTraits::ScalarType; - static void Box(T val, Scalar* out) { - checked_cast(out)->value = std::make_shared(val); - } -}; +// template +// struct BoxScalar> { +// using T = typename GetOutputType::T; +// using ScalarType = typename TypeTraits::ScalarType; +// static void Box(T val, Scalar* out) { +// checked_cast(out)->value = std::make_shared(val); +// } +// }; -template <> -struct BoxScalar { - using T = Decimal128; - using ScalarType = Decimal128Scalar; - static void Box(T val, Scalar* out) { checked_cast(out)->value = val; } -}; +// template <> +// struct BoxScalar { +// using T = Decimal128; +// using ScalarType = Decimal128Scalar; +// static void Box(T val, Scalar* out) { checked_cast(out)->value = val; } +// }; -template <> -struct BoxScalar { - using T = Decimal256; - using ScalarType = Decimal256Scalar; - static void Box(T val, Scalar* out) { checked_cast(out)->value = val; } -}; +// template <> +// struct BoxScalar { +// using T = Decimal256; +// using ScalarType = Decimal256Scalar; +// static void Box(T val, Scalar* out) { checked_cast(out)->value = val; } +// }; // A VisitArraySpanInline variant that calls its visitor function with logical // values, such as Decimal128 rather than std::string_view. diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index daf8ed76d62..f6dc0a15db8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -491,8 +491,9 @@ template struct ScalarMinMax { using OutValue = typename GetOutputType::T; - static void ExecScalar(const ExecSpan& batch, - const ElementWiseAggregateOptions& options, Scalar* out) { + static Result> ExecScalar( + const ExecSpan& batch, const ElementWiseAggregateOptions& options, + std::shared_ptr type) { // All arguments are scalar OutValue value{}; bool valid = false; @@ -502,8 +503,8 @@ struct ScalarMinMax { const Scalar& scalar = *arg.scalar; if (!scalar.is_valid) { if (options.skip_nulls) continue; - out->is_valid = false; - return; + valid = false; + break; } if (!valid) { value = UnboxScalar::Unbox(scalar); @@ -513,9 +514,10 @@ struct ScalarMinMax { value, UnboxScalar::Unbox(scalar)); } } - out->is_valid = valid; if (valid) { - BoxScalar::Box(value, out); + return MakeScalar(std::move(type), std::move(value)); + } else { + return MakeNullScalar(std::move(type)); } } @@ -536,9 +538,9 @@ struct ScalarMinMax { bool initialize_output = true; if (scalar_count > 0) { - ARROW_ASSIGN_OR_RAISE(std::shared_ptr temp_scalar, - MakeScalar(out->type()->GetSharedPtr(), 0)); - ExecScalar(batch, options, temp_scalar.get()); + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr temp_scalar, + ExecScalar(batch, options, out->type()->GetSharedPtr(), temp_scalar.get())); if (temp_scalar->is_valid) { const auto value = UnboxScalar::Unbox(*temp_scalar); initialize_output = false; diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index c0ff64197bb..1820200d929 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -557,13 +557,10 @@ void FillScalarScratchSpaceHelperInternal(uint8_t* scratch_space, T first, Args. // assign it *reinterpret_cast(scratch_space + offset) = first; - // Calculate the next offset based on the size of the current type T - constexpr size_t next_offset = offset + sizeof(T); - // Recursively fill the scratch space with the remaining arguments if constexpr (sizeof...(args) > 0) { // Use if constexpr to stop recursion when // there are no more arguments - FillScalarScratchSpaceHelperInternal(scratch_space, + FillScalarScratchSpaceHelperInternal(scratch_space, std::forward(args)...); } } @@ -1450,13 +1447,10 @@ struct ToTypeVisitor : CastImplVisitor { } // namespace Result> Scalar::CastTo(std::shared_ptr to) const { - std::shared_ptr out = MakeNullScalar(to); if (is_valid) { - out->is_valid = true; - ToTypeVisitor unpack_to_type{*this, to, out.get()}; - RETURN_NOT_OK(VisitTypeInline(*to, &unpack_to_type)); + return ToTypeVisitor{*this, std::move(to)}.Finish(); } - return out; + return MakeNullScalar(std::move(to)); } void PrintTo(const Scalar& scalar, std::ostream* os) { *os << scalar.ToString(); } diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index f93ff2984f5..57c02c5c37e 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -159,7 +159,7 @@ struct ARROW_EXPORT PrimitiveScalarBase : public Scalar { /// \brief Get a const pointer to the value of this scalar. May be null. virtual const void* data() const = 0; /// \brief Get a mutable pointer to the value of this scalar. May be null. - virtual void* mutable_data() = 0; + // virtual void* mutable_data() = 0; /// \brief Get an immutable view of the value of this scalar as bytes. virtual std::string_view view() const = 0; }; @@ -180,7 +180,7 @@ struct ARROW_EXPORT PrimitiveScalar : public PrimitiveScalarBase { ValueType value{}; const void* data() const override { return &value; } - void* mutable_data() override { return &value; } + // void* mutable_data() override { return &value; } std::string_view view() const override { return std::string_view(reinterpret_cast(&value), sizeof(ValueType)); }; @@ -263,14 +263,14 @@ struct ARROW_EXPORT BaseBinaryScalar private internal::ArraySpanFillFromScalarScratchSpace { using ValueType = std::shared_ptr; - std::shared_ptr value; + const std::shared_ptr value; const void* data() const override { return value ? reinterpret_cast(value->data()) : NULLPTR; } - void* mutable_data() override { - return value ? reinterpret_cast(value->mutable_data()) : NULLPTR; - } + // void* mutable_data() override { + // return value ? reinterpret_cast(value->mutable_data()) : NULLPTR; + // } std::string_view view() const override { return value ? std::string_view(*value) : std::string_view(); } @@ -572,9 +572,9 @@ struct ARROW_EXPORT DecimalScalar : public internal::PrimitiveScalarBase { return reinterpret_cast(value.native_endian_bytes()); } - void* mutable_data() override { - return reinterpret_cast(value.mutable_native_endian_bytes()); - } + // void* mutable_data() override { + // return reinterpret_cast(value.mutable_native_endian_bytes()); + // } std::string_view view() const override { return std::string_view(reinterpret_cast(value.native_endian_bytes()), @@ -880,10 +880,10 @@ struct ARROW_EXPORT DictionaryScalar : public internal::PrimitiveScalarBase { const void* data() const override { return internal::checked_cast(*value.index).data(); } - void* mutable_data() override { - return internal::checked_cast(*value.index) - .mutable_data(); - } + // void* mutable_data() override { + // return internal::checked_cast(*value.index) + // .mutable_data(); + // } std::string_view view() const override { return internal::checked_cast(*value.index) .view(); From 4ff67b117d43b786d319aae38778d0a119d15a7a Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Mon, 1 Apr 2024 18:05:50 +0800 Subject: [PATCH 09/31] Fix --- .../arrow/compute/kernels/scalar_compare.cc | 2 +- cpp/src/arrow/scalar.cc | 300 +++++++++++------- cpp/src/arrow/scalar_test.cc | 66 ++-- 3 files changed, 230 insertions(+), 138 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index f6dc0a15db8..e235e86d233 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -540,7 +540,7 @@ struct ScalarMinMax { if (scalar_count > 0) { ARROW_ASSIGN_OR_RAISE( std::shared_ptr temp_scalar, - ExecScalar(batch, options, out->type()->GetSharedPtr(), temp_scalar.get())); + ExecScalar(batch, options, out->type()->GetSharedPtr())); if (temp_scalar->is_valid) { const auto value = UnboxScalar::Unbox(*temp_scalar); initialize_output = false; diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 1820200d929..95bc73c2fdb 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -561,7 +561,7 @@ void FillScalarScratchSpaceHelperInternal(uint8_t* scratch_space, T first, Args. if constexpr (sizeof...(args) > 0) { // Use if constexpr to stop recursion when // there are no more arguments FillScalarScratchSpaceHelperInternal(scratch_space, - std::forward(args)...); + std::forward(args)...); } } @@ -1142,58 +1142,72 @@ std::shared_ptr FormatToBuffer(Formatter&& formatter, const ScalarType& } // error fallback -Status CastImpl(const Scalar& from, Scalar* to) { +template +Result> CastImpl(const Scalar& from, + std::shared_ptr to_type) { return Status::NotImplemented("casting scalars of type ", *from.type, " to type ", - *to->type); + *to_type); } // numeric to numeric -template -Status CastImpl(const NumericScalar& from, NumericScalar* to) { - to->value = static_cast(from.value); - return Status::OK(); +template +enable_if_number>> CastImpl( + const NumericScalar& from, std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; + return std::make_shared(static_cast(from.value), + std::move(to_type)); } // numeric to boolean -template -Status CastImpl(const NumericScalar& from, BooleanScalar* to) { - constexpr auto zero = static_cast(0); - to->value = from.value != zero; - return Status::OK(); +template +enable_if_boolean>> CastImpl( + const NumericScalar& from, std::shared_ptr to_type) { + constexpr auto zero = static_cast(0); + return std::make_shared(from.value != zero, std::move(to_type)); } // boolean to numeric -template -Status CastImpl(const BooleanScalar& from, NumericScalar* to) { - to->value = static_cast(from.value); - return Status::OK(); +template +enable_if_number>> CastImpl( + const BooleanScalar& from, std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; + return std::make_shared(static_cast(from.value), + std::move(to_type)); } // numeric to temporal -template +template typename std::enable_if::value && !std::is_same::value && !std::is_same::value, - Status>::type -CastImpl(const NumericScalar& from, TemporalScalar* to) { - to->value = static_cast(from.value); - return Status::OK(); + Result>>::type +CastImpl(const NumericScalar& from, std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; + return std::make_shared(static_cast(from.value), + std::move(to_type)); } // temporal to numeric -template -typename std::enable_if::value && +template +typename std::enable_if::value && + std::is_base_of::value && !std::is_same::value && !std::is_same::value, - Status>::type -CastImpl(const TemporalScalar& from, NumericScalar* to) { - to->value = static_cast(from.value); - return Status::OK(); + Result>>::type +CastImpl(const TemporalScalar& from, std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; + return std::make_shared(static_cast(from.value), + std::move(to_type)); } // timestamp to timestamp -Status CastImpl(const TimestampScalar& from, TimestampScalar* to) { - return util::ConvertTimestampValue(from.type, to->type, from.value).Value(&to->value); +template +enable_if_timestamp>> CastImpl( + const TimestampScalar& from, std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; + ARROW_ASSIGN_OR_RAISE(auto value, + util::ConvertTimestampValue(from.type, to_type, from.value)); + return std::make_shared(value, std::move(to_type)); } template @@ -1202,101 +1216,134 @@ std::shared_ptr AsTimestampType(const std::shared_ptr& type) } // duration to duration -Status CastImpl(const DurationScalar& from, DurationScalar* to) { - return util::ConvertTimestampValue(AsTimestampType(from.type), - AsTimestampType(to->type), from.value) - .Value(&to->value); +template +enable_if_duration>> CastImpl( + const DurationScalar& from, std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; + ARROW_ASSIGN_OR_RAISE( + auto value, + util::ConvertTimestampValue(AsTimestampType(from.type), + AsTimestampType(to_type), from.value)); + return std::make_shared(value, std::move(to_type)); } // time to time -template -enable_if_time CastImpl(const TimeScalar& from, ToScalar* to) { - return util::ConvertTimestampValue(AsTimestampType(from.type), - AsTimestampType(to->type), from.value) - .Value(&to->value); +template +enable_if_time>> CastImpl( + const TimeScalar& from, std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; + ARROW_ASSIGN_OR_RAISE( + auto value, util::ConvertTimestampValue(AsTimestampType(from.type), + AsTimestampType(to_type), from.value)); + return std::make_shared(value, std::move(to_type)); } constexpr int64_t kMillisecondsInDay = 86400000; // date to date -Status CastImpl(const Date32Scalar& from, Date64Scalar* to) { - to->value = from.value * kMillisecondsInDay; - return Status::OK(); +template +enable_if_t::value, Result>> +CastImpl(const Date32Scalar& from, std::shared_ptr to_type) { + return std::make_shared(from.value * kMillisecondsInDay, + std::move(to_type)); } -Status CastImpl(const Date64Scalar& from, Date32Scalar* to) { - to->value = static_cast(from.value / kMillisecondsInDay); - return Status::OK(); +template +enable_if_t::value, Result>> +CastImpl(const Date64Scalar& from, std::shared_ptr to_type) { + return std::make_shared( + static_cast(from.value / kMillisecondsInDay), std::move(to_type)); } // timestamp to date -Status CastImpl(const TimestampScalar& from, Date64Scalar* to) { +template +enable_if_t::value, Result>> +CastImpl(const TimestampScalar& from, std::shared_ptr to_type) { ARROW_ASSIGN_OR_RAISE( auto millis, util::ConvertTimestampValue(from.type, timestamp(TimeUnit::MILLI), from.value)); - to->value = millis - millis % kMillisecondsInDay; - return Status::OK(); + return std::make_shared(millis - millis % kMillisecondsInDay, + std::move(to_type)); } -Status CastImpl(const TimestampScalar& from, Date32Scalar* to) { +template +enable_if_t::value, Result>> +CastImpl(const TimestampScalar& from, std::shared_ptr to_type) { ARROW_ASSIGN_OR_RAISE( auto millis, util::ConvertTimestampValue(from.type, timestamp(TimeUnit::MILLI), from.value)); - to->value = static_cast(millis / kMillisecondsInDay); - return Status::OK(); + return std::make_shared(static_cast(millis / kMillisecondsInDay), + std::move(to_type)); } // date to timestamp -template -Status CastImpl(const DateScalar& from, TimestampScalar* to) { +template +enable_if_timestamp>> CastImpl( + const DateScalar& from, std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; int64_t millis = from.value; - if (std::is_same::value) { + if (std::is_same::value) { millis *= kMillisecondsInDay; } - return util::ConvertTimestampValue(timestamp(TimeUnit::MILLI), to->type, millis) - .Value(&to->value); + ARROW_ASSIGN_OR_RAISE(auto value, util::ConvertTimestampValue( + timestamp(TimeUnit::MILLI), to_type, millis)); + return std::make_shared(value, std::move(to_type)); } // string to any -template -Status CastImpl(const StringScalar& from, ScalarType* to) { - ARROW_ASSIGN_OR_RAISE(auto out, Scalar::Parse(to->type, std::string_view(*from.value))); - to->value = std::move(checked_cast(*out).value); - return Status::OK(); +template +Result> CastImpl(const StringScalar& from, + std::shared_ptr to_type) { + using ToScalar = typename TypeTraits::ScalarType; + ARROW_ASSIGN_OR_RAISE(auto out, + Scalar::Parse(std::move(to_type), std::string_view(*from.value))); + DCHECK(checked_pointer_cast(out) != nullptr); + return std::move(out); } // binary/large binary/large string to string -template -enable_if_t && - !std::is_same::value, - Status> -CastImpl(const ScalarType& from, StringScalar* to) { - to->value = from.value; - return Status::OK(); +template +enable_if_t::value && + std::is_base_of_v && + !std::is_same::value, + Result>> +CastImpl(const From& from, std::shared_ptr to_type) { + return std::make_shared(from.value, std::move(to_type)); } // formattable to string -template , // note: Value unused but necessary to trigger SFINAE if Formatter is // undefined typename Value = typename Formatter::value_type> -Status CastImpl(const ScalarType& from, StringScalar* to) { - to->value = FormatToBuffer(Formatter{from.type.get()}, from); - return Status::OK(); -} - -Status CastImpl(const Decimal128Scalar& from, StringScalar* to) { - auto from_type = checked_cast(from.type.get()); - to->value = Buffer::FromString(from.value.ToString(from_type->scale())); - return Status::OK(); +typename std::enable_if_t::value, + Result>> +CastImpl(const From& from, std::shared_ptr to_type) { + return std::make_shared(FormatToBuffer(Formatter{from.type.get()}, from), + std::move(to_type)); } -Status CastImpl(const Decimal256Scalar& from, StringScalar* to) { - auto from_type = checked_cast(from.type.get()); - to->value = Buffer::FromString(from.value.ToString(from_type->scale())); - return Status::OK(); -} - -Status CastImpl(const StructScalar& from, StringScalar* to) { +// template +// typename std::enable_if_t::value, +// Result>> +// CastImpl(const Decimal128Scalar& from, std::shared_ptr to_type) { +// auto from_type = checked_cast(from.type.get()); +// return std::make_shared( +// Buffer::FromString(from.value.ToString(from_type->scale())), std::move(to_type)); +// } + +// template +// typename std::enable_if_t::value, +// Result>> +// CastImpl(const Decimal256Scalar& from, std::shared_ptr to_type) { +// auto from_type = checked_cast(from.type.get()); +// return std::make_shared( +// Buffer::FromString(from.value.ToString(from_type->scale())), std::move(to_type)); +// } + +template +typename std::enable_if_t::value, + Result>> +CastImpl(const StructScalar& from, std::shared_ptr to_type) { std::stringstream ss; ss << '{'; for (int i = 0; static_cast(i) < from.value.size(); i++) { @@ -1305,24 +1352,23 @@ Status CastImpl(const StructScalar& from, StringScalar* to) { << " = " << from.value[i]->ToString(); } ss << '}'; - to->value = Buffer::FromString(ss.str()); - return Status::OK(); + return std::make_shared(Buffer::FromString(ss.str()), std::move(to_type)); } // casts between variable-length and fixed-length list types template -enable_if_list_type CastImpl( - const BaseListScalar& from, ToScalar* to) { +enable_if_list_type>> +CastImpl(const BaseListScalar& from, std::shared_ptr to_type) { if constexpr (sizeof(typename ToScalar::TypeClass::offset_type) < sizeof(int64_t)) { if (from.value->length() > std::numeric_limits::max()) { return Status::Invalid(from.type->ToString(), " too large to cast to ", - to->type->ToString()); + to_type->ToString()); } } if constexpr (is_fixed_size_list_type::value) { - const auto& fixed_size_list_type = checked_cast(*to->type); + const auto& fixed_size_list_type = checked_cast(*to_type); if (from.value->length() != fixed_size_list_type.list_size()) { return Status::Invalid("Cannot cast ", from.type->ToString(), " of length ", from.value->length(), " to fixed size list of length ", @@ -1330,13 +1376,14 @@ enable_if_list_type CastImpl( } } - DCHECK_EQ(from.is_valid, to->is_valid); - to->value = from.value; - return Status::OK(); + return std::make_shared(from.value, std::move(to_type), from.is_valid); } // list based types (list, large list and map (fixed sized list too)) to string -Status CastImpl(const BaseListScalar& from, StringScalar* to) { +template +typename std::enable_if_t::value, + Result>> +CastImpl(const BaseListScalar& from, std::shared_ptr to_type) { std::stringstream ss; ss << from.type->ToString() << "["; for (int64_t i = 0; i < from.value->length(); i++) { @@ -1345,11 +1392,14 @@ Status CastImpl(const BaseListScalar& from, StringScalar* to) { ss << value->ToString(); } ss << ']'; - to->value = Buffer::FromString(ss.str()); - return Status::OK(); + return std::make_shared(Buffer::FromString(ss.str()), std::move(to_type)); } -Status CastImpl(const UnionScalar& from, StringScalar* to) { +// union types to string +template +typename std::enable_if_t::value, + Result>> +CastImpl(const UnionScalar& from, std::shared_ptr to_type) { const auto& union_ty = checked_cast(*from.type); std::stringstream ss; const Scalar* selected_value; @@ -1361,8 +1411,7 @@ Status CastImpl(const UnionScalar& from, StringScalar* to) { } ss << "union{" << union_ty.field(union_ty.child_ids()[from.type_code])->ToString() << " = " << selected_value->ToString() << '}'; - to->value = Buffer::FromString(ss.str()); - return Status::OK(); + return std::make_shared(Buffer::FromString(ss.str()), std::move(to_type)); } struct CastImplVisitor { @@ -1372,34 +1421,39 @@ struct CastImplVisitor { const Scalar& from_; const std::shared_ptr& to_type_; - Scalar* out_; + std::shared_ptr out_ = nullptr; }; template struct FromTypeVisitor : CastImplVisitor { using ToScalar = typename TypeTraits::ScalarType; - FromTypeVisitor(const Scalar& from, const std::shared_ptr& to_type, - Scalar* out) - : CastImplVisitor{from, to_type, out} {} + FromTypeVisitor(const Scalar& from, const std::shared_ptr& to_type) + : CastImplVisitor{from, to_type} {} template Status Visit(const FromType&) { - return CastImpl(checked_cast::ScalarType&>(from_), - checked_cast(out_)); + ARROW_ASSIGN_OR_RAISE( + out_, CastImpl( + checked_cast::ScalarType&>(from_), + std::move(to_type_))); + return Status::OK(); } // identity cast only for parameter free types template typename std::enable_if_t::is_parameter_free, Status> Visit( const ToType&) { - checked_cast(out_)->value = checked_cast(from_).value; + ARROW_ASSIGN_OR_RAISE(out_, MakeScalar(std::move(to_type_), + checked_cast(from_).value)); return Status::OK(); } Status CastFromListLike(const BaseListType& base_list_type) { - return CastImpl(checked_cast(from_), - checked_cast(out_)); + ARROW_ASSIGN_OR_RAISE(out_, + CastImpl(checked_cast(from_), + std::move(to_type_))); + return Status::OK(); } Status Visit(const ListType& list_type) { return CastFromListLike(list_type); } @@ -1412,19 +1466,29 @@ struct FromTypeVisitor : CastImplVisitor { return CastFromListLike(fixed_size_list_type); } + Status Visit(const ListViewType& list_view_type) { + return CastFromListLike(list_view_type); + } + + Status Visit(const LargeListViewType& large_list_view_type) { + return CastFromListLike(large_list_view_type); + } + Status Visit(const NullType&) { return NotImplemented(); } Status Visit(const DictionaryType&) { return NotImplemented(); } Status Visit(const ExtensionType&) { return NotImplemented(); } }; struct ToTypeVisitor : CastImplVisitor { - ToTypeVisitor(const Scalar& from, const std::shared_ptr& to_type, Scalar* out) - : CastImplVisitor{from, to_type, out} {} + ToTypeVisitor(const Scalar& from, const std::shared_ptr& to_type) + : CastImplVisitor{from, to_type} {} template Status Visit(const ToType&) { - FromTypeVisitor unpack_from_type{from_, to_type_, out_}; - return VisitTypeInline(*from_.type, &unpack_from_type); + FromTypeVisitor unpack_from_type{from_, to_type_}; + ARROW_RETURN_NOT_OK(VisitTypeInline(*from_.type, &unpack_from_type)); + out_ = std::move(unpack_from_type.out_); + return Status::OK(); } Status Visit(const NullType&) { @@ -1435,13 +1499,19 @@ struct ToTypeVisitor : CastImplVisitor { } Status Visit(const DictionaryType& dict_type) { - auto& out = checked_cast(out_)->value; ARROW_ASSIGN_OR_RAISE(auto cast_value, from_.CastTo(dict_type.value_type())); - ARROW_ASSIGN_OR_RAISE(out.dictionary, MakeArrayFromScalar(*cast_value, 1)); - return Int32Scalar(0).CastTo(dict_type.index_type()).Value(&out.index); + ARROW_ASSIGN_OR_RAISE(auto dictionary, MakeArrayFromScalar(*cast_value, 1)); + ARROW_ASSIGN_OR_RAISE(auto index, Int32Scalar(0).CastTo(dict_type.index_type())); + out_ = DictionaryScalar::Make(std::move(index), std::move(dictionary)); + return Status::OK(); } Status Visit(const ExtensionType&) { return NotImplemented(); } + + Result> Finish() && { + ARROW_RETURN_NOT_OK(VisitTypeInline(*to_type_, this)); + return std::move(out_); + } }; } // namespace diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index 09dfde32271..10ca8963048 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -95,6 +95,30 @@ TEST(TestNullScalar, ValidateErrors) { AssertValidationFails(scalar); } +TEST(TestNullScalar, Cast) { + NullScalar scalar; + for (auto to_type : { + int8(), + float64(), + date32(), + time32(TimeUnit::SECOND), + timestamp(TimeUnit::SECOND), + duration(TimeUnit::SECOND), + utf8(), + large_binary(), + list(int32()), + struct_({field("f", int32())}), + map(utf8(), int32()), + decimal(12, 2), + list_view(int32()), + large_list(int32()), + }) { + ASSERT_OK_AND_ASSIGN(auto casted, scalar.CastTo(to_type)); + ASSERT_EQ(casted->type->id(), to_type->id()); + ASSERT_FALSE(casted->is_valid); + } +} + template class TestNumericScalar : public ::testing::Test { public: @@ -470,6 +494,8 @@ TYPED_TEST_SUITE(TestDecimalScalar, DecimalArrowTypes); TYPED_TEST(TestDecimalScalar, Basics) { this->TestBasics(); } +TYPED_TEST(TestDecimalScalar, Cast) {} + TEST(TestBinaryScalar, Basics) { std::string data = "test data"; auto buf = std::make_shared(data); @@ -580,19 +606,25 @@ class TestStringScalar : public ::testing::Test { } void TestValidateErrors() { - // Inconsistent is_valid / value - ScalarType scalar(Buffer::FromString("xxx")); - scalar.is_valid = false; - AssertValidationFails(scalar); + { + // Inconsistent is_valid / value + ScalarType scalar(Buffer::FromString("xxx")); + scalar.is_valid = false; + AssertValidationFails(scalar); + } - auto null_scalar = MakeNullScalar(type_); - null_scalar->is_valid = true; - AssertValidationFails(*null_scalar); + { + auto null_scalar = MakeNullScalar(type_); + null_scalar->is_valid = true; + AssertValidationFails(*null_scalar); + } - // Invalid UTF8 - scalar = ScalarType(Buffer::FromString("\xff")); - ASSERT_OK(scalar.Validate()); - ASSERT_RAISES(Invalid, scalar.ValidateFull()); + { + // Invalid UTF8 + ScalarType scalar(Buffer::FromString("\xff")); + ASSERT_OK(scalar.Validate()); + ASSERT_RAISES(Invalid, scalar.ValidateFull()); + } } protected: @@ -676,8 +708,7 @@ TEST(TestFixedSizeBinaryScalar, ValidateErrors) { FixedSizeBinaryScalar scalar(buf, type); ASSERT_OK(scalar.ValidateFull()); - scalar.value = SliceBuffer(buf, 1); - AssertValidationFails(scalar); + ASSERT_RAISES(Invalid, MakeScalar(type, SliceBuffer(buf, 1))); } TEST(TestDateScalars, Basics) { @@ -1973,15 +2004,6 @@ TEST_F(TestExtensionScalar, ValidateErrors) { // If the scalar is null it's okay scalar.is_valid = false; ASSERT_OK(scalar.ValidateFull()); - - // Invalid storage scalar (wrong length) - std::shared_ptr invalid_storage = MakeNullScalar(storage_type_); - invalid_storage->is_valid = true; - static_cast(invalid_storage.get())->value = - std::make_shared("123"); - AssertValidationFails(*invalid_storage); - scalar = ExtensionScalar(invalid_storage, type_); - AssertValidationFails(scalar); } } // namespace arrow From c82b1fda9820430dcb3d6a6eaccff3e972f2d46d Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Mon, 1 Apr 2024 18:08:14 +0800 Subject: [PATCH 10/31] Format --- cpp/src/arrow/compute/kernels/scalar_compare.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_compare.cc b/cpp/src/arrow/compute/kernels/scalar_compare.cc index e235e86d233..9b2fd987d81 100644 --- a/cpp/src/arrow/compute/kernels/scalar_compare.cc +++ b/cpp/src/arrow/compute/kernels/scalar_compare.cc @@ -538,9 +538,8 @@ struct ScalarMinMax { bool initialize_output = true; if (scalar_count > 0) { - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr temp_scalar, - ExecScalar(batch, options, out->type()->GetSharedPtr())); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr temp_scalar, + ExecScalar(batch, options, out->type()->GetSharedPtr())); if (temp_scalar->is_valid) { const auto value = UnboxScalar::Unbox(*temp_scalar); initialize_output = false; From c90a4d767b8398b4f21a3043026d8bac3143d5ef Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Mon, 1 Apr 2024 20:31:18 +0800 Subject: [PATCH 11/31] Fix --- cpp/src/arrow/scalar.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 95bc73c2fdb..cd83e904f98 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -1301,7 +1301,7 @@ Result> CastImpl(const StringScalar& from, // binary/large binary/large string to string template -enable_if_t::value && +enable_if_t::value && std::is_base_of_v && !std::is_same::value, Result>> From c68b7c8c327d894847d6c6c773b1d77b136bbc40 Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Mon, 1 Apr 2024 21:01:27 +0800 Subject: [PATCH 12/31] Fix --- cpp/src/arrow/scalar.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index cd83e904f98..57e5da0081b 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -1162,7 +1162,7 @@ enable_if_number>> CastImpl( template enable_if_boolean>> CastImpl( const NumericScalar& from, std::shared_ptr to_type) { - constexpr auto zero = static_cast(0); + constexpr auto zero = static_cast(0); return std::make_shared(from.value != zero, std::move(to_type)); } From f1aba1e4dd7e1003b190d17bb9bcc2c46538335e Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Tue, 2 Apr 2024 01:39:37 +0800 Subject: [PATCH 13/31] Use CRTP to fill scalar scratch space --- cpp/src/arrow/array/data.cc | 51 ++++-- cpp/src/arrow/scalar.cc | 195 +++++++-------------- cpp/src/arrow/scalar.h | 341 +++++++++++++----------------------- 3 files changed, 217 insertions(+), 370 deletions(-) diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc index 7be2658835b..67581130b31 100644 --- a/cpp/src/arrow/array/data.cc +++ b/cpp/src/arrow/array/data.cc @@ -415,16 +415,18 @@ void ArraySpan::FillFromScalar(const Scalar& value) { data_size = scalar.value->size(); } if (is_binary_like(type_id)) { + const auto& binary_scalar = checked_cast(value); this->buffers[1] = - OffsetsForScalar(scalar.scratch_space_, static_cast(data_size)); + OffsetsForScalar(binary_scalar.scratch_space_, static_cast(data_size)); } else { // is_large_binary_like - this->buffers[1] = OffsetsForScalar(scalar.scratch_space_, data_size); + const auto& large_binary_scalar = checked_cast(value); + this->buffers[1] = OffsetsForScalar(large_binary_scalar.scratch_space_, data_size); } this->buffers[2].data = const_cast(data_buffer); this->buffers[2].size = data_size; } else if (type_id == Type::BINARY_VIEW || type_id == Type::STRING_VIEW) { - const auto& scalar = checked_cast(value); + const auto& scalar = checked_cast(value); this->buffers[1].size = BinaryViewType::kSize; this->buffers[1].data = scalar.scratch_space_; @@ -456,17 +458,26 @@ void ArraySpan::FillFromScalar(const Scalar& value) { &this->child_data[0]); } - if (type_id == Type::LIST || type_id == Type::MAP) { + if (type_id == Type::LIST) { + const auto& list_scalar = checked_cast(value); + this->buffers[1] = OffsetsForScalar(list_scalar.scratch_space_, + static_cast(value_length)); + } else if (type_id == Type::MAP) { + const auto& map_scalar = checked_cast(value); this->buffers[1] = - OffsetsForScalar(scalar.scratch_space_, static_cast(value_length)); + OffsetsForScalar(map_scalar.scratch_space_, static_cast(value_length)); } else if (type_id == Type::LARGE_LIST) { - this->buffers[1] = OffsetsForScalar(scalar.scratch_space_, value_length); + const auto& large_list_scalar = checked_cast(value); + this->buffers[1] = OffsetsForScalar(large_list_scalar.scratch_space_, value_length); } else if (type_id == Type::LIST_VIEW) { + const auto& list_view_scalar = checked_cast(value); std::tie(this->buffers[1], this->buffers[2]) = OffsetsAndSizesForScalar( - scalar.scratch_space_, static_cast(value_length)); + list_view_scalar.scratch_space_, static_cast(value_length)); } else if (type_id == Type::LARGE_LIST_VIEW) { + const auto& large_list_view_scalar = + checked_cast(value); std::tie(this->buffers[1], this->buffers[2]) = - OffsetsAndSizesForScalar(scalar.scratch_space_, value_length); + OffsetsAndSizesForScalar(large_list_view_scalar.scratch_space_, value_length); } else { DCHECK_EQ(type_id, Type::FIXED_SIZE_LIST); // FIXED_SIZE_LIST: does not have a second buffer @@ -485,20 +496,22 @@ void ArraySpan::FillFromScalar(const Scalar& value) { alignas(int64_t) int8_t type_code; alignas(int64_t) uint8_t offsets[sizeof(int32_t) * 2]; }; - static_assert(sizeof(UnionScratchSpace) <= sizeof(UnionScalar::scratch_space_)); - auto* union_scratch_space = reinterpret_cast( - &checked_cast(value).scratch_space_); + // static_assert(sizeof(UnionScratchSpace) <= sizeof(UnionScalar::scratch_space_)); // First buffer is kept null since unions have no validity vector this->buffers[0] = {}; - // union_scratch_space->type_code = checked_cast(value).type_code; - this->buffers[1].data = reinterpret_cast(&union_scratch_space->type_code); - this->buffers[1].size = 1; - this->child_data.resize(this->type->num_fields()); if (type_id == Type::DENSE_UNION) { const auto& scalar = checked_cast(value); + auto* union_scratch_space = + reinterpret_cast(&scalar.scratch_space_); + + // union_scratch_space->type_code = checked_cast(value).type_code; + this->buffers[1].data = reinterpret_cast(&union_scratch_space->type_code); + this->buffers[1].size = 1; + this->buffers[2] = OffsetsForScalar(union_scratch_space->offsets, static_cast(1)); // We can't "see" the other arrays in the union, but we put the "active" @@ -517,6 +530,14 @@ void ArraySpan::FillFromScalar(const Scalar& value) { } } else { const auto& scalar = checked_cast(value); + auto* union_scratch_space = + reinterpret_cast(&scalar.scratch_space_); + + // union_scratch_space->type_code = checked_cast(value).type_code; + this->buffers[1].data = reinterpret_cast(&union_scratch_space->type_code); + this->buffers[1].size = 1; + // Sparse union scalars have a full complement of child values even // though only one of them is relevant, so we just fill them in here for (int i = 0; i < static_cast(this->child_data.size()); ++i) { diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 57e5da0081b..dc9c4c54726 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -542,33 +542,23 @@ struct ScalarValidateImpl { } }; -// // Helper function to fill scratch space, base case for recursion -// void FillScalarScratchSpaceHelperInternal(uint8_t* scratch_space) {} - -// Recursive variadic function to fill scratch space template -void FillScalarScratchSpaceHelperInternal(uint8_t* scratch_space, T first, Args... args) { - // Ensure that the argument does not exceed the bounds of the scratch space - static_assert(offset + sizeof(T) <= - sizeof(internal::ArraySpanFillFromScalarScratchSpace::scratch_space_), - "Total size of arguments exceeds scratch space size."); - - // Cast the scratch space at the given offset to the type of the current argument and - // assign it +void FillScalarScratchSpaceHelper(uint8_t* scratch_space, T first, Args... args) { + // static_assert(offset + sizeof(T) <= + // sizeof(internal::ArraySpanFillFromScalarScratchSpace::scratch_space_), + // "Total size of arguments exceeds scratch space size."); *reinterpret_cast(scratch_space + offset) = first; - - // Recursively fill the scratch space with the remaining arguments - if constexpr (sizeof...(args) > 0) { // Use if constexpr to stop recursion when - // there are no more arguments - FillScalarScratchSpaceHelperInternal(scratch_space, - std::forward(args)...); + if constexpr (sizeof...(args) > 0) { + FillScalarScratchSpaceHelper(scratch_space, + std::forward(args)...); } } template -void FillScalarScratchSpaceHelper(Args... args) { - FillScalarScratchSpaceHelperInternal<0>(std::forward(args)...); +void FillScalarScratchSpace(Args... args) { + FillScalarScratchSpaceHelper<0>(std::forward(args)...); } + } // namespace size_t Scalar::hash() const { return ScalarHashImpl(*this).hash_; } @@ -581,45 +571,28 @@ Status Scalar::ValidateFull() const { return ScalarValidateImpl(/*full_validation=*/true).Validate(*this); } -template -BaseBinaryScalar::BaseBinaryScalar(std::string s, std::shared_ptr type, - FillScalarScratchSpaceFactoryT factory) - : BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type), - std::move(factory)) {} - -BinaryScalar::BinaryScalar(std::string s, std::shared_ptr type) - : BinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {} +BaseBinaryScalar::BaseBinaryScalar(std::string s, std::shared_ptr type) + : BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {} -void BinaryScalar::FillScalarScratchSpace::Fill( - const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { - FillScalarScratchSpaceHelper( - scratch_space.scratch_space_, int32_t(0), - is_valid_ ? static_cast(value_->size()) : int32_t(0)); +void BinaryScalar::FillScratchSpace() { + FillScalarScratchSpace(scratch_space_, int32_t(0), + is_valid ? static_cast(value->size()) : int32_t(0)); } -BinaryViewScalar::BinaryViewScalar(std::string s, std::shared_ptr type) - : BinaryViewScalar(Buffer::FromString(std::move(s)), std::move(type)) {} - -void BinaryViewScalar::FillScalarScratchSpace::Fill( - const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { +void BinaryViewScalar::FillScratchSpace() { static_assert(sizeof(BinaryViewType::c_type) <= - sizeof(internal::ArraySpanFillFromScalarScratchSpace::scratch_space_)); - auto* view = new (&scratch_space.scratch_space_) BinaryViewType::c_type; - if (is_valid_) { - *view = util::ToBinaryView(std::string_view{*value_}, 0, 0); + sizeof(ArraySpanFillFromScalarScratchSpace::scratch_space_)); + auto* view = new (&scratch_space_) BinaryViewType::c_type; + if (is_valid) { + *view = util::ToBinaryView(std::string_view{*value}, 0, 0); } else { *view = {}; } } -LargeBinaryScalar::LargeBinaryScalar(std::string s, std::shared_ptr type) - : LargeBinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {} - -void LargeBinaryScalar::FillScalarScratchSpace::Fill( - const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { - FillScalarScratchSpaceHelper( - scratch_space.scratch_space_, int64_t(0), - is_valid_ ? static_cast(value_->size()) : int64_t(0)); +void LargeBinaryScalar::FillScratchSpace() { + FillScalarScratchSpace(scratch_space_, int64_t(0), + is_valid ? static_cast(value->size()) : int64_t(0)); } FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::shared_ptr value, @@ -640,72 +613,42 @@ FixedSizeBinaryScalar::FixedSizeBinaryScalar(const std::shared_ptr& valu FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::string s, bool is_valid) : FixedSizeBinaryScalar(Buffer::FromString(std::move(s)), is_valid) {} -template BaseListScalar::BaseListScalar(std::shared_ptr value, - std::shared_ptr type, bool is_valid, - FillScalarScratchSpaceFactoryT factory) - : Scalar{std::move(type), is_valid}, - ArraySpanFillFromScalarScratchSpace(factory(value.get())), - value(std::move(value)) { + std::shared_ptr type, bool is_valid) + : Scalar{std::move(type), is_valid}, value(std::move(value)) { ARROW_CHECK(this->type->field(0)->type()->Equals(this->value->type())); } -ListScalar::ListScalar(std::shared_ptr value, std::shared_ptr type, - bool is_valid) - : BaseListScalar(std::move(value), std::move(type), is_valid, - FillScalarScratchSpaceFactory) {} - ListScalar::ListScalar(std::shared_ptr value, bool is_valid) - : ListScalar(value, list(value->type()), is_valid) {} + : BaseListScalar(value, list(value->type()), is_valid) {} -void ListScalar::FillScalarScratchSpace::Fill( - const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { - FillScalarScratchSpaceHelper( - scratch_space.scratch_space_, int32_t(0), - value_ ? static_cast(value_->length()) : int32_t(0)); +void ListScalar::FillScratchSpace() { + FillScalarScratchSpace(scratch_space_, int32_t(0), + value ? static_cast(value->length()) : int32_t(0)); } -LargeListScalar::LargeListScalar(std::shared_ptr value, - std::shared_ptr type, bool is_valid) - : BaseListScalar(std::move(value), std::move(type), is_valid, - FillScalarScratchSpaceFactory) {} - LargeListScalar::LargeListScalar(std::shared_ptr value, bool is_valid) : LargeListScalar(value, large_list(value->type()), is_valid) {} -void LargeListScalar::FillScalarScratchSpace::Fill( - const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { - FillScalarScratchSpaceHelper(scratch_space.scratch_space_, int64_t(0), - value_ ? value_->length() : int64_t(0)); +void LargeListScalar::FillScratchSpace() { + FillScalarScratchSpace(scratch_space_, int64_t(0), + value ? value->length() : int64_t(0)); } -ListViewScalar::ListViewScalar(std::shared_ptr value, - std::shared_ptr type, bool is_valid) - : BaseListScalar(std::move(value), std::move(type), is_valid, - FillScalarScratchSpaceFactory) {} - ListViewScalar::ListViewScalar(std::shared_ptr value, bool is_valid) : ListViewScalar(value, list_view(value->type()), is_valid) {} -void ListViewScalar::FillScalarScratchSpace::Fill( - const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { - FillScalarScratchSpaceHelper( - scratch_space.scratch_space_, int32_t(0), - value_ ? static_cast(value_->length()) : int32_t(0)); +void ListViewScalar::FillScratchSpace() { + FillScalarScratchSpace(scratch_space_, int32_t(0), + value ? static_cast(value->length()) : int32_t(0)); } -LargeListViewScalar::LargeListViewScalar(std::shared_ptr value, - std::shared_ptr type, bool is_valid) - : BaseListScalar(std::move(value), std::move(type), is_valid, - FillScalarScratchSpaceFactory) {} - LargeListViewScalar::LargeListViewScalar(std::shared_ptr value, bool is_valid) : LargeListViewScalar(value, large_list_view(value->type()), is_valid) {} -void LargeListViewScalar::FillScalarScratchSpace::Fill( - const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { - FillScalarScratchSpaceHelper(scratch_space.scratch_space_, int64_t(0), - value_ ? value_->length() : int64_t(0)); +void LargeListViewScalar::FillScratchSpace() { + FillScalarScratchSpace(scratch_space_, int64_t(0), + value ? value->length() : int64_t(0)); } inline std::shared_ptr MakeMapType(const std::shared_ptr& pair_type) { @@ -714,31 +657,23 @@ inline std::shared_ptr MakeMapType(const std::shared_ptr& pa return map(pair_type->field(0)->type(), pair_type->field(1)->type()); } -MapScalar::MapScalar(std::shared_ptr value, std::shared_ptr type, - bool is_valid) - : BaseListScalar(std::move(value), std::move(type), is_valid, - FillScalarScratchSpaceFactory) {} - MapScalar::MapScalar(std::shared_ptr value, bool is_valid) : MapScalar(value, MakeMapType(value->type()), is_valid) {} -void MapScalar::FillScalarScratchSpace::Fill( - const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { - FillScalarScratchSpaceHelper( - scratch_space.scratch_space_, int32_t(0), - value_ ? static_cast(value_->length()) : int32_t(0)); +void MapScalar::FillScratchSpace() { + FillScalarScratchSpace(scratch_space_, int32_t(0), + value ? static_cast(value->length()) : int32_t(0)); } FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) - : BaseListScalar(std::move(value), std::move(type), is_valid, - FillScalarScratchSpaceFactory) { + : BaseListScalar(std::move(value), std::move(type), is_valid) { ARROW_CHECK_EQ(this->value->length(), checked_cast(*this->type).list_size()); } FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr value, bool is_valid) - : FixedSizeListScalar( + : BaseListScalar( value, fixed_size_list(value->type(), static_cast(value->length())), is_valid) {} @@ -773,10 +708,7 @@ Result> StructScalar::field(FieldRef ref) const { RunEndEncodedScalar::RunEndEncodedScalar(std::shared_ptr value, std::shared_ptr type) - : Scalar{std::move(type), value->is_valid}, - ArraySpanFillFromScalarScratchSpace( - FillScalarScratchSpace(*ree_type(this->type)->run_end_type())), - value{std::move(value)} { + : Scalar{std::move(type), value->is_valid}, value{std::move(value)} { ARROW_CHECK_EQ(this->type->id(), Type::RUN_END_ENCODED); } @@ -787,18 +719,18 @@ RunEndEncodedScalar::RunEndEncodedScalar(const std::shared_ptr& type) RunEndEncodedScalar::~RunEndEncodedScalar() = default; -void RunEndEncodedScalar::FillScalarScratchSpace::Fill( - const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { - switch (run_end_type_.id()) { +void RunEndEncodedScalar::FillScratchSpace() { + auto run_end = run_end_type()->id(); + switch (run_end) { case Type::INT16: - FillScalarScratchSpaceHelper(scratch_space.scratch_space_, int16_t(1)); + FillScalarScratchSpace(scratch_space_, int16_t(1)); break; case Type::INT32: - FillScalarScratchSpaceHelper(scratch_space.scratch_space_, int32_t(1)); + FillScalarScratchSpace(scratch_space_, int32_t(1)); break; default: - DCHECK_EQ(run_end_type_.id(), Type::INT64); - FillScalarScratchSpaceHelper(scratch_space.scratch_space_, int64_t(1)); + DCHECK_EQ(run_end, Type::INT64); + FillScalarScratchSpace(scratch_space_, int64_t(1)); } } @@ -876,8 +808,7 @@ Result TimestampScalar::FromISO8601(std::string_view iso8601, SparseUnionScalar::SparseUnionScalar(ValueType value, int8_t type_code, std::shared_ptr type) - : UnionScalar(std::move(type), type_code, /*is_valid=*/true, - FillScalarScratchSpace(type_code)), + : UnionScalar(std::move(type), type_code, /*is_valid=*/true), value(std::move(value)) { this->child_id = checked_cast(*this->type).child_ids()[type_code]; @@ -908,24 +839,20 @@ struct UnionScratchSpace { alignas(int64_t) int8_t type_code; alignas(int64_t) uint8_t offsets[sizeof(int32_t) * 2]; }; -static_assert(sizeof(UnionScratchSpace) <= - sizeof(internal::ArraySpanFillFromScalarScratchSpace::scratch_space_)); +// static_assert(sizeof(UnionScratchSpace) <= +// sizeof(internal::ArraySpanFillFromScalarScratchSpace::scratch_space_)); } // namespace -void SparseUnionScalar::FillScalarScratchSpace::Fill( - const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { - auto* union_scratch_space = - reinterpret_cast(&scratch_space.scratch_space_); - union_scratch_space->type_code = type_code_; +void SparseUnionScalar::FillScratchSpace() { + auto* union_scratch_space = reinterpret_cast(&scratch_space_); + union_scratch_space->type_code = type_code; } -void DenseUnionScalar::FillScalarScratchSpace::Fill( - const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) const { - auto* union_scratch_space = - reinterpret_cast(&scratch_space.scratch_space_); - union_scratch_space->type_code = type_code_; - FillScalarScratchSpaceHelper(union_scratch_space->offsets, int32_t(0), int32_t(1)); +void DenseUnionScalar::FillScratchSpace() { + auto* union_scratch_space = reinterpret_cast(&scratch_space_); + union_scratch_space->type_code = type_code; + FillScalarScratchSpace(union_scratch_space->offsets, int32_t(0), int32_t(1)); } namespace { diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 57c02c5c37e..6dc2028afa2 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -131,24 +131,17 @@ struct ARROW_EXPORT NullScalar : public Scalar { namespace internal { -struct ArraySpanFillFromScalarScratchSpace; - -// Helper class to fill scratch space in a polymorphic way. -struct FillScalarScratchSpace { - virtual ~FillScalarScratchSpace() = default; - virtual void Fill(const ArraySpanFillFromScalarScratchSpace& scratch_space) const = 0; -}; - +template struct ARROW_EXPORT ArraySpanFillFromScalarScratchSpace { // 16 bytes of scratch space to enable ArraySpan to be a view onto any // Scalar- including binary scalars where we need to create a buffer // that looks like two 32-bit or 64-bit offsets. alignas(int64_t) mutable uint8_t scratch_space_[sizeof(int64_t) * 2]; - protected: - explicit ArraySpanFillFromScalarScratchSpace(const FillScalarScratchSpace& fill) { - fill.Fill(*this); - } + private: + ArraySpanFillFromScalarScratchSpace() { static_cast(this)->FillScratchSpace(); } + + friend Impl; }; struct ARROW_EXPORT PrimitiveScalarBase : public Scalar { @@ -258,9 +251,8 @@ struct ARROW_EXPORT DoubleScalar : public NumericScalar { using NumericScalar::NumericScalar; }; -struct ARROW_EXPORT BaseBinaryScalar - : public internal::PrimitiveScalarBase, - private internal::ArraySpanFillFromScalarScratchSpace { +struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase { + using internal::PrimitiveScalarBase::PrimitiveScalarBase; using ValueType = std::shared_ptr; const std::shared_ptr value; @@ -275,69 +267,32 @@ struct ARROW_EXPORT BaseBinaryScalar return value ? std::string_view(*value) : std::string_view(); } - protected: - template - BaseBinaryScalar(std::shared_ptr type, FillScalarScratchSpaceFactoryT factory) - : PrimitiveScalarBase(std::move(type)), - ArraySpanFillFromScalarScratchSpace(factory(false, NULLPTR)) {} - - template - BaseBinaryScalar(std::shared_ptr value, std::shared_ptr type, - FillScalarScratchSpaceFactoryT factory) - : internal::PrimitiveScalarBase{std::move(type), true}, - ArraySpanFillFromScalarScratchSpace(factory(true, value.get())), - value(std::move(value)) {} + BaseBinaryScalar(std::shared_ptr value, std::shared_ptr type) + : internal::PrimitiveScalarBase{std::move(type), true}, value(std::move(value)) {} - template - BaseBinaryScalar(std::string s, std::shared_ptr type, - FillScalarScratchSpaceFactoryT factory); - - friend ArraySpan; - - protected: - struct FillScalarScratchSpace : public internal::FillScalarScratchSpace { - FillScalarScratchSpace(bool is_valid, const Buffer* value) - : is_valid_(is_valid), value_(value) {} - - protected: - bool is_valid_; - const Buffer* value_; - }; - - template - static FillScalarScratchSpaceT FillScalarScratchSpaceFactory(bool is_valid, - const Buffer* value) { - return {is_valid, value}; - } + BaseBinaryScalar(std::string s, std::shared_ptr type); }; -struct ARROW_EXPORT BinaryScalar : public BaseBinaryScalar { +struct ARROW_EXPORT BinaryScalar + : public BaseBinaryScalar, + private internal::ArraySpanFillFromScalarScratchSpace { + using BaseBinaryScalar::BaseBinaryScalar; using TypeClass = BinaryType; - - explicit BinaryScalar(std::shared_ptr type) - : BaseBinaryScalar(std::move(type), - FillScalarScratchSpaceFactory) {} - - BinaryScalar(std::shared_ptr value, std::shared_ptr type) - : BaseBinaryScalar(std::move(value), std::move(type), - FillScalarScratchSpaceFactory) {} - - BinaryScalar(std::string s, std::shared_ptr type); + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; explicit BinaryScalar(std::shared_ptr value) : BinaryScalar(std::move(value), binary()) {} - explicit BinaryScalar(std::string s) : BinaryScalar(std::move(s), binary()) {} + explicit BinaryScalar(std::string s) : BaseBinaryScalar(std::move(s), binary()) {} BinaryScalar() : BinaryScalar(binary()) {} private: - struct FillScalarScratchSpace : public BaseBinaryScalar::FillScalarScratchSpace { - using BaseBinaryScalar::FillScalarScratchSpace::FillScalarScratchSpace; + void FillScratchSpace(); - void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) - const override; - }; + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; struct ARROW_EXPORT StringScalar : public BinaryScalar { @@ -345,43 +300,36 @@ struct ARROW_EXPORT StringScalar : public BinaryScalar { using TypeClass = StringType; explicit StringScalar(std::shared_ptr value) - : BinaryScalar(std::move(value), utf8()) {} + : StringScalar(std::move(value), utf8()) {} explicit StringScalar(std::string s) : BinaryScalar(std::move(s), utf8()) {} - StringScalar() : BinaryScalar(utf8()) {} + StringScalar() : StringScalar(utf8()) {} }; -struct ARROW_EXPORT BinaryViewScalar : public BaseBinaryScalar { +struct ARROW_EXPORT BinaryViewScalar + : public BaseBinaryScalar, + private internal::ArraySpanFillFromScalarScratchSpace { + using BaseBinaryScalar::BaseBinaryScalar; using TypeClass = BinaryViewType; - - explicit BinaryViewScalar(std::shared_ptr type) - : BaseBinaryScalar(std::move(type), - FillScalarScratchSpaceFactory) {} - - BinaryViewScalar(std::shared_ptr value, std::shared_ptr type) - : BaseBinaryScalar(std::move(value), std::move(type), - FillScalarScratchSpaceFactory) {} - - BinaryViewScalar(std::string s, std::shared_ptr type); + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; explicit BinaryViewScalar(std::shared_ptr value) : BinaryViewScalar(std::move(value), binary_view()) {} explicit BinaryViewScalar(std::string s) - : BinaryViewScalar(std::move(s), binary_view()) {} + : BaseBinaryScalar(std::move(s), binary_view()) {} BinaryViewScalar() : BinaryViewScalar(binary_view()) {} std::string_view view() const override { return std::string_view(*this->value); } private: - struct FillScalarScratchSpace : public BaseBinaryScalar::FillScalarScratchSpace { - using BaseBinaryScalar::FillScalarScratchSpace::FillScalarScratchSpace; + void FillScratchSpace(); - void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) - const override; - }; + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; struct ARROW_EXPORT StringViewScalar : public BinaryViewScalar { @@ -389,43 +337,38 @@ struct ARROW_EXPORT StringViewScalar : public BinaryViewScalar { using TypeClass = StringViewType; explicit StringViewScalar(std::shared_ptr value) - : BinaryViewScalar(std::move(value), utf8_view()) {} + : StringViewScalar(std::move(value), utf8_view()) {} explicit StringViewScalar(std::string s) : BinaryViewScalar(std::move(s), utf8_view()) {} - StringViewScalar() : BinaryViewScalar(utf8_view()) {} + StringViewScalar() : StringViewScalar(utf8_view()) {} }; -struct ARROW_EXPORT LargeBinaryScalar : public BaseBinaryScalar { +struct ARROW_EXPORT LargeBinaryScalar + : public BaseBinaryScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using BaseBinaryScalar::BaseBinaryScalar; using TypeClass = LargeBinaryType; - - explicit LargeBinaryScalar(std::shared_ptr type) - : BaseBinaryScalar(std::move(type), - FillScalarScratchSpaceFactory) {} + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; LargeBinaryScalar(std::shared_ptr value, std::shared_ptr type) - : BaseBinaryScalar(std::move(value), std::move(type), - FillScalarScratchSpaceFactory) {} - - LargeBinaryScalar(std::string s, std::shared_ptr type); + : BaseBinaryScalar(std::move(value), std::move(type)) {} explicit LargeBinaryScalar(std::shared_ptr value) : LargeBinaryScalar(std::move(value), large_binary()) {} explicit LargeBinaryScalar(std::string s) - : LargeBinaryScalar(std::move(s), large_binary()) {} + : BaseBinaryScalar(std::move(s), large_binary()) {} LargeBinaryScalar() : LargeBinaryScalar(large_binary()) {} private: - struct FillScalarScratchSpace : public BaseBinaryScalar::FillScalarScratchSpace { - using BaseBinaryScalar::FillScalarScratchSpace::FillScalarScratchSpace; + void FillScratchSpace(); - void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) - const override; - }; + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; struct ARROW_EXPORT LargeStringScalar : public LargeBinaryScalar { @@ -433,7 +376,7 @@ struct ARROW_EXPORT LargeStringScalar : public LargeBinaryScalar { using TypeClass = LargeStringType; explicit LargeStringScalar(std::shared_ptr value) - : LargeBinaryScalar(std::move(value), large_utf8()) {} + : LargeStringScalar(std::move(value), large_utf8()) {} explicit LargeStringScalar(std::string s) : LargeBinaryScalar(std::move(s), large_utf8()) {} @@ -592,114 +535,98 @@ struct ARROW_EXPORT Decimal256Scalar : public DecimalScalar; std::shared_ptr value; - template BaseListScalar(std::shared_ptr value, std::shared_ptr type, - bool is_valid, FillScalarScratchSpaceFactoryT factory); - - private: - friend struct ArraySpan; - - protected: - struct FillScalarScratchSpace : public internal::FillScalarScratchSpace { - explicit FillScalarScratchSpace(const Array* value) : value_(value) {} - - protected: - const Array* value_; - }; - - template - static FillScalarScratchSpaceT FillScalarScratchSpaceFactory(const Array* value) { - return FillScalarScratchSpaceT{value}; - } + bool is_valid = true); }; -struct ARROW_EXPORT ListScalar : public BaseListScalar { +struct ARROW_EXPORT ListScalar + : public BaseListScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = ListType; - - ListScalar(std::shared_ptr value, std::shared_ptr type, - bool is_valid = true); + using BaseListScalar::BaseListScalar; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; explicit ListScalar(std::shared_ptr value, bool is_valid = true); private: - struct FillScalarScratchSpace : public BaseListScalar::FillScalarScratchSpace { - using BaseListScalar::FillScalarScratchSpace::FillScalarScratchSpace; + void FillScratchSpace(); - void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) - const override; - }; + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; -struct ARROW_EXPORT LargeListScalar : public BaseListScalar { +struct ARROW_EXPORT LargeListScalar + : public BaseListScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = LargeListType; + using BaseListScalar::BaseListScalar; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; - LargeListScalar(std::shared_ptr value, std::shared_ptr type, - bool is_valid = true); explicit LargeListScalar(std::shared_ptr value, bool is_valid = true); private: - struct FillScalarScratchSpace : public BaseListScalar::FillScalarScratchSpace { - using BaseListScalar::FillScalarScratchSpace::FillScalarScratchSpace; + void FillScratchSpace(); - void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) - const override; - }; + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; -struct ARROW_EXPORT ListViewScalar : public BaseListScalar { +struct ARROW_EXPORT ListViewScalar + : public BaseListScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = ListViewType; + using BaseListScalar::BaseListScalar; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; - ListViewScalar(std::shared_ptr value, std::shared_ptr type, - bool is_valid = true); explicit ListViewScalar(std::shared_ptr value, bool is_valid = true); private: - struct FillScalarScratchSpace : public BaseListScalar::FillScalarScratchSpace { - using BaseListScalar::FillScalarScratchSpace::FillScalarScratchSpace; + void FillScratchSpace(); - void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) - const override; - }; + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; -struct ARROW_EXPORT LargeListViewScalar : public BaseListScalar { +struct ARROW_EXPORT LargeListViewScalar + : public BaseListScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = LargeListViewType; + using BaseListScalar::BaseListScalar; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; - LargeListViewScalar(std::shared_ptr value, std::shared_ptr type, - bool is_valid = true); explicit LargeListViewScalar(std::shared_ptr value, bool is_valid = true); private: - struct FillScalarScratchSpace : public BaseListScalar::FillScalarScratchSpace { - using BaseListScalar::FillScalarScratchSpace::FillScalarScratchSpace; + void FillScratchSpace(); - void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) - const override; - }; + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; -struct ARROW_EXPORT MapScalar : public BaseListScalar { +struct ARROW_EXPORT MapScalar + : public BaseListScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = MapType; - // using BaseListScalar::BaseListScalar; + using BaseListScalar::BaseListScalar; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; - MapScalar(std::shared_ptr value, std::shared_ptr type, - bool is_valid = true); explicit MapScalar(std::shared_ptr value, bool is_valid = true); private: - struct FillScalarScratchSpace : public BaseListScalar::FillScalarScratchSpace { - using BaseListScalar::FillScalarScratchSpace::FillScalarScratchSpace; + void FillScratchSpace(); - void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) - const override; - }; + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; struct ARROW_EXPORT FixedSizeListScalar : public BaseListScalar { @@ -709,14 +636,6 @@ struct ARROW_EXPORT FixedSizeListScalar : public BaseListScalar { bool is_valid = true); explicit FixedSizeListScalar(std::shared_ptr value, bool is_valid = true); - - private: - struct FillScalarScratchSpace : public BaseListScalar::FillScalarScratchSpace { - using BaseListScalar::FillScalarScratchSpace::FillScalarScratchSpace; - - void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) - const override {} - }; }; struct ARROW_EXPORT StructScalar : public Scalar { @@ -734,31 +653,22 @@ struct ARROW_EXPORT StructScalar : public Scalar { std::vector field_names); }; -struct ARROW_EXPORT UnionScalar : public Scalar, - private internal::ArraySpanFillFromScalarScratchSpace { +struct ARROW_EXPORT UnionScalar : public Scalar { int8_t type_code; virtual const std::shared_ptr& child_value() const = 0; protected: - struct FillScalarScratchSpace : public internal::FillScalarScratchSpace { - explicit FillScalarScratchSpace(int8_t type_code) : type_code_(type_code) {} - - protected: - int8_t type_code_; - }; - - UnionScalar(std::shared_ptr type, int8_t type_code, bool is_valid, - const FillScalarScratchSpace& fill) - : Scalar(std::move(type), is_valid), - internal::ArraySpanFillFromScalarScratchSpace(fill), - type_code(type_code) {} - - friend struct ArraySpan; + UnionScalar(std::shared_ptr type, int8_t type_code, bool is_valid) + : Scalar(std::move(type), is_valid), type_code(type_code) {} }; -struct ARROW_EXPORT SparseUnionScalar : public UnionScalar { +struct ARROW_EXPORT SparseUnionScalar + : public UnionScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = SparseUnionType; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; // Even though only one of the union values is relevant for this scalar, we // nonetheless construct a vector of scalars, one per union value, to have @@ -781,16 +691,18 @@ struct ARROW_EXPORT SparseUnionScalar : public UnionScalar { std::shared_ptr type); private: - struct FillScalarScratchSpace : public UnionScalar::FillScalarScratchSpace { - using UnionScalar::FillScalarScratchSpace::FillScalarScratchSpace; + void FillScratchSpace(); - void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) - const override; - }; + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; -struct ARROW_EXPORT DenseUnionScalar : public UnionScalar { +struct ARROW_EXPORT DenseUnionScalar + : public UnionScalar, + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = DenseUnionType; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; // For DenseUnionScalar, we can make a valid ArraySpan of length 1 from this // scalar @@ -800,24 +712,23 @@ struct ARROW_EXPORT DenseUnionScalar : public UnionScalar { const std::shared_ptr& child_value() const override { return this->value; } DenseUnionScalar(ValueType value, int8_t type_code, std::shared_ptr type) - : UnionScalar(std::move(type), type_code, value->is_valid, - FillScalarScratchSpace(type_code)), + : UnionScalar(std::move(type), type_code, value->is_valid), value(std::move(value)) {} private: - struct FillScalarScratchSpace : public UnionScalar::FillScalarScratchSpace { - using UnionScalar::FillScalarScratchSpace::FillScalarScratchSpace; + void FillScratchSpace(); - void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) - const override; - }; + friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; struct ARROW_EXPORT RunEndEncodedScalar : public Scalar, - private internal::ArraySpanFillFromScalarScratchSpace { + private internal::ArraySpanFillFromScalarScratchSpace { using TypeClass = RunEndEncodedType; using ValueType = std::shared_ptr; + using ArraySpanFillFromScalarScratchSpace = + internal::ArraySpanFillFromScalarScratchSpace; ValueType value; @@ -829,30 +740,18 @@ struct ARROW_EXPORT RunEndEncodedScalar ~RunEndEncodedScalar() override; const std::shared_ptr& run_end_type() const { - return ree_type(type)->run_end_type(); + return ree_type().run_end_type(); } - const std::shared_ptr& value_type() const { - return ree_type(type)->value_type(); - } + const std::shared_ptr& value_type() const { return ree_type().value_type(); } private: - static std::shared_ptr ree_type(const std::shared_ptr& type) { - return internal::checked_pointer_cast(type); - } - - struct FillScalarScratchSpace : public internal::FillScalarScratchSpace { - explicit FillScalarScratchSpace(const DataType& run_end_type) - : run_end_type_(run_end_type) {} + const TypeClass& ree_type() const { return internal::checked_cast(*type); } - void Fill(const internal::ArraySpanFillFromScalarScratchSpace& scratch_space) - const override; - - private: - const DataType& run_end_type_; - }; + void FillScratchSpace(); friend ArraySpan; + friend ArraySpanFillFromScalarScratchSpace; }; /// \brief A Scalar value for DictionaryType From 02155aac8e02caf0ba8edfc92d64f49f40e6cc64 Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Tue, 2 Apr 2024 02:07:23 +0800 Subject: [PATCH 14/31] Refine --- cpp/src/arrow/array/data.cc | 14 +++++++------- cpp/src/arrow/scalar.cc | 18 ++---------------- cpp/src/arrow/scalar.h | 12 +++++++++++- 3 files changed, 20 insertions(+), 24 deletions(-) diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc index 67581130b31..a6728d6f362 100644 --- a/cpp/src/arrow/array/data.cc +++ b/cpp/src/arrow/array/data.cc @@ -491,11 +491,11 @@ void ArraySpan::FillFromScalar(const Scalar& value) { this->child_data[i].FillFromScalar(*scalar.value[i]); } } else if (is_union(type_id)) { - // Dense union needs scratch space to store both offsets and a type code - struct UnionScratchSpace { - alignas(int64_t) int8_t type_code; - alignas(int64_t) uint8_t offsets[sizeof(int32_t) * 2]; - }; + // // Dense union needs scratch space to store both offsets and a type code + // struct UnionScratchSpace { + // alignas(int64_t) int8_t type_code; + // alignas(int64_t) uint8_t offsets[sizeof(int32_t) * 2]; + // }; // static_assert(sizeof(UnionScratchSpace) <= sizeof(UnionScalar::scratch_space_)); // First buffer is kept null since unions have no validity vector @@ -505,7 +505,7 @@ void ArraySpan::FillFromScalar(const Scalar& value) { if (type_id == Type::DENSE_UNION) { const auto& scalar = checked_cast(value); auto* union_scratch_space = - reinterpret_cast(&scalar.scratch_space_); + reinterpret_cast(&scalar.scratch_space_); // union_scratch_space->type_code = checked_cast(value).type_code; @@ -531,7 +531,7 @@ void ArraySpan::FillFromScalar(const Scalar& value) { } else { const auto& scalar = checked_cast(value); auto* union_scratch_space = - reinterpret_cast(&scalar.scratch_space_); + reinterpret_cast(&scalar.scratch_space_); // union_scratch_space->type_code = checked_cast(value).type_code; diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index dc9c4c54726..18aed8402be 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -544,9 +544,7 @@ struct ScalarValidateImpl { template void FillScalarScratchSpaceHelper(uint8_t* scratch_space, T first, Args... args) { - // static_assert(offset + sizeof(T) <= - // sizeof(internal::ArraySpanFillFromScalarScratchSpace::scratch_space_), - // "Total size of arguments exceeds scratch space size."); + static_assert(offset + sizeof(T) <= internal::kScalarScratchSpaceSize); *reinterpret_cast(scratch_space + offset) = first; if constexpr (sizeof...(args) > 0) { FillScalarScratchSpaceHelper(scratch_space, @@ -580,8 +578,7 @@ void BinaryScalar::FillScratchSpace() { } void BinaryViewScalar::FillScratchSpace() { - static_assert(sizeof(BinaryViewType::c_type) <= - sizeof(ArraySpanFillFromScalarScratchSpace::scratch_space_)); + static_assert(sizeof(BinaryViewType::c_type) <= internal::kScalarScratchSpaceSize); auto* view = new (&scratch_space_) BinaryViewType::c_type; if (is_valid) { *view = util::ToBinaryView(std::string_view{*value}, 0, 0); @@ -833,17 +830,6 @@ std::shared_ptr SparseUnionScalar::FromValue(std::shared_ptr val return std::make_shared(field_values, type_code, std::move(type)); } -namespace { - -struct UnionScratchSpace { - alignas(int64_t) int8_t type_code; - alignas(int64_t) uint8_t offsets[sizeof(int32_t) * 2]; -}; -// static_assert(sizeof(UnionScratchSpace) <= -// sizeof(internal::ArraySpanFillFromScalarScratchSpace::scratch_space_)); - -} // namespace - void SparseUnionScalar::FillScratchSpace() { auto* union_scratch_space = reinterpret_cast(&scratch_space_); union_scratch_space->type_code = type_code; diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 6dc2028afa2..8156765dfb3 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -131,12 +131,14 @@ struct ARROW_EXPORT NullScalar : public Scalar { namespace internal { +constexpr auto kScalarScratchSpaceSize = sizeof(int64_t) * 2; + template struct ARROW_EXPORT ArraySpanFillFromScalarScratchSpace { // 16 bytes of scratch space to enable ArraySpan to be a view onto any // Scalar- including binary scalars where we need to create a buffer // that looks like two 32-bit or 64-bit offsets. - alignas(int64_t) mutable uint8_t scratch_space_[sizeof(int64_t) * 2]; + alignas(int64_t) mutable uint8_t scratch_space_[kScalarScratchSpaceSize]; private: ArraySpanFillFromScalarScratchSpace() { static_cast(this)->FillScratchSpace(); } @@ -661,6 +663,14 @@ struct ARROW_EXPORT UnionScalar : public Scalar { protected: UnionScalar(std::shared_ptr type, int8_t type_code, bool is_valid) : Scalar(std::move(type), is_valid), type_code(type_code) {} + + struct UnionScratchSpace { + alignas(int64_t) int8_t type_code; + alignas(int64_t) uint8_t offsets[sizeof(int32_t) * 2]; + }; + static_assert(sizeof(UnionScratchSpace) <= internal::kScalarScratchSpaceSize); + + friend ArraySpan; }; struct ARROW_EXPORT SparseUnionScalar From 7522c19e8a8d1ea0b3da735ea8c9b3568e602d5a Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Tue, 2 Apr 2024 15:14:10 +0800 Subject: [PATCH 15/31] Fix --- cpp/src/arrow/scalar.cc | 8 ++++---- cpp/src/arrow/scalar.h | 5 +++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 18aed8402be..aead10efa9c 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -625,7 +625,7 @@ void ListScalar::FillScratchSpace() { } LargeListScalar::LargeListScalar(std::shared_ptr value, bool is_valid) - : LargeListScalar(value, large_list(value->type()), is_valid) {} + : BaseListScalar(value, large_list(value->type()), is_valid) {} void LargeListScalar::FillScratchSpace() { FillScalarScratchSpace(scratch_space_, int64_t(0), @@ -633,7 +633,7 @@ void LargeListScalar::FillScratchSpace() { } ListViewScalar::ListViewScalar(std::shared_ptr value, bool is_valid) - : ListViewScalar(value, list_view(value->type()), is_valid) {} + : BaseListScalar(value, list_view(value->type()), is_valid) {} void ListViewScalar::FillScratchSpace() { FillScalarScratchSpace(scratch_space_, int32_t(0), @@ -641,7 +641,7 @@ void ListViewScalar::FillScratchSpace() { } LargeListViewScalar::LargeListViewScalar(std::shared_ptr value, bool is_valid) - : LargeListViewScalar(value, large_list_view(value->type()), is_valid) {} + : BaseListScalar(value, large_list_view(value->type()), is_valid) {} void LargeListViewScalar::FillScratchSpace() { FillScalarScratchSpace(scratch_space_, int64_t(0), @@ -655,7 +655,7 @@ inline std::shared_ptr MakeMapType(const std::shared_ptr& pa } MapScalar::MapScalar(std::shared_ptr value, bool is_valid) - : MapScalar(value, MakeMapType(value->type()), is_valid) {} + : BaseListScalar(value, MakeMapType(value->type()), is_valid) {} void MapScalar::FillScratchSpace() { FillScalarScratchSpace(scratch_space_, int32_t(0), diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 8156765dfb3..7b2a246c239 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -538,12 +538,13 @@ struct ARROW_EXPORT Decimal256Scalar : public DecimalScalar; - std::shared_ptr value; - BaseListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid = true); + + std::shared_ptr value; }; struct ARROW_EXPORT ListScalar From c7c8be4dc1d56797ab8438abb20b303aa59e03a0 Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Tue, 2 Apr 2024 15:19:21 +0800 Subject: [PATCH 16/31] Fix ctor --- cpp/src/arrow/scalar.h | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 7b2a246c239..dc2ce2ccc5a 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -254,7 +254,6 @@ struct ARROW_EXPORT DoubleScalar : public NumericScalar { }; struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase { - using internal::PrimitiveScalarBase::PrimitiveScalarBase; using ValueType = std::shared_ptr; const std::shared_ptr value; @@ -269,6 +268,9 @@ struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase { return value ? std::string_view(*value) : std::string_view(); } + BaseBinaryScalar(std::shared_ptr type) + : internal::PrimitiveScalarBase(std::move(type)) {} + BaseBinaryScalar(std::shared_ptr value, std::shared_ptr type) : internal::PrimitiveScalarBase{std::move(type), true}, value(std::move(value)) {} From 385129cc75fad782213fd4ff4780eb3ca1c8162f Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Tue, 2 Apr 2024 15:32:37 +0800 Subject: [PATCH 17/31] Fix lint --- cpp/src/arrow/scalar.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index dc2ce2ccc5a..9c607eabf75 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -268,7 +268,7 @@ struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase { return value ? std::string_view(*value) : std::string_view(); } - BaseBinaryScalar(std::shared_ptr type) + explicit BaseBinaryScalar(std::shared_ptr type) : internal::PrimitiveScalarBase(std::move(type)) {} BaseBinaryScalar(std::shared_ptr value, std::shared_ptr type) From 28c911a941ffb721f80b8ae4881ed26e206b48a1 Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Tue, 2 Apr 2024 23:34:13 +0800 Subject: [PATCH 18/31] Fix c_glib scalar --- c_glib/arrow-glib/scalar.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/c_glib/arrow-glib/scalar.cpp b/c_glib/arrow-glib/scalar.cpp index def6b151483..20417006441 100644 --- a/c_glib/arrow-glib/scalar.cpp +++ b/c_glib/arrow-glib/scalar.cpp @@ -1063,7 +1063,8 @@ garrow_base_binary_scalar_get_value(GArrowBaseBinaryScalar *scalar) if (!priv->value) { const auto arrow_scalar = std::static_pointer_cast( garrow_scalar_get_raw(GARROW_SCALAR(scalar))); - priv->value = garrow_buffer_new_raw(&(arrow_scalar->value)); + priv->value = garrow_buffer_new_raw( + const_cast *>(&(arrow_scalar->value))); } return priv->value; } From 029b6f63837c772e95bf7f94767047ee5bf8618b Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Sat, 6 Apr 2024 00:58:52 +0800 Subject: [PATCH 19/31] Fix --- cpp/src/arrow/scalar.cc | 10 +++++++--- cpp/src/arrow/scalar.h | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index aead10efa9c..6b9cc37cc3e 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -566,6 +566,7 @@ Status Scalar::Validate() const { } Status Scalar::ValidateFull() const { + // TODO: Validate scratch space content? return ScalarValidateImpl(/*full_validation=*/true).Validate(*this); } @@ -573,14 +574,16 @@ BaseBinaryScalar::BaseBinaryScalar(std::string s, std::shared_ptr type : BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {} void BinaryScalar::FillScratchSpace() { + // TODO: Test value being nullptr. FillScalarScratchSpace(scratch_space_, int32_t(0), - is_valid ? static_cast(value->size()) : int32_t(0)); + value ? static_cast(value->size()) : int32_t(0)); } void BinaryViewScalar::FillScratchSpace() { + // TODO: Test value being nullptr. static_assert(sizeof(BinaryViewType::c_type) <= internal::kScalarScratchSpaceSize); auto* view = new (&scratch_space_) BinaryViewType::c_type; - if (is_valid) { + if (value) { *view = util::ToBinaryView(std::string_view{*value}, 0, 0); } else { *view = {}; @@ -588,8 +591,9 @@ void BinaryViewScalar::FillScratchSpace() { } void LargeBinaryScalar::FillScratchSpace() { + // TODO: Test value being nullptr. FillScalarScratchSpace(scratch_space_, int64_t(0), - is_valid ? static_cast(value->size()) : int64_t(0)); + value ? static_cast(value->size()) : int64_t(0)); } FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::shared_ptr value, diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 9c607eabf75..df5afd25fcb 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -256,7 +256,7 @@ struct ARROW_EXPORT DoubleScalar : public NumericScalar { struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase { using ValueType = std::shared_ptr; - const std::shared_ptr value; + const std::shared_ptr value = NULLPTR; const void* data() const override { return value ? reinterpret_cast(value->data()) : NULLPTR; From a0759d3ffc2872da29d80e2b158bdb058843ec76 Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Sun, 7 Apr 2024 12:30:51 +0800 Subject: [PATCH 20/31] Fix scalar cast from union to string --- cpp/src/arrow/scalar.cc | 4 ++-- cpp/src/arrow/scalar_test.cc | 12 ++++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 6b9cc37cc3e..58516801ee2 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -1313,8 +1313,8 @@ CastImpl(const BaseListScalar& from, std::shared_ptr to_type) { } // union types to string -template -typename std::enable_if_t::value, +template +typename std::enable_if_t::value, Result>> CastImpl(const UnionScalar& from, std::shared_ptr to_type) { const auto& union_ty = checked_cast(*from.type); diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index 10ca8963048..ea4eb145cfe 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -1711,6 +1711,16 @@ class TestUnionScalar : public ::testing::Test { } } + void TestToString() { + ASSERT_EQ(union_alpha_->ToString(), "union{string: string = alpha}"); + ASSERT_EQ(union_beta_->ToString(), "union{string: string = beta}"); + ASSERT_EQ(union_two_->ToString(), "union{number: uint64 = 2}"); + ASSERT_EQ(union_other_two_->ToString(), "union{other_number: uint64 = 2}"); + ASSERT_EQ(union_three_->ToString(), "union{number: uint64 = 3}"); + ASSERT_EQ(union_string_null_->ToString(), "null"); + ASSERT_EQ(union_number_null_->ToString(), "null"); + } + protected: std::shared_ptr type_; const UnionType* union_type_; @@ -1729,6 +1739,8 @@ TYPED_TEST(TestUnionScalar, Equals) { this->TestEquals(); } TYPED_TEST(TestUnionScalar, MakeNullScalar) { this->TestMakeNullScalar(); } +TYPED_TEST(TestUnionScalar, ToString) { this->TestToString(); } + class TestSparseUnionScalar : public TestUnionScalar {}; TEST_F(TestSparseUnionScalar, GetScalar) { From 2448857b62f4bb198a7d72581892f9c8bcfbcbab Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Sun, 7 Apr 2024 15:29:54 +0800 Subject: [PATCH 21/31] Fix map scalar to string --- cpp/src/arrow/scalar.cc | 2 ++ cpp/src/arrow/scalar_test.cc | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 58516801ee2..f6d07505e3d 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -1391,6 +1391,8 @@ struct FromTypeVisitor : CastImplVisitor { return CastFromListLike(large_list_view_type); } + Status Visit(const MapType& map_type) { return CastFromListLike(map_type); } + Status Visit(const NullType&) { return NotImplemented(); } Status Visit(const DictionaryType&) { return NotImplemented(); } Status Visit(const ExtensionType&) { return NotImplemented(); } diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index ea4eb145cfe..48074750fd4 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -1286,6 +1286,17 @@ TEST(TestMapScalar, Cast) { CheckListCastError(scalar, invalid_cast_type); } +TEST(TestMapScalar, ToString) { + auto key_value_type = struct_({field("key", utf8(), false), field("value", int8())}); + auto value = ArrayFromJSON(key_value_type, + R"([{"key": "a", "value": 1}, {"key": "b", "value": 2}])"); + auto scalar = MapScalar(value); + + ASSERT_EQ( + scalar.ToString(), + R"(map[{key:string = a, value:int8 = 1}, {key:string = b, value:int8 = 2}])"); +} + TEST(TestStructScalar, FieldAccess) { StructScalar abc({MakeScalar(true), MakeNullScalar(int32()), MakeScalar("hello"), MakeNullScalar(int64())}, From 420170eec161863051c080bd7cc43b90305826a5 Mon Sep 17 00:00:00 2001 From: Rossi Sun Date: Mon, 8 Apr 2024 14:26:07 +0800 Subject: [PATCH 22/31] Make value immutable for all scalars with scratch space (#6) * Make value immutable for all scalars with scratch space * Glib change for list scalar * Fix * Fix glib build --- c_glib/arrow-glib/scalar.cpp | 3 +- cpp/src/arrow/scalar.cc | 21 ++++-- cpp/src/arrow/scalar.h | 11 ++- cpp/src/arrow/scalar_test.cc | 129 +++++++++++++++++++++++------------ 4 files changed, 105 insertions(+), 59 deletions(-) diff --git a/c_glib/arrow-glib/scalar.cpp b/c_glib/arrow-glib/scalar.cpp index 20417006441..f965b497030 100644 --- a/c_glib/arrow-glib/scalar.cpp +++ b/c_glib/arrow-glib/scalar.cpp @@ -1984,7 +1984,8 @@ garrow_base_list_scalar_get_value(GArrowBaseListScalar *scalar) if (!priv->value) { const auto arrow_scalar = std::static_pointer_cast( garrow_scalar_get_raw(GARROW_SCALAR(scalar))); - priv->value = garrow_array_new_raw(&(arrow_scalar->value)); + priv->value = garrow_array_new_raw( + const_cast *>(&(arrow_scalar->value))); } return priv->value; } diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index f6d07505e3d..c8c05e6cbe8 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -617,7 +617,9 @@ FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::string s, bool is_valid) BaseListScalar::BaseListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) : Scalar{std::move(type), is_valid}, value(std::move(value)) { - ARROW_CHECK(this->type->field(0)->type()->Equals(this->value->type())); + if (this->value) { + ARROW_CHECK(this->type->field(0)->type()->Equals(this->value->type())); + } } ListScalar::ListScalar(std::shared_ptr value, bool is_valid) @@ -669,8 +671,10 @@ void MapScalar::FillScratchSpace() { FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) : BaseListScalar(std::move(value), std::move(type), is_valid) { - ARROW_CHECK_EQ(this->value->length(), - checked_cast(*this->type).list_size()); + if (value) { + ARROW_CHECK_EQ(this->value->length(), + checked_cast(*this->type).list_size()); + } } FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr value, bool is_valid) @@ -811,11 +815,14 @@ SparseUnionScalar::SparseUnionScalar(ValueType value, int8_t type_code, std::shared_ptr type) : UnionScalar(std::move(type), type_code, /*is_valid=*/true), value(std::move(value)) { - this->child_id = - checked_cast(*this->type).child_ids()[type_code]; + const auto child_ids = checked_cast(*this->type).child_ids(); + if (type_code >= 0 && static_cast(type_code) < child_ids.size() && + child_ids[type_code] != UnionType::kInvalidChildId) { + this->child_id = child_ids[type_code]; - // Fix nullness based on whether the selected child is null - this->is_valid = this->value[this->child_id]->is_valid; + // Fix nullness based on whether the selected child is null + this->is_valid = this->value[this->child_id]->is_valid; + } } std::shared_ptr SparseUnionScalar::FromValue(std::shared_ptr value, diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index df5afd25fcb..7b422222466 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -540,13 +540,12 @@ struct ARROW_EXPORT Decimal256Scalar : public DecimalScalar; BaseListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid = true); - std::shared_ptr value; + const std::shared_ptr value; }; struct ARROW_EXPORT ListScalar @@ -659,7 +658,7 @@ struct ARROW_EXPORT StructScalar : public Scalar { }; struct ARROW_EXPORT UnionScalar : public Scalar { - int8_t type_code; + const int8_t type_code; virtual const std::shared_ptr& child_value() const = 0; @@ -687,7 +686,7 @@ struct ARROW_EXPORT SparseUnionScalar // nonetheless construct a vector of scalars, one per union value, to have // enough data to reconstruct a valid ArraySpan of length 1 from this scalar using ValueType = std::vector>; - ValueType value; + const ValueType value; // The value index corresponding to the active type code int child_id; @@ -720,7 +719,7 @@ struct ARROW_EXPORT DenseUnionScalar // For DenseUnionScalar, we can make a valid ArraySpan of length 1 from this // scalar using ValueType = std::shared_ptr; - ValueType value; + const ValueType value; const std::shared_ptr& child_value() const override { return this->value; } @@ -743,7 +742,7 @@ struct ARROW_EXPORT RunEndEncodedScalar using ArraySpanFillFromScalarScratchSpace = internal::ArraySpanFillFromScalarScratchSpace; - ValueType value; + const ValueType value; RunEndEncodedScalar(std::shared_ptr value, std::shared_ptr type); diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index 48074750fd4..32aa8dfa0fa 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -1167,24 +1167,25 @@ class TestListLikeScalar : public ::testing::Test { } void TestValidateErrors() { - ScalarType scalar(value_); - scalar.is_valid = false; - ASSERT_OK(scalar.ValidateFull()); - - // Value must be defined - scalar = ScalarType(value_); - scalar.value = nullptr; - AssertValidationFails(scalar); + { + ScalarType scalar(value_); + scalar.is_valid = false; + ASSERT_OK(scalar.ValidateFull()); + } - // Inconsistent child type - scalar = ScalarType(value_); - scalar.value = ArrayFromJSON(int32(), "[1, 2, null]"); - AssertValidationFails(scalar); + { + // Value must be defined + ScalarType scalar(nullptr, type_); + scalar.is_valid = true; + AssertValidationFails(scalar); + } - // Invalid UTF8 in child data - scalar = ScalarType(ArrayFromJSON(utf8(), "[null, null, \"\xff\"]")); - ASSERT_OK(scalar.Validate()); - ASSERT_RAISES(Invalid, scalar.ValidateFull()); + { + // Invalid UTF8 in child data + ScalarType scalar(ArrayFromJSON(utf8(), "[null, null, \"\xff\"]")); + ASSERT_OK(scalar.Validate()); + ASSERT_RAISES(Invalid, scalar.ValidateFull()); + } } void TestHashing() { @@ -1576,17 +1577,41 @@ void CheckGetNullUnionScalar(const Array& arr, int64_t index) { ASSERT_FALSE(checked_cast(*scalar).child_value()->is_valid); } +std::shared_ptr MakeUnionScalar(const SparseUnionType& type, int8_t type_code, + std::shared_ptr field_value, + int field_index) { + ScalarVector field_values; + for (int i = 0; i < type.num_fields(); ++i) { + if (i == field_index) { + field_values.emplace_back(std::move(field_value)); + } else { + field_values.emplace_back(MakeNullScalar(type.field(i)->type())); + } + } + return std::make_shared(std::move(field_values), type_code, + type.GetSharedPtr()); +} + std::shared_ptr MakeUnionScalar(const SparseUnionType& type, std::shared_ptr field_value, int field_index) { - return SparseUnionScalar::FromValue(field_value, field_index, type.GetSharedPtr()); + return SparseUnionScalar::FromValue(std::move(field_value), field_index, + type.GetSharedPtr()); +} + +std::shared_ptr MakeUnionScalar(const DenseUnionType& type, int8_t type_code, + std::shared_ptr field_value, + int field_index) { + return std::make_shared(std::move(field_value), type_code, + type.GetSharedPtr()); } std::shared_ptr MakeUnionScalar(const DenseUnionType& type, std::shared_ptr field_value, int field_index) { int8_t type_code = type.type_codes()[field_index]; - return std::make_shared(field_value, type_code, type.GetSharedPtr()); + return std::make_shared(std::move(field_value), type_code, + type.GetSharedPtr()); } std::shared_ptr MakeSpecificNullScalar(const DenseUnionType& type, @@ -1634,7 +1659,13 @@ class TestUnionScalar : public ::testing::Test { std::shared_ptr ScalarFromValue(int field_index, std::shared_ptr field_value) { - return MakeUnionScalar(*union_type_, field_value, field_index); + return MakeUnionScalar(*union_type_, std::move(field_value), field_index); + } + + std::shared_ptr ScalarFromTypeCodeAndValue(int8_t type_code, + std::shared_ptr field_value, + int field_index) { + return MakeUnionScalar(*union_type_, type_code, std::move(field_value), field_index); } std::shared_ptr SpecificNull(int field_index) { @@ -1652,40 +1683,48 @@ class TestUnionScalar : public ::testing::Test { } void TestValidateErrors() { - // Type code doesn't exist - auto scalar = ScalarFromValue(0, alpha_); - UnionScalar* union_scalar = static_cast(scalar.get()); - - // Invalid type code - union_scalar->type_code = 0; - AssertValidationFails(*union_scalar); + { + // Invalid type code + auto scalar = ScalarFromTypeCodeAndValue(0, alpha_, 0); + AssertValidationFails(*scalar); + } - union_scalar->is_valid = false; - AssertValidationFails(*union_scalar); + { + auto scalar = ScalarFromTypeCodeAndValue(0, alpha_, 0); + scalar->is_valid = false; + AssertValidationFails(*scalar); + } - union_scalar->type_code = -42; - union_scalar->is_valid = true; - AssertValidationFails(*union_scalar); + { + auto scalar = ScalarFromTypeCodeAndValue(-42, alpha_, 0); + AssertValidationFails(*scalar); + } - union_scalar->is_valid = false; - AssertValidationFails(*union_scalar); + { + auto scalar = ScalarFromTypeCodeAndValue(-42, alpha_, 0); + scalar->is_valid = false; + AssertValidationFails(*scalar); + } // Type code doesn't correspond to child type if (type_->id() == ::arrow::Type::DENSE_UNION) { - union_scalar->type_code = 42; - union_scalar->is_valid = true; - AssertValidationFails(*union_scalar); - - scalar = ScalarFromValue(2, two_); - union_scalar = static_cast(scalar.get()); - union_scalar->type_code = 3; - AssertValidationFails(*union_scalar); + { + auto scalar = ScalarFromTypeCodeAndValue(42, alpha_, 0); + AssertValidationFails(*scalar); + } + + { + auto scalar = ScalarFromTypeCodeAndValue(3, two_, 2); + AssertValidationFails(*scalar); + } } - // underlying value has invalid UTF8 - scalar = ScalarFromValue(0, std::make_shared("\xff")); - ASSERT_OK(scalar->Validate()); - ASSERT_RAISES(Invalid, scalar->ValidateFull()); + { + // underlying value has invalid UTF8 + auto scalar = ScalarFromValue(0, std::make_shared("\xff")); + ASSERT_OK(scalar->Validate()); + ASSERT_RAISES(Invalid, scalar->ValidateFull()); + } } void TestEquals() { From 1ece2a118ee527fd8768a7ca428e311d7804dba8 Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Mon, 8 Apr 2024 15:17:25 +0800 Subject: [PATCH 23/31] Cleanup --- cpp/src/arrow/array/data.cc | 65 ++++++++++--------------------------- 1 file changed, 17 insertions(+), 48 deletions(-) diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc index a6728d6f362..52f21d21e16 100644 --- a/cpp/src/arrow/array/data.cc +++ b/cpp/src/arrow/array/data.cc @@ -283,25 +283,15 @@ void ArraySpan::SetMembers(const ArrayData& data) { namespace { -template -BufferSpan OffsetsForScalar(uint8_t* scratch_space, offset_type value_size) { - // auto* offsets = reinterpret_cast(scratch_space); - // offsets[0] = 0; - // offsets[1] = static_cast(value_size); - // static_assert(2 * sizeof(offset_type) <= 16); - return {scratch_space, sizeof(offset_type) * 2}; +BufferSpan OffsetsForScalar(uint8_t* scratch_space, int64_t offset_width) { + return {scratch_space, offset_width}; } -template std::pair OffsetsAndSizesForScalar(uint8_t* scratch_space, - offset_type value_size) { + int64_t offset_width) { auto* offsets = scratch_space; - auto* sizes = scratch_space + sizeof(offset_type); - // reinterpret_cast(offsets)[0] = 0; - // reinterpret_cast(sizes)[0] = value_size; - // static_assert(2 * sizeof(offset_type) <= 16); - return {BufferSpan{offsets, sizeof(offset_type)}, - BufferSpan{sizes, sizeof(offset_type)}}; + auto* sizes = scratch_space + offset_width; + return {BufferSpan{offsets, offset_width}, BufferSpan{sizes, offset_width}}; } int GetNumBuffers(const DataType& type) { @@ -416,12 +406,12 @@ void ArraySpan::FillFromScalar(const Scalar& value) { } if (is_binary_like(type_id)) { const auto& binary_scalar = checked_cast(value); - this->buffers[1] = - OffsetsForScalar(binary_scalar.scratch_space_, static_cast(data_size)); + this->buffers[1] = OffsetsForScalar(binary_scalar.scratch_space_, sizeof(int32_t)); } else { // is_large_binary_like const auto& large_binary_scalar = checked_cast(value); - this->buffers[1] = OffsetsForScalar(large_binary_scalar.scratch_space_, data_size); + this->buffers[1] = + OffsetsForScalar(large_binary_scalar.scratch_space_, sizeof(int64_t)); } this->buffers[2].data = const_cast(data_buffer); this->buffers[2].size = data_size; @@ -430,13 +420,8 @@ void ArraySpan::FillFromScalar(const Scalar& value) { this->buffers[1].size = BinaryViewType::kSize; this->buffers[1].data = scalar.scratch_space_; - // static_assert(sizeof(BinaryViewType::c_type) <= sizeof(scalar.scratch_space_)); - // auto* view = new (&scalar.scratch_space_) BinaryViewType::c_type; if (scalar.is_valid) { - // *view = util::ToBinaryView(std::string_view{*scalar.value}, 0, 0); this->buffers[2] = internal::PackVariadicBuffers({&scalar.value, 1}); - } else { - // *view = {}; } } else if (type_id == Type::FIXED_SIZE_BINARY) { const auto& scalar = checked_cast(value); @@ -445,12 +430,10 @@ void ArraySpan::FillFromScalar(const Scalar& value) { } else if (is_var_length_list_like(type_id) || type_id == Type::FIXED_SIZE_LIST) { const auto& scalar = checked_cast(value); - int64_t value_length = 0; this->child_data.resize(1); if (scalar.value != nullptr) { // When the scalar is null, scalar.value can also be null this->child_data[0].SetMembers(*scalar.value->data()); - value_length = scalar.value->length(); } else { // Even when the value is null, we still must populate the // child_data to yield a valid array. Tedious @@ -460,24 +443,23 @@ void ArraySpan::FillFromScalar(const Scalar& value) { if (type_id == Type::LIST) { const auto& list_scalar = checked_cast(value); - this->buffers[1] = OffsetsForScalar(list_scalar.scratch_space_, - static_cast(value_length)); + this->buffers[1] = OffsetsForScalar(list_scalar.scratch_space_, sizeof(int32_t)); } else if (type_id == Type::MAP) { const auto& map_scalar = checked_cast(value); - this->buffers[1] = - OffsetsForScalar(map_scalar.scratch_space_, static_cast(value_length)); + this->buffers[1] = OffsetsForScalar(map_scalar.scratch_space_, sizeof(int32_t)); } else if (type_id == Type::LARGE_LIST) { const auto& large_list_scalar = checked_cast(value); - this->buffers[1] = OffsetsForScalar(large_list_scalar.scratch_space_, value_length); + this->buffers[1] = + OffsetsForScalar(large_list_scalar.scratch_space_, sizeof(int64_t)); } else if (type_id == Type::LIST_VIEW) { const auto& list_view_scalar = checked_cast(value); - std::tie(this->buffers[1], this->buffers[2]) = OffsetsAndSizesForScalar( - list_view_scalar.scratch_space_, static_cast(value_length)); + std::tie(this->buffers[1], this->buffers[2]) = + OffsetsAndSizesForScalar(list_view_scalar.scratch_space_, sizeof(int32_t)); } else if (type_id == Type::LARGE_LIST_VIEW) { const auto& large_list_view_scalar = checked_cast(value); - std::tie(this->buffers[1], this->buffers[2]) = - OffsetsAndSizesForScalar(large_list_view_scalar.scratch_space_, value_length); + std::tie(this->buffers[1], this->buffers[2]) = OffsetsAndSizesForScalar( + large_list_view_scalar.scratch_space_, sizeof(int64_t)); } else { DCHECK_EQ(type_id, Type::FIXED_SIZE_LIST); // FIXED_SIZE_LIST: does not have a second buffer @@ -491,13 +473,6 @@ void ArraySpan::FillFromScalar(const Scalar& value) { this->child_data[i].FillFromScalar(*scalar.value[i]); } } else if (is_union(type_id)) { - // // Dense union needs scratch space to store both offsets and a type code - // struct UnionScratchSpace { - // alignas(int64_t) int8_t type_code; - // alignas(int64_t) uint8_t offsets[sizeof(int32_t) * 2]; - // }; - // static_assert(sizeof(UnionScratchSpace) <= sizeof(UnionScalar::scratch_space_)); - // First buffer is kept null since unions have no validity vector this->buffers[0] = {}; @@ -507,13 +482,10 @@ void ArraySpan::FillFromScalar(const Scalar& value) { auto* union_scratch_space = reinterpret_cast(&scalar.scratch_space_); - // union_scratch_space->type_code = checked_cast(value).type_code; this->buffers[1].data = reinterpret_cast(&union_scratch_space->type_code); this->buffers[1].size = 1; - this->buffers[2] = - OffsetsForScalar(union_scratch_space->offsets, static_cast(1)); + this->buffers[2] = OffsetsForScalar(union_scratch_space->offsets, sizeof(int32_t)); // We can't "see" the other arrays in the union, but we put the "active" // union array in the right place and fill zero-length arrays for the // others @@ -533,8 +505,6 @@ void ArraySpan::FillFromScalar(const Scalar& value) { auto* union_scratch_space = reinterpret_cast(&scalar.scratch_space_); - // union_scratch_space->type_code = checked_cast(value).type_code; this->buffers[1].data = reinterpret_cast(&union_scratch_space->type_code); this->buffers[1].size = 1; @@ -562,7 +532,6 @@ void ArraySpan::FillFromScalar(const Scalar& value) { e.null_count = 0; e.buffers[1].data = scalar.scratch_space_; e.buffers[1].size = sizeof(run_end); - // reinterpret_cast(scalar.scratch_space_)[0] = run_end; }; switch (scalar.run_end_type()->id()) { From c328656a4be1a0548ec4b4447e630b46a75182c3 Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Mon, 8 Apr 2024 15:32:58 +0800 Subject: [PATCH 24/31] More cleanup useless code --- .../arrow/compute/kernels/codegen_internal.h | 37 ------------------- cpp/src/arrow/scalar.cc | 19 ---------- 2 files changed, 56 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 8e21393743e..097ee1de45b 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -369,43 +369,6 @@ struct UnboxScalar { } }; -// template -// struct BoxScalar; - -// template -// struct BoxScalar> { -// using T = typename GetOutputType::T; -// static void Box(T val, Scalar* out) { -// // Enables BoxScalar to work on a (for example) Time64Scalar -// T* mutable_data = reinterpret_cast( -// checked_cast<::arrow::internal::PrimitiveScalarBase*>(out)->mutable_data()); -// *mutable_data = val; -// } -// }; - -// template -// struct BoxScalar> { -// using T = typename GetOutputType::T; -// using ScalarType = typename TypeTraits::ScalarType; -// static void Box(T val, Scalar* out) { -// checked_cast(out)->value = std::make_shared(val); -// } -// }; - -// template <> -// struct BoxScalar { -// using T = Decimal128; -// using ScalarType = Decimal128Scalar; -// static void Box(T val, Scalar* out) { checked_cast(out)->value = val; } -// }; - -// template <> -// struct BoxScalar { -// using T = Decimal256; -// using ScalarType = Decimal256Scalar; -// static void Box(T val, Scalar* out) { checked_cast(out)->value = val; } -// }; - // A VisitArraySpanInline variant that calls its visitor function with logical // values, such as Decimal128 rather than std::string_view. diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index c8c05e6cbe8..22f9c773e6f 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -566,7 +566,6 @@ Status Scalar::Validate() const { } Status Scalar::ValidateFull() const { - // TODO: Validate scratch space content? return ScalarValidateImpl(/*full_validation=*/true).Validate(*this); } @@ -1246,24 +1245,6 @@ CastImpl(const From& from, std::shared_ptr to_type) { std::move(to_type)); } -// template -// typename std::enable_if_t::value, -// Result>> -// CastImpl(const Decimal128Scalar& from, std::shared_ptr to_type) { -// auto from_type = checked_cast(from.type.get()); -// return std::make_shared( -// Buffer::FromString(from.value.ToString(from_type->scale())), std::move(to_type)); -// } - -// template -// typename std::enable_if_t::value, -// Result>> -// CastImpl(const Decimal256Scalar& from, std::shared_ptr to_type) { -// auto from_type = checked_cast(from.type.get()); -// return std::make_shared( -// Buffer::FromString(from.value.ToString(from_type->scale())), std::move(to_type)); -// } - template typename std::enable_if_t::value, Result>> From af6760f0b9df5c7d0cec0a3fd7ec75716121468d Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Mon, 8 Apr 2024 15:36:49 +0800 Subject: [PATCH 25/31] Comment --- cpp/src/arrow/array/array_test.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index 1411e08c158..1656454aa4d 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -824,6 +824,8 @@ TEST_F(TestArray, TestFillFromScalar) { } } +// GH-40069: Data-race when concurrent calling ArraySpan::FillFromScalar of the same +// scalar instance. TEST_F(TestArray, TestConcurrentFillFromScalar) { for (auto type : TestArrayUtilitiesAgainstTheseTypes()) { ARROW_SCOPED_TRACE("type = ", type->ToString()); From d13c36f2e9a8172c3c6392aceef19c0d8ff07fd4 Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Mon, 8 Apr 2024 16:11:40 +0800 Subject: [PATCH 26/31] Fix --- cpp/src/arrow/array/data.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/array/data.cc b/cpp/src/arrow/array/data.cc index 52f21d21e16..ff3112ec1fc 100644 --- a/cpp/src/arrow/array/data.cc +++ b/cpp/src/arrow/array/data.cc @@ -284,7 +284,7 @@ void ArraySpan::SetMembers(const ArrayData& data) { namespace { BufferSpan OffsetsForScalar(uint8_t* scratch_space, int64_t offset_width) { - return {scratch_space, offset_width}; + return {scratch_space, offset_width * 2}; } std::pair OffsetsAndSizesForScalar(uint8_t* scratch_space, From 78c5c9eb88f6e2ec62c61b266f295add6de4aeda Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Mon, 8 Apr 2024 16:40:18 +0800 Subject: [PATCH 27/31] Refine test --- cpp/src/arrow/scalar_test.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index 32aa8dfa0fa..a7bd3c53325 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -2066,6 +2066,15 @@ TEST_F(TestExtensionScalar, ValidateErrors) { // If the scalar is null it's okay scalar.is_valid = false; ASSERT_OK(scalar.ValidateFull()); + + // Invalid storage scalar (invalid UTF8) + ASSERT_OK_AND_ASSIGN(std::shared_ptr invalid_storage, + MakeScalar(utf8(), std::make_shared("\xff"))); + ASSERT_OK(invalid_storage->Validate()); + ASSERT_RAISES(Invalid, invalid_storage->ValidateFull()); + scalar = ExtensionScalar(invalid_storage, type_); + ASSERT_OK(scalar.Validate()); + ASSERT_RAISES(Invalid, scalar.ValidateFull()); } } // namespace arrow From f4fbb2840331b4a5acc89acb383cff6aba29c5b4 Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Mon, 8 Apr 2024 19:27:53 +0800 Subject: [PATCH 28/31] More cast tests --- cpp/src/arrow/scalar.cc | 48 ++++------------ cpp/src/arrow/scalar.h | 36 +++++++----- cpp/src/arrow/scalar_test.cc | 107 ++++++++++++++++++++++++++++++----- 3 files changed, 125 insertions(+), 66 deletions(-) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 22f9c773e6f..e3eafceba25 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -670,7 +670,7 @@ void MapScalar::FillScratchSpace() { FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid) : BaseListScalar(std::move(value), std::move(type), is_valid) { - if (value) { + if (this->value) { ARROW_CHECK_EQ(this->value->length(), checked_cast(*this->type).list_size()); } @@ -1261,18 +1261,18 @@ CastImpl(const StructScalar& from, std::shared_ptr to_type) { } // casts between variable-length and fixed-length list types -template -enable_if_list_type>> -CastImpl(const BaseListScalar& from, std::shared_ptr to_type) { - if constexpr (sizeof(typename ToScalar::TypeClass::offset_type) < sizeof(int64_t)) { - if (from.value->length() > - std::numeric_limits::max()) { +template +std::enable_if_t::value && is_list_type::value, + Result>> +CastImpl(const From& from, std::shared_ptr to_type) { + if constexpr (sizeof(typename To::offset_type) < sizeof(int64_t)) { + if (from.value->length() > std::numeric_limits::max()) { return Status::Invalid(from.type->ToString(), " too large to cast to ", to_type->ToString()); } } - if constexpr (is_fixed_size_list_type::value) { + if constexpr (is_fixed_size_list_type::value) { const auto& fixed_size_list_type = checked_cast(*to_type); if (from.value->length() != fixed_size_list_type.list_size()) { return Status::Invalid("Cannot cast ", from.type->ToString(), " of length ", @@ -1281,12 +1281,13 @@ CastImpl(const BaseListScalar& from, std::shared_ptr to_type) { } } + using ToScalar = typename TypeTraits::ScalarType; return std::make_shared(from.value, std::move(to_type), from.is_valid); } // list based types (list, large list and map (fixed sized list too)) to string -template -typename std::enable_if_t::value, +template +typename std::enable_if_t::value, Result>> CastImpl(const BaseListScalar& from, std::shared_ptr to_type) { std::stringstream ss; @@ -1354,33 +1355,6 @@ struct FromTypeVisitor : CastImplVisitor { return Status::OK(); } - Status CastFromListLike(const BaseListType& base_list_type) { - ARROW_ASSIGN_OR_RAISE(out_, - CastImpl(checked_cast(from_), - std::move(to_type_))); - return Status::OK(); - } - - Status Visit(const ListType& list_type) { return CastFromListLike(list_type); } - - Status Visit(const LargeListType& large_list_type) { - return CastFromListLike(large_list_type); - } - - Status Visit(const FixedSizeListType& fixed_size_list_type) { - return CastFromListLike(fixed_size_list_type); - } - - Status Visit(const ListViewType& list_view_type) { - return CastFromListLike(list_view_type); - } - - Status Visit(const LargeListViewType& large_list_view_type) { - return CastFromListLike(large_list_view_type); - } - - Status Visit(const MapType& map_type) { return CastFromListLike(map_type); } - Status Visit(const NullType&) { return NotImplemented(); } Status Visit(const DictionaryType&) { return NotImplemented(); } Status Visit(const ExtensionType&) { return NotImplemented(); } diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 7b422222466..a7ee6a417d9 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -153,8 +153,6 @@ struct ARROW_EXPORT PrimitiveScalarBase : public Scalar { using Scalar::Scalar; /// \brief Get a const pointer to the value of this scalar. May be null. virtual const void* data() const = 0; - /// \brief Get a mutable pointer to the value of this scalar. May be null. - // virtual void* mutable_data() = 0; /// \brief Get an immutable view of the value of this scalar as bytes. virtual std::string_view view() const = 0; }; @@ -175,7 +173,6 @@ struct ARROW_EXPORT PrimitiveScalar : public PrimitiveScalarBase { ValueType value{}; const void* data() const override { return &value; } - // void* mutable_data() override { return &value; } std::string_view view() const override { return std::string_view(reinterpret_cast(&value), sizeof(ValueType)); }; @@ -256,14 +253,15 @@ struct ARROW_EXPORT DoubleScalar : public NumericScalar { struct ARROW_EXPORT BaseBinaryScalar : public internal::PrimitiveScalarBase { using ValueType = std::shared_ptr; + // The value is not supposed to be modified after construction, because subclasses have + // a scratch space whose content need to be kept consistent with the value. It is also + // the user of this class's responsibility to ensure that the buffer is not written to + // accidentally. const std::shared_ptr value = NULLPTR; const void* data() const override { return value ? reinterpret_cast(value->data()) : NULLPTR; } - // void* mutable_data() override { - // return value ? reinterpret_cast(value->mutable_data()) : NULLPTR; - // } std::string_view view() const override { return value ? std::string_view(*value) : std::string_view(); } @@ -519,10 +517,6 @@ struct ARROW_EXPORT DecimalScalar : public internal::PrimitiveScalarBase { return reinterpret_cast(value.native_endian_bytes()); } - // void* mutable_data() override { - // return reinterpret_cast(value.mutable_native_endian_bytes()); - // } - std::string_view view() const override { return std::string_view(reinterpret_cast(value.native_endian_bytes()), ValueType::kByteWidth); @@ -545,6 +539,10 @@ struct ARROW_EXPORT BaseListScalar : public Scalar { BaseListScalar(std::shared_ptr value, std::shared_ptr type, bool is_valid = true); + // The value is not supposed to be modified after construction, because subclasses have + // a scratch space whose content need to be kept consistent with the value. It is also + // the user of this class's responsibility to ensure that the array is not modified + // accidentally. const std::shared_ptr value; }; @@ -658,6 +656,8 @@ struct ARROW_EXPORT StructScalar : public Scalar { }; struct ARROW_EXPORT UnionScalar : public Scalar { + // The type code is not supposed to be modified after construction, because the scratch + // space's content need to be kept consistent with it. const int8_t type_code; virtual const std::shared_ptr& child_value() const = 0; @@ -686,6 +686,10 @@ struct ARROW_EXPORT SparseUnionScalar // nonetheless construct a vector of scalars, one per union value, to have // enough data to reconstruct a valid ArraySpan of length 1 from this scalar using ValueType = std::vector>; + // The value is not supposed to be modified after construction, because the scratch + // space's content need to be kept consistent with the value. It is also the user of + // this class's responsibility to ensure that the scalars of the vector is not modified + // to accidentally. const ValueType value; // The value index corresponding to the active type code @@ -719,6 +723,10 @@ struct ARROW_EXPORT DenseUnionScalar // For DenseUnionScalar, we can make a valid ArraySpan of length 1 from this // scalar using ValueType = std::shared_ptr; + // The value is not supposed to be modified after construction, because the scratch + // space's content need to be kept consistent with the value. It is also the user of + // this class's responsibility to ensure that the elements of the vector is not modified + // accidentally. const ValueType value; const std::shared_ptr& child_value() const override { return this->value; } @@ -742,6 +750,10 @@ struct ARROW_EXPORT RunEndEncodedScalar using ArraySpanFillFromScalarScratchSpace = internal::ArraySpanFillFromScalarScratchSpace; + // The value is not supposed to be modified after construction, because the scratch + // space's content need to be kept consistent with the value. It is also the user of + // this class's responsibility to ensure that the wrapped scalar is not modified + // accidentally. const ValueType value; RunEndEncodedScalar(std::shared_ptr value, std::shared_ptr type); @@ -791,10 +803,6 @@ struct ARROW_EXPORT DictionaryScalar : public internal::PrimitiveScalarBase { const void* data() const override { return internal::checked_cast(*value.index).data(); } - // void* mutable_data() override { - // return internal::checked_cast(*value.index) - // .mutable_data(); - // } std::string_view view() const override { return internal::checked_cast(*value.index) .view(); diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index a7bd3c53325..c9698e7b945 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -95,7 +95,8 @@ TEST(TestNullScalar, ValidateErrors) { AssertValidationFails(scalar); } -TEST(TestNullScalar, Cast) { +// Test Scalar::CastTo goes to the right CastImpl specialization. +TEST(TestNullScalar, CastTo) { NullScalar scalar; for (auto to_type : { int8(), @@ -112,6 +113,8 @@ TEST(TestNullScalar, Cast) { decimal(12, 2), list_view(int32()), large_list(int32()), + dense_union({field("string", utf8()), field("number", uint64())}), + sparse_union({field("string", utf8()), field("number", uint64())}), }) { ASSERT_OK_AND_ASSIGN(auto casted, scalar.CastTo(to_type)); ASSERT_EQ(casted->type->id(), to_type->id()); @@ -488,13 +491,42 @@ class TestDecimalScalar : public ::testing::Test { ::testing::HasSubstr("does not fit in precision of"), invalid.ValidateFull()); } + + // Test Scalar::CastTo goes to the right CastImpl specialization. + void TestCastTo() { + const auto ty = std::make_shared(3, 2); + const auto pi = ScalarType(ValueType(314), ty); + + ASSERT_OK_AND_ASSIGN(auto pi_str, pi.CastTo(utf8())); + ASSERT_TRUE(pi_str->Equals(StringScalar("3.14"))); + + for (auto to_type : { + int8(), + float64(), + date32(), + time32(TimeUnit::SECOND), + timestamp(TimeUnit::SECOND), + duration(TimeUnit::SECOND), + large_binary(), + list(int32()), + struct_({field("f", int32())}), + map(utf8(), int32()), + decimal(12, 2), + list_view(int32()), + large_list(int32()), + dense_union({field("string", utf8()), field("number", uint64())}), + sparse_union({field("string", utf8()), field("number", uint64())}), + }) { + ASSERT_RAISES(NotImplemented, pi.CastTo(to_type)); + } + } }; TYPED_TEST_SUITE(TestDecimalScalar, DecimalArrowTypes); TYPED_TEST(TestDecimalScalar, Basics) { this->TestBasics(); } -TYPED_TEST(TestDecimalScalar, Cast) {} +TYPED_TEST(TestDecimalScalar, CastTo) { this->TestCastTo(); } TEST(TestBinaryScalar, Basics) { std::string data = "test data"; @@ -1287,15 +1319,40 @@ TEST(TestMapScalar, Cast) { CheckListCastError(scalar, invalid_cast_type); } -TEST(TestMapScalar, ToString) { +// Test Scalar::CastTo goes to the right CastImpl specialization. +TEST(TestMapScalar, CastTo) { auto key_value_type = struct_({field("key", utf8(), false), field("value", int8())}); auto value = ArrayFromJSON(key_value_type, R"([{"key": "a", "value": 1}, {"key": "b", "value": 2}])"); auto scalar = MapScalar(value); - ASSERT_EQ( - scalar.ToString(), - R"(map[{key:string = a, value:int8 = 1}, {key:string = b, value:int8 = 2}])"); + // Supported cast types. + { + ASSERT_OK_AND_ASSIGN(auto casted, scalar.CastTo(utf8())); + ASSERT_TRUE(casted->Equals(StringScalar( + R"(map[{key:string = a, value:int8 = 1}, {key:string = b, value:int8 = 2}])"))); + } + + // Unsupported cast types. + for (auto to_type : { + int8(), + float64(), + date32(), + time32(TimeUnit::SECOND), + timestamp(TimeUnit::SECOND), + duration(TimeUnit::SECOND), + large_binary(), + list(int32()), + struct_({field("f", int32())}), + map(utf8(), int32()), + decimal(12, 2), + list_view(int32()), + large_list(int32()), + dense_union({field("string", utf8()), field("number", uint64())}), + sparse_union({field("string", utf8()), field("number", uint64())}), + }) { + ASSERT_RAISES(NotImplemented, scalar.CastTo(to_type)); + } } TEST(TestStructScalar, FieldAccess) { @@ -1761,14 +1818,34 @@ class TestUnionScalar : public ::testing::Test { } } - void TestToString() { - ASSERT_EQ(union_alpha_->ToString(), "union{string: string = alpha}"); - ASSERT_EQ(union_beta_->ToString(), "union{string: string = beta}"); - ASSERT_EQ(union_two_->ToString(), "union{number: uint64 = 2}"); - ASSERT_EQ(union_other_two_->ToString(), "union{other_number: uint64 = 2}"); - ASSERT_EQ(union_three_->ToString(), "union{number: uint64 = 3}"); - ASSERT_EQ(union_string_null_->ToString(), "null"); - ASSERT_EQ(union_number_null_->ToString(), "null"); + // Test Scalar::CastTo goes to the right CastImpl specialization. + void TestCastTo() { + // Supported cast types. + { + ASSERT_OK_AND_ASSIGN(auto casted, union_alpha_->CastTo(utf8())); + ASSERT_TRUE(casted->Equals(StringScalar(R"(union{string: string = alpha})"))); + } + + // Unsupported cast types. + for (auto to_type : { + int8(), + float64(), + date32(), + time32(TimeUnit::SECOND), + timestamp(TimeUnit::SECOND), + duration(TimeUnit::SECOND), + large_binary(), + list(int32()), + struct_({field("f", int32())}), + map(utf8(), int32()), + decimal(12, 2), + list_view(int32()), + large_list(int32()), + dense_union({field("string", utf8()), field("number", uint64())}), + sparse_union({field("string", utf8()), field("number", uint64())}), + }) { + ASSERT_RAISES(NotImplemented, union_alpha_->CastTo(to_type)); + } } protected: @@ -1789,7 +1866,7 @@ TYPED_TEST(TestUnionScalar, Equals) { this->TestEquals(); } TYPED_TEST(TestUnionScalar, MakeNullScalar) { this->TestMakeNullScalar(); } -TYPED_TEST(TestUnionScalar, ToString) { this->TestToString(); } +TYPED_TEST(TestUnionScalar, CastTo) { this->TestCastTo(); } class TestSparseUnionScalar : public TestUnionScalar {}; From 57a547d8e38f27fb470eefcc9cf7f54bf1231520 Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Mon, 8 Apr 2024 23:07:40 +0800 Subject: [PATCH 29/31] Add more cast tests --- cpp/src/arrow/scalar.cc | 1 + cpp/src/arrow/scalar_test.cc | 201 +++++++++++++++++++---------------- 2 files changed, 111 insertions(+), 91 deletions(-) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index e3eafceba25..902abef7542 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -1245,6 +1245,7 @@ CastImpl(const From& from, std::shared_ptr to_type) { std::move(to_type)); } +// struct to string template typename std::enable_if_t::value, Result>> diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index c9698e7b945..ca6f26c4a37 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -95,8 +95,7 @@ TEST(TestNullScalar, ValidateErrors) { AssertValidationFails(scalar); } -// Test Scalar::CastTo goes to the right CastImpl specialization. -TEST(TestNullScalar, CastTo) { +TEST(TestNullScalar, Cast) { NullScalar scalar; for (auto to_type : { int8(), @@ -116,12 +115,48 @@ TEST(TestNullScalar, CastTo) { dense_union({field("string", utf8()), field("number", uint64())}), sparse_union({field("string", utf8()), field("number", uint64())}), }) { + // Cast() function doesn't support casting null scalar, use Scalar::CastTo() instead. ASSERT_OK_AND_ASSIGN(auto casted, scalar.CastTo(to_type)); ASSERT_EQ(casted->type->id(), to_type->id()); ASSERT_FALSE(casted->is_valid); } } +TEST(TestBooleanScalar, Cast) { + for (auto b : {true, false}) { + BooleanScalar scalar(b); + ARROW_SCOPED_TRACE("boolean value: ", scalar.ToString()); + + // Boolean type (identity cast). + { + ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, boolean())); + ASSERT_TRUE(casted.scalar()->Equals(scalar)) << casted.scalar()->ToString(); + } + + // Numeric types. + for (auto to_type : { + int8(), + uint16(), + int32(), + uint64(), + float32(), + float64(), + }) { + ARROW_SCOPED_TRACE("to type: ", to_type->ToString()); + ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, to_type)); + ASSERT_EQ(casted.scalar()->type->id(), to_type->id()); + ASSERT_EQ(casted.scalar()->ToString(), std::to_string(b)); + } + + // String type. + { + ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, utf8())); + ASSERT_EQ(casted.scalar()->type->id(), utf8()->id()); + ASSERT_EQ(casted.scalar()->ToString(), scalar.ToString()); + } + } +} + template class TestNumericScalar : public ::testing::Test { public: @@ -492,33 +527,13 @@ class TestDecimalScalar : public ::testing::Test { invalid.ValidateFull()); } - // Test Scalar::CastTo goes to the right CastImpl specialization. - void TestCastTo() { + void TestCast() { const auto ty = std::make_shared(3, 2); const auto pi = ScalarType(ValueType(314), ty); - ASSERT_OK_AND_ASSIGN(auto pi_str, pi.CastTo(utf8())); - ASSERT_TRUE(pi_str->Equals(StringScalar("3.14"))); - - for (auto to_type : { - int8(), - float64(), - date32(), - time32(TimeUnit::SECOND), - timestamp(TimeUnit::SECOND), - duration(TimeUnit::SECOND), - large_binary(), - list(int32()), - struct_({field("f", int32())}), - map(utf8(), int32()), - decimal(12, 2), - list_view(int32()), - large_list(int32()), - dense_union({field("string", utf8()), field("number", uint64())}), - sparse_union({field("string", utf8()), field("number", uint64())}), - }) { - ASSERT_RAISES(NotImplemented, pi.CastTo(to_type)); - } + ASSERT_OK_AND_ASSIGN(auto casted, Cast(pi, utf8())); + ASSERT_TRUE(casted.scalar()->Equals(StringScalar("3.14"))) + << casted.scalar()->ToString(); } }; @@ -526,7 +541,7 @@ TYPED_TEST_SUITE(TestDecimalScalar, DecimalArrowTypes); TYPED_TEST(TestDecimalScalar, Basics) { this->TestBasics(); } -TYPED_TEST(TestDecimalScalar, CastTo) { this->TestCastTo(); } +TYPED_TEST(TestDecimalScalar, Cast) { this->TestCast(); } TEST(TestBinaryScalar, Basics) { std::string data = "test data"; @@ -609,6 +624,14 @@ TEST(TestBinaryScalar, ValidateErrors) { AssertValidationFails(*null_scalar); } +TEST(TestBinaryScalar, Cast) { + BinaryScalar scalar(Buffer::FromString("test data")); + ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, utf8())); + ASSERT_EQ(casted.scalar()->type->id(), utf8()->id()); + AssertBufferEqual(*checked_cast(*casted.scalar()).value, + *scalar.value); +} + template class TestStringScalar : public ::testing::Test { public: @@ -686,6 +709,11 @@ TEST(TestStringScalar, MakeScalarString) { ASSERT_EQ(StringScalar("three"), *three); } +TEST(TestStringScalar, Cast) { + std::string s = "test data"; + // TODO: all others. +} + TEST(TestFixedSizeBinaryScalar, Basics) { std::string data = "test data"; auto buf = std::make_shared(data); @@ -743,6 +771,15 @@ TEST(TestFixedSizeBinaryScalar, ValidateErrors) { ASSERT_RAISES(Invalid, MakeScalar(type, SliceBuffer(buf, 1))); } +TEST(TestFixedSizeBinaryScalar, Cast) { + std::string data = "test data"; + FixedSizeBinaryScalar scalar(data); + ASSERT_OK_AND_ASSIGN(auto casted, Cast(scalar, utf8())); + ASSERT_EQ(casted.scalar()->type->id(), utf8()->id()); + AssertBufferEqual(*checked_cast(*casted.scalar()).value, + *scalar.value); +} + TEST(TestDateScalars, Basics) { int32_t i32_val = 1; Date32Scalar date32_val(i32_val); @@ -1259,6 +1296,12 @@ class TestListLikeScalar : public ::testing::Test { auto invalid_cast_type = fixed_size_list(value_->type(), 5); CheckListCastError(scalar, invalid_cast_type); + + // Cast() function doesn't support casting list-like to string, use Scalar::CastTo() + // instead. + ASSERT_OK_AND_ASSIGN(auto casted_str, scalar.CastTo(utf8())); + ASSERT_EQ(casted_str->type->id(), utf8()->id()); + ASSERT_EQ(casted_str->ToString(), scalar.ToString()); } protected: @@ -1288,6 +1331,24 @@ TEST(TestFixedSizeListScalar, ValidateErrors) { AssertValidationFails(scalar); } +TEST(TestFixedSizeListScalar, Cast) { + const auto ty = fixed_size_list(int16(), 3); + FixedSizeListScalar scalar(ArrayFromJSON(int16(), "[1, 2, 5]"), ty); + + CheckListCast(scalar, list(int16())); + CheckListCast(scalar, large_list(int16())); + CheckListCast(scalar, fixed_size_list(int16(), 3)); + + auto invalid_cast_type = fixed_size_list(int16(), 4); + CheckListCastError(scalar, invalid_cast_type); + + // Cast() function doesn't support casting list-like to string, use Scalar::CastTo() + // instead. + ASSERT_OK_AND_ASSIGN(auto casted_str, scalar.CastTo(utf8())); + ASSERT_EQ(casted_str->type->id(), utf8()->id()); + ASSERT_EQ(casted_str->ToString(), scalar.ToString()); +} + TEST(TestMapScalar, Basics) { auto value = ArrayFromJSON(struct_({field("key", utf8(), false), field("value", int8())}), @@ -1317,42 +1378,12 @@ TEST(TestMapScalar, Cast) { auto invalid_cast_type = fixed_size_list(key_value_type, 5); CheckListCastError(scalar, invalid_cast_type); -} - -// Test Scalar::CastTo goes to the right CastImpl specialization. -TEST(TestMapScalar, CastTo) { - auto key_value_type = struct_({field("key", utf8(), false), field("value", int8())}); - auto value = ArrayFromJSON(key_value_type, - R"([{"key": "a", "value": 1}, {"key": "b", "value": 2}])"); - auto scalar = MapScalar(value); - - // Supported cast types. - { - ASSERT_OK_AND_ASSIGN(auto casted, scalar.CastTo(utf8())); - ASSERT_TRUE(casted->Equals(StringScalar( - R"(map[{key:string = a, value:int8 = 1}, {key:string = b, value:int8 = 2}])"))); - } - // Unsupported cast types. - for (auto to_type : { - int8(), - float64(), - date32(), - time32(TimeUnit::SECOND), - timestamp(TimeUnit::SECOND), - duration(TimeUnit::SECOND), - large_binary(), - list(int32()), - struct_({field("f", int32())}), - map(utf8(), int32()), - decimal(12, 2), - list_view(int32()), - large_list(int32()), - dense_union({field("string", utf8()), field("number", uint64())}), - sparse_union({field("string", utf8()), field("number", uint64())}), - }) { - ASSERT_RAISES(NotImplemented, scalar.CastTo(to_type)); - } + // Cast() function doesn't support casting map to string, use Scalar::CastTo() instead. + ASSERT_OK_AND_ASSIGN(auto casted_str, scalar.CastTo(utf8())); + ASSERT_TRUE(casted_str->Equals(StringScalar( + R"(map[{key:string = a, value:int8 = 1}, {key:string = b, value:int8 = 2}])"))) + << casted_str->ToString(); } TEST(TestStructScalar, FieldAccess) { @@ -1445,6 +1476,16 @@ TEST(TestStructScalar, ValidateErrors) { ASSERT_RAISES(Invalid, scalar.ValidateFull()); } +TEST(TestStructScalar, Cast) { + auto ty = struct_({field("i", int32()), field("s", utf8())}); + StructScalar scalar({MakeScalar(42), MakeScalar("xxx")}, ty); + + // Cast() function doesn't support casting map to string, use Scalar::CastTo() instead. + ASSERT_OK_AND_ASSIGN(auto casted_str, scalar.CastTo(utf8())); + ASSERT_TRUE(casted_str->Equals(StringScalar(R"({i:int32 = 42, s:string = xxx})"))) + << casted_str->ToString(); +} + TEST(TestDictionaryScalar, Basics) { for (auto index_ty : all_dictionary_index_types()) { auto ty = dictionary(index_ty, utf8()); @@ -1818,34 +1859,12 @@ class TestUnionScalar : public ::testing::Test { } } - // Test Scalar::CastTo goes to the right CastImpl specialization. - void TestCastTo() { - // Supported cast types. - { - ASSERT_OK_AND_ASSIGN(auto casted, union_alpha_->CastTo(utf8())); - ASSERT_TRUE(casted->Equals(StringScalar(R"(union{string: string = alpha})"))); - } - - // Unsupported cast types. - for (auto to_type : { - int8(), - float64(), - date32(), - time32(TimeUnit::SECOND), - timestamp(TimeUnit::SECOND), - duration(TimeUnit::SECOND), - large_binary(), - list(int32()), - struct_({field("f", int32())}), - map(utf8(), int32()), - decimal(12, 2), - list_view(int32()), - large_list(int32()), - dense_union({field("string", utf8()), field("number", uint64())}), - sparse_union({field("string", utf8()), field("number", uint64())}), - }) { - ASSERT_RAISES(NotImplemented, union_alpha_->CastTo(to_type)); - } + void TestCast() { + // Cast() function doesn't support casting union to string, use Scalar::CastTo() + // instead. + ASSERT_OK_AND_ASSIGN(auto casted, union_alpha_->CastTo(utf8())); + ASSERT_TRUE(casted->Equals(StringScalar(R"(union{string: string = alpha})"))) + << casted->ToString(); } protected: @@ -1866,7 +1885,7 @@ TYPED_TEST(TestUnionScalar, Equals) { this->TestEquals(); } TYPED_TEST(TestUnionScalar, MakeNullScalar) { this->TestMakeNullScalar(); } -TYPED_TEST(TestUnionScalar, CastTo) { this->TestCastTo(); } +TYPED_TEST(TestUnionScalar, Cast) { this->TestCast(); } class TestSparseUnionScalar : public TestUnionScalar {}; From c1e0526469796de754d505a8fe527dccdf7a0f2b Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Mon, 8 Apr 2024 23:11:33 +0800 Subject: [PATCH 30/31] Remove code --- cpp/src/arrow/scalar_test.cc | 5 ----- 1 file changed, 5 deletions(-) diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index ca6f26c4a37..104a5697b57 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -709,11 +709,6 @@ TEST(TestStringScalar, MakeScalarString) { ASSERT_EQ(StringScalar("three"), *three); } -TEST(TestStringScalar, Cast) { - std::string s = "test data"; - // TODO: all others. -} - TEST(TestFixedSizeBinaryScalar, Basics) { std::string data = "test data"; auto buf = std::make_shared(data); From 7bf23ab6b9857009283e4ff2fcd962d576468983 Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Fri, 12 Apr 2024 01:47:00 +0800 Subject: [PATCH 31/31] Address comment --- cpp/src/arrow/scalar.cc | 61 ++++++++++++++++++----------------------- 1 file changed, 27 insertions(+), 34 deletions(-) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 902abef7542..8e8d3903663 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -542,19 +542,10 @@ struct ScalarValidateImpl { } }; -template -void FillScalarScratchSpaceHelper(uint8_t* scratch_space, T first, Args... args) { - static_assert(offset + sizeof(T) <= internal::kScalarScratchSpaceSize); - *reinterpret_cast(scratch_space + offset) = first; - if constexpr (sizeof...(args) > 0) { - FillScalarScratchSpaceHelper(scratch_space, - std::forward(args)...); - } -} - -template -void FillScalarScratchSpace(Args... args) { - FillScalarScratchSpaceHelper<0>(std::forward(args)...); +template +void FillScalarScratchSpace(void* scratch_space, T const (&arr)[N]) { + static_assert(sizeof(arr) <= internal::kScalarScratchSpaceSize); + std::memcpy(scratch_space, arr, sizeof(arr)); } } // namespace @@ -573,13 +564,12 @@ BaseBinaryScalar::BaseBinaryScalar(std::string s, std::shared_ptr type : BaseBinaryScalar(Buffer::FromString(std::move(s)), std::move(type)) {} void BinaryScalar::FillScratchSpace() { - // TODO: Test value being nullptr. - FillScalarScratchSpace(scratch_space_, int32_t(0), - value ? static_cast(value->size()) : int32_t(0)); + FillScalarScratchSpace( + scratch_space_, + {int32_t(0), value ? static_cast(value->size()) : int32_t(0)}); } void BinaryViewScalar::FillScratchSpace() { - // TODO: Test value being nullptr. static_assert(sizeof(BinaryViewType::c_type) <= internal::kScalarScratchSpaceSize); auto* view = new (&scratch_space_) BinaryViewType::c_type; if (value) { @@ -590,9 +580,9 @@ void BinaryViewScalar::FillScratchSpace() { } void LargeBinaryScalar::FillScratchSpace() { - // TODO: Test value being nullptr. - FillScalarScratchSpace(scratch_space_, int64_t(0), - value ? static_cast(value->size()) : int64_t(0)); + FillScalarScratchSpace( + scratch_space_, + {int64_t(0), value ? static_cast(value->size()) : int64_t(0)}); } FixedSizeBinaryScalar::FixedSizeBinaryScalar(std::shared_ptr value, @@ -625,32 +615,34 @@ ListScalar::ListScalar(std::shared_ptr value, bool is_valid) : BaseListScalar(value, list(value->type()), is_valid) {} void ListScalar::FillScratchSpace() { - FillScalarScratchSpace(scratch_space_, int32_t(0), - value ? static_cast(value->length()) : int32_t(0)); + FillScalarScratchSpace( + scratch_space_, + {int32_t(0), value ? static_cast(value->length()) : int32_t(0)}); } LargeListScalar::LargeListScalar(std::shared_ptr value, bool is_valid) : BaseListScalar(value, large_list(value->type()), is_valid) {} void LargeListScalar::FillScratchSpace() { - FillScalarScratchSpace(scratch_space_, int64_t(0), - value ? value->length() : int64_t(0)); + FillScalarScratchSpace(scratch_space_, + {int64_t(0), value ? value->length() : int64_t(0)}); } ListViewScalar::ListViewScalar(std::shared_ptr value, bool is_valid) : BaseListScalar(value, list_view(value->type()), is_valid) {} void ListViewScalar::FillScratchSpace() { - FillScalarScratchSpace(scratch_space_, int32_t(0), - value ? static_cast(value->length()) : int32_t(0)); + FillScalarScratchSpace( + scratch_space_, + {int32_t(0), value ? static_cast(value->length()) : int32_t(0)}); } LargeListViewScalar::LargeListViewScalar(std::shared_ptr value, bool is_valid) : BaseListScalar(value, large_list_view(value->type()), is_valid) {} void LargeListViewScalar::FillScratchSpace() { - FillScalarScratchSpace(scratch_space_, int64_t(0), - value ? value->length() : int64_t(0)); + FillScalarScratchSpace(scratch_space_, + {int64_t(0), value ? value->length() : int64_t(0)}); } inline std::shared_ptr MakeMapType(const std::shared_ptr& pair_type) { @@ -663,8 +655,9 @@ MapScalar::MapScalar(std::shared_ptr value, bool is_valid) : BaseListScalar(value, MakeMapType(value->type()), is_valid) {} void MapScalar::FillScratchSpace() { - FillScalarScratchSpace(scratch_space_, int32_t(0), - value ? static_cast(value->length()) : int32_t(0)); + FillScalarScratchSpace( + scratch_space_, + {int32_t(0), value ? static_cast(value->length()) : int32_t(0)}); } FixedSizeListScalar::FixedSizeListScalar(std::shared_ptr value, @@ -727,14 +720,14 @@ void RunEndEncodedScalar::FillScratchSpace() { auto run_end = run_end_type()->id(); switch (run_end) { case Type::INT16: - FillScalarScratchSpace(scratch_space_, int16_t(1)); + FillScalarScratchSpace(scratch_space_, {int16_t(1)}); break; case Type::INT32: - FillScalarScratchSpace(scratch_space_, int32_t(1)); + FillScalarScratchSpace(scratch_space_, {int32_t(1)}); break; default: DCHECK_EQ(run_end, Type::INT64); - FillScalarScratchSpace(scratch_space_, int64_t(1)); + FillScalarScratchSpace(scratch_space_, {int64_t(1)}); } } @@ -848,7 +841,7 @@ void SparseUnionScalar::FillScratchSpace() { void DenseUnionScalar::FillScratchSpace() { auto* union_scratch_space = reinterpret_cast(&scratch_space_); union_scratch_space->type_code = type_code; - FillScalarScratchSpace(union_scratch_space->offsets, int32_t(0), int32_t(1)); + FillScalarScratchSpace(union_scratch_space->offsets, {int32_t(0), int32_t(1)}); } namespace {