diff --git a/cpp/src/arrow/scalar-test.cc b/cpp/src/arrow/scalar-test.cc index 67af4fb2493..6bd20556a19 100644 --- a/cpp/src/arrow/scalar-test.cc +++ b/cpp/src/arrow/scalar-test.cc @@ -199,6 +199,73 @@ TEST(TestTimestampScalars, Basics) { ASSERT_FALSE(ts_val2.Equals(ts_null)); } +TEST(TestDurationScalars, Basics) { + auto type1 = duration(TimeUnit::MILLI); + auto type2 = duration(TimeUnit::SECOND); + + int64_t val1 = 1; + int64_t val2 = 2; + DurationScalar ts_val1(val1, type1); + DurationScalar ts_val2(val2, type2); + DurationScalar ts_null(val2, type1, false); + ASSERT_EQ(val1, ts_val1.value); + ASSERT_EQ(val2, ts_null.value); + + ASSERT_TRUE(ts_val1.type->Equals(*type1)); + ASSERT_TRUE(ts_val2.type->Equals(*type2)); + ASSERT_TRUE(ts_val1.is_valid); + ASSERT_FALSE(ts_null.is_valid); + ASSERT_TRUE(ts_null.type->Equals(*type1)); + + ASSERT_FALSE(ts_val1.Equals(ts_val2)); + ASSERT_FALSE(ts_val1.Equals(ts_null)); + ASSERT_FALSE(ts_val2.Equals(ts_null)); +} + +TEST(TestMonthIntervalScalars, Basics) { + auto type = month_interval(); + + int32_t val1 = 1; + int32_t val2 = 2; + MonthIntervalScalar ts_val1(val1, type); + MonthIntervalScalar ts_val2(val2, type); + MonthIntervalScalar ts_null(val2, type, false); + ASSERT_EQ(val1, ts_val1.value); + ASSERT_EQ(val2, ts_null.value); + + ASSERT_TRUE(ts_val1.type->Equals(*type)); + ASSERT_TRUE(ts_val2.type->Equals(*type)); + ASSERT_TRUE(ts_val1.is_valid); + ASSERT_FALSE(ts_null.is_valid); + ASSERT_TRUE(ts_null.type->Equals(*type)); + + ASSERT_FALSE(ts_val1.Equals(ts_val2)); + ASSERT_FALSE(ts_val1.Equals(ts_null)); + ASSERT_FALSE(ts_val2.Equals(ts_null)); +} + +TEST(TestDayTimeIntervalScalars, Basics) { + auto type = day_time_interval(); + + DayTimeIntervalType::DayMilliseconds val1 = {1, 1}; + DayTimeIntervalType::DayMilliseconds val2 = {2, 2}; + DayTimeIntervalScalar ts_val1(val1, type); + DayTimeIntervalScalar ts_val2(val2, type); + DayTimeIntervalScalar ts_null(val2, type, false); + ASSERT_EQ(val1, ts_val1.value); + ASSERT_EQ(val2, ts_null.value); + + ASSERT_TRUE(ts_val1.type->Equals(*type)); + ASSERT_TRUE(ts_val2.type->Equals(*type)); + ASSERT_TRUE(ts_val1.is_valid); + ASSERT_FALSE(ts_null.is_valid); + ASSERT_TRUE(ts_null.type->Equals(*type)); + + ASSERT_FALSE(ts_val1.Equals(ts_val2)); + ASSERT_FALSE(ts_val1.Equals(ts_null)); + ASSERT_FALSE(ts_val2.Equals(ts_null)); +} + // TODO test HalfFloatScalar } // namespace arrow diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 4bc9b92cd54..2e101b1ccc3 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -51,6 +51,30 @@ TimestampScalar::TimestampScalar(int64_t value, const std::shared_ptr& DCHECK_EQ(Type::TIMESTAMP, type->id()); } +DurationScalar::DurationScalar(int64_t value, const std::shared_ptr& type, + bool is_valid) + : internal::PrimitiveScalar{type, is_valid}, value(value) { + DCHECK_EQ(Type::DURATION, type->id()); +} + +MonthIntervalScalar::MonthIntervalScalar(int32_t value, + const std::shared_ptr& type, + bool is_valid) + : internal::PrimitiveScalar{type, is_valid}, value(value) { + DCHECK_EQ(Type::INTERVAL, type->id()); + DCHECK_EQ(IntervalType::MONTHS, + checked_cast(type.get())->interval_type()); +} + +DayTimeIntervalScalar::DayTimeIntervalScalar(DayTimeIntervalType::DayMilliseconds value, + const std::shared_ptr& type, + bool is_valid) + : internal::PrimitiveScalar{type, is_valid}, value(value) { + DCHECK_EQ(Type::INTERVAL, type->id()); + DCHECK_EQ(IntervalType::DAY_TIME, + checked_cast(type.get())->interval_type()); +} + FixedSizeBinaryScalar::FixedSizeBinaryScalar(const std::shared_ptr& value, const std::shared_ptr& type, bool is_valid) diff --git a/cpp/src/arrow/scalar.h b/cpp/src/arrow/scalar.h index 856660e2d2f..4f0589a2f5c 100644 --- a/cpp/src/arrow/scalar.h +++ b/cpp/src/arrow/scalar.h @@ -157,7 +157,7 @@ class ARROW_EXPORT MonthIntervalScalar : public internal::PrimitiveScalar { bool is_valid = true); }; -class ARROW_EXPORT DayTimeIntervalScalar : public Scalar { +class ARROW_EXPORT DayTimeIntervalScalar : public internal::PrimitiveScalar { public: DayTimeIntervalType::DayMilliseconds value; DayTimeIntervalScalar(DayTimeIntervalType::DayMilliseconds value, diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index a10175957fe..16f486f45f1 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -897,10 +897,10 @@ class ARROW_EXPORT DayTimeIntervalType : public IntervalType { struct DayMilliseconds { int32_t days; int32_t milliseconds; - bool operator==(DayMilliseconds other) { + bool operator==(DayMilliseconds other) const { return this->days == other.days && this->milliseconds == other.milliseconds; } - bool operator!=(DayMilliseconds other) { return !(*this == other); } + bool operator!=(DayMilliseconds other) const { return !(*this == other); } }; using c_type = DayMilliseconds; static_assert(sizeof(DayMilliseconds) == 8, diff --git a/cpp/src/arrow/visitor_inline.h b/cpp/src/arrow/visitor_inline.h index b9ade98c0a4..544763a2f74 100644 --- a/cpp/src/arrow/visitor_inline.h +++ b/cpp/src/arrow/visitor_inline.h @@ -264,6 +264,17 @@ template inline Status VisitScalarInline(const Scalar& scalar, VISITOR* visitor) { switch (scalar.type->id()) { ARROW_GENERATE_FOR_ALL_TYPES(SCALAR_VISIT_INLINE); + case Type::INTERVAL: { + const auto& interval_type = + internal::checked_cast(*scalar.type); + if (interval_type.interval_type() == IntervalType::MONTHS) { + return visitor->Visit(internal::checked_cast(scalar)); + } + if (interval_type.interval_type() == IntervalType::DAY_TIME) { + return visitor->Visit( + internal::checked_cast(scalar)); + } + } default: break; }