From f042310e1c826bb3ffec58c66186361612f386e5 Mon Sep 17 00:00:00 2001 From: Benjamin Kietzman Date: Wed, 24 Feb 2021 13:16:31 -0500 Subject: [PATCH] ARROW-11767: [C++] Scalar::Hash may segfault --- cpp/src/arrow/scalar.cc | 14 +++++++++++--- cpp/src/arrow/scalar_test.cc | 19 +++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index ee4d0ecad8f..399eac675f4 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -90,9 +90,13 @@ struct ScalarHashImpl { return Status::OK(); } + Status Visit(const DictionaryScalar& s) { + AccumulateHashFrom(*s.value.index); + return Status::OK(); + } + // TODO(bkietz) implement less wimpy hashing when these have ValueType Status Visit(const UnionScalar& s) { return Status::OK(); } - Status Visit(const DictionaryScalar& s) { return Status::OK(); } Status Visit(const ExtensionScalar& s) { return Status::OK(); } template @@ -127,14 +131,18 @@ struct ScalarHashImpl { return Status::OK(); } - explicit ScalarHashImpl(const Scalar& scalar) { AccumulateHashFrom(scalar); } + explicit ScalarHashImpl(const Scalar& scalar) : hash_(scalar.type->Hash()) { + if (scalar.is_valid) { + AccumulateHashFrom(scalar); + } + } void AccumulateHashFrom(const Scalar& scalar) { DCHECK_OK(StdHash(scalar.type->fingerprint())); DCHECK_OK(VisitScalarInline(scalar, this)); } - size_t hash_ = 0; + size_t hash_; }; size_t Scalar::Hash::hash(const Scalar& scalar) { return ScalarHashImpl(scalar).hash_; } diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index 16c2f92d13b..d99debb2ba9 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -111,10 +111,12 @@ TYPED_TEST(TestNumericScalar, Hashing) { using ScalarType = typename TypeTraits::ScalarType; std::unordered_set, Scalar::Hash, Scalar::PtrsEqual> set; + set.emplace(std::make_shared()); for (T i = 0; i < 10; ++i) { set.emplace(std::make_shared(i)); } + ASSERT_FALSE(set.emplace(std::make_shared()).second); for (T i = 0; i < 10; ++i) { ASSERT_FALSE(set.emplace(std::make_shared(i)).second); } @@ -406,6 +408,23 @@ TEST(TestBinaryScalar, Basics) { ASSERT_FALSE(two->Equals(BinaryScalar(Buffer::FromString("else")))); } +TEST(TestBinaryScalar, Hashing) { + auto FromInt = [](int i) { + return std::make_shared(Buffer::FromString(std::to_string(i))); + }; + + std::unordered_set, Scalar::Hash, Scalar::PtrsEqual> set; + set.emplace(std::make_shared()); + for (int i = 0; i < 10; ++i) { + set.emplace(FromInt(i)); + } + + ASSERT_FALSE(set.emplace(std::make_shared()).second); + for (int i = 0; i < 10; ++i) { + ASSERT_FALSE(set.emplace(FromInt(i)).second); + } +} + TEST(TestStringScalar, MakeScalar) { auto three = MakeScalar("three"); ASSERT_EQ(StringScalar("three"), *three);