Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions cpp/src/arrow/array/array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2068,6 +2068,24 @@ void CheckApproxEquals() {
ASSERT_FALSE(b->ApproxEquals(a, EqualOptions().nans_equal(true)));
}

template <typename TYPE>
void CheckFloatApproxEqualsWithAtol() {
using c_type = typename TYPE::c_type;
auto type = TypeTraits<TYPE>::type_singleton();
std::shared_ptr<Array> a, b;
ArrayFromVector<TYPE>(type, {true}, {static_cast<c_type>(0.5)}, &a);
ArrayFromVector<TYPE>(type, {true}, {static_cast<c_type>(0.6)}, &b);
auto options = EqualOptions::Defaults().atol(0.2);

ASSERT_FALSE(a->Equals(b));
ASSERT_TRUE(a->Equals(b, options.use_atol(true)));
ASSERT_TRUE(a->ApproxEquals(b, options));

ASSERT_FALSE(a->RangeEquals(0, 1, 0, b, options));
ASSERT_TRUE(a->RangeEquals(0, 1, 0, b, options.use_atol(true)));
ASSERT_TRUE(ArrayRangeApproxEquals(*a, *b, 0, 1, 0, options));
}

template <typename TYPE>
void CheckSliceApproxEquals() {
using T = typename TYPE::c_type;
Expand Down Expand Up @@ -2272,6 +2290,11 @@ TEST(TestPrimitiveAdHoc, FloatingApproxEquals) {
CheckApproxEquals<DoubleType>();
}

TEST(TestPrimitiveAdHoc, FloatingApproxEqualsWithAtol) {
CheckFloatApproxEqualsWithAtol<FloatType>();
CheckFloatApproxEqualsWithAtol<DoubleType>();
}

TEST(TestPrimitiveAdHoc, FloatingSliceApproxEquals) {
CheckSliceApproxEquals<FloatType>();
CheckSliceApproxEquals<DoubleType>();
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/array/statistics_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ TEST_F(TestArrayStatisticsEqualityDoubleValue, NaN) {
TEST_F(TestArrayStatisticsEqualityDoubleValue, ApproximateEquals) {
statistics1_.max = 0.5001f;
statistics2_.max = 0.5;
ASSERT_FALSE(statistics1_.Equals(statistics2_, options_.atol(1e-3).use_atol(false)));
ASSERT_TRUE(statistics1_.Equals(statistics2_, options_.atol(1e-3)));
ASSERT_FALSE(statistics1_.Equals(statistics2_, options_.atol(1e-3)));
ASSERT_TRUE(statistics1_.Equals(statistics2_, options_.atol(1e-3).use_atol(true)));
}

} // namespace arrow
55 changes: 16 additions & 39 deletions cpp/src/arrow/chunked_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,33 +98,6 @@ DeviceAllocationTypeSet ChunkedArray::device_types() const {
}
return set;
}

bool ChunkedArray::Equals(const ChunkedArray& other, const EqualOptions& opts) const {
if (length_ != other.length()) {
return false;
}
if (null_count_ != other.null_count()) {
return false;
}
// We cannot toggle check_metadata here yet, so we don't check it
if (!type_->Equals(*other.type_, /*check_metadata=*/false)) {
return false;
}

// Check contents of the underlying arrays. This checks for equality of
// the underlying data independently of the chunk size.
return internal::ApplyBinaryChunked(
*this, other,
[&](const Array& left_piece, const Array& right_piece,
int64_t ARROW_ARG_UNUSED(position)) {
if (!left_piece.Equals(right_piece, opts)) {
return Status::Invalid("Unequal piece");
}
return Status::OK();
})
.ok();
}

namespace {

bool mayHaveNaN(const arrow::DataType& type) {
Expand All @@ -142,19 +115,10 @@ bool mayHaveNaN(const arrow::DataType& type) {

} // namespace

bool ChunkedArray::Equals(const std::shared_ptr<ChunkedArray>& other,
const EqualOptions& opts) const {
if (!other) {
return false;
}
if (this == other.get() && !mayHaveNaN(*type_)) {
bool ChunkedArray::Equals(const ChunkedArray& other, const EqualOptions& opts) const {
if (this == &other && !mayHaveNaN(*type_)) {
return true;
}
return Equals(*other.get(), opts);
}

bool ChunkedArray::ApproxEquals(const ChunkedArray& other,
const EqualOptions& equal_options) const {
if (length_ != other.length()) {
return false;
}
Expand All @@ -172,14 +136,27 @@ bool ChunkedArray::ApproxEquals(const ChunkedArray& other,
*this, other,
[&](const Array& left_piece, const Array& right_piece,
int64_t ARROW_ARG_UNUSED(position)) {
if (!left_piece.ApproxEquals(right_piece, equal_options)) {
if (!left_piece.Equals(right_piece, opts)) {
return Status::Invalid("Unequal piece");
}
return Status::OK();
})
.ok();
}

bool ChunkedArray::Equals(const std::shared_ptr<ChunkedArray>& other,
const EqualOptions& opts) const {
if (!other) {
return false;
}
return Equals(*other.get(), opts);
}

bool ChunkedArray::ApproxEquals(const ChunkedArray& other,
const EqualOptions& equal_options) const {
return Equals(other, equal_options.use_atol(true));
}

Result<std::shared_ptr<Scalar>> ChunkedArray::GetScalar(int64_t index) const {
const auto loc = chunk_resolver_.Resolve(index);
if (loc.chunk_index >= static_cast<int64_t>(chunks_.size())) {
Expand Down
12 changes: 12 additions & 0 deletions cpp/src/arrow/chunked_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,18 @@ TEST_F(TestChunkedArray, EqualsSameAddressWithNaNs) {
ASSERT_TRUE(chunked_array_without_nan2->Equals(chunked_array_without_nan2));
}

TEST_F(TestChunkedArray, ApproxEquals) {
auto chunk_1 = ArrayFromJSON(float64(), R"([0.0, 0.1, 0.5])");
auto chunk_2 = ArrayFromJSON(float64(), R"([0.0, 0.1, 0.5001])");
ASSERT_OK_AND_ASSIGN(auto chunked_array_1, ChunkedArray::Make({chunk_1}));
ASSERT_OK_AND_ASSIGN(auto chunked_array_2, ChunkedArray::Make({chunk_2}));
auto options = EqualOptions::Defaults().atol(1e-3);

ASSERT_FALSE(chunked_array_1->Equals(chunked_array_2));
ASSERT_TRUE(chunked_array_1->Equals(chunked_array_2, options.use_atol(true)));
ASSERT_TRUE(chunked_array_1->ApproxEquals(*chunked_array_2, options));
}

TEST_F(TestChunkedArray, SliceEquals) {
random::RandomArrayGenerator gen(42);

Expand Down
9 changes: 3 additions & 6 deletions cpp/src/arrow/compare.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1155,9 +1155,8 @@ bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& o
bool ArrayRangeEquals(const Array& left, const Array& right, int64_t left_start_idx,
int64_t left_end_idx, int64_t right_start_idx,
const EqualOptions& options) {
const bool floating_approximate = false;
return ArrayRangeEquals(left, right, left_start_idx, left_end_idx, right_start_idx,
options, floating_approximate);
options, options.use_atol());
}

bool ArrayRangeApproxEquals(const Array& left, const Array& right, int64_t left_start_idx,
Expand All @@ -1169,8 +1168,7 @@ bool ArrayRangeApproxEquals(const Array& left, const Array& right, int64_t left_
}

bool ArrayEquals(const Array& left, const Array& right, const EqualOptions& opts) {
const bool floating_approximate = false;
return ArrayEquals(left, right, opts, floating_approximate);
return ArrayEquals(left, right, opts, opts.use_atol());
}

bool ArrayApproxEquals(const Array& left, const Array& right, const EqualOptions& opts) {
Expand All @@ -1179,8 +1177,7 @@ bool ArrayApproxEquals(const Array& left, const Array& right, const EqualOptions
}

bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options) {
const bool floating_approximate = false;
return ScalarEquals(left, right, options, floating_approximate);
return ScalarEquals(left, right, options, options.use_atol());
}

bool ScalarApproxEquals(const Scalar& left, const Scalar& right,
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/arrow/compare.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class EqualOptions {
}

/// Whether the "atol" property is used in the comparison.
///
/// This option only affects the Equals methods
/// and has no effect on ApproxEquals methods.
bool use_atol() const { return use_atol_; }

/// Return a new EqualOptions object with the "use_atol" property changed.
Expand Down Expand Up @@ -99,7 +102,7 @@ class EqualOptions {
double atol_ = kDefaultAbsoluteTolerance;
bool nans_equal_ = false;
bool signed_zeros_equal_ = true;
bool use_atol_ = true;
bool use_atol_ = false;

std::ostream* diff_sink_ = NULLPTR;
};
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/arrow/record_batch_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,10 @@ TEST_F(TestRecordBatch, ApproxEqualOptions) {
EXPECT_FALSE(b1->ApproxEquals(*b2, EqualOptions::Defaults().nans_equal(false)));
EXPECT_FALSE(b1->ApproxEquals(*b2, EqualOptions::Defaults().nans_equal(true)));

EXPECT_TRUE(b1->ApproxEquals(*b2, EqualOptions::Defaults().nans_equal(true).atol(0.1)));
auto options = EqualOptions::Defaults().nans_equal(true).atol(0.1);
EXPECT_FALSE(b1->Equals(*b2, false, options));
EXPECT_TRUE(b1->Equals(*b2, false, options.use_atol(true)));
EXPECT_TRUE(b1->ApproxEquals(*b2, options));
}

TEST_F(TestRecordBatch, Validate) {
Expand Down
10 changes: 10 additions & 0 deletions cpp/src/arrow/scalar_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,14 @@ class TestRealScalar : public ::testing::Test {
ASSERT_FALSE(scalar_zero_->ApproxEquals(*scalar_neg_zero_, options));
}

void TestUseAtol() {
auto options = EqualOptions::Defaults().atol(0.2f);

ASSERT_FALSE(scalar_val_->Equals(*scalar_other_, options));
ASSERT_TRUE(scalar_val_->Equals(*scalar_other_, options.use_atol(true)));
ASSERT_TRUE(scalar_val_->ApproxEquals(*scalar_other_, options));
}

void TestStructOf() {
auto ty = struct_({field("float", type_)});

Expand Down Expand Up @@ -522,6 +530,8 @@ TYPED_TEST(TestRealScalar, SignedZeroEquals) { this->TestSignedZeroEquals(); }

TYPED_TEST(TestRealScalar, ApproxEquals) { this->TestApproxEquals(); }

TYPED_TEST(TestRealScalar, UseAtol) { this->TestUseAtol(); }

TYPED_TEST(TestRealScalar, StructOf) { this->TestStructOf(); }

TYPED_TEST(TestRealScalar, ListOf) { this->TestListOf(); }
Expand Down
Loading