diff --git a/java/vector/src/main/codegen/templates/AbstractFieldWriter.java b/java/vector/src/main/codegen/templates/AbstractFieldWriter.java index 6c2368117f7..01d5f493345 100644 --- a/java/vector/src/main/codegen/templates/AbstractFieldWriter.java +++ b/java/vector/src/main/codegen/templates/AbstractFieldWriter.java @@ -233,6 +233,15 @@ public MapWriter map(String name, boolean keysSorted) { } + <#if lowerName?contains("decimal")> + + @Override + public ${capName}Writer ${lowerName}(<#list minor.typeParams as typeParam>${typeParam.type} ${typeParam.name}<#if typeParam_has_next>, ) { + fail("${capName}(" + <#list minor.typeParams as typeParam>"${typeParam.name}: " + ${typeParam.name} + ", " + ")"); + return null; + } + + @Override public ${capName}Writer ${lowerName}(String name) { fail("${capName}"); diff --git a/java/vector/src/main/codegen/templates/BaseWriter.java b/java/vector/src/main/codegen/templates/BaseWriter.java index 35df256b324..718ee227f49 100644 --- a/java/vector/src/main/codegen/templates/BaseWriter.java +++ b/java/vector/src/main/codegen/templates/BaseWriter.java @@ -56,6 +56,9 @@ public interface StructWriter extends BaseWriter { <#if minor.typeParams?? > ${capName}Writer ${lowerName}(String name<#list minor.typeParams as typeParam>, ${typeParam.type} ${typeParam.name}); + <#if lowerName?contains("decimal")> + ${capName}Writer ${lowerName}(<#list minor.typeParams as typeParam>${typeParam.type} ${typeParam.name}<#if typeParam_has_next>, ); + ${capName}Writer ${lowerName}(String name); @@ -83,6 +86,9 @@ public interface ListWriter extends BaseWriter { <#assign upperName = minor.class?upper_case /> <#assign capName = minor.class?cap_first /> ${capName}Writer ${lowerName}(); + <#if lowerName?contains("decimal")> + ${capName}Writer ${lowerName}(<#list minor.typeParams as typeParam>${typeParam.type} ${typeParam.name}<#if typeParam_has_next>, ); + } diff --git a/java/vector/src/main/codegen/templates/UnionListWriter.java b/java/vector/src/main/codegen/templates/UnionListWriter.java index eeb964c055f..af8e19d15b3 100644 --- a/java/vector/src/main/codegen/templates/UnionListWriter.java +++ b/java/vector/src/main/codegen/templates/UnionListWriter.java @@ -127,6 +127,12 @@ public void setPosition(int index) { public ${minor.class}Writer ${lowerName}(String name<#list minor.typeParams as typeParam>, ${typeParam.type} ${typeParam.name}) { return writer.${lowerName}(name<#list minor.typeParams as typeParam>, ${typeParam.name}); } + <#if lowerName?contains("decimal")> + @Override + public ${capName}Writer ${lowerName}(<#list minor.typeParams as typeParam>${typeParam.type} ${typeParam.name}<#if typeParam_has_next>, ) { + return writer.${lowerName}(<#list minor.typeParams as typeParam>${typeParam.name}<#if typeParam_has_next>, ); + } + @Override diff --git a/java/vector/src/main/codegen/templates/UnionMapWriter.java b/java/vector/src/main/codegen/templates/UnionMapWriter.java index 606f880377b..938f009047b 100644 --- a/java/vector/src/main/codegen/templates/UnionMapWriter.java +++ b/java/vector/src/main/codegen/templates/UnionMapWriter.java @@ -183,6 +183,29 @@ public Decimal256Writer decimal256() { } } + @Override + public DecimalWriter decimal(int scale, int precision) { + switch (mode) { + case KEY: + return entryWriter.decimal(MapVector.KEY_NAME, scale, precision); + case VALUE: + return entryWriter.decimal(MapVector.VALUE_NAME, scale, precision); + default: + return this; + } + } + + @Override + public Decimal256Writer decimal256(int scale, int precision) { + switch (mode) { + case KEY: + return entryWriter.decimal256(MapVector.KEY_NAME, scale, precision); + case VALUE: + return entryWriter.decimal256(MapVector.VALUE_NAME, scale, precision); + default: + return this; + } + } @Override public StructWriter struct() { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java index 7fd0def9673..cc698f04e19 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java @@ -289,6 +289,9 @@ private boolean requiresArrowType(MinorType type) { protected FieldWriter getWriter(MinorType type, ArrowType arrowType) { if (state == State.UNION) { if (requiresArrowType(type)) { + if (arrowType == null) { + arrowType = type.getType(); + } ((UnionWriter) writer).getWriter(type, arrowType); } else { ((UnionWriter) writer).getWriter(type); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java b/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java index ed099890e1b..31a11ce2386 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java @@ -693,7 +693,7 @@ public FieldWriter getNewFieldWriter(ValueVector vector) { return new DenseUnionWriter((DenseUnionVector) vector); } }, - MAP(null) { + MAP(new Map(false)) { @Override public FieldVector getNewVector( Field field, BufferAllocator allocator, CallBack schemaChangeCallback) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java index 213ffced273..ec56b08d407 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestMapVector.java @@ -22,14 +22,17 @@ import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.math.BigDecimal; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; import org.apache.arrow.vector.complex.impl.UnionMapReader; import org.apache.arrow.vector.complex.impl.UnionMapWriter; import org.apache.arrow.vector.complex.reader.FieldReader; @@ -1204,4 +1207,66 @@ public void testMakeTransferPairPreserveNullability() { assertEquals(intField, vec.getField().getChildren().get(0)); assertEquals(intField, res.getField().getChildren().get(0)); } + + @Test + public void testSimpleMapVectorWithDecimals() { + try (final MapVector vector = MapVector.empty("v", allocator, false)) { + vector.allocateNew(); + UnionMapWriter mapWriter = vector.getWriter(); + + mapWriter.setPosition(0); + mapWriter.startMap(); + mapWriter.startEntry(); + mapWriter.key().decimal(1, 2).writeDecimal(BigDecimal.valueOf(10, 1)); + mapWriter.value().decimal(1, 2).writeDecimal(BigDecimal.valueOf(20, 1)); + mapWriter.endEntry(); + mapWriter.endMap(); + + vector.setValueCount(1); + assertEquals(vector.getChildrenFromFields().size(), 1); + StructVector structVector = (StructVector) vector.getChildrenFromFields().get(0); + assertEquals(structVector.getChildrenFromFields().size(), 2); + DecimalVector keyVector = (DecimalVector) structVector.getChildrenFromFields().get(0); + DecimalVector valueVector = (DecimalVector) structVector.getChildrenFromFields().get(1); + assertEquals(keyVector.getObject(0), BigDecimal.valueOf(10, 1)); + assertEquals(valueVector.getObject(0), BigDecimal.valueOf(20, 1)); + } + } + + @Test + public void testWritingDecimals() { + try (ListVector vector = ListVector.empty("v", allocator)) { + UnionListWriter listWriter = vector.getWriter(); + listWriter.allocate(); + + listWriter.setPosition(0); + listWriter.startList(); + listWriter.map().startMap(); + listWriter.map().startEntry(); + listWriter.map().key().integer().writeInt(10); + listWriter.map().value().integer().writeInt(20); + listWriter.map().endEntry(); + listWriter.map().startEntry(); + listWriter.map().key().decimal(1, 2).writeDecimal(BigDecimal.valueOf(2.0)); + listWriter.map().value().decimal(1, 2).writeDecimal(BigDecimal.valueOf(3.0)); + listWriter.map().endEntry(); + listWriter.map().endMap(); + listWriter.endList(); + + listWriter.setValueCount(1); + vector.setValueCount(1); + + assertEquals(vector.getChildrenFromFields().size(), 1); + MapVector mapVector = (MapVector) vector.getChildrenFromFields().get(0); + assertEquals(mapVector.getChildrenFromFields().size(), 1); + StructVector structVector = (StructVector) mapVector.getChildrenFromFields().get(0); + assertEquals(structVector.getChildrenFromFields().size(), 2); + assertEquals(structVector.getChildrenFromFields().get(0).getObject(0), 10); + assertEquals( + structVector.getChildrenFromFields().get(0).getObject(1), BigDecimal.valueOf(2.0)); + assertEquals(structVector.getChildrenFromFields().get(1).getObject(0), 20); + assertEquals( + structVector.getChildrenFromFields().get(1).getObject(1), BigDecimal.valueOf(3.0)); + } + } }