diff --git a/cpp/cmake_modules/san-config.cmake b/cpp/cmake_modules/san-config.cmake index 2e2807801fb..5eee6278009 100644 --- a/cpp/cmake_modules/san-config.cmake +++ b/cpp/cmake_modules/san-config.cmake @@ -35,14 +35,17 @@ endif() # - disable 'vptr' because of RTTI issues across shared libraries (?) # - disable 'alignment' because unaligned access is really OK on Nehalem and we do it # all over the place. -# - disable 'function' because it appears to give a false positive https://github.com/google/sanitizers/issues/911 +# - disable 'function' because it appears to give a false positive +# (https://github.com/google/sanitizers/issues/911) +# - disable 'float-divide-by-zero' on clang, which considers it UB +# (https://bugs.llvm.org/show_bug.cgi?id=17000#c1) # Note: GCC does not support the 'function' flag. if(${ARROW_USE_UBSAN}) if(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") set( CMAKE_CXX_FLAGS - "${CMAKE_CXX_FLAGS} -fsanitize=undefined -fno-sanitize=alignment,vptr,function -fno-sanitize-recover=all" + "${CMAKE_CXX_FLAGS} -fsanitize=undefined -fno-sanitize=alignment,vptr,function,float-divide-by-zero -fno-sanitize-recover=all" ) elseif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL "5.1") diff --git a/cpp/src/arrow/compare.cc b/cpp/src/arrow/compare.cc index e0c23a31eac..421ec139242 100644 --- a/cpp/src/arrow/compare.cc +++ b/cpp/src/arrow/compare.cc @@ -862,7 +862,9 @@ class TypeEqualsVisitor { class ScalarEqualsVisitor { public: - explicit ScalarEqualsVisitor(const Scalar& right) : right_(right), result_(false) {} + explicit ScalarEqualsVisitor(const Scalar& right, + const EqualOptions& opts = EqualOptions::Defaults()) + : right_(right), result_(false), options_(opts) {} Status Visit(const NullScalar& left) { result_ = true; @@ -875,9 +877,26 @@ class ScalarEqualsVisitor { return Status::OK(); } + template + typename std::enable_if::value || + std::is_base_of::value, + Status>::type + Visit(const T& left_) { + const auto& right = checked_cast(right_); + if (options_.nans_equal()) { + result_ = right.value == left_.value || + (std::isnan(right.value) && std::isnan(left_.value)); + } else { + result_ = right.value == left_.value; + } + return Status::OK(); + } + template typename std::enable_if< - std::is_base_of, T>::value || + (std::is_base_of, T>::value && + !std::is_base_of::value && + !std::is_base_of::value) || std::is_base_of, T>::value, Status>::type Visit(const T& left_) { @@ -968,6 +987,7 @@ class ScalarEqualsVisitor { protected: const Scalar& right_; bool result_; + const EqualOptions options_; }; Status PrintDiff(const Array& left, const Array& right, std::ostream* os) { @@ -1386,7 +1406,7 @@ bool TypeEquals(const DataType& left, const DataType& right, bool check_metadata } } -bool ScalarEquals(const Scalar& left, const Scalar& right) { +bool ScalarEquals(const Scalar& left, const Scalar& right, const EqualOptions& options) { bool are_equal = false; if (&left == &right) { are_equal = true; @@ -1395,7 +1415,7 @@ bool ScalarEquals(const Scalar& left, const Scalar& right) { } else if (left.is_valid != right.is_valid) { are_equal = false; } else { - ScalarEqualsVisitor visitor(right); + ScalarEqualsVisitor visitor(right, options); auto error = VisitScalarInline(left, &visitor); DCHECK_OK(error); are_equal = visitor.result(); diff --git a/cpp/src/arrow/compare.h b/cpp/src/arrow/compare.h index abcf39a62e5..f7899b7c5c6 100644 --- a/cpp/src/arrow/compare.h +++ b/cpp/src/arrow/compare.h @@ -111,6 +111,8 @@ bool ARROW_EXPORT TypeEquals(const DataType& left, const DataType& right, /// Returns true if scalars are equal /// \param[in] left a Scalar /// \param[in] right a Scalar -bool ARROW_EXPORT ScalarEquals(const Scalar& left, const Scalar& right); +/// \param[in] options comparison options +bool ARROW_EXPORT ScalarEquals(const Scalar& left, const Scalar& right, + const EqualOptions& options = EqualOptions::Defaults()); } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc index e56203bdfc3..ff6c6fab7c6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic.cc @@ -190,10 +190,6 @@ struct MultiplyChecked { struct Divide { template static enable_if_floating_point Call(KernelContext* ctx, Arg0 left, Arg1 right) { - if (ARROW_PREDICT_FALSE(right == 0)) { - ctx->SetStatus(Status::Invalid("divide by zero")); - return 0; - } return left / right; } diff --git a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc index ea24089a06d..9b3ed2a57e3 100644 --- a/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc @@ -144,7 +144,8 @@ class TestBinaryArithmetic : public TestBase { const auto expected_scalar = *expected->GetScalar(i); ASSERT_OK_AND_ASSIGN( actual, func(*left->GetScalar(i), *right->GetScalar(i), options_, nullptr)); - AssertScalarsEqual(*expected_scalar, *actual.scalar(), /*verbose=*/true); + AssertScalarsEqual(*expected_scalar, *actual.scalar(), /*verbose=*/true, + equal_options_); } } @@ -165,12 +166,17 @@ class TestBinaryArithmetic : public TestBase { void ValidateAndAssertApproxEqual(const std::shared_ptr& actual, const std::shared_ptr& expected) { ASSERT_OK(actual->ValidateFull()); - AssertArraysApproxEqual(*expected, *actual, /*verbose=*/true); + AssertArraysApproxEqual(*expected, *actual, /*verbose=*/true, equal_options_); } void SetOverflowCheck(bool value = true) { options_.check_overflow = value; } + void SetNansEqual(bool value = true) { + this->equal_options_ = equal_options_.nans_equal(value); + } + ArithmeticOptions options_ = ArithmeticOptions(); + EqualOptions equal_options_ = EqualOptions::Defaults(); }; template @@ -510,6 +516,9 @@ TYPED_TEST(TestBinaryArithmeticFloating, Div) { "[null, 0.1, 0.25, null, 0.2, 0.5]"); // Array with infinity this->AssertBinop(Divide, "[3.4, Inf, -Inf]", "[1, 2, 3]", "[3.4, Inf, -Inf]"); + // Array with NaN + this->SetNansEqual(true); + this->AssertBinop(Divide, "[3.4, NaN, 2.0]", "[1, 2, 2.0]", "[3.4, NaN, 1.0]"); // Scalar divides by scalar this->AssertBinop(Divide, 21.0F, 3.0F, 7.0F); } @@ -557,10 +566,17 @@ TYPED_TEST(TestBinaryArithmeticIntegral, DivideByZero) { } TYPED_TEST(TestBinaryArithmeticFloating, DivideByZero) { - for (auto check_overflow : {false, true}) { - this->SetOverflowCheck(check_overflow); - this->AssertBinopRaises(Divide, "[3.0, 2.0, 6.0]", "[1.0, 1.0, 0]", "divide by zero"); - } + this->SetOverflowCheck(true); + this->AssertBinopRaises(Divide, "[3.0, 2.0, 6.0]", "[1.0, 1.0, 0.0]", "divide by zero"); + this->AssertBinopRaises(Divide, "[3.0, 2.0, 0.0]", "[1.0, 1.0, 0.0]", "divide by zero"); + this->AssertBinopRaises(Divide, "[3.0, 2.0, -6.0]", "[1.0, 1.0, 0.0]", + "divide by zero"); + + this->SetOverflowCheck(false); + this->SetNansEqual(true); + this->AssertBinop(Divide, "[3.0, 2.0, 6.0]", "[1.0, 1.0, 0.0]", "[3.0, 2.0, Inf]"); + this->AssertBinop(Divide, "[3.0, 2.0, 0.0]", "[1.0, 1.0, 0.0]", "[3.0, 2.0, NaN]"); + this->AssertBinop(Divide, "[3.0, 2.0, -6.0]", "[1.0, 1.0, 0.0]", "[3.0, 2.0, -Inf]"); } TYPED_TEST(TestBinaryArithmeticSigned, DivideOverflowRaises) { diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index b953177a459..88e594e9b19 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -40,7 +40,9 @@ namespace arrow { using internal::checked_cast; using internal::checked_pointer_cast; -bool Scalar::Equals(const Scalar& other) const { return ScalarEquals(*this, other); } +bool Scalar::Equals(const Scalar& other, const EqualOptions& options) const { + return ScalarEquals(*this, other, options); +} struct ScalarHashImpl { static std::hash string_hash; diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index d15c44ce9db..4a007dd8782 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -28,6 +28,7 @@ #include #include +#include "arrow/compare.h" #include "arrow/result.h" #include "arrow/status.h" #include "arrow/type.h" @@ -61,7 +62,8 @@ struct ARROW_EXPORT Scalar : public util::EqualityComparable { using util::EqualityComparable::operator==; using util::EqualityComparable::Equals; - bool Equals(const Scalar& other) const; + bool Equals(const Scalar& other, + const EqualOptions& options = EqualOptions::Defaults()) const; struct ARROW_EXPORT Hash { size_t operator()(const Scalar& scalar) const { return hash(scalar); } diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index b2f55668977..75cd204e1b2 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -135,22 +135,24 @@ void AssertArraysEqual(const Array& expected, const Array& actual, bool verbose) }); } -void AssertArraysApproxEqual(const Array& expected, const Array& actual, bool verbose) { +void AssertArraysApproxEqual(const Array& expected, const Array& actual, bool verbose, + const EqualOptions& option) { return AssertArraysEqualWith( expected, actual, verbose, - [](const Array& expected, const Array& actual, std::stringstream* diff) { - return expected.ApproxEquals(actual, EqualOptions().diff_sink(diff)); + [&option](const Array& expected, const Array& actual, std::stringstream* diff) { + return expected.ApproxEquals(actual, option.diff_sink(diff)); }); } -void AssertScalarsEqual(const Scalar& expected, const Scalar& actual, bool verbose) { +void AssertScalarsEqual(const Scalar& expected, const Scalar& actual, bool verbose, + const EqualOptions& options) { std::stringstream diff; // ARROW-8956, ScalarEquals returns false when both are null if (!expected.is_valid && !actual.is_valid) { // We consider both being null to be equal in this function return; } - if (!expected.Equals(actual)) { + if (!expected.Equals(actual, options)) { if (verbose) { diff << "Expected:\n" << expected.ToString(); diff << "\nActual:\n" << actual.ToString(); diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index 1411e705bcf..fd72b5a88c3 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -164,12 +164,13 @@ std::vector AllTypeIds(); // If verbose is true, then the arrays will be pretty printed ARROW_TESTING_EXPORT void AssertArraysEqual(const Array& expected, const Array& actual, bool verbose = false); -ARROW_TESTING_EXPORT void AssertArraysApproxEqual(const Array& expected, - const Array& actual, - bool verbose = false); +ARROW_TESTING_EXPORT void AssertArraysApproxEqual( + const Array& expected, const Array& actual, bool verbose = false, + const EqualOptions& option = EqualOptions::Defaults()); // Returns true when values are both null -ARROW_TESTING_EXPORT void AssertScalarsEqual(const Scalar& expected, const Scalar& actual, - bool verbose = false); +ARROW_TESTING_EXPORT void AssertScalarsEqual( + const Scalar& expected, const Scalar& actual, bool verbose = false, + const EqualOptions& options = EqualOptions::Defaults()); ARROW_TESTING_EXPORT void AssertBatchesEqual(const RecordBatch& expected, const RecordBatch& actual, bool check_metadata = false);