diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/DictionaryUtility.java b/java/vector/src/main/java/org/apache/arrow/vector/util/DictionaryUtility.java index 345fa592241..9592f3975ab 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/DictionaryUtility.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/DictionaryUtility.java @@ -49,16 +49,13 @@ public static Field toMessageFormat(Field field, DictionaryProvider provider, Se return field; } DictionaryEncoding encoding = field.getDictionary(); - List children = field.getChildren(); + List children; - List updatedChildren = new ArrayList<>(children.size()); - for (Field child : children) { - updatedChildren.add(toMessageFormat(child, provider, dictionaryIdsUsed)); - } ArrowType type; if (encoding == null) { type = field.getType(); + children = field.getChildren(); } else { long id = encoding.getId(); Dictionary dictionary = provider.lookup(id); @@ -66,10 +63,16 @@ public static Field toMessageFormat(Field field, DictionaryProvider provider, Se throw new IllegalArgumentException("Could not find dictionary with ID " + id); } type = dictionary.getVectorType(); + children = dictionary.getVector().getField().getChildren(); dictionaryIdsUsed.add(id); } + final List updatedChildren = new ArrayList<>(children.size()); + for (Field child : children) { + updatedChildren.add(toMessageFormat(child, provider, dictionaryIdsUsed)); + } + return new Field(field.getName(), new FieldType(field.isNullable(), type, encoding, field.getMetadata()), updatedChildren); } @@ -115,8 +118,10 @@ public static Field toMemoryFormat(Field field, BufferAllocator allocator, Map fieldChildren = null; if (encoding == null) { type = field.getType(); + fieldChildren = updatedChildren; } else { // re-type the field for in-memory format type = encoding.getIndexType(); @@ -127,13 +132,14 @@ public static Field toMemoryFormat(Field field, BufferAllocator allocator, Map> dictionaryValues4 = new HashMap<>(); + dictionaryValues4.put("a", Arrays.asList(1, 2, 3)); + dictionaryValues4.put("b", Arrays.asList(4, 5, 6)); + setVector(dictionaryVector4, dictionaryValues4); dictionary1 = new Dictionary(dictionaryVector1, new DictionaryEncoding(/*id=*/1L, /*ordered=*/false, /*indexType=*/null)); @@ -126,6 +143,8 @@ public void init() { new DictionaryEncoding(/*id=*/2L, /*ordered=*/false, /*indexType=*/null)); dictionary3 = new Dictionary(dictionaryVector3, new DictionaryEncoding(/*id=*/1L, /*ordered=*/false, /*indexType=*/null)); + dictionary4 = new Dictionary(dictionaryVector4, + new DictionaryEncoding(/*id=*/3L, /*ordered=*/false, /*indexType=*/null)); } @After @@ -133,6 +152,7 @@ public void terminate() throws Exception { dictionaryVector1.close(); dictionaryVector2.close(); dictionaryVector3.close(); + dictionaryVector4.close(); allocator.close(); } @@ -305,6 +325,82 @@ public void testWriteReadWithDictionaries() throws IOException { } } + @Test + public void testWriteReadWithStructDictionaries() throws IOException { + DictionaryProvider.MapDictionaryProvider provider = + new DictionaryProvider.MapDictionaryProvider(); + provider.put(dictionary4); + + try (final StructVector vector = + newVector(StructVector.class, "D4", MinorType.STRUCT, allocator)) { + final Map> values = new HashMap<>(); + // Index: 0, 2, 1, 2, 1, 0, 0 + values.put("a", Arrays.asList(1, 3, 2, 3, 2, 1, 1)); + values.put("b", Arrays.asList(4, 6, 5, 6, 5, 4, 4)); + setVector(vector, values); + FieldVector encodedVector = (FieldVector) DictionaryEncoder.encode(vector, dictionary4); + + List fields = Arrays.asList(encodedVector.getField()); + List vectors = Collections2.asImmutableList(encodedVector); + try ( + VectorSchemaRoot root = + new VectorSchemaRoot(fields, vectors, encodedVector.getValueCount()); + ByteArrayOutputStream out = new ByteArrayOutputStream(); + ArrowFileWriter writer = new ArrowFileWriter(root, provider, newChannel(out));) { + + writer.start(); + writer.writeBatch(); + writer.end(); + + try ( + SeekableReadChannel channel = new SeekableReadChannel( + new ByteArrayReadableSeekableByteChannel(out.toByteArray())); + ArrowFileReader reader = new ArrowFileReader(channel, allocator)) { + final VectorSchemaRoot readRoot = reader.getVectorSchemaRoot(); + final Schema readSchema = readRoot.getSchema(); + assertEquals(root.getSchema(), readSchema); + assertEquals(1, reader.getDictionaryBlocks().size()); + assertEquals(1, reader.getRecordBlocks().size()); + + reader.loadNextBatch(); + assertEquals(1, readRoot.getFieldVectors().size()); + assertEquals(1, reader.getDictionaryVectors().size()); + + // Read the encoded vector and check it + final FieldVector readEncoded = readRoot.getVector(0); + assertEquals(encodedVector.getValueCount(), readEncoded.getValueCount()); + assertTrue(new RangeEqualsVisitor(encodedVector, readEncoded) + .rangeEquals(new Range(0, 0, encodedVector.getValueCount()))); + + // Read the dictionary + final Map readDictionaryMap = reader.getDictionaryVectors(); + final Dictionary readDictionary = + readDictionaryMap.get(readEncoded.getField().getDictionary().getId()); + assertNotNull(readDictionary); + + // Assert the dictionary vector is correct + final FieldVector readDictionaryVector = readDictionary.getVector(); + assertEquals(dictionaryVector4.getValueCount(), readDictionaryVector.getValueCount()); + final BiFunction typeComparatorIgnoreName = + (v1, v2) -> new TypeEqualsVisitor(v1, false, true).equals(v2); + assertTrue("Dictionary vectors are not equal", + new RangeEqualsVisitor(dictionaryVector4, readDictionaryVector, + typeComparatorIgnoreName) + .rangeEquals(new Range(0, 0, dictionaryVector4.getValueCount()))); + + // Assert the decoded vector is correct + try (final ValueVector readVector = + DictionaryEncoder.decode(readEncoded, readDictionary)) { + assertEquals(vector.getValueCount(), readVector.getValueCount()); + assertTrue("Decoded vectors are not equal", + new RangeEqualsVisitor(vector, readVector, typeComparatorIgnoreName) + .rangeEquals(new Range(0, 0, vector.getValueCount()))); + } + } + } + } + } + @Test public void testEmptyStreamInFileIPC() throws IOException { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/testing/ValueVectorDataPopulator.java b/java/vector/src/test/java/org/apache/arrow/vector/testing/ValueVectorDataPopulator.java index 3d389d86515..15d6a5cf993 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/testing/ValueVectorDataPopulator.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/testing/ValueVectorDataPopulator.java @@ -21,6 +21,8 @@ import java.nio.charset.StandardCharsets; import java.util.List; +import java.util.Map; +import java.util.Map.Entry; import org.apache.arrow.vector.BigIntVector; import org.apache.arrow.vector.BitVector; @@ -60,8 +62,10 @@ import org.apache.arrow.vector.complex.FixedSizeListVector; import org.apache.arrow.vector.complex.LargeListVector; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.holders.IntervalDayHolder; import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.FieldType; /** @@ -673,4 +677,32 @@ public static void setVector(FixedSizeListVector vector, List... values dataVector.setValueCount(curPos); vector.setValueCount(values.length); } + + /** + * Populate values for {@link StructVector}. + */ + public static void setVector(StructVector vector, Map> values) { + vector.allocateNewSafe(); + + int valueCount = 0; + for (final Entry> entry : values.entrySet()) { + // Add the child + final IntVector child = vector.addOrGet(entry.getKey(), + FieldType.nullable(MinorType.INT.getType()), IntVector.class); + + // Write the values to the child + child.allocateNew(); + final List v = entry.getValue(); + for (int i = 0; i < v.size(); i++) { + if (v.get(i) != null) { + child.set(i, v.get(i)); + vector.setIndexDefined(i); + } else { + child.setNull(i); + } + } + valueCount = Math.max(valueCount, v.size()); + } + vector.setValueCount(valueCount); + } }