Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
293 changes: 199 additions & 94 deletions cpp/src/arrow/compare.cc

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions cpp/src/arrow/compare.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,12 @@ bool ARROW_EXPORT TypeEquals(const DataType& left, const DataType& right,
bool ARROW_EXPORT ScalarEquals(const Scalar& left, const Scalar& right,
const EqualOptions& options = EqualOptions::Defaults());

/// Returns true if scalars are approximately equal
/// \param[in] left a Scalar
/// \param[in] right a Scalar
/// \param[in] options comparison options
bool ARROW_EXPORT
ScalarApproxEquals(const Scalar& left, const Scalar& right,
const EqualOptions& options = EqualOptions::Defaults());

} // namespace arrow
14 changes: 8 additions & 6 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,28 +512,30 @@ struct ScalarUnary {
using OutValue = typename GetOutputType<OutType>::T;
using Arg0Value = typename GetViewType<Arg0Type>::T;

static void Array(KernelContext* ctx, const ArrayData& arg0, Datum* out) {
static void ExecArray(KernelContext* ctx, const ArrayData& arg0, Datum* out) {
ArrayIterator<Arg0Type> arg0_it(arg0);
OutputAdapter<OutType>::Write(ctx, out, [&]() -> OutValue {
return Op::template Call<OutValue, Arg0Value>(ctx, arg0_it());
});
}

static void Scalar(KernelContext* ctx, const Scalar& arg0, Datum* out) {
static void ExecScalar(KernelContext* ctx, const Scalar& arg0, Datum* out) {
Scalar* out_scalar = out->scalar().get();
if (arg0.is_valid) {
Arg0Value arg0_val = UnboxScalar<Arg0Type>::Unbox(arg0);
out_scalar->is_valid = true;
BoxScalar<OutType>::Box(Op::template Call<OutValue, Arg0Value>(ctx, arg0_val),
out->scalar().get());
out_scalar);
} else {
out->value = MakeNullScalar(arg0.type);
out_scalar->is_valid = false;
}
}

static void Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
if (batch[0].kind() == Datum::ARRAY) {
return Array(ctx, *batch[0].array(), out);
return ExecArray(ctx, *batch[0].array(), out);
} else {
return Scalar(ctx, *batch[0].scalar(), out);
return ExecScalar(ctx, *batch[0].scalar(), out);
}
}
};
Expand Down
6 changes: 3 additions & 3 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class TestBinaryArithmetic : public TestBase {
auto exp = MakeScalar(expected);

ASSERT_OK_AND_ASSIGN(auto actual, func(left, right, options_, nullptr));
AssertScalarsEqual(*exp, *actual.scalar(), /*verbose=*/true);
AssertScalarsApproxEqual(*exp, *actual.scalar(), /*verbose=*/true);
}

// (Scalar, Array)
Expand Down Expand Up @@ -144,8 +144,8 @@ class TestBinaryArithmetic : public TestBase {
const auto expected_scalar = *expected->GetScalar(i);
ASSERT_OK_AND_ASSIGN(
actual, func(*left->GetScalar(i), *right->GetScalar(i), options_, nullptr));
AssertScalarsEqual(*expected_scalar, *actual.scalar(), /*verbose=*/true,
equal_options_);
AssertScalarsApproxEqual(*expected_scalar, *actual.scalar(), /*verbose=*/true,
equal_options_);
}
}

Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/compute/kernels/scalar_cast_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class TestCast : public TestBase {
AssertArraysEqual(expected, *result, /*verbose=*/true);

if (input.type_id() == Type::DECIMAL || out_type->id() == Type::DECIMAL) {
// ARROW-9194
// ARROW-10835
check_scalar = false;
}

Expand All @@ -111,7 +111,7 @@ class TestCast : public TestBase {
ASSERT_RAISES(Invalid, Cast(input, out_type, options));

if (input.type_id() == Type::DECIMAL || out_type->id() == Type::DECIMAL) {
// ARROW-9194
// ARROW-10835
check_scalar = false;
}

Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/scalar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ bool Scalar::Equals(const Scalar& other, const EqualOptions& options) const {
return ScalarEquals(*this, other, options);
}

bool Scalar::ApproxEquals(const Scalar& other, const EqualOptions& options) const {
return ScalarApproxEquals(*this, other, options);
}

struct ScalarHashImpl {
static std::hash<std::string> string_hash;

Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ struct ARROW_EXPORT Scalar : public util::EqualityComparable<Scalar> {
bool Equals(const Scalar& other,
const EqualOptions& options = EqualOptions::Defaults()) const;

bool ApproxEquals(const Scalar& other,
const EqualOptions& options = EqualOptions::Defaults()) const;

struct ARROW_EXPORT Hash {
size_t operator()(const Scalar& scalar) const { return hash(scalar); }

Expand Down
201 changes: 201 additions & 0 deletions cpp/src/arrow/scalar_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
// specific language governing permissions and limitations
// under the License.

#include <limits>
#include <memory>
#include <ostream>
#include <string>
#include <unordered_set>
#include <utility>
Expand Down Expand Up @@ -96,6 +98,12 @@ TYPED_TEST(TestNumericScalar, Basics) {
ASSERT_FALSE(one->Equals(ScalarType(2)));
ASSERT_TRUE(two->Equals(ScalarType(2)));
ASSERT_FALSE(two->Equals(ScalarType(3)));

ASSERT_TRUE(null->ApproxEquals(*null_value));
ASSERT_TRUE(one->ApproxEquals(ScalarType(1)));
ASSERT_FALSE(one->ApproxEquals(ScalarType(2)));
ASSERT_TRUE(two->ApproxEquals(ScalarType(2)));
ASSERT_FALSE(two->ApproxEquals(ScalarType(3)));
}

TYPED_TEST(TestNumericScalar, Hashing) {
Expand Down Expand Up @@ -127,6 +135,199 @@ TYPED_TEST(TestNumericScalar, MakeScalar) {
ASSERT_EQ(ScalarType(3), *three);
}

template <typename T>
class TestRealScalar : public ::testing::Test {
public:
using CType = typename T::c_type;
using ScalarType = typename TypeTraits<T>::ScalarType;

void SetUp() {
type_ = TypeTraits<T>::type_singleton();

scalar_val_ = std::make_shared<ScalarType>(static_cast<CType>(1));
ASSERT_TRUE(scalar_val_->is_valid);

scalar_other_ = std::make_shared<ScalarType>(static_cast<CType>(1.1));
ASSERT_TRUE(scalar_other_->is_valid);

const CType nan_value = std::numeric_limits<CType>::quiet_NaN();
scalar_nan_ = std::make_shared<ScalarType>(nan_value);
ASSERT_TRUE(scalar_nan_->is_valid);

const CType other_nan_value = std::numeric_limits<CType>::quiet_NaN();
scalar_other_nan_ = std::make_shared<ScalarType>(other_nan_value);
ASSERT_TRUE(scalar_other_nan_->is_valid);
}

void TestNanEquals() {
EqualOptions options = EqualOptions::Defaults();
ASSERT_FALSE(scalar_nan_->Equals(*scalar_val_, options));
ASSERT_FALSE(scalar_nan_->Equals(*scalar_nan_, options));
ASSERT_FALSE(scalar_nan_->Equals(*scalar_other_nan_, options));

options = options.nans_equal(true);
ASSERT_FALSE(scalar_nan_->Equals(*scalar_val_, options));
ASSERT_TRUE(scalar_nan_->Equals(*scalar_nan_, options));
ASSERT_TRUE(scalar_nan_->Equals(*scalar_other_nan_, options));
}

void TestApproxEquals() {
// The scalars are unequal with the small delta
EqualOptions options = EqualOptions::Defaults().atol(0.05);
ASSERT_FALSE(scalar_val_->ApproxEquals(*scalar_other_, options));
ASSERT_FALSE(scalar_other_->ApproxEquals(*scalar_val_, options));
ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_val_, options));
ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_other_nan_, options));

// After enlarging the delta, they become equal
options = options.atol(0.15);
ASSERT_TRUE(scalar_val_->ApproxEquals(*scalar_other_, options));
ASSERT_TRUE(scalar_other_->ApproxEquals(*scalar_val_, options));
ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_val_, options));
ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_other_nan_, options));

options = options.nans_equal(true);
ASSERT_TRUE(scalar_val_->ApproxEquals(*scalar_other_, options));
ASSERT_TRUE(scalar_other_->ApproxEquals(*scalar_val_, options));
ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_val_, options));
ASSERT_TRUE(scalar_nan_->ApproxEquals(*scalar_other_nan_, options));

options = options.atol(0.05);
ASSERT_FALSE(scalar_val_->ApproxEquals(*scalar_other_, options));
ASSERT_FALSE(scalar_other_->ApproxEquals(*scalar_val_, options));
ASSERT_FALSE(scalar_nan_->ApproxEquals(*scalar_val_, options));
ASSERT_TRUE(scalar_nan_->ApproxEquals(*scalar_other_nan_, options));
}

void TestStructOf() {
auto ty = struct_({field("float", type_)});

StructScalar struct_val({scalar_val_}, ty);
StructScalar struct_other_val({scalar_other_}, ty);
StructScalar struct_nan({scalar_nan_}, ty);
StructScalar struct_other_nan({scalar_other_nan_}, ty);

EqualOptions options = EqualOptions::Defaults().atol(0.05);
ASSERT_FALSE(struct_val.Equals(struct_other_val, options));
ASSERT_FALSE(struct_other_val.Equals(struct_val, options));
ASSERT_FALSE(struct_nan.Equals(struct_val, options));
ASSERT_FALSE(struct_nan.Equals(struct_nan, options));
ASSERT_FALSE(struct_nan.Equals(struct_other_nan, options));
ASSERT_FALSE(struct_val.ApproxEquals(struct_other_val, options));
ASSERT_FALSE(struct_other_val.ApproxEquals(struct_val, options));
ASSERT_FALSE(struct_nan.ApproxEquals(struct_val, options));
ASSERT_FALSE(struct_nan.ApproxEquals(struct_nan, options));
ASSERT_FALSE(struct_nan.ApproxEquals(struct_other_nan, options));

options = options.atol(0.15);
ASSERT_FALSE(struct_val.Equals(struct_other_val, options));
ASSERT_FALSE(struct_other_val.Equals(struct_val, options));
ASSERT_FALSE(struct_nan.Equals(struct_val, options));
ASSERT_FALSE(struct_nan.Equals(struct_nan, options));
ASSERT_FALSE(struct_nan.Equals(struct_other_nan, options));
ASSERT_TRUE(struct_val.ApproxEquals(struct_other_val, options));
ASSERT_TRUE(struct_other_val.ApproxEquals(struct_val, options));
ASSERT_FALSE(struct_nan.ApproxEquals(struct_val, options));
ASSERT_FALSE(struct_nan.ApproxEquals(struct_nan, options));
ASSERT_FALSE(struct_nan.ApproxEquals(struct_other_nan, options));

options = options.nans_equal(true);
ASSERT_FALSE(struct_val.Equals(struct_other_val, options));
ASSERT_FALSE(struct_other_val.Equals(struct_val, options));
ASSERT_FALSE(struct_nan.Equals(struct_val, options));
ASSERT_TRUE(struct_nan.Equals(struct_nan, options));
ASSERT_TRUE(struct_nan.Equals(struct_other_nan, options));
ASSERT_TRUE(struct_val.ApproxEquals(struct_other_val, options));
ASSERT_TRUE(struct_other_val.ApproxEquals(struct_val, options));
ASSERT_FALSE(struct_nan.ApproxEquals(struct_val, options));
ASSERT_TRUE(struct_nan.ApproxEquals(struct_nan, options));
ASSERT_TRUE(struct_nan.ApproxEquals(struct_other_nan, options));

options = options.atol(0.05);
ASSERT_FALSE(struct_val.Equals(struct_other_val, options));
ASSERT_FALSE(struct_other_val.Equals(struct_val, options));
ASSERT_FALSE(struct_nan.Equals(struct_val, options));
ASSERT_TRUE(struct_nan.Equals(struct_nan, options));
ASSERT_TRUE(struct_nan.Equals(struct_other_nan, options));
ASSERT_FALSE(struct_val.ApproxEquals(struct_other_val, options));
ASSERT_FALSE(struct_other_val.ApproxEquals(struct_val, options));
ASSERT_FALSE(struct_nan.ApproxEquals(struct_val, options));
ASSERT_TRUE(struct_nan.ApproxEquals(struct_nan, options));
ASSERT_TRUE(struct_nan.ApproxEquals(struct_other_nan, options));
}

void TestListOf() {
auto ty = list(type_);

ListScalar list_val(ArrayFromJSON(type_, "[0, null, 1.0]"), ty);
ListScalar list_other_val(ArrayFromJSON(type_, "[0, null, 1.1]"), ty);
ListScalar list_nan(ArrayFromJSON(type_, "[0, null, NaN]"), ty);
ListScalar list_other_nan(ArrayFromJSON(type_, "[0, null, NaN]"), ty);

EqualOptions options = EqualOptions::Defaults().atol(0.05);
ASSERT_TRUE(list_val.Equals(list_val, options));
ASSERT_FALSE(list_val.Equals(list_other_val, options));
ASSERT_FALSE(list_nan.Equals(list_val, options));
ASSERT_FALSE(list_nan.Equals(list_nan, options));
ASSERT_FALSE(list_nan.Equals(list_other_nan, options));
ASSERT_TRUE(list_val.ApproxEquals(list_val, options));
ASSERT_FALSE(list_val.ApproxEquals(list_other_val, options));
ASSERT_FALSE(list_nan.ApproxEquals(list_val, options));
ASSERT_FALSE(list_nan.ApproxEquals(list_nan, options));
ASSERT_FALSE(list_nan.ApproxEquals(list_other_nan, options));

options = options.atol(0.15);
ASSERT_TRUE(list_val.Equals(list_val, options));
ASSERT_FALSE(list_val.Equals(list_other_val, options));
ASSERT_FALSE(list_nan.Equals(list_val, options));
ASSERT_FALSE(list_nan.Equals(list_nan, options));
ASSERT_FALSE(list_nan.Equals(list_other_nan, options));
ASSERT_TRUE(list_val.ApproxEquals(list_val, options));
ASSERT_TRUE(list_val.ApproxEquals(list_other_val, options));
ASSERT_FALSE(list_nan.ApproxEquals(list_val, options));
ASSERT_FALSE(list_nan.ApproxEquals(list_nan, options));
ASSERT_FALSE(list_nan.ApproxEquals(list_other_nan, options));

options = options.nans_equal(true);
ASSERT_TRUE(list_val.Equals(list_val, options));
ASSERT_FALSE(list_val.Equals(list_other_val, options));
ASSERT_FALSE(list_nan.Equals(list_val, options));
ASSERT_TRUE(list_nan.Equals(list_nan, options));
ASSERT_TRUE(list_nan.Equals(list_other_nan, options));
ASSERT_TRUE(list_val.ApproxEquals(list_val, options));
ASSERT_TRUE(list_val.ApproxEquals(list_other_val, options));
ASSERT_FALSE(list_nan.ApproxEquals(list_val, options));
ASSERT_TRUE(list_nan.ApproxEquals(list_nan, options));
ASSERT_TRUE(list_nan.ApproxEquals(list_other_nan, options));

options = options.atol(0.05);
ASSERT_TRUE(list_val.Equals(list_val, options));
ASSERT_FALSE(list_val.Equals(list_other_val, options));
ASSERT_FALSE(list_nan.Equals(list_val, options));
ASSERT_TRUE(list_nan.Equals(list_nan, options));
ASSERT_TRUE(list_nan.Equals(list_other_nan, options));
ASSERT_TRUE(list_val.ApproxEquals(list_val, options));
ASSERT_FALSE(list_val.ApproxEquals(list_other_val, options));
ASSERT_FALSE(list_nan.ApproxEquals(list_val, options));
ASSERT_TRUE(list_nan.ApproxEquals(list_nan, options));
ASSERT_TRUE(list_nan.ApproxEquals(list_other_nan, options));
}

protected:
std::shared_ptr<DataType> type_;
std::shared_ptr<Scalar> scalar_val_, scalar_other_, scalar_nan_, scalar_other_nan_;
};

TYPED_TEST_SUITE(TestRealScalar, RealArrowTypes);

TYPED_TEST(TestRealScalar, NanEquals) { this->TestNanEquals(); }

TYPED_TEST(TestRealScalar, ApproxEquals) { this->TestApproxEquals(); }

TYPED_TEST(TestRealScalar, StructOf) { this->TestStructOf(); }

TYPED_TEST(TestRealScalar, ListOf) { this->TestListOf(); }

TEST(TestDecimal128Scalar, Basics) {
auto ty = decimal128(3, 2);
auto pi = Decimal128Scalar(Decimal128("3.14"), ty);
Expand Down
19 changes: 13 additions & 6 deletions cpp/src/arrow/testing/gtest_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,20 @@ void AssertArraysApproxEqual(const Array& expected, const Array& actual, bool ve

void AssertScalarsEqual(const Scalar& expected, const Scalar& actual, bool verbose,
const EqualOptions& options) {
std::stringstream diff;
// ARROW-8956, ScalarEquals returns false when both are null
if (!expected.is_valid && !actual.is_valid) {
// We consider both being null to be equal in this function
return;
}
if (!expected.Equals(actual, options)) {
std::stringstream diff;
if (verbose) {
diff << "Expected:\n" << expected.ToString();
diff << "\nActual:\n" << actual.ToString();
}
FAIL() << diff.str();
}
}

void AssertScalarsApproxEqual(const Scalar& expected, const Scalar& actual, bool verbose,
const EqualOptions& options) {
if (!expected.ApproxEquals(actual, options)) {
std::stringstream diff;
if (verbose) {
diff << "Expected:\n" << expected.ToString();
diff << "\nActual:\n" << actual.ToString();
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/testing/gtest_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ ARROW_TESTING_EXPORT void AssertArraysApproxEqual(
ARROW_TESTING_EXPORT void AssertScalarsEqual(
const Scalar& expected, const Scalar& actual, bool verbose = false,
const EqualOptions& options = EqualOptions::Defaults());
ARROW_TESTING_EXPORT void AssertScalarsApproxEqual(
const Scalar& expected, const Scalar& actual, bool verbose = false,
const EqualOptions& options = EqualOptions::Defaults());
ARROW_TESTING_EXPORT void AssertBatchesEqual(const RecordBatch& expected,
const RecordBatch& actual,
bool check_metadata = false);
Expand Down