Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions java/vector/src/main/codegen/templates/MapWriters.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
public class ${mode}MapWriter extends AbstractFieldWriter {

protected final ${containerClass} container;
private int initialCapacity;
private final Map<String, FieldWriter> fields = Maps.newHashMap();
public ${mode}MapWriter(${containerClass} container) {
<#if mode == "Single">
Expand All @@ -55,6 +56,7 @@ public class ${mode}MapWriter extends AbstractFieldWriter {
}
</#if>
this.container = container;
this.initialCapacity = 0;
for (Field child : container.getField().getChildren()) {
MinorType minorType = Types.getMinorTypeForArrowType(child.getType());
switch (minorType) {
Expand Down Expand Up @@ -101,6 +103,11 @@ public int getValueCapacity() {
return container.getValueCapacity();
}

public void setInitialCapacity(int initialCapacity) {
this.initialCapacity = initialCapacity;
container.setInitialCapacity(initialCapacity);
}

@Override
public boolean isEmptyMap() {
return 0 == container.size();
Expand Down Expand Up @@ -248,6 +255,9 @@ public void end() {
writer = new PromotableWriter(v, container, getNullableMapWriterFactory());
vector = v;
if (currentVector == null || currentVector != vector) {
if(this.initialCapacity > 0) {
vector.setInitialCapacity(this.initialCapacity);
}
vector.allocateNewSafe();
}
writer.setPosition(idx());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.SchemaChangeCallBack;
import org.apache.arrow.vector.NullableFloat8Vector;
import org.apache.arrow.vector.NullableFloat4Vector;
import org.apache.arrow.vector.NullableBigIntVector;
import org.apache.arrow.vector.NullableIntVector;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.NullableMapVector;
Expand All @@ -38,6 +42,11 @@
import org.apache.arrow.vector.complex.impl.UnionListWriter;
import org.apache.arrow.vector.complex.impl.UnionReader;
import org.apache.arrow.vector.complex.impl.UnionWriter;
import org.apache.arrow.vector.complex.impl.SingleMapWriter;
import org.apache.arrow.vector.complex.reader.IntReader;
import org.apache.arrow.vector.complex.reader.Float8Reader;
import org.apache.arrow.vector.complex.reader.Float4Reader;
import org.apache.arrow.vector.complex.reader.BigIntReader;
import org.apache.arrow.vector.complex.reader.BaseReader.MapReader;
import org.apache.arrow.vector.complex.reader.FieldReader;
import org.apache.arrow.vector.complex.writer.BaseWriter.ComplexWriter;
Expand Down Expand Up @@ -834,4 +843,83 @@ public void complexCopierWithList() {
innerMap = (JsonStringHashMap<?, ?>) object.get(3);
assertEquals(2, innerMap.get("a"));
}

@Test
public void testSingleMapWriter1() {
/* initialize a SingleMapWriter with empty MapVector and then lazily
* create all vectors with expected initialCapacity.
*/
MapVector parent = MapVector.empty("parent", allocator);
SingleMapWriter singleMapWriter = new SingleMapWriter(parent);

int initialCapacity = 1024;
singleMapWriter.setInitialCapacity(initialCapacity);

IntWriter intWriter = singleMapWriter.integer("intField");
BigIntWriter bigIntWriter = singleMapWriter.bigInt("bigIntField");
Float4Writer float4Writer = singleMapWriter.float4("float4Field");
Float8Writer float8Writer = singleMapWriter.float8("float8Field");
ListWriter listWriter = singleMapWriter.list("listField");

int intValue = 100;
long bigIntValue = 10000;
float float4Value = 100.5f;
double float8Value = 100.375;

for (int i = 0; i < initialCapacity; i++) {
singleMapWriter.start();

intWriter.writeInt(intValue + i);
bigIntWriter.writeBigInt(bigIntValue + (long)i);
float4Writer.writeFloat4(float4Value + (float)i);
float8Writer.writeFloat8(float8Value + (double)i);

listWriter.setPosition(i);
listWriter.startList();
listWriter.integer().writeInt(intValue + i);
listWriter.integer().writeInt(intValue + i + 1);
listWriter.integer().writeInt(intValue + i + 2);
listWriter.integer().writeInt(intValue + i + 3);
listWriter.endList();

singleMapWriter.end();
}

NullableIntVector intVector = (NullableIntVector)parent.getChild("intField");
NullableBigIntVector bigIntVector = (NullableBigIntVector)parent.getChild("bigIntField");
NullableFloat4Vector float4Vector = (NullableFloat4Vector)parent.getChild("float4Field");
NullableFloat8Vector float8Vector = (NullableFloat8Vector)parent.getChild("float8Field");

assertEquals(initialCapacity, singleMapWriter.getValueCapacity());
assertEquals(initialCapacity, intVector.getValueCapacity());
assertEquals(initialCapacity, bigIntVector.getValueCapacity());
assertEquals(initialCapacity, float4Vector.getValueCapacity());
assertEquals(initialCapacity, float8Vector.getValueCapacity());

MapReader singleMapReader = new SingleMapReaderImpl(parent);

IntReader intReader = singleMapReader.reader("intField");
BigIntReader bigIntReader = singleMapReader.reader("bigIntField");
Float4Reader float4Reader = singleMapReader.reader("float4Field");
Float8Reader float8Reader = singleMapReader.reader("float8Field");
UnionListReader listReader = (UnionListReader)singleMapReader.reader("listField");

for (int i = 0; i < initialCapacity; i++) {
intReader.setPosition(i);
bigIntReader.setPosition(i);
float4Reader.setPosition(i);
float8Reader.setPosition(i);
listReader.setPosition(i);

assertEquals(intValue + i, intReader.readInteger().intValue());
assertEquals(bigIntValue + (long)i, bigIntReader.readLong().longValue());
assertEquals(float4Value + (float)i, float4Reader.readFloat().floatValue(), 0);
assertEquals(float8Value + (double)i, float8Reader.readDouble().doubleValue(), 0);

for (int j = 0; j < 4; j++) {
listReader.next();
assertEquals(intValue + i + j, listReader.reader().readInteger().intValue());
}
}
}
}