diff --git a/cpp/src/arrow/dataset/filter.cc b/cpp/src/arrow/dataset/filter.cc index 780ad296b0f..bb5d7b3a855 100644 --- a/cpp/src/arrow/dataset/filter.cc +++ b/cpp/src/arrow/dataset/filter.cc @@ -78,6 +78,8 @@ struct Comparison { }; }; +Result Compare(const Scalar& lhs, const Scalar& rhs); + struct CompareVisitor { template using ScalarType = typename TypeTraits::ScalarType; @@ -144,7 +146,11 @@ struct CompareVisitor { } Status Visit(const DictionaryType&) { - return Status::NotImplemented("comparison of scalars of type ", *lhs_.type); + ARROW_ASSIGN_OR_RAISE(auto lhs, + checked_cast(lhs_).GetEncodedValue()); + ARROW_ASSIGN_OR_RAISE(auto rhs, + checked_cast(rhs_).GetEncodedValue()); + return Compare(*lhs, *rhs).Value(&result_); } // defer comparison to ScalarType::value @@ -170,7 +176,7 @@ struct CompareVisitor { // Compare two scalars // if either is null, return is null -// TODO(bkietz) extract this to scalar.h +// TODO(bkietz) extract this to the scalar comparison kernels Result Compare(const Scalar& lhs, const Scalar& rhs) { if (!lhs.type->Equals(*rhs.type)) { return Status::TypeError("Cannot compare scalars of differing type: ", *lhs.type, diff --git a/cpp/src/arrow/dataset/filter_test.cc b/cpp/src/arrow/dataset/filter_test.cc index 3650e85aecd..c8327607da1 100644 --- a/cpp/src/arrow/dataset/filter_test.cc +++ b/cpp/src/arrow/dataset/filter_test.cc @@ -355,6 +355,27 @@ TEST_F(ExpressionsTest, ImplicitCast) { InsertImplicitCasts("nope"_ == 0.0, *schema_)); } +TEST_F(ExpressionsTest, ImplicitCastToDict) { + auto dict_type = dictionary(int8(), float64()); + ASSERT_OK_AND_ASSIGN(auto filter, + InsertImplicitCasts("a"_ == 1.5, Schema({field("a", dict_type)}))); + + auto encoded_scalar = std::make_shared( + DictionaryScalar::ValueType{MakeScalar(0), + ArrayFromJSON(float64(), "[1.5]")}, + dict_type); + + ASSERT_EQ(E{filter}, E{"a"_ == encoded_scalar}); + + for (int8_t i = 0; i < 5; ++i) { + auto partition_scalar = std::make_shared( + DictionaryScalar::ValueType{ + MakeScalar(i), ArrayFromJSON(float64(), "[0.0, 0.5, 1.0, 1.5, 2.0]")}, + dict_type); + ASSERT_EQ(E{filter->Assume("a"_ == partition_scalar)}, E{scalar(i == 3)}); + } +} + TEST_F(FilterTest, ImplicitCast) { ASSERT_OK_AND_ASSIGN(auto filter, InsertImplicitCasts("a"_ >= "1", Schema({field("a", int32())}))); diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index e12c29206cc..8c42f5049c5 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -517,53 +517,52 @@ Status CastImpl(const ScalarType& from, StringScalar* to) { }); } +struct CastImplVisitor { + Status NotImplemented() { + return Status::NotImplemented("cast to ", *to_type_, " from ", *from_.type); + } + + const Scalar& from_; + const std::shared_ptr& to_type_; + Scalar* out_; +}; + template -struct FromTypeVisitor { +struct FromTypeVisitor : CastImplVisitor { using ToScalar = typename TypeTraits::ScalarType; + FromTypeVisitor(const Scalar& from, const std::shared_ptr& to_type, + Scalar* out) + : CastImplVisitor{from, to_type, out} {} + template Status Visit(const FromType&) { return CastImpl(checked_cast::ScalarType&>(from_), - out_); + checked_cast(out_)); } // identity cast only for parameter free types template typename std::enable_if::is_parameter_free, Status>::type Visit( const ToType&) { - out_->value = checked_cast(from_).value; + checked_cast(out_)->value = checked_cast(from_).value; return Status::OK(); } - // null to any - Status Visit(const NullType&) { - return Status::Invalid("attempting to cast scalar of type null to ", *to_type_); - } - - Status Visit(const SparseUnionType&) { - return Status::NotImplemented("cast to ", *to_type_); - } - Status Visit(const DenseUnionType&) { - return Status::NotImplemented("cast to ", *to_type_); - } - Status Visit(const DictionaryType&) { - return Status::NotImplemented("cast to ", *to_type_); - } - Status Visit(const ExtensionType&) { - return Status::NotImplemented("cast to ", *to_type_); - } - - const Scalar& from_; - const std::shared_ptr& to_type_; - ToScalar* out_; + Status Visit(const NullType&) { return NotImplemented(); } + Status Visit(const SparseUnionType&) { return NotImplemented(); } + Status Visit(const DenseUnionType&) { return NotImplemented(); } + Status Visit(const DictionaryType&) { return NotImplemented(); } + Status Visit(const ExtensionType&) { return NotImplemented(); } }; -struct ToTypeVisitor { +struct ToTypeVisitor : CastImplVisitor { + ToTypeVisitor(const Scalar& from, const std::shared_ptr& to_type, Scalar* out) + : CastImplVisitor{from, to_type, out} {} + template Status Visit(const ToType&) { - using ToScalar = typename TypeTraits::ScalarType; - FromTypeVisitor unpack_from_type{from_, to_type_, - checked_cast(out_)}; + FromTypeVisitor unpack_from_type{from_, to_type_, out_}; return VisitTypeInline(*from_.type, &unpack_from_type); } @@ -574,22 +573,16 @@ struct ToTypeVisitor { return Status::OK(); } - Status Visit(const SparseUnionType&) { - return Status::NotImplemented("cast from ", *from_.type); - } - Status Visit(const DenseUnionType&) { - return Status::NotImplemented("cast from ", *from_.type); - } - Status Visit(const DictionaryType&) { - return Status::NotImplemented("cast from ", *from_.type); - } - Status Visit(const ExtensionType&) { - return Status::NotImplemented("cast from ", *from_.type); + Status Visit(const DictionaryType& dict_type) { + auto& out = checked_cast(out_)->value; + ARROW_ASSIGN_OR_RAISE(auto cast_value, from_.CastTo(dict_type.value_type())); + ARROW_ASSIGN_OR_RAISE(out.dictionary, MakeArrayFromScalar(*cast_value, 1)); + return Int32Scalar(0).CastTo(dict_type.index_type()).Value(&out.index); } - const Scalar& from_; - const std::shared_ptr& to_type_; - Scalar* out_; + Status Visit(const SparseUnionType&) { return NotImplemented(); } + Status Visit(const DenseUnionType&) { return NotImplemented(); } + Status Visit(const ExtensionType&) { return NotImplemented(); } }; } // namespace diff --git a/cpp/src/arrow/scalar_test.cc b/cpp/src/arrow/scalar_test.cc index 3bd7e67c6fc..8530bea5259 100644 --- a/cpp/src/arrow/scalar_test.cc +++ b/cpp/src/arrow/scalar_test.cc @@ -33,6 +33,7 @@ namespace arrow { using internal::checked_cast; +using internal::checked_pointer_cast; TEST(TestNullScalar, Basics) { NullScalar scalar; @@ -47,7 +48,7 @@ TEST(TestNullScalar, Basics) { template class TestNumericScalar : public ::testing::Test { public: - TestNumericScalar() {} + TestNumericScalar() = default; }; TYPED_TEST_SUITE(TestNumericScalar, NumericArrowTypes); @@ -198,7 +199,7 @@ TEST(TestStringScalar, MakeScalar) { ASSERT_EQ(StringScalar("three"), *three); // test Array.GetScalar - auto arr = ArrayFromJSON(utf8(), "[null, \"one\", \"two\"]"); + auto arr = ArrayFromJSON(utf8(), R"([null, "one", "two"])"); ASSERT_OK_AND_ASSIGN(auto null, arr->GetScalar(0)); ASSERT_OK_AND_ASSIGN(auto one, arr->GetScalar(1)); ASSERT_OK_AND_ASSIGN(auto two, arr->GetScalar(2)); @@ -221,7 +222,7 @@ TEST(TestFixedSizeBinaryScalar, Basics) { // test Array.GetScalar auto ty = fixed_size_binary(3); - auto arr = ArrayFromJSON(ty, "[null, \"one\", \"two\"]"); + auto arr = ArrayFromJSON(ty, R"([null, "one", "two"])"); ASSERT_OK_AND_ASSIGN(auto null, arr->GetScalar(0)); ASSERT_OK_AND_ASSIGN(auto one, arr->GetScalar(1)); ASSERT_OK_AND_ASSIGN(auto two, arr->GetScalar(2)); @@ -581,9 +582,9 @@ TEST(TestStructScalar, FieldAccess) { } TEST(TestDictionaryScalar, Basics) { - auto CheckIndexType = [&](const std::shared_ptr& index_ty) { + for (auto index_ty : all_dictionary_index_types()) { auto ty = dictionary(index_ty, utf8()); - auto dict = ArrayFromJSON(utf8(), "[\"alpha\", \"beta\", \"gamma\"]"); + auto dict = ArrayFromJSON(utf8(), R"(["alpha", "beta", "gamma"])"); DictionaryScalar::ValueType alpha; ASSERT_OK_AND_ASSIGN(alpha.index, MakeScalar(index_ty, 0)); @@ -621,10 +622,34 @@ TEST(TestDictionaryScalar, Basics) { ASSERT_TRUE(first->Equals(scalar_gamma)); ASSERT_TRUE(second->Equals(scalar_alpha)); ASSERT_TRUE(last->Equals(scalar_null)); - }; + } +} - for (auto ty : all_dictionary_index_types()) { - CheckIndexType(ty); +TEST(TestDictionaryScalar, Cast) { + for (auto index_ty : all_dictionary_index_types()) { + auto ty = dictionary(index_ty, utf8()); + auto dict = checked_pointer_cast( + ArrayFromJSON(utf8(), R"(["alpha", "beta", "gamma"])")); + + for (int64_t i = 0; i < dict->length(); ++i) { + auto alpha = MakeScalar(dict->GetString(i)); + ASSERT_OK_AND_ASSIGN(auto cast_alpha, alpha->CastTo(ty)); + ASSERT_OK_AND_ASSIGN( + auto roundtripped_alpha, + checked_cast(*cast_alpha).GetEncodedValue()); + + ASSERT_OK_AND_ASSIGN(auto i_scalar, MakeScalar(index_ty, i)); + auto alpha_dict = DictionaryScalar({i_scalar, dict}, ty); + ASSERT_OK_AND_ASSIGN( + auto encoded_alpha, + checked_cast(alpha_dict).GetEncodedValue()); + + AssertScalarsEqual(*alpha, *roundtripped_alpha); + AssertScalarsEqual(*encoded_alpha, *roundtripped_alpha); + + // dictionaries differ, though encoded values are identical + ASSERT_FALSE(alpha_dict.Equals(cast_alpha)); + } } } @@ -641,7 +666,7 @@ TEST(TestSparseUnionScalar, Basics) { // test Array.GetScalar std::vector> children{ - ArrayFromJSON(utf8(), "[\"alpha\", \"\", \"beta\", null, \"gamma\"]"), + ArrayFromJSON(utf8(), R"(["alpha", "", "beta", null, "gamma"])"), ArrayFromJSON(uint64(), "[1, 2, 11, 22, null]")}; auto type_ids = ArrayFromJSON(int8(), "[0, 1, 0, 0, 1]"); @@ -697,7 +722,7 @@ TEST(TestDenseUnionScalar, Basics) { // test Array.GetScalar std::vector> children = { - ArrayFromJSON(utf8(), "[\"alpha\", \"beta\", null]"), + ArrayFromJSON(utf8(), R"(["alpha", "beta", null])"), ArrayFromJSON(uint64(), "[2, 3]")}; auto type_ids = ArrayFromJSON(int8(), "[0, 1, 0, 0, 1]"); diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index dabdf410587..a16b6dccd41 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -1068,6 +1068,12 @@ def test_partitioning_factory_dictionary(mockfs): expected = pa.array(['xxx'] * 5 + ['yyy'] * 5).dictionary_encode() assert actual.equals(expected) + # ARROW-9345 ensure filtering on the partition field works + table = factory.finish().to_table(filter=ds.field('key') == 'xxx') + actual = table.column('key').chunk(0) + expected = expected.slice(0, 5) + assert actual.equals(expected) + def test_partitioning_function(): schema = pa.schema([("year", pa.int16()), ("month", pa.int8())])