diff --git a/rust/lance-namespace/src/schema.rs b/rust/lance-namespace/src/schema.rs index 3f44847bcd4..69aa59a51e9 100644 --- a/rust/lance-namespace/src/schema.rs +++ b/rust/lance-namespace/src/schema.rs @@ -248,9 +248,12 @@ pub fn convert_json_arrow_field(json_field: &JsonArrowField) -> Result { /// Convert JsonArrowDataType to Arrow DataType pub fn convert_json_arrow_type(json_type: &JsonArrowDataType) -> Result { + use std::sync::Arc; + let type_name = json_type.r#type.to_lowercase(); match type_name.as_str() { + // Primitive types "null" => Ok(DataType::Null), "bool" | "boolean" => Ok(DataType::Boolean), "int8" => Ok(DataType::Int8), @@ -261,10 +264,108 @@ pub fn convert_json_arrow_type(json_type: &JsonArrowDataType) -> Result Ok(DataType::UInt32), "int64" => Ok(DataType::Int64), "uint64" => Ok(DataType::UInt64), + "float16" => Ok(DataType::Float16), "float32" => Ok(DataType::Float32), "float64" => Ok(DataType::Float64), + + // Decimal types - encoding: precision * 1000 + scale + // Decoding must handle negative scale: precision = ((encoded + 128) / 1000) + "decimal32" => { + let encoded = json_type.length.unwrap_or(0); + let precision = ((encoded + 128) / 1000) as u8; + let scale = (encoded - precision as i64 * 1000) as i8; + Ok(DataType::Decimal32(precision, scale)) + } + "decimal64" => { + let encoded = json_type.length.unwrap_or(0); + let precision = ((encoded + 128) / 1000) as u8; + let scale = (encoded - precision as i64 * 1000) as i8; + Ok(DataType::Decimal64(precision, scale)) + } + "decimal128" => { + let encoded = json_type.length.unwrap_or(0); + let precision = ((encoded + 128) / 1000) as u8; + let scale = (encoded - precision as i64 * 1000) as i8; + Ok(DataType::Decimal128(precision, scale)) + } + "decimal256" => { + let encoded = json_type.length.unwrap_or(0); + let precision = ((encoded + 128) / 1000) as u8; + let scale = (encoded - precision as i64 * 1000) as i8; + Ok(DataType::Decimal256(precision, scale)) + } + + // Date/Time types + "date32" => Ok(DataType::Date32), + "date64" => Ok(DataType::Date64), + "timestamp" => Ok(DataType::Timestamp( + arrow::datatypes::TimeUnit::Microsecond, + None, + )), + "duration" => Ok(DataType::Duration(arrow::datatypes::TimeUnit::Microsecond)), + + // String and Binary types "utf8" => Ok(DataType::Utf8), + "large_utf8" => Ok(DataType::LargeUtf8), "binary" => Ok(DataType::Binary), + "large_binary" => Ok(DataType::LargeBinary), + "fixed_size_binary" => { + let size = json_type.length.unwrap_or(0) as i32; + Ok(DataType::FixedSizeBinary(size)) + } + + // Nested types + "list" => { + let inner = json_type + .fields + .as_ref() + .and_then(|f| f.first()) + .ok_or_else(|| Error::namespace("list type missing inner field"))?; + Ok(DataType::List(Arc::new(convert_json_arrow_field(inner)?))) + } + "large_list" => { + let inner = json_type + .fields + .as_ref() + .and_then(|f| f.first()) + .ok_or_else(|| Error::namespace("large_list type missing inner field"))?; + Ok(DataType::LargeList(Arc::new(convert_json_arrow_field( + inner, + )?))) + } + "fixed_size_list" => { + let inner = json_type + .fields + .as_ref() + .and_then(|f| f.first()) + .ok_or_else(|| Error::namespace("fixed_size_list type missing inner field"))?; + let size = json_type.length.unwrap_or(0) as i32; + Ok(DataType::FixedSizeList( + Arc::new(convert_json_arrow_field(inner)?), + size, + )) + } + "struct" => { + let fields = json_type + .fields + .as_ref() + .ok_or_else(|| Error::namespace("struct type missing fields"))?; + let arrow_fields: Result> = + fields.iter().map(convert_json_arrow_field).collect(); + Ok(DataType::Struct(arrow_fields?.into())) + } + "map" => { + let entries = json_type + .fields + .as_ref() + .and_then(|f| f.first()) + .ok_or_else(|| Error::namespace("map type missing entries field"))?; + Ok(DataType::Map( + Arc::new(convert_json_arrow_field(entries)?), + false, + )) + } + _ => Err(Error::namespace(format!( "Unsupported Arrow type: {}", type_name @@ -524,4 +625,205 @@ mod tests { let float16 = arrow_type_to_json(&DataType::Float16).unwrap(); assert_eq!(float16.r#type, "float16"); } + + /// Verify that convert_json_arrow_type (deserialization) is the inverse of + /// arrow_type_to_json (serialization) for all supported types. + #[test] + fn test_json_arrow_type_roundtrip() { + use arrow::datatypes::Field; + + let cases: Vec = vec![ + // Scalars + DataType::Null, + DataType::Boolean, + DataType::Int8, + DataType::UInt8, + DataType::Int16, + DataType::UInt16, + DataType::Int32, + DataType::UInt32, + DataType::Int64, + DataType::UInt64, + DataType::Float16, + DataType::Float32, + DataType::Float64, + DataType::Utf8, + DataType::LargeUtf8, + DataType::Binary, + DataType::LargeBinary, + DataType::Date32, + DataType::Date64, + DataType::FixedSizeBinary(16), + // Decimal types with positive and negative scales + DataType::Decimal32(10, -2), + DataType::Decimal32(9, 3), + DataType::Decimal64(18, -5), + DataType::Decimal64(10, 4), + DataType::Decimal128(9, -2), + DataType::Decimal128(38, 10), + DataType::Decimal256(38, 10), + DataType::Decimal256(76, -10), + // Timestamp and Duration + DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None), + DataType::Duration(arrow::datatypes::TimeUnit::Microsecond), + // Nested + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + DataType::LargeList(Arc::new(Field::new("item", DataType::Utf8, true))), + DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, false)), 128), + DataType::Struct( + vec![ + Field::new("a", DataType::Int64, false), + Field::new("b", DataType::Utf8, true), + ] + .into(), + ), + // Map + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + vec![ + Field::new("keys", DataType::Utf8, false), + Field::new("values", DataType::Int32, true), + ] + .into(), + ), + false, + )), + false, + ), + ]; + + for dt in &cases { + let json = arrow_type_to_json(dt) + .unwrap_or_else(|e| panic!("arrow_type_to_json failed for {:?}: {}", dt, e)); + let back = convert_json_arrow_type(&json) + .unwrap_or_else(|e| panic!("convert_json_arrow_type failed for {:?}: {}", dt, e)); + assert_eq!(&back, dt, "Roundtrip mismatch for {:?}: got {:?}", dt, back); + } + } + + #[test] + fn test_decimal_negative_scale_roundtrip() { + // Explicitly test the cases requested by reviewer + let cases = vec![ + DataType::Decimal32(10, -2), + DataType::Decimal128(9, -2), + DataType::Decimal256(38, 10), + ]; + for dt in &cases { + let json = arrow_type_to_json(dt).unwrap(); + let back = convert_json_arrow_type(&json).unwrap(); + assert_eq!(&back, dt, "Decimal roundtrip failed for {:?}", dt); + } + } + + #[test] + fn test_schema_with_metadata_roundtrip() { + let mut metadata = HashMap::new(); + metadata.insert("key1".to_string(), "value1".to_string()); + metadata.insert("key2".to_string(), "value2".to_string()); + + let arrow_schema = ArrowSchema::new_with_metadata( + vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ], + metadata.clone(), + ); + + let json_schema = arrow_schema_to_json(&arrow_schema).unwrap(); + assert_eq!(json_schema.metadata.as_ref().unwrap(), &metadata); + + let roundtrip = convert_json_arrow_schema(&json_schema).unwrap(); + assert_eq!(roundtrip.metadata(), &metadata); + } + + #[test] + fn test_dictionary_type_unwraps_to_value_type() { + let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let json = arrow_type_to_json(&dict_type).unwrap(); + assert_eq!(json.r#type, "utf8"); + } + + #[test] + fn test_map_keys_sorted_unsupported() { + let map_type = DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + vec![ + Field::new("keys", DataType::Utf8, false), + Field::new("values", DataType::Int32, true), + ] + .into(), + ), + false, + )), + true, // keys_sorted = true + ); + let result = arrow_type_to_json(&map_type); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("keys_sorted=true")); + } + + #[test] + fn test_unsupported_types_error() { + // RunEndEncoded + let ree = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + assert!(arrow_type_to_json(&ree).is_err()); + + // ListView + let lv = DataType::ListView(Arc::new(Field::new("item", DataType::Int32, true))); + assert!(arrow_type_to_json(&lv).is_err()); + + // LargeListView + let llv = DataType::LargeListView(Arc::new(Field::new("item", DataType::Int32, true))); + assert!(arrow_type_to_json(&llv).is_err()); + + // Utf8View / BinaryView + assert!(arrow_type_to_json(&DataType::Utf8View).is_err()); + assert!(arrow_type_to_json(&DataType::BinaryView).is_err()); + } + + #[test] + fn test_large_list_roundtrip() { + let inner_field = Field::new("item", DataType::Float64, true); + let large_list = DataType::LargeList(Arc::new(inner_field)); + + let json = arrow_type_to_json(&large_list).unwrap(); + assert_eq!(json.r#type, "large_list"); + + let back = convert_json_arrow_type(&json).unwrap(); + assert_eq!(back, large_list); + } + + #[test] + fn test_field_with_metadata_roundtrip() { + let mut field_meta = HashMap::new(); + field_meta.insert("custom_key".to_string(), "custom_val".to_string()); + + let field = Field::new("col", DataType::Int64, false).with_metadata(field_meta.clone()); + let schema = ArrowSchema::new(vec![field]); + + let json_schema = arrow_schema_to_json(&schema).unwrap(); + let roundtrip = convert_json_arrow_schema(&json_schema).unwrap(); + assert_eq!(roundtrip.field(0).metadata(), &field_meta); + } + + #[test] + fn test_nested_list_with_field_metadata() { + let mut meta = HashMap::new(); + meta.insert("encoding".to_string(), "delta".to_string()); + + let inner = Field::new("item", DataType::Int32, true).with_metadata(meta.clone()); + let list_type = DataType::List(Arc::new(inner)); + + let json = arrow_type_to_json(&list_type).unwrap(); + let fields = json.fields.as_ref().unwrap(); + assert_eq!(fields[0].metadata.as_ref().unwrap(), &meta); + } }