diff --git a/java/core/lance-jni/src/blocking_dataset.rs b/java/core/lance-jni/src/blocking_dataset.rs index 7bfe0e0a602..b91477497d1 100644 --- a/java/core/lance-jni/src/blocking_dataset.rs +++ b/java/core/lance-jni/src/blocking_dataset.rs @@ -14,8 +14,11 @@ use crate::error::{Error, Result}; use crate::ffi::JNIEnvExt; +use crate::schema::convert_to_java_field; use crate::traits::{export_vec, import_vec, FromJObjectWithEnv, FromJString}; -use crate::utils::{extract_storage_options, extract_write_params, get_index_params, to_rust_map}; +use crate::utils::{ + extract_storage_options, extract_write_params, get_index_params, to_java_map, to_rust_map, +}; use crate::{traits::IntoJava, RT}; use arrow::array::RecordBatchReader; use arrow::datatypes::Schema; @@ -710,6 +713,41 @@ fn inner_get_fragment<'local>( Ok(obj) } +#[no_mangle] +pub extern "system" fn Java_com_lancedb_lance_Dataset_nativeGetLanceSchema<'local>( + mut env: JNIEnv<'local>, + java_dataset: JObject, +) -> JObject<'local> { + ok_or_throw!(env, inner_get_lance_schema(&mut env, java_dataset)) +} + +fn inner_get_lance_schema<'local>( + env: &mut JNIEnv<'local>, + java_dataset: JObject, +) -> Result> { + let schema = { + let dataset = + unsafe { env.get_rust_field::<_, _, BlockingDataset>(java_dataset, NATIVE_DATASET) }?; + dataset.inner.schema().clone() + }; + let jfield_list = env.new_object("java/util/ArrayList", "()V", &[])?; + for lance_field in schema.fields.iter() { + let java_field = convert_to_java_field(env, lance_field)?; + env.call_method( + &jfield_list, + "add", + "(Ljava/lang/Object;)Z", + &[JValue::Object(&java_field)], + )?; + } + let metadata = to_java_map(env, &schema.metadata)?; + Ok(env.new_object( + "com/lancedb/lance/schema/LanceSchema", + "(Ljava/util/List;Ljava/util/Map;)V", + &[JValue::Object(&jfield_list), JValue::Object(&metadata)], + )?) +} + #[no_mangle] pub extern "system" fn Java_com_lancedb_lance_Dataset_importFfiSchema( mut env: JNIEnv, diff --git a/java/core/lance-jni/src/lib.rs b/java/core/lance-jni/src/lib.rs index 4212137cde8..053bd8e9cc8 100644 --- a/java/core/lance-jni/src/lib.rs +++ b/java/core/lance-jni/src/lib.rs @@ -57,6 +57,7 @@ pub mod ffi; mod file_reader; mod file_writer; mod fragment; +mod schema; pub mod traits; pub mod utils; diff --git a/java/core/lance-jni/src/schema.rs b/java/core/lance-jni/src/schema.rs new file mode 100644 index 00000000000..0d8741a4824 --- /dev/null +++ b/java/core/lance-jni/src/schema.rs @@ -0,0 +1,419 @@ +use crate::error::{Error, Result}; +use crate::utils::to_java_map; +use arrow::datatypes::DataType; +use arrow_schema::{TimeUnit, UnionFields}; +use jni::objects::{JObject, JValue}; +use jni::sys::{jboolean, jint}; +use jni::JNIEnv; +use lance_core::datatypes::{Field, StorageClass}; + +pub fn convert_to_java_field<'local>( + env: &mut JNIEnv<'local>, + lance_field: &Field, +) -> Result> { + let name = env.new_string(&lance_field.name)?; + let children = convert_children_fields(env, lance_field)?; + let metadata = to_java_map(env, &lance_field.metadata)?; + let arrow_type = convert_arrow_type(env, &lance_field.data_type())?; + let storage_type = convert_storage_type(env, &lance_field.storage_class)?; + + let ctor_sig = "(IILjava/lang/String;".to_owned() + + "ZLorg/apache/arrow/vector/types/pojo/ArrowType;" + + "Lcom/lancedb/lance/schema/StorageType;" + + "Lorg/apache/arrow/vector/types/pojo/DictionaryEncoding;" + + "Ljava/util/Map;" + + "Ljava/util/List;Z)V"; + let field_obj = env.new_object( + "com/lancedb/lance/schema/LanceField", + ctor_sig.as_str(), + &[ + JValue::Int(lance_field.id as jint), + JValue::Int(lance_field.parent_id as jint), + JValue::Object(&JObject::from(name)), + JValue::Bool(lance_field.nullable as jboolean), + JValue::Object(&arrow_type), + JValue::Object(&storage_type), + JValue::Object(&JObject::null()), + JValue::Object(&metadata), + JValue::Object(&children), + JValue::Bool(lance_field.unenforced_primary_key as jboolean), + ], + )?; + + Ok(field_obj) +} + +fn convert_storage_type<'local>( + env: &mut JNIEnv<'local>, + storage_class: &StorageClass, +) -> Result> { + let jname = match storage_class { + StorageClass::Blob => env.new_string("BLOB")?, + _ => env.new_string("DEFAULT")?, + }; + + Ok(env + .call_static_method( + "com/lancedb/lance/schema/StorageType", + "valueOf", + "(Ljava/lang/String;)Lcom/lancedb/lance/schema/StorageType;", + &[JValue::Object(&JObject::from(jname))], + )? + .l()?) +} + +fn convert_children_fields<'local>( + env: &mut JNIEnv<'local>, + lance_field: &Field, +) -> Result> { + let children_list = env.new_object("java/util/ArrayList", "()V", &[])?; + for lance_field in lance_field.children.iter() { + let field = convert_to_java_field(env, lance_field)?; + env.call_method( + &children_list, + "add", + "(Ljava/lang/Object;)Z", + &[JValue::Object(&field)], + )?; + } + Ok(children_list) +} + +pub fn convert_arrow_type<'local>( + env: &mut JNIEnv<'local>, + arrow_type: &DataType, +) -> Result> { + match arrow_type { + DataType::Null => convert_null_type(env), + DataType::Boolean => convert_boolean_type(env), + DataType::Int8 => convert_int_type(env, 8, true), + DataType::Int16 => convert_int_type(env, 16, true), + DataType::Int32 => convert_int_type(env, 32, true), + DataType::Int64 => convert_int_type(env, 64, true), + DataType::UInt8 => convert_int_type(env, 8, false), + DataType::UInt16 => convert_int_type(env, 16, false), + DataType::UInt32 => convert_int_type(env, 32, false), + DataType::UInt64 => convert_int_type(env, 64, false), + DataType::Float16 => convert_floating_point_type(env, "HALF"), + DataType::Float32 => convert_floating_point_type(env, "SINGLE"), + DataType::Float64 => convert_floating_point_type(env, "DOUBLE"), + DataType::Utf8 => convert_utf8_type(env, false), + DataType::LargeUtf8 => convert_utf8_type(env, true), + DataType::Binary => convert_binary_type(env, false), + DataType::LargeBinary => convert_binary_type(env, true), + DataType::FixedSizeBinary(len) => convert_fixed_size_binary_type(env, *len), + DataType::Date32 => convert_date_type(env, "DAY"), + DataType::Date64 => convert_date_type(env, "MILLISECOND"), + DataType::Time32(unit) => convert_time_type(env, *unit, 32), + DataType::Time64(unit) => convert_time_type(env, *unit, 64), + DataType::Timestamp(unit, tz) => convert_timestamp_type(env, *unit, tz.as_deref()), + DataType::Duration(unit) => convert_duration_type(env, *unit), + DataType::Decimal128(precision, scale) => { + convert_decimal_type(env, *precision, *scale, 128) + } + DataType::Decimal256(precision, scale) => { + convert_decimal_type(env, *precision, *scale, 256) + } + DataType::List(..) => convert_list_type(env, false), + DataType::LargeList(..) => convert_list_type(env, true), + DataType::FixedSizeList(.., len) => convert_fixed_size_list_type(env, *len), + DataType::Struct(..) => convert_struct_type(env), + DataType::Union(fields, mode) => convert_union_type(env, fields, *mode), + DataType::Map(.., keys_sorted) => convert_map_type(env, *keys_sorted), + _ => Err(Error::input_error( + "ArrowSchema conversion error".to_string(), + )), + } +} + +fn convert_null_type<'local>(env: &mut JNIEnv<'local>) -> Result> { + Ok(env + .get_static_field( + "org/apache/arrow/vector/types/pojo/ArrowType$Null", + "INSTANCE", + "Lorg/apache/arrow/vector/types/pojo/ArrowType$Null;", + )? + .l()?) +} + +fn convert_boolean_type<'local>(env: &mut JNIEnv<'local>) -> Result> { + Ok(env + .get_static_field( + "org/apache/arrow/vector/types/pojo/ArrowType$Bool", + "INSTANCE", + "Lorg/apache/arrow/vector/types/pojo/ArrowType$Bool;", + )? + .l()?) +} + +fn convert_int_type<'local>( + env: &mut JNIEnv<'local>, + bit_width: i32, + is_signed: bool, +) -> Result> { + Ok(env.new_object( + "org/apache/arrow/vector/types/pojo/ArrowType$Int", + "(IZ)V", + &[ + JValue::Int(bit_width as jint), + JValue::Bool(is_signed as jboolean), + ], + )?) +} + +fn convert_floating_point_type<'local>( + env: &mut JNIEnv<'local>, + precision: &str, +) -> Result> { + let precision_enum = env + .get_static_field( + "org/apache/arrow/vector/types/FloatingPointPrecision", + precision, + "Lorg/apache/arrow/vector/types/FloatingPointPrecision;", + )? + .l()?; + + Ok(env.new_object( + "org/apache/arrow/vector/types/pojo/ArrowType$FloatingPoint", + "(Lorg/apache/arrow/vector/types/FloatingPointPrecision;)V", + &[JValue::Object(&precision_enum)], + )?) +} + +fn convert_utf8_type<'local>(env: &mut JNIEnv<'local>, is_large: bool) -> Result> { + let class_name = if is_large { + "org/apache/arrow/vector/types/pojo/ArrowType$LargeUtf8" + } else { + "org/apache/arrow/vector/types/pojo/ArrowType$Utf8" + }; + + convert_arrow_type_by_class_name(env, class_name) +} + +fn convert_binary_type<'local>( + env: &mut JNIEnv<'local>, + is_large: bool, +) -> Result> { + let class_name = if is_large { + "org/apache/arrow/vector/types/pojo/ArrowType$LargeBinary" + } else { + "org/apache/arrow/vector/types/pojo/ArrowType$Binary" + }; + + convert_arrow_type_by_class_name(env, class_name) +} + +fn convert_arrow_type_by_class_name<'local>( + env: &mut JNIEnv<'local>, + class_name: &str, +) -> Result> { + let class = env.find_class(class_name)?; + let field_sig = format!("L{};", class_name); + let instance = env.get_static_field(class, "INSTANCE", &field_sig)?.l()?; + Ok(instance) +} + +fn convert_fixed_size_binary_type<'local>( + env: &mut JNIEnv<'local>, + byte_width: i32, +) -> Result> { + let class = env.find_class("org/apache/arrow/vector/types/pojo/ArrowType$FixedSizeBinary")?; + Ok(env.new_object(class, "(I)V", &[JValue::Int(byte_width)])?) +} + +fn convert_date_type<'local>(env: &mut JNIEnv<'local>, unit: &str) -> Result> { + let class = env.find_class("org/apache/arrow/vector/types/pojo/ArrowType$Date")?; + let unit_enum = env + .get_static_field( + "org/apache/arrow/vector/types/DateUnit", + unit, + "Lorg/apache/arrow/vector/types/DateUnit;", + )? + .l()?; + + Ok(env.new_object( + class, + "(Lorg/apache/arrow/vector/types/DateUnit;)V", + &[JValue::Object(&unit_enum)], + )?) +} + +fn convert_time_type<'local>( + env: &mut JNIEnv<'local>, + unit: TimeUnit, + bit_width: i32, +) -> Result> { + let class = env.find_class("org/apache/arrow/vector/types/pojo/ArrowType$Time")?; + let unit_str = match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }; + + let unit_enum = env + .get_static_field( + "org/apache/arrow/vector/types/TimeUnit", + unit_str, + "Lorg/apache/arrow/vector/types/TimeUnit;", + )? + .l()?; + + Ok(env.new_object( + class, + "(Lorg/apache/arrow/vector/types/TimeUnit;I)V", + &[JValue::Object(&unit_enum), JValue::Int(bit_width)], + )?) +} + +fn convert_timestamp_type<'local>( + env: &mut JNIEnv<'local>, + unit: TimeUnit, + timezone: Option<&str>, +) -> Result> { + let class = env.find_class("org/apache/arrow/vector/types/pojo/ArrowType$Timestamp")?; + let unit_str = match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }; + + let unit_enum = env + .get_static_field( + "org/apache/arrow/vector/types/TimeUnit", + unit_str, + "Lorg/apache/arrow/vector/types/TimeUnit;", + )? + .l()?; + + let timezone_str = timezone.unwrap_or("-"); + let j_timezone = env.new_string(timezone_str)?; + + Ok(env.new_object( + class, + "(Lorg/apache/arrow/vector/types/TimeUnit;Ljava/lang/String;)V", + &[JValue::Object(&unit_enum), JValue::Object(&j_timezone)], + )?) +} + +fn convert_duration_type<'local>( + env: &mut JNIEnv<'local>, + unit: TimeUnit, +) -> Result> { + let class = env.find_class("org/apache/arrow/vector/types/pojo/ArrowType$Duration")?; + let unit_str = match unit { + TimeUnit::Second => "SECOND", + TimeUnit::Millisecond => "MILLISECOND", + TimeUnit::Microsecond => "MICROSECOND", + TimeUnit::Nanosecond => "NANOSECOND", + }; + + let unit_enum = env + .get_static_field( + "org/apache/arrow/vector/types/TimeUnit", + unit_str, + "Lorg/apache/arrow/vector/types/TimeUnit;", + )? + .l()?; + + Ok(env.new_object( + class, + "(Lorg/apache/arrow/vector/types/TimeUnit;)V", + &[JValue::Object(&unit_enum)], + )?) +} + +fn convert_decimal_type<'local>( + env: &mut JNIEnv<'local>, + precision: u8, + scale: i8, + bit_width: i32, +) -> Result> { + let class = env.find_class("org/apache/arrow/vector/types/pojo/ArrowType$Decimal")?; + Ok(env.new_object( + class, + "(III)V", + &[ + JValue::Int(precision as jint), + JValue::Int(scale as jint), + JValue::Int(bit_width), + ], + )?) +} + +fn convert_list_type<'local>(env: &mut JNIEnv<'local>, is_large: bool) -> Result> { + let class_name = if is_large { + "org/apache/arrow/vector/types/pojo/ArrowType$LargeList" + } else { + "org/apache/arrow/vector/types/pojo/ArrowType$List" + }; + + convert_arrow_type_by_class_name(env, class_name) +} + +fn convert_fixed_size_list_type<'local>( + env: &mut JNIEnv<'local>, + list_size: i32, +) -> Result> { + Ok(env.new_object( + "org/apache/arrow/vector/types/pojo/ArrowType$FixedSizeList", + "(I)V", + &[JValue::Int(list_size)], + )?) +} + +fn convert_struct_type<'local>(env: &mut JNIEnv<'local>) -> Result> { + Ok(env + .get_static_field( + "org/apache/arrow/vector/types/pojo/ArrowType$Struct", + "INSTANCE", + "Lorg/apache/arrow/vector/types/pojo/ArrowType$Struct;", + )? + .l()?) +} + +fn convert_union_type<'local>( + env: &mut JNIEnv<'local>, + fields: &UnionFields, + mode: arrow_schema::UnionMode, +) -> Result> { + let class = env.find_class("org/apache/arrow/vector/types/pojo/ArrowType$Union")?; + + let mode_str = match mode { + arrow_schema::UnionMode::Sparse => "SPARSE", + arrow_schema::UnionMode::Dense => "DENSE", + }; + let mode_enum = env + .get_static_field( + "org/apache/arrow/vector/types/UnionMode", + mode_str, + "Lorg/apache/arrow/vector/types/UnionMode;", + )? + .l()?; + + let jarray = env.new_int_array(fields.size() as jint)?; + + let mut rust_array = vec![0; fields.size()]; + for (i, (type_id, _)) in fields.iter().enumerate() { + rust_array[i] = type_id as i32; + } + env.set_int_array_region(&jarray, 0, &rust_array)?; + + Ok(env.new_object( + class, + "(Lorg/apache/arrow/vector/types/UnionMode;[I)V", + &[JValue::Object(&mode_enum), JValue::Object(&jarray)], + )?) +} + +fn convert_map_type<'local>( + env: &mut JNIEnv<'local>, + keys_sorted: bool, +) -> Result> { + Ok(env.new_object( + "org/apache/arrow/vector/types/pojo/ArrowType$Map", + "(Z)V", + &[JValue::Bool(keys_sorted as jboolean)], + )?) +} diff --git a/java/core/lance-jni/src/utils.rs b/java/core/lance-jni/src/utils.rs index 937fc0e44e6..7d46de27792 100644 --- a/java/core/lance-jni/src/utils.rs +++ b/java/core/lance-jni/src/utils.rs @@ -15,7 +15,7 @@ use std::sync::Arc; use arrow::array::Float32Array; -use jni::objects::{JMap, JObject, JString}; +use jni::objects::{JMap, JObject, JString, JValue}; use jni::JNIEnv; use lance::dataset::{WriteMode, WriteParams}; use lance::index::vector::{IndexFileVersion, StageParams, VectorIndexParams}; @@ -295,3 +295,21 @@ pub fn to_rust_map(env: &mut JNIEnv, jmap: &JMap) -> Result(map) }) } + +pub fn to_java_map<'local>( + env: &mut JNIEnv<'local>, + map: &HashMap, +) -> Result> { + let java_map = env.new_object("java/util/HashMap", "()V", &[])?; + for (k, v) in map { + let jkey = env.new_string(k)?; + let jval = env.new_string(v)?; + env.call_method( + &java_map, + "put", + "(Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;", + &[JValue::Object(&jkey), JValue::Object(&jval)], + )?; + } + Ok(java_map) +} diff --git a/java/core/src/main/java/com/lancedb/lance/Dataset.java b/java/core/src/main/java/com/lancedb/lance/Dataset.java index 3fb1fbda42b..191025c9756 100644 --- a/java/core/src/main/java/com/lancedb/lance/Dataset.java +++ b/java/core/src/main/java/com/lancedb/lance/Dataset.java @@ -19,6 +19,7 @@ import com.lancedb.lance.ipc.LanceScanner; import com.lancedb.lance.ipc.ScanOptions; import com.lancedb.lance.schema.ColumnAlteration; +import com.lancedb.lance.schema.LanceSchema; import com.lancedb.lance.schema.SqlExpressions; import org.apache.arrow.c.ArrowArrayStream; @@ -651,7 +652,7 @@ public List getFragments() { private native List getFragmentsNative(); /** - * Gets the schema of the dataset. + * Gets the arrow schema of the dataset. * * @return the arrow schema */ @@ -667,6 +668,20 @@ public Schema getSchema() { private native void importFfiSchema(long arrowSchemaMemoryAddress); + /** + * Get the {@link com.lancedb.lance.schema.LanceSchema} of the dataset with field ids. + * + * @return the LanceSchema + */ + public LanceSchema getLanceSchema() { + try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { + Preconditions.checkArgument(nativeDatasetHandle != 0, "Dataset is closed"); + return nativeGetLanceSchema(); + } + } + + private native LanceSchema nativeGetLanceSchema(); + /** @return all the created indexes names */ public List listIndexes() { try (LockManager.ReadLock readLock = lockManager.acquireReadLock()) { diff --git a/java/core/src/main/java/com/lancedb/lance/schema/LanceField.java b/java/core/src/main/java/com/lancedb/lance/schema/LanceField.java new file mode 100644 index 00000000000..17ec1e1c1b4 --- /dev/null +++ b/java/core/src/main/java/com/lancedb/lance/schema/LanceField.java @@ -0,0 +1,135 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.schema; + +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +public class LanceField { + + private final int id; + private final int parentId; + private final String name; + private final boolean nullable; + private final ArrowType type; + private final StorageType storageType; + private final DictionaryEncoding dictionaryEncoding; + private final Map metadata; + private final List children; + private final boolean isUnenforcedPrimaryKey; + + LanceField( + int id, + int parentId, + String name, + boolean nullable, + ArrowType type, + StorageType storageType, + DictionaryEncoding dictionaryEncoding, + Map metadata, + List children, + boolean isUnenforcedPrimaryKey) { + this.id = id; + this.parentId = parentId; + this.name = name; + this.nullable = nullable; + this.type = type; + this.storageType = storageType; + this.dictionaryEncoding = dictionaryEncoding; + this.metadata = metadata; + this.children = children; + this.isUnenforcedPrimaryKey = isUnenforcedPrimaryKey; + } + + public int getId() { + return id; + } + + public int getParentId() { + return parentId; + } + + public String getName() { + return name; + } + + public boolean isNullable() { + return nullable; + } + + public ArrowType getType() { + return type; + } + + public StorageType getStorageType() { + return storageType; + } + + public Optional getDictionaryEncoding() { + return Optional.ofNullable(dictionaryEncoding); + } + + public Map getMetadata() { + return metadata; + } + + public List getChildren() { + return children; + } + + public boolean isUnenforcedPrimaryKey() { + return isUnenforcedPrimaryKey; + } + + public Field asArrowField() { + List arrowChildren = + children.stream().map(LanceField::asArrowField).collect(Collectors.toList()); + return new Field( + name, new FieldType(nullable, type, dictionaryEncoding, metadata), arrowChildren); + } + + @Override + public String toString() { + return "LanceField{" + + "id=" + + id + + ", parentId=" + + parentId + + ", name='" + + name + + '\'' + + ", nullable=" + + nullable + + ", type=" + + type + + ", storageType=" + + storageType + + ", dictionaryEncoding=" + + dictionaryEncoding + + ", metadata=" + + metadata + + ", children=" + + children + + ", isUnenforcedPrimaryKey=" + + isUnenforcedPrimaryKey + + '}'; + } +} diff --git a/java/core/src/main/java/com/lancedb/lance/schema/LanceSchema.java b/java/core/src/main/java/com/lancedb/lance/schema/LanceSchema.java new file mode 100644 index 00000000000..499c7633793 --- /dev/null +++ b/java/core/src/main/java/com/lancedb/lance/schema/LanceSchema.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.schema; + +import org.apache.arrow.vector.types.pojo.Schema; + +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +public class LanceSchema { + + private final List fields; + private final Map metadata; + + LanceSchema(List fields, Map metadata) { + this.fields = fields; + this.metadata = metadata; + } + + public List fields() { + return fields; + } + + public Map metadata() { + return Collections.unmodifiableMap(metadata); + } + + public Schema asArrowSchema() { + return new Schema( + fields.stream().map(LanceField::asArrowField).collect(Collectors.toList()), metadata); + } + + // Builder class for LanceSchema + private static class Builder { + private List fields; + private Map metadata; + + Builder() {} + + public Builder withFields(List fields) { + this.fields = fields; + return this; + } + + public Builder withMetadata(Map metadata) { + this.metadata = metadata; + return this; + } + + public LanceSchema build() { + return new LanceSchema(fields, metadata); + } + } +} diff --git a/java/core/src/main/java/com/lancedb/lance/schema/StorageType.java b/java/core/src/main/java/com/lancedb/lance/schema/StorageType.java new file mode 100644 index 00000000000..829189ae0ca --- /dev/null +++ b/java/core/src/main/java/com/lancedb/lance/schema/StorageType.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.lancedb.lance.schema; + +public enum StorageType { + DEFAULT, + BLOB +} diff --git a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java index b11dbd411b1..f6da52e6bfe 100644 --- a/java/core/src/test/java/com/lancedb/lance/DatasetTest.java +++ b/java/core/src/test/java/com/lancedb/lance/DatasetTest.java @@ -836,4 +836,16 @@ void testDeleteConfigKeys() { assertFalse(currentConfig.containsKey("key1")); } } + + @Test + void testGetLanceSchema() { + String testMethodName = new Object() {}.getClass().getEnclosingMethod().getName(); + String datasetPath = tempDir.resolve(testMethodName).toString(); + try (RootAllocator allocator = new RootAllocator(Long.MAX_VALUE)) { + TestUtils.ComplexTestDataset testDataset = + new TestUtils.ComplexTestDataset(allocator, datasetPath); + dataset = testDataset.createEmptyDataset(); + assertEquals(testDataset.getSchema(), dataset.getLanceSchema().asArrowSchema()); + } + } } diff --git a/java/core/src/test/java/com/lancedb/lance/TestUtils.java b/java/core/src/test/java/com/lancedb/lance/TestUtils.java index dac8db2b628..445de6e4ce1 100644 --- a/java/core/src/test/java/com/lancedb/lance/TestUtils.java +++ b/java/core/src/test/java/com/lancedb/lance/TestUtils.java @@ -24,8 +24,12 @@ import org.apache.arrow.vector.ipc.ArrowFileReader; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.ipc.SeekableReadChannel; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel; @@ -37,6 +41,7 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Optional; @@ -46,30 +51,22 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class TestUtils { - public static class SimpleTestDataset { - private final Schema schema = - new Schema( - Arrays.asList( - Field.nullable("id", new ArrowType.Int(32, true)), - Field.nullable("name", new ArrowType.Utf8())), - null); + private abstract static class TestDataset { private final BufferAllocator allocator; private final String datasetPath; - public SimpleTestDataset(BufferAllocator allocator, String datasetPath) { + public TestDataset(BufferAllocator allocator, String datasetPath) { this.allocator = allocator; this.datasetPath = datasetPath; } - public Schema getSchema() { - return schema; - } + public abstract Schema getSchema(); public Dataset createEmptyDataset() { Dataset dataset = - Dataset.create(allocator, datasetPath, schema, new WriteParams.Builder().build()); + Dataset.create(allocator, datasetPath, getSchema(), new WriteParams.Builder().build()); assertEquals(0, dataset.countRows()); - assertEquals(schema, dataset.getSchema()); + assertEquals(getSchema(), dataset.getSchema()); List fragments = dataset.getFragments(); assertEquals(0, fragments.size()); assertEquals(1, dataset.version()); @@ -87,7 +84,7 @@ public FragmentMetadata createNewFragment(int rowCount) { public List createNewFragment(int rowCount, int maxRowsPerFile) { List fragmentMetas; - try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + try (VectorSchemaRoot root = VectorSchemaRoot.create(getSchema(), allocator)) { root.allocateNew(); IntVector idVector = (IntVector) root.getVector("id"); VarCharVector nameVector = (VarCharVector) root.getVector("name"); @@ -117,7 +114,7 @@ public Dataset write(long version, int rowCount) { public Dataset writeSortByDataset(long version) { List fragmentMetas; - try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + try (VectorSchemaRoot root = VectorSchemaRoot.create(getSchema(), allocator)) { root.allocateNew(); IntVector idVector = (IntVector) root.getVector("id"); VarCharVector nameVector = (VarCharVector) root.getVector("name"); @@ -282,4 +279,74 @@ public static ByteBuffer getSubstraitByteBuffer(String substrait) { substraitExpression.put(decodedSubstrait); return substraitExpression; } + + public static class SimpleTestDataset extends TestDataset { + private static final Schema schema = + new Schema( + Arrays.asList( + Field.nullable("id", new ArrowType.Int(32, true)), + Field.nullable("name", new ArrowType.Utf8())), + null); + + public SimpleTestDataset(BufferAllocator allocator, String datasetPath) { + super(allocator, datasetPath); + } + + @Override + public Schema getSchema() { + return schema; + } + } + + public static class ComplexTestDataset extends TestDataset { + public static final Schema COMPLETE_SCHEMA = + new Schema( + Arrays.asList( + // basic scalar types + Field.nullable("null_col", ArrowType.Null.INSTANCE), + Field.nullable("bool_col", ArrowType.Bool.INSTANCE), + Field.nullable("int8_col", new ArrowType.Int(8, true)), + Field.nullable("uint32_col", new ArrowType.Int(32, false)), + Field.nullable( + "float64_col", new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)), + + // strings and binary types + Field.nullable("utf8_col", ArrowType.Utf8.INSTANCE), + Field.nullable("large_utf8_col", ArrowType.LargeUtf8.INSTANCE), + Field.nullable("binary_col", ArrowType.Binary.INSTANCE), + Field.nullable("fixed_binary_col", new ArrowType.FixedSizeBinary(16)), + + // time and date types + Field.notNullable("date32_col", new ArrowType.Date(DateUnit.DAY)), + Field.nullable( + "timestamp_col", new ArrowType.Timestamp(TimeUnit.MICROSECOND, "UTC")), + Field.nullable("time64_nano_col", new ArrowType.Time(TimeUnit.NANOSECOND, 64)), + + // decimals + Field.notNullable("decimal128_col", new ArrowType.Decimal(38, 10, 128)), + Field.nullable("decimal256_col", new ArrowType.Decimal(76, 20, 256)), + + // nested types + new Field( + "list_col", + FieldType.nullable(new ArrowType.List()), + Collections.singletonList(Field.nullable("item", new ArrowType.Int(32, true)))), + + // struct and union types + new Field( + "struct_col", + FieldType.nullable(new ArrowType.Struct()), + Arrays.asList( + Field.nullable("field1", ArrowType.Utf8.INSTANCE), + Field.nullable("field2", new ArrowType.Int(16, true)))))); + + public ComplexTestDataset(BufferAllocator allocator, String datasetPath) { + super(allocator, datasetPath); + } + + @Override + public Schema getSchema() { + return COMPLETE_SCHEMA; + } + } }