diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 0a92d0d1b2ea..50545b581954 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -714,6 +714,9 @@ message Union{ } message ScalarListValue{ + // encode null explicitly to distinguish a list with a null value + // from a list with no values) + bool is_null = 3; Field field = 1; repeated ScalarValue values = 2; } @@ -768,7 +771,7 @@ message ScalarValue{ //Literal Date32 value always has a unit of day int32 date_32_value = 14; ScalarListValue list_value = 17; - ScalarType null_list_value = 18; + //WAS: ScalarType null_list_value = 18; Decimal128 decimal128_value = 20; int64 date_64_value = 21; @@ -825,17 +828,6 @@ enum PrimitiveScalarType{ TIME64 = 27; } -message ScalarType{ - oneof datatype{ - PrimitiveScalarType scalar = 1; - ScalarListType list = 2; - } -} - -message ScalarListType{ - repeated string field_names = 3; - PrimitiveScalarType deepest_type = 2; -} // Broke out into multiple message types so that type // metadata did not need to be in separate message diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 3eeb30edf649..9db9fc2933fa 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -95,10 +95,6 @@ impl Error { Error::MissingRequiredField(field.into()) } - fn at_least_one(field: impl Into) -> Error { - Error::AtLeastOneValue(field.into()) - } - fn unknown(name: impl Into, value: i32) -> Error { Error::UnknownEnumVariant { name: name.into(), @@ -559,56 +555,6 @@ impl TryFrom<&i32> for protobuf::AggregateFunction { } } -impl TryFrom<&protobuf::scalar_type::Datatype> for DataType { - type Error = Error; - - fn try_from( - scalar_type: &protobuf::scalar_type::Datatype, - ) -> Result { - use protobuf::scalar_type::Datatype; - - Ok(match scalar_type { - Datatype::Scalar(scalar_type) => { - protobuf::PrimitiveScalarType::try_from(scalar_type)?.into() - } - Datatype::List(protobuf::ScalarListType { - deepest_type, - field_names, - }) => { - if field_names.is_empty() { - return Err(Error::at_least_one("field_names")); - } - let field_type = - protobuf::PrimitiveScalarType::try_from(deepest_type)?.into(); - // Because length is checked above it is safe to unwrap .last() - let mut scalar_type = DataType::List(Box::new(Field::new( - field_names.last().unwrap().as_str(), - field_type, - true, - ))); - // Iterate over field names in reverse order except for the last item in the vector - for name in field_names.iter().rev().skip(1) { - let new_datatype = DataType::List(Box::new(Field::new( - name.as_str(), - scalar_type, - true, - ))); - scalar_type = new_datatype; - } - scalar_type - } - }) - } -} - -impl TryFrom<&protobuf::ScalarType> for DataType { - type Error = Error; - - fn try_from(scalar: &protobuf::ScalarType) -> Result { - scalar.datatype.as_ref().required("datatype") - } -} - impl TryFrom<&protobuf::Schema> for Schema { type Error = Error; @@ -676,36 +622,6 @@ impl TryFrom<&protobuf::PrimitiveScalarType> for ScalarValue { } } -impl TryFrom<&protobuf::ScalarListType> for DataType { - type Error = Error; - fn try_from(scalar: &protobuf::ScalarListType) -> Result { - use protobuf::PrimitiveScalarType; - - let protobuf::ScalarListType { - deepest_type, - field_names, - } = scalar; - - let depth = field_names.len(); - if depth == 0 { - return Err(Error::at_least_one("field_names")); - } - - let mut curr_type = Self::List(Box::new(Field::new( - // Since checked vector is not empty above this is safe to unwrap - field_names.last().unwrap(), - PrimitiveScalarType::try_from(deepest_type)?.into(), - true, - ))); - // Iterates over field names in reverse order except for the last item in the vector - for name in field_names.iter().rev().skip(1) { - let temp_curr_type = Self::List(Box::new(Field::new(name, curr_type, true))); - curr_type = temp_curr_type; - } - Ok(curr_type) - } -} - impl TryFrom<&protobuf::ScalarValue> for ScalarValue { type Error = Error; @@ -734,23 +650,23 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { Value::Date32Value(v) => Self::Date32(Some(*v)), Value::ListValue(scalar_list) => { let protobuf::ScalarListValue { + is_null, values, - field: opt_field, + field, } = &scalar_list; - let field = opt_field.as_ref().required("field")?; + let field: Field = field.as_ref().required("field")?; let field = Box::new(field); - let typechecked_values: Vec = values - .iter() - .map(|val| val.try_into()) - .collect::, _>>()?; + let values: Result, Error> = + values.iter().map(|val| val.try_into()).collect(); + let values = values?; - Self::List(Some(typechecked_values), field) - } - Value::NullListValue(v) => { - let field = Field::new("item", v.try_into()?, true); - Self::List(None, Box::new(field)) + validate_list_values(field.as_ref(), &values)?; + + let values = if *is_null { None } else { Some(values) }; + + Self::List(values, field) } Value::NullValue(v) => { let null_type_enum = protobuf::PrimitiveScalarType::try_from(v)?; @@ -840,6 +756,23 @@ impl TryFrom<&protobuf::ScalarValue> for ScalarValue { } } +/// Ensures that all `values` are of type DataType::List and have the +/// same type as field +fn validate_list_values(field: &Field, values: &[ScalarValue]) -> Result<(), Error> { + for value in values { + let field_type = field.data_type(); + let value_type = value.get_datatype(); + + if field_type != &value_type { + return Err(proto_error(format!( + "Expected field type {:?}, got scalar of type: {:?}", + field_type, value_type + ))); + } + } + Ok(()) +} + pub fn parse_expr( proto: &protobuf::LogicalExprNode, registry: &dyn FunctionRegistry, diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index e3b6c848a2b1..ef6aa1f172e4 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -320,7 +320,7 @@ mod roundtrip_tests { Some(vec![]), Box::new(vec![Field::new("item", DataType::Int16, true)]), ), - // Should fail due to inconsistent types + // Should fail due to inconsistent types in the list ScalarValue::new_list( Some(vec![ ScalarValue::Int16(None), @@ -335,6 +335,13 @@ mod roundtrip_tests { ]), DataType::List(new_box_field("item", DataType::Int16, true)), ), + ScalarValue::new_list( + Some(vec![ + ScalarValue::Float32(None), + ScalarValue::Float32(Some(32.0)), + ]), + DataType::Int16, + ), ScalarValue::new_list( Some(vec![ ScalarValue::new_list( @@ -369,15 +376,20 @@ mod roundtrip_tests { ]; for test_case in should_fail_on_seralize.into_iter() { - let res: std::result::Result< - super::protobuf::ScalarValue, - super::to_proto::Error, - > = (&test_case).try_into(); - assert!( - res.is_err(), - "The value {:?} should not have been able to serialize. Serialized to :{:?}", - test_case, res - ); + let proto: Result = + (&test_case).try_into(); + + // Validation is also done on read, so if serialization passed + // also try to convert back to ScalarValue + if let Ok(proto) = proto { + let res: Result = (&proto).try_into(); + assert!( + res.is_err(), + "The value {:?} unexpectedly serialized without error:{:?}", + test_case, + res + ); + } } } @@ -482,14 +494,11 @@ mod roundtrip_tests { ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), ]), - DataType::List(new_box_field("level1", DataType::Float32, true)), + DataType::Float32, ), ScalarValue::new_list( Some(vec![ - ScalarValue::new_list( - None, - DataType::List(new_box_field("level2", DataType::Float32, true)), - ), + ScalarValue::new_list(None, DataType::Float32), ScalarValue::new_list( Some(vec![ ScalarValue::Float32(Some(-213.1)), @@ -498,14 +507,10 @@ mod roundtrip_tests { ScalarValue::Float32(Some(2.0)), ScalarValue::Float32(Some(1.0)), ]), - DataType::List(new_box_field("level2", DataType::Float32, true)), + DataType::Float32, ), ]), - DataType::List(new_box_field( - "level1", - DataType::List(new_box_field("level2", DataType::Float32, true)), - true, - )), + DataType::List(new_box_field("item", DataType::Float32, true)), ), ScalarValue::Dictionary( Box::new(DataType::Int32), @@ -576,125 +581,12 @@ mod roundtrip_tests { DataType::Utf8, DataType::LargeUtf8, // Recursive list tests - DataType::List(new_box_field("Level1", DataType::Boolean, true)), - DataType::List(new_box_field( - "Level1", - DataType::List(new_box_field("Level2", DataType::Date32, true)), - true, - )), - ]; - - let should_fail: Vec = vec![ - DataType::Null, - DataType::Float16, - // Add more timestamp tests - DataType::Timestamp(TimeUnit::Millisecond, None), - DataType::Date64, - DataType::Time32(TimeUnit::Second), - DataType::Time32(TimeUnit::Millisecond), - DataType::Time32(TimeUnit::Microsecond), - DataType::Time32(TimeUnit::Nanosecond), - DataType::Time64(TimeUnit::Second), - DataType::Time64(TimeUnit::Millisecond), - DataType::Duration(TimeUnit::Second), - DataType::Duration(TimeUnit::Millisecond), - DataType::Duration(TimeUnit::Microsecond), - DataType::Duration(TimeUnit::Nanosecond), - DataType::Interval(IntervalUnit::YearMonth), - DataType::Interval(IntervalUnit::DayTime), - DataType::Binary, - DataType::FixedSizeBinary(0), - DataType::FixedSizeBinary(1234), - DataType::FixedSizeBinary(-432), - DataType::LargeBinary, - DataType::Decimal128(123, 234), - // Recursive list tests - DataType::List(new_box_field("Level1", DataType::Binary, true)), + DataType::List(new_box_field("level1", DataType::Boolean, true)), DataType::List(new_box_field( "Level1", - DataType::List(new_box_field( - "Level2", - DataType::FixedSizeBinary(53), - false, - )), + DataType::List(new_box_field("level2", DataType::Date32, true)), true, )), - // Fixed size lists - DataType::FixedSizeList(new_box_field("Level1", DataType::Binary, true), 4), - DataType::FixedSizeList( - new_box_field( - "Level1", - DataType::List(new_box_field( - "Level2", - DataType::FixedSizeBinary(53), - false, - )), - true, - ), - 41, - ), - // Struct Testing - DataType::Struct(vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - ]), - DataType::Struct(vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - Field::new( - "nested_struct", - DataType::Struct(vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - ]), - true, - ), - ]), - DataType::Union( - vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - ], - vec![0, 2, 3], - UnionMode::Dense, - ), - DataType::Union( - vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - Field::new( - "nested_struct", - DataType::Struct(vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - ]), - true, - ), - ], - vec![1, 2, 3], - UnionMode::Sparse, - ), - DataType::Dictionary( - Box::new(DataType::Utf8), - Box::new(DataType::Struct(vec![ - Field::new("nullable", DataType::Boolean, false), - Field::new("name", DataType::Utf8, false), - Field::new("datatype", DataType::Binary, false), - ])), - ), - DataType::Dictionary( - Box::new(DataType::Decimal128(10, 50)), - Box::new(DataType::FixedSizeList( - new_box_field("Level1", DataType::Binary, true), - 4, - )), - ), ]; for test_case in should_pass.into_iter() { @@ -703,22 +595,6 @@ mod roundtrip_tests { let roundtrip: Field = (&proto).try_into().unwrap(); assert_eq!(format!("{:?}", field), format!("{:?}", roundtrip)); } - - let mut success: Vec = Vec::new(); - for test_case in should_fail.into_iter() { - let proto: std::result::Result< - super::protobuf::ScalarType, - super::to_proto::Error, - > = (&Field::new("item", test_case.clone(), true)).try_into(); - if proto.is_ok() { - success.push(test_case) - } - } - assert!( - success.is_empty(), - "These should have resulted in an error but completed successfully: {:?}", - success - ); } #[test] diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 47b779fffc74..519ace6eb32d 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -101,27 +101,6 @@ impl std::fmt::Display for Error { } } -impl Error { - fn inconsistent_list_typing(type1: &DataType, type2: &DataType) -> Self { - Self::InconsistentListTyping(type1.to_owned(), type2.to_owned()) - } - - fn inconsistent_list_designated(value: &ScalarValue, designated: &DataType) -> Self { - Self::InconsistentListDesignated { - value: value.to_owned(), - designated: designated.to_owned(), - } - } - - fn invalid_scalar_type(data_type: &DataType) -> Self { - Self::InvalidScalarType(data_type.to_owned()) - } - - fn invalid_time_unit(time_unit: &TimeUnit) -> Self { - Self::InvalidTimeUnit(time_unit.to_owned()) - } -} - impl TryFrom<&Field> for protobuf::Field { type Error = Error; @@ -980,115 +959,30 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { Value::LargeUtf8Value(s.to_owned()) }) } - scalar::ScalarValue::List(value, boxed_field) => match value { - Some(values) => { - if values.is_empty() { - protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::ListValue( - protobuf::ScalarListValue { - field: Some(boxed_field.as_ref().try_into()?), - values: Vec::new(), - }, - )), - } - } else { - let scalar_type = match boxed_field.data_type() { - DataType::List(field) => field.as_ref().data_type(), - unsupported => { - return Err(Error::General(format!("Proto serialization error: {:?} not supported to convert to DataType::List", unsupported))); - } - }; + scalar::ScalarValue::List(values, boxed_field) => { + let is_null = values.is_none(); - let type_checked_values: Vec = values - .iter() - .map(|scalar| match (scalar, scalar_type) { - ( - scalar::ScalarValue::List(_, list_type), - DataType::List(field), - ) => { - if let DataType::List(list_field) = - list_type.data_type() - { - let scalar_datatype = field.data_type(); - let list_datatype = list_field.data_type(); - if std::mem::discriminant(list_datatype) - != std::mem::discriminant(scalar_datatype) - { - return Err(Error::inconsistent_list_typing( - list_datatype, - scalar_datatype, - )); - } - scalar.try_into() - } else { - Err(Error::inconsistent_list_designated( - scalar, - boxed_field.data_type(), - )) - } - } - (scalar::ScalarValue::Boolean(_), DataType::Boolean) => { - scalar.try_into() - } - (scalar::ScalarValue::Float32(_), DataType::Float32) => { - scalar.try_into() - } - (scalar::ScalarValue::Float64(_), DataType::Float64) => { - scalar.try_into() - } - (scalar::ScalarValue::Int8(_), DataType::Int8) => { - scalar.try_into() - } - (scalar::ScalarValue::Int16(_), DataType::Int16) => { - scalar.try_into() - } - (scalar::ScalarValue::Int32(_), DataType::Int32) => { - scalar.try_into() - } - (scalar::ScalarValue::Int64(_), DataType::Int64) => { - scalar.try_into() - } - (scalar::ScalarValue::UInt8(_), DataType::UInt8) => { - scalar.try_into() - } - (scalar::ScalarValue::UInt16(_), DataType::UInt16) => { - scalar.try_into() - } - (scalar::ScalarValue::UInt32(_), DataType::UInt32) => { - scalar.try_into() - } - (scalar::ScalarValue::UInt64(_), DataType::UInt64) => { - scalar.try_into() - } - (scalar::ScalarValue::Utf8(_), DataType::Utf8) => { - scalar.try_into() - } - ( - scalar::ScalarValue::LargeUtf8(_), - DataType::LargeUtf8, - ) => scalar.try_into(), - _ => Err(Error::inconsistent_list_designated( - scalar, - boxed_field.data_type(), - )), - }) - .collect::, _>>()?; - protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::ListValue( - protobuf::ScalarListValue { - field: Some(boxed_field.as_ref().try_into()?), - values: type_checked_values, - }, - )), - } - } - } - None => protobuf::ScalarValue { - value: Some(protobuf::scalar_value::Value::NullListValue( - boxed_field.as_ref().try_into()?, + let values = if let Some(values) = values.as_ref() { + values + .iter() + .map(|v| v.try_into()) + .collect::, _>>()? + } else { + vec![] + }; + + let field = boxed_field.as_ref().try_into()?; + + protobuf::ScalarValue { + value: Some(protobuf::scalar_value::Value::ListValue( + protobuf::ScalarListValue { + is_null, + field: Some(field), + values, + }, )), - }, - }, + } + } datafusion::scalar::ScalarValue::Date32(val) => { create_proto_scalar(val, PrimitiveScalarType::Date32, |s| { Value::Date32Value(*s) @@ -1335,128 +1229,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { } } -impl TryFrom<&Field> for protobuf::ScalarType { - type Error = Error; - - fn try_from(value: &Field) -> Result { - let datatype = protobuf::scalar_type::Datatype::try_from(value.data_type())?; - Ok(Self { - datatype: Some(datatype), - }) - } -} - -impl TryFrom<&DataType> for protobuf::scalar_type::Datatype { - type Error = Error; - - fn try_from(val: &DataType) -> Result { - use protobuf::PrimitiveScalarType; - - let scalar_value = match val { - DataType::Boolean => Self::Scalar(PrimitiveScalarType::Bool as i32), - DataType::Int8 => Self::Scalar(PrimitiveScalarType::Int8 as i32), - DataType::Int16 => Self::Scalar(PrimitiveScalarType::Int16 as i32), - DataType::Int32 => Self::Scalar(PrimitiveScalarType::Int32 as i32), - DataType::Int64 => Self::Scalar(PrimitiveScalarType::Int64 as i32), - DataType::UInt8 => Self::Scalar(PrimitiveScalarType::Uint8 as i32), - DataType::UInt16 => Self::Scalar(PrimitiveScalarType::Uint16 as i32), - DataType::UInt32 => Self::Scalar(PrimitiveScalarType::Uint32 as i32), - DataType::UInt64 => Self::Scalar(PrimitiveScalarType::Uint64 as i32), - DataType::Float32 => Self::Scalar(PrimitiveScalarType::Float32 as i32), - DataType::Float64 => Self::Scalar(PrimitiveScalarType::Float64 as i32), - DataType::Date32 => Self::Scalar(PrimitiveScalarType::Date32 as i32), - DataType::Time64(time_unit) => match time_unit { - TimeUnit::Microsecond => { - Self::Scalar(PrimitiveScalarType::TimestampMicrosecond as i32) - } - TimeUnit::Nanosecond => { - Self::Scalar(PrimitiveScalarType::TimestampNanosecond as i32) - } - _ => { - return Err(Error::invalid_time_unit(time_unit)); - } - }, - DataType::Utf8 => Self::Scalar(PrimitiveScalarType::Utf8 as i32), - DataType::LargeUtf8 => Self::Scalar(PrimitiveScalarType::LargeUtf8 as i32), - DataType::List(field_type) => { - let mut field_names: Vec = Vec::new(); - let mut curr_field = field_type.as_ref(); - field_names.push(curr_field.name().to_owned()); - // For each nested field check nested datatype, since datafusion scalars only - // support recursive lists with a leaf scalar type - // any other compound types are errors. - - while let DataType::List(nested_field_type) = curr_field.data_type() { - curr_field = nested_field_type.as_ref(); - field_names.push(curr_field.name().to_owned()); - if !is_valid_scalar_type_no_list_check(curr_field.data_type()) { - return Err(Error::invalid_scalar_type(curr_field.data_type())); - } - } - let deepest_datatype = curr_field.data_type(); - if !is_valid_scalar_type_no_list_check(deepest_datatype) { - return Err(Error::invalid_scalar_type(deepest_datatype)); - } - let pb_deepest_type: PrimitiveScalarType = match deepest_datatype { - DataType::Boolean => PrimitiveScalarType::Bool, - DataType::Int8 => PrimitiveScalarType::Int8, - DataType::Int16 => PrimitiveScalarType::Int16, - DataType::Int32 => PrimitiveScalarType::Int32, - DataType::Int64 => PrimitiveScalarType::Int64, - DataType::UInt8 => PrimitiveScalarType::Uint8, - DataType::UInt16 => PrimitiveScalarType::Uint16, - DataType::UInt32 => PrimitiveScalarType::Uint32, - DataType::UInt64 => PrimitiveScalarType::Uint64, - DataType::Float32 => PrimitiveScalarType::Float32, - DataType::Float64 => PrimitiveScalarType::Float64, - DataType::Date32 => PrimitiveScalarType::Date32, - DataType::Time64(time_unit) => match time_unit { - TimeUnit::Microsecond => { - PrimitiveScalarType::TimestampMicrosecond - } - TimeUnit::Nanosecond => PrimitiveScalarType::TimestampNanosecond, - _ => { - return Err(Error::invalid_time_unit(time_unit)); - } - }, - - DataType::Utf8 => PrimitiveScalarType::Utf8, - DataType::LargeUtf8 => PrimitiveScalarType::LargeUtf8, - _ => { - return Err(Error::invalid_scalar_type(val)); - } - }; - Self::List(protobuf::ScalarListType { - field_names, - deepest_type: pb_deepest_type as i32, - }) - } - DataType::Null - | DataType::Float16 - | DataType::Timestamp(_, _) - | DataType::Date64 - | DataType::Time32(_) - | DataType::Duration(_) - | DataType::Interval(_) - | DataType::Binary - | DataType::FixedSizeBinary(_) - | DataType::LargeBinary - | DataType::FixedSizeList(_, _) - | DataType::LargeList(_) - | DataType::Struct(_) - | DataType::Union(_, _, _) - | DataType::Dictionary(_, _) - | DataType::Map(_, _) - | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) => { - return Err(Error::invalid_scalar_type(val)); - } - }; - - Ok(scalar_value) - } -} - impl From<&TimeUnit> for protobuf::TimeUnit { fn from(val: &TimeUnit) -> Self { match val { @@ -1489,29 +1261,3 @@ fn create_proto_scalar protobuf::scalar_value::Value>( )), } } - -// Does not check if list subtypes are valid -fn is_valid_scalar_type_no_list_check(datatype: &DataType) -> bool { - match datatype { - DataType::Boolean - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 - | DataType::LargeUtf8 - | DataType::Utf8 - | DataType::Date32 => true, - DataType::Time64(time_unit) => { - matches!(time_unit, TimeUnit::Microsecond | TimeUnit::Nanosecond) - } - - DataType::List(_) => true, - _ => false, - } -}