From f4560d994cc8bd462e5288d2f428370b03671b05 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 25 Aug 2017 15:25:20 -0700 Subject: [PATCH 1/6] added Decimal JSON support, Java roundtrip unit tests --- .../vector/file/json/JsonFileReader.java | 14 +++- .../vector/file/json/JsonFileWriter.java | 16 +++- .../arrow/vector/util/DecimalUtility.java | 20 ++++- .../arrow/vector/file/BaseFileTest.java | 82 ++++++++++++++----- .../arrow/vector/file/json/TestJSONFile.java | 32 ++++++++ 5 files changed, 134 insertions(+), 30 deletions(-) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java index 484a82fdaab..71685d13589 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java @@ -39,6 +39,7 @@ import org.apache.arrow.vector.BufferBacked; import org.apache.arrow.vector.DateDayVector; import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.Float4Vector; import org.apache.arrow.vector.Float8Vector; @@ -72,6 +73,7 @@ import org.apache.arrow.vector.schema.ArrowVectorType; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.DecimalUtility; import org.apache.arrow.vector.util.DictionaryUtility; import org.apache.commons.codec.DecoderException; import org.apache.commons.codec.binary.Hex; @@ -235,14 +237,16 @@ private void readVector(Field field, FieldVector vector) throws JsonParseExcepti nextFieldIs(vectorType.getName()); readToken(START_ARRAY); ValueVector valueVector = (ValueVector) innerVector; - valueVector.allocateNew(); - Mutator mutator = valueVector.getMutator(); int innerVectorCount = vectorType.equals(OFFSET) ? count + 1 : count; + valueVector.setInitialCapacity(innerVectorCount); + valueVector.allocateNew(); + for (int i = 0; i < innerVectorCount; i++) { parser.nextToken(); setValueFromParser(valueVector, i); } + Mutator mutator = valueVector.getMutator(); mutator.setValueCount(innerVectorCount); readToken(END_ARRAY); } @@ -312,6 +316,12 @@ private void setValueFromParser(ValueVector valueVector, int i) throws IOExcepti case FLOAT8: ((Float8Vector) valueVector).getMutator().set(i, parser.readValueAs(Double.class)); break; + case DECIMAL: { + DecimalVector decimalVector = ((DecimalVector) valueVector); + byte[] value = decodeHexSafe(parser.readValueAs(String.class)); + DecimalUtility.writeByteArrayToArrowBuf(value, decimalVector.getBuffer(), i); + } + break; case VARBINARY: ((VarBinaryVector) valueVector).getMutator().setSafe(i, decodeHexSafe(parser.readValueAs(String.class))); break; diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileWriter.java index a2229cef231..5e558ef6cc3 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileWriter.java @@ -26,10 +26,12 @@ import java.util.Set; import com.google.common.collect.ImmutableList; +import io.netty.buffer.ArrowBuf; import org.apache.arrow.vector.BitVector; import org.apache.arrow.vector.BufferBacked; import org.apache.arrow.vector.DateDayVector; import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.TimeMicroVector; import org.apache.arrow.vector.TimeMilliVector; @@ -54,6 +56,7 @@ import com.fasterxml.jackson.core.util.DefaultPrettyPrinter; import com.fasterxml.jackson.core.util.DefaultPrettyPrinter.NopIndenter; import com.fasterxml.jackson.databind.MappingJsonFactory; +import org.apache.arrow.vector.util.DecimalUtility; import org.apache.arrow.vector.util.DictionaryUtility; import org.apache.commons.codec.binary.Hex; @@ -233,9 +236,16 @@ private void writeValueToGenerator(ValueVector valueVector, int i) throws IOExce case BIT: generator.writeNumber(((BitVector) valueVector).getAccessor().get(i)); break; - case VARBINARY: - String hexString = Hex.encodeHexString(((VarBinaryVector) valueVector).getAccessor().get(i)); - generator.writeObject(hexString); + case VARBINARY: { + String hexString = Hex.encodeHexString(((VarBinaryVector) valueVector).getAccessor().get(i)); + generator.writeObject(hexString); + } + break; + case DECIMAL: { + ArrowBuf bytebuf = ((DecimalVector) valueVector).getAccessor().get(i); + String hexString = Hex.encodeHexString(DecimalUtility.getByteArrayFromArrowBuf(bytebuf, 0)); + generator.writeObject(hexString); + } break; default: // TODO: each type diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java b/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java index 4b11b368dff..3b1d0beef5c 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java @@ -19,13 +19,10 @@ package org.apache.arrow.vector.util; import io.netty.buffer.ArrowBuf; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.UnpooledByteBufAllocator; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; -import java.util.Arrays; public class DecimalUtility { @@ -150,14 +147,29 @@ public static BigDecimal getBigDecimalFromByteBuffer(ByteBuffer bytebuf, int sta return new BigDecimal(unscaledValue, scale); } + public static byte[] getByteArrayFromArrowBuf(ArrowBuf bytebuf, int index) { + final byte[] value = new byte[DECIMAL_BYTE_LENGTH]; + final int startIndex = index * DECIMAL_BYTE_LENGTH; + bytebuf.getBytes(startIndex, value, 0, DECIMAL_BYTE_LENGTH); + return value; + } + public static void writeBigDecimalToArrowBuf(BigDecimal value, ArrowBuf bytebuf, int index) { final byte[] bytes = value.unscaledValue().toByteArray(); + final int padValue = value.signum() == -1 ? 0xFF : 0; + writeByteArrayToArrowBuf(bytes, bytebuf, index, padValue); + } + + public static void writeByteArrayToArrowBuf(byte[] bytes, ArrowBuf bytebuf, int index) { + writeByteArrayToArrowBuf(bytes, bytebuf, index, 0); + } + + private static void writeByteArrayToArrowBuf(byte[] bytes, ArrowBuf bytebuf, int index, int padValue) { final int startIndex = index * DECIMAL_BYTE_LENGTH; if (bytes.length > DECIMAL_BYTE_LENGTH) { throw new UnsupportedOperationException("Decimal size greater than 16 bytes"); } final int padLength = DECIMAL_BYTE_LENGTH - bytes.length; - final int padValue = value.signum() == -1 ? 0xFF : 0; for (int i = 0; i < padLength; i++) { bytebuf.setByte(startIndex + i, padValue); } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/BaseFileTest.java b/java/vector/src/test/java/org/apache/arrow/vector/file/BaseFileTest.java index 732bd98b7c6..9f117ceaaf9 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/BaseFileTest.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/BaseFileTest.java @@ -18,6 +18,8 @@ package org.apache.arrow.vector.file; +import java.math.BigDecimal; +import java.math.BigInteger; import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.List; @@ -27,6 +29,7 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.NullableDateMilliVector; +import org.apache.arrow.vector.NullableDecimalVector; import org.apache.arrow.vector.NullableIntVector; import org.apache.arrow.vector.NullableTimeMilliVector; import org.apache.arrow.vector.NullableVarCharVector; @@ -56,7 +59,7 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.DateUtility; -import org.apache.arrow.vector.util.DictionaryUtility; +import org.apache.arrow.vector.util.DecimalUtility; import org.apache.arrow.vector.util.Text; import org.joda.time.DateTimeZone; import org.joda.time.LocalDateTime; @@ -314,7 +317,6 @@ protected void validateFlatDictionary(VectorSchemaRoot root, DictionaryProvider Assert.assertEquals(1, accessor.getObject(4)); Assert.assertEquals(0, accessor.getObject(5)); - FieldVector vector2 = root.getVector("sizes"); Assert.assertNotNull(vector2); @@ -408,28 +410,66 @@ protected void validateNestedDictionary(VectorSchemaRoot root, DictionaryProvide Assert.assertEquals(new Text("bar"), dictionaryAccessor.getObject(1)); } - protected void validateNestedDictionary(ListVector vector, DictionaryProvider provider) { - Assert.assertNotNull(vector); - Assert.assertNull(vector.getField().getDictionary()); - Field nestedField = vector.getField().getChildren().get(0); + private void writeIntAsDecimal(NullableDecimalVector vector, long value, int index) { + ArrowType.Decimal type = (ArrowType.Decimal) vector.getField().getType(); + BigDecimal decimalValue = new BigDecimal(BigInteger.valueOf(value), type.getScale()); + DecimalUtility.writeBigDecimalToArrowBuf(decimalValue, vector.getBuffer(), index); + vector.getValidityVector().getMutator().setToOne(index); + } - DictionaryEncoding encoding = nestedField.getDictionary(); - Assert.assertNotNull(encoding); - Assert.assertEquals(2L, encoding.getId()); - Assert.assertEquals(new ArrowType.Int(32, true), encoding.getIndexType()); + protected VectorSchemaRoot writeDecimalData(BufferAllocator bufferAllocator) { + NullableDecimalVector decimalVector1 = new NullableDecimalVector("decimal1", bufferAllocator, 10, 3); + NullableDecimalVector decimalVector2 = new NullableDecimalVector("decimal2", bufferAllocator, 4, 2); + NullableDecimalVector decimalVector3 = new NullableDecimalVector("decimal3", bufferAllocator, 16, 8); - ListVector.Accessor accessor = vector.getAccessor(); - Assert.assertEquals(3, accessor.getValueCount()); - Assert.assertEquals(Arrays.asList(0, 1), accessor.getObject(0)); - Assert.assertEquals(Arrays.asList(0), accessor.getObject(1)); - Assert.assertEquals(Arrays.asList(1), accessor.getObject(2)); + int count = 10; + decimalVector1.allocateNew(count); + decimalVector2.allocateNew(count); + decimalVector3.allocateNew(count); - Dictionary dictionary = provider.lookup(2L); - Assert.assertNotNull(dictionary); - NullableVarCharVector.Accessor dictionaryAccessor = ((NullableVarCharVector) dictionary.getVector()).getAccessor(); - Assert.assertEquals(2, dictionaryAccessor.getValueCount()); - Assert.assertEquals(new Text("foo"), dictionaryAccessor.getObject(0)); - Assert.assertEquals(new Text("bar"), dictionaryAccessor.getObject(1)); + for (int i = 0; i < count; i++) { + writeIntAsDecimal(decimalVector1, i, i); + writeIntAsDecimal(decimalVector2, i * (1 << 10), i); + //long blah = i * Long.valueOf(String.format("%0" + 10 + "d", 0).replace("0", "1")); + //System.out.println("*** Int val: " + blah); (Long.valueOf(String.format("%0" + 16 + "d", 0).replace("0", "1"))), i); + writeIntAsDecimal(decimalVector3, i * 1111111111111111L, i); + } + + decimalVector1.getMutator().setValueCount(count); + decimalVector2.getMutator().setValueCount(count); + decimalVector3.getMutator().setValueCount(count); + + List fields = ImmutableList.of(decimalVector1.getField(), decimalVector2.getField(), decimalVector3.getField()); + List vectors = ImmutableList.of(decimalVector1, decimalVector2, decimalVector3); + return new VectorSchemaRoot(fields, vectors, count); + } + + protected void validateDecimalData(VectorSchemaRoot root) { + NullableDecimalVector decimalVector1 = (NullableDecimalVector) root.getVector("decimal1"); + NullableDecimalVector decimalVector2 = (NullableDecimalVector) root.getVector("decimal2"); + NullableDecimalVector decimalVector3 = (NullableDecimalVector) root.getVector("decimal3"); + int count = 10; + Assert.assertEquals(count, root.getRowCount()); + + for (int i = 0; i < count; i++) { + // Verify decimal 1 vector + BigDecimal readValue = decimalVector1.getAccessor().getObject(i); + ArrowType.Decimal type = (ArrowType.Decimal) decimalVector1.getField().getType(); + BigDecimal genValue = new BigDecimal(BigInteger.valueOf(i), type.getScale()); + Assert.assertEquals(genValue, readValue); + + // Verify decimal 2 vector + readValue = decimalVector2.getAccessor().getObject(i); + type = (ArrowType.Decimal) decimalVector2.getField().getType(); + genValue = new BigDecimal(BigInteger.valueOf(i * (1 << 10)), type.getScale()); + Assert.assertEquals(genValue, readValue); + + // Verify decimal 3 vector + readValue = decimalVector3.getAccessor().getObject(i); + type = (ArrowType.Decimal) decimalVector3.getField().getType(); + genValue = new BigDecimal(BigInteger.valueOf(i * 1111111111111111L), type.getScale()); + Assert.assertEquals(genValue, readValue); + } } protected void writeData(int count, MapVector parent) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/json/TestJSONFile.java b/java/vector/src/test/java/org/apache/arrow/vector/file/json/TestJSONFile.java index 24b2138386d..b7c06327291 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/json/TestJSONFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/json/TestJSONFile.java @@ -237,6 +237,38 @@ public void testWriteReadNestedDictionaryJSON() throws IOException { } } + @Test + public void testWriteReadDecimalJSON() throws IOException { + File file = new File("target/mytest_decimal.json"); + + // write + try ( + BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE) + ) { + + try (VectorSchemaRoot root = writeDecimalData(vectorAllocator)) { + printVectors(root.getFieldVectors()); + validateDecimalData(root); + writeJSON(file, root, null); + } + } + + // read + try ( + BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + ) { + JsonFileReader reader = new JsonFileReader(file, readerAllocator); + Schema schema = reader.start(); + LOGGER.debug("reading schema: " + schema); + + // initialize vectors + try (VectorSchemaRoot root = reader.read();) { + validateDecimalData(root); + } + reader.close(); + } + } + @Test public void testSetStructLength() throws IOException { File file = new File("../../integration/data/struct_example.json"); From da11b4f6873b87c67f25353133f453158f8d9107 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 25 Aug 2017 16:10:03 -0700 Subject: [PATCH 2/6] removed debug line --- .../test/java/org/apache/arrow/vector/file/BaseFileTest.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/BaseFileTest.java b/java/vector/src/test/java/org/apache/arrow/vector/file/BaseFileTest.java index 9f117ceaaf9..1dd86b62f28 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/BaseFileTest.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/BaseFileTest.java @@ -430,8 +430,6 @@ protected VectorSchemaRoot writeDecimalData(BufferAllocator bufferAllocator) { for (int i = 0; i < count; i++) { writeIntAsDecimal(decimalVector1, i, i); writeIntAsDecimal(decimalVector2, i * (1 << 10), i); - //long blah = i * Long.valueOf(String.format("%0" + 10 + "d", 0).replace("0", "1")); - //System.out.println("*** Int val: " + blah); (Long.valueOf(String.format("%0" + 16 + "d", 0).replace("0", "1"))), i); writeIntAsDecimal(decimalVector3, i * 1111111111111111L, i); } From c5e8fba1b61af6a72f15a897388d0a08e85857c9 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 29 Aug 2017 11:41:48 -0700 Subject: [PATCH 3/6] minor tweaks to JsonFileWriter --- .../org/apache/arrow/vector/file/json/JsonFileWriter.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileWriter.java index 5e558ef6cc3..04e44379e5d 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileWriter.java @@ -238,13 +238,13 @@ private void writeValueToGenerator(ValueVector valueVector, int i) throws IOExce break; case VARBINARY: { String hexString = Hex.encodeHexString(((VarBinaryVector) valueVector).getAccessor().get(i)); - generator.writeObject(hexString); + generator.writeString(hexString); } break; case DECIMAL: { - ArrowBuf bytebuf = ((DecimalVector) valueVector).getAccessor().get(i); - String hexString = Hex.encodeHexString(DecimalUtility.getByteArrayFromArrowBuf(bytebuf, 0)); - generator.writeObject(hexString); + ArrowBuf bytebuf = valueVector.getDataBuffer(); + String hexString = Hex.encodeHexString(DecimalUtility.getByteArrayFromArrowBuf(bytebuf, i)); + generator.writeString(hexString); } break; default: From 10cac9cc53a7142c872fead7a7a7411b54aa6193 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 29 Aug 2017 17:18:06 -0700 Subject: [PATCH 4/6] added vector API for set and setSafe with BigDecimal --- .../codegen/templates/AbstractFieldWriter.java | 7 +++++++ .../main/codegen/templates/ComplexWriters.java | 11 +++++++++++ .../codegen/templates/FixedValueVectors.java | 17 ++++++++++++++--- .../codegen/templates/NullableValueVectors.java | 13 +++++++++++++ .../arrow/vector/util/DecimalUtility.java | 5 ++--- .../apache/arrow/vector/TestDecimalVector.java | 3 +-- .../apache/arrow/vector/file/BaseFileTest.java | 14 +++----------- 7 files changed, 51 insertions(+), 19 deletions(-) diff --git a/java/vector/src/main/codegen/templates/AbstractFieldWriter.java b/java/vector/src/main/codegen/templates/AbstractFieldWriter.java index da8e4f54ec2..853f67fd0dd 100644 --- a/java/vector/src/main/codegen/templates/AbstractFieldWriter.java +++ b/java/vector/src/main/codegen/templates/AbstractFieldWriter.java @@ -53,6 +53,7 @@ public void endList() { <#list vv.types as type><#list type.minor as minor><#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> + <#assign friendlyType = (minor.friendlyType!minor.boxedType!type.boxedType) /> @Override public void write(${name}Holder holder) { fail("${name}"); @@ -62,6 +63,12 @@ public void write(${name}Holder holder) { fail("${name}"); } + <#if minor.class == "Decimal"> + public void write${minor.class}(${friendlyType} value) { + fail("${name}"); + } + + public void writeNull() { diff --git a/java/vector/src/main/codegen/templates/ComplexWriters.java b/java/vector/src/main/codegen/templates/ComplexWriters.java index 8ebecf3e1de..fe099bede35 100644 --- a/java/vector/src/main/codegen/templates/ComplexWriters.java +++ b/java/vector/src/main/codegen/templates/ComplexWriters.java @@ -24,6 +24,7 @@ <#assign eName = name /> <#assign javaType = (minor.javaType!type.javaType) /> <#assign fields = minor.fields!type.fields /> +<#assign friendlyType = (minor.friendlyType!minor.boxedType!type.boxedType) /> <@pp.changeOutputFile name="/org/apache/arrow/vector/complex/impl/${eName}WriterImpl.java" /> <#include "/@includes/license.ftl" /> @@ -115,7 +116,13 @@ public void write(Nullable${minor.class}Holder h) { mutator.setSafe(idx()<#if mode == "Nullable">, 1<#list fields as field><#if field.include!true >, ${field.name}); vector.getMutator().setValueCount(idx()+1); } + <#if minor.class == "Decimal"> + public void write${minor.class}(${friendlyType} value) { + mutator.setSafe(idx(), value); + vector.getMutator().setValueCount(idx()+1); + } + <#if mode == "Nullable"> public void writeNull() { @@ -140,6 +147,10 @@ public interface ${eName}Writer extends BaseWriter { public void write(${minor.class}Holder h); public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, ); +<#if minor.class == "Decimal"> + + public void write${minor.class}(${friendlyType} value); + } diff --git a/java/vector/src/main/codegen/templates/FixedValueVectors.java b/java/vector/src/main/codegen/templates/FixedValueVectors.java index 9747d421c41..9831231c510 100644 --- a/java/vector/src/main/codegen/templates/FixedValueVectors.java +++ b/java/vector/src/main/codegen/templates/FixedValueVectors.java @@ -403,7 +403,7 @@ public void get(int index, Nullable${minor.class}Holder holder) { @Override public ${friendlyType} getObject(int index) { - return org.apache.arrow.vector.util.DecimalUtility.getBigDecimalFromArrowBuf(data, ${type.width} * index, scale); + return DecimalUtility.getBigDecimalFromArrowBuf(data, index, scale); } <#else> @@ -596,10 +596,10 @@ void set(int index, Nullable${minor.class}Holder holder){ set(index, holder.start, holder.buffer); } - public void setSafe(int index, Nullable${minor.class}Holder holder){ + public void setSafe(int index, Nullable${minor.class}Holder holder){ setSafe(index, holder.start, holder.buffer); } - public void setSafe(int index, ${minor.class}Holder holder){ + public void setSafe(int index, ${minor.class}Holder holder){ setSafe(index, holder.start, holder.buffer); } @@ -614,6 +614,17 @@ public void set(int index, int start, ArrowBuf buffer){ data.setBytes(index * ${type.width}, buffer, start, ${type.width}); } + public void set(int index, ${friendlyType} value){ + DecimalUtility.writeBigDecimalToArrowBuf(value, data, index); + } + + public void setSafe(int index, ${friendlyType} value){ + while(index >= getValueCapacity()) { + reAlloc(); + } + set(index, value); + } + <#else> protected void set(int index, ${minor.class}Holder holder){ set(index, holder.start, holder.buffer); diff --git a/java/vector/src/main/codegen/templates/NullableValueVectors.java b/java/vector/src/main/codegen/templates/NullableValueVectors.java index a4313332563..319c61c8624 100644 --- a/java/vector/src/main/codegen/templates/NullableValueVectors.java +++ b/java/vector/src/main/codegen/templates/NullableValueVectors.java @@ -723,6 +723,19 @@ public void setSafe(int index, ${minor.javaType!type.javaType} value) { setCount++; } + + <#if minor.class == "Decimal"> + public void set(int index, ${friendlyType} value) { + bits.getMutator().setToOne(index); + values.getMutator().set(index, value); + } + + public void setSafe(int index, ${friendlyType} value) { + bits.getMutator().setSafeToOne(index); + values.getMutator().setSafe(index, value); + setCount++; + } + @Override public void setValueCount(int valueCount) { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java b/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java index 3b1d0beef5c..ec9d3eb755f 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java @@ -133,8 +133,9 @@ public static StringBuilder toStringWithZeroes(long number, int desiredLength) { return str; } - public static BigDecimal getBigDecimalFromArrowBuf(ArrowBuf bytebuf, int startIndex, int scale) { + public static BigDecimal getBigDecimalFromArrowBuf(ArrowBuf bytebuf, int index, int scale) { byte[] value = new byte[DECIMAL_BYTE_LENGTH]; + final int startIndex = index * DECIMAL_BYTE_LENGTH; bytebuf.getBytes(startIndex, value, 0, DECIMAL_BYTE_LENGTH); BigInteger unscaledValue = new BigInteger(value); return new BigDecimal(unscaledValue, scale); @@ -176,5 +177,3 @@ private static void writeByteArrayToArrowBuf(byte[] bytes, ArrowBuf bytebuf, int bytebuf.setBytes(startIndex + padLength, bytes, 0, bytes.length); } } - - diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java index 774fbe084f1..00f83e145cb 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java @@ -55,8 +55,7 @@ public void test() { for (int i = 0; i < intValues.length; i++) { BigDecimal decimal = new BigDecimal(BigInteger.valueOf(intValues[i]), scale); values[i] = decimal; - decimalVector.getMutator().setIndexDefined(i); - DecimalUtility.writeBigDecimalToArrowBuf(decimal, decimalVector.getBuffer(), i); + decimalVector.getMutator().setSafe(i, decimal); } decimalVector.getMutator().setValueCount(intValues.length); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/BaseFileTest.java b/java/vector/src/test/java/org/apache/arrow/vector/file/BaseFileTest.java index 1dd86b62f28..c05d5904977 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/BaseFileTest.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/BaseFileTest.java @@ -59,7 +59,6 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.DateUtility; -import org.apache.arrow.vector.util.DecimalUtility; import org.apache.arrow.vector.util.Text; import org.joda.time.DateTimeZone; import org.joda.time.LocalDateTime; @@ -410,13 +409,6 @@ protected void validateNestedDictionary(VectorSchemaRoot root, DictionaryProvide Assert.assertEquals(new Text("bar"), dictionaryAccessor.getObject(1)); } - private void writeIntAsDecimal(NullableDecimalVector vector, long value, int index) { - ArrowType.Decimal type = (ArrowType.Decimal) vector.getField().getType(); - BigDecimal decimalValue = new BigDecimal(BigInteger.valueOf(value), type.getScale()); - DecimalUtility.writeBigDecimalToArrowBuf(decimalValue, vector.getBuffer(), index); - vector.getValidityVector().getMutator().setToOne(index); - } - protected VectorSchemaRoot writeDecimalData(BufferAllocator bufferAllocator) { NullableDecimalVector decimalVector1 = new NullableDecimalVector("decimal1", bufferAllocator, 10, 3); NullableDecimalVector decimalVector2 = new NullableDecimalVector("decimal2", bufferAllocator, 4, 2); @@ -428,9 +420,9 @@ protected VectorSchemaRoot writeDecimalData(BufferAllocator bufferAllocator) { decimalVector3.allocateNew(count); for (int i = 0; i < count; i++) { - writeIntAsDecimal(decimalVector1, i, i); - writeIntAsDecimal(decimalVector2, i * (1 << 10), i); - writeIntAsDecimal(decimalVector3, i * 1111111111111111L, i); + decimalVector1.getMutator().setSafe(i, new BigDecimal(BigInteger.valueOf(i), 3)); + decimalVector2.getMutator().setSafe(i, new BigDecimal(BigInteger.valueOf(i * (1 << 10)), 2)); + decimalVector3.getMutator().setSafe(i, new BigDecimal(BigInteger.valueOf(i * 1111111111111111L), 8)); } decimalVector1.getMutator().setValueCount(count); From 31b7ec18d2365750b10d5338082df2eeda0adbe4 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 1 Sep 2017 14:46:52 -0700 Subject: [PATCH 5/6] Added check that BigDecimal precision and scale matches that of the vector --- .../codegen/templates/FixedValueVectors.java | 1 + .../arrow/vector/util/DecimalUtility.java | 50 +++++++++++++++++-- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/java/vector/src/main/codegen/templates/FixedValueVectors.java b/java/vector/src/main/codegen/templates/FixedValueVectors.java index 9831231c510..ffd8cad02e2 100644 --- a/java/vector/src/main/codegen/templates/FixedValueVectors.java +++ b/java/vector/src/main/codegen/templates/FixedValueVectors.java @@ -615,6 +615,7 @@ public void set(int index, int start, ArrowBuf buffer){ } public void set(int index, ${friendlyType} value){ + DecimalUtility.checkPrecisionAndScale(value, precision, scale); DecimalUtility.writeBigDecimalToArrowBuf(value, data, index); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java b/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java index ec9d3eb755f..033ae6c0991 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/DecimalUtility.java @@ -19,6 +19,8 @@ package org.apache.arrow.vector.util; import io.netty.buffer.ArrowBuf; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.types.pojo.ArrowType; import java.math.BigDecimal; import java.math.BigInteger; @@ -66,7 +68,7 @@ public class DecimalUtility { public static final int DECIMAL_BYTE_LENGTH = 16; - /* + /** * Simple function that returns the static precomputed * power of ten, instead of using Math.pow */ @@ -75,7 +77,7 @@ public static long getPowerOfTen(int power) { return scale_long_constants[(power)]; } - /* + /** * Math.pow returns a double and while multiplying with large digits * in the decimal data type we encounter noise. So instead of multiplying * with Math.pow we use the static constants to perform the multiplication @@ -100,7 +102,8 @@ public static long adjustScaleDivide(long input, int factor) { } } - /* Returns a string representation of the given integer + /** + * Returns a string representation of the given integer * If the length of the given integer is less than the * passed length, this function will prepend zeroes to the string */ @@ -133,6 +136,10 @@ public static StringBuilder toStringWithZeroes(long number, int desiredLength) { return str; } + /** + * Read an ArrowType.Decimal at the given value index in the ArrowBuf and convert to a BigDecimal + * with the given scale. + */ public static BigDecimal getBigDecimalFromArrowBuf(ArrowBuf bytebuf, int index, int scale) { byte[] value = new byte[DECIMAL_BYTE_LENGTH]; final int startIndex = index * DECIMAL_BYTE_LENGTH; @@ -141,13 +148,21 @@ public static BigDecimal getBigDecimalFromArrowBuf(ArrowBuf bytebuf, int index, return new BigDecimal(unscaledValue, scale); } - public static BigDecimal getBigDecimalFromByteBuffer(ByteBuffer bytebuf, int start, int scale) { + /** + * Read an ArrowType.Decimal from the ByteBuffer and convert to a BigDecimal with the given + * scale. + */ + public static BigDecimal getBigDecimalFromByteBuffer(ByteBuffer bytebuf, int scale) { byte[] value = new byte[DECIMAL_BYTE_LENGTH]; bytebuf.get(value); BigInteger unscaledValue = new BigInteger(value); return new BigDecimal(unscaledValue, scale); } + /** + * Read an ArrowType.Decimal from the ArrowBuf at the given value index and return it as a byte + * array. + */ public static byte[] getByteArrayFromArrowBuf(ArrowBuf bytebuf, int index) { final byte[] value = new byte[DECIMAL_BYTE_LENGTH]; final int startIndex = index * DECIMAL_BYTE_LENGTH; @@ -155,12 +170,39 @@ public static byte[] getByteArrayFromArrowBuf(ArrowBuf bytebuf, int index) { return value; } + /** + * Check that the BigDecimal scale equals the vectorScale and that the BigDecimal precision is + * less than or equal to the vectorPrecision. If not, then an UnsupportedOperationException is + * thrown, otherwise returns true. + */ + public static boolean checkPrecisionAndScale(BigDecimal value, int vectorPrecision, int vectorScale) { + if (value.scale() != vectorScale) { + throw new UnsupportedOperationException("BigDecimal scale must equal that in the Arrow vector: " + + value.scale() + " != " + vectorScale); + } + if (value.precision() > vectorPrecision) { + throw new UnsupportedOperationException("BigDecimal precision can not be greater than that in the Arrow vector: " + + value.precision() + " > " + vectorPrecision); + } + return true; + } + + /** + * Write the given BigDecimal to the ArrowBuf at the given value index. Will throw an + * UnsupportedOperationException if the decimal size is greater than the Decimal vector byte + * width. + */ public static void writeBigDecimalToArrowBuf(BigDecimal value, ArrowBuf bytebuf, int index) { final byte[] bytes = value.unscaledValue().toByteArray(); final int padValue = value.signum() == -1 ? 0xFF : 0; writeByteArrayToArrowBuf(bytes, bytebuf, index, padValue); } + /** + * Write the given byte array to the ArrowBuf at the given value index. Will throw an + * UnsupportedOperationException if the decimal size is greater than the Decimal vector byte + * width. + */ public static void writeByteArrayToArrowBuf(byte[] bytes, ArrowBuf bytebuf, int index) { writeByteArrayToArrowBuf(bytes, bytebuf, index, 0); } From 28c1e3e6ac231216fe0a6f760ef2dac44b8668bb Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Fri, 1 Sep 2017 15:18:28 -0700 Subject: [PATCH 6/6] added test for BigDecimal precision and scale mismatch --- .../arrow/vector/TestDecimalVector.java | 78 +++++++++++++++---- 1 file changed, 62 insertions(+), 16 deletions(-) diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java index 00f83e145cb..56d22932764 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java @@ -19,6 +19,7 @@ package org.apache.arrow.vector; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import java.math.BigDecimal; import java.math.BigInteger; @@ -27,6 +28,8 @@ import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.util.DecimalUtility; +import org.junit.After; +import org.junit.Before; import org.junit.Test; public class TestDecimalVector { @@ -43,26 +46,69 @@ public class TestDecimalVector { private int scale = 3; + private BufferAllocator allocator; + + @Before + public void init() { + allocator = new DirtyRootAllocator(Long.MAX_VALUE, (byte) 100); + } + + @After + public void terminate() throws Exception { + allocator.close(); + } + @Test - public void test() { - BufferAllocator allocator = new RootAllocator(Integer.MAX_VALUE); - NullableDecimalVector decimalVector = TestUtils.newVector(NullableDecimalVector.class, "decimal", new ArrowType.Decimal(10, scale), allocator); - try (NullableDecimalVector oldConstructor = new NullableDecimalVector("decimal", allocator, 10, scale);) { - assertEquals(decimalVector.getField().getType(), oldConstructor.getField().getType()); - } - decimalVector.allocateNew(); - BigDecimal[] values = new BigDecimal[intValues.length]; - for (int i = 0; i < intValues.length; i++) { - BigDecimal decimal = new BigDecimal(BigInteger.valueOf(intValues[i]), scale); - values[i] = decimal; - decimalVector.getMutator().setSafe(i, decimal); + public void testValuesWriteRead() { + try (NullableDecimalVector decimalVector = TestUtils.newVector(NullableDecimalVector.class, "decimal", new ArrowType.Decimal(10, scale), allocator);) { + + try (NullableDecimalVector oldConstructor = new NullableDecimalVector("decimal", allocator, 10, scale);) { + assertEquals(decimalVector.getField().getType(), oldConstructor.getField().getType()); + } + + decimalVector.allocateNew(); + BigDecimal[] values = new BigDecimal[intValues.length]; + for (int i = 0; i < intValues.length; i++) { + BigDecimal decimal = new BigDecimal(BigInteger.valueOf(intValues[i]), scale); + values[i] = decimal; + decimalVector.getMutator().setSafe(i, decimal); + } + + decimalVector.getMutator().setValueCount(intValues.length); + + for (int i = 0; i < intValues.length; i++) { + BigDecimal value = decimalVector.getAccessor().getObject(i); + assertEquals(values[i], value); + } } + } + + @Test + public void testBigDecimalDifferentScaleAndPrecision() { + try (NullableDecimalVector decimalVector = TestUtils.newVector(NullableDecimalVector.class, "decimal", new ArrowType.Decimal(4, 2), allocator);) { + decimalVector.allocateNew(); - decimalVector.getMutator().setValueCount(intValues.length); + // test BigDecimal with different scale + boolean hasError = false; + try { + BigDecimal decimal = new BigDecimal(BigInteger.valueOf(0), 3); + decimalVector.getMutator().setSafe(0, decimal); + } catch (UnsupportedOperationException ue) { + hasError = true; + } finally { + assertTrue(hasError); + } - for (int i = 0; i < intValues.length; i++) { - BigDecimal value = decimalVector.getAccessor().getObject(i); - assertEquals(values[i], value); + // test BigDecimal with larger precision than initialized + hasError = false; + try { + BigDecimal decimal = new BigDecimal(BigInteger.valueOf(12345), 2); + decimalVector.getMutator().setSafe(0, decimal); + } catch (UnsupportedOperationException ue) { + hasError = true; + } finally { + assertTrue(hasError); + } } } }