diff --git a/java/dataset/pom.xml b/java/dataset/pom.xml index fa91adfd87b..b1b9cdb6ec8 100644 --- a/java/dataset/pom.xml +++ b/java/dataset/pom.xml @@ -140,6 +140,12 @@ 2.8.1 test + + commons-io + commons-io + 2.4 + test + diff --git a/java/dataset/src/main/cpp/jni_wrapper.cc b/java/dataset/src/main/cpp/jni_wrapper.cc index aa7d7670232..b3b5fe18c79 100644 --- a/java/dataset/src/main/cpp/jni_wrapper.cc +++ b/java/dataset/src/main/cpp/jni_wrapper.cc @@ -19,6 +19,8 @@ #include "arrow/array.h" #include "arrow/array/concatenate.h" +#include "arrow/c/bridge.h" +#include "arrow/c/helpers.h" #include "arrow/dataset/api.h" #include "arrow/dataset/file_base.h" #include "arrow/filesystem/localfs.h" @@ -176,6 +178,21 @@ class DisposableScannerAdaptor { } }; +arrow::Result> SchemaFromColumnNames( + const std::shared_ptr& input, + const std::vector& column_names) { + std::vector> columns; + for (arrow::FieldRef ref : column_names) { + auto maybe_field = ref.GetOne(*input); + if (maybe_field.ok()) { + columns.push_back(std::move(maybe_field).ValueOrDie()); + } else { + return arrow::Status::Invalid("Partition column '", ref.ToString(), "' is not in dataset schema"); + } + } + + return schema(std::move(columns))->WithMetadata(input->metadata()); +} } // namespace using arrow::dataset::jni::CreateGlobalClassReference; @@ -229,7 +246,6 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { GetMethodID(env, java_reservation_listener_class, "unreserve", "(J)V")); default_memory_pool_id = reinterpret_cast(arrow::default_memory_pool()); - return JNI_VERSION; JNI_METHOD_END(JNI_ERR) } @@ -516,3 +532,49 @@ Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory( return CreateNativeRef(d); JNI_METHOD_END(-1L) } + +/* + * Class: org_apache_arrow_dataset_file_JniWrapper + * Method: writeFromScannerToFile + * Signature: + * (JJJLjava/lang/String;[Ljava/lang/String;ILjava/lang/String;)V + */ +JNIEXPORT void JNICALL +Java_org_apache_arrow_dataset_file_JniWrapper_writeFromScannerToFile( + JNIEnv* env, jobject, jlong c_arrow_array_stream_address, + jlong file_format_id, jstring uri, jobjectArray partition_columns, + jint max_partitions, jstring base_name_template) { + JNI_METHOD_START + JavaVM* vm; + if (env->GetJavaVM(&vm) != JNI_OK) { + JniThrow("Unable to get JavaVM instance"); + } + + auto* arrow_stream = reinterpret_cast(c_arrow_array_stream_address); + std::shared_ptr reader = + JniGetOrThrow(arrow::ImportRecordBatchReader(arrow_stream)); + std::shared_ptr scanner_builder = + arrow::dataset::ScannerBuilder::FromRecordBatchReader(reader); + JniAssertOkOrThrow(scanner_builder->Pool(arrow::default_memory_pool())); + auto scanner = JniGetOrThrow(scanner_builder->Finish()); + + std::shared_ptr schema = reader->schema(); + + std::shared_ptr file_format = + JniGetOrThrow(GetFileFormat(file_format_id)); + arrow::dataset::FileSystemDatasetWriteOptions options; + std::string output_path; + auto filesystem = JniGetOrThrow( + arrow::fs::FileSystemFromUri(JStringToCString(env, uri), &output_path)); + std::vector partition_column_vector = + ToStringVector(env, partition_columns); + options.file_write_options = file_format->DefaultWriteOptions(); + options.filesystem = filesystem; + options.base_dir = output_path; + options.basename_template = JStringToCString(env, base_name_template); + options.partitioning = std::make_shared( + SchemaFromColumnNames(schema, partition_column_vector).ValueOrDie()); + options.max_partitions = max_partitions; + JniAssertOkOrThrow(arrow::dataset::FileSystemDataset::Write(options, scanner)); + JNI_METHOD_END() +} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/DatasetFileWriter.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/DatasetFileWriter.java new file mode 100644 index 00000000000..b2369b853ad --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/DatasetFileWriter.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.file; + +import org.apache.arrow.c.ArrowArrayStream; +import org.apache.arrow.c.Data; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.ipc.ArrowReader; + +/** + * JNI-based utility to write datasets into files. It internally depends on C++ static method + * FileSystemDataset::Write. + */ +public class DatasetFileWriter { + + /** + * Write the contents of an ArrowReader as a dataset. + * + * @param reader the datasource for writing + * @param format target file format + * @param uri target file uri + * @param maxPartitions maximum partitions to be included in written files + * @param partitionColumns columns used to partition output files. Empty to disable partitioning + * @param baseNameTemplate file name template used to make partitions. E.g. "dat_{i}", i is current partition + * ID around all written files. + */ + public static void write(BufferAllocator allocator, ArrowReader reader, FileFormat format, String uri, + String[] partitionColumns, int maxPartitions, String baseNameTemplate) { + try (final ArrowArrayStream stream = ArrowArrayStream.allocateNew(allocator)) { + Data.exportArrayStream(allocator, reader, stream); + JniWrapper.get().writeFromScannerToFile(stream.memoryAddress(), + format.id(), uri, partitionColumns, maxPartitions, baseNameTemplate); + } + } + + /** + * Write the contents of an ArrowReader as a dataset, with default partitioning settings. + * + * @param reader the datasource for writing + * @param format target file format + * @param uri target file uri + */ + public static void write(BufferAllocator allocator, ArrowReader reader, FileFormat format, String uri) { + write(allocator, reader, format, uri, new String[0], 1024, "data_{i}"); + } +} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java index 6e65803a333..18560a46a5c 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java @@ -45,4 +45,23 @@ private JniWrapper() { */ public native long makeFileSystemDatasetFactory(String uri, int fileFormat); + /** + * Write the content in a {@link org.apache.arrow.c.ArrowArrayStream} into files. This internally + * depends on C++ write API: FileSystemDataset::Write. + * + * @param streamAddress the ArrowArrayStream address + * @param fileFormat target file format (ID) + * @param uri target file uri + * @param partitionColumns columns used to partition output files + * @param maxPartitions maximum partitions to be included in written files + * @param baseNameTemplate file name template used to make partitions. E.g. "dat_{i}", i is current partition + * ID around all written files. + */ + public native void writeFromScannerToFile(long streamAddress, + long fileFormat, + String uri, + String[] partitionColumns, + int maxPartitions, + String baseNameTemplate); + } diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ArrowScannerReader.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ArrowScannerReader.java new file mode 100644 index 00000000000..417ba837a3b --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ArrowScannerReader.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.scanner; + +import java.io.IOException; +import java.util.Iterator; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.ArrowReader; +import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * An implementation of {@link ArrowReader} that reads + * the dataset from {@link Scanner}. + */ +public class ArrowScannerReader extends ArrowReader { + private final Scanner scanner; + + private Iterator taskIterator; + + private ScanTask currentTask = null; + private ArrowReader currentReader = null; + + /** + * Constructs a scanner reader using a Scanner. + * + * @param scanner scanning data over dataset + * @param allocator to allocate new buffers + */ + public ArrowScannerReader(Scanner scanner, BufferAllocator allocator) { + super(allocator); + this.scanner = scanner; + this.taskIterator = scanner.scan().iterator(); + if (taskIterator.hasNext()) { + currentTask = taskIterator.next(); + currentReader = currentTask.execute(); + } + } + + @Override + protected void loadRecordBatch(ArrowRecordBatch batch) { + throw new UnsupportedOperationException(); + } + + @Override + protected void loadDictionary(ArrowDictionaryBatch dictionaryBatch) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean loadNextBatch() throws IOException { + if (currentReader == null) { + return false; + } + boolean result = currentReader.loadNextBatch(); + + if (!result) { + try { + currentTask.close(); + currentReader.close(); + } catch (Exception e) { + throw new IOException(e); + } + + while (!result) { + if (!taskIterator.hasNext()) { + return false; + } else { + currentTask = taskIterator.next(); + currentReader = currentTask.execute(); + result = currentReader.loadNextBatch(); + } + } + } + + VectorLoader loader = new VectorLoader(this.getVectorSchemaRoot()); + VectorUnloader unloader = + new VectorUnloader(currentReader.getVectorSchemaRoot()); + try (ArrowRecordBatch recordBatch = unloader.getRecordBatch()) { + loader.load(recordBatch); + } + return true; + } + + @Override + public long bytesRead() { + return 0L; + } + + @Override + protected void closeReadSource() throws IOException { + try { + currentTask.close(); + currentReader.close(); + scanner.close(); + } catch (Exception e) { + throw new IOException(e); + } + } + + @Override + protected Schema readSchema() throws IOException { + return scanner.schema(); + } +} diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestDatasetFileWriter.java b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestDatasetFileWriter.java new file mode 100644 index 00000000000..10c06be2cca --- /dev/null +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/file/TestDatasetFileWriter.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.file; + +import java.io.File; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; + +import org.apache.arrow.dataset.ParquetWriteSupport; +import org.apache.arrow.dataset.TestDataset; +import org.apache.arrow.dataset.jni.NativeMemoryPool; +import org.apache.arrow.dataset.scanner.ArrowScannerReader; +import org.apache.arrow.dataset.scanner.ScanOptions; +import org.apache.arrow.dataset.scanner.Scanner; +import org.apache.arrow.dataset.source.Dataset; +import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.compare.VectorEqualsVisitor; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.commons.io.FileUtils; +import org.junit.Assert; +import org.junit.ClassRule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +public class TestDatasetFileWriter extends TestDataset { + + @ClassRule + public static final TemporaryFolder TMP = new TemporaryFolder(); + + public static final String AVRO_SCHEMA_USER = "user.avsc"; + + @Test + public void testParquetWriteSimple() throws Exception { + ParquetWriteSupport writeSupport = ParquetWriteSupport.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), + 1, "a", 2, "b", 3, "c", 2, "d"); + String sampleParquet = writeSupport.getOutputURI(); + ScanOptions options = new ScanOptions(new String[0], 100); + final File writtenFolder = TMP.newFolder(); + final String writtenParquet = writtenFolder.toURI().toString(); + try (FileSystemDatasetFactory factory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, sampleParquet); + final Dataset dataset = factory.finish(); + final Scanner scanner = dataset.newScan(options); + final ArrowScannerReader reader = new ArrowScannerReader(scanner, rootAllocator()); + ) { + DatasetFileWriter.write(rootAllocator(), reader, FileFormat.PARQUET, writtenParquet); + assertParquetFileEquals(sampleParquet, Objects.requireNonNull(writtenFolder.listFiles())[0].toURI().toString()); + } + } + + @Test + public void testParquetWriteWithPartitions() throws Exception { + ParquetWriteSupport writeSupport = ParquetWriteSupport.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), + 1, "a", 2, "b", 3, "c", 2, "d"); + String sampleParquet = writeSupport.getOutputURI(); + ScanOptions options = new ScanOptions(new String[0], 100); + final File writtenFolder = TMP.newFolder(); + final String writtenParquet = writtenFolder.toURI().toString(); + + try (FileSystemDatasetFactory factory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(), + FileFormat.PARQUET, sampleParquet); + final Dataset dataset = factory.finish(); + final Scanner scanner = dataset.newScan(options); + final ArrowScannerReader reader = new ArrowScannerReader(scanner, rootAllocator()); + ) { + DatasetFileWriter.write(rootAllocator(), reader, + FileFormat.PARQUET, writtenParquet, new String[]{"id", "name"}, + 100, "data_{i}"); + final Set expectedOutputFiles = new HashSet<>( + Arrays.asList("id=1/name=a/data_0", "id=2/name=b/data_0", "id=3/name=c/data_0", "id=2/name=d/data_0")); + final Set outputFiles = FileUtils.listFiles(writtenFolder, null, true) + .stream() + .map(file -> { + return writtenFolder.toURI().relativize(file.toURI()).toString(); + }) + .collect(Collectors.toSet()); + Assert.assertEquals(expectedOutputFiles, outputFiles); + } + } + + private void assertParquetFileEquals(String expectedURI, String actualURI) throws Exception { + final FileSystemDatasetFactory expectedFactory = new FileSystemDatasetFactory( + rootAllocator(), NativeMemoryPool.getDefault(), FileFormat.PARQUET, expectedURI); + final FileSystemDatasetFactory actualFactory = new FileSystemDatasetFactory( + rootAllocator(), NativeMemoryPool.getDefault(), FileFormat.PARQUET, actualURI); + List expectedBatches = collectResultFromFactory(expectedFactory, + new ScanOptions(new String[0], 100)); + List actualBatches = collectResultFromFactory(actualFactory, + new ScanOptions(new String[0], 100)); + try ( + VectorSchemaRoot expectVsr = VectorSchemaRoot.create(expectedFactory.inspect(), rootAllocator()); + VectorSchemaRoot actualVsr = VectorSchemaRoot.create(actualFactory.inspect(), rootAllocator())) { + + // fast-fail by comparing metadata + Assert.assertEquals(expectedBatches.toString(), actualBatches.toString()); + // compare ArrowRecordBatches + Assert.assertEquals(expectedBatches.size(), actualBatches.size()); + VectorLoader expectLoader = new VectorLoader(expectVsr); + VectorLoader actualLoader = new VectorLoader(actualVsr); + for (int i = 0; i < expectedBatches.size(); i++) { + expectLoader.load(expectedBatches.get(i)); + actualLoader.load(actualBatches.get(i)); + for (int j = 0; j < expectVsr.getFieldVectors().size(); j++) { + FieldVector vector = expectVsr.getFieldVectors().get(i); + FieldVector otherVector = actualVsr.getFieldVectors().get(i); + // TODO: ARROW-18140 Use VectorSchemaRoot#equals() method to compare + Assert.assertTrue(VectorEqualsVisitor.vectorEquals(vector, otherVector)); + } + } + } finally { + AutoCloseables.close(expectedBatches, actualBatches); + } + } +} +