From dfcefbce3f06b98b5b05d2ab1d4704f25afea248 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 17 Dec 2021 11:38:29 -0500 Subject: [PATCH 01/22] ARROW-14705: [C++] Implement more complete type unification --- cpp/src/arrow/dataset/discovery.cc | 2 +- cpp/src/arrow/dataset/discovery.h | 3 + cpp/src/arrow/type.cc | 214 +++++++++++++++++++++++++---- cpp/src/arrow/type.h | 20 +++ cpp/src/arrow/type_test.cc | 65 ++++++++- cpp/src/arrow/type_traits.h | 22 +++ 6 files changed, 296 insertions(+), 30 deletions(-) diff --git a/cpp/src/arrow/dataset/discovery.cc b/cpp/src/arrow/dataset/discovery.cc index 0f9d479b9d6..8b12f3ea815 100644 --- a/cpp/src/arrow/dataset/discovery.cc +++ b/cpp/src/arrow/dataset/discovery.cc @@ -43,7 +43,7 @@ Result> DatasetFactory::Inspect(InspectOptions options) return arrow::schema({}); } - return UnifySchemas(schemas); + return UnifySchemas(schemas, options.field_merge_options); } Result> DatasetFactory::Finish() { diff --git a/cpp/src/arrow/dataset/discovery.h b/cpp/src/arrow/dataset/discovery.h index 40c02051955..bd928ce57ff 100644 --- a/cpp/src/arrow/dataset/discovery.h +++ b/cpp/src/arrow/dataset/discovery.h @@ -58,6 +58,9 @@ struct InspectOptions { /// `kInspectAllFragments`. A value of `0` disables inspection of fragments /// altogether so only the partitioning schema will be inspected. int fragments = 1; + + /// Control how to unify types. + Field::MergeOptions field_merge_options = Field::MergeOptions::Defaults(); }; struct FinishOptions { diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 2a382662497..935d2e4ab1a 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -47,6 +47,8 @@ namespace arrow { +using internal::checked_cast; + constexpr Type::type NullType::type_id; constexpr Type::type ListType::type_id; constexpr Type::type LargeListType::type_id; @@ -216,27 +218,6 @@ std::shared_ptr GetPhysicalType(const std::shared_ptr& real_ return std::move(visitor.result); } -namespace { - -using internal::checked_cast; - -// Merges `existing` and `other` if one of them is of NullType, otherwise -// returns nullptr. -// - if `other` if of NullType or is nullable, the unified field will be nullable. -// - if `existing` is of NullType but other is not, the unified field will -// have `other`'s type and will be nullable -std::shared_ptr MaybePromoteNullTypes(const Field& existing, const Field& other) { - if (existing.type()->id() != Type::NA && other.type()->id() != Type::NA) { - return nullptr; - } - if (existing.type()->id() == Type::NA) { - return other.WithNullable(true)->WithMetadata(existing.metadata()); - } - // `other` must be null. - return existing.WithNullable(true); -} -} // namespace - Field::~Field() {} bool Field::HasMetadata() const { @@ -275,6 +256,176 @@ std::shared_ptr Field::WithNullable(const bool nullable) const { return std::make_shared(name_, type_, nullable, metadata_); } +namespace { +// Utilities for Field::MergeWith + +std::shared_ptr MakeSigned(const DataType& type) { + switch (type.id()) { + case Type::INT8: + case Type::UINT8: + return int8(); + case Type::INT16: + case Type::UINT16: + return int16(); + case Type::INT32: + case Type::UINT32: + return int32(); + case Type::INT64: + case Type::UINT64: + return int64(); + default: + DCHECK(false) << "unreachable"; + } + return std::shared_ptr(nullptr); +} +std::shared_ptr MakeBinary(const DataType& type) { + switch (type.id()) { + case Type::BINARY: + case Type::STRING: + return binary(); + case Type::LARGE_BINARY: + case Type::LARGE_STRING: + return large_binary(); + default: + DCHECK(false) << "unreachable"; + } + return std::shared_ptr(nullptr); +} + +std::shared_ptr MergeTypes(std::shared_ptr promoted_type, + std::shared_ptr other_type, + const Field::MergeOptions& options) { + bool promoted = false; + if (options.promote_nullability) { + if (promoted_type->id() == Type::NA) { + return other_type; + } else if (other_type->id() == Type::NA) { + return promoted_type; + } + } else if (promoted_type->id() == Type::NA || other_type->id() == Type::NA) { + return nullptr; + } + + if (options.promote_integer_sign) { + if (is_unsigned_integer(promoted_type->id()) && is_signed_integer(other_type->id())) { + promoted = bit_width(other_type->id()) >= bit_width(promoted_type->id()); + promoted_type = MakeSigned(*promoted_type); + } else if (is_signed_integer(promoted_type->id()) && + is_unsigned_integer(other_type->id())) { + promoted = bit_width(promoted_type->id()) >= bit_width(other_type->id()); + other_type = MakeSigned(*other_type); + } + } + + if (options.promote_integer_float && + ((is_floating(promoted_type->id()) && is_integer(other_type->id())) || + (is_integer(promoted_type->id()) && is_floating(other_type->id())))) { + const int max_width = + std::max(bit_width(promoted_type->id()), bit_width(other_type->id())); + if (max_width >= 64) { + promoted_type = float64(); + } else if (max_width >= 32) { + promoted_type = float32(); + } else { + promoted_type = float16(); + } + promoted = true; + } + + if (options.promote_numeric_width) { + const int max_width = + std::max(bit_width(promoted_type->id()), bit_width(other_type->id())); + if (is_floating(promoted_type->id()) && is_floating(other_type->id())) { + if (max_width >= 64) { + promoted_type = float64(); + } else if (max_width >= 32) { + promoted_type = float32(); + } else { + promoted_type = float16(); + } + promoted = true; + } else if (is_signed_integer(promoted_type->id()) && + is_signed_integer(other_type->id())) { + if (max_width >= 64) { + promoted_type = int64(); + } else if (max_width >= 32) { + promoted_type = int32(); + } else if (max_width >= 16) { + promoted_type = int16(); + } else { + promoted_type = int8(); + } + promoted = true; + } else if (is_unsigned_integer(promoted_type->id()) && + is_unsigned_integer(other_type->id())) { + if (max_width >= 64) { + promoted_type = uint64(); + } else if (max_width >= 32) { + promoted_type = uint32(); + } else if (max_width >= 16) { + promoted_type = uint16(); + } else { + promoted_type = uint8(); + } + promoted = true; + } + } + + if (options.promote_binary) { + if (promoted_type->id() == Type::FIXED_SIZE_BINARY) { + promoted_type = binary(); + promoted = other_type->id() == Type::BINARY; + } + if (other_type->id() == Type::FIXED_SIZE_BINARY) { + other_type = binary(); + promoted = promoted_type->id() == Type::BINARY; + } + + if (is_string(promoted_type->id()) && is_binary(other_type->id())) { + promoted_type = MakeBinary(*promoted_type); + promoted = + offset_bit_width(promoted_type->id()) == offset_bit_width(other_type->id()); + } else if (is_binary(promoted_type->id()) && is_string(other_type->id())) { + other_type = MakeBinary(*other_type); + promoted = + offset_bit_width(promoted_type->id()) == offset_bit_width(other_type->id()); + } + } + + if (options.promote_large) { + if ((promoted_type->id() == Type::STRING && other_type->id() == Type::LARGE_STRING) || + (promoted_type->id() == Type::LARGE_STRING && other_type->id() == Type::STRING)) { + promoted_type = large_utf8(); + promoted = true; + } else if ((promoted_type->id() == Type::BINARY && + other_type->id() == Type::LARGE_BINARY) || + (promoted_type->id() == Type::LARGE_BINARY && + other_type->id() == Type::BINARY)) { + promoted_type = large_binary(); + promoted = true; + } + } + + // TODO + // Date32 -> Date64 + // Timestamp units + // Time32 -> Time64 + // Decimal128 -> Decimal256 + // Integer -> Decimal + // Decimal -> Float + // List(A) -> List(B) + // List -> LargeList + // Unions? + // Dictionary: indices, values + // Struct: reconcile order, fields, types + // Map + // Fixed size list + // Duration units + + return promoted ? promoted_type : nullptr; +} +} // namespace + Result> Field::MergeWith(const Field& other, MergeOptions options) const { if (name() != other.name()) { @@ -286,14 +437,20 @@ Result> Field::MergeWith(const Field& other, return Copy(); } - if (options.promote_nullability) { - if (type()->Equals(other.type())) { - return Copy()->WithNullable(nullable() || other.nullable()); + auto promoted_type = MergeTypes(type_, other.type(), options); + if (promoted_type) { + bool nullable = nullable_; + if (options.promote_nullability) { + nullable = nullable || other.nullable() || type_->id() == Type::NA || + other.type()->id() == Type::NA; + } else if (nullable_ != other.nullable()) { + return Status::Invalid("Unable to merge: Field ", name(), + " has incompatible nullability: ", nullable_, " vs ", + other.nullable()); } - std::shared_ptr promoted = MaybePromoteNullTypes(*this, other); - if (promoted) return promoted; - } + return std::make_shared(name_, promoted_type, nullable, metadata_); + } return Status::Invalid("Unable to merge: Field ", name(), " has incompatible types: ", type()->ToString(), " vs ", other.type()->ToString()); @@ -1668,7 +1825,8 @@ class SchemaBuilder::Impl { if (policy_ == CONFLICT_REPLACE) { fields_[i] = field; } else if (policy_ == CONFLICT_MERGE) { - ARROW_ASSIGN_OR_RAISE(fields_[i], fields_[i]->MergeWith(field)); + ARROW_ASSIGN_OR_RAISE(fields_[i], + fields_[i]->MergeWith(field, field_merge_options_)); } return Status::OK(); diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 463636b0537..cca9e5ea168 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -310,6 +310,26 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// nullable). bool promote_nullability = true; + /// Allow an integer, float, or decimal of a given bit width to be + /// promoted to an equivalent type of a greater bit width. + bool promote_numeric_width = true; + + /// Allow an integer of a given bit width to be promoted to a + /// float of an equal or greater bit width. + bool promote_integer_float = true; + + /// Allow an unsigned integer of a given bit width to be promoted + /// to a signed integer of the same bit width. + bool promote_integer_sign = true; + + /// Allow a type to be promoted to the Large variant. + bool promote_large = true; + + /// Allow strings to be promoted to binary types. + bool promote_binary = true; + + // TODO: how do we want to handle decimal? + static MergeOptions Defaults() { return MergeOptions(); } }; diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index c7ac5f6c7f2..eeba2c95c15 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -978,6 +978,44 @@ class TestUnifySchemas : public TestSchema { << lhs_field->ToString() << " vs " << rhs_field->ToString(); } } + + void CheckUnify(const std::shared_ptr& field1, + const std::shared_ptr& field2, + const std::shared_ptr& expected) { + ARROW_SCOPED_TRACE("field2: ", field2->ToString()); + ARROW_SCOPED_TRACE("field1: ", field1->ToString()); + ASSERT_OK_AND_ASSIGN(auto merged1, field1->MergeWith(field2)); + ASSERT_OK_AND_ASSIGN(auto merged2, field2->MergeWith(field1)); + AssertFieldEqual(merged1, expected); + AssertFieldEqual(merged2, expected); + } + + void CheckUnify(const std::shared_ptr& left, + const std::shared_ptr& right, + const std::shared_ptr& expected) { + auto field1 = field("a", left); + auto field2 = field("a", right); + CheckUnify(field1, field2, field("a", expected)); + + field1 = field("a", left, /*nullable=*/false); + field2 = field("a", right, /*nullable=*/false); + CheckUnify(field1, field2, field("a", expected, /*nullable=*/false)); + + field1 = field("a", left); + field2 = field("a", right, /*nullable=*/false); + CheckUnify(field1, field2, field("a", expected, /*nullable=*/true)); + + field1 = field("a", left, /*nullable=*/false); + field2 = field("a", right); + CheckUnify(field1, field2, field("a", expected, /*nullable=*/true)); + } + + void CheckUnify(const std::shared_ptr& from, + const std::vector>& to) { + for (const auto& ty : to) { + CheckUnify(from, ty, ty); + } + } }; TEST_F(TestUnifySchemas, EmptyInput) { ASSERT_RAISES(Invalid, UnifySchemas({})); } @@ -1069,9 +1107,34 @@ TEST_F(TestUnifySchemas, MoreSchemas) { utf8_field->WithNullable(true)})); } +TEST_F(TestUnifySchemas, Numerics) { + CheckUnify(uint8(), {int8(), uint16(), int16(), uint32(), int32(), uint64(), int64(), + float32(), float64()}); + CheckUnify(int8(), {int16(), int32(), int64(), float32(), float64()}); + CheckUnify(uint16(), + {int16(), uint32(), int32(), uint64(), int64(), float32(), float64()}); + CheckUnify(int16(), {int32(), int64(), float32(), float64()}); + CheckUnify(uint32(), {int32(), uint64(), int64(), float32(), float64()}); + CheckUnify(int32(), {int64(), float32(), float64()}); + CheckUnify(uint64(), {int64(), float64()}); + CheckUnify(int64(), {float64()}); + CheckUnify(float16(), {float32(), float64()}); + CheckUnify(float32(), {float64()}); + + // CheckUnifyFails(int8(), {uint8(), uint16(), uint32(), uint64()}); + // uint32 and int32/int64; uint64 and int64 +} + +TEST_F(TestUnifySchemas, Binary) { + CheckUnify(utf8(), {large_utf8(), binary(), large_binary()}); + CheckUnify(binary(), {large_binary()}); + CheckUnify(fixed_size_binary(2), {binary(), large_binary()}); + CheckUnify(fixed_size_binary(2), fixed_size_binary(4), binary()); +} + TEST_F(TestUnifySchemas, IncompatibleTypes) { auto int32_field = field("f", int32()); - auto uint8_field = field("f", uint8(), false); + auto uint8_field = field("f", utf8(), false); auto schema1 = schema({int32_field}); auto schema2 = schema({uint8_field}); diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 4b4cb5d15d3..278997d8463 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -944,6 +944,28 @@ static inline bool is_large_binary_like(Type::type type_id) { return false; } +static inline bool is_binary(Type::type type_id) { + switch (type_id) { + case Type::BINARY: + case Type::LARGE_BINARY: + return true; + default: + break; + } + return false; +} + +static inline bool is_string(Type::type type_id) { + switch (type_id) { + case Type::STRING: + case Type::LARGE_STRING: + return true; + default: + break; + } + return false; +} + static inline bool is_dictionary(Type::type type_id) { return type_id == Type::DICTIONARY; } From 0b223be4811d5690bc64f89c4449e3d071f9edfa Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 17 Dec 2021 12:02:15 -0500 Subject: [PATCH 02/22] ARROW-14705: [C++] Add remaining options --- cpp/src/arrow/type.cc | 46 +++++++++++++++++++++++++-- cpp/src/arrow/type.h | 56 +++++++++++++++++++++++++++----- cpp/src/arrow/type_test.cc | 65 +++++++++++++++++++++++--------------- 3 files changed, 131 insertions(+), 36 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 935d2e4ab1a..bd33ebe2ebe 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -256,6 +256,49 @@ std::shared_ptr Field::WithNullable(const bool nullable) const { return std::make_shared(name_, type_, nullable, metadata_); } +Field::MergeOptions Field::MergeOptions::Permissive() { + MergeOptions options = Defaults(); + options.promote_nullability = true; + options.promote_numeric_width = true; + options.promote_integer_float = true; + options.promote_integer_decimal = true; + options.promote_decimal_float = true; + options.increase_decimal_precision = true; + options.promote_date = true; + options.promote_time = true; + options.promote_duration_units = true; + options.promote_timestamp_units = true; + options.promote_nested = true; + options.promote_dictionary = true; + options.promote_integer_sign = true; + options.promote_large = true; + options.promote_binary = true; + return options; +} + +std::string Field::MergeOptions::ToString() const { + std::stringstream ss; + ss << "MergeOptions{"; + ss << "promote_nullability=" << (promote_nullability ? "true" : "false"); + ss << ", promote_numeric_width=" << (promote_numeric_width ? "true" : "false"); + ss << ", promote_integer_float=" << (promote_integer_float ? "true" : "false"); + ss << ", promote_integer_decimal=" << (promote_integer_decimal ? "true" : "false"); + ss << ", promote_decimal_float=" << (promote_decimal_float ? "true" : "false"); + ss << ", increase_decimal_precision=" + << (increase_decimal_precision ? "true" : "false"); + ss << ", promote_date=" << (promote_date ? "true" : "false"); + ss << ", promote_time=" << (promote_time ? "true" : "false"); + ss << ", promote_duration_units=" << (promote_duration_units ? "true" : "false"); + ss << ", promote_timestamp_units=" << (promote_timestamp_units ? "true" : "false"); + ss << ", promote_nested=" << (promote_nested ? "true" : "false"); + ss << ", promote_dictionary=" << (promote_dictionary ? "true" : "false"); + ss << ", promote_integer_sign=" << (promote_integer_sign ? "true" : "false"); + ss << ", promote_large=" << (promote_large ? "true" : "false"); + ss << ", promote_binary=" << (promote_binary ? "true" : "false"); + ss << '}'; + return ss.str(); +} + namespace { // Utilities for Field::MergeWith @@ -291,7 +334,6 @@ std::shared_ptr MakeBinary(const DataType& type) { } return std::shared_ptr(nullptr); } - std::shared_ptr MergeTypes(std::shared_ptr promoted_type, std::shared_ptr other_type, const Field::MergeOptions& options) { @@ -413,6 +455,7 @@ std::shared_ptr MergeTypes(std::shared_ptr promoted_type, // Decimal128 -> Decimal256 // Integer -> Decimal // Decimal -> Float + // Duration units // List(A) -> List(B) // List -> LargeList // Unions? @@ -420,7 +463,6 @@ std::shared_ptr MergeTypes(std::shared_ptr promoted_type, // Struct: reconcile order, fields, types // Map // Fixed size list - // Duration units return promoted ? promoted_type : nullptr; } diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index cca9e5ea168..e4765e93f3d 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -303,7 +303,7 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// \brief Options that control the behavior of `MergeWith`. /// Options are to be added to allow type conversions, including integer /// widening, promotion from integer to float, or conversion to or from boolean. - struct MergeOptions { + struct MergeOptions : public util::ToStringOstreamable { /// If true, a Field of NullType can be unified with a Field of another type. /// The unified field will be of the other type and become nullable. /// Nullability will be promoted to the looser option (nullable if one is not @@ -312,25 +312,65 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// Allow an integer, float, or decimal of a given bit width to be /// promoted to an equivalent type of a greater bit width. - bool promote_numeric_width = true; + bool promote_numeric_width = false; /// Allow an integer of a given bit width to be promoted to a /// float of an equal or greater bit width. - bool promote_integer_float = true; + bool promote_integer_float = false; + + /// Allow an integer to be promoted to a decimal. + /// + /// May fail if the decimal has insufficient precision to + /// accomodate the integer. (See increase_decimal_precision.) + bool promote_integer_decimal = false; + + /// Allow a decimal to be promoted to a float. The float type will + /// not itself be promoted (e.g. Decimal128 + Float32 = Float32). + bool promote_decimal_float = false; + + /// When promoting another type to a decimal, increase precision + /// (and possibly fail) to hold all possible values of the other type. + /// + /// For example: unifying int64 and decimal256(76, 70) will fail + /// if this is true since we need at least 19 digits to the left + /// of the decimal point but we are already at max precision. If + /// this is false, the unified type will be decimal128(38, 30). + bool increase_decimal_precision = false; + + /// Promote Date32 to Date64. + bool promote_date = false; + + /// Promote Time32 to Time64. + bool promote_time = false; + + /// Promote second to millisecond, etc. + bool promote_duration_units = false; + + /// Promote second to millisecond, etc. + bool promote_timestamp_units = false; + + /// Recursively merge nested types. + bool promote_nested = false; + + /// Promote dictionary index types to a common type, and unify the + /// value types. + bool promote_dictionary = false; /// Allow an unsigned integer of a given bit width to be promoted /// to a signed integer of the same bit width. - bool promote_integer_sign = true; + bool promote_integer_sign = false; + // TODO: does this include fixed size list? + // TODO: does this include fixed size binary? /// Allow a type to be promoted to the Large variant. - bool promote_large = true; + bool promote_large = false; /// Allow strings to be promoted to binary types. - bool promote_binary = true; - - // TODO: how do we want to handle decimal? + bool promote_binary = false; static MergeOptions Defaults() { return MergeOptions(); } + static MergeOptions Permissive(); + std::string ToString() const; }; /// \brief Merge the current field with a field of the same name. diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index eeba2c95c15..f7ea9b352be 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -981,39 +981,43 @@ class TestUnifySchemas : public TestSchema { void CheckUnify(const std::shared_ptr& field1, const std::shared_ptr& field2, - const std::shared_ptr& expected) { + const std::shared_ptr& expected, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + ARROW_SCOPED_TRACE("options: ", options); ARROW_SCOPED_TRACE("field2: ", field2->ToString()); ARROW_SCOPED_TRACE("field1: ", field1->ToString()); - ASSERT_OK_AND_ASSIGN(auto merged1, field1->MergeWith(field2)); - ASSERT_OK_AND_ASSIGN(auto merged2, field2->MergeWith(field1)); + ASSERT_OK_AND_ASSIGN(auto merged1, field1->MergeWith(field2, options)); + ASSERT_OK_AND_ASSIGN(auto merged2, field2->MergeWith(field1, options)); AssertFieldEqual(merged1, expected); AssertFieldEqual(merged2, expected); } void CheckUnify(const std::shared_ptr& left, const std::shared_ptr& right, - const std::shared_ptr& expected) { + const std::shared_ptr& expected, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { auto field1 = field("a", left); auto field2 = field("a", right); - CheckUnify(field1, field2, field("a", expected)); + CheckUnify(field1, field2, field("a", expected), options); field1 = field("a", left, /*nullable=*/false); field2 = field("a", right, /*nullable=*/false); - CheckUnify(field1, field2, field("a", expected, /*nullable=*/false)); + CheckUnify(field1, field2, field("a", expected, /*nullable=*/false), options); field1 = field("a", left); field2 = field("a", right, /*nullable=*/false); - CheckUnify(field1, field2, field("a", expected, /*nullable=*/true)); + CheckUnify(field1, field2, field("a", expected, /*nullable=*/true), options); field1 = field("a", left, /*nullable=*/false); field2 = field("a", right); - CheckUnify(field1, field2, field("a", expected, /*nullable=*/true)); + CheckUnify(field1, field2, field("a", expected, /*nullable=*/true), options); } void CheckUnify(const std::shared_ptr& from, - const std::vector>& to) { + const std::vector>& to, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { for (const auto& ty : to) { - CheckUnify(from, ty, ty); + CheckUnify(from, ty, ty, options); } } }; @@ -1108,33 +1112,42 @@ TEST_F(TestUnifySchemas, MoreSchemas) { } TEST_F(TestUnifySchemas, Numerics) { - CheckUnify(uint8(), {int8(), uint16(), int16(), uint32(), int32(), uint64(), int64(), - float32(), float64()}); - CheckUnify(int8(), {int16(), int32(), int64(), float32(), float64()}); + CheckUnify(uint8(), + {int8(), uint16(), int16(), uint32(), int32(), uint64(), int64(), float32(), + float64()}, + Field::MergeOptions::Permissive()); + CheckUnify(int8(), {int16(), int32(), int64(), float32(), float64()}, + Field::MergeOptions::Permissive()); CheckUnify(uint16(), - {int16(), uint32(), int32(), uint64(), int64(), float32(), float64()}); - CheckUnify(int16(), {int32(), int64(), float32(), float64()}); - CheckUnify(uint32(), {int32(), uint64(), int64(), float32(), float64()}); - CheckUnify(int32(), {int64(), float32(), float64()}); - CheckUnify(uint64(), {int64(), float64()}); - CheckUnify(int64(), {float64()}); - CheckUnify(float16(), {float32(), float64()}); - CheckUnify(float32(), {float64()}); + {int16(), uint32(), int32(), uint64(), int64(), float32(), float64()}, + Field::MergeOptions::Permissive()); + CheckUnify(int16(), {int32(), int64(), float32(), float64()}, + Field::MergeOptions::Permissive()); + CheckUnify(uint32(), {int32(), uint64(), int64(), float32(), float64()}, + Field::MergeOptions::Permissive()); + CheckUnify(int32(), {int64(), float32(), float64()}, Field::MergeOptions::Permissive()); + CheckUnify(uint64(), {int64(), float64()}, Field::MergeOptions::Permissive()); + CheckUnify(int64(), {float64()}, Field::MergeOptions::Permissive()); + CheckUnify(float16(), {float32(), float64()}, Field::MergeOptions::Permissive()); + CheckUnify(float32(), {float64()}, Field::MergeOptions::Permissive()); // CheckUnifyFails(int8(), {uint8(), uint16(), uint32(), uint64()}); // uint32 and int32/int64; uint64 and int64 } TEST_F(TestUnifySchemas, Binary) { - CheckUnify(utf8(), {large_utf8(), binary(), large_binary()}); - CheckUnify(binary(), {large_binary()}); - CheckUnify(fixed_size_binary(2), {binary(), large_binary()}); - CheckUnify(fixed_size_binary(2), fixed_size_binary(4), binary()); + CheckUnify(utf8(), {large_utf8(), binary(), large_binary()}, + Field::MergeOptions::Permissive()); + CheckUnify(binary(), {large_binary()}, Field::MergeOptions::Permissive()); + CheckUnify(fixed_size_binary(2), {binary(), large_binary()}, + Field::MergeOptions::Permissive()); + CheckUnify(fixed_size_binary(2), fixed_size_binary(4), binary(), + Field::MergeOptions::Permissive()); } TEST_F(TestUnifySchemas, IncompatibleTypes) { auto int32_field = field("f", int32()); - auto uint8_field = field("f", utf8(), false); + auto uint8_field = field("f", uint8(), false); auto schema1 = schema({int32_field}); auto schema2 = schema({uint8_field}); From 7abdc13fba739dce5207e5ef2bda64d31a2c3270 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 17 Dec 2021 12:22:11 -0500 Subject: [PATCH 03/22] ARROW-14705: [C++] Add expected failures --- cpp/src/arrow/type.cc | 4 +- cpp/src/arrow/type.h | 22 ++++---- cpp/src/arrow/type_test.cc | 111 +++++++++++++++++++++++++++++-------- 3 files changed, 101 insertions(+), 36 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index bd33ebe2ebe..f3848b33c09 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -413,7 +413,7 @@ std::shared_ptr MergeTypes(std::shared_ptr promoted_type, } } - if (options.promote_binary) { + if (options.promote_large) { if (promoted_type->id() == Type::FIXED_SIZE_BINARY) { promoted_type = binary(); promoted = other_type->id() == Type::BINARY; @@ -422,7 +422,9 @@ std::shared_ptr MergeTypes(std::shared_ptr promoted_type, other_type = binary(); promoted = promoted_type->id() == Type::BINARY; } + } + if (options.promote_binary) { if (is_string(promoted_type->id()) && is_binary(other_type->id())) { promoted_type = MakeBinary(*promoted_type); promoted = diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index e4765e93f3d..8c702ca5cf8 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -310,24 +310,24 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// nullable). bool promote_nullability = true; - /// Allow an integer, float, or decimal of a given bit width to be - /// promoted to an equivalent type of a greater bit width. - bool promote_numeric_width = false; + /// Allow a decimal to be promoted to a float. The float type will + /// not itself be promoted (e.g. Decimal128 + Float32 = Float32). + bool promote_decimal_float = false; /// Allow an integer of a given bit width to be promoted to a /// float of an equal or greater bit width. bool promote_integer_float = false; + /// Allow an unsigned integer of a given bit width to be promoted + /// to a signed integer of the same bit width. + bool promote_integer_sign = false; + /// Allow an integer to be promoted to a decimal. /// /// May fail if the decimal has insufficient precision to /// accomodate the integer. (See increase_decimal_precision.) bool promote_integer_decimal = false; - /// Allow a decimal to be promoted to a float. The float type will - /// not itself be promoted (e.g. Decimal128 + Float32 = Float32). - bool promote_decimal_float = false; - /// When promoting another type to a decimal, increase precision /// (and possibly fail) to hold all possible values of the other type. /// @@ -337,6 +337,10 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// this is false, the unified type will be decimal128(38, 30). bool increase_decimal_precision = false; + /// Allow an integer, float, or decimal of a given bit width to be + /// promoted to an equivalent type of a greater bit width. + bool promote_numeric_width = false; + /// Promote Date32 to Date64. bool promote_date = false; @@ -356,10 +360,6 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// value types. bool promote_dictionary = false; - /// Allow an unsigned integer of a given bit width to be promoted - /// to a signed integer of the same bit width. - bool promote_integer_sign = false; - // TODO: does this include fixed size list? // TODO: does this include fixed size binary? /// Allow a type to be promoted to the Large variant. diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index f7ea9b352be..bfb508042a0 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -992,6 +992,16 @@ class TestUnifySchemas : public TestSchema { AssertFieldEqual(merged2, expected); } + void CheckUnifyFails( + const std::shared_ptr& field1, const std::shared_ptr& field2, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + ARROW_SCOPED_TRACE("options: ", options); + ARROW_SCOPED_TRACE("field2: ", field2->ToString()); + ARROW_SCOPED_TRACE("field1: ", field1->ToString()); + ASSERT_RAISES(Invalid, field1->MergeWith(field2, options)); + ASSERT_RAISES(Invalid, field2->MergeWith(field1, options)); + } + void CheckUnify(const std::shared_ptr& left, const std::shared_ptr& right, const std::shared_ptr& expected, @@ -1013,6 +1023,14 @@ class TestUnifySchemas : public TestSchema { CheckUnify(field1, field2, field("a", expected, /*nullable=*/true), options); } + void CheckUnifyFails( + const std::shared_ptr& left, const std::shared_ptr& right, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + auto field1 = field("a", left); + auto field2 = field("a", right); + CheckUnifyFails(field1, field2, options); + } + void CheckUnify(const std::shared_ptr& from, const std::vector>& to, const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { @@ -1020,6 +1038,24 @@ class TestUnifySchemas : public TestSchema { CheckUnify(from, ty, ty, options); } } + + void CheckUnifyFails( + const std::shared_ptr& from, + const std::vector>& to, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + for (const auto& ty : to) { + CheckUnifyFails(from, ty, options); + } + } + + void CheckUnifyFails( + const std::vector>& from, + const std::vector>& to, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + for (const auto& ty : from) { + CheckUnifyFails(ty, to, options); + } + } }; TEST_F(TestUnifySchemas, EmptyInput) { ASSERT_RAISES(Invalid, UnifySchemas({})); } @@ -1111,38 +1147,65 @@ TEST_F(TestUnifySchemas, MoreSchemas) { utf8_field->WithNullable(true)})); } -TEST_F(TestUnifySchemas, Numerics) { +TEST_F(TestUnifySchemas, Numeric) { + auto options = Field::MergeOptions::Defaults(); + options.promote_numeric_width = true; + options.promote_integer_float = true; + options.promote_integer_sign = true; CheckUnify(uint8(), {int8(), uint16(), int16(), uint32(), int32(), uint64(), int64(), float32(), float64()}, - Field::MergeOptions::Permissive()); - CheckUnify(int8(), {int16(), int32(), int64(), float32(), float64()}, - Field::MergeOptions::Permissive()); + options); + CheckUnify(int8(), {int16(), int32(), int64(), float32(), float64()}, options); CheckUnify(uint16(), {int16(), uint32(), int32(), uint64(), int64(), float32(), float64()}, - Field::MergeOptions::Permissive()); - CheckUnify(int16(), {int32(), int64(), float32(), float64()}, - Field::MergeOptions::Permissive()); - CheckUnify(uint32(), {int32(), uint64(), int64(), float32(), float64()}, - Field::MergeOptions::Permissive()); - CheckUnify(int32(), {int64(), float32(), float64()}, Field::MergeOptions::Permissive()); - CheckUnify(uint64(), {int64(), float64()}, Field::MergeOptions::Permissive()); - CheckUnify(int64(), {float64()}, Field::MergeOptions::Permissive()); - CheckUnify(float16(), {float32(), float64()}, Field::MergeOptions::Permissive()); - CheckUnify(float32(), {float64()}, Field::MergeOptions::Permissive()); - - // CheckUnifyFails(int8(), {uint8(), uint16(), uint32(), uint64()}); - // uint32 and int32/int64; uint64 and int64 + options); + CheckUnify(int16(), {int32(), int64(), float32(), float64()}, options); + CheckUnify(uint32(), {int32(), uint64(), int64(), float32(), float64()}, options); + CheckUnify(int32(), {int64(), float32(), float64()}, options); + CheckUnify(uint64(), {int64(), float64()}, options); + CheckUnify(int64(), {float64()}, options); + CheckUnify(float16(), {float32(), float64()}, options); + CheckUnify(float32(), {float64()}, options); + CheckUnify(uint64(), float32(), float64(), options); + CheckUnify(int64(), float32(), float64(), options); + + options.promote_integer_sign = false; + CheckUnify(uint8(), {uint16(), uint32(), uint64(), float32(), float64()}, options); + CheckUnifyFails(uint8(), {int8(), int16(), int32(), int64()}, options); + CheckUnify(uint16(), {uint32(), uint64(), float32(), float64()}, options); + CheckUnifyFails(uint16(), {int16(), int32(), int64()}, options); + CheckUnify(uint32(), {uint64(), float32(), float64()}, options); + CheckUnifyFails(uint32(), {int32(), int64()}, options); + CheckUnify(uint64(), {float64()}, options); + CheckUnifyFails(uint64(), {int64()}, options); + + options.promote_integer_sign = true; + options.promote_integer_float = false; + CheckUnifyFails(IntTypes(), FloatingPointTypes(), options); + + options.promote_integer_float = true; + options.promote_numeric_width = false; + CheckUnifyFails(int8(), {int16(), int32(), int64()}, options); + CheckUnifyFails(int16(), {int32(), int64()}, options); + CheckUnifyFails(int32(), {int64()}, options); } TEST_F(TestUnifySchemas, Binary) { - CheckUnify(utf8(), {large_utf8(), binary(), large_binary()}, - Field::MergeOptions::Permissive()); - CheckUnify(binary(), {large_binary()}, Field::MergeOptions::Permissive()); - CheckUnify(fixed_size_binary(2), {binary(), large_binary()}, - Field::MergeOptions::Permissive()); - CheckUnify(fixed_size_binary(2), fixed_size_binary(4), binary(), - Field::MergeOptions::Permissive()); + auto options = Field::MergeOptions::Defaults(); + options.promote_large = true; + options.promote_binary = true; + CheckUnify(utf8(), {large_utf8(), binary(), large_binary()}, options); + CheckUnify(binary(), {large_binary()}, options); + CheckUnify(fixed_size_binary(2), {binary(), large_binary()}, options); + CheckUnify(fixed_size_binary(2), fixed_size_binary(4), binary(), options); + + options.promote_large = false; + CheckUnifyFails({utf8(), binary()}, {large_utf8(), large_binary()}); + CheckUnifyFails(fixed_size_binary(2), BaseBinaryTypes()); + + options.promote_binary = false; + CheckUnifyFails(utf8(), {binary(), large_binary(), fixed_size_binary(2)}); } TEST_F(TestUnifySchemas, IncompatibleTypes) { From f993f90a05309ecb3b820435f7d3c88148e91a14 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 17 Dec 2021 13:26:34 -0500 Subject: [PATCH 04/22] ARROW-14705: [C++] Add tests for unimplemented flags --- cpp/src/arrow/type.cc | 8 +-- cpp/src/arrow/type.h | 17 ++--- cpp/src/arrow/type_test.cc | 131 +++++++++++++++++++++++++++++++++++-- 3 files changed, 140 insertions(+), 16 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index f3848b33c09..142beb0077b 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -266,8 +266,8 @@ Field::MergeOptions Field::MergeOptions::Permissive() { options.increase_decimal_precision = true; options.promote_date = true; options.promote_time = true; - options.promote_duration_units = true; - options.promote_timestamp_units = true; + options.promote_duration = true; + options.promote_timestamp = true; options.promote_nested = true; options.promote_dictionary = true; options.promote_integer_sign = true; @@ -288,8 +288,8 @@ std::string Field::MergeOptions::ToString() const { << (increase_decimal_precision ? "true" : "false"); ss << ", promote_date=" << (promote_date ? "true" : "false"); ss << ", promote_time=" << (promote_time ? "true" : "false"); - ss << ", promote_duration_units=" << (promote_duration_units ? "true" : "false"); - ss << ", promote_timestamp_units=" << (promote_timestamp_units ? "true" : "false"); + ss << ", promote_duration=" << (promote_duration ? "true" : "false"); + ss << ", promote_timestamp=" << (promote_timestamp ? "true" : "false"); ss << ", promote_nested=" << (promote_nested ? "true" : "false"); ss << ", promote_dictionary=" << (promote_dictionary ? "true" : "false"); ss << ", promote_integer_sign=" << (promote_integer_sign ? "true" : "false"); diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 8c702ca5cf8..8dedbd49689 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -310,6 +310,9 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// nullable). bool promote_nullability = true; + /// Allow a decimal to be unified with another decimal of the same width. + bool promote_decimal = false; + /// Allow a decimal to be promoted to a float. The float type will /// not itself be promoted (e.g. Decimal128 + Float32 = Float32). bool promote_decimal_float = false; @@ -332,9 +335,9 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// (and possibly fail) to hold all possible values of the other type. /// /// For example: unifying int64 and decimal256(76, 70) will fail - /// if this is true since we need at least 19 digits to the left - /// of the decimal point but we are already at max precision. If - /// this is false, the unified type will be decimal128(38, 30). + /// if this is true, since we need at least 19 digits to the left + /// of the decimal point, but we are already at max precision. If + /// this is false, the unified type will be decimal256(76, 70). bool increase_decimal_precision = false; /// Allow an integer, float, or decimal of a given bit width to be @@ -344,14 +347,14 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// Promote Date32 to Date64. bool promote_date = false; - /// Promote Time32 to Time64. + /// Promote Time32 to Time64, or Time32(SECOND) to Time32(MILLI), etc. bool promote_time = false; /// Promote second to millisecond, etc. - bool promote_duration_units = false; + bool promote_duration = false; /// Promote second to millisecond, etc. - bool promote_timestamp_units = false; + bool promote_timestamp = false; /// Recursively merge nested types. bool promote_nested = false; @@ -360,8 +363,6 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// value types. bool promote_dictionary = false; - // TODO: does this include fixed size list? - // TODO: does this include fixed size binary? /// Allow a type to be promoted to the Large variant. bool promote_large = false; diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index bfb508042a0..0c803c2ae58 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -1171,13 +1171,15 @@ TEST_F(TestUnifySchemas, Numeric) { CheckUnify(int64(), float32(), float64(), options); options.promote_integer_sign = false; - CheckUnify(uint8(), {uint16(), uint32(), uint64(), float32(), float64()}, options); + CheckUnify(uint8(), {uint16(), uint32(), uint64()}, options); + CheckUnify(int8(), {int16(), int32(), int64()}, options); CheckUnifyFails(uint8(), {int8(), int16(), int32(), int64()}, options); - CheckUnify(uint16(), {uint32(), uint64(), float32(), float64()}, options); + CheckUnify(uint16(), {uint32(), uint64()}, options); + CheckUnify(int16(), {int32(), int64()}, options); CheckUnifyFails(uint16(), {int16(), int32(), int64()}, options); - CheckUnify(uint32(), {uint64(), float32(), float64()}, options); + CheckUnify(uint32(), {uint64()}, options); + CheckUnify(int32(), {int64()}, options); CheckUnifyFails(uint32(), {int32(), int64()}, options); - CheckUnify(uint64(), {float64()}, options); CheckUnifyFails(uint64(), {int64()}, options); options.promote_integer_sign = true; @@ -1191,6 +1193,127 @@ TEST_F(TestUnifySchemas, Numeric) { CheckUnifyFails(int32(), {int64()}, options); } +TEST_F(TestUnifySchemas, Decimal) { + auto options = Field::MergeOptions::Defaults(); + + options.promote_decimal = true; + CheckUnify(decimal128(3, 2), decimal128(5, 2), decimal128(5, 2), options); + CheckUnify(decimal128(3, 2), decimal128(5, 3), decimal128(5, 3), options); + CheckUnify(decimal128(3, 2), decimal128(5, 1), decimal128(6, 2), options); + CheckUnify(decimal128(3, 2), decimal128(5, -2), decimal128(9, 2), options); + CheckUnify(decimal128(3, -2), decimal128(5, -2), decimal128(5, -2), options); + + options.promote_decimal_float = true; + CheckUnify(decimal128(3, 2), {float32(), float64()}, options); + CheckUnify(decimal256(3, 2), {float32(), float64()}, options); + + options.promote_integer_decimal = true; + CheckUnify(int32(), decimal128(3, 2), decimal128(3, 2), options); + CheckUnify(int32(), decimal128(3, -2), decimal128(3, -2), options); + + options.increase_decimal_precision = true; + // int32() is essentially decimal128(10, 0) + CheckUnify(int32(), decimal128(3, 2), decimal128(12, 2), options); + CheckUnify(int32(), decimal128(3, -2), decimal128(10, 0), options); + + options.promote_numeric_width = true; + CheckUnify(decimal128(3, 2), decimal256(5, 2), decimal256(5, 2), options); + CheckUnify(int32(), decimal128(38, 37), decimal256(47, 37), options); +} + +TEST_F(TestUnifySchemas, Temporal) { + auto options = Field::MergeOptions::Defaults(); + + options.promote_date = true; + CheckUnify(date32(), {date64()}, options); + + options.promote_time = true; + CheckUnify(time32(TimeUnit::SECOND), + {time32(TimeUnit::MILLI), time64(TimeUnit::MICRO), time64(TimeUnit::NANO)}, + options); + CheckUnify(time32(TimeUnit::MILLI), {time64(TimeUnit::MICRO), time64(TimeUnit::NANO)}, + options); + CheckUnify(time64(TimeUnit::MICRO), {time64(TimeUnit::NANO)}, options); + + options.promote_duration = true; + CheckUnify( + duration(TimeUnit::SECOND), + {duration(TimeUnit::MILLI), duration(TimeUnit::MICRO), duration(TimeUnit::NANO)}, + options); + CheckUnify(duration(TimeUnit::MILLI), + {duration(TimeUnit::MICRO), duration(TimeUnit::NANO)}, options); + CheckUnify(duration(TimeUnit::MICRO), {duration(TimeUnit::NANO)}, options); + + options.promote_timestamp = true; + CheckUnify( + timestamp(TimeUnit::SECOND), + {timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MICRO), timestamp(TimeUnit::NANO)}, + options); + CheckUnify(timestamp(TimeUnit::MILLI), + {timestamp(TimeUnit::MICRO), timestamp(TimeUnit::NANO)}, options); + CheckUnify(timestamp(TimeUnit::MICRO), {timestamp(TimeUnit::NANO)}, options); + + CheckUnifyFails(timestamp(TimeUnit::SECOND), timestamp(TimeUnit::SECOND, "UTC"), + options); +} + +TEST_F(TestUnifySchemas, List) { + auto options = Field::MergeOptions::Defaults(); + options.promote_numeric_width = true; + + options.promote_large = true; + CheckUnify(list(int8()), {large_list(int8())}, options); + CheckUnify(fixed_size_list(int8(), 2), {list(int8()), large_list(int8())}, options); + + options.promote_nested = true; + CheckUnify(list(int8()), {list(int16()), list(int32()), list(int64())}, options); + CheckUnify(fixed_size_list(int8(), 2), {list(int16()), list(int32()), list(int64())}, + options); +} + +TEST_F(TestUnifySchemas, Map) { + auto options = Field::MergeOptions::Defaults(); + options.promote_nested = true; + options.promote_numeric_width = true; + + CheckUnify(map(int8(), int32()), + {map(int8(), int64()), map(int16(), int32()), map(int64(), int64())}, + options); +} + +TEST_F(TestUnifySchemas, Struct) { + auto options = Field::MergeOptions::Defaults(); + options.promote_nested = true; + options.promote_numeric_width = true; + options.promote_binary = true; + + CheckUnify(struct_({}), struct_({field("a", int8())}), struct_({field("a", int8())}), + options); + + CheckUnify(struct_({field("b", utf8())}), struct_({field("a", int8())}), + struct_({field("b", utf8()), field("a", int8())}), options); + + CheckUnify(struct_({field("b", utf8())}), struct_({field("b", binary())}), + struct_({field("b", binary())}), options); + + CheckUnify(struct_({field("a", int8()), field("b", utf8())}), + struct_({field("b", utf8()), field("a", int8())}), + struct_({field("a", int8()), field("b", utf8())}), options); +} + +TEST_F(TestUnifySchemas, Dictionary) { + auto options = Field::MergeOptions::Defaults(); + options.promote_dictionary = true; + options.promote_large = true; + + CheckUnify(dictionary(int8(), utf8()), + { + dictionary(int64(), utf8()), + dictionary(int8(), large_utf8()), + }, + options); +} + TEST_F(TestUnifySchemas, Binary) { auto options = Field::MergeOptions::Defaults(); options.promote_large = true; From 2dc904aa65cb9884cb1f116683335de04d181b7e Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 17 Dec 2021 13:43:11 -0500 Subject: [PATCH 05/22] ARROW-14705: [C++] Implement dictionary merging --- cpp/src/arrow/type.cc | 18 ++++++++++++++++++ cpp/src/arrow/type.h | 5 +++++ cpp/src/arrow/type_test.cc | 14 ++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 142beb0077b..c6af833e72b 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -337,6 +337,8 @@ std::shared_ptr MakeBinary(const DataType& type) { std::shared_ptr MergeTypes(std::shared_ptr promoted_type, std::shared_ptr other_type, const Field::MergeOptions& options) { + if (promoted_type->Equals(*other_type)) return promoted_type; + bool promoted = false; if (options.promote_nullability) { if (promoted_type->id() == Type::NA) { @@ -348,6 +350,22 @@ std::shared_ptr MergeTypes(std::shared_ptr promoted_type, return nullptr; } + if (options.promote_dictionary && is_dictionary(promoted_type->id()) && + is_dictionary(other_type->id())) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + if (!options.promote_dictionary_ordered && left.ordered() != right.ordered()) { + return nullptr; + } + Field::MergeOptions index_options = options; + index_options.promote_integer_sign = true; + index_options.promote_numeric_width = true; + auto indices = MergeTypes(left.index_type(), right.index_type(), index_options); + auto values = MergeTypes(left.value_type(), right.value_type(), options); + auto ordered = left.ordered() && right.ordered(); + return (indices && values) ? dictionary(indices, values, ordered) : nullptr; + } + if (options.promote_integer_sign) { if (is_unsigned_integer(promoted_type->id()) && is_signed_integer(other_type->id())) { promoted = bit_width(other_type->id()) >= bit_width(promoted_type->id()); diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 8dedbd49689..ac1d84a3083 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -363,6 +363,11 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// value types. bool promote_dictionary = false; + /// Allow merging ordered and non-ordered dictionaries, else + /// error. The result will be ordered if and only if both inputs + /// are ordered. + bool promote_dictionary_ordered = false; + /// Allow a type to be promoted to the Large variant. bool promote_large = false; diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index 0c803c2ae58..a66d48d2cf2 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -1269,6 +1269,8 @@ TEST_F(TestUnifySchemas, List) { CheckUnify(list(int8()), {list(int16()), list(int32()), list(int64())}, options); CheckUnify(fixed_size_list(int8(), 2), {list(int16()), list(int32()), list(int64())}, options); + + // TODO: test nonstandard field names } TEST_F(TestUnifySchemas, Map) { @@ -1312,6 +1314,18 @@ TEST_F(TestUnifySchemas, Dictionary) { dictionary(int8(), large_utf8()), }, options); + CheckUnify(dictionary(int8(), utf8(), /*ordered=*/true), + { + dictionary(int64(), utf8(), /*ordered=*/true), + dictionary(int8(), large_utf8(), /*ordered=*/true), + }, + options); + CheckUnifyFails(dictionary(int8(), utf8()), + dictionary(int8(), utf8(), /*ordered=*/true), options); + + options.promote_dictionary_ordered = true; + CheckUnify(dictionary(int8(), utf8()), dictionary(int8(), utf8(), /*ordered=*/true), + dictionary(int8(), utf8(), /*ordered=*/false), options); } TEST_F(TestUnifySchemas, Binary) { From daf3392ea394ef8920b030126f1eacf5b40f63cc Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 17 Dec 2021 14:15:47 -0500 Subject: [PATCH 06/22] ARROW-14705: [C++] Implement decimals --- cpp/src/arrow/type.cc | 50 +++++++++++++++++++++++++++++++++++--- cpp/src/arrow/type.h | 13 +++------- cpp/src/arrow/type_test.cc | 15 ++++++------ 3 files changed, 57 insertions(+), 21 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index c6af833e72b..276149712b5 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -263,7 +263,6 @@ Field::MergeOptions Field::MergeOptions::Permissive() { options.promote_integer_float = true; options.promote_integer_decimal = true; options.promote_decimal_float = true; - options.increase_decimal_precision = true; options.promote_date = true; options.promote_time = true; options.promote_duration = true; @@ -284,8 +283,6 @@ std::string Field::MergeOptions::ToString() const { ss << ", promote_integer_float=" << (promote_integer_float ? "true" : "false"); ss << ", promote_integer_decimal=" << (promote_integer_decimal ? "true" : "false"); ss << ", promote_decimal_float=" << (promote_decimal_float ? "true" : "false"); - ss << ", increase_decimal_precision=" - << (increase_decimal_precision ? "true" : "false"); ss << ", promote_date=" << (promote_date ? "true" : "false"); ss << ", promote_time=" << (promote_time ? "true" : "false"); ss << ", promote_duration=" << (promote_duration ? "true" : "false"); @@ -363,9 +360,56 @@ std::shared_ptr MergeTypes(std::shared_ptr promoted_type, auto indices = MergeTypes(left.index_type(), right.index_type(), index_options); auto values = MergeTypes(left.value_type(), right.value_type(), options); auto ordered = left.ordered() && right.ordered(); + // TODO: make this return Result so we can report a more detailed error return (indices && values) ? dictionary(indices, values, ordered) : nullptr; } + if (options.promote_decimal_float) { + if (is_decimal(promoted_type->id()) && is_floating(other_type->id())) { + promoted_type = other_type; + promoted = true; + } else if (is_floating(promoted_type->id()) && is_decimal(other_type->id())) { + other_type = promoted_type; + promoted = true; + } + } + + if (options.promote_integer_decimal) { + if (is_integer(promoted_type->id()) && is_decimal(other_type->id())) { + promoted_type.swap(other_type); + } + + if (is_decimal(promoted_type->id()) && is_integer(other_type->id())) { + int32_t precision = 0; + if (!MaxDecimalDigitsForInteger(other_type->id()).Value(&precision).ok()) { + return nullptr; + } + // TODO: return result and use DecimalType::Make + other_type = promoted_type->id() == Type::DECIMAL128 ? decimal128(precision, 0) + : decimal256(precision, 0); + promoted = true; + } + } + + if (options.promote_decimal && is_decimal(promoted_type->id()) && + is_decimal(other_type->id())) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + if (!options.promote_numeric_width && left.bit_width() != right.bit_width()) + return nullptr; + const int32_t max_scale = std::max(left.scale(), right.scale()); + const int32_t common_precision = + std::max(left.precision() + max_scale - left.scale(), + right.precision() + max_scale - right.scale()); + // TODO: return result and use DecimalType::Make + if (left.id() == Type::DECIMAL256 || right.id() == Type::DECIMAL256 || + (options.promote_numeric_width && + common_precision > BasicDecimal128::kMaxPrecision)) { + return decimal256(common_precision, max_scale); + } + return decimal128(common_precision, max_scale); + } + if (options.promote_integer_sign) { if (is_unsigned_integer(promoted_type->id()) && is_signed_integer(other_type->id())) { promoted = bit_width(other_type->id()) >= bit_width(promoted_type->id()); diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index ac1d84a3083..fa0679179a9 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -310,7 +310,9 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// nullable). bool promote_nullability = true; - /// Allow a decimal to be unified with another decimal of the same width. + /// Allow a decimal to be unified with another decimal of the same + /// width, adjusting scale and precision as appropriate. May fail + /// if the adjustment is not possible. bool promote_decimal = false; /// Allow a decimal to be promoted to a float. The float type will @@ -331,15 +333,6 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// accomodate the integer. (See increase_decimal_precision.) bool promote_integer_decimal = false; - /// When promoting another type to a decimal, increase precision - /// (and possibly fail) to hold all possible values of the other type. - /// - /// For example: unifying int64 and decimal256(76, 70) will fail - /// if this is true, since we need at least 19 digits to the left - /// of the decimal point, but we are already at max precision. If - /// this is false, the unified type will be decimal256(76, 70). - bool increase_decimal_precision = false; - /// Allow an integer, float, or decimal of a given bit width to be /// promoted to an equivalent type of a greater bit width. bool promote_numeric_width = false; diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index a66d48d2cf2..65ada1c5fd8 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -1196,13 +1196,6 @@ TEST_F(TestUnifySchemas, Numeric) { TEST_F(TestUnifySchemas, Decimal) { auto options = Field::MergeOptions::Defaults(); - options.promote_decimal = true; - CheckUnify(decimal128(3, 2), decimal128(5, 2), decimal128(5, 2), options); - CheckUnify(decimal128(3, 2), decimal128(5, 3), decimal128(5, 3), options); - CheckUnify(decimal128(3, 2), decimal128(5, 1), decimal128(6, 2), options); - CheckUnify(decimal128(3, 2), decimal128(5, -2), decimal128(9, 2), options); - CheckUnify(decimal128(3, -2), decimal128(5, -2), decimal128(5, -2), options); - options.promote_decimal_float = true; CheckUnify(decimal128(3, 2), {float32(), float64()}, options); CheckUnify(decimal256(3, 2), {float32(), float64()}, options); @@ -1211,7 +1204,13 @@ TEST_F(TestUnifySchemas, Decimal) { CheckUnify(int32(), decimal128(3, 2), decimal128(3, 2), options); CheckUnify(int32(), decimal128(3, -2), decimal128(3, -2), options); - options.increase_decimal_precision = true; + options.promote_decimal = true; + CheckUnify(decimal128(3, 2), decimal128(5, 2), decimal128(5, 2), options); + CheckUnify(decimal128(3, 2), decimal128(5, 3), decimal128(5, 3), options); + CheckUnify(decimal128(3, 2), decimal128(5, 1), decimal128(6, 2), options); + CheckUnify(decimal128(3, 2), decimal128(5, -2), decimal128(9, 2), options); + CheckUnify(decimal128(3, -2), decimal128(5, -2), decimal128(5, -2), options); + // int32() is essentially decimal128(10, 0) CheckUnify(int32(), decimal128(3, 2), decimal128(12, 2), options); CheckUnify(int32(), decimal128(3, -2), decimal128(10, 0), options); From a2547f85218f5cd5293af47fcef8028418e3ecd7 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 17 Dec 2021 14:22:39 -0500 Subject: [PATCH 07/22] ARROW-14705: [C++] Report better errors --- cpp/src/arrow/type.cc | 46 +++++++++++++++++++++++--------------- cpp/src/arrow/type_test.cc | 5 +++++ 2 files changed, 33 insertions(+), 18 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 276149712b5..1e27f53de14 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -331,9 +331,9 @@ std::shared_ptr MakeBinary(const DataType& type) { } return std::shared_ptr(nullptr); } -std::shared_ptr MergeTypes(std::shared_ptr promoted_type, - std::shared_ptr other_type, - const Field::MergeOptions& options) { +Result> MergeTypes(std::shared_ptr promoted_type, + std::shared_ptr other_type, + const Field::MergeOptions& options) { if (promoted_type->Equals(*other_type)) return promoted_type; bool promoted = false; @@ -344,7 +344,7 @@ std::shared_ptr MergeTypes(std::shared_ptr promoted_type, return promoted_type; } } else if (promoted_type->id() == Type::NA || other_type->id() == Type::NA) { - return nullptr; + return Status::Invalid("Cannot merge type with null unless promote_nullability=true"); } if (options.promote_dictionary && is_dictionary(promoted_type->id()) && @@ -357,11 +357,17 @@ std::shared_ptr MergeTypes(std::shared_ptr promoted_type, Field::MergeOptions index_options = options; index_options.promote_integer_sign = true; index_options.promote_numeric_width = true; - auto indices = MergeTypes(left.index_type(), right.index_type(), index_options); - auto values = MergeTypes(left.value_type(), right.value_type(), options); + ARROW_ASSIGN_OR_RAISE( + auto indices, MergeTypes(left.index_type(), right.index_type(), index_options)); + ARROW_ASSIGN_OR_RAISE(auto values, + MergeTypes(left.value_type(), right.value_type(), options)); auto ordered = left.ordered() && right.ordered(); - // TODO: make this return Result so we can report a more detailed error - return (indices && values) ? dictionary(indices, values, ordered) : nullptr; + if (indices && values) { + return dictionary(indices, values, ordered); + } + return Status::Invalid( + "Cannot merge ordered and unordered dictionary unless " + "promote_dictionary_ordered=true"); } if (options.promote_decimal_float) { @@ -380,13 +386,10 @@ std::shared_ptr MergeTypes(std::shared_ptr promoted_type, } if (is_decimal(promoted_type->id()) && is_integer(other_type->id())) { - int32_t precision = 0; - if (!MaxDecimalDigitsForInteger(other_type->id()).Value(&precision).ok()) { - return nullptr; - } - // TODO: return result and use DecimalType::Make - other_type = promoted_type->id() == Type::DECIMAL128 ? decimal128(precision, 0) - : decimal256(precision, 0); + ARROW_ASSIGN_OR_RAISE(const int32_t precision, + MaxDecimalDigitsForInteger(other_type->id())); + ARROW_ASSIGN_OR_RAISE(other_type, + DecimalType::Make(promoted_type->id(), precision, 0)); promoted = true; } } @@ -405,9 +408,9 @@ std::shared_ptr MergeTypes(std::shared_ptr promoted_type, if (left.id() == Type::DECIMAL256 || right.id() == Type::DECIMAL256 || (options.promote_numeric_width && common_precision > BasicDecimal128::kMaxPrecision)) { - return decimal256(common_precision, max_scale); + return DecimalType::Make(Type::DECIMAL256, common_precision, max_scale); } - return decimal128(common_precision, max_scale); + return DecimalType::Make(Type::DECIMAL128, common_precision, max_scale); } if (options.promote_integer_sign) { @@ -543,7 +546,14 @@ Result> Field::MergeWith(const Field& other, return Copy(); } - auto promoted_type = MergeTypes(type_, other.type(), options); + auto maybe_promoted_type = MergeTypes(type_, other.type(), options); + if (!maybe_promoted_type.ok()) { + return maybe_promoted_type.status().WithMessage( + "Unable to merge: Field ", name(), + " has incompatible types: ", type()->ToString(), " vs ", other.type()->ToString(), + ": ", maybe_promoted_type.status().message()); + } + auto promoted_type = move(maybe_promoted_type).MoveValueUnsafe(); if (promoted_type) { bool nullable = nullable_; if (options.promote_nullability) { diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index 65ada1c5fd8..6ad279c759b 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -1215,9 +1215,14 @@ TEST_F(TestUnifySchemas, Decimal) { CheckUnify(int32(), decimal128(3, 2), decimal128(12, 2), options); CheckUnify(int32(), decimal128(3, -2), decimal128(10, 0), options); + CheckUnifyFails(decimal256(1, 0), decimal128(1, 0), options); + CheckUnifyFails(int64(), decimal128(38, 37), options); + options.promote_numeric_width = true; CheckUnify(decimal128(3, 2), decimal256(5, 2), decimal256(5, 2), options); CheckUnify(int32(), decimal128(38, 37), decimal256(47, 37), options); + + CheckUnifyFails(int64(), decimal256(76, 75), options); } TEST_F(TestUnifySchemas, Temporal) { From 34020dc442ad31ec6524f02d432a668afaff6851 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 17 Dec 2021 14:24:55 -0500 Subject: [PATCH 08/22] ARROW-14705: [C++] Update TODOs --- cpp/src/arrow/type.cc | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 1e27f53de14..629a718fafc 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -517,19 +517,14 @@ Result> MergeTypes(std::shared_ptr promoted_ // TODO // Date32 -> Date64 + // Time32 -> Time64 (and units) // Timestamp units - // Time32 -> Time64 - // Decimal128 -> Decimal256 - // Integer -> Decimal - // Decimal -> Float // Duration units - // List(A) -> List(B) + // List(A) -> List(B) (incl. fixed) // List -> LargeList - // Unions? // Dictionary: indices, values // Struct: reconcile order, fields, types // Map - // Fixed size list return promoted ? promoted_type : nullptr; } From 651c4ccec2bc2e302bc5601c631033824f2ad872 Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 17 Dec 2021 15:09:41 -0500 Subject: [PATCH 09/22] ARROW-14705: [C++] Implement temporal types --- cpp/src/arrow/type.cc | 60 +++++++++++++++++++++++++++++++++---- cpp/src/arrow/type_test.cc | 2 ++ cpp/src/arrow/type_traits.h | 11 +++++++ 3 files changed, 67 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 629a718fafc..442b4de365e 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -331,6 +331,16 @@ std::shared_ptr MakeBinary(const DataType& type) { } return std::shared_ptr(nullptr); } +TimeUnit::type CommonTimeUnit(TimeUnit::type left, TimeUnit::type right) { + if (left == TimeUnit::NANO || right == TimeUnit::NANO) { + return TimeUnit::NANO; + } else if (left == TimeUnit::MICRO || right == TimeUnit::MICRO) { + return TimeUnit::MICRO; + } else if (left == TimeUnit::MILLI || right == TimeUnit::MILLI) { + return TimeUnit::MILLI; + } + return TimeUnit::SECOND; +} Result> MergeTypes(std::shared_ptr promoted_type, std::shared_ptr other_type, const Field::MergeOptions& options) { @@ -370,6 +380,50 @@ Result> MergeTypes(std::shared_ptr promoted_ "promote_dictionary_ordered=true"); } + if (options.promote_date) { + if (promoted_type->id() == Type::DATE32 && other_type->id() == Type::DATE64) { + return date64(); + } + if (promoted_type->id() == Type::DATE64 && other_type->id() == Type::DATE32) { + return date64(); + } + } + + if (options.promote_duration) { + if (promoted_type->id() == Type::DURATION && other_type->id() == Type::DURATION) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + return duration(CommonTimeUnit(left.unit(), right.unit())); + } + } + + if (options.promote_time) { + if (is_time(promoted_type->id()) && is_time(other_type->id())) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + const auto unit = CommonTimeUnit(left.unit(), right.unit()); + if (unit == TimeUnit::MICRO || unit == TimeUnit::NANO) { + return time64(unit); + } + return time32(unit); + } + } + + if (options.promote_timestamp) { + if (promoted_type->id() == Type::TIMESTAMP && other_type->id() == Type::TIMESTAMP) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + if (left.timezone().empty() ^ right.timezone().empty()) { + return Status::Invalid( + "Cannot merge timestamp with timezone and timestamp without timezone"); + } + if (left.timezone() != right.timezone()) { + return Status::Invalid("Cannot merge timestamps with differing timezones"); + } + return timestamp(CommonTimeUnit(left.unit(), right.unit()), left.timezone()); + } + } + if (options.promote_decimal_float) { if (is_decimal(promoted_type->id()) && is_floating(other_type->id())) { promoted_type = other_type; @@ -404,7 +458,6 @@ Result> MergeTypes(std::shared_ptr promoted_ const int32_t common_precision = std::max(left.precision() + max_scale - left.scale(), right.precision() + max_scale - right.scale()); - // TODO: return result and use DecimalType::Make if (left.id() == Type::DECIMAL256 || right.id() == Type::DECIMAL256 || (options.promote_numeric_width && common_precision > BasicDecimal128::kMaxPrecision)) { @@ -516,13 +569,8 @@ Result> MergeTypes(std::shared_ptr promoted_ } // TODO - // Date32 -> Date64 - // Time32 -> Time64 (and units) - // Timestamp units - // Duration units // List(A) -> List(B) (incl. fixed) // List -> LargeList - // Dictionary: indices, values // Struct: reconcile order, fields, types // Map diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index 6ad279c759b..94f25d1c7d6 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -1259,6 +1259,8 @@ TEST_F(TestUnifySchemas, Temporal) { CheckUnifyFails(timestamp(TimeUnit::SECOND), timestamp(TimeUnit::SECOND, "UTC"), options); + CheckUnifyFails(timestamp(TimeUnit::SECOND, "America/New_York"), + timestamp(TimeUnit::SECOND, "UTC"), options); } TEST_F(TestUnifySchemas, List) { diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 278997d8463..aaccaa53c14 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -879,6 +879,17 @@ static inline bool is_decimal(Type::type type_id) { return false; } +static inline bool is_time(Type::type type_id) { + switch (type_id) { + case Type::TIME32: + case Type::TIME64: + return true; + default: + break; + } + return false; +} + static inline bool is_primitive(Type::type type_id) { switch (type_id) { case Type::BOOL: From 940cfe5db68ceac6e87b85650515da99d8d4d74f Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 17 Dec 2021 15:53:13 -0500 Subject: [PATCH 10/22] ARROW-14705: [C++] Implement list types --- cpp/src/arrow/type.cc | 59 ++++++++++++++++++++++++++++++++----- cpp/src/arrow/type_test.cc | 9 ++++-- cpp/src/arrow/type_traits.h | 11 +++++++ 3 files changed, 70 insertions(+), 9 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 442b4de365e..9e64b5e6468 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -357,6 +357,7 @@ Result> MergeTypes(std::shared_ptr promoted_ return Status::Invalid("Cannot merge type with null unless promote_nullability=true"); } + // TODO: split these out if (options.promote_dictionary && is_dictionary(promoted_type->id()) && is_dictionary(other_type->id())) { const auto& left = checked_cast(*promoted_type); @@ -532,14 +533,28 @@ Result> MergeTypes(std::shared_ptr promoted_ } if (options.promote_large) { - if (promoted_type->id() == Type::FIXED_SIZE_BINARY) { + if (promoted_type->id() == Type::FIXED_SIZE_BINARY && + is_base_binary_like(other_type->id())) { promoted_type = binary(); promoted = other_type->id() == Type::BINARY; } - if (other_type->id() == Type::FIXED_SIZE_BINARY) { + if (other_type->id() == Type::FIXED_SIZE_BINARY && + is_base_binary_like(promoted_type->id())) { other_type = binary(); promoted = promoted_type->id() == Type::BINARY; } + + if (promoted_type->id() == Type::FIXED_SIZE_LIST && + is_var_size_list(other_type->id())) { + promoted_type = + list(checked_cast(*promoted_type).value_field()); + promoted = other_type->Equals(*promoted_type); + } + if (other_type->id() == Type::FIXED_SIZE_LIST && + is_var_size_list(promoted_type->id())) { + other_type = list(checked_cast(*other_type).value_field()); + promoted = other_type->Equals(*promoted_type); + } } if (options.promote_binary) { @@ -557,20 +572,50 @@ Result> MergeTypes(std::shared_ptr promoted_ if (options.promote_large) { if ((promoted_type->id() == Type::STRING && other_type->id() == Type::LARGE_STRING) || (promoted_type->id() == Type::LARGE_STRING && other_type->id() == Type::STRING)) { - promoted_type = large_utf8(); - promoted = true; + return large_utf8(); } else if ((promoted_type->id() == Type::BINARY && other_type->id() == Type::LARGE_BINARY) || (promoted_type->id() == Type::LARGE_BINARY && other_type->id() == Type::BINARY)) { - promoted_type = large_binary(); + return large_binary(); + } + if ((promoted_type->id() == Type::LIST && other_type->id() == Type::LARGE_LIST) || + (promoted_type->id() == Type::LARGE_LIST && other_type->id() == Type::LIST)) { + promoted_type = + large_list(checked_cast(*promoted_type).value_field()); promoted = true; } } + if (options.promote_nested) { + if ((promoted_type->id() == Type::LIST && other_type->id() == Type::LIST) || + (promoted_type->id() == Type::LARGE_LIST && + other_type->id() == Type::LARGE_LIST) || + (promoted_type->id() == Type::FIXED_SIZE_LIST && + other_type->id() == Type::FIXED_SIZE_LIST)) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + ARROW_ASSIGN_OR_RAISE(auto value_type, + MergeTypes(left.value_type(), right.value_type(), options)); + if (!value_type) return nullptr; + auto field = left.value_field()->WithType(std::move(value_type)); + if (promoted_type->id() == Type::LIST) { + return list(std::move(field)); + } else if (promoted_type->id() == Type::LARGE_LIST) { + return large_list(std::move(field)); + } + const auto left_size = + checked_cast(*promoted_type).list_size(); + const auto right_size = + checked_cast(*other_type).list_size(); + if (left_size == right_size) { + return fixed_size_list(std::move(field), left_size); + } + return Status::Invalid("Cannot merge fixed_size_list of different sizes"); + } + } + // TODO - // List(A) -> List(B) (incl. fixed) - // List -> LargeList // Struct: reconcile order, fields, types // Map diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index 94f25d1c7d6..e021bb9ade2 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -1266,6 +1266,9 @@ TEST_F(TestUnifySchemas, Temporal) { TEST_F(TestUnifySchemas, List) { auto options = Field::MergeOptions::Defaults(); options.promote_numeric_width = true; + CheckUnifyFails(fixed_size_list(int8(), 2), + {fixed_size_list(int8(), 3), list(int8()), large_list(int8())}, + options); options.promote_large = true; CheckUnify(list(int8()), {large_list(int8())}, options); @@ -1273,7 +1276,8 @@ TEST_F(TestUnifySchemas, List) { options.promote_nested = true; CheckUnify(list(int8()), {list(int16()), list(int32()), list(int64())}, options); - CheckUnify(fixed_size_list(int8(), 2), {list(int16()), list(int32()), list(int64())}, + CheckUnify(fixed_size_list(int8(), 2), + {fixed_size_list(int16(), 2), list(int16()), list(int32()), list(int64())}, options); // TODO: test nonstandard field names @@ -1340,7 +1344,8 @@ TEST_F(TestUnifySchemas, Binary) { options.promote_binary = true; CheckUnify(utf8(), {large_utf8(), binary(), large_binary()}, options); CheckUnify(binary(), {large_binary()}, options); - CheckUnify(fixed_size_binary(2), {binary(), large_binary()}, options); + CheckUnify(fixed_size_binary(2), {fixed_size_binary(2), binary(), large_binary()}, + options); CheckUnify(fixed_size_binary(2), fixed_size_binary(4), binary(), options); options.promote_large = false; diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index aaccaa53c14..26ea9e79c0c 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -977,6 +977,17 @@ static inline bool is_string(Type::type type_id) { return false; } +static inline bool is_var_size_list(Type::type type_id) { + switch (type_id) { + case Type::LIST: + case Type::LARGE_LIST: + return true; + default: + break; + } + return false; +} + static inline bool is_dictionary(Type::type type_id) { return type_id == Type::DICTIONARY; } From 682dc98173cda3d91e08a04508f2ba9b64e87adc Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 20 Dec 2021 09:34:24 -0500 Subject: [PATCH 11/22] ARROW-14705: [C++] Merge fixed_size_binary together --- cpp/src/arrow/type.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 9e64b5e6468..1452c5cb92f 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -558,6 +558,10 @@ Result> MergeTypes(std::shared_ptr promoted_ } if (options.promote_binary) { + if (promoted_type->id() == Type::FIXED_SIZE_BINARY && + other_type->id() == Type::FIXED_SIZE_BINARY) { + return binary(); + } if (is_string(promoted_type->id()) && is_binary(other_type->id())) { promoted_type = MakeBinary(*promoted_type); promoted = From b4a18a03d907812736a9d643b2db511bf08fc55b Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 20 Dec 2021 09:41:59 -0500 Subject: [PATCH 12/22] ARROW-14705: [C++] Refactor --- cpp/src/arrow/type.cc | 107 +++++++++++++++++++++++------------------- 1 file changed, 58 insertions(+), 49 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 1452c5cb92f..ac657d16678 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -341,6 +341,36 @@ TimeUnit::type CommonTimeUnit(TimeUnit::type left, TimeUnit::type right) { } return TimeUnit::SECOND; } + +Result> MergeTypes(std::shared_ptr promoted_type, + std::shared_ptr other_type, + const Field::MergeOptions& options); + +Result> MergeDictionaryTypes( + std::shared_ptr promoted_type, std::shared_ptr other_type, + const Field::MergeOptions& options) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + if (!options.promote_dictionary_ordered && left.ordered() != right.ordered()) { + return Status::Invalid( + "Cannot merge ordered and unordered dictionary unless " + "promote_dictionary_ordered=true"); + } + Field::MergeOptions index_options = options; + index_options.promote_integer_sign = true; + index_options.promote_numeric_width = true; + ARROW_ASSIGN_OR_RAISE(auto indices, + MergeTypes(left.index_type(), right.index_type(), index_options)); + ARROW_ASSIGN_OR_RAISE(auto values, + MergeTypes(left.value_type(), right.value_type(), options)); + auto ordered = left.ordered() && right.ordered(); + if (indices && values) { + return dictionary(indices, values, ordered); + } else if (values) { + return Status::Invalid("Could not merge index types"); + } + return Status::Invalid("Could not merge value types"); +} Result> MergeTypes(std::shared_ptr promoted_type, std::shared_ptr other_type, const Field::MergeOptions& options) { @@ -357,28 +387,9 @@ Result> MergeTypes(std::shared_ptr promoted_ return Status::Invalid("Cannot merge type with null unless promote_nullability=true"); } - // TODO: split these out if (options.promote_dictionary && is_dictionary(promoted_type->id()) && is_dictionary(other_type->id())) { - const auto& left = checked_cast(*promoted_type); - const auto& right = checked_cast(*other_type); - if (!options.promote_dictionary_ordered && left.ordered() != right.ordered()) { - return nullptr; - } - Field::MergeOptions index_options = options; - index_options.promote_integer_sign = true; - index_options.promote_numeric_width = true; - ARROW_ASSIGN_OR_RAISE( - auto indices, MergeTypes(left.index_type(), right.index_type(), index_options)); - ARROW_ASSIGN_OR_RAISE(auto values, - MergeTypes(left.value_type(), right.value_type(), options)); - auto ordered = left.ordered() && right.ordered(); - if (indices && values) { - return dictionary(indices, values, ordered); - } - return Status::Invalid( - "Cannot merge ordered and unordered dictionary unless " - "promote_dictionary_ordered=true"); + return MergeDictionaryTypes(promoted_type, other_type, options); } if (options.promote_date) { @@ -390,39 +401,35 @@ Result> MergeTypes(std::shared_ptr promoted_ } } - if (options.promote_duration) { - if (promoted_type->id() == Type::DURATION && other_type->id() == Type::DURATION) { - const auto& left = checked_cast(*promoted_type); - const auto& right = checked_cast(*other_type); - return duration(CommonTimeUnit(left.unit(), right.unit())); - } + if (options.promote_duration && promoted_type->id() == Type::DURATION && + other_type->id() == Type::DURATION) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + return duration(CommonTimeUnit(left.unit(), right.unit())); } - if (options.promote_time) { - if (is_time(promoted_type->id()) && is_time(other_type->id())) { - const auto& left = checked_cast(*promoted_type); - const auto& right = checked_cast(*other_type); - const auto unit = CommonTimeUnit(left.unit(), right.unit()); - if (unit == TimeUnit::MICRO || unit == TimeUnit::NANO) { - return time64(unit); - } - return time32(unit); + if (options.promote_time && is_time(promoted_type->id()) && is_time(other_type->id())) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + const auto unit = CommonTimeUnit(left.unit(), right.unit()); + if (unit == TimeUnit::MICRO || unit == TimeUnit::NANO) { + return time64(unit); } + return time32(unit); } - if (options.promote_timestamp) { - if (promoted_type->id() == Type::TIMESTAMP && other_type->id() == Type::TIMESTAMP) { - const auto& left = checked_cast(*promoted_type); - const auto& right = checked_cast(*other_type); - if (left.timezone().empty() ^ right.timezone().empty()) { - return Status::Invalid( - "Cannot merge timestamp with timezone and timestamp without timezone"); - } - if (left.timezone() != right.timezone()) { - return Status::Invalid("Cannot merge timestamps with differing timezones"); - } - return timestamp(CommonTimeUnit(left.unit(), right.unit()), left.timezone()); + if (options.promote_timestamp && promoted_type->id() == Type::TIMESTAMP && + other_type->id() == Type::TIMESTAMP) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + if (left.timezone().empty() ^ right.timezone().empty()) { + return Status::Invalid( + "Cannot merge timestamp with timezone and timestamp without timezone"); } + if (left.timezone() != right.timezone()) { + return Status::Invalid("Cannot merge timestamps with differing timezones"); + } + return timestamp(CommonTimeUnit(left.unit(), right.unit()), left.timezone()); } if (options.promote_decimal_float) { @@ -453,8 +460,10 @@ Result> MergeTypes(std::shared_ptr promoted_ is_decimal(other_type->id())) { const auto& left = checked_cast(*promoted_type); const auto& right = checked_cast(*other_type); - if (!options.promote_numeric_width && left.bit_width() != right.bit_width()) - return nullptr; + if (!options.promote_numeric_width && left.bit_width() != right.bit_width()) { + return Status::Invalid( + "Cannot promote decimal128 to decimal256 without promote_numeric_width=true"); + } const int32_t max_scale = std::max(left.scale(), right.scale()); const int32_t common_precision = std::max(left.precision() + max_scale - left.scale(), From 057ded14c7a5947419027f595a2f942b8c956170 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 20 Dec 2021 09:49:21 -0500 Subject: [PATCH 13/22] ARROW-14705: [C++] Implement map --- cpp/src/arrow/type.cc | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index ac657d16678..6adb7937404 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -432,6 +432,21 @@ Result> MergeTypes(std::shared_ptr promoted_ return timestamp(CommonTimeUnit(left.unit(), right.unit()), left.timezone()); } + if (options.promote_nested && promoted_type->id() == Type::MAP && + other_type->id() == Type::MAP) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + ARROW_ASSIGN_OR_RAISE(const auto key_type, + MergeTypes(left.key_type(), right.key_type(), options)); + ARROW_ASSIGN_OR_RAISE(const auto item_type, + MergeTypes(left.item_type(), right.item_type(), options)); + // TODO: need to actually merge the field nullability (here and dictionary etc) + // TODO: tests + return std::make_shared( + left.key_field()->WithType(key_type), left.item_field()->WithType(item_type), + /*keys_sorted=*/left.keys_sorted() && right.keys_sorted()); + } + if (options.promote_decimal_float) { if (is_decimal(promoted_type->id()) && is_floating(other_type->id())) { promoted_type = other_type; From 2de01edd1d9a959bfcd1a1ab11fbb1a2016d3ec4 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 20 Dec 2021 10:34:25 -0500 Subject: [PATCH 14/22] ARROW-14705: [C++] Handle nonstandard field names --- cpp/src/arrow/type.cc | 50 +++++++++++++----------- cpp/src/arrow/type_fwd.h | 8 ++++ cpp/src/arrow/type_test.cc | 79 +++++++++++++++++++++++++------------- 3 files changed, 89 insertions(+), 48 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 6adb7937404..59ee42c605c 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -432,21 +432,6 @@ Result> MergeTypes(std::shared_ptr promoted_ return timestamp(CommonTimeUnit(left.unit(), right.unit()), left.timezone()); } - if (options.promote_nested && promoted_type->id() == Type::MAP && - other_type->id() == Type::MAP) { - const auto& left = checked_cast(*promoted_type); - const auto& right = checked_cast(*other_type); - ARROW_ASSIGN_OR_RAISE(const auto key_type, - MergeTypes(left.key_type(), right.key_type(), options)); - ARROW_ASSIGN_OR_RAISE(const auto item_type, - MergeTypes(left.item_type(), right.item_type(), options)); - // TODO: need to actually merge the field nullability (here and dictionary etc) - // TODO: tests - return std::make_shared( - left.key_field()->WithType(key_type), left.item_field()->WithType(item_type), - /*keys_sorted=*/left.keys_sorted() && right.keys_sorted()); - } - if (options.promote_decimal_float) { if (is_decimal(promoted_type->id()) && is_floating(other_type->id())) { promoted_type = other_type; @@ -623,23 +608,38 @@ Result> MergeTypes(std::shared_ptr promoted_ other_type->id() == Type::FIXED_SIZE_LIST)) { const auto& left = checked_cast(*promoted_type); const auto& right = checked_cast(*other_type); - ARROW_ASSIGN_OR_RAISE(auto value_type, - MergeTypes(left.value_type(), right.value_type(), options)); - if (!value_type) return nullptr; - auto field = left.value_field()->WithType(std::move(value_type)); + ARROW_ASSIGN_OR_RAISE( + auto value_field, + left.value_field()->MergeWith( + *right.value_field()->WithName(left.value_field()->name()), options)); if (promoted_type->id() == Type::LIST) { - return list(std::move(field)); + return list(std::move(value_field)); } else if (promoted_type->id() == Type::LARGE_LIST) { - return large_list(std::move(field)); + return large_list(std::move(value_field)); } const auto left_size = checked_cast(*promoted_type).list_size(); const auto right_size = checked_cast(*other_type).list_size(); if (left_size == right_size) { - return fixed_size_list(std::move(field), left_size); + return fixed_size_list(std::move(value_field), left_size); } return Status::Invalid("Cannot merge fixed_size_list of different sizes"); + } else if (promoted_type->id() == Type::MAP && other_type->id() == Type::MAP) { + const auto& left = checked_cast(*promoted_type); + const auto& right = checked_cast(*other_type); + // While we try to preserve nonstandard field names here, note that + // MapType comparisons ignore field name. See ARROW-7173, ARROW-14999. + ARROW_ASSIGN_OR_RAISE( + auto key_field, + left.key_field()->MergeWith( + *right.key_field()->WithName(left.key_field()->name()), options)); + ARROW_ASSIGN_OR_RAISE( + auto item_field, + left.item_field()->MergeWith( + *right.item_field()->WithName(left.item_field()->name()), options)); + return map(std::move(key_field), std::move(item_field), + /*keys_sorted=*/left.keys_sorted() && right.keys_sorted()); } } @@ -2628,6 +2628,12 @@ std::shared_ptr map(std::shared_ptr key_type, keys_sorted); } +std::shared_ptr map(std::shared_ptr key_field, + std::shared_ptr item_field, bool keys_sorted) { + return std::make_shared(std::move(key_field), std::move(item_field), + keys_sorted); +} + std::shared_ptr fixed_size_list(const std::shared_ptr& value_type, int32_t list_size) { return std::make_shared(value_type, list_size); diff --git a/cpp/src/arrow/type_fwd.h b/cpp/src/arrow/type_fwd.h index 45afd7af2e6..1e566fa9ebd 100644 --- a/cpp/src/arrow/type_fwd.h +++ b/cpp/src/arrow/type_fwd.h @@ -503,6 +503,14 @@ std::shared_ptr map(std::shared_ptr key_type, std::shared_ptr item_field, bool keys_sorted = false); +/// \brief Create a MapType instance from its key field and value field. +/// +/// The field override is provided to communicate nullability of the value. +ARROW_EXPORT +std::shared_ptr map(std::shared_ptr key_field, + std::shared_ptr item_field, + bool keys_sorted = false); + /// \brief Create a FixedSizeListType instance from its child Field type ARROW_EXPORT std::shared_ptr fixed_size_list(const std::shared_ptr& value_type, diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index e021bb9ade2..0fd4faeee75 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -979,17 +979,23 @@ class TestUnifySchemas : public TestSchema { } } + void CheckUnifyAsymmetric( + const std::shared_ptr& field1, const std::shared_ptr& field2, + const std::shared_ptr& expected, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + ARROW_SCOPED_TRACE("options: ", options); + ARROW_SCOPED_TRACE("field2: ", field2->ToString()); + ARROW_SCOPED_TRACE("field1: ", field1->ToString()); + ASSERT_OK_AND_ASSIGN(auto merged, field1->MergeWith(field2, options)); + AssertFieldEqual(merged, expected); + } + void CheckUnify(const std::shared_ptr& field1, const std::shared_ptr& field2, const std::shared_ptr& expected, const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { - ARROW_SCOPED_TRACE("options: ", options); - ARROW_SCOPED_TRACE("field2: ", field2->ToString()); - ARROW_SCOPED_TRACE("field1: ", field1->ToString()); - ASSERT_OK_AND_ASSIGN(auto merged1, field1->MergeWith(field2, options)); - ASSERT_OK_AND_ASSIGN(auto merged2, field2->MergeWith(field1, options)); - AssertFieldEqual(merged1, expected); - AssertFieldEqual(merged2, expected); + CheckUnifyAsymmetric(field1, field2, expected, options); + CheckUnifyAsymmetric(field2, field1, expected, options); } void CheckUnifyFails( @@ -1263,6 +1269,24 @@ TEST_F(TestUnifySchemas, Temporal) { timestamp(TimeUnit::SECOND, "UTC"), options); } +TEST_F(TestUnifySchemas, Binary) { + auto options = Field::MergeOptions::Defaults(); + options.promote_large = true; + options.promote_binary = true; + CheckUnify(utf8(), {large_utf8(), binary(), large_binary()}, options); + CheckUnify(binary(), {large_binary()}, options); + CheckUnify(fixed_size_binary(2), {fixed_size_binary(2), binary(), large_binary()}, + options); + CheckUnify(fixed_size_binary(2), fixed_size_binary(4), binary(), options); + + options.promote_large = false; + CheckUnifyFails({utf8(), binary()}, {large_utf8(), large_binary()}); + CheckUnifyFails(fixed_size_binary(2), BaseBinaryTypes()); + + options.promote_binary = false; + CheckUnifyFails(utf8(), {binary(), large_binary(), fixed_size_binary(2)}); +} + TEST_F(TestUnifySchemas, List) { auto options = Field::MergeOptions::Defaults(); options.promote_numeric_width = true; @@ -1280,7 +1304,13 @@ TEST_F(TestUnifySchemas, List) { {fixed_size_list(int16(), 2), list(int16()), list(int32()), list(int64())}, options); - // TODO: test nonstandard field names + auto field1 = field("a", list(field("foo", int8(), /*nullable=*/false))); + CheckUnifyAsymmetric(field1, field("a", list(int8())), + field("a", list(field("foo", int8(), /*nullable=*/true))), + options); + CheckUnifyAsymmetric( + field1, field("a", list(field("bar", int16(), /*nullable=*/false))), + field("a", list(field("foo", int16(), /*nullable=*/false))), options); } TEST_F(TestUnifySchemas, Map) { @@ -1291,6 +1321,21 @@ TEST_F(TestUnifySchemas, Map) { CheckUnify(map(int8(), int32()), {map(int8(), int64()), map(int16(), int32()), map(int64(), int64())}, options); + + // Do not test field names, since MapType intentionally ignores them in comparisons + // See ARROW-7173, ARROW-14999 + auto ty = map(field("key", int8(), /*nullable=*/false), + field("value", int32(), /*nullable=*/false)); + CheckUnify(ty, map(int8(), int32()), + map(field("key", int8(), /*nullable=*/true), + field("value", int32(), /*nullable=*/true)), + options); + CheckUnify(ty, + map(field("key", int16(), /*nullable=*/false), + field("value", int64(), /*nullable=*/false)), + map(field("key", int16(), /*nullable=*/false), + field("value", int64(), /*nullable=*/false)), + options); } TEST_F(TestUnifySchemas, Struct) { @@ -1338,24 +1383,6 @@ TEST_F(TestUnifySchemas, Dictionary) { dictionary(int8(), utf8(), /*ordered=*/false), options); } -TEST_F(TestUnifySchemas, Binary) { - auto options = Field::MergeOptions::Defaults(); - options.promote_large = true; - options.promote_binary = true; - CheckUnify(utf8(), {large_utf8(), binary(), large_binary()}, options); - CheckUnify(binary(), {large_binary()}, options); - CheckUnify(fixed_size_binary(2), {fixed_size_binary(2), binary(), large_binary()}, - options); - CheckUnify(fixed_size_binary(2), fixed_size_binary(4), binary(), options); - - options.promote_large = false; - CheckUnifyFails({utf8(), binary()}, {large_utf8(), large_binary()}); - CheckUnifyFails(fixed_size_binary(2), BaseBinaryTypes()); - - options.promote_binary = false; - CheckUnifyFails(utf8(), {binary(), large_binary(), fixed_size_binary(2)}); -} - TEST_F(TestUnifySchemas, IncompatibleTypes) { auto int32_field = field("f", int32()); auto uint8_field = field("f", uint8(), false); From f09ab8d1e3ca2feeddc81a2f56eb6dc9dc609bce Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 20 Dec 2021 10:45:09 -0500 Subject: [PATCH 15/22] ARROW-14705: [C++] Refactor --- cpp/src/arrow/type.cc | 98 +++++++++++++++++++++++++------------------ 1 file changed, 58 insertions(+), 40 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 59ee42c605c..6d6870589a8 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -346,9 +346,10 @@ Result> MergeTypes(std::shared_ptr promoted_ std::shared_ptr other_type, const Field::MergeOptions& options); +// Merge two dictionary types, or else give an error. Result> MergeDictionaryTypes( - std::shared_ptr promoted_type, std::shared_ptr other_type, - const Field::MergeOptions& options) { + const std::shared_ptr& promoted_type, + const std::shared_ptr& other_type, const Field::MergeOptions& options) { const auto& left = checked_cast(*promoted_type); const auto& right = checked_cast(*other_type); if (!options.promote_dictionary_ordered && left.ordered() != right.ordered()) { @@ -371,27 +372,11 @@ Result> MergeDictionaryTypes( } return Status::Invalid("Could not merge value types"); } -Result> MergeTypes(std::shared_ptr promoted_type, - std::shared_ptr other_type, - const Field::MergeOptions& options) { - if (promoted_type->Equals(*other_type)) return promoted_type; - - bool promoted = false; - if (options.promote_nullability) { - if (promoted_type->id() == Type::NA) { - return other_type; - } else if (other_type->id() == Type::NA) { - return promoted_type; - } - } else if (promoted_type->id() == Type::NA || other_type->id() == Type::NA) { - return Status::Invalid("Cannot merge type with null unless promote_nullability=true"); - } - - if (options.promote_dictionary && is_dictionary(promoted_type->id()) && - is_dictionary(other_type->id())) { - return MergeDictionaryTypes(promoted_type, other_type, options); - } +// Merge temporal types based on options. Returns nullptr for non-temporal types. +Result> MaybeMergeTemporalTypes( + const std::shared_ptr& promoted_type, + const std::shared_ptr& other_type, const Field::MergeOptions& options) { if (options.promote_date) { if (promoted_type->id() == Type::DATE32 && other_type->id() == Type::DATE64) { return date64(); @@ -432,6 +417,14 @@ Result> MergeTypes(std::shared_ptr promoted_ return timestamp(CommonTimeUnit(left.unit(), right.unit()), left.timezone()); } + return nullptr; +} + +// Merge numeric types based on options. Returns nullptr for non-temporal types. +Result> MaybeMergeNumericTypes( + std::shared_ptr promoted_type, std::shared_ptr other_type, + const Field::MergeOptions& options) { + bool promoted = false; if (options.promote_decimal_float) { if (is_decimal(promoted_type->id()) && is_floating(other_type->id())) { promoted_type = other_type; @@ -507,40 +500,66 @@ Result> MergeTypes(std::shared_ptr promoted_ std::max(bit_width(promoted_type->id()), bit_width(other_type->id())); if (is_floating(promoted_type->id()) && is_floating(other_type->id())) { if (max_width >= 64) { - promoted_type = float64(); + return float64(); } else if (max_width >= 32) { - promoted_type = float32(); - } else { - promoted_type = float16(); + return float32(); } - promoted = true; + return float16(); } else if (is_signed_integer(promoted_type->id()) && is_signed_integer(other_type->id())) { if (max_width >= 64) { - promoted_type = int64(); + return int64(); } else if (max_width >= 32) { - promoted_type = int32(); + return int32(); } else if (max_width >= 16) { - promoted_type = int16(); - } else { - promoted_type = int8(); + return int16(); } - promoted = true; + return int8(); } else if (is_unsigned_integer(promoted_type->id()) && is_unsigned_integer(other_type->id())) { if (max_width >= 64) { - promoted_type = uint64(); + return uint64(); } else if (max_width >= 32) { - promoted_type = uint32(); + return uint32(); } else if (max_width >= 16) { - promoted_type = uint16(); - } else { - promoted_type = uint8(); + return uint16(); } - promoted = true; + return uint8(); + } + } + + return promoted ? promoted_type : nullptr; +} + +Result> MergeTypes(std::shared_ptr promoted_type, + std::shared_ptr other_type, + const Field::MergeOptions& options) { + if (promoted_type->Equals(*other_type)) return promoted_type; + + bool promoted = false; + if (options.promote_nullability) { + if (promoted_type->id() == Type::NA) { + return other_type; + } else if (other_type->id() == Type::NA) { + return promoted_type; } + } else if (promoted_type->id() == Type::NA || other_type->id() == Type::NA) { + return Status::Invalid("Cannot merge type with null unless promote_nullability=true"); } + if (options.promote_dictionary && is_dictionary(promoted_type->id()) && + is_dictionary(other_type->id())) { + return MergeDictionaryTypes(promoted_type, other_type, options); + } + + ARROW_ASSIGN_OR_RAISE(auto maybe_promoted, + MaybeMergeTemporalTypes(promoted_type, other_type, options)); + if (maybe_promoted) return maybe_promoted; + + ARROW_ASSIGN_OR_RAISE(maybe_promoted, + MaybeMergeNumericTypes(promoted_type, other_type, options)); + if (maybe_promoted) return maybe_promoted; + if (options.promote_large) { if (promoted_type->id() == Type::FIXED_SIZE_BINARY && is_base_binary_like(other_type->id())) { @@ -645,7 +664,6 @@ Result> MergeTypes(std::shared_ptr promoted_ // TODO // Struct: reconcile order, fields, types - // Map return promoted ? promoted_type : nullptr; } From fde4c52297a03c6591e9110b57f8ec8da46e111b Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 20 Dec 2021 10:59:38 -0500 Subject: [PATCH 16/22] ARROW-14705: [C++] Implement structs --- cpp/src/arrow/type.cc | 16 +++++++++--- cpp/src/arrow/type_test.cc | 52 ++++++++++++++++++++++++++++++-------- 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 6d6870589a8..6423fc10d8f 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -659,12 +659,22 @@ Result> MergeTypes(std::shared_ptr promoted_ *right.item_field()->WithName(left.item_field()->name()), options)); return map(std::move(key_field), std::move(item_field), /*keys_sorted=*/left.keys_sorted() && right.keys_sorted()); + } else if (promoted_type->id() == Type::STRUCT && other_type->id() == Type::STRUCT) { + SchemaBuilder builder(SchemaBuilder::CONFLICT_APPEND, options); + // Add the LHS fields. Duplicates will be preserved. + RETURN_NOT_OK(builder.AddFields(promoted_type->fields())); + + // Add the RHS fields. Duplicates will be merged, unless the field was + // already a duplicate, in which case we error (since we don't know which + // field to merge with). + builder.SetPolicy(SchemaBuilder::CONFLICT_MERGE); + RETURN_NOT_OK(builder.AddFields(other_type->fields())); + + ARROW_ASSIGN_OR_RAISE(auto schema, builder.Finish()); + return struct_(schema->fields()); } } - // TODO - // Struct: reconcile order, fields, types - return promoted ? promoted_type : nullptr; } } // namespace diff --git a/cpp/src/arrow/type_test.cc b/cpp/src/arrow/type_test.cc index 0fd4faeee75..b447c01833f 100644 --- a/cpp/src/arrow/type_test.cc +++ b/cpp/src/arrow/type_test.cc @@ -1029,6 +1029,30 @@ class TestUnifySchemas : public TestSchema { CheckUnify(field1, field2, field("a", expected, /*nullable=*/true), options); } + void CheckUnifyAsymmetric( + const std::shared_ptr& left, const std::shared_ptr& right, + const std::shared_ptr& expected, + const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { + auto field1 = field("a", left); + auto field2 = field("a", right); + CheckUnifyAsymmetric(field1, field2, field("a", expected), options); + + field1 = field("a", left, /*nullable=*/false); + field2 = field("a", right, /*nullable=*/false); + CheckUnifyAsymmetric(field1, field2, field("a", expected, /*nullable=*/false), + options); + + field1 = field("a", left); + field2 = field("a", right, /*nullable=*/false); + CheckUnifyAsymmetric(field1, field2, field("a", expected, /*nullable=*/true), + options); + + field1 = field("a", left, /*nullable=*/false); + field2 = field("a", right); + CheckUnifyAsymmetric(field1, field2, field("a", expected, /*nullable=*/true), + options); + } + void CheckUnifyFails( const std::shared_ptr& left, const std::shared_ptr& right, const Field::MergeOptions& options = Field::MergeOptions::Defaults()) { @@ -1304,13 +1328,11 @@ TEST_F(TestUnifySchemas, List) { {fixed_size_list(int16(), 2), list(int16()), list(int32()), list(int64())}, options); - auto field1 = field("a", list(field("foo", int8(), /*nullable=*/false))); - CheckUnifyAsymmetric(field1, field("a", list(int8())), - field("a", list(field("foo", int8(), /*nullable=*/true))), + auto ty = list(field("foo", int8(), /*nullable=*/false)); + CheckUnifyAsymmetric(ty, list(int8()), list(field("foo", int8(), /*nullable=*/true)), options); - CheckUnifyAsymmetric( - field1, field("a", list(field("bar", int16(), /*nullable=*/false))), - field("a", list(field("foo", int16(), /*nullable=*/false))), options); + CheckUnifyAsymmetric(ty, list(field("bar", int16(), /*nullable=*/false)), + list(field("foo", int16(), /*nullable=*/false)), options); } TEST_F(TestUnifySchemas, Map) { @@ -1347,15 +1369,23 @@ TEST_F(TestUnifySchemas, Struct) { CheckUnify(struct_({}), struct_({field("a", int8())}), struct_({field("a", int8())}), options); - CheckUnify(struct_({field("b", utf8())}), struct_({field("a", int8())}), - struct_({field("b", utf8()), field("a", int8())}), options); + CheckUnifyAsymmetric(struct_({field("b", utf8())}), struct_({field("a", int8())}), + struct_({field("b", utf8()), field("a", int8())}), options); + CheckUnifyAsymmetric(struct_({field("a", int8())}), struct_({field("b", utf8())}), + struct_({field("a", int8()), field("b", utf8())}), options); CheckUnify(struct_({field("b", utf8())}), struct_({field("b", binary())}), struct_({field("b", binary())}), options); - CheckUnify(struct_({field("a", int8()), field("b", utf8())}), - struct_({field("b", utf8()), field("a", int8())}), - struct_({field("a", int8()), field("b", utf8())}), options); + CheckUnifyAsymmetric( + struct_({field("a", int8()), field("b", utf8()), field("a", int64())}), + struct_({field("b", binary())}), + struct_({field("a", int8()), field("b", binary()), field("a", int64())}), options); + + ASSERT_RAISES( + Invalid, + field("foo", struct_({field("a", int8()), field("b", utf8()), field("a", int64())})) + ->MergeWith(field("foo", struct_({field("a", int64())})), options)); } TEST_F(TestUnifySchemas, Dictionary) { From d14814fe18e53860e03927db3f9f249dc823be3c Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 20 Dec 2021 11:07:42 -0500 Subject: [PATCH 17/22] ARROW-14705: [C++] Add options to discovery --- cpp/src/arrow/dataset/discovery.h | 3 ++- cpp/src/arrow/dataset/discovery_test.cc | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/dataset/discovery.h b/cpp/src/arrow/dataset/discovery.h index bd928ce57ff..382b23e4caa 100644 --- a/cpp/src/arrow/dataset/discovery.h +++ b/cpp/src/arrow/dataset/discovery.h @@ -59,7 +59,8 @@ struct InspectOptions { /// altogether so only the partitioning schema will be inspected. int fragments = 1; - /// Control how to unify types. + /// Control how to unify types. By default, types are merged strictly (the + /// type must match exactly, except nulls can be merged with other types). Field::MergeOptions field_merge_options = Field::MergeOptions::Defaults(); }; diff --git a/cpp/src/arrow/dataset/discovery_test.cc b/cpp/src/arrow/dataset/discovery_test.cc index a51b3c09971..8842d084b69 100644 --- a/cpp/src/arrow/dataset/discovery_test.cc +++ b/cpp/src/arrow/dataset/discovery_test.cc @@ -120,6 +120,12 @@ TEST_F(MockDatasetFactoryTest, UnifySchemas) { ASSERT_RAISES(Invalid, factory_->Inspect()); // Return the individual schema for closer inspection should not fail. AssertInspectSchemas({schema({i32, f64}), schema({f64, i32_fake})}); + + MakeFactory({schema({field("num", int32())}), schema({field("num", float64())})}); + ASSERT_RAISES(Invalid, factory_->Inspect()); + InspectOptions permissive_options; + permissive_options.field_merge_options = Field::MergeOptions::Permissive(); + AssertInspect(schema({field("num", float64())}), permissive_options); } class FileSystemDatasetFactoryTest : public DatasetFactoryTest { @@ -473,6 +479,12 @@ TEST(UnionDatasetFactoryTest, ConflictingSchemas) { auto i32_schema = schema({i32}); ASSERT_OK_AND_ASSIGN(auto dataset, factory->Finish(i32_schema)); EXPECT_EQ(*dataset->schema(), *i32_schema); + + // The user decided to allow merging the types. + FinishOptions options; + options.inspect_options.field_merge_options = Field::MergeOptions::Permissive(); + ASSERT_OK_AND_ASSIGN(dataset, factory->Finish(options)); + EXPECT_EQ(*dataset->schema(), *schema({f64, i32})); } } // namespace dataset From fffb846f064cd0b9aa4e6f7ddb68a0116e89807e Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 20 Dec 2021 12:08:16 -0500 Subject: [PATCH 18/22] ARROW-14705: [Python] Add basic bindings --- python/pyarrow/__init__.py | 1 + python/pyarrow/includes/libarrow.pxd | 9 ++++-- python/pyarrow/tests/test_schema.py | 4 +++ python/pyarrow/types.pxi | 42 ++++++++++++++++++++++++++-- 4 files changed, 51 insertions(+), 5 deletions(-) diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index 07ef7f4b078..69ea79b50cf 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -115,6 +115,7 @@ def show_versions(): DictionaryMemo, KeyValueMetadata, Field, + FieldMergeOptions, Schema, schema, unify_schemas, diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 1e6c741ac30..3202831f641 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -387,12 +387,16 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: int scale() cdef cppclass CField" arrow::Field": - cppclass CMergeOptions "arrow::Field::MergeOptions": + cppclass CMergeOptions "MergeOptions": + CMergeOptions() c_bool promote_nullability @staticmethod CMergeOptions Defaults() + @staticmethod + CMergeOptions Permissive() + const c_string& name() shared_ptr[CDataType] type() c_bool nullable() @@ -483,7 +487,8 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil: shared_ptr[CSchema] RemoveMetadata() CResult[shared_ptr[CSchema]] UnifySchemas( - const vector[shared_ptr[CSchema]]& schemas) + const vector[shared_ptr[CSchema]]& schemas, + CField.CMergeOptions field_merge_options) cdef cppclass PrettyPrintOptions: PrettyPrintOptions() diff --git a/python/pyarrow/tests/test_schema.py b/python/pyarrow/tests/test_schema.py index f26eaaf5fc1..b208a995833 100644 --- a/python/pyarrow/tests/test_schema.py +++ b/python/pyarrow/tests/test_schema.py @@ -718,6 +718,10 @@ def test_schema_merge(): result = pa.unify_schemas((a, b, c)) assert result.equals(expected) + result = pa.unify_schemas( + [b, d], options=pa.FieldMergeOptions.permissive()) + assert result.equals(d) + def test_undecodable_metadata(): # ARROW-10214: undecodable metadata shouldn't fail repr() diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 32d70887aab..abd4c9e6a7a 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1118,6 +1118,33 @@ cdef KeyValueMetadata ensure_metadata(object meta, c_bool allow_none=False): return KeyValueMetadata(meta) +cdef class FieldMergeOptions(_Weakrefable): + """ + Options controlling how to merge the types of two fields. + + By default, types must match exactly, except the null type can be + merged with any other type. + + """ + + cdef: + CField.CMergeOptions c_options + + __slots__ = () + + def __init__(self, *): + self.c_options = CField.CMergeOptions.Defaults() + + @staticmethod + def permissive(): + """ + Allow merging generally compatible types (e.g. float64 and int64). + """ + cdef FieldMergeOptions options = FieldMergeOptions() + options.c_options = CField.CMergeOptions.Permissive() + return options + + cdef class Field(_Weakrefable): """ A named field, with a data type, nullability, and optional metadata. @@ -1783,13 +1810,13 @@ cdef class Schema(_Weakrefable): return self.__str__() -def unify_schemas(schemas): +def unify_schemas(schemas, *, options=None): """ Unify schemas by merging fields by name. The resulting schema will contain the union of fields from all schemas. Fields with the same name will be merged. Note that two fields with - different types will fail merging. + different types will fail merging by default. - The unified field will inherit the metadata from the schema where that field is first defined. @@ -1804,6 +1831,9 @@ def unify_schemas(schemas): schemas : list of Schema Schemas to merge into a single one. + options : FieldMergeOptions, optional + Options for merging duplicate fields. + Returns ------- Schema @@ -1816,10 +1846,16 @@ def unify_schemas(schemas): """ cdef: Schema schema + CField.CMergeOptions c_options vector[shared_ptr[CSchema]] c_schemas for schema in schemas: c_schemas.push_back(pyarrow_unwrap_schema(schema)) - return pyarrow_wrap_schema(GetResultValue(UnifySchemas(c_schemas))) + if options: + c_options = ( options).c_options + else: + c_options = CField.CMergeOptions.Defaults() + return pyarrow_wrap_schema( + GetResultValue(UnifySchemas(c_schemas, c_options))) cdef dict _type_cache = {} From e4fc346bc31646ffe2a67c129cc00abdfa79c8b6 Mon Sep 17 00:00:00 2001 From: David Li Date: Mon, 20 Dec 2021 13:24:05 -0500 Subject: [PATCH 19/22] ARROW-14705: [C++] Add missing export --- cpp/src/arrow/type.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index fa0679179a9..189b53b6ff3 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -303,7 +303,7 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// \brief Options that control the behavior of `MergeWith`. /// Options are to be added to allow type conversions, including integer /// widening, promotion from integer to float, or conversion to or from boolean. - struct MergeOptions : public util::ToStringOstreamable { + struct ARROW_EXPORT MergeOptions : public util::ToStringOstreamable { /// If true, a Field of NullType can be unified with a Field of another type. /// The unified field will be of the other type and become nullable. /// Nullability will be promoted to the looser option (nullable if one is not From ceac69299f157bde2cc8820c3fabe122517ba499 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 28 Dec 2021 10:33:21 -0500 Subject: [PATCH 20/22] ARROW-14705: [C++] Organize and document options --- cpp/src/arrow/type.cc | 16 +++++++++------- cpp/src/arrow/type.h | 37 +++++++++++++++++++++---------------- 2 files changed, 30 insertions(+), 23 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 6423fc10d8f..5a1c7822001 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -259,19 +259,21 @@ std::shared_ptr Field::WithNullable(const bool nullable) const { Field::MergeOptions Field::MergeOptions::Permissive() { MergeOptions options = Defaults(); options.promote_nullability = true; - options.promote_numeric_width = true; - options.promote_integer_float = true; - options.promote_integer_decimal = true; + options.promote_decimal = true; options.promote_decimal_float = true; + options.promote_integer_decimal = true; + options.promote_integer_float = true; + options.promote_integer_sign = true; + options.promote_numeric_width = true; + options.promote_binary = true; options.promote_date = true; - options.promote_time = true; options.promote_duration = true; + options.promote_time = true; options.promote_timestamp = true; - options.promote_nested = true; options.promote_dictionary = true; - options.promote_integer_sign = true; + options.promote_dictionary_ordered = false; options.promote_large = true; - options.promote_binary = true; + options.promote_nested = true; return options; } diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 189b53b6ff3..8afa40bc010 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -319,39 +319,40 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// not itself be promoted (e.g. Decimal128 + Float32 = Float32). bool promote_decimal_float = false; - /// Allow an integer of a given bit width to be promoted to a - /// float of an equal or greater bit width. - bool promote_integer_float = false; - - /// Allow an unsigned integer of a given bit width to be promoted - /// to a signed integer of the same bit width. - bool promote_integer_sign = false; - /// Allow an integer to be promoted to a decimal. /// /// May fail if the decimal has insufficient precision to /// accomodate the integer. (See increase_decimal_precision.) bool promote_integer_decimal = false; + /// Allow an integer of a given bit width to be promoted to a + /// float; the result will be a float of an equal or greater bit + /// width to both of the inputs. + bool promote_integer_float = false; + + /// Allow an unsigned integer of a given bit width to be promoted + /// to a signed integer of the equal or greater bit width. + bool promote_integer_sign = false; + /// Allow an integer, float, or decimal of a given bit width to be /// promoted to an equivalent type of a greater bit width. bool promote_numeric_width = false; + /// Allow strings to be promoted to binary types. + bool promote_binary = false; + /// Promote Date32 to Date64. bool promote_date = false; - /// Promote Time32 to Time64, or Time32(SECOND) to Time32(MILLI), etc. - bool promote_time = false; - /// Promote second to millisecond, etc. bool promote_duration = false; + /// Promote Time32 to Time64, or Time32(SECOND) to Time32(MILLI), etc. + bool promote_time = false; + /// Promote second to millisecond, etc. bool promote_timestamp = false; - /// Recursively merge nested types. - bool promote_nested = false; - /// Promote dictionary index types to a common type, and unify the /// value types. bool promote_dictionary = false; @@ -364,11 +365,15 @@ class ARROW_EXPORT Field : public detail::Fingerprintable { /// Allow a type to be promoted to the Large variant. bool promote_large = false; - /// Allow strings to be promoted to binary types. - bool promote_binary = false; + /// Recursively merge nested types. + bool promote_nested = false; + /// Get default options. Only NullType will be merged with other types. static MergeOptions Defaults() { return MergeOptions(); } + /// Get permissive options. All options are enabled, except + /// promote_dictionary_ordered. static MergeOptions Permissive(); + /// Get a human-readable representation of the options. std::string ToString() const; }; From 1bb664254c5185386f4f416454a287d3903e8527 Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 6 Jan 2022 08:52:08 -0500 Subject: [PATCH 21/22] ARROW-14705: [C++] Add missing header --- cpp/src/arrow/type.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 5a1c7822001..333eb0ac39c 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -36,6 +36,7 @@ #include "arrow/result.h" #include "arrow/status.h" #include "arrow/util/checked_cast.h" +#include "arrow/util/decimal.h" #include "arrow/util/hash_util.h" #include "arrow/util/hashing.h" #include "arrow/util/key_value_metadata.h" From 6992934370de972beb7cba25f9fada5d0fbaf88a Mon Sep 17 00:00:00 2001 From: David Li Date: Thu, 6 Jan 2022 13:02:19 -0500 Subject: [PATCH 22/22] ARROW-14705: [C++][Python] Add unification to ConcatenateTables --- cpp/src/arrow/table.cc | 23 +++++++++++++++- cpp/src/arrow/table.h | 12 ++++++--- cpp/src/arrow/table_test.cc | 42 ++++++++++++++++++++++++++++-- python/pyarrow/array.pxi | 4 +-- python/pyarrow/table.pxi | 12 +++++++-- python/pyarrow/tests/test_table.py | 9 +++++++ python/pyarrow/types.pxi | 1 - 7 files changed, 91 insertions(+), 12 deletions(-) diff --git a/cpp/src/arrow/table.cc b/cpp/src/arrow/table.cc index 7d7ad61bca5..48756f8f6fc 100644 --- a/cpp/src/arrow/table.cc +++ b/cpp/src/arrow/table.cc @@ -38,9 +38,15 @@ #include "arrow/type_fwd.h" #include "arrow/type_traits.h" #include "arrow/util/checked_cast.h" +// Get ARROW_COMPUTE definition +#include "arrow/util/config.h" #include "arrow/util/logging.h" #include "arrow/util/vector.h" +#ifdef ARROW_COMPUTE +#include "arrow/compute/cast.h" +#endif + namespace arrow { using internal::checked_cast; @@ -504,9 +510,24 @@ Result> PromoteTableToSchema(const std::shared_ptr continue; } +#ifdef ARROW_COMPUTE + if (!compute::CanCast(*current_field->type(), *field->type())) { + return Status::Invalid("Unable to promote field ", field->name(), + ": incompatible types: ", field->type()->ToString(), " vs ", + current_field->type()->ToString()); + } + compute::ExecContext ctx(pool); + auto options = compute::CastOptions::Safe(); + ARROW_ASSIGN_OR_RAISE(auto casted, compute::Cast(table->column(field_index), + field->type(), options, &ctx)); + columns.push_back(casted.chunked_array()); +#else return Status::Invalid("Unable to promote field ", field->name(), ": incompatible types: ", field->type()->ToString(), " vs ", - current_field->type()->ToString()); + current_field->type()->ToString(), + " (Arrow must be built with ARROW_COMPUTE " + "in order to cast incompatible types)"); +#endif } auto unseen_field_iter = std::find(fields_seen.begin(), fields_seen.end(), false); diff --git a/cpp/src/arrow/table.h b/cpp/src/arrow/table.h index 1d6cdd56765..f23756c4849 100644 --- a/cpp/src/arrow/table.h +++ b/cpp/src/arrow/table.h @@ -293,14 +293,18 @@ Result> ConcatenateTables( /// \brief Promotes a table to conform to the given schema. /// -/// If a field in the schema does not have a corresponding column in the -/// table, a column of nulls will be added to the resulting table. -/// If the corresponding column is of type Null, it will be promoted to -/// the type specified by schema, with null values filled. +/// If a field in the schema does not have a corresponding column in +/// the table, a column of nulls will be added to the resulting table. +/// If the corresponding column is of type Null, it will be promoted +/// to the type specified by schema, with null values filled. If Arrow +/// was built with ARROW_COMPUTE, then the column will be casted to +/// the type specified by the schema. +/// /// Returns an error: /// - if the corresponding column's type is not compatible with the /// schema. /// - if there is a column in the table that does not exist in the schema. +/// - if the cast fails or casting would be required but is not available. /// /// \param[in] table the input Table /// \param[in] schema the target schema to promote to diff --git a/cpp/src/arrow/table_test.cc b/cpp/src/arrow/table_test.cc index 3f6589fdf94..c4dddacb28d 100644 --- a/cpp/src/arrow/table_test.cc +++ b/cpp/src/arrow/table_test.cc @@ -34,6 +34,7 @@ #include "arrow/testing/gtest_util.h" #include "arrow/testing/random.h" #include "arrow/type.h" +#include "arrow/util/config.h" #include "arrow/util/key_value_metadata.h" namespace arrow { @@ -417,8 +418,9 @@ TEST_F(TestPromoteTableToSchema, IncompatibleTypes) { // Invalid promotion: int32 to null. ASSERT_RAISES(Invalid, PromoteTableToSchema(table, schema({field("field", null())}))); - // Invalid promotion: int32 to uint32. - ASSERT_RAISES(Invalid, PromoteTableToSchema(table, schema({field("field", uint32())}))); + // Invalid promotion: int32 to list. + ASSERT_RAISES(Invalid, + PromoteTableToSchema(table, schema({field("field", list(int32()))}))); } TEST_F(TestPromoteTableToSchema, IncompatibleNullity) { @@ -517,6 +519,42 @@ TEST_F(ConcatenateTablesWithPromotionTest, Simple) { AssertTablesEqualUnorderedFields(*expected, *result); } +TEST_F(ConcatenateTablesWithPromotionTest, Unify) { + auto t1 = TableFromJSON(schema({field("f0", int32())}), {"[[0], [1]]"}); + auto t2 = TableFromJSON(schema({field("f0", int64())}), {"[[2], [3]]"}); + auto t3 = TableFromJSON(schema({field("f0", null())}), {"[[null], [null]]"}); + + auto expected_int64 = + TableFromJSON(schema({field("f0", int64())}), {"[[0], [1], [2], [3]]"}); + auto expected_null = + TableFromJSON(schema({field("f0", int32())}), {"[[0], [1], [null], [null]]"}); + + ConcatenateTablesOptions options; + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("Schema at index 1 was different"), + ConcatenateTables({t1, t2}, options)); + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("Schema at index 1 was different"), + ConcatenateTables({t1, t3}, options)); + + options.unify_schemas = true; + EXPECT_RAISES_WITH_MESSAGE_THAT(Invalid, + ::testing::HasSubstr("Field f0 has incompatible types"), + ConcatenateTables({t1, t2}, options)); + ASSERT_OK_AND_ASSIGN(auto actual, ConcatenateTables({t1, t3}, options)); + AssertTablesEqual(*expected_null, *actual, /*same_chunk_layout=*/false); + + options.field_merge_options.promote_numeric_width = true; +#ifdef ARROW_COMPUTE + ASSERT_OK_AND_ASSIGN(actual, ConcatenateTables({t1, t2}, options)); + AssertTablesEqual(*expected_int64, *actual, /*same_chunk_layout=*/false); +#else + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, ::testing::HasSubstr("must be built with ARROW_COMPUTE"), + ConcatenateTables({t1, t2}, options)); +#endif +} + TEST_F(TestTable, Slice) { const int64_t length = 10; diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index de0d3a74dfb..6e71b5a51aa 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -1013,11 +1013,11 @@ cdef class Array(_PandasConvertible): Parameters ---------- indent : int, default 2 - How much to indent the internal items in the string to + How much to indent the internal items in the string to the right, by default ``2``. top_level_indent : int, default 0 How much to indent right the entire content of the array, - by default ``0``. + by default ``0``. window : int How many items to preview at the begin and end of the array when the arrays is bigger than the window. diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 5f42d71c7e3..e86782daf7f 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -2349,7 +2349,8 @@ def table(data, names=None, schema=None, metadata=None, nthreads=None): "Expected pandas DataFrame, python dictionary or list of arrays") -def concat_tables(tables, c_bool promote=False, MemoryPool memory_pool=None): +def concat_tables(tables, c_bool promote=False, MemoryPool memory_pool=None, + FieldMergeOptions field_merge_options=None): """ Concatenate pyarrow.Table objects. @@ -2371,9 +2372,13 @@ def concat_tables(tables, c_bool promote=False, MemoryPool memory_pool=None): tables : iterable of pyarrow.Table objects Pyarrow tables to concatenate into a single Table. promote : bool, default False - If True, concatenate tables with null-filling and null type promotion. + If True, concatenate tables with null-filling and type promotion. + See field_merge_options for the type promotion behavior. memory_pool : MemoryPool, default None For memory allocations, if required, otherwise use default pool. + field_merge_options : FieldMergeOptions, default None + The type promotion options; by default, null and only null can + be unified with another type. """ cdef: vector[shared_ptr[CTable]] c_tables @@ -2386,6 +2391,9 @@ def concat_tables(tables, c_bool promote=False, MemoryPool memory_pool=None): for table in tables: c_tables.push_back(table.sp_table) + if field_merge_options: + options.field_merge_options = field_merge_options.c_options + with nogil: options.unify_schemas = promote c_result_table = GetResultValue( diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index e5e27332d13..010c2a10cf5 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -1189,6 +1189,15 @@ def test_concat_tables_with_promotion(): pa.array([None, None, 1.0, 2.0], type=pa.float32()), ], ["int64_field", "float_field"])) + t3 = pa.Table.from_arrays( + [pa.array([1, 2], type=pa.int32())], ["int64_field"]) + result = pa.concat_tables( + [t1, t3], promote=True, + field_merge_options=pa.FieldMergeOptions.permissive()) + assert result.equals(pa.Table.from_arrays([ + pa.array([1, 2, 1, 2], type=pa.int64()), + ], ["int64_field"])) + def test_concat_tables_with_promotion_error(): t1 = pa.Table.from_arrays( diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index abd4c9e6a7a..e81a1d5534e 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -1830,7 +1830,6 @@ def unify_schemas(schemas, *, options=None): ---------- schemas : list of Schema Schemas to merge into a single one. - options : FieldMergeOptions, optional Options for merging duplicate fields.