Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions java/core/lance-jni/src/blocking_scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@ use crate::error::{Error, Result};
use crate::ffi::JNIEnvExt;
use arrow::array::Float32Array;
use arrow::{ffi::FFI_ArrowSchema, ffi_stream::FFI_ArrowArrayStream};
use arrow_schema::SchemaRef;
use arrow_schema::{DataType, Field, Schema, SchemaRef};
use jni::objects::{JObject, JString};
use jni::sys::{jboolean, jint, JNI_TRUE};
use jni::{sys::jlong, JNIEnv};
use lance::dataset::scanner::{DatasetRecordBatchStream, Scanner};
use lance_io::ffi::to_ffi_arrow_array_stream;
use lance::dataset::ROW_ID;
use lance_io::ffi::to_ffi_jni_arrow_array_stream;
use lance_linalg::distance::DistanceType;

use crate::{
Expand Down Expand Up @@ -53,7 +54,21 @@ impl BlockingScanner {

pub fn schema(&self) -> Result<SchemaRef> {
let res = RT.block_on(self.inner.schema())?;
Ok(res)
let mut new_fields = Vec::new();
for field in res.clone().fields() {
if field.name() == ROW_ID {
let new_field = match field.data_type() {
DataType::UInt64 => {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to understand why making this change
Arrow in Java also have DataType:UInt64 but Spark cannot convert UInt64 to SparkType?

Could we modify the Spark Type to Arrow Type conversion directly? Like having an extension??

Field::new(field.name().clone(), DataType::Int64, field.is_nullable())
}
_ => field.as_ref().clone(),
};
new_fields.push(new_field);
} else {
new_fields.push(field.as_ref().clone());
}
}
Ok(Arc::new(Schema::new(new_fields)))
}

pub fn count_rows(&self) -> Result<u64> {
Expand Down Expand Up @@ -269,7 +284,7 @@ fn inner_open_stream(env: &mut JNIEnv, j_scanner: JObject, stream_addr: jlong) -
unsafe { env.get_rust_field::<_, _, BlockingScanner>(j_scanner, NATIVE_SCANNER) }?;
scanner_guard.open_stream()?
};
let ffi_stream = to_ffi_arrow_array_stream(record_batch_stream, RT.handle().clone())?;
let ffi_stream = to_ffi_jni_arrow_array_stream(record_batch_stream, RT.handle().clone())?;
unsafe { std::ptr::write_unaligned(stream_addr as *mut FFI_ArrowArrayStream, ffi_stream) }
Ok(())
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.lancedb.lance.spark;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
public class LanceConstant {
public static final String ROW_ID = "_rowid";
public static final StructField ROW_ID_SPARK_TYPE = new StructField(
ROW_ID, DataTypes.LongType, true, Metadata.empty());
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,13 @@ public class LanceDataSource implements SupportsCatalogOptions, DataSourceRegist

@Override
public StructType inferSchema(CaseInsensitiveStringMap options) {
Optional<StructType> schema = LanceDatasetAdapter.getSchema(LanceConfig.from(options));
return schema.isPresent() ? schema.get() : null;
LanceConfig config = LanceConfig.from(options);
Optional<StructType> schema = LanceDatasetAdapter.getSchema(config);
if (schema.isEmpty()) {
return null;
}
StructType actualSchema = schema.get();
return actualSchema;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,40 @@

import com.lancedb.lance.spark.read.LanceScanBuilder;
import com.lancedb.lance.spark.write.SparkWrite;
import org.apache.spark.sql.connector.catalog.MetadataColumn;
import org.apache.spark.sql.connector.catalog.SupportsMetadataColumns;
import org.apache.spark.sql.connector.catalog.SupportsRead;
import org.apache.spark.sql.connector.catalog.SupportsWrite;
import org.apache.spark.sql.connector.catalog.TableCapability;
import org.apache.spark.sql.connector.read.ScanBuilder;
import org.apache.spark.sql.connector.write.LogicalWriteInfo;
import org.apache.spark.sql.connector.write.WriteBuilder;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;

/**
* Lance Spark Dataset.
*/
public class LanceDataset implements SupportsRead, SupportsWrite {
public class LanceDataset implements SupportsRead, SupportsWrite, SupportsMetadataColumns {
private static final Set<TableCapability> CAPABILITIES =
ImmutableSet.of(TableCapability.BATCH_READ, TableCapability.BATCH_WRITE);

public static final MetadataColumn[] METADATA_COLUMNS = new MetadataColumn[]{
new MetadataColumn() {
@Override
public String name() {
return LanceConstant.ROW_ID;
}

@Override
public DataType dataType() {
return DataTypes.LongType;
}
}
};

LanceConfig options;
private final StructType sparkSchema;

Expand Down Expand Up @@ -72,4 +90,9 @@ public Set<TableCapability> capabilities() {
public WriteBuilder newWriteBuilder(LogicalWriteInfo logicalWriteInfo) {
return new SparkWrite.SparkWriteBuilder(sparkSchema, options);
}

@Override
public MetadataColumn[] metadataColumns() {
return METADATA_COLUMNS;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public class SparkOptions {
private static final String max_row_per_file = "max_row_per_file";
private static final String max_rows_per_group = "max_rows_per_group";
private static final String max_bytes_per_file = "max_bytes_per_file";
private static final String batch_size = "batch_size";

public static ReadOptions genReadOptionFromConfig(LanceConfig config) {
ReadOptions.Builder builder = new ReadOptions.Builder();
Expand Down Expand Up @@ -86,4 +87,11 @@ private static Map<String, String> genStorageOptions(LanceConfig config) {
return storageOptions;
}

public static int getBatchSize(LanceConfig config) {
Map<String, String> options = config.getOptions();
if (options.containsKey(batch_size)) {
return Integer.parseInt(options.get(batch_size));
}
return 512;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,14 @@
import java.util.stream.Collectors;

public class LanceDatasetAdapter {
private static final BufferAllocator allocator = new RootAllocator(
RootAllocator.configBuilder().from(RootAllocator.defaultConfig())
.maxAllocation(64 * 1024 * 1024).build());
private static final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);

public static Optional<StructType> getSchema(LanceConfig config) {
String uri = config.getDatasetUri();
ReadOptions options = SparkOptions.genReadOptionFromConfig(config);
try (Dataset dataset = Dataset.open(allocator, uri, options)) {
return Optional.of(ArrowUtils.fromArrowSchema(dataset.getSchema()));
StructType actualSchema = ArrowUtils.fromArrowSchema(dataset.getSchema());
return Optional.of(actualSchema);
} catch (IllegalArgumentException e) {
// dataset not found
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.lancedb.lance.ipc.LanceScanner;
import com.lancedb.lance.ipc.ScanOptions;
import com.lancedb.lance.spark.LanceConfig;
import com.lancedb.lance.spark.LanceConstant;
import com.lancedb.lance.spark.read.LanceInputPartition;
import com.lancedb.lance.spark.SparkOptions;
import org.apache.arrow.memory.BufferAllocator;
Expand Down Expand Up @@ -55,9 +56,11 @@ public static LanceFragmentScanner create(int fragmentId,
fragment = dataset.getFragments().get(fragmentId);
ScanOptions.Builder scanOptions = new ScanOptions.Builder();
scanOptions.columns(getColumnNames(inputPartition.getSchema()));
scanOptions.withRowId(getWithRowId(inputPartition.getSchema()));
if (inputPartition.getWhereCondition().isPresent()) {
scanOptions.filter(inputPartition.getWhereCondition().get());
}
scanOptions.batchSize(SparkOptions.getBatchSize(config));
scanner = fragment.newScan(scanOptions.build());
} catch (Throwable t) {
if (scanner != null) {
Expand Down Expand Up @@ -103,6 +106,13 @@ public void close() throws IOException {
private static List<String> getColumnNames(StructType schema) {
return Arrays.stream(schema.fields())
.map(StructField::name)
.filter(name -> !name.equals(LanceConstant.ROW_ID))
.collect(Collectors.toList());
}

private static boolean getWithRowId(StructType schema) {
return Arrays.stream(schema.fields())
.map(StructField::name)
.anyMatch(name -> name.equals(LanceConstant.ROW_ID));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ protected WriterFactory(StructType schema, LanceConfig config) {

@Override
public DataWriter<InternalRow> createWriter(int partitionId, long taskId) {
LanceArrowWriter arrowWriter = LanceDatasetAdapter.getArrowWriter(schema, 1024);
int batch_size = SparkOptions.getBatchSize(config);
LanceArrowWriter arrowWriter = LanceDatasetAdapter.getArrowWriter(schema, batch_size);
WriteParams params = SparkOptions.genWriteParamsFromConfig(config);
Callable<FragmentMetadata> fragmentCreator
= () -> LanceDatasetAdapter.createFragment(config.getDatasetUri(), arrowWriter, params);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ static void tearDown() {
}
}

@Test
public void testMetadata() {
String path = LanceConfig.getDatasetUri(dbPath, TestUtils.TestTable1Config.datasetName);
Dataset<Row> df = spark.sql("SELECT _rowid AS row_id, * FROM lance.`" + path + "`");
Object rows = df.collect();
System.out.println(rows);
}

private void validateData(Dataset<Row> data, List<List<Long>> expectedValues) {
List<Row> rows = data.collectAsList();
assertEquals(expectedValues.size(), rows.size());
Expand Down
90 changes: 87 additions & 3 deletions rust/lance-io/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
// SPDX-FileCopyrightText: Copyright The Lance Authors

use arrow::ffi_stream::FFI_ArrowArrayStream;
use arrow_array::RecordBatch;
use arrow_schema::{ArrowError, SchemaRef};
use arrow_array::{Array, Int64Array, RecordBatch, RecordBatchReader, UInt64Array};
use arrow_schema::{ArrowError, DataType, Field, Schema, SchemaRef};
use futures::StreamExt;
use lance_core::Result;
use lance_core::{Result, ROW_ID};
use std::sync::Arc;

use crate::stream::RecordBatchStream;

Expand Down Expand Up @@ -58,3 +59,86 @@ pub fn to_ffi_arrow_array_stream(

Ok(reader)
}

/// Wrap a [`RecordBatchStream`] into an [FFI_ArrowArrayStream] for jni call
/// transformer the _rowid UInt64 to Int64 since java only has Long type
pub fn to_ffi_jni_arrow_array_stream(
stream: impl RecordBatchStream + std::marker::Unpin + 'static,
handle: tokio::runtime::Handle,
) -> Result<FFI_ArrowArrayStream> {
let schema = stream.schema();
let arrow_stream = JniRecordBatchIteratorAdaptor::new(stream, schema, handle);
let reader = FFI_ArrowArrayStream::new(Box::new(arrow_stream));

Ok(reader)
}

#[pin_project::pin_project]
struct JniRecordBatchIteratorAdaptor<S: RecordBatchStream> {
schema: SchemaRef,
#[pin]
stream: S,
handle: tokio::runtime::Handle,
}
impl<S: RecordBatchStream> JniRecordBatchIteratorAdaptor<S> {
fn new(stream: S, schema: SchemaRef, handle: tokio::runtime::Handle) -> Self {
Self {
schema,
stream,
handle,
}
}
}
impl<S: RecordBatchStream + Unpin> arrow::record_batch::RecordBatchReader
for JniRecordBatchIteratorAdaptor<S>
{
fn schema(&self) -> SchemaRef {
let mut new_fields = Vec::new();
for field in self.schema.clone().fields() {
if field.name() == ROW_ID {
let new_field = match field.data_type() {
DataType::UInt64 => {
Field::new(field.name().clone(), DataType::Int64, field.is_nullable())
}
// Add more conversions as needed
_ => field.as_ref().clone(), // Keep the original if no conversion is needed
};
new_fields.push(new_field);
} else {
new_fields.push(field.as_ref().clone());
}
}
Arc::new(Schema::new(new_fields))
}
}
impl<S: RecordBatchStream + Unpin> Iterator for JniRecordBatchIteratorAdaptor<S> {
type Item = std::result::Result<RecordBatch, ArrowError>;
fn next(&mut self) -> Option<Self::Item> {
self.handle
.block_on(async { self.stream.next().await })
.map(|r| match r {
Ok(batch) => match batch.schema().index_of(ROW_ID) {
Ok(index) => {
let mut new_columns = batch.columns().to_vec();
let uint64_array = batch
.column(index)
.as_any()
.downcast_ref::<UInt64Array>()
.unwrap();
let mut int_values: Vec<i64> = Vec::with_capacity(uint64_array.len());
for i in 0..uint64_array.len() {
match uint64_array.value(i).try_into() {
Ok(value) => int_values.push(value),
Err(_err) => return Err(ArrowError::ExternalError(Box::new(_err))),
};
}
let int_array = Int64Array::from(int_values);
new_columns[index] = Arc::new(int_array);
RecordBatch::try_new(self.schema(), new_columns)
}
Err(_err) => Ok(batch),
},
Err(_err) => Err(ArrowError::ExternalError(Box::new(_err))),
})
}
}