Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 302 additions & 0 deletions rust/lance-namespace/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,12 @@ pub fn convert_json_arrow_field(json_field: &JsonArrowField) -> Result<Field> {

/// Convert JsonArrowDataType to Arrow DataType
pub fn convert_json_arrow_type(json_type: &JsonArrowDataType) -> Result<DataType> {
use std::sync::Arc;

let type_name = json_type.r#type.to_lowercase();

match type_name.as_str() {
// Primitive types
"null" => Ok(DataType::Null),
"bool" | "boolean" => Ok(DataType::Boolean),
"int8" => Ok(DataType::Int8),
Expand All @@ -261,10 +264,108 @@ pub fn convert_json_arrow_type(json_type: &JsonArrowDataType) -> Result<DataType
"uint32" => Ok(DataType::UInt32),
"int64" => Ok(DataType::Int64),
"uint64" => Ok(DataType::UInt64),
"float16" => Ok(DataType::Float16),
"float32" => Ok(DataType::Float32),
"float64" => Ok(DataType::Float64),

// Decimal types - encoding: precision * 1000 + scale
// Decoding must handle negative scale: precision = ((encoded + 128) / 1000)
"decimal32" => {
let encoded = json_type.length.unwrap_or(0);
let precision = ((encoded + 128) / 1000) as u8;
let scale = (encoded - precision as i64 * 1000) as i8;
Ok(DataType::Decimal32(precision, scale))
}
"decimal64" => {
let encoded = json_type.length.unwrap_or(0);
let precision = ((encoded + 128) / 1000) as u8;
let scale = (encoded - precision as i64 * 1000) as i8;
Ok(DataType::Decimal64(precision, scale))
}
"decimal128" => {
let encoded = json_type.length.unwrap_or(0);
let precision = ((encoded + 128) / 1000) as u8;
let scale = (encoded - precision as i64 * 1000) as i8;
Ok(DataType::Decimal128(precision, scale))
}
"decimal256" => {
let encoded = json_type.length.unwrap_or(0);
let precision = ((encoded + 128) / 1000) as u8;
let scale = (encoded - precision as i64 * 1000) as i8;
Ok(DataType::Decimal256(precision, scale))
}

// Date/Time types
"date32" => Ok(DataType::Date32),
"date64" => Ok(DataType::Date64),
"timestamp" => Ok(DataType::Timestamp(
arrow::datatypes::TimeUnit::Microsecond,
None,
)),
"duration" => Ok(DataType::Duration(arrow::datatypes::TimeUnit::Microsecond)),

// String and Binary types
"utf8" => Ok(DataType::Utf8),
"large_utf8" => Ok(DataType::LargeUtf8),
"binary" => Ok(DataType::Binary),
"large_binary" => Ok(DataType::LargeBinary),
"fixed_size_binary" => {
let size = json_type.length.unwrap_or(0) as i32;
Ok(DataType::FixedSizeBinary(size))
}

// Nested types
"list" => {
let inner = json_type
.fields
.as_ref()
.and_then(|f| f.first())
.ok_or_else(|| Error::namespace("list type missing inner field"))?;
Ok(DataType::List(Arc::new(convert_json_arrow_field(inner)?)))
}
"large_list" => {
let inner = json_type
.fields
.as_ref()
.and_then(|f| f.first())
.ok_or_else(|| Error::namespace("large_list type missing inner field"))?;
Ok(DataType::LargeList(Arc::new(convert_json_arrow_field(
inner,
)?)))
}
"fixed_size_list" => {
let inner = json_type
.fields
.as_ref()
.and_then(|f| f.first())
.ok_or_else(|| Error::namespace("fixed_size_list type missing inner field"))?;
let size = json_type.length.unwrap_or(0) as i32;
Ok(DataType::FixedSizeList(
Arc::new(convert_json_arrow_field(inner)?),
size,
))
}
"struct" => {
let fields = json_type
.fields
.as_ref()
.ok_or_else(|| Error::namespace("struct type missing fields"))?;
let arrow_fields: Result<Vec<Field>> =
fields.iter().map(convert_json_arrow_field).collect();
Ok(DataType::Struct(arrow_fields?.into()))
}
"map" => {
let entries = json_type
.fields
.as_ref()
.and_then(|f| f.first())
.ok_or_else(|| Error::namespace("map type missing entries field"))?;
Ok(DataType::Map(
Arc::new(convert_json_arrow_field(entries)?),
false,
))
}

_ => Err(Error::namespace(format!(
"Unsupported Arrow type: {}",
type_name
Expand Down Expand Up @@ -524,4 +625,205 @@ mod tests {
let float16 = arrow_type_to_json(&DataType::Float16).unwrap();
assert_eq!(float16.r#type, "float16");
}

/// Verify that convert_json_arrow_type (deserialization) is the inverse of
/// arrow_type_to_json (serialization) for all supported types.
#[test]
fn test_json_arrow_type_roundtrip() {
use arrow::datatypes::Field;

let cases: Vec<DataType> = vec![
// Scalars
DataType::Null,
DataType::Boolean,
DataType::Int8,
DataType::UInt8,
DataType::Int16,
DataType::UInt16,
DataType::Int32,
DataType::UInt32,
DataType::Int64,
DataType::UInt64,
DataType::Float16,
DataType::Float32,
DataType::Float64,
DataType::Utf8,
DataType::LargeUtf8,
DataType::Binary,
DataType::LargeBinary,
DataType::Date32,
DataType::Date64,
DataType::FixedSizeBinary(16),
// Decimal types with positive and negative scales
DataType::Decimal32(10, -2),
DataType::Decimal32(9, 3),
DataType::Decimal64(18, -5),
DataType::Decimal64(10, 4),
DataType::Decimal128(9, -2),
DataType::Decimal128(38, 10),
DataType::Decimal256(38, 10),
DataType::Decimal256(76, -10),
// Timestamp and Duration
DataType::Timestamp(arrow::datatypes::TimeUnit::Microsecond, None),
DataType::Duration(arrow::datatypes::TimeUnit::Microsecond),
// Nested
DataType::List(Arc::new(Field::new("item", DataType::Int32, true))),
DataType::LargeList(Arc::new(Field::new("item", DataType::Utf8, true))),
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, false)), 128),
DataType::Struct(
vec![
Field::new("a", DataType::Int64, false),
Field::new("b", DataType::Utf8, true),
]
.into(),
),
// Map
DataType::Map(
Arc::new(Field::new(
"entries",
DataType::Struct(
vec![
Field::new("keys", DataType::Utf8, false),
Field::new("values", DataType::Int32, true),
]
.into(),
),
false,
)),
false,
),
];

for dt in &cases {
let json = arrow_type_to_json(dt)
.unwrap_or_else(|e| panic!("arrow_type_to_json failed for {:?}: {}", dt, e));
let back = convert_json_arrow_type(&json)
.unwrap_or_else(|e| panic!("convert_json_arrow_type failed for {:?}: {}", dt, e));
assert_eq!(&back, dt, "Roundtrip mismatch for {:?}: got {:?}", dt, back);
}
}

#[test]
fn test_decimal_negative_scale_roundtrip() {
// Explicitly test the cases requested by reviewer
let cases = vec![
DataType::Decimal32(10, -2),
DataType::Decimal128(9, -2),
DataType::Decimal256(38, 10),
];
for dt in &cases {
let json = arrow_type_to_json(dt).unwrap();
let back = convert_json_arrow_type(&json).unwrap();
assert_eq!(&back, dt, "Decimal roundtrip failed for {:?}", dt);
}
}

#[test]
fn test_schema_with_metadata_roundtrip() {
let mut metadata = HashMap::new();
metadata.insert("key1".to_string(), "value1".to_string());
metadata.insert("key2".to_string(), "value2".to_string());

let arrow_schema = ArrowSchema::new_with_metadata(
vec![
Field::new("id", DataType::Int32, false),
Field::new("name", DataType::Utf8, true),
],
metadata.clone(),
);

let json_schema = arrow_schema_to_json(&arrow_schema).unwrap();
assert_eq!(json_schema.metadata.as_ref().unwrap(), &metadata);

let roundtrip = convert_json_arrow_schema(&json_schema).unwrap();
assert_eq!(roundtrip.metadata(), &metadata);
}

#[test]
fn test_dictionary_type_unwraps_to_value_type() {
let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8));
let json = arrow_type_to_json(&dict_type).unwrap();
assert_eq!(json.r#type, "utf8");
}

#[test]
fn test_map_keys_sorted_unsupported() {
let map_type = DataType::Map(
Arc::new(Field::new(
"entries",
DataType::Struct(
vec![
Field::new("keys", DataType::Utf8, false),
Field::new("values", DataType::Int32, true),
]
.into(),
),
false,
)),
true, // keys_sorted = true
);
let result = arrow_type_to_json(&map_type);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("keys_sorted=true"));
}

#[test]
fn test_unsupported_types_error() {
// RunEndEncoded
let ree = DataType::RunEndEncoded(
Arc::new(Field::new("run_ends", DataType::Int32, false)),
Arc::new(Field::new("values", DataType::Utf8, true)),
);
assert!(arrow_type_to_json(&ree).is_err());

// ListView
let lv = DataType::ListView(Arc::new(Field::new("item", DataType::Int32, true)));
assert!(arrow_type_to_json(&lv).is_err());

// LargeListView
let llv = DataType::LargeListView(Arc::new(Field::new("item", DataType::Int32, true)));
assert!(arrow_type_to_json(&llv).is_err());

// Utf8View / BinaryView
assert!(arrow_type_to_json(&DataType::Utf8View).is_err());
assert!(arrow_type_to_json(&DataType::BinaryView).is_err());
}

#[test]
fn test_large_list_roundtrip() {
let inner_field = Field::new("item", DataType::Float64, true);
let large_list = DataType::LargeList(Arc::new(inner_field));

let json = arrow_type_to_json(&large_list).unwrap();
assert_eq!(json.r#type, "large_list");

let back = convert_json_arrow_type(&json).unwrap();
assert_eq!(back, large_list);
}

#[test]
fn test_field_with_metadata_roundtrip() {
let mut field_meta = HashMap::new();
field_meta.insert("custom_key".to_string(), "custom_val".to_string());

let field = Field::new("col", DataType::Int64, false).with_metadata(field_meta.clone());
let schema = ArrowSchema::new(vec![field]);

let json_schema = arrow_schema_to_json(&schema).unwrap();
let roundtrip = convert_json_arrow_schema(&json_schema).unwrap();
assert_eq!(roundtrip.field(0).metadata(), &field_meta);
}

#[test]
fn test_nested_list_with_field_metadata() {
let mut meta = HashMap::new();
meta.insert("encoding".to_string(), "delta".to_string());

let inner = Field::new("item", DataType::Int32, true).with_metadata(meta.clone());
let list_type = DataType::List(Arc::new(inner));

let json = arrow_type_to_json(&list_type).unwrap();
let fields = json.fields.as_ref().unwrap();
assert_eq!(fields[0].metadata.as_ref().unwrap(), &meta);
}
}
Loading