diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java index 04d50331b76..f274b748e55 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileReader.java @@ -69,6 +69,8 @@ public JsonFileReader(File inputFile, BufferAllocator allocator) throws JsonPars this.allocator = allocator; MappingJsonFactory jsonFactory = new MappingJsonFactory(); this.parser = jsonFactory.createParser(inputFile); + // Allow reading NaN for floating point values + this.parser.configure(JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS, true); } @Override diff --git a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java index 067fb25b8d8..1c9e1d38095 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/ipc/JsonFileWriter.java @@ -84,6 +84,8 @@ public JsonFileWriter(File outputFile, JSONWriteConfig config) throws IOExceptio prettyPrinter.indentArraysWith(NopIndenter.instance); this.generator.setPrettyPrinter(prettyPrinter); } + // Allow writing of floating point NaN values not as strings + this.generator.configure(JsonGenerator.Feature.QUOTE_NON_NUMERIC_NUMBERS, false); } public void start(Schema schema, DictionaryProvider provider) throws IOException { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java index 3514acaa242..d26385d7494 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java @@ -47,6 +47,7 @@ import org.apache.arrow.vector.complex.writer.BaseWriter.MapWriter; import org.apache.arrow.vector.complex.writer.BigIntWriter; import org.apache.arrow.vector.complex.writer.DateMilliWriter; +import org.apache.arrow.vector.complex.writer.Float4Writer; import org.apache.arrow.vector.complex.writer.IntWriter; import org.apache.arrow.vector.complex.writer.TimeMilliWriter; import org.apache.arrow.vector.complex.writer.TimeStampMilliTZWriter; @@ -95,10 +96,28 @@ public void tearDown() { DateTimeZone.setDefault(defaultTimezone); } + protected void writeData(int count, MapVector parent) { + ComplexWriter writer = new ComplexWriterImpl("root", parent); + MapWriter rootWriter = writer.rootAsMap(); + IntWriter intWriter = rootWriter.integer("int"); + BigIntWriter bigIntWriter = rootWriter.bigInt("bigInt"); + Float4Writer float4Writer = rootWriter.float4("float"); + for (int i = 0; i < count; i++) { + intWriter.setPosition(i); + intWriter.writeInt(i); + bigIntWriter.setPosition(i); + bigIntWriter.writeBigInt(i); + float4Writer.setPosition(i); + float4Writer.writeFloat4(i == 0 ? Float.NaN : i); + } + writer.setValueCount(count); + } + protected void validateContent(int count, VectorSchemaRoot root) { for (int i = 0; i < count; i++) { Assert.assertEquals(i, root.getVector("int").getObject(i)); Assert.assertEquals(Long.valueOf(i), root.getVector("bigInt").getObject(i)); + Assert.assertEquals(i == 0 ? Float.NaN : i, root.getVector("float").getObject(i)); } } @@ -454,20 +473,6 @@ protected void validateDecimalData(VectorSchemaRoot root) { } } - protected void writeData(int count, MapVector parent) { - ComplexWriter writer = new ComplexWriterImpl("root", parent); - MapWriter rootWriter = writer.rootAsMap(); - IntWriter intWriter = rootWriter.integer("int"); - BigIntWriter bigIntWriter = rootWriter.bigInt("bigInt"); - for (int i = 0; i < count; i++) { - intWriter.setPosition(i); - intWriter.writeInt(i); - bigIntWriter.setPosition(i); - bigIntWriter.writeBigInt(i); - } - writer.setValueCount(count); - } - public void validateUnionData(int count, VectorSchemaRoot root) { FieldReader unionReader = root.getVector("union").getReader(); for (int i = 0; i < count; i++) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestJSONFile.java b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestJSONFile.java index 625717048bf..4f9093b8c02 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestJSONFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/ipc/TestJSONFile.java @@ -38,6 +38,33 @@ public class TestJSONFile extends BaseFileTest { private static final Logger LOGGER = LoggerFactory.getLogger(TestJSONFile.class); + @Test + public void testWriteRead() throws IOException { + File file = new File("target/mytest.json"); + int count = COUNT; + + // write + try (BufferAllocator originalVectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); + MapVector parent = MapVector.empty("parent", originalVectorAllocator)) { + writeData(count, parent); + writeJSON(file, new VectorSchemaRoot(parent.getChild("root")), null); + } + + // read + try ( + BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); + JsonFileReader reader = new JsonFileReader(file, readerAllocator) + ) { + Schema schema = reader.start(); + LOGGER.debug("reading schema: " + schema); + + // initialize vectors + try (VectorSchemaRoot root = reader.read();) { + validateContent(count, root); + } + } + } + @Test public void testWriteReadComplexJSON() throws IOException { File file = new File("target/mytest_complex.json");