diff --git a/java/core/lance-jni/src/blocking_scanner.rs b/java/core/lance-jni/src/blocking_scanner.rs index 8a3168e161c..21c39771372 100644 --- a/java/core/lance-jni/src/blocking_scanner.rs +++ b/java/core/lance-jni/src/blocking_scanner.rs @@ -18,12 +18,13 @@ use crate::error::{Error, Result}; use crate::ffi::JNIEnvExt; use arrow::array::Float32Array; use arrow::{ffi::FFI_ArrowSchema, ffi_stream::FFI_ArrowArrayStream}; -use arrow_schema::SchemaRef; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; use jni::objects::{JObject, JString}; use jni::sys::{jboolean, jint, JNI_TRUE}; use jni::{sys::jlong, JNIEnv}; use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner}; -use lance_io::ffi::to_ffi_arrow_array_stream; +use lance::dataset::ROW_ID; +use lance_io::ffi::to_ffi_jni_arrow_array_stream; use lance_linalg::distance::DistanceType; use crate::{ @@ -53,7 +54,21 @@ impl BlockingScanner { pub fn schema(&self) -> Result { let res = RT.block_on(self.inner.schema())?; - Ok(res) + let mut new_fields = Vec::new(); + for field in res.clone().fields() { + if field.name() == ROW_ID { + let new_field = match field.data_type() { + DataType::UInt64 => { + Field::new(field.name().clone(), DataType::Int64, field.is_nullable()) + } + _ => field.as_ref().clone(), + }; + new_fields.push(new_field); + } else { + new_fields.push(field.as_ref().clone()); + } + } + Ok(Arc::new(Schema::new(new_fields))) } pub fn count_rows(&self) -> Result { @@ -269,7 +284,7 @@ fn inner_open_stream(env: &mut JNIEnv, j_scanner: JObject, stream_addr: jlong) - unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) }?; scanner_guard.open_stream()? }; - let ffi_stream = to_ffi_arrow_array_stream(record_batch_stream, RT.handle().clone())?; + let ffi_stream = to_ffi_jni_arrow_array_stream(record_batch_stream, RT.handle().clone())?; unsafe { std::ptr::write_unaligned(stream_addr as *mut FFI_ArrowArrayStream, ffi_stream) } Ok(()) } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceConstant.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceConstant.java new file mode 100644 index 00000000000..76054fef1e3 --- /dev/null +++ b/java/spark/src/main/java/com/lancedb/lance/spark/LanceConstant.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.spark; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +public class LanceConstant { + public static final String ROW_ID = "_rowid"; + public static final StructField ROW_ID_SPARK_TYPE = new StructField( + ROW_ID, DataTypes.LongType, true, Metadata.empty()); +} \ No newline at end of file diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataSource.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataSource.java index 0bc5fcbbdd1..9e0d88433ae 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataSource.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataSource.java @@ -31,8 +31,13 @@ public class LanceDataSource implements SupportsCatalogOptions, DataSourceRegist @Override public StructType inferSchema(CaseInsensitiveStringMap options) { - Optional schema = LanceDatasetAdapter.getSchema(LanceConfig.from(options)); - return schema.isPresent() ? schema.get() : null; + LanceConfig config = LanceConfig.from(options); + Optional schema = LanceDatasetAdapter.getSchema(config); + if (schema.isEmpty()) { + return null; + } + StructType actualSchema = schema.get(); + return actualSchema; } @Override diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataset.java b/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataset.java index 702b3bdf42a..1de2a2b967f 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataset.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/LanceDataset.java @@ -18,22 +18,40 @@ import com.lancedb.lance.spark.read.LanceScanBuilder; import com.lancedb.lance.spark.write.SparkWrite; +import org.apache.spark.sql.connector.catalog.MetadataColumn; +import org.apache.spark.sql.connector.catalog.SupportsMetadataColumns; import org.apache.spark.sql.connector.catalog.SupportsRead; import org.apache.spark.sql.connector.catalog.SupportsWrite; import org.apache.spark.sql.connector.catalog.TableCapability; import org.apache.spark.sql.connector.read.ScanBuilder; import org.apache.spark.sql.connector.write.LogicalWriteInfo; import org.apache.spark.sql.connector.write.WriteBuilder; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** * Lance Spark Dataset. */ -public class LanceDataset implements SupportsRead, SupportsWrite { +public class LanceDataset implements SupportsRead, SupportsWrite, SupportsMetadataColumns { private static final Set CAPABILITIES = ImmutableSet.of(TableCapability.BATCH_READ, TableCapability.BATCH_WRITE); + public static final MetadataColumn[] METADATA_COLUMNS = new MetadataColumn[]{ + new MetadataColumn() { + @Override + public String name() { + return LanceConstant.ROW_ID; + } + + @Override + public DataType dataType() { + return DataTypes.LongType; + } + } + }; + LanceConfig options; private final StructType sparkSchema; @@ -72,4 +90,9 @@ public Set capabilities() { public WriteBuilder newWriteBuilder(LogicalWriteInfo logicalWriteInfo) { return new SparkWrite.SparkWriteBuilder(sparkSchema, options); } + + @Override + public MetadataColumn[] metadataColumns() { + return METADATA_COLUMNS; + } } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/SparkOptions.java b/java/spark/src/main/java/com/lancedb/lance/spark/SparkOptions.java index 6ccee2c79ef..7ca2ec8ee4f 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/SparkOptions.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/SparkOptions.java @@ -34,6 +34,7 @@ public class SparkOptions { private static final String max_row_per_file = "max_row_per_file"; private static final String max_rows_per_group = "max_rows_per_group"; private static final String max_bytes_per_file = "max_bytes_per_file"; + private static final String batch_size = "batch_size"; public static ReadOptions genReadOptionFromConfig(LanceConfig config) { ReadOptions.Builder builder = new ReadOptions.Builder(); @@ -86,4 +87,11 @@ private static Map genStorageOptions(LanceConfig config) { return storageOptions; } + public static int getBatchSize(LanceConfig config) { + Map options = config.getOptions(); + if (options.containsKey(batch_size)) { + return Integer.parseInt(options.get(batch_size)); + } + return 512; + } } diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java index ff87744b6c1..c029abaa122 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceDatasetAdapter.java @@ -33,15 +33,14 @@ import java.util.stream.Collectors; public class LanceDatasetAdapter { - private static final BufferAllocator allocator = new RootAllocator( - RootAllocator.configBuilder().from(RootAllocator.defaultConfig()) - .maxAllocation(64 * 1024 * 1024).build()); + private static final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); public static Optional getSchema(LanceConfig config) { String uri = config.getDatasetUri(); ReadOptions options = SparkOptions.genReadOptionFromConfig(config); try (Dataset dataset = Dataset.open(allocator, uri, options)) { - return Optional.of(ArrowUtils.fromArrowSchema(dataset.getSchema())); + StructType actualSchema = ArrowUtils.fromArrowSchema(dataset.getSchema()); + return Optional.of(actualSchema); } catch (IllegalArgumentException e) { // dataset not found return Optional.empty(); diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentScanner.java b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentScanner.java index e71cf33b7e3..88e0a78682e 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentScanner.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/internal/LanceFragmentScanner.java @@ -20,6 +20,7 @@ import com.lancedb.lance.ipc.LanceScanner; import com.lancedb.lance.ipc.ScanOptions; import com.lancedb.lance.spark.LanceConfig; +import com.lancedb.lance.spark.LanceConstant; import com.lancedb.lance.spark.read.LanceInputPartition; import com.lancedb.lance.spark.SparkOptions; import org.apache.arrow.memory.BufferAllocator; @@ -55,9 +56,11 @@ public static LanceFragmentScanner create(int fragmentId, fragment = dataset.getFragments().get(fragmentId); ScanOptions.Builder scanOptions = new ScanOptions.Builder(); scanOptions.columns(getColumnNames(inputPartition.getSchema())); + scanOptions.withRowId(getWithRowId(inputPartition.getSchema())); if (inputPartition.getWhereCondition().isPresent()) { scanOptions.filter(inputPartition.getWhereCondition().get()); } + scanOptions.batchSize(SparkOptions.getBatchSize(config)); scanner = fragment.newScan(scanOptions.build()); } catch (Throwable t) { if (scanner != null) { @@ -103,6 +106,13 @@ public void close() throws IOException { private static List getColumnNames(StructType schema) { return Arrays.stream(schema.fields()) .map(StructField::name) + .filter(name -> !name.equals(LanceConstant.ROW_ID)) .collect(Collectors.toList()); } + + private static boolean getWithRowId(StructType schema) { + return Arrays.stream(schema.fields()) + .map(StructField::name) + .anyMatch(name -> name.equals(LanceConstant.ROW_ID)); + } } \ No newline at end of file diff --git a/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java index 706b6144d19..0f9aaa0f0d4 100644 --- a/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java +++ b/java/spark/src/main/java/com/lancedb/lance/spark/write/LanceDataWriter.java @@ -91,7 +91,8 @@ protected WriterFactory(StructType schema, LanceConfig config) { @Override public DataWriter createWriter(int partitionId, long taskId) { - LanceArrowWriter arrowWriter = LanceDatasetAdapter.getArrowWriter(schema, 1024); + int batch_size = SparkOptions.getBatchSize(config); + LanceArrowWriter arrowWriter = LanceDatasetAdapter.getArrowWriter(schema, batch_size); WriteParams params = SparkOptions.genWriteParamsFromConfig(config); Callable fragmentCreator = () -> LanceDatasetAdapter.createFragment(config.getDatasetUri(), arrowWriter, params); diff --git a/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadTest.java b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadTest.java index fe5a82a6427..742cb908752 100644 --- a/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadTest.java +++ b/java/spark/src/test/java/com/lancedb/lance/spark/read/SparkConnectorReadTest.java @@ -55,6 +55,14 @@ static void tearDown() { } } + @Test + public void testMetadata() { + String path = LanceConfig.getDatasetUri(dbPath, TestUtils.TestTable1Config.datasetName); + Dataset df = spark.sql("SELECT _rowid AS row_id, * FROM lance.`" + path + "`"); + Object rows = df.collect(); + System.out.println(rows); + } + private void validateData(Dataset data, List> expectedValues) { List rows = data.collectAsList(); assertEquals(expectedValues.size(), rows.size()); diff --git a/rust/lance-io/src/ffi.rs b/rust/lance-io/src/ffi.rs index 950977e3f11..0b8e3aa0e58 100644 --- a/rust/lance-io/src/ffi.rs +++ b/rust/lance-io/src/ffi.rs @@ -2,10 +2,11 @@ // SPDX-FileCopyrightText: Copyright The Lance Authors use arrow::ffi_stream::FFI_ArrowArrayStream; -use arrow_array::RecordBatch; -use arrow_schema::{ArrowError, SchemaRef}; +use arrow_array::{Array, Int64Array, RecordBatch, RecordBatchReader, UInt64Array}; +use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef}; use futures::StreamExt; -use lance_core::Result; +use lance_core::{Result, ROW_ID}; +use std::sync::Arc; use crate::stream::RecordBatchStream; @@ -58,3 +59,86 @@ pub fn to_ffi_arrow_array_stream( Ok(reader) } + +/// Wrap a [`RecordBatchStream`] into an [FFI_ArrowArrayStream] for jni call +/// transformer the _rowid UInt64 to Int64 since java only has Long type +pub fn to_ffi_jni_arrow_array_stream( + stream: impl RecordBatchStream + std::marker::Unpin + 'static, + handle: tokio::runtime::Handle, +) -> Result { + let schema = stream.schema(); + let arrow_stream = JniRecordBatchIteratorAdaptor::new(stream, schema, handle); + let reader = FFI_ArrowArrayStream::new(Box::new(arrow_stream)); + + Ok(reader) +} + +#[pin_project::pin_project] +struct JniRecordBatchIteratorAdaptor { + schema: SchemaRef, + #[pin] + stream: S, + handle: tokio::runtime::Handle, +} +impl JniRecordBatchIteratorAdaptor { + fn new(stream: S, schema: SchemaRef, handle: tokio::runtime::Handle) -> Self { + Self { + schema, + stream, + handle, + } + } +} +impl arrow::record_batch::RecordBatchReader + for JniRecordBatchIteratorAdaptor +{ + fn schema(&self) -> SchemaRef { + let mut new_fields = Vec::new(); + for field in self.schema.clone().fields() { + if field.name() == ROW_ID { + let new_field = match field.data_type() { + DataType::UInt64 => { + Field::new(field.name().clone(), DataType::Int64, field.is_nullable()) + } + // Add more conversions as needed + _ => field.as_ref().clone(), // Keep the original if no conversion is needed + }; + new_fields.push(new_field); + } else { + new_fields.push(field.as_ref().clone()); + } + } + Arc::new(Schema::new(new_fields)) + } +} +impl Iterator for JniRecordBatchIteratorAdaptor { + type Item = std::result::Result; + fn next(&mut self) -> Option { + self.handle + .block_on(async { self.stream.next().await }) + .map(|r| match r { + Ok(batch) => match batch.schema().index_of(ROW_ID) { + Ok(index) => { + let mut new_columns = batch.columns().to_vec(); + let uint64_array = batch + .column(index) + .as_any() + .downcast_ref::() + .unwrap(); + let mut int_values: Vec = Vec::with_capacity(uint64_array.len()); + for i in 0..uint64_array.len() { + match uint64_array.value(i).try_into() { + Ok(value) => int_values.push(value), + Err(_err) => return Err(ArrowError::ExternalError(Box::new(_err))), + }; + } + let int_array = Int64Array::from(int_values); + new_columns[index] = Arc::new(int_array); + RecordBatch::try_new(self.schema(), new_columns) + } + Err(_err) => Ok(batch), + }, + Err(_err) => Err(ArrowError::ExternalError(Box::new(_err))), + }) + } +}