diff --git a/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java b/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java index 0f249c71ca5..3bb368ba323 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/VectorLoader.java @@ -73,7 +73,9 @@ private void loadBuffers( checkArgument(nodes.hasNext(), "no more field nodes for for field " + field + " and vector " + vector); ArrowFieldNode fieldNode = nodes.next(); - List bufferLayouts = TypeLayout.getTypeLayout(field.getType()).getBufferLayouts(); + List bufferLayouts = field.getDictionary() == null ? + TypeLayout.getTypeLayout(field.getType()).getBufferLayouts() : + TypeLayout.getTypeLayout(field.getDictionary().getIndexType()).getBufferLayouts(); List ownBuffers = new ArrayList<>(bufferLayouts.size()); for (int j = 0; j < bufferLayouts.size(); j++) { ownBuffers.add(buffers.next()); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java b/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java index 2263ea9904e..b3c6030bbe7 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/VectorUnloader.java @@ -23,6 +23,7 @@ import org.apache.arrow.vector.BufferLayout.BufferType; import org.apache.arrow.vector.ipc.message.ArrowFieldNode; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Field; import io.netty.buffer.ArrowBuf; @@ -72,7 +73,11 @@ public ArrowRecordBatch getRecordBatch() { private void appendNodes(FieldVector vector, List nodes, List buffers) { nodes.add(new ArrowFieldNode(vector.getValueCount(), includeNullCount ? vector.getNullCount() : -1)); List fieldBuffers = vector.getFieldBuffers(); - List expectedBuffers = TypeLayout.getTypeLayout(vector.getField().getType()).getBufferTypes(); + Field field = vector.getField(); + List expectedBuffers = field.getDictionary() == null ? + TypeLayout.getTypeLayout(field.getType()).getBufferTypes() : + TypeLayout.getTypeLayout(field.getDictionary().getIndexType()).getBufferTypes(); + if (fieldBuffers.size() != expectedBuffers.size()) { throw new IllegalArgumentException(String.format( "wrong number of buffers for field %s in vector %s. found: %s", diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java index 945f5df2d98..1ecd623ee2d 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/pojo/FieldType.java @@ -94,7 +94,8 @@ public Map getMetadata() { } public FieldVector createNewSingleVector(String name, BufferAllocator allocator, CallBack schemaCallBack) { - MinorType minorType = Types.getMinorTypeForArrowType(type); + MinorType minorType = dictionary == null ? + Types.getMinorTypeForArrowType(type) : Types.getMinorTypeForArrowType(dictionary.getIndexType()); return minorType.getNewVector(name, this, allocator, schemaCallBack); } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryEncodedVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryEncodedVector.java new file mode 100644 index 00000000000..10df3e2ac1d --- /dev/null +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDictionaryEncodedVector.java @@ -0,0 +1,132 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.vector; + +import static java.nio.channels.Channels.newChannel; +import static java.util.Arrays.asList; +import static org.apache.arrow.vector.TestUtils.newVarCharVector; + +import java.io.*; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Collections; +import java.util.Vector; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.pojo.*; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class TestDictionaryEncodedVector { + + private BufferAllocator allocator; + + @Before + public void init() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @After + public void terminate() throws Exception { + allocator.close(); + } + + @Test + public void testDicionaryEncodedVector() throws IOException { + + // Create the dictionary + VarCharVector dictVector = (VarCharVector) + FieldType.nullable(new ArrowType.Utf8()) + .createNewSingleVector("enum1_dict", allocator, null); + dictVector.allocateNew(); + dictVector.set(0, "foo".getBytes(StandardCharsets.UTF_8)); + dictVector.set(1, "bar".getBytes(StandardCharsets.UTF_8)); + dictVector.setValueCount(2); + + DictionaryEncoding enum1Encoding = new DictionaryEncoding( + 0, + true, + new ArrowType.Int(8, true) + ); + + DictionaryProvider dictionaryProvider = new DictionaryProvider.MapDictionaryProvider( + new Dictionary(dictVector, enum1Encoding) + ); + + // Create the dictionary encoded vector + Schema schema = new Schema( + asList( + new Field( + "enum1", + true, + // this is the type of the decoded value + new ArrowType.Utf8(), + enum1Encoding, + Collections.emptyList()) + )); + + // This doesn't not work without the patch + // root.getVector("enum1") returns a VarCharVector instead of a TinyIntVector + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + TinyIntVector vector = (TinyIntVector) root.getVector("enum1"); + vector.allocateNew(); + + // This is the encoded vector + // TinyIntVector vector = new TinyIntVector("enum1", allocator); + // vector.allocateNew(); + vector.set(0, 1); + vector.set(1, 0); + vector.set(2, 1); + vector.setValueCount(3); + root.setRowCount(3); + + // Test round trip + ByteArrayOutputStream out = new ByteArrayOutputStream(); + // FileOutputStream out = new FileOutputStream("/tmp/dictionary_encoded.arrow"); + ArrowStreamWriter writer = new ArrowStreamWriter(root, dictionaryProvider, newChannel(out)); + + writer.start(); + writer.writeBatch(); + writer.close(); + + dictVector.close(); + root.close(); + + byte[] data = out.toByteArray(); + out.close(); + + // Read + ByteArrayInputStream in = new ByteArrayInputStream(data); + + // FileInputStream in = new FileInputStream("/tmp/dictionary_encoded.arrow"); + ArrowStreamReader reader = new ArrowStreamReader(in, allocator); + root = reader.getVectorSchemaRoot(); + reader.loadNextBatch(); + System.out.println(root.contentToTSVString()); + + reader.close(); + root.close(); + } +}