diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index ab0e90c9f47..4c6f97faf95 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -61,6 +61,79 @@ using internal::OptionalBitmapEquals; namespace { +// TODO also handle HALF_FLOAT NaNs + +enum FloatingEqualityFlags : int8_t { Approximate = 1, NansEqual = 2 }; + +template +struct FloatingEquality { + bool operator()(T x, T y) { return x == y; } +}; + +template +struct FloatingEquality { + bool operator()(T x, T y) { return (x == y) || (std::isnan(x) && std::isnan(y)); } +}; + +template +struct FloatingEquality { + explicit FloatingEquality(const EqualOptions& options) + : epsilon(static_cast(options.atol())) {} + + bool operator()(T x, T y) { return (fabs(x - y) <= epsilon) || (x == y); } + + const T epsilon; +}; + +template +struct FloatingEquality { + explicit FloatingEquality(const EqualOptions& options) + : epsilon(static_cast(options.atol())) {} + + bool operator()(T x, T y) { + return (fabs(x - y) <= epsilon) || (x == y) || (std::isnan(x) && std::isnan(y)); + } + + const T epsilon; +}; + +template +void VisitFloatingEquality(const EqualOptions& options, bool floating_approximate, + Visitor&& visit) { + if (options.nans_equal()) { + if (floating_approximate) { + visit(FloatingEquality{options}); + } else { + visit(FloatingEquality{}); + } + } else { + if (floating_approximate) { + visit(FloatingEquality{options}); + } else { + visit(FloatingEquality{}); + } + } +} + +inline bool IdentityImpliesEqualityNansNotEqual(const DataType& type) { + if (type.id() == Type::FLOAT || type.id() == Type::DOUBLE) { + return false; + } + for (const auto& child : type.fields()) { + if (!IdentityImpliesEqualityNansNotEqual(*child->type())) { + return false; + } + } + return true; +} + +inline bool IdentityImpliesEquality(const DataType& type, const EqualOptions& options) { + if (options.nans_equal()) { + return true; + } + return IdentityImpliesEqualityNansNotEqual(type); +} + bool CompareArrayRanges(const ArrayData& left, const ArrayData& right, int64_t left_start_idx, int64_t left_end_idx, int64_t right_start_idx, const EqualOptions& options, @@ -299,6 +372,26 @@ class RangeDataEqualsImpl { } protected: + // For CompareFloating (templated local classes or lambdas not supported in C++11) + template + struct ComparatorVisitor { + RangeDataEqualsImpl* impl; + const CType* left_values; + const CType* right_values; + + template + void operator()(CompareFunction&& compare) { + impl->VisitValues([&](int64_t i) { + const CType x = left_values[i + impl->left_start_idx_]; + const CType y = right_values[i + impl->right_start_idx_]; + return compare(x, y); + }); + } + }; + + template + friend struct ComparatorVisitor; + template Status ComparePrimitive(const TypeClass&) { const CType* left_values = left_.GetValues(1); @@ -312,40 +405,12 @@ class RangeDataEqualsImpl { template Status CompareFloating(const TypeClass&) { - using T = typename TypeClass::c_type; - const T* left_values = left_.GetValues(1); - const T* right_values = right_.GetValues(1); - - if (floating_approximate_) { - const T epsilon = static_cast(options_.atol()); - if (options_.nans_equal()) { - VisitValues([&](int64_t i) { - const T x = left_values[i + left_start_idx_]; - const T y = right_values[i + right_start_idx_]; - return (fabs(x - y) <= epsilon) || (x == y) || (std::isnan(x) && std::isnan(y)); - }); - } else { - VisitValues([&](int64_t i) { - const T x = left_values[i + left_start_idx_]; - const T y = right_values[i + right_start_idx_]; - return (fabs(x - y) <= epsilon) || (x == y); - }); - } - } else { - if (options_.nans_equal()) { - VisitValues([&](int64_t i) { - const T x = left_values[i + left_start_idx_]; - const T y = right_values[i + right_start_idx_]; - return (x == y) || (std::isnan(x) && std::isnan(y)); - }); - } else { - VisitValues([&](int64_t i) { - const T x = left_values[i + left_start_idx_]; - const T y = right_values[i + right_start_idx_]; - return x == y; - }); - } - } + using CType = typename TypeClass::c_type; + const CType* left_values = left_.GetValues(1); + const CType* right_values = right_.GetValues(1); + + ComparatorVisitor visitor{this, left_values, right_values}; + VisitFloatingEquality(options_, floating_approximate_, visitor); return Status::OK(); } @@ -471,7 +536,8 @@ bool CompareArrayRanges(const ArrayData& left, const ArrayData& right, // Right range too small return false; } - if (&left == &right && left_start_idx == right_start_idx) { + if (&left == &right && left_start_idx == right_start_idx && + IdentityImpliesEquality(*left.type, options)) { return true; } // Compare values @@ -605,11 +671,22 @@ class TypeEqualsVisitor { bool result_; }; +bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts, + bool floating_approximate); +bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options, + bool floating_approximate); + class ScalarEqualsVisitor { public: - explicit ScalarEqualsVisitor(const Scalar& right, - const EqualOptions& opts = EqualOptions::Defaults()) - : right_(right), result_(false), options_(opts) {} + // PRE-CONDITIONS: + // - the types are equal + // - the scalars are non-null + explicit ScalarEqualsVisitor(const Scalar& right, const EqualOptions& opts, + bool floating_approximate) + : right_(right), + options_(opts), + floating_approximate_(floating_approximate), + result_(false) {} Status Visit(const NullScalar& left) { result_ = true; @@ -623,33 +700,19 @@ class ScalarEqualsVisitor { } template - typename std::enable_if::value || - std::is_base_of::value, + typename std::enable_if<(is_primitive_ctype::value || + is_temporal_type::value), Status>::type - Visit(const T& left_) { - const auto& right = checked_cast(right_); - if (options_.nans_equal()) { - result_ = right.value == left_.value || - (std::isnan(right.value) && std::isnan(left_.value)); - } else { - result_ = right.value == left_.value; - } - return Status::OK(); - } - - template - typename std::enable_if< - (std::is_base_of, T>::value && - !std::is_base_of::value && - !std::is_base_of::value) || - std::is_base_of, T>::value, - Status>::type Visit(const T& left_) { const auto& right = checked_cast(right_); result_ = right.value == left_.value; return Status::OK(); } + Status Visit(const FloatScalar& left) { return CompareFloating(left); } + + Status Visit(const DoubleScalar& left) { return CompareFloating(left); } + template typename std::enable_if::value, Status>::type Visit(const T& left) { @@ -672,25 +735,25 @@ class ScalarEqualsVisitor { Status Visit(const ListScalar& left) { const auto& right = checked_cast(right_); - result_ = internal::SharedPtrEquals(left.value, right.value); + result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); return Status::OK(); } Status Visit(const LargeListScalar& left) { const auto& right = checked_cast(right_); - result_ = internal::SharedPtrEquals(left.value, right.value); + result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); return Status::OK(); } Status Visit(const MapScalar& left) { const auto& right = checked_cast(right_); - result_ = internal::SharedPtrEquals(left.value, right.value); + result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); return Status::OK(); } Status Visit(const FixedSizeListScalar& left) { const auto& right = checked_cast(right_); - result_ = internal::SharedPtrEquals(left.value, right.value); + result_ = ArrayEquals(*left.value, *right.value, options_, floating_approximate_); return Status::OK(); } @@ -702,7 +765,8 @@ class ScalarEqualsVisitor { } else { bool all_equals = true; for (size_t i = 0; i < left.value.size() && all_equals; i++) { - all_equals &= internal::SharedPtrEquals(left.value[i], right.value[i]); + all_equals &= ScalarEquals(*left.value[i], *right.value[i], options_, + floating_approximate_); } result_ = all_equals; } @@ -713,7 +777,7 @@ class ScalarEqualsVisitor { Status Visit(const UnionScalar& left) { const auto& right = checked_cast(right_); if (left.is_valid && right.is_valid) { - result_ = left.value->Equals(*right.value); + result_ = ScalarEquals(*left.value, *right.value, options_, floating_approximate_); } else if (!left.is_valid && !right.is_valid) { result_ = true; } else { @@ -724,8 +788,10 @@ class ScalarEqualsVisitor { Status Visit(const DictionaryScalar& left) { const auto& right = checked_cast(right_); - result_ = left.value.index->Equals(right.value.index) && - left.value.dictionary->Equals(right.value.dictionary); + result_ = ScalarEquals(*left.value.index, *right.value.index, options_, + floating_approximate_) && + ArrayEquals(*left.value.dictionary, *right.value.dictionary, options_, + floating_approximate_); return Status::OK(); } @@ -736,9 +802,33 @@ class ScalarEqualsVisitor { bool result() const { return result_; } protected: + // For CompareFloating (templated local classes or lambdas not supported in C++11) + template + struct ComparatorVisitor { + const ScalarType& left; + const ScalarType& right; + bool* result; + + template + void operator()(CompareFunction&& compare) { + *result = compare(left.value, right.value); + } + }; + + template + Status CompareFloating(const ScalarType& left) { + using CType = decltype(left.value); + + ComparatorVisitor visitor{left, checked_cast(right_), + &result_}; + VisitFloatingEquality(options_, floating_approximate_, visitor); + return Status::OK(); + } + const Scalar& right_; - bool result_; const EqualOptions options_; + const bool floating_approximate_; + bool result_; }; Status PrintDiff(const Array& left, const Array& right, std::ostream* os); @@ -804,6 +894,35 @@ bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_ return are_equal; } +bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts, + bool floating_approximate) { + if (left.length() != right.length()) { + ARROW_IGNORE_EXPR(PrintDiff(left, right, opts.diff_sink())); + return false; + } + return ArrayRangeEquals(left, right, 0, left.length(), 0, opts, floating_approximate); +} + +bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options, + bool floating_approximate) { + if (&left == &right && IdentityImpliesEquality(*left.type, options)) { + return true; + } + if (!left.type->Equals(right.type)) { + return false; + } + if (left.is_valid != right.is_valid) { + return false; + } + if (!left.is_valid) { + return true; + } + ScalarEqualsVisitor visitor(right, options, floating_approximate); + auto error = VisitScalarInline(left, &visitor); + DCHECK_OK(error); + return visitor.result(); +} + } // namespace bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx, @@ -823,21 +942,24 @@ bool ArrayRangeApproxEquals(const Array& left, const Array& right, int64_t left_ } bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts) { - if (left.length() != right.length()) { - ARROW_IGNORE_EXPR(PrintDiff(left, right, opts.diff_sink())); - return false; - } const bool floating_approximate = false; - return ArrayRangeEquals(left, right, 0, left.length(), 0, opts, floating_approximate); + return ArrayEquals(left, right, opts, floating_approximate); } bool ArrayApproxEquals(const Array& left, const Array& right, const EqualOptions& opts) { - if (left.length() != right.length()) { - ARROW_IGNORE_EXPR(PrintDiff(left, right, opts.diff_sink())); - return false; - } const bool floating_approximate = true; - return ArrayRangeEquals(left, right, 0, left.length(), 0, opts, floating_approximate); + return ArrayEquals(left, right, opts, floating_approximate); +} + +bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options) { + const bool floating_approximate = false; + return ScalarEquals(left, right, options, floating_approximate); +} + +bool ScalarApproxEquals(const Scalar& left, const Scalar& right, + const EqualOptions& options) { + const bool floating_approximate = true; + return ScalarEquals(left, right, options, floating_approximate); } namespace { @@ -1179,21 +1301,4 @@ bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata } } -bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options) { - bool are_equal = false; - if (&left == &right) { - are_equal = true; - } else if (!left.type->Equals(right.type)) { - are_equal = false; - } else if (left.is_valid != right.is_valid) { - are_equal = false; - } else { - ScalarEqualsVisitor visitor(right, options); - auto error = VisitScalarInline(left, &visitor); - DCHECK_OK(error); - are_equal = visitor.result(); - } - return are_equal; -} - } // namespace arrow diff --git a/cpp/src/arrow/compare.h b/cpp/src/arrow/compare.h index 73a8401460b..387105de9e7 100644 --- a/cpp/src/arrow/compare.h +++ b/cpp/src/arrow/compare.h @@ -122,4 +122,12 @@ bool ARROW_EXPORT TypeEquals(const DataType& left, const DataType& right, bool ARROW_EXPORT ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options = EqualOptions::Defaults()); +/// Returns true if scalars are approximately equal +/// \param[in] left a Scalar +/// \param[in] right a Scalar +/// \param[in] options comparison options +bool ARROW_EXPORT +ScalarApproxEquals(const Scalar& left, const Scalar& right, + const EqualOptions& options = EqualOptions::Defaults()); + } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.h b/cpp/src/arrow/compute/kernels/codegen_internal.h index 2f00a23f7fc..75bc7ba99d3 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.h +++ b/cpp/src/arrow/compute/kernels/codegen_internal.h @@ -512,28 +512,30 @@ struct ScalarUnary { using OutValue = typename GetOutputType::T; using Arg0Value = typename GetViewType::T; - static void Array(KernelContext* ctx, const ArrayData& arg0, Datum* out) { + static void ExecArray(KernelContext* ctx, const ArrayData& arg0, Datum* out) { ArrayIterator arg0_it(arg0); OutputAdapter::Write(ctx, out, [&]() -> OutValue { return Op::template Call(ctx, arg0_it()); }); } - static void Scalar(KernelContext* ctx, const Scalar& arg0, Datum* out) { + static void ExecScalar(KernelContext* ctx, const Scalar& arg0, Datum* out) { + Scalar* out_scalar = out->scalar().get(); if (arg0.is_valid) { Arg0Value arg0_val = UnboxScalar::Unbox(arg0); + out_scalar->is_valid = true; BoxScalar::Box(Op::template Call(ctx, arg0_val), - out->scalar().get()); + out_scalar); } else { - out->value = MakeNullScalar(arg0.type); + out_scalar->is_valid = false; } } static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { if (batch[0].kind() == Datum::ARRAY) { - return Array(ctx, *batch[0].array(), out); + return ExecArray(ctx, *batch[0].array(), out); } else { - return Scalar(ctx, *batch[0].scalar(), out); + return ExecScalar(ctx, *batch[0].scalar(), out); } } }; diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index 9b3ed2a57e3..a19abe82873 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -84,7 +84,7 @@ class TestBinaryArithmetic : public TestBase { auto exp = MakeScalar(expected); ASSERT_OK_AND_ASSIGN(auto actual, func(left, right, options_, nullptr)); - AssertScalarsEqual(*exp, *actual.scalar(), /*verbose=*/true); + AssertScalarsApproxEqual(*exp, *actual.scalar(), /*verbose=*/true); } // (Scalar, Array) @@ -144,8 +144,8 @@ class TestBinaryArithmetic : public TestBase { const auto expected_scalar = *expected->GetScalar(i); ASSERT_OK_AND_ASSIGN( actual, func(*left->GetScalar(i), *right->GetScalar(i), options_, nullptr)); - AssertScalarsEqual(*expected_scalar, *actual.scalar(), /*verbose=*/true, - equal_options_); + AssertScalarsApproxEqual(*expected_scalar, *actual.scalar(), /*verbose=*/true, + equal_options_); } } diff --git a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc index 68d6f09c76d..ec612e497f4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_cast_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_cast_test.cc @@ -94,7 +94,7 @@ class TestCast : public TestBase { AssertArraysEqual(expected, *result, /*verbose=*/true); if (input.type_id() == Type::DECIMAL || out_type->id() == Type::DECIMAL) { - // ARROW-9194 + // ARROW-10835 check_scalar = false; } @@ -111,7 +111,7 @@ class TestCast : public TestBase { ASSERT_RAISES(Invalid, Cast(input, out_type, options)); if (input.type_id() == Type::DECIMAL || out_type->id() == Type::DECIMAL) { - // ARROW-9194 + // ARROW-10835 check_scalar = false; } diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 9e038024e06..4c5b2e1bcee 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -44,6 +44,10 @@ bool Scalar::Equals(const Scalar& other, const EqualOptions& options) const { return ScalarEquals(*this, other, options); } +bool Scalar::ApproxEquals(const Scalar& other, const EqualOptions& options) const { + return ScalarApproxEquals(*this, other, options); +} + struct ScalarHashImpl { static std::hash string_hash; diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 80157a750cb..1fa866c8623 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -65,6 +65,9 @@ struct ARROW_EXPORT Scalar : public util::EqualityComparable { bool Equals(const Scalar& other, const EqualOptions& options = EqualOptions::Defaults()) const; + bool ApproxEquals(const Scalar& other, + const EqualOptions& options = EqualOptions::Defaults()) const; + struct ARROW_EXPORT Hash { size_t operator()(const Scalar& scalar) const { return hash(scalar); } diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index 71f1ae04ce2..30a39e6e4c0 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. +#include #include +#include #include #include #include @@ -96,6 +98,12 @@ TYPED_TEST(TestNumericScalar, Basics) { ASSERT_FALSE(one->Equals(ScalarType(2))); ASSERT_TRUE(two->Equals(ScalarType(2))); ASSERT_FALSE(two->Equals(ScalarType(3))); + + ASSERT_TRUE(null->ApproxEquals(*null_value)); + ASSERT_TRUE(one->ApproxEquals(ScalarType(1))); + ASSERT_FALSE(one->ApproxEquals(ScalarType(2))); + ASSERT_TRUE(two->ApproxEquals(ScalarType(2))); + ASSERT_FALSE(two->ApproxEquals(ScalarType(3))); } TYPED_TEST(TestNumericScalar, Hashing) { @@ -127,6 +135,199 @@ TYPED_TEST(TestNumericScalar, MakeScalar) { ASSERT_EQ(ScalarType(3), *three); } +template +class TestRealScalar : public ::testing::Test { + public: + using CType = typename T::c_type; + using ScalarType = typename TypeTraits::ScalarType; + + void SetUp() { + type_ = TypeTraits::type_singleton(); + + scalar_val_ = std::make_shared(static_cast(1)); + ASSERT_TRUE(scalar_val_->is_valid); + + scalar_other_ = std::make_shared(static_cast(1.1)); + ASSERT_TRUE(scalar_other_->is_valid); + + const CType nan_value = std::numeric_limits::quiet_NaN(); + scalar_nan_ = std::make_shared(nan_value); + ASSERT_TRUE(scalar_nan_->is_valid); + + const CType other_nan_value = std::numeric_limits::quiet_NaN(); + scalar_other_nan_ = std::make_shared(other_nan_value); + ASSERT_TRUE(scalar_other_nan_->is_valid); + } + + void TestNanEquals() { + EqualOptions options = EqualOptions::Defaults(); + ASSERT_FALSE(scalar_nan_->Equals(*scalar_val_, options)); + ASSERT_FALSE(scalar_nan_->Equals(*scalar_nan_, options)); + ASSERT_FALSE(scalar_nan_->Equals(*scalar_other_nan_, options)); + + options = options.nans_equal(true); + ASSERT_FALSE(scalar_nan_->Equals(*scalar_val_, options)); + ASSERT_TRUE(scalar_nan_->Equals(*scalar_nan_, options)); + ASSERT_TRUE(scalar_nan_->Equals(*scalar_other_nan_, options)); + } + + void TestApproxEquals() { + // The scalars are unequal with the small delta + EqualOptions options = EqualOptions::Defaults().atol(0.05); + ASSERT_FALSE(scalar_val_->ApproxEquals(*scalar_other_, options)); + ASSERT_FALSE(scalar_other_->ApproxEquals(*scalar_val_, options)); + ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_val_, options)); + ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_other_nan_, options)); + + // After enlarging the delta, they become equal + options = options.atol(0.15); + ASSERT_TRUE(scalar_val_->ApproxEquals(*scalar_other_, options)); + ASSERT_TRUE(scalar_other_->ApproxEquals(*scalar_val_, options)); + ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_val_, options)); + ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_other_nan_, options)); + + options = options.nans_equal(true); + ASSERT_TRUE(scalar_val_->ApproxEquals(*scalar_other_, options)); + ASSERT_TRUE(scalar_other_->ApproxEquals(*scalar_val_, options)); + ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_val_, options)); + ASSERT_TRUE(scalar_nan_->ApproxEquals(*scalar_other_nan_, options)); + + options = options.atol(0.05); + ASSERT_FALSE(scalar_val_->ApproxEquals(*scalar_other_, options)); + ASSERT_FALSE(scalar_other_->ApproxEquals(*scalar_val_, options)); + ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_val_, options)); + ASSERT_TRUE(scalar_nan_->ApproxEquals(*scalar_other_nan_, options)); + } + + void TestStructOf() { + auto ty = struct_({field("float", type_)}); + + StructScalar struct_val({scalar_val_}, ty); + StructScalar struct_other_val({scalar_other_}, ty); + StructScalar struct_nan({scalar_nan_}, ty); + StructScalar struct_other_nan({scalar_other_nan_}, ty); + + EqualOptions options = EqualOptions::Defaults().atol(0.05); + ASSERT_FALSE(struct_val.Equals(struct_other_val, options)); + ASSERT_FALSE(struct_other_val.Equals(struct_val, options)); + ASSERT_FALSE(struct_nan.Equals(struct_val, options)); + ASSERT_FALSE(struct_nan.Equals(struct_nan, options)); + ASSERT_FALSE(struct_nan.Equals(struct_other_nan, options)); + ASSERT_FALSE(struct_val.ApproxEquals(struct_other_val, options)); + ASSERT_FALSE(struct_other_val.ApproxEquals(struct_val, options)); + ASSERT_FALSE(struct_nan.ApproxEquals(struct_val, options)); + ASSERT_FALSE(struct_nan.ApproxEquals(struct_nan, options)); + ASSERT_FALSE(struct_nan.ApproxEquals(struct_other_nan, options)); + + options = options.atol(0.15); + ASSERT_FALSE(struct_val.Equals(struct_other_val, options)); + ASSERT_FALSE(struct_other_val.Equals(struct_val, options)); + ASSERT_FALSE(struct_nan.Equals(struct_val, options)); + ASSERT_FALSE(struct_nan.Equals(struct_nan, options)); + ASSERT_FALSE(struct_nan.Equals(struct_other_nan, options)); + ASSERT_TRUE(struct_val.ApproxEquals(struct_other_val, options)); + ASSERT_TRUE(struct_other_val.ApproxEquals(struct_val, options)); + ASSERT_FALSE(struct_nan.ApproxEquals(struct_val, options)); + ASSERT_FALSE(struct_nan.ApproxEquals(struct_nan, options)); + ASSERT_FALSE(struct_nan.ApproxEquals(struct_other_nan, options)); + + options = options.nans_equal(true); + ASSERT_FALSE(struct_val.Equals(struct_other_val, options)); + ASSERT_FALSE(struct_other_val.Equals(struct_val, options)); + ASSERT_FALSE(struct_nan.Equals(struct_val, options)); + ASSERT_TRUE(struct_nan.Equals(struct_nan, options)); + ASSERT_TRUE(struct_nan.Equals(struct_other_nan, options)); + ASSERT_TRUE(struct_val.ApproxEquals(struct_other_val, options)); + ASSERT_TRUE(struct_other_val.ApproxEquals(struct_val, options)); + ASSERT_FALSE(struct_nan.ApproxEquals(struct_val, options)); + ASSERT_TRUE(struct_nan.ApproxEquals(struct_nan, options)); + ASSERT_TRUE(struct_nan.ApproxEquals(struct_other_nan, options)); + + options = options.atol(0.05); + ASSERT_FALSE(struct_val.Equals(struct_other_val, options)); + ASSERT_FALSE(struct_other_val.Equals(struct_val, options)); + ASSERT_FALSE(struct_nan.Equals(struct_val, options)); + ASSERT_TRUE(struct_nan.Equals(struct_nan, options)); + ASSERT_TRUE(struct_nan.Equals(struct_other_nan, options)); + ASSERT_FALSE(struct_val.ApproxEquals(struct_other_val, options)); + ASSERT_FALSE(struct_other_val.ApproxEquals(struct_val, options)); + ASSERT_FALSE(struct_nan.ApproxEquals(struct_val, options)); + ASSERT_TRUE(struct_nan.ApproxEquals(struct_nan, options)); + ASSERT_TRUE(struct_nan.ApproxEquals(struct_other_nan, options)); + } + + void TestListOf() { + auto ty = list(type_); + + ListScalar list_val(ArrayFromJSON(type_, "[0, null, 1.0]"), ty); + ListScalar list_other_val(ArrayFromJSON(type_, "[0, null, 1.1]"), ty); + ListScalar list_nan(ArrayFromJSON(type_, "[0, null, NaN]"), ty); + ListScalar list_other_nan(ArrayFromJSON(type_, "[0, null, NaN]"), ty); + + EqualOptions options = EqualOptions::Defaults().atol(0.05); + ASSERT_TRUE(list_val.Equals(list_val, options)); + ASSERT_FALSE(list_val.Equals(list_other_val, options)); + ASSERT_FALSE(list_nan.Equals(list_val, options)); + ASSERT_FALSE(list_nan.Equals(list_nan, options)); + ASSERT_FALSE(list_nan.Equals(list_other_nan, options)); + ASSERT_TRUE(list_val.ApproxEquals(list_val, options)); + ASSERT_FALSE(list_val.ApproxEquals(list_other_val, options)); + ASSERT_FALSE(list_nan.ApproxEquals(list_val, options)); + ASSERT_FALSE(list_nan.ApproxEquals(list_nan, options)); + ASSERT_FALSE(list_nan.ApproxEquals(list_other_nan, options)); + + options = options.atol(0.15); + ASSERT_TRUE(list_val.Equals(list_val, options)); + ASSERT_FALSE(list_val.Equals(list_other_val, options)); + ASSERT_FALSE(list_nan.Equals(list_val, options)); + ASSERT_FALSE(list_nan.Equals(list_nan, options)); + ASSERT_FALSE(list_nan.Equals(list_other_nan, options)); + ASSERT_TRUE(list_val.ApproxEquals(list_val, options)); + ASSERT_TRUE(list_val.ApproxEquals(list_other_val, options)); + ASSERT_FALSE(list_nan.ApproxEquals(list_val, options)); + ASSERT_FALSE(list_nan.ApproxEquals(list_nan, options)); + ASSERT_FALSE(list_nan.ApproxEquals(list_other_nan, options)); + + options = options.nans_equal(true); + ASSERT_TRUE(list_val.Equals(list_val, options)); + ASSERT_FALSE(list_val.Equals(list_other_val, options)); + ASSERT_FALSE(list_nan.Equals(list_val, options)); + ASSERT_TRUE(list_nan.Equals(list_nan, options)); + ASSERT_TRUE(list_nan.Equals(list_other_nan, options)); + ASSERT_TRUE(list_val.ApproxEquals(list_val, options)); + ASSERT_TRUE(list_val.ApproxEquals(list_other_val, options)); + ASSERT_FALSE(list_nan.ApproxEquals(list_val, options)); + ASSERT_TRUE(list_nan.ApproxEquals(list_nan, options)); + ASSERT_TRUE(list_nan.ApproxEquals(list_other_nan, options)); + + options = options.atol(0.05); + ASSERT_TRUE(list_val.Equals(list_val, options)); + ASSERT_FALSE(list_val.Equals(list_other_val, options)); + ASSERT_FALSE(list_nan.Equals(list_val, options)); + ASSERT_TRUE(list_nan.Equals(list_nan, options)); + ASSERT_TRUE(list_nan.Equals(list_other_nan, options)); + ASSERT_TRUE(list_val.ApproxEquals(list_val, options)); + ASSERT_FALSE(list_val.ApproxEquals(list_other_val, options)); + ASSERT_FALSE(list_nan.ApproxEquals(list_val, options)); + ASSERT_TRUE(list_nan.ApproxEquals(list_nan, options)); + ASSERT_TRUE(list_nan.ApproxEquals(list_other_nan, options)); + } + + protected: + std::shared_ptr type_; + std::shared_ptr scalar_val_, scalar_other_, scalar_nan_, scalar_other_nan_; +}; + +TYPED_TEST_SUITE(TestRealScalar, RealArrowTypes); + +TYPED_TEST(TestRealScalar, NanEquals) { this->TestNanEquals(); } + +TYPED_TEST(TestRealScalar, ApproxEquals) { this->TestApproxEquals(); } + +TYPED_TEST(TestRealScalar, StructOf) { this->TestStructOf(); } + +TYPED_TEST(TestRealScalar, ListOf) { this->TestListOf(); } + TEST(TestDecimal128Scalar, Basics) { auto ty = decimal128(3, 2); auto pi = Decimal128Scalar(Decimal128("3.14"), ty); diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 41ab7a4e8b5..c89831072b4 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -148,13 +148,20 @@ void AssertArraysApproxEqual(const Array& expected, const Array& actual, bool ve void AssertScalarsEqual(const Scalar& expected, const Scalar& actual, bool verbose, const EqualOptions& options) { - std::stringstream diff; - // ARROW-8956, ScalarEquals returns false when both are null - if (!expected.is_valid && !actual.is_valid) { - // We consider both being null to be equal in this function - return; - } if (!expected.Equals(actual, options)) { + std::stringstream diff; + if (verbose) { + diff << "Expected:\n" << expected.ToString(); + diff << "\nActual:\n" << actual.ToString(); + } + FAIL() << diff.str(); + } +} + +void AssertScalarsApproxEqual(const Scalar& expected, const Scalar& actual, bool verbose, + const EqualOptions& options) { + if (!expected.ApproxEquals(actual, options)) { + std::stringstream diff; if (verbose) { diff << "Expected:\n" << expected.ToString(); diff << "\nActual:\n" << actual.ToString(); diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 7a0c9b7c257..96695c33ef2 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -171,6 +171,9 @@ ARROW_TESTING_EXPORT void AssertArraysApproxEqual( ARROW_TESTING_EXPORT void AssertScalarsEqual( const Scalar& expected, const Scalar& actual, bool verbose = false, const EqualOptions& options = EqualOptions::Defaults()); +ARROW_TESTING_EXPORT void AssertScalarsApproxEqual( + const Scalar& expected, const Scalar& actual, bool verbose = false, + const EqualOptions& options = EqualOptions::Defaults()); ARROW_TESTING_EXPORT void AssertBatchesEqual(const RecordBatch& expected, const RecordBatch& actual, bool check_metadata = false);