diff --git a/java/vector/src/main/codegen/templates/UnionVector.java b/java/vector/src/main/codegen/templates/UnionVector.java index 6220e518cad..a809c508a3f 100644 --- a/java/vector/src/main/codegen/templates/UnionVector.java +++ b/java/vector/src/main/codegen/templates/UnionVector.java @@ -675,20 +675,6 @@ public int hashCode(int index) { return getVector(index).hashCode(index); } - @Override - public boolean equals(int index, ValueVector to, int toIndex) { - if (to == null) { - return false; - } - Preconditions.checkArgument(index >= 0 && index < valueCount, - "index %s out of range[0, %s]:", index, valueCount - 1); - Preconditions.checkArgument(toIndex >= 0 && toIndex < to.getValueCount(), - "index %s out of range[0, %s]:", index, to.getValueCount() - 1); - - RangeEqualsVisitor visitor = new RangeEqualsVisitor(to, index, toIndex, 1); - return this.accept(visitor); - } - @Override public boolean accept(RangeEqualsVisitor visitor) { return visitor.visit(this); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java index a1663731fe3..eba70e8527c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseFixedWidthVector.java @@ -882,21 +882,6 @@ public int hashCode(int index) { return ByteFunctionHelpers.hash(this.getDataBuffer(), start, end); } - @Override - public boolean equals(int index, ValueVector to, int toIndex) { - if (to == null) { - return false; - } - - Preconditions.checkArgument(index >= 0 && index < valueCount, - "index %s out of range[0, %s]:", index, valueCount - 1); - Preconditions.checkArgument(toIndex >= 0 && toIndex < to.getValueCount(), - "index %s out of range[0, %s]:", index, to.getValueCount() - 1); - - RangeEqualsVisitor visitor = new RangeEqualsVisitor(to, index, toIndex, 1); - return this.accept(visitor); - } - @Override public boolean accept(RangeEqualsVisitor visitor) { return visitor.visit(this); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java index b7aa8161bd0..c5b3569a2c5 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java @@ -1365,20 +1365,6 @@ public int hashCode(int index) { return ByteFunctionHelpers.hash(this.getDataBuffer(), start, end); } - @Override - public boolean equals(int index, ValueVector to, int toIndex) { - if (to == null) { - return false; - } - Preconditions.checkArgument(index >= 0 && index < valueCount, - "index %s out of range[0, %s]:", index, valueCount - 1); - Preconditions.checkArgument(toIndex >= 0 && toIndex < to.getValueCount(), - "index %s out of range[0, %s]:", index, to.getValueCount() - 1); - - RangeEqualsVisitor visitor = new RangeEqualsVisitor(to, index, toIndex, 1); - return this.accept(visitor); - } - @Override public boolean accept(RangeEqualsVisitor visitor) { return visitor.visit(this); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ValueVector.java b/java/vector/src/main/java/org/apache/arrow/vector/ValueVector.java index aed24f259ee..af5cb651670 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ValueVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ValueVector.java @@ -244,15 +244,6 @@ public interface ValueVector extends Closeable, Iterable { */ int hashCode(int index); - /** - * Check whether the element in index equals to the element in targetIndex from the target vector. - * @param index index to compare in this vector - * @param target target vector - * @param targetIndex index to compare in target vector - * @return true if equals, otherwise false. - */ - boolean equals(int index, ValueVector target, int targetIndex); - /** * Copy a cell value from a particular index in source vector to a particular * position in this vector. diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ZeroVector.java b/java/vector/src/main/java/org/apache/arrow/vector/ZeroVector.java index 4a9f30aed47..d24d7ff4d34 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ZeroVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ZeroVector.java @@ -251,11 +251,6 @@ public int hashCode(int index) { return 0; } - @Override - public boolean equals(int index, ValueVector to, int toIndex) { - return false; - } - @Override public void copyFrom(int fromIndex, int thisIndex, ValueVector from) { throw new UnsupportedOperationException(); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java index a5826e7e48f..537746eff6c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/compare/RangeEqualsVisitor.java @@ -20,6 +20,7 @@ import java.util.List; import org.apache.arrow.memory.util.ByteFunctionHelpers; +import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.BaseFixedWidthVector; import org.apache.arrow.vector.BaseVariableWidthVector; import org.apache.arrow.vector.FieldVector; @@ -37,46 +38,78 @@ public class RangeEqualsVisitor { protected final ValueVector right; - protected final int leftStart; - protected final int rightStart; - protected final int length; + protected int leftStart; + protected int rightStart; + protected int length; + + protected boolean typeCheckNeeded = true; /** * Constructs a new instance. */ - public RangeEqualsVisitor(ValueVector right, int leftStart, int rightStart, int length) { + public RangeEqualsVisitor(ValueVector right, int rightStart, int leftStart, int length, boolean typeCheckNeeded) { this.leftStart = leftStart; this.rightStart = rightStart; this.right = right; this.length = length; + this.typeCheckNeeded = typeCheckNeeded; + Preconditions.checkArgument(length >= 0, "length must be non negative"); + } + + /** + * Constructs a new instance. + */ + public RangeEqualsVisitor(ValueVector right, int leftStart, int rightStart, int length) { + this(right, rightStart, leftStart, length, true); + } + + /** + * Do some validation work, like type check and indices check. + */ + private boolean validate(ValueVector left) { + + if (!compareValueVector(left, right)) { + return false; + } + + Preconditions.checkArgument(leftStart >= 0, + "leftStart %s must be non negative.", leftStart); + Preconditions.checkArgument((leftStart + length) <= left.getValueCount(), + "(leftStart + length) %s out of range[0, %s].", 0, left.getValueCount()); + Preconditions.checkArgument(rightStart >= 0, + "rightStart %s must be non negative.", rightStart); + Preconditions.checkArgument((rightStart + length) <= right.getValueCount(), + "(rightStart + length) %s out of range[0, %s].", 0, right.getValueCount()); + + return true; } public boolean visit(BaseFixedWidthVector left) { - return compareBaseFixedWidthVectors(left); + return validate(left) && compareBaseFixedWidthVectors(left); } public boolean visit(BaseVariableWidthVector left) { - return compareBaseVariableWidthVectors(left); + return validate(left) && compareBaseVariableWidthVectors(left); } public boolean visit(ListVector left) { - return compareListVectors(left); + return validate(left) && compareListVectors(left); } public boolean visit(FixedSizeListVector left) { - return compareFixedSizeListVectors(left); + return validate(left) && compareFixedSizeListVectors(left); } public boolean visit(NonNullableStructVector left) { - return compareStructVectors(left); + return validate(left) && compareStructVectors(left); } public boolean visit(UnionVector left) { - return compareUnionVectors(left); + return validate(left) && compareUnionVectors(left); } public boolean visit(ZeroVector left) { - return compareValueVector(left, right); + return validate(left); } public boolean visit(ValueVector left) { @@ -84,15 +117,14 @@ public boolean visit(ValueVector left) { } protected boolean compareValueVector(ValueVector left, ValueVector right) { + if (!typeCheckNeeded) { + return true; + } return left.getField().getType().equals(right.getField().getType()); } protected boolean compareUnionVectors(UnionVector left) { - if (!compareValueVector(left, right)) { - return false; - } - UnionVector rightVector = (UnionVector) right; List leftChildren = left.getChildrenFromFields(); @@ -113,9 +145,6 @@ protected boolean compareUnionVectors(UnionVector left) { } protected boolean compareStructVectors(NonNullableStructVector left) { - if (!compareValueVector(left, right)) { - return false; - } NonNullableStructVector rightVector = (NonNullableStructVector) right; @@ -136,10 +165,6 @@ protected boolean compareStructVectors(NonNullableStructVector left) { protected boolean compareBaseFixedWidthVectors(BaseFixedWidthVector left) { - if (!compareValueVector(left, right)) { - return false; - } - for (int i = 0; i < length; i++) { int leftIndex = leftStart + i; int rightIndex = rightStart + i; @@ -152,14 +177,14 @@ protected boolean compareBaseFixedWidthVectors(BaseFixedWidthVector left) { int typeWidth = left.getTypeWidth(); if (!isNull) { - int startByteLeft = typeWidth * leftIndex; - int endByteLeft = typeWidth * (leftIndex + 1); + int startIndexLeft = typeWidth * leftIndex; + int endIndexLeft = typeWidth * (leftIndex + 1); - int startByteRight = typeWidth * rightIndex; - int endByteRight = typeWidth * (rightIndex + 1); + int startIndexRight = typeWidth * rightIndex; + int endIndexRight = typeWidth * (rightIndex + 1); - int ret = ByteFunctionHelpers.equal(left.getDataBuffer(), startByteLeft, endByteLeft, - right.getDataBuffer(), startByteRight, endByteRight); + int ret = ByteFunctionHelpers.equal(left.getDataBuffer(), startIndexLeft, endIndexLeft, + right.getDataBuffer(), startIndexRight, endIndexRight); if (ret == 0) { return false; @@ -170,9 +195,6 @@ protected boolean compareBaseFixedWidthVectors(BaseFixedWidthVector left) { } protected boolean compareBaseVariableWidthVectors(BaseVariableWidthVector left) { - if (!compareValueVector(left, right)) { - return false; - } for (int i = 0; i < length; i++) { int leftIndex = leftStart + i; @@ -186,14 +208,14 @@ protected boolean compareBaseVariableWidthVectors(BaseVariableWidthVector left) int offsetWidth = BaseVariableWidthVector.OFFSET_WIDTH; if (!isNull) { - final int startByteLeft = left.getOffsetBuffer().getInt(leftIndex * offsetWidth); - final int endByteLeft = left.getOffsetBuffer().getInt((leftIndex + 1) * offsetWidth); + final int startIndexLeft = left.getOffsetBuffer().getInt(leftIndex * offsetWidth); + final int endIndexLeft = left.getOffsetBuffer().getInt((leftIndex + 1) * offsetWidth); - final int startByteRight = right.getOffsetBuffer().getInt(rightIndex * offsetWidth); - final int endByteRight = right.getOffsetBuffer().getInt((rightIndex + 1) * offsetWidth); + final int startIndexRight = right.getOffsetBuffer().getInt(rightIndex * offsetWidth); + final int endIndexRight = right.getOffsetBuffer().getInt((rightIndex + 1) * offsetWidth); - int ret = ByteFunctionHelpers.equal(left.getDataBuffer(), startByteLeft, endByteLeft, - right.getDataBuffer(), startByteRight, endByteRight); + int ret = ByteFunctionHelpers.equal(left.getDataBuffer(), startIndexLeft, endIndexLeft, + right.getDataBuffer(), startIndexRight, endIndexRight); if (ret == 0) { return false; @@ -204,9 +226,6 @@ protected boolean compareBaseVariableWidthVectors(BaseVariableWidthVector left) } protected boolean compareListVectors(ListVector left) { - if (!compareValueVector(left, right)) { - return false; - } for (int i = 0; i < length; i++) { int leftIndex = leftStart + i; @@ -220,21 +239,21 @@ protected boolean compareListVectors(ListVector left) { int offsetWidth = BaseRepeatedValueVector.OFFSET_WIDTH; if (!isNull) { - final int startByteLeft = left.getOffsetBuffer().getInt(leftIndex * offsetWidth); - final int endByteLeft = left.getOffsetBuffer().getInt((leftIndex + 1) * offsetWidth); + final int startIndexLeft = left.getOffsetBuffer().getInt(leftIndex * offsetWidth); + final int endIndexLeft = left.getOffsetBuffer().getInt((leftIndex + 1) * offsetWidth); - final int startByteRight = right.getOffsetBuffer().getInt(rightIndex * offsetWidth); - final int endByteRight = right.getOffsetBuffer().getInt((rightIndex + 1) * offsetWidth); + final int startIndexRight = right.getOffsetBuffer().getInt(rightIndex * offsetWidth); + final int endIndexRight = right.getOffsetBuffer().getInt((rightIndex + 1) * offsetWidth); - if ((endByteLeft - startByteLeft) != (endByteRight - startByteRight)) { + if ((endIndexLeft - startIndexLeft) != (endIndexRight - startIndexRight)) { return false; } ValueVector leftDataVector = left.getDataVector(); ValueVector rightDataVector = ((ListVector)right).getDataVector(); - if (!leftDataVector.accept(new RangeEqualsVisitor(rightDataVector, startByteLeft, - startByteRight, (endByteLeft - startByteLeft)))) { + if (!leftDataVector.accept(new RangeEqualsVisitor(rightDataVector, startIndexLeft, + startIndexRight, (endIndexLeft - startIndexLeft)))) { return false; } } @@ -243,9 +262,6 @@ protected boolean compareListVectors(ListVector left) { } protected boolean compareFixedSizeListVectors(FixedSizeListVector left) { - if (!compareValueVector(left, right)) { - return false; - } if (left.getListSize() != ((FixedSizeListVector)right).getListSize()) { return false; @@ -263,26 +279,25 @@ protected boolean compareFixedSizeListVectors(FixedSizeListVector left) { int listSize = left.getListSize(); if (!isNull) { - final int startByteLeft = leftIndex * listSize; - final int endByteLeft = (leftIndex + 1) * listSize; + final int startIndexLeft = leftIndex * listSize; + final int endIndexLeft = (leftIndex + 1) * listSize; - final int startByteRight = rightIndex * listSize; - final int endByteRight = (rightIndex + 1) * listSize; + final int startIndexRight = rightIndex * listSize; + final int endIndexRight = (rightIndex + 1) * listSize; - if ((endByteLeft - startByteLeft) != (endByteRight - startByteRight)) { + if ((endIndexLeft - startIndexLeft) != (endIndexRight - startIndexRight)) { return false; } ValueVector leftDataVector = left.getDataVector(); ValueVector rightDataVector = ((FixedSizeListVector)right).getDataVector(); - if (!leftDataVector.accept(new RangeEqualsVisitor(rightDataVector, startByteLeft, startByteRight, - (endByteLeft - startByteLeft)))) { + if (!leftDataVector.accept(new RangeEqualsVisitor(rightDataVector, startIndexLeft, startIndexRight, + (endIndexLeft - startIndexLeft)))) { return false; } } } return true; } - } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/compare/VectorEqualsVisitor.java b/java/vector/src/main/java/org/apache/arrow/vector/compare/VectorEqualsVisitor.java index 47071dd1958..dfaf45fa43f 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/compare/VectorEqualsVisitor.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/compare/VectorEqualsVisitor.java @@ -26,7 +26,11 @@ public class VectorEqualsVisitor extends RangeEqualsVisitor { public VectorEqualsVisitor(ValueVector right) { - super(Preconditions.checkNotNull(right), 0, 0, right.getValueCount()); + this(right, true); + } + + public VectorEqualsVisitor(ValueVector right, boolean typeCheckNeeded) { + super(Preconditions.checkNotNull(right), 0, 0, right.getValueCount(), typeCheckNeeded); } @Override diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java index 6bdf81753b7..bcebae68966 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/FixedSizeListVector.java @@ -535,20 +535,6 @@ public int hashCode(int index) { return hash; } - @Override - public boolean equals(int index, ValueVector to, int toIndex) { - if (to == null) { - return false; - } - Preconditions.checkArgument(index >= 0 && index < valueCount, - "index %s out of range[0, %s]:", index, valueCount - 1); - Preconditions.checkArgument(toIndex >= 0 && toIndex < to.getValueCount(), - "index %s out of range[0, %s]:", index, to.getValueCount() - 1); - - RangeEqualsVisitor visitor = new RangeEqualsVisitor(to, index, toIndex, 1); - return this.accept(visitor); - } - @Override public boolean accept(RangeEqualsVisitor visitor) { return visitor.visit(this); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java index d6935dedb87..094b65874c9 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java @@ -28,7 +28,6 @@ import org.apache.arrow.memory.BaseAllocator; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.OutOfMemoryException; -import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.AddOrGetResult; import org.apache.arrow.vector.BitVectorHelper; import org.apache.arrow.vector.BufferBacked; @@ -427,20 +426,6 @@ public int hashCode(int index) { return hash; } - @Override - public boolean equals(int index, ValueVector to, int toIndex) { - if (to == null) { - return false; - } - Preconditions.checkArgument(index >= 0 && index < valueCount, - "index %s out of range[0, %s]:", index, valueCount - 1); - Preconditions.checkArgument(toIndex >= 0 && toIndex < to.getValueCount(), - "index %s out of range[0, %s]:", index, to.getValueCount() - 1); - - RangeEqualsVisitor visitor = new RangeEqualsVisitor(to, index, toIndex, 1); - return this.accept(visitor); - } - private class TransferImpl implements TransferPair { ListVector to; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/NonNullableStructVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/NonNullableStructVector.java index 995751ed0b3..54cff432766 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/NonNullableStructVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/NonNullableStructVector.java @@ -26,7 +26,6 @@ import java.util.Map; import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.DensityAwareVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; @@ -305,20 +304,6 @@ public boolean accept(RangeEqualsVisitor visitor) { return visitor.visit(this); } - @Override - public boolean equals(int index, ValueVector to, int toIndex) { - if (to == null) { - return false; - } - Preconditions.checkArgument(index >= 0 && index < valueCount, - "index %s out of range[0, %s]:", index, valueCount - 1); - Preconditions.checkArgument(toIndex >= 0 && toIndex < to.getValueCount(), - "index %s out of range[0, %s]:", index, to.getValueCount() - 1); - - RangeEqualsVisitor visitor = new RangeEqualsVisitor(to, index, toIndex, 1); - return this.accept(visitor); - } - @Override public boolean isNull(int index) { return false; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryHashTable.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryHashTable.java index f6cf2d38416..e7d0727a3bd 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryHashTable.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/DictionaryHashTable.java @@ -20,6 +20,7 @@ import java.util.Objects; import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; /** * HashTable used for Dictionary encoding. It holds two vectors (the vector to encode and dictionary vector) @@ -140,7 +141,7 @@ public int getIndex(int indexInArray, ValueVector toEncode) { for (DictionaryHashTable.Entry e = table[index]; e != null ; e = e.next) { if ((e.hash == hash)) { int dictIndex = e.index; - if (dictionary.equals(dictIndex, toEncode, indexInArray)) { + if (toEncode.accept(new RangeEqualsVisitor(dictionary, dictIndex, indexInArray, 1, false))) { return dictIndex; } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java index d344246fd59..3092492d7ac 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java @@ -36,6 +36,7 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.memory.util.ArrowBufPointer; +import org.apache.arrow.vector.compare.RangeEqualsVisitor; import org.apache.arrow.vector.compare.VectorEqualsVisitor; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.StructVector; @@ -2268,7 +2269,7 @@ public void testZeroVectorEquals() { @Test public void testZeroVectorNotEquals() { try (final IntVector intVector = new IntVector("int", allocator); - final ZeroVector zeroVector = new ZeroVector()) { + final ZeroVector zeroVector = new ZeroVector()) { VectorEqualsVisitor zeroVisitor = new VectorEqualsVisitor(zeroVector); assertFalse(intVector.accept(zeroVisitor)); @@ -2633,7 +2634,8 @@ public void testEqualsWithIndexOutOfRange() { vector2.setSafe(0, 1); vector2.setSafe(1, 2); - assertTrue(vector1.equals(3, vector2, 2)); + RangeEqualsVisitor visitor = new RangeEqualsVisitor(vector2, 3, 2, 1); + assertTrue(vector1.accept(visitor)); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java new file mode 100644 index 00000000000..847da35d56c --- /dev/null +++ b/java/vector/src/test/java/org/apache/arrow/vector/compare/TestRangeEqualsVisitor.java @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector.compare; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.nio.charset.Charset; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.ZeroVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.UnionVector; +import org.apache.arrow.vector.complex.impl.NullableStructWriter; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.holders.NullableIntHolder; +import org.apache.arrow.vector.holders.NullableUInt4Holder; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestRangeEqualsVisitor { + + private BufferAllocator allocator; + + @Before + public void init() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + private static final Charset utf8Charset = Charset.forName("UTF-8"); + private static final byte[] STR1 = "AAAAA1".getBytes(utf8Charset); + private static final byte[] STR2 = "BBBBBBBBB2".getBytes(utf8Charset); + private static final byte[] STR3 = "CCCC3".getBytes(utf8Charset); + + @After + public void terminate() throws Exception { + allocator.close(); + } + + @Test + public void testIntVectorEqualsWithNull() { + try (final IntVector vector1 = new IntVector("int", allocator); + final IntVector vector2 = new IntVector("int", allocator)) { + + vector1.allocateNew(2); + vector1.setValueCount(2); + vector2.allocateNew(2); + vector2.setValueCount(2); + + vector1.setSafe(0, 1); + vector1.setSafe(1, 2); + + vector2.setSafe(0, 1); + VectorEqualsVisitor visitor = new VectorEqualsVisitor(vector2); + + assertFalse(vector1.accept(visitor)); + } + } + + @Test + public void testBaseFixedWidthVectorRangeEqual() { + try (final IntVector vector1 = new IntVector("int", allocator); + final IntVector vector2 = new IntVector("int", allocator)) { + + vector1.allocateNew(5); + vector1.setValueCount(5); + vector2.allocateNew(5); + vector2.setValueCount(5); + + vector1.setSafe(0, 1); + vector1.setSafe(1, 2); + vector1.setSafe(2, 3); + vector1.setSafe(3, 4); + vector1.setSafe(4, 5); + + vector2.setSafe(0, 11); + vector2.setSafe(1, 2); + vector2.setSafe(2,3); + vector2.setSafe(3,4); + vector2.setSafe(4,55); + + RangeEqualsVisitor visitor = new RangeEqualsVisitor(vector2, 1, 1, 3); + assertTrue(vector1.accept(visitor)); + } + } + + @Test + public void testBaseVariableVectorRangeEquals() { + try (final VarCharVector vector1 = new VarCharVector("varchar", allocator); + final VarCharVector vector2 = new VarCharVector("varchar", allocator)) { + + vector1.allocateNew(); + vector2.allocateNew(); + + // set some values + vector1.setSafe(0, STR1, 0, STR1.length); + vector1.setSafe(1, STR2, 0, STR2.length); + vector1.setSafe(2, STR3, 0, STR3.length); + vector1.setSafe(3, STR2, 0, STR2.length); + vector1.setSafe(4, STR1, 0, STR1.length); + vector1.setValueCount(5); + + vector2.setSafe(0, STR1, 0, STR1.length); + vector2.setSafe(1, STR2, 0, STR2.length); + vector2.setSafe(2, STR3, 0, STR3.length); + vector2.setSafe(3, STR2, 0, STR2.length); + vector2.setSafe(4, STR1, 0, STR1.length); + vector2.setValueCount(5); + + RangeEqualsVisitor visitor = new RangeEqualsVisitor(vector2, 1, 1, 3); + assertTrue(vector1.accept(visitor)); + } + } + + @Test + public void testListVectorRangeEquals() { + try (final ListVector vector1 = ListVector.empty("list", allocator); + final ListVector vector2 = ListVector.empty("list", allocator);) { + + UnionListWriter writer1 = vector1.getWriter(); + writer1.allocate(); + + //set some values + writeListVector(writer1, new int[] {1, 2}); + writeListVector(writer1, new int[] {3, 4}); + writeListVector(writer1, new int[] {5, 6}); + writeListVector(writer1, new int[] {7, 8}); + writeListVector(writer1, new int[] {9, 10}); + writer1.setValueCount(5); + + UnionListWriter writer2 = vector2.getWriter(); + writer2.allocate(); + + //set some values + writeListVector(writer2, new int[] {0, 0}); + writeListVector(writer2, new int[] {3, 4}); + writeListVector(writer2, new int[] {5, 6}); + writeListVector(writer2, new int[] {7, 8}); + writeListVector(writer2, new int[] {0, 0}); + writer2.setValueCount(5); + + RangeEqualsVisitor visitor = new RangeEqualsVisitor(vector2, 1, 1, 3); + assertTrue(vector1.accept(visitor)); + } + } + + @Test + public void testStructVectorRangeEquals() { + try (final StructVector vector1 = StructVector.empty("struct", allocator); + final StructVector vector2 = StructVector.empty("struct", allocator);) { + vector1.addOrGet("f0", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class); + vector1.addOrGet("f1", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class); + vector2.addOrGet("f0", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class); + vector2.addOrGet("f1", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class); + + NullableStructWriter writer1 = vector1.getWriter(); + writer1.allocate(); + + writeStructVector(writer1, 1, 10L); + writeStructVector(writer1, 2, 20L); + writeStructVector(writer1, 3, 30L); + writeStructVector(writer1, 4, 40L); + writeStructVector(writer1, 5, 50L); + writer1.setValueCount(5); + + NullableStructWriter writer2 = vector2.getWriter(); + writer2.allocate(); + + writeStructVector(writer2, 0, 00L); + writeStructVector(writer2, 2, 20L); + writeStructVector(writer2, 3, 30L); + writeStructVector(writer2, 4, 40L); + writeStructVector(writer2, 0, 0L); + writer2.setValueCount(5); + + RangeEqualsVisitor visitor = new RangeEqualsVisitor(vector2, 1, 1, 3); + assertTrue(vector1.accept(visitor)); + } + } + + @Test + public void testUnionVectorRangeEquals() { + try (final UnionVector vector1 = new UnionVector("union", allocator, null); + final UnionVector vector2 = new UnionVector("union", allocator, null);) { + + final NullableUInt4Holder uInt4Holder = new NullableUInt4Holder(); + uInt4Holder.value = 10; + uInt4Holder.isSet = 1; + + final NullableIntHolder intHolder = new NullableIntHolder(); + uInt4Holder.value = 20; + uInt4Holder.isSet = 1; + + vector1.setType(0, Types.MinorType.UINT4); + vector1.setSafe(0, uInt4Holder); + + vector1.setType(1, Types.MinorType.INT); + vector1.setSafe(1, intHolder); + + vector1.setType(2, Types.MinorType.INT); + vector1.setSafe(2, intHolder); + vector1.setValueCount(3); + + vector2.setType(0, Types.MinorType.UINT4); + vector2.setSafe(0, uInt4Holder); + + vector2.setType(1, Types.MinorType.INT); + vector2.setSafe(1, intHolder); + + vector2.setType(2, Types.MinorType.INT); + vector2.setSafe(2, intHolder); + vector2.setValueCount(3); + + RangeEqualsVisitor visitor = new RangeEqualsVisitor(vector2, 1, 1, 2); + assertTrue(vector1.accept(visitor)); + } + } + + @Test + public void testEqualsWithOutTypeCheck() { + try (final IntVector intVector = new IntVector("int", allocator); + final ZeroVector zeroVector = new ZeroVector()) { + + VectorEqualsVisitor zeroVisitor = new VectorEqualsVisitor(zeroVector, false); + assertTrue(intVector.accept(zeroVisitor)); + + VectorEqualsVisitor intVisitor = new VectorEqualsVisitor(intVector, false); + assertTrue(zeroVector.accept(intVisitor)); + } + } + + private void writeStructVector(NullableStructWriter writer, int value1, long value2) { + writer.start(); + writer.integer("f0").writeInt(value1); + writer.bigInt("f1").writeBigInt(value2); + writer.end(); + } + + private void writeListVector(UnionListWriter writer, int[] values) { + writer.startList(); + for (int v: values) { + writer.integer().writeInt(v); + } + writer.endList(); + } +} diff --git a/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java b/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java index 792bd29903b..5f7522838d8 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/types/pojo/TestExtensionType.java @@ -34,7 +34,6 @@ import org.apache.arrow.vector.ExtensionTypeVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.FixedSizeBinaryVector; -import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.ipc.ArrowFileReader; import org.apache.arrow.vector.ipc.ArrowFileWriter; @@ -210,11 +209,6 @@ public int hashCode(int index) { return getUnderlyingVector().hashCode(index); } - @Override - public boolean equals(int index, ValueVector to, int toIndex) { - return getUnderlyingVector().equals(index, to, toIndex); - } - public void set(int index, UUID uuid) { ByteBuffer bb = ByteBuffer.allocate(16); bb.putLong(uuid.getMostSignificantBits());