diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index 87e86d89677..e5a27d18d00 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -2068,6 +2068,24 @@ void CheckApproxEquals() { ASSERT_FALSE(b->ApproxEquals(a, EqualOptions().nans_equal(true))); } +template +void CheckFloatApproxEqualsWithAtol() { + using c_type = typename TYPE::c_type; + auto type = TypeTraits::type_singleton(); + std::shared_ptr a, b; + ArrayFromVector(type, {true}, {static_cast(0.5)}, &a); + ArrayFromVector(type, {true}, {static_cast(0.6)}, &b); + auto options = EqualOptions::Defaults().atol(0.2); + + ASSERT_FALSE(a->Equals(b)); + ASSERT_TRUE(a->Equals(b, options.use_atol(true))); + ASSERT_TRUE(a->ApproxEquals(b, options)); + + ASSERT_FALSE(a->RangeEquals(0, 1, 0, b, options)); + ASSERT_TRUE(a->RangeEquals(0, 1, 0, b, options.use_atol(true))); + ASSERT_TRUE(ArrayRangeApproxEquals(*a, *b, 0, 1, 0, options)); +} + template void CheckSliceApproxEquals() { using T = typename TYPE::c_type; @@ -2272,6 +2290,11 @@ TEST(TestPrimitiveAdHoc, FloatingApproxEquals) { CheckApproxEquals(); } +TEST(TestPrimitiveAdHoc, FloatingApproxEqualsWithAtol) { + CheckFloatApproxEqualsWithAtol(); + CheckFloatApproxEqualsWithAtol(); +} + TEST(TestPrimitiveAdHoc, FloatingSliceApproxEquals) { CheckSliceApproxEquals(); CheckSliceApproxEquals(); diff --git a/cpp/src/arrow/array/statistics_test.cc b/cpp/src/arrow/array/statistics_test.cc index 95199a9683b..250c4bb437a 100644 --- a/cpp/src/arrow/array/statistics_test.cc +++ b/cpp/src/arrow/array/statistics_test.cc @@ -148,8 +148,8 @@ TEST_F(TestArrayStatisticsEqualityDoubleValue, NaN) { TEST_F(TestArrayStatisticsEqualityDoubleValue, ApproximateEquals) { statistics1_.max = 0.5001f; statistics2_.max = 0.5; - ASSERT_FALSE(statistics1_.Equals(statistics2_, options_.atol(1e-3).use_atol(false))); - ASSERT_TRUE(statistics1_.Equals(statistics2_, options_.atol(1e-3))); + ASSERT_FALSE(statistics1_.Equals(statistics2_, options_.atol(1e-3))); + ASSERT_TRUE(statistics1_.Equals(statistics2_, options_.atol(1e-3).use_atol(true))); } } // namespace arrow diff --git a/cpp/src/arrow/chunked_array.cc b/cpp/src/arrow/chunked_array.cc index 988fc148632..32578ffd93f 100644 --- a/cpp/src/arrow/chunked_array.cc +++ b/cpp/src/arrow/chunked_array.cc @@ -98,33 +98,6 @@ DeviceAllocationTypeSet ChunkedArray::device_types() const { } return set; } - -bool ChunkedArray::Equals(const ChunkedArray& other, const EqualOptions& opts) const { - if (length_ != other.length()) { - return false; - } - if (null_count_ != other.null_count()) { - return false; - } - // We cannot toggle check_metadata here yet, so we don't check it - if (!type_->Equals(*other.type_, /*check_metadata=*/false)) { - return false; - } - - // Check contents of the underlying arrays. This checks for equality of - // the underlying data independently of the chunk size. - return internal::ApplyBinaryChunked( - *this, other, - [&](const Array& left_piece, const Array& right_piece, - int64_t ARROW_ARG_UNUSED(position)) { - if (!left_piece.Equals(right_piece, opts)) { - return Status::Invalid("Unequal piece"); - } - return Status::OK(); - }) - .ok(); -} - namespace { bool mayHaveNaN(const arrow::DataType& type) { @@ -142,19 +115,10 @@ bool mayHaveNaN(const arrow::DataType& type) { } // namespace -bool ChunkedArray::Equals(const std::shared_ptr& other, - const EqualOptions& opts) const { - if (!other) { - return false; - } - if (this == other.get() && !mayHaveNaN(*type_)) { +bool ChunkedArray::Equals(const ChunkedArray& other, const EqualOptions& opts) const { + if (this == &other && !mayHaveNaN(*type_)) { return true; } - return Equals(*other.get(), opts); -} - -bool ChunkedArray::ApproxEquals(const ChunkedArray& other, - const EqualOptions& equal_options) const { if (length_ != other.length()) { return false; } @@ -172,7 +136,7 @@ bool ChunkedArray::ApproxEquals(const ChunkedArray& other, *this, other, [&](const Array& left_piece, const Array& right_piece, int64_t ARROW_ARG_UNUSED(position)) { - if (!left_piece.ApproxEquals(right_piece, equal_options)) { + if (!left_piece.Equals(right_piece, opts)) { return Status::Invalid("Unequal piece"); } return Status::OK(); @@ -180,6 +144,19 @@ bool ChunkedArray::ApproxEquals(const ChunkedArray& other, .ok(); } +bool ChunkedArray::Equals(const std::shared_ptr& other, + const EqualOptions& opts) const { + if (!other) { + return false; + } + return Equals(*other.get(), opts); +} + +bool ChunkedArray::ApproxEquals(const ChunkedArray& other, + const EqualOptions& equal_options) const { + return Equals(other, equal_options.use_atol(true)); +} + Result> ChunkedArray::GetScalar(int64_t index) const { const auto loc = chunk_resolver_.Resolve(index); if (loc.chunk_index >= static_cast(chunks_.size())) { diff --git a/cpp/src/arrow/chunked_array_test.cc b/cpp/src/arrow/chunked_array_test.cc index b3944fd1b19..689ef57c59a 100644 --- a/cpp/src/arrow/chunked_array_test.cc +++ b/cpp/src/arrow/chunked_array_test.cc @@ -182,6 +182,18 @@ TEST_F(TestChunkedArray, EqualsSameAddressWithNaNs) { ASSERT_TRUE(chunked_array_without_nan2->Equals(chunked_array_without_nan2)); } +TEST_F(TestChunkedArray, ApproxEquals) { + auto chunk_1 = ArrayFromJSON(float64(), R"([0.0, 0.1, 0.5])"); + auto chunk_2 = ArrayFromJSON(float64(), R"([0.0, 0.1, 0.5001])"); + ASSERT_OK_AND_ASSIGN(auto chunked_array_1, ChunkedArray::Make({chunk_1})); + ASSERT_OK_AND_ASSIGN(auto chunked_array_2, ChunkedArray::Make({chunk_2})); + auto options = EqualOptions::Defaults().atol(1e-3); + + ASSERT_FALSE(chunked_array_1->Equals(chunked_array_2)); + ASSERT_TRUE(chunked_array_1->Equals(chunked_array_2, options.use_atol(true))); + ASSERT_TRUE(chunked_array_1->ApproxEquals(*chunked_array_2, options)); +} + TEST_F(TestChunkedArray, SliceEquals) { random::RandomArrayGenerator gen(42); diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index a7df9efdbbb..6ece1cb444c 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -1155,9 +1155,8 @@ bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& o bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx, int64_t left_end_idx, int64_t right_start_idx, const EqualOptions& options) { - const bool floating_approximate = false; return ArrayRangeEquals(left, right, left_start_idx, left_end_idx, right_start_idx, - options, floating_approximate); + options, options.use_atol()); } bool ArrayRangeApproxEquals(const Array& left, const Array& right, int64_t left_start_idx, @@ -1169,8 +1168,7 @@ bool ArrayRangeApproxEquals(const Array& left, const Array& right, int64_t left_ } bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts) { - const bool floating_approximate = false; - return ArrayEquals(left, right, opts, floating_approximate); + return ArrayEquals(left, right, opts, opts.use_atol()); } bool ArrayApproxEquals(const Array& left, const Array& right, const EqualOptions& opts) { @@ -1179,8 +1177,7 @@ bool ArrayApproxEquals(const Array& left, const Array& right, const EqualOptions } bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options) { - const bool floating_approximate = false; - return ScalarEquals(left, right, options, floating_approximate); + return ScalarEquals(left, right, options, options.use_atol()); } bool ScalarApproxEquals(const Scalar& left, const Scalar& right, diff --git a/cpp/src/arrow/compare.h b/cpp/src/arrow/compare.h index ec7dc8bda18..4d2282c982a 100644 --- a/cpp/src/arrow/compare.h +++ b/cpp/src/arrow/compare.h @@ -60,6 +60,9 @@ class EqualOptions { } /// Whether the "atol" property is used in the comparison. + /// + /// This option only affects the Equals methods + /// and has no effect on ApproxEquals methods. bool use_atol() const { return use_atol_; } /// Return a new EqualOptions object with the "use_atol" property changed. @@ -99,7 +102,7 @@ class EqualOptions { double atol_ = kDefaultAbsoluteTolerance; bool nans_equal_ = false; bool signed_zeros_equal_ = true; - bool use_atol_ = true; + bool use_atol_ = false; std::ostream* diff_sink_ = NULLPTR; }; diff --git a/cpp/src/arrow/record_batch_test.cc b/cpp/src/arrow/record_batch_test.cc index 3dc847bb96c..0572883441f 100644 --- a/cpp/src/arrow/record_batch_test.cc +++ b/cpp/src/arrow/record_batch_test.cc @@ -136,7 +136,10 @@ TEST_F(TestRecordBatch, ApproxEqualOptions) { EXPECT_FALSE(b1->ApproxEquals(*b2, EqualOptions::Defaults().nans_equal(false))); EXPECT_FALSE(b1->ApproxEquals(*b2, EqualOptions::Defaults().nans_equal(true))); - EXPECT_TRUE(b1->ApproxEquals(*b2, EqualOptions::Defaults().nans_equal(true).atol(0.1))); + auto options = EqualOptions::Defaults().nans_equal(true).atol(0.1); + EXPECT_FALSE(b1->Equals(*b2, false, options)); + EXPECT_TRUE(b1->Equals(*b2, false, options.use_atol(true))); + EXPECT_TRUE(b1->ApproxEquals(*b2, options)); } TEST_F(TestRecordBatch, Validate) { diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index 6938bc0d887..66f2daf7b9e 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -387,6 +387,14 @@ class TestRealScalar : public ::testing::Test { ASSERT_FALSE(scalar_zero_->ApproxEquals(*scalar_neg_zero_, options)); } + void TestUseAtol() { + auto options = EqualOptions::Defaults().atol(0.2f); + + ASSERT_FALSE(scalar_val_->Equals(*scalar_other_, options)); + ASSERT_TRUE(scalar_val_->Equals(*scalar_other_, options.use_atol(true))); + ASSERT_TRUE(scalar_val_->ApproxEquals(*scalar_other_, options)); + } + void TestStructOf() { auto ty = struct_({field("float", type_)}); @@ -522,6 +530,8 @@ TYPED_TEST(TestRealScalar, SignedZeroEquals) { this->TestSignedZeroEquals(); } TYPED_TEST(TestRealScalar, ApproxEquals) { this->TestApproxEquals(); } +TYPED_TEST(TestRealScalar, UseAtol) { this->TestUseAtol(); } + TYPED_TEST(TestRealScalar, StructOf) { this->TestStructOf(); } TYPED_TEST(TestRealScalar, ListOf) { this->TestListOf(); }