diff --git a/java/vector/src/main/codegen/templates/MapWriters.java b/java/vector/src/main/codegen/templates/MapWriters.java index 14cc08d7db0..b89f91457e8 100644 --- a/java/vector/src/main/codegen/templates/MapWriters.java +++ b/java/vector/src/main/codegen/templates/MapWriters.java @@ -47,6 +47,7 @@ public class ${mode}MapWriter extends AbstractFieldWriter { protected final ${containerClass} container; + private int initialCapacity; private final Map fields = Maps.newHashMap(); public ${mode}MapWriter(${containerClass} container) { <#if mode == "Single"> @@ -55,6 +56,7 @@ public class ${mode}MapWriter extends AbstractFieldWriter { } this.container = container; + this.initialCapacity = 0; for (Field child : container.getField().getChildren()) { MinorType minorType = Types.getMinorTypeForArrowType(child.getType()); switch (minorType) { @@ -101,6 +103,11 @@ public int getValueCapacity() { return container.getValueCapacity(); } + public void setInitialCapacity(int initialCapacity) { + this.initialCapacity = initialCapacity; + container.setInitialCapacity(initialCapacity); + } + @Override public boolean isEmptyMap() { return 0 == container.size(); @@ -248,6 +255,9 @@ public void end() { writer = new PromotableWriter(v, container, getNullableMapWriterFactory()); vector = v; if (currentVector == null || currentVector != vector) { + if(this.initialCapacity > 0) { + vector.setInitialCapacity(this.initialCapacity); + } vector.allocateNewSafe(); } writer.setPosition(idx()); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java index f81cd557a9d..856d60724b0 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java @@ -28,6 +28,10 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.SchemaChangeCallBack; +import org.apache.arrow.vector.NullableFloat8Vector; +import org.apache.arrow.vector.NullableFloat4Vector; +import org.apache.arrow.vector.NullableBigIntVector; +import org.apache.arrow.vector.NullableIntVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.NullableMapVector; @@ -38,6 +42,11 @@ import org.apache.arrow.vector.complex.impl.UnionListWriter; import org.apache.arrow.vector.complex.impl.UnionReader; import org.apache.arrow.vector.complex.impl.UnionWriter; +import org.apache.arrow.vector.complex.impl.SingleMapWriter; +import org.apache.arrow.vector.complex.reader.IntReader; +import org.apache.arrow.vector.complex.reader.Float8Reader; +import org.apache.arrow.vector.complex.reader.Float4Reader; +import org.apache.arrow.vector.complex.reader.BigIntReader; import org.apache.arrow.vector.complex.reader.BaseReader.MapReader; import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.complex.writer.BaseWriter.ComplexWriter; @@ -834,4 +843,83 @@ public void complexCopierWithList() { innerMap = (JsonStringHashMap) object.get(3); assertEquals(2, innerMap.get("a")); } + + @Test + public void testSingleMapWriter1() { + /* initialize a SingleMapWriter with empty MapVector and then lazily + * create all vectors with expected initialCapacity. + */ + MapVector parent = MapVector.empty("parent", allocator); + SingleMapWriter singleMapWriter = new SingleMapWriter(parent); + + int initialCapacity = 1024; + singleMapWriter.setInitialCapacity(initialCapacity); + + IntWriter intWriter = singleMapWriter.integer("intField"); + BigIntWriter bigIntWriter = singleMapWriter.bigInt("bigIntField"); + Float4Writer float4Writer = singleMapWriter.float4("float4Field"); + Float8Writer float8Writer = singleMapWriter.float8("float8Field"); + ListWriter listWriter = singleMapWriter.list("listField"); + + int intValue = 100; + long bigIntValue = 10000; + float float4Value = 100.5f; + double float8Value = 100.375; + + for (int i = 0; i < initialCapacity; i++) { + singleMapWriter.start(); + + intWriter.writeInt(intValue + i); + bigIntWriter.writeBigInt(bigIntValue + (long)i); + float4Writer.writeFloat4(float4Value + (float)i); + float8Writer.writeFloat8(float8Value + (double)i); + + listWriter.setPosition(i); + listWriter.startList(); + listWriter.integer().writeInt(intValue + i); + listWriter.integer().writeInt(intValue + i + 1); + listWriter.integer().writeInt(intValue + i + 2); + listWriter.integer().writeInt(intValue + i + 3); + listWriter.endList(); + + singleMapWriter.end(); + } + + NullableIntVector intVector = (NullableIntVector)parent.getChild("intField"); + NullableBigIntVector bigIntVector = (NullableBigIntVector)parent.getChild("bigIntField"); + NullableFloat4Vector float4Vector = (NullableFloat4Vector)parent.getChild("float4Field"); + NullableFloat8Vector float8Vector = (NullableFloat8Vector)parent.getChild("float8Field"); + + assertEquals(initialCapacity, singleMapWriter.getValueCapacity()); + assertEquals(initialCapacity, intVector.getValueCapacity()); + assertEquals(initialCapacity, bigIntVector.getValueCapacity()); + assertEquals(initialCapacity, float4Vector.getValueCapacity()); + assertEquals(initialCapacity, float8Vector.getValueCapacity()); + + MapReader singleMapReader = new SingleMapReaderImpl(parent); + + IntReader intReader = singleMapReader.reader("intField"); + BigIntReader bigIntReader = singleMapReader.reader("bigIntField"); + Float4Reader float4Reader = singleMapReader.reader("float4Field"); + Float8Reader float8Reader = singleMapReader.reader("float8Field"); + UnionListReader listReader = (UnionListReader)singleMapReader.reader("listField"); + + for (int i = 0; i < initialCapacity; i++) { + intReader.setPosition(i); + bigIntReader.setPosition(i); + float4Reader.setPosition(i); + float8Reader.setPosition(i); + listReader.setPosition(i); + + assertEquals(intValue + i, intReader.readInteger().intValue()); + assertEquals(bigIntValue + (long)i, bigIntReader.readLong().longValue()); + assertEquals(float4Value + (float)i, float4Reader.readFloat().floatValue(), 0); + assertEquals(float8Value + (double)i, float8Reader.readDouble().doubleValue(), 0); + + for (int j = 0; j < 4; j++) { + listReader.next(); + assertEquals(intValue + i + j, listReader.reader().readInteger().intValue()); + } + } + } }