diff --git a/java/vector/src/main/codegen/templates/UnionVector.java b/java/vector/src/main/codegen/templates/UnionVector.java index b05005dad6a..c79dfd00863 100644 --- a/java/vector/src/main/codegen/templates/UnionVector.java +++ b/java/vector/src/main/codegen/templates/UnionVector.java @@ -518,7 +518,11 @@ private ValueVector getVector(int index) { } public Object getObject(int index) { - return getVector(index).getObject(index); + ValueVector vector = getVector(index); + if (vector != null) { + return vector.getObject(index); + } + return null; } public byte[] get(int index) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java index ccbdc9c19f4..082d2ba1744 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/dictionary/Dictionary.java @@ -22,6 +22,7 @@ import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.util.Validator; /** * A dictionary (integer to Value mapping) that is used to facilitate @@ -63,11 +64,21 @@ public boolean equals(Object o) { return false; } Dictionary that = (Dictionary) o; - return Objects.equals(encoding, that.encoding) && Objects.equals(dictionary, that.dictionary); + return Objects.equals(encoding, that.encoding) && compareFieldVector(dictionary, that.dictionary); } @Override public int hashCode() { return Objects.hash(encoding, dictionary); } + + //TODO after vector api support compare two vectors, this should be cleaned up + private boolean compareFieldVector(FieldVector vector1, FieldVector vector2) { + try { + Validator.compareFieldVectors(vector1, vector2); + } catch (IllegalArgumentException e) { + return false; + } + return true; + } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java index e0bd218d47b..2d6391b33c1 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryVector.java @@ -20,6 +20,7 @@ import static org.apache.arrow.vector.TestUtils.newVarBinaryVector; import static org.apache.arrow.vector.TestUtils.newVarCharVector; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; import java.nio.charset.StandardCharsets; @@ -67,7 +68,7 @@ public void terminate() throws Exception { public void testEncodeStrings() { // Create a new value vector try (final VarCharVector vector = newVarCharVector("foo", allocator); - final VarCharVector dictionaryVector = newVarCharVector("dict", allocator);) { + final VarCharVector dictionaryVector = newVarCharVector("dict", allocator);) { vector.allocateNew(512, 5); // set some values @@ -85,13 +86,14 @@ public void testEncodeStrings() { dictionaryVector.setSafe(2, two, 0, two.length); dictionaryVector.setValueCount(3); - Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + Dictionary dictionary = + new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); try (final ValueVector encoded = (FieldVector) DictionaryEncoder.encode(vector, dictionary)) { // verify indices assertEquals(IntVector.class, encoded.getClass()); - IntVector index = ((IntVector)encoded); + IntVector index = ((IntVector) encoded); assertEquals(5, index.getValueCount()); assertEquals(0, index.get(0)); assertEquals(1, index.get(1)); @@ -102,9 +104,9 @@ public void testEncodeStrings() { // now run through the decoder and verify we get the original back try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) { assertEquals(vector.getClass(), decoded.getClass()); - assertEquals(vector.getValueCount(), ((VarCharVector)decoded).getValueCount()); + assertEquals(vector.getValueCount(), ((VarCharVector) decoded).getValueCount()); for (int i = 0; i < 5; i++) { - assertEquals(vector.getObject(i), ((VarCharVector)decoded).getObject(i)); + assertEquals(vector.getObject(i), ((VarCharVector) decoded).getObject(i)); } } } @@ -115,7 +117,7 @@ public void testEncodeStrings() { public void testEncodeLargeVector() { // Create a new value vector try (final VarCharVector vector = newVarCharVector("foo", allocator); - final VarCharVector dictionaryVector = newVarCharVector("dict", allocator);) { + final VarCharVector dictionaryVector = newVarCharVector("dict", allocator);) { vector.allocateNew(); int count = 10000; @@ -131,7 +133,8 @@ public void testEncodeLargeVector() { dictionaryVector.setSafe(2, two, 0, two.length); dictionaryVector.setValueCount(3); - Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + Dictionary dictionary = + new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); try (final ValueVector encoded = (FieldVector) DictionaryEncoder.encode(vector, dictionary)) { @@ -156,14 +159,6 @@ public void testEncodeLargeVector() { } } - private void writeListVector(UnionListWriter writer, int[] values) { - writer.startList(); - for (int v: values) { - writer.integer().writeInt(v); - } - writer.endList(); - } - @Test public void testEncodeList() { // Create a new value vector @@ -218,13 +213,6 @@ public void testEncodeList() { } } - private void writeStructVector(NullableStructWriter writer, int value1, long value2) { - writer.start(); - writer.integer("f0").writeInt(value1); - writer.bigInt("f1").writeBigInt(value2); - writer.end(); - } - @Test public void testEncodeStruct() { // Create a new value vector @@ -406,4 +394,200 @@ public void testEncodeUnion() { } } } + + @Test + public void testIntEquals() { + //test Int + try (final IntVector vector1 = new IntVector("", allocator); + final IntVector vector2 = new IntVector("", allocator)) { + + Dictionary dict1 = new Dictionary(vector1, new DictionaryEncoding(1L, false, null)); + Dictionary dict2 = new Dictionary(vector2, new DictionaryEncoding(1L, false, null)); + + vector1.allocateNew(3); + vector1.setValueCount(3); + vector2.allocateNew(3); + vector2.setValueCount(3); + + vector1.setSafe(0, 1); + vector1.setSafe(1, 2); + vector1.setSafe(2, 3); + + vector2.setSafe(0, 1); + vector2.setSafe(1, 2); + vector2.setSafe(2, 0); + + assertFalse(dict1.equals(dict2)); + + vector2.setSafe(2, 3); + assertTrue(dict1.equals(dict2)); + } + } + + @Test + public void testVarcharEquals() { + try (final VarCharVector vector1 = new VarCharVector("", allocator); + final VarCharVector vector2 = new VarCharVector("", allocator)) { + + Dictionary dict1 = new Dictionary(vector1, new DictionaryEncoding(1L, false, null)); + Dictionary dict2 = new Dictionary(vector2, new DictionaryEncoding(1L, false, null)); + + vector1.allocateNew(); + vector1.setValueCount(3); + vector2.allocateNew(); + vector2.setValueCount(3); + + // set some values + vector1.setSafe(0, zero, 0, zero.length); + vector1.setSafe(1, one, 0, one.length); + vector1.setSafe(2, two, 0, two.length); + + vector2.setSafe(0, zero, 0, zero.length); + vector2.setSafe(1, one, 0, one.length); + vector2.setSafe(2, one, 0, one.length); + + assertFalse(dict1.equals(dict2)); + + vector2.setSafe(2, two, 0, two.length); + assertTrue(dict1.equals(dict2)); + } + } + + @Test + public void testVarBinaryEquals() { + try (final VarBinaryVector vector1 = new VarBinaryVector("", allocator); + final VarBinaryVector vector2 = new VarBinaryVector("", allocator)) { + + Dictionary dict1 = new Dictionary(vector1, new DictionaryEncoding(1L, false, null)); + Dictionary dict2 = new Dictionary(vector2, new DictionaryEncoding(1L, false, null)); + + vector1.allocateNew(); + vector1.setValueCount(3); + vector2.allocateNew(); + vector2.setValueCount(3); + + // set some values + vector1.setSafe(0, zero, 0, zero.length); + vector1.setSafe(1, one, 0, one.length); + vector1.setSafe(2, two, 0, two.length); + + vector2.setSafe(0, zero, 0, zero.length); + vector2.setSafe(1, one, 0, one.length); + vector2.setSafe(2, one, 0, one.length); + + assertFalse(dict1.equals(dict2)); + + vector2.setSafe(2, two, 0, two.length); + assertTrue(dict1.equals(dict2)); + } + } + + @Test + public void testListEquals() { + try (final ListVector vector1 = ListVector.empty("", allocator); + final ListVector vector2 = ListVector.empty("", allocator);) { + + Dictionary dict1 = new Dictionary(vector1, new DictionaryEncoding(1L, false, null)); + Dictionary dict2 = new Dictionary(vector2, new DictionaryEncoding(1L, false, null)); + + UnionListWriter writer1 = vector1.getWriter(); + writer1.allocate(); + + //set some values + writeListVector(writer1, new int[] {1, 2}); + writeListVector(writer1, new int[] {3, 4}); + writeListVector(writer1, new int[] {5, 6}); + writer1.setValueCount(3); + + UnionListWriter writer2 = vector2.getWriter(); + writer2.allocate(); + + //set some values + writeListVector(writer2, new int[] {1, 2}); + writeListVector(writer2, new int[] {3, 4}); + writeListVector(writer2, new int[] {5, 6}); + writer2.setValueCount(3); + + assertTrue(dict1.equals(dict2)); + } + } + + @Test + public void testStructEquals() { + try (final StructVector vector1 = StructVector.empty("", allocator); + final StructVector vector2 = StructVector.empty("", allocator);) { + vector1.addOrGet("f0", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class); + vector1.addOrGet("f1", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class); + vector2.addOrGet("f0", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class); + vector2.addOrGet("f1", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class); + + Dictionary dict1 = new Dictionary(vector1, new DictionaryEncoding(1L, false, null)); + Dictionary dict2 = new Dictionary(vector2, new DictionaryEncoding(1L, false, null)); + + NullableStructWriter writer1 = vector1.getWriter(); + writer1.allocate(); + + writeStructVector(writer1, 1, 10L); + writeStructVector(writer1, 2, 20L); + writer1.setValueCount(2); + + NullableStructWriter writer2 = vector2.getWriter(); + writer2.allocate(); + + writeStructVector(writer2, 1, 10L); + writeStructVector(writer2, 2, 20L); + writer2.setValueCount(2); + + assertTrue(dict1.equals(dict2)); + } + } + + @Test + public void testUnionEquals() { + try (final UnionVector vector1 = new UnionVector("", allocator, null); + final UnionVector vector2 = new UnionVector("", allocator, null);) { + + final NullableUInt4Holder uInt4Holder = new NullableUInt4Holder(); + uInt4Holder.value = 10; + uInt4Holder.isSet = 1; + + final NullableIntHolder intHolder = new NullableIntHolder(); + uInt4Holder.value = 20; + uInt4Holder.isSet = 1; + + vector1.setType(0, Types.MinorType.UINT4); + vector1.setSafe(0, uInt4Holder); + + vector1.setType(2, Types.MinorType.INT); + vector1.setSafe(2, intHolder); + vector1.setValueCount(3); + + vector2.setType(0, Types.MinorType.UINT4); + vector2.setSafe(0, uInt4Holder); + + vector2.setType(2, Types.MinorType.INT); + vector2.setSafe(2, intHolder); + vector2.setValueCount(3); + + Dictionary dict1 = new Dictionary(vector1, new DictionaryEncoding(1L, false, null)); + Dictionary dict2 = new Dictionary(vector2, new DictionaryEncoding(1L, false, null)); + + assertTrue(dict1.equals(dict2)); + } + } + + private void writeStructVector(NullableStructWriter writer, int value1, long value2) { + writer.start(); + writer.integer("f0").writeInt(value1); + writer.bigInt("f1").writeBigInt(value2); + writer.end(); + } + + private void writeListVector(UnionListWriter writer, int[] values) { + writer.startList(); + for (int v: values) { + writer.integer().writeInt(v); + } + writer.endList(); + } }