diff --git a/java/vector/src/main/codegen/templates/AbstractFieldWriter.java b/java/vector/src/main/codegen/templates/AbstractFieldWriter.java index da8e4f54ec2..853f67fd0dd 100644 --- a/java/vector/src/main/codegen/templates/AbstractFieldWriter.java +++ b/java/vector/src/main/codegen/templates/AbstractFieldWriter.java @@ -53,6 +53,7 @@ public void endList() { <#list vv.types as type><#list type.minor as minor><#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> + <#assign friendlyType = (minor.friendlyType!minor.boxedType!type.boxedType) /> @Override public void write(${name}Holder holder) { fail("${name}"); @@ -62,6 +63,12 @@ public void write(${name}Holder holder) { fail("${name}"); } + <#if minor.class == "Decimal"> + public void write${minor.class}(${friendlyType} value) { + fail("${name}"); + } + + public void writeNull() { diff --git a/java/vector/src/main/codegen/templates/ComplexWriters.java b/java/vector/src/main/codegen/templates/ComplexWriters.java index 8ebecf3e1de..fe099bede35 100644 --- a/java/vector/src/main/codegen/templates/ComplexWriters.java +++ b/java/vector/src/main/codegen/templates/ComplexWriters.java @@ -24,6 +24,7 @@ <#assign eName = name /> <#assign javaType = (minor.javaType!type.javaType) /> <#assign fields = minor.fields!type.fields /> +<#assign friendlyType = (minor.friendlyType!minor.boxedType!type.boxedType) /> <@pp.changeOutputFile name="/org/apache/arrow/vector/complex/impl/${eName}WriterImpl.java" /> <#include "/@includes/license.ftl" /> @@ -115,7 +116,13 @@ public void write(Nullable${minor.class}Holder h) { mutator.setSafe(idx()<#if mode == "Nullable">, 1<#list fields as field><#if field.include!true >, ${field.name}); vector.getMutator().setValueCount(idx()+1); } + <#if minor.class == "Decimal"> + public void write${minor.class}(${friendlyType} value) { + mutator.setSafe(idx(), value); + vector.getMutator().setValueCount(idx()+1); + } + <#if mode == "Nullable"> public void writeNull() { @@ -140,6 +147,10 @@ public interface ${eName}Writer extends BaseWriter { public void write(${minor.class}Holder h); public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, ); +<#if minor.class == "Decimal"> + + public void write${minor.class}(${friendlyType} value); + } diff --git a/java/vector/src/main/codegen/templates/FixedValueVectors.java b/java/vector/src/main/codegen/templates/FixedValueVectors.java index 9747d421c41..ffd8cad02e2 100644 --- a/java/vector/src/main/codegen/templates/FixedValueVectors.java +++ b/java/vector/src/main/codegen/templates/FixedValueVectors.java @@ -403,7 +403,7 @@ public void get(int index, Nullable${minor.class}Holder holder) { @Override public ${friendlyType} getObject(int index) { - return org.apache.arrow.vector.util.DecimalUtility.getBigDecimalFromArrowBuf(data, ${type.width} * index, scale); + return DecimalUtility.getBigDecimalFromArrowBuf(data, index, scale); } <#else> @@ -596,10 +596,10 @@ void set(int index, Nullable${minor.class}Holder holder){ set(index, holder.start, holder.buffer); } - public void setSafe(int index, Nullable${minor.class}Holder holder){ + public void setSafe(int index, Nullable${minor.class}Holder holder){ setSafe(index, holder.start, holder.buffer); } - public void setSafe(int index, ${minor.class}Holder holder){ + public void setSafe(int index, ${minor.class}Holder holder){ setSafe(index, holder.start, holder.buffer); } @@ -614,6 +614,18 @@ public void set(int index, int start, ArrowBuf buffer){ data.setBytes(index * ${type.width}, buffer, start, ${type.width}); } + public void set(int index, ${friendlyType} value){ + DecimalUtility.checkPrecisionAndScale(value, precision, scale); + DecimalUtility.writeBigDecimalToArrowBuf(value, data, index); + } + + public void setSafe(int index, ${friendlyType} value){ + while(index >= getValueCapacity()) { + reAlloc(); + } + set(index, value); + } + <#else> protected void set(int index, ${minor.class}Holder holder){ set(index, holder.start, holder.buffer); diff --git a/java/vector/src/main/codegen/templates/NullableValueVectors.java b/java/vector/src/main/codegen/templates/NullableValueVectors.java index a4313332563..319c61c8624 100644 --- a/java/vector/src/main/codegen/templates/NullableValueVectors.java +++ b/java/vector/src/main/codegen/templates/NullableValueVectors.java @@ -723,6 +723,19 @@ public void setSafe(int index, ${minor.javaType!type.javaType} value) { setCount++; } + + <#if minor.class == "Decimal"> + public void set(int index, ${friendlyType} value) { + bits.getMutator().setToOne(index); + values.getMutator().set(index, value); + } + + public void setSafe(int index, ${friendlyType} value) { + bits.getMutator().setSafeToOne(index); + values.getMutator().setSafe(index, value); + setCount++; + } + @Override public void setValueCount(int valueCount) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java index 484a82fdaab..71685d13589 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java @@ -39,6 +39,7 @@ import org.apache.arrow.vector.BufferBacked; import org.apache.arrow.vector.DateDayVector; import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.Float4Vector; import org.apache.arrow.vector.Float8Vector; @@ -72,6 +73,7 @@ import org.apache.arrow.vector.schema.ArrowVectorType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.DecimalUtility; import org.apache.arrow.vector.util.DictionaryUtility; import org.apache.commons.codec.DecoderException; import org.apache.commons.codec.binary.Hex; @@ -235,14 +237,16 @@ private void readVector(Field field, FieldVector vector) throws JsonParseExcepti nextFieldIs(vectorType.getName()); readToken(START_ARRAY); ValueVector valueVector = (ValueVector) innerVector; - valueVector.allocateNew(); - Mutator mutator = valueVector.getMutator(); int innerVectorCount = vectorType.equals(OFFSET) ? count + 1 : count; + valueVector.setInitialCapacity(innerVectorCount); + valueVector.allocateNew(); + for (int i = 0; i < innerVectorCount; i++) { parser.nextToken(); setValueFromParser(valueVector, i); } + Mutator mutator = valueVector.getMutator(); mutator.setValueCount(innerVectorCount); readToken(END_ARRAY); } @@ -312,6 +316,12 @@ private void setValueFromParser(ValueVector valueVector, int i) throws IOExcepti case FLOAT8: ((Float8Vector) valueVector).getMutator().set(i, parser.readValueAs(Double.class)); break; + case DECIMAL: { + DecimalVector decimalVector = ((DecimalVector) valueVector); + byte[] value = decodeHexSafe(parser.readValueAs(String.class)); + DecimalUtility.writeByteArrayToArrowBuf(value, decimalVector.getBuffer(), i); + } + break; case VARBINARY: ((VarBinaryVector) valueVector).getMutator().setSafe(i, decodeHexSafe(parser.readValueAs(String.class))); break; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileWriter.java index a2229cef231..04e44379e5d 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileWriter.java @@ -26,10 +26,12 @@ import java.util.Set; import com.google.common.collect.ImmutableList; +import io.netty.buffer.ArrowBuf; import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.BufferBacked; import org.apache.arrow.vector.DateDayVector; import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.TimeMicroVector; import org.apache.arrow.vector.TimeMilliVector; @@ -54,6 +56,7 @@ import com.fasterxml.jackson.core.util.DefaultPrettyPrinter; import com.fasterxml.jackson.core.util.DefaultPrettyPrinter.NopIndenter; import com.fasterxml.jackson.databind.MappingJsonFactory; +import org.apache.arrow.vector.util.DecimalUtility; import org.apache.arrow.vector.util.DictionaryUtility; import org.apache.commons.codec.binary.Hex; @@ -233,9 +236,16 @@ private void writeValueToGenerator(ValueVector valueVector, int i) throws IOExce case BIT: generator.writeNumber(((BitVector) valueVector).getAccessor().get(i)); break; - case VARBINARY: - String hexString = Hex.encodeHexString(((VarBinaryVector) valueVector).getAccessor().get(i)); - generator.writeObject(hexString); + case VARBINARY: { + String hexString = Hex.encodeHexString(((VarBinaryVector) valueVector).getAccessor().get(i)); + generator.writeString(hexString); + } + break; + case DECIMAL: { + ArrowBuf bytebuf = valueVector.getDataBuffer(); + String hexString = Hex.encodeHexString(DecimalUtility.getByteArrayFromArrowBuf(bytebuf, i)); + generator.writeString(hexString); + } break; default: // TODO: each type diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java b/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java index 4b11b368dff..033ae6c0991 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java @@ -19,13 +19,12 @@ package org.apache.arrow.vector.util; import io.netty.buffer.ArrowBuf; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.UnpooledByteBufAllocator; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.types.pojo.ArrowType; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; -import java.util.Arrays; public class DecimalUtility { @@ -69,7 +68,7 @@ public class DecimalUtility { public static final int DECIMAL_BYTE_LENGTH = 16; - /* + /** * Simple function that returns the static precomputed * power of ten, instead of using Math.pow */ @@ -78,7 +77,7 @@ public static long getPowerOfTen(int power) { return scale_long_constants[(power)]; } - /* + /** * Math.pow returns a double and while multiplying with large digits * in the decimal data type we encounter noise. So instead of multiplying * with Math.pow we use the static constants to perform the multiplication @@ -103,7 +102,8 @@ public static long adjustScaleDivide(long input, int factor) { } } - /* Returns a string representation of the given integer + /** + * Returns a string representation of the given integer * If the length of the given integer is less than the * passed length, this function will prepend zeroes to the string */ @@ -136,33 +136,86 @@ public static StringBuilder toStringWithZeroes(long number, int desiredLength) { return str; } - public static BigDecimal getBigDecimalFromArrowBuf(ArrowBuf bytebuf, int startIndex, int scale) { + /** + * Read an ArrowType.Decimal at the given value index in the ArrowBuf and convert to a BigDecimal + * with the given scale. + */ + public static BigDecimal getBigDecimalFromArrowBuf(ArrowBuf bytebuf, int index, int scale) { byte[] value = new byte[DECIMAL_BYTE_LENGTH]; + final int startIndex = index * DECIMAL_BYTE_LENGTH; bytebuf.getBytes(startIndex, value, 0, DECIMAL_BYTE_LENGTH); BigInteger unscaledValue = new BigInteger(value); return new BigDecimal(unscaledValue, scale); } - public static BigDecimal getBigDecimalFromByteBuffer(ByteBuffer bytebuf, int start, int scale) { + /** + * Read an ArrowType.Decimal from the ByteBuffer and convert to a BigDecimal with the given + * scale. + */ + public static BigDecimal getBigDecimalFromByteBuffer(ByteBuffer bytebuf, int scale) { byte[] value = new byte[DECIMAL_BYTE_LENGTH]; bytebuf.get(value); BigInteger unscaledValue = new BigInteger(value); return new BigDecimal(unscaledValue, scale); } + /** + * Read an ArrowType.Decimal from the ArrowBuf at the given value index and return it as a byte + * array. + */ + public static byte[] getByteArrayFromArrowBuf(ArrowBuf bytebuf, int index) { + final byte[] value = new byte[DECIMAL_BYTE_LENGTH]; + final int startIndex = index * DECIMAL_BYTE_LENGTH; + bytebuf.getBytes(startIndex, value, 0, DECIMAL_BYTE_LENGTH); + return value; + } + + /** + * Check that the BigDecimal scale equals the vectorScale and that the BigDecimal precision is + * less than or equal to the vectorPrecision. If not, then an UnsupportedOperationException is + * thrown, otherwise returns true. + */ + public static boolean checkPrecisionAndScale(BigDecimal value, int vectorPrecision, int vectorScale) { + if (value.scale() != vectorScale) { + throw new UnsupportedOperationException("BigDecimal scale must equal that in the Arrow vector: " + + value.scale() + " != " + vectorScale); + } + if (value.precision() > vectorPrecision) { + throw new UnsupportedOperationException("BigDecimal precision can not be greater than that in the Arrow vector: " + + value.precision() + " > " + vectorPrecision); + } + return true; + } + + /** + * Write the given BigDecimal to the ArrowBuf at the given value index. Will throw an + * UnsupportedOperationException if the decimal size is greater than the Decimal vector byte + * width. + */ public static void writeBigDecimalToArrowBuf(BigDecimal value, ArrowBuf bytebuf, int index) { final byte[] bytes = value.unscaledValue().toByteArray(); + final int padValue = value.signum() == -1 ? 0xFF : 0; + writeByteArrayToArrowBuf(bytes, bytebuf, index, padValue); + } + + /** + * Write the given byte array to the ArrowBuf at the given value index. Will throw an + * UnsupportedOperationException if the decimal size is greater than the Decimal vector byte + * width. + */ + public static void writeByteArrayToArrowBuf(byte[] bytes, ArrowBuf bytebuf, int index) { + writeByteArrayToArrowBuf(bytes, bytebuf, index, 0); + } + + private static void writeByteArrayToArrowBuf(byte[] bytes, ArrowBuf bytebuf, int index, int padValue) { final int startIndex = index * DECIMAL_BYTE_LENGTH; if (bytes.length > DECIMAL_BYTE_LENGTH) { throw new UnsupportedOperationException("Decimal size greater than 16 bytes"); } final int padLength = DECIMAL_BYTE_LENGTH - bytes.length; - final int padValue = value.signum() == -1 ? 0xFF : 0; for (int i = 0; i < padLength; i++) { bytebuf.setByte(startIndex + i, padValue); } bytebuf.setBytes(startIndex + padLength, bytes, 0, bytes.length); } } - - diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java index 774fbe084f1..56d22932764 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java @@ -19,6 +19,7 @@ package org.apache.arrow.vector; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import java.math.BigDecimal; import java.math.BigInteger; @@ -27,6 +28,8 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.util.DecimalUtility; +import org.junit.After; +import org.junit.Before; import org.junit.Test; public class TestDecimalVector { @@ -43,27 +46,69 @@ public class TestDecimalVector { private int scale = 3; + private BufferAllocator allocator; + + @Before + public void init() { + allocator = new DirtyRootAllocator(Long.MAX_VALUE, (byte) 100); + } + + @After + public void terminate() throws Exception { + allocator.close(); + } + @Test - public void test() { - BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); - NullableDecimalVector decimalVector = TestUtils.newVector(NullableDecimalVector.class, "decimal", new ArrowType.Decimal(10, scale), allocator); - try (NullableDecimalVector oldConstructor = new NullableDecimalVector("decimal", allocator, 10, scale);) { - assertEquals(decimalVector.getField().getType(), oldConstructor.getField().getType()); - } - decimalVector.allocateNew(); - BigDecimal[] values = new BigDecimal[intValues.length]; - for (int i = 0; i < intValues.length; i++) { - BigDecimal decimal = new BigDecimal(BigInteger.valueOf(intValues[i]), scale); - values[i] = decimal; - decimalVector.getMutator().setIndexDefined(i); - DecimalUtility.writeBigDecimalToArrowBuf(decimal, decimalVector.getBuffer(), i); + public void testValuesWriteRead() { + try (NullableDecimalVector decimalVector = TestUtils.newVector(NullableDecimalVector.class, "decimal", new ArrowType.Decimal(10, scale), allocator);) { + + try (NullableDecimalVector oldConstructor = new NullableDecimalVector("decimal", allocator, 10, scale);) { + assertEquals(decimalVector.getField().getType(), oldConstructor.getField().getType()); + } + + decimalVector.allocateNew(); + BigDecimal[] values = new BigDecimal[intValues.length]; + for (int i = 0; i < intValues.length; i++) { + BigDecimal decimal = new BigDecimal(BigInteger.valueOf(intValues[i]), scale); + values[i] = decimal; + decimalVector.getMutator().setSafe(i, decimal); + } + + decimalVector.getMutator().setValueCount(intValues.length); + + for (int i = 0; i < intValues.length; i++) { + BigDecimal value = decimalVector.getAccessor().getObject(i); + assertEquals(values[i], value); + } } + } + + @Test + public void testBigDecimalDifferentScaleAndPrecision() { + try (NullableDecimalVector decimalVector = TestUtils.newVector(NullableDecimalVector.class, "decimal", new ArrowType.Decimal(4, 2), allocator);) { + decimalVector.allocateNew(); - decimalVector.getMutator().setValueCount(intValues.length); + // test BigDecimal with different scale + boolean hasError = false; + try { + BigDecimal decimal = new BigDecimal(BigInteger.valueOf(0), 3); + decimalVector.getMutator().setSafe(0, decimal); + } catch (UnsupportedOperationException ue) { + hasError = true; + } finally { + assertTrue(hasError); + } - for (int i = 0; i < intValues.length; i++) { - BigDecimal value = decimalVector.getAccessor().getObject(i); - assertEquals(values[i], value); + // test BigDecimal with larger precision than initialized + hasError = false; + try { + BigDecimal decimal = new BigDecimal(BigInteger.valueOf(12345), 2); + decimalVector.getMutator().setSafe(0, decimal); + } catch (UnsupportedOperationException ue) { + hasError = true; + } finally { + assertTrue(hasError); + } } } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/BaseFileTest.java b/java/vector/src/test/java/org/apache/arrow/vector/file/BaseFileTest.java index 732bd98b7c6..c05d5904977 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/BaseFileTest.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/BaseFileTest.java @@ -18,6 +18,8 @@ package org.apache.arrow.vector.file; +import java.math.BigDecimal; +import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.List; @@ -27,6 +29,7 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.NullableDateMilliVector; +import org.apache.arrow.vector.NullableDecimalVector; import org.apache.arrow.vector.NullableIntVector; import org.apache.arrow.vector.NullableTimeMilliVector; import org.apache.arrow.vector.NullableVarCharVector; @@ -56,7 +59,6 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.DateUtility; -import org.apache.arrow.vector.util.DictionaryUtility; import org.apache.arrow.vector.util.Text; import org.joda.time.DateTimeZone; import org.joda.time.LocalDateTime; @@ -314,7 +316,6 @@ protected void validateFlatDictionary(VectorSchemaRoot root, DictionaryProvider Assert.assertEquals(1, accessor.getObject(4)); Assert.assertEquals(0, accessor.getObject(5)); - FieldVector vector2 = root.getVector("sizes"); Assert.assertNotNull(vector2); @@ -408,28 +409,57 @@ protected void validateNestedDictionary(VectorSchemaRoot root, DictionaryProvide Assert.assertEquals(new Text("bar"), dictionaryAccessor.getObject(1)); } - protected void validateNestedDictionary(ListVector vector, DictionaryProvider provider) { - Assert.assertNotNull(vector); - Assert.assertNull(vector.getField().getDictionary()); - Field nestedField = vector.getField().getChildren().get(0); + protected VectorSchemaRoot writeDecimalData(BufferAllocator bufferAllocator) { + NullableDecimalVector decimalVector1 = new NullableDecimalVector("decimal1", bufferAllocator, 10, 3); + NullableDecimalVector decimalVector2 = new NullableDecimalVector("decimal2", bufferAllocator, 4, 2); + NullableDecimalVector decimalVector3 = new NullableDecimalVector("decimal3", bufferAllocator, 16, 8); - DictionaryEncoding encoding = nestedField.getDictionary(); - Assert.assertNotNull(encoding); - Assert.assertEquals(2L, encoding.getId()); - Assert.assertEquals(new ArrowType.Int(32, true), encoding.getIndexType()); + int count = 10; + decimalVector1.allocateNew(count); + decimalVector2.allocateNew(count); + decimalVector3.allocateNew(count); - ListVector.Accessor accessor = vector.getAccessor(); - Assert.assertEquals(3, accessor.getValueCount()); - Assert.assertEquals(Arrays.asList(0, 1), accessor.getObject(0)); - Assert.assertEquals(Arrays.asList(0), accessor.getObject(1)); - Assert.assertEquals(Arrays.asList(1), accessor.getObject(2)); + for (int i = 0; i < count; i++) { + decimalVector1.getMutator().setSafe(i, new BigDecimal(BigInteger.valueOf(i), 3)); + decimalVector2.getMutator().setSafe(i, new BigDecimal(BigInteger.valueOf(i * (1 << 10)), 2)); + decimalVector3.getMutator().setSafe(i, new BigDecimal(BigInteger.valueOf(i * 1111111111111111L), 8)); + } - Dictionary dictionary = provider.lookup(2L); - Assert.assertNotNull(dictionary); - NullableVarCharVector.Accessor dictionaryAccessor = ((NullableVarCharVector) dictionary.getVector()).getAccessor(); - Assert.assertEquals(2, dictionaryAccessor.getValueCount()); - Assert.assertEquals(new Text("foo"), dictionaryAccessor.getObject(0)); - Assert.assertEquals(new Text("bar"), dictionaryAccessor.getObject(1)); + decimalVector1.getMutator().setValueCount(count); + decimalVector2.getMutator().setValueCount(count); + decimalVector3.getMutator().setValueCount(count); + + List fields = ImmutableList.of(decimalVector1.getField(), decimalVector2.getField(), decimalVector3.getField()); + List vectors = ImmutableList.of(decimalVector1, decimalVector2, decimalVector3); + return new VectorSchemaRoot(fields, vectors, count); + } + + protected void validateDecimalData(VectorSchemaRoot root) { + NullableDecimalVector decimalVector1 = (NullableDecimalVector) root.getVector("decimal1"); + NullableDecimalVector decimalVector2 = (NullableDecimalVector) root.getVector("decimal2"); + NullableDecimalVector decimalVector3 = (NullableDecimalVector) root.getVector("decimal3"); + int count = 10; + Assert.assertEquals(count, root.getRowCount()); + + for (int i = 0; i < count; i++) { + // Verify decimal 1 vector + BigDecimal readValue = decimalVector1.getAccessor().getObject(i); + ArrowType.Decimal type = (ArrowType.Decimal) decimalVector1.getField().getType(); + BigDecimal genValue = new BigDecimal(BigInteger.valueOf(i), type.getScale()); + Assert.assertEquals(genValue, readValue); + + // Verify decimal 2 vector + readValue = decimalVector2.getAccessor().getObject(i); + type = (ArrowType.Decimal) decimalVector2.getField().getType(); + genValue = new BigDecimal(BigInteger.valueOf(i * (1 << 10)), type.getScale()); + Assert.assertEquals(genValue, readValue); + + // Verify decimal 3 vector + readValue = decimalVector3.getAccessor().getObject(i); + type = (ArrowType.Decimal) decimalVector3.getField().getType(); + genValue = new BigDecimal(BigInteger.valueOf(i * 1111111111111111L), type.getScale()); + Assert.assertEquals(genValue, readValue); + } } protected void writeData(int count, MapVector parent) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/json/TestJSONFile.java b/java/vector/src/test/java/org/apache/arrow/vector/file/json/TestJSONFile.java index 24b2138386d..b7c06327291 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/json/TestJSONFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/json/TestJSONFile.java @@ -237,6 +237,38 @@ public void testWriteReadNestedDictionaryJSON() throws IOException { } } + @Test + public void testWriteReadDecimalJSON() throws IOException { + File file = new File("target/mytest_decimal.json"); + + // write + try ( + BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE) + ) { + + try (VectorSchemaRoot root = writeDecimalData(vectorAllocator)) { + printVectors(root.getFieldVectors()); + validateDecimalData(root); + writeJSON(file, root, null); + } + } + + // read + try ( + BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + ) { + JsonFileReader reader = new JsonFileReader(file, readerAllocator); + Schema schema = reader.start(); + LOGGER.debug("reading schema: " + schema); + + // initialize vectors + try (VectorSchemaRoot root = reader.read();) { + validateDecimalData(root); + } + reader.close(); + } + } + @Test public void testSetStructLength() throws IOException { File file = new File("../../integration/data/struct_example.json");