diff --git a/cpp/src/arrow/array/diff.cc b/cpp/src/arrow/array/diff.cc index dc6bee95f60..7cc1a8e69ca 100644 --- a/cpp/src/arrow/array/diff.cc +++ b/cpp/src/arrow/array/diff.cc @@ -35,7 +35,6 @@ #include "arrow/util/logging.h" #include "arrow/util/range.h" #include "arrow/util/string.h" -#include "arrow/util/variant.h" #include "arrow/util/visibility.h" #include "arrow/vendored/datetime.h" #include "arrow/visitor_inline.h" @@ -89,83 +88,49 @@ static UnitSlice GetView(const UnionArray& array, int64_t index) { return UnitSlice{&array, index}; } -struct NullTag { - constexpr bool operator==(const NullTag& other) const { return true; } - constexpr bool operator!=(const NullTag& other) const { return false; } -}; - -template -class NullOr { - public: - using VariantType = util::variant; +using ValueComparator = std::function; - NullOr() : variant_(NullTag{}) {} - explicit NullOr(T t) : variant_(std::move(t)) {} - - template - NullOr(const ArrayType& array, int64_t index) { - if (array.IsNull(index)) { - variant_.emplace(NullTag{}); - } else { - variant_.emplace(GetView(array, index)); - } +struct ValueComparatorVisitor { + template + Status Visit(const T&) { + using ArrayType = typename TypeTraits::ArrayType; + out = [](const Array& base, int64_t base_index, const Array& target, + int64_t target_index) { + return (GetView(checked_cast(base), base_index) == + GetView(checked_cast(target), target_index)); + }; + return Status::OK(); } - bool operator==(const NullOr& other) const { return variant_ == other.variant_; } - bool operator!=(const NullOr& other) const { return variant_ != other.variant_; } + Status Visit(const NullType&) { return Status::NotImplemented("null type"); } - private: - VariantType variant_; -}; - -template -using ViewType = decltype(GetView(std::declval(), 0)); + Status Visit(const ExtensionType&) { return Status::NotImplemented("extension type"); } -template -class ViewGenerator { - public: - using View = ViewType; - - explicit ViewGenerator(const Array& array) - : array_(checked_cast(array)) { - DCHECK_EQ(array.null_count(), 0); + Status Visit(const DictionaryType&) { + return Status::NotImplemented("dictionary type"); } - View operator()(int64_t index) const { return GetView(array_, index); } + ValueComparator Create(const DataType& type) { + DCHECK_OK(VisitTypeInline(type, this)); + return out; + } - private: - const ArrayType& array_; + ValueComparator out; }; -template -internal::LazyRange> MakeViewRange(const Array& array) { - using Generator = ViewGenerator; - return internal::LazyRange(Generator(array), array.length()); +ValueComparator GetValueComparator(const DataType& type) { + ValueComparatorVisitor type_visitor; + return type_visitor.Create(type); } -template -class NullOrViewGenerator { - public: - using View = ViewType; - - explicit NullOrViewGenerator(const Array& array) - : array_(checked_cast(array)) {} - - NullOr operator()(int64_t index) const { - return array_.IsNull(index) ? NullOr() : NullOr(GetView(array_, index)); +// represents an intermediate state in the comparison of two arrays +struct EditPoint { + int64_t base, target; + bool operator==(EditPoint other) const { + return base == other.base && target == other.target; } - - private: - const ArrayType& array_; }; -template -internal::LazyRange> MakeNullOrViewRange( - const Array& array) { - using Generator = NullOrViewGenerator; - return internal::LazyRange(Generator(array), array.length()); -} - /// A generic sequence difference algorithm, based on /// /// E. W. Myers, "An O(ND) difference algorithm and its variations," @@ -181,17 +146,35 @@ internal::LazyRange> MakeNullOrViewRange( /// representation is minimal in the common case where the sequences differ only slightly, /// since most of the elements are shared between base and target and are represented /// implicitly. -template class QuadraticSpaceMyersDiff { public: - // represents an intermediate state in the comparison of two arrays - struct EditPoint { - Iterator base, target; + QuadraticSpaceMyersDiff(const Array& base, const Array& target, MemoryPool* pool) + : base_(base), + target_(target), + pool_(pool), + value_comparator_(GetValueComparator(*base.type())), + base_begin_(0), + base_end_(base.length()), + target_begin_(0), + target_end_(target.length()), + endpoint_base_({ExtendFrom({base_begin_, target_begin_}).base}), + insert_({true}) { + if ((base_end_ - base_begin_ == target_end_ - target_begin_) && + endpoint_base_[0] == base_end_) { + // trivial case: base == target + finish_index_ = 0; + } + } - bool operator==(EditPoint other) const { - return base == other.base && target == other.target; + bool ValuesEqual(int64_t base_index, int64_t target_index) const { + bool base_null = base_.IsNull(base_index); + bool target_null = target_.IsNull(target_index); + if (base_null || target_null) { + // If only one is null, then this is false, otherwise true + return base_null && target_null; } - }; + return value_comparator_(base_, base_index, target_, target_index); + } // increment the position within base (the element pointed to was deleted) // then extend maximally @@ -215,29 +198,13 @@ class QuadraticSpaceMyersDiff { // present in both sequences) EditPoint ExtendFrom(EditPoint p) const { for (; p.base != base_end_ && p.target != target_end_; ++p.base, ++p.target) { - if (*p.base != *p.target) { + if (!ValuesEqual(p.base, p.target)) { break; } } return p; } - QuadraticSpaceMyersDiff(Iterator base_begin, Iterator base_end, Iterator target_begin, - Iterator target_end) - : base_begin_(base_begin), - base_end_(base_end), - target_begin_(target_begin), - target_end_(target_end), - endpoint_base_({ExtendFrom({base_begin_, target_begin_}).base}), - insert_({true}) { - if (std::distance(base_begin_, base_end_) == - std::distance(target_begin_, target_end_) && - endpoint_base_[0] == base_end_) { - // trivial case: base == target - finish_index_ = 0; - } - } - // beginning of a range for storing per-edit state in endpoint_base_ and insert_ int64_t StorageOffset(int64_t edit_count) const { return edit_count * (edit_count + 1) / 2; @@ -342,98 +309,56 @@ class QuadraticSpaceMyersDiff { {field("insert", boolean()), field("run_length", int64())}); } + Result> Diff() { + while (!Done()) { + Next(); + } + return GetEdits(pool_); + } + private: + const Array& base_; + const Array& target_; + MemoryPool* pool_; + ValueComparator value_comparator_; int64_t finish_index_ = -1; int64_t edit_count_ = 0; - Iterator base_begin_, base_end_; - Iterator target_begin_, target_end_; + int64_t base_begin_, base_end_; + int64_t target_begin_, target_end_; // each element of endpoint_base_ is the furthest position in base reachable given an // edit_count and (# insertions) - (# deletions). Each bit of insert_ records whether // the corresponding furthest position was reached via an insertion or a deletion // (followed by a run of shared elements). See StorageOffset for the // layout of these vectors - std::vector endpoint_base_; + std::vector endpoint_base_; std::vector insert_; }; -struct DiffImpl { - Status Visit(const NullType&) { - bool insert = base_.length() < target_.length(); - auto run_length = std::min(base_.length(), target_.length()); - auto edit_count = std::max(base_.length(), target_.length()) - run_length; - - TypedBufferBuilder insert_builder(pool_); - RETURN_NOT_OK(insert_builder.Resize(edit_count + 1)); - insert_builder.UnsafeAppend(false); - TypedBufferBuilder run_length_builder(pool_); - RETURN_NOT_OK(run_length_builder.Resize(edit_count + 1)); - run_length_builder.UnsafeAppend(run_length); - if (edit_count > 0) { - insert_builder.UnsafeAppend(edit_count, insert); - run_length_builder.UnsafeAppend(edit_count, 0); - } - - std::shared_ptr insert_buf, run_length_buf; - RETURN_NOT_OK(insert_builder.Finish(&insert_buf)); - RETURN_NOT_OK(run_length_builder.Finish(&run_length_buf)); - - ARROW_ASSIGN_OR_RAISE( - out_, - StructArray::Make({std::make_shared(edit_count + 1, insert_buf), - std::make_shared(edit_count + 1, run_length_buf)}, - {field("insert", boolean()), field("run_length", int64())})); - return Status::OK(); - } - - template - Status Visit(const T&) { - using ArrayType = typename TypeTraits::ArrayType; - if (base_.null_count() == 0 && target_.null_count() == 0) { - auto base = MakeViewRange(base_); - auto target = MakeViewRange(target_); - ARROW_ASSIGN_OR_RAISE(out_, - Diff(base.begin(), base.end(), target.begin(), target.end())); - } else { - auto base = MakeNullOrViewRange(base_); - auto target = MakeNullOrViewRange(target_); - ARROW_ASSIGN_OR_RAISE(out_, - Diff(base.begin(), base.end(), target.begin(), target.end())); - } - return Status::OK(); - } - - Status Visit(const ExtensionType&) { - auto base = checked_cast(base_).storage(); - auto target = checked_cast(target_).storage(); - ARROW_ASSIGN_OR_RAISE(out_, arrow::Diff(*base, *target, pool_)); - return Status::OK(); - } - - Status Visit(const DictionaryType& t) { - return Status::NotImplemented("diffing arrays of type ", t); - } - - Result> Diff() { - RETURN_NOT_OK(VisitTypeInline(*base_.type(), this)); - return out_; - } - - template - Result> Diff(Iterator base_begin, Iterator base_end, - Iterator target_begin, Iterator target_end) { - QuadraticSpaceMyersDiff impl(base_begin, base_end, target_begin, - target_end); - while (!impl.Done()) { - impl.Next(); - } - return impl.GetEdits(pool_); - } - - const Array& base_; - const Array& target_; - MemoryPool* pool_; - std::shared_ptr out_; -}; +Result> NullDiff(const Array& base, const Array& target, + MemoryPool* pool) { + bool insert = base.length() < target.length(); + auto run_length = std::min(base.length(), target.length()); + auto edit_count = std::max(base.length(), target.length()) - run_length; + + TypedBufferBuilder insert_builder(pool); + RETURN_NOT_OK(insert_builder.Resize(edit_count + 1)); + insert_builder.UnsafeAppend(false); + TypedBufferBuilder run_length_builder(pool); + RETURN_NOT_OK(run_length_builder.Resize(edit_count + 1)); + run_length_builder.UnsafeAppend(run_length); + if (edit_count > 0) { + insert_builder.UnsafeAppend(edit_count, insert); + run_length_builder.UnsafeAppend(edit_count, 0); + } + + std::shared_ptr insert_buf, run_length_buf; + RETURN_NOT_OK(insert_builder.Finish(&insert_buf)); + RETURN_NOT_OK(run_length_builder.Finish(&run_length_buf)); + + return StructArray::Make({std::make_shared(edit_count + 1, insert_buf), + std::make_shared(edit_count + 1, run_length_buf)}, + {field("insert", boolean()), field("run_length", int64())}); +} Result> Diff(const Array& base, const Array& target, MemoryPool* pool) { @@ -441,7 +366,17 @@ Result> Diff(const Array& base, const Array& target return Status::TypeError("only taking the diff of like-typed arrays is supported."); } - return DiffImpl{base, target, pool, nullptr}.Diff(); + if (base.type()->id() == Type::NA) { + return NullDiff(base, target, pool); + } else if (base.type()->id() == Type::EXTENSION) { + auto base_storage = checked_cast(base).storage(); + auto target_storage = checked_cast(target).storage(); + return Diff(*base_storage, *target_storage, pool); + } else if (base.type()->id() == Type::DICTIONARY) { + return Status::NotImplemented("diffing arrays of type ", *base.type()); + } else { + return QuadraticSpaceMyersDiff(base, target, pool).Diff(); + } } using Formatter = std::function; diff --git a/cpp/src/arrow/array/diff.h b/cpp/src/arrow/array/diff.h index 7c091ee5dae..e0b85fb90f1 100644 --- a/cpp/src/arrow/array/diff.h +++ b/cpp/src/arrow/array/diff.h @@ -54,7 +54,7 @@ class MemoryPool; /// \return an edit script array which can be applied to base to produce target ARROW_EXPORT Result> Diff(const Array& base, const Array& target, - MemoryPool* pool); + MemoryPool* pool = default_memory_pool()); /// \brief visitor interface for easy traversal of an edit script /// diff --git a/cpp/src/arrow/array/diff_test.cc b/cpp/src/arrow/array/diff_test.cc index 4917d4524d1..827c8894ac6 100644 --- a/cpp/src/arrow/array/diff_test.cc +++ b/cpp/src/arrow/array/diff_test.cc @@ -586,6 +586,9 @@ TEST_F(DiffTest, DictionaryDiffFormatter) { )"; ASSERT_EQ(formatted.str(), formatted_expected_indices); + // Note: Diff doesn't work at the moment with dictionary arrays + ASSERT_RAISES(NotImplemented, Diff(*base_, *target_)); + // differing dictionaries target_dict = ArrayFromJSON(utf8(), R"(["b", "c", "a"])"); target_indices = base_indices;