diff --git a/java/vector/src/main/codegen/templates/UnionVector.java b/java/vector/src/main/codegen/templates/UnionVector.java index db3c8a89f5e..2c28c66beb3 100644 --- a/java/vector/src/main/codegen/templates/UnionVector.java +++ b/java/vector/src/main/codegen/templates/UnionVector.java @@ -550,7 +550,7 @@ public Iterator iterator() { return vectors.iterator(); } - private ValueVector getVector(int index) { + public ValueVector getVector(int index) { int type = typeBuffer.getByte(index * TYPE_WIDTH); switch (MinorType.values()[type]) { case NULL: diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compare/ApproxEqualsVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/compare/ApproxEqualsVisitor.java index ab5ce94e7e3..bcf8c64e0ce 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/compare/ApproxEqualsVisitor.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/compare/ApproxEqualsVisitor.java @@ -109,8 +109,10 @@ public Boolean visit(BaseFixedWidthVector left, Range range) { } @Override - protected ApproxEqualsVisitor createInnerVisitor(ValueVector left, ValueVector right) { - return new ApproxEqualsVisitor(left, right, floatDiffFunction.clone(), doubleDiffFunction.clone()); + protected ApproxEqualsVisitor createInnerVisitor( + ValueVector left, ValueVector right, + BiFunction typeComparator) { + return new ApproxEqualsVisitor(left, right, floatDiffFunction.clone(), doubleDiffFunction.clone(), typeComparator); } private boolean float4ApproxEquals(Range range) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java index 49d6125a82c..f658af79ed0 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java @@ -24,7 +24,6 @@ import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.BaseFixedWidthVector; import org.apache.arrow.vector.BaseVariableWidthVector; -import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.NullVector; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.complex.BaseRepeatedValueVector; @@ -186,12 +185,9 @@ public Boolean visit(NullVector left, Range range) { return true; } - /** - * Creates a visitor to visit child vectors. - * It is used for complex vector types. - * @return the visitor for child vecors. - */ - protected RangeEqualsVisitor createInnerVisitor(ValueVector leftInner, ValueVector rightInner) { + protected RangeEqualsVisitor createInnerVisitor( + ValueVector leftInner, ValueVector rightInner, + BiFunction typeComparator) { return new RangeEqualsVisitor(leftInner, rightInner, typeComparator); } @@ -199,16 +195,23 @@ protected boolean compareUnionVectors(Range range) { UnionVector leftVector = (UnionVector) left; UnionVector rightVector = (UnionVector) right; - List leftChildren = leftVector.getChildrenFromFields(); - List rightChildren = rightVector.getChildrenFromFields(); - - if (leftChildren.size() != rightChildren.size()) { - return false; - } - - for (int k = 0; k < leftChildren.size(); k++) { - RangeEqualsVisitor visitor = createInnerVisitor(leftChildren.get(k), rightChildren.get(k)); - if (!visitor.rangeEquals(range)) { + Range subRange = new Range(0, 0, 1); + for (int i = 0; i < range.getLength(); i++) { + subRange.setLeftStart(range.getLeftStart() + i).setRightStart(range.getRightStart() + i); + ValueVector leftSubVector = leftVector.getVector(range.getLeftStart() + i); + ValueVector rightSubVector = rightVector.getVector(range.getRightStart() + i); + + if (leftSubVector == null || rightSubVector == null) { + if (leftSubVector == rightSubVector) { + continue; + } else { + return false; + } + } + TypeEqualsVisitor typeVisitor = new TypeEqualsVisitor(rightSubVector); + RangeEqualsVisitor visitor = + createInnerVisitor(leftSubVector, rightSubVector, (left, right) -> typeVisitor.equals(left)); + if (!visitor.rangeEquals(subRange)) { return false; } } @@ -225,7 +228,8 @@ protected boolean compareStructVectors(Range range) { } for (String name : leftChildNames) { - RangeEqualsVisitor visitor = createInnerVisitor(leftVector.getChild(name), rightVector.getChild(name)); + RangeEqualsVisitor visitor = + createInnerVisitor(leftVector.getChild(name), rightVector.getChild(name), /*type comparator*/ null); if (!visitor.rangeEquals(range)) { return false; } @@ -304,7 +308,8 @@ protected boolean compareListVectors(Range range) { ListVector leftVector = (ListVector) left; ListVector rightVector = (ListVector) right; - RangeEqualsVisitor innerVisitor = createInnerVisitor(leftVector.getDataVector(), rightVector.getDataVector()); + RangeEqualsVisitor innerVisitor = + createInnerVisitor(leftVector.getDataVector(), rightVector.getDataVector(), /*type comparator*/ null); Range innerRange = new Range(); for (int i = 0; i < range.getLength(); i++) { @@ -350,7 +355,8 @@ protected boolean compareFixedSizeListVectors(Range range) { } int listSize = leftVector.getListSize(); - RangeEqualsVisitor innerVisitor = createInnerVisitor(leftVector.getDataVector(), rightVector.getDataVector()); + RangeEqualsVisitor innerVisitor = + createInnerVisitor(leftVector.getDataVector(), rightVector.getDataVector(), /*type comparator*/ null); Range innerRange = new Range(0, 0, listSize); for (int i = 0; i < range.getLength(); i++) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java index 183c8abac18..c35fd4107a9 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java @@ -282,6 +282,57 @@ public void testUnionVectorRangeEquals() { } } + /** + * Test comparing two union vectors. + * The two vectors are different in total, but have a range with equal values. + */ + @Test + public void testUnionVectorSubRangeEquals() { + try (final UnionVector vector1 = new UnionVector("union", allocator, null, null); + final UnionVector vector2 = new UnionVector("union", allocator, null, null);) { + + final NullableUInt4Holder uInt4Holder = new NullableUInt4Holder(); + uInt4Holder.value = 10; + uInt4Holder.isSet = 1; + + final NullableIntHolder intHolder = new NullableIntHolder(); + intHolder.value = 20; + intHolder.isSet = 1; + + vector1.setType(0, Types.MinorType.UINT4); + vector1.setSafe(0, uInt4Holder); + + vector1.setType(1, Types.MinorType.INT); + vector1.setSafe(1, intHolder); + + vector1.setType(2, Types.MinorType.INT); + vector1.setSafe(2, intHolder); + + vector1.setType(3, Types.MinorType.INT); + vector1.setSafe(3, intHolder); + + vector1.setValueCount(4); + + vector2.setType(0, Types.MinorType.UINT4); + vector2.setSafe(0, uInt4Holder); + + vector2.setType(1, Types.MinorType.INT); + vector2.setSafe(1, intHolder); + + vector2.setType(2, Types.MinorType.INT); + vector2.setSafe(2, intHolder); + + vector2.setType(3, Types.MinorType.UINT4); + vector2.setSafe(3, uInt4Holder); + + vector2.setValueCount(4); + + RangeEqualsVisitor visitor = new RangeEqualsVisitor(vector1, vector2); + assertFalse(visitor.rangeEquals(new Range(0, 0, 4))); + assertTrue(visitor.rangeEquals(new Range(1, 1, 2))); + } + } + @Ignore @Test public void testEqualsWithOutTypeCheck() {