diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index e16479110671d..5a71ab91db1a3 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -61,6 +61,7 @@ use substrait::proto::{ }; use substrait::proto::{FunctionArgument, SortField}; +use datafusion::arrow::array::GenericListArray; use datafusion::common::plan_err; use datafusion::logical_expr::expr::{InList, InSubquery, Sort}; use std::collections::HashMap; @@ -1058,7 +1059,7 @@ pub async fn from_substrait_rex( } } -fn from_substrait_type(dt: &substrait::proto::Type) -> Result { +pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result { match &dt.kind { Some(s_kind) => match s_kind { r#type::Kind::Bool(_) => Ok(DataType::Boolean), @@ -1138,7 +1139,7 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result { from_substrait_type(list.r#type.as_ref().ok_or_else(|| { substrait_datafusion_err!("List type must have inner type") })?)?; - let field = Arc::new(Field::new("list_item", inner_type, true)); + let field = Arc::new(Field::new_list_field(inner_type, true)); match list.type_variation_reference { DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::List(field)), LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeList(field)), @@ -1278,6 +1279,45 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { s, ) } + Some(LiteralType::List(l)) => { + let elements = l + .values + .iter() + .map(from_substrait_literal) + .collect::>>()?; + if elements.is_empty() { + return substrait_err!( + "Empty list must be encoded as EmptyList literal type, not List" + ); + } + let element_type = elements[0].data_type(); + match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => ScalarValue::List(ScalarValue::new_list( + elements.as_slice(), + &element_type, + )), + LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeList( + ScalarValue::new_large_list(elements.as_slice(), &element_type), + ), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + } + } + Some(LiteralType::EmptyList(l)) => { + let element_type = from_substrait_type(l.r#type.clone().unwrap().as_ref())?; + match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => { + ScalarValue::List(ScalarValue::new_list(&[], &element_type)) + } + LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeList( + ScalarValue::new_large_list(&[], &element_type), + ), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + } + } Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?, _ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type), }; @@ -1361,7 +1401,24 @@ fn from_substrait_null(null_type: &Type) -> Result { d.precision as u8, d.scale as i8, )), - _ => not_impl_err!("Unsupported Substrait type: {kind:?}"), + r#type::Kind::List(l) => { + let field = Field::new_list_field( + from_substrait_type(l.r#type.clone().unwrap().as_ref())?, + true, + ); + match l.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => Ok(ScalarValue::List(Arc::new( + GenericListArray::new_null(field.into(), 1), + ))), + LARGE_CONTAINER_TYPE_REF => Ok(ScalarValue::LargeList(Arc::new( + GenericListArray::new_null(field.into(), 1), + ))), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {kind:?}" + ), + } + } + _ => not_impl_err!("Unsupported Substrait type for null: {kind:?}"), } } else { not_impl_err!("Null type without kind is not supported") diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 6f0738c38df5d..bfdffdc3a260f 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -30,6 +30,7 @@ use datafusion::{ scalar::ScalarValue, }; +use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; use datafusion::common::{exec_err, internal_err, not_impl_err}; use datafusion::common::{substrait_err, DFSchemaRef}; #[allow(unused_imports)] @@ -42,6 +43,7 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; +use substrait::proto::expression::literal::List; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; use substrait::proto::{CrossRel, ExchangeRel}; @@ -1100,7 +1102,7 @@ pub fn to_substrait_rex( ))), }) } - Expr::Literal(value) => to_substrait_literal(value), + Expr::Literal(value) => to_substrait_literal_expr(value), Expr::Alias(Alias { expr, .. }) => { to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info) } @@ -1526,8 +1528,9 @@ fn make_substrait_like_expr( }; let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?; - let escape_char = - to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c| c.to_string())))?; + let escape_char = to_substrait_literal_expr(&ScalarValue::Utf8( + escape_char.map(|c| c.to_string()), + ))?; let arguments = vec![ FunctionArgument { arg_type: Some(ArgType::Value(expr)), @@ -1683,7 +1686,7 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { )) } -fn to_substrait_literal(value: &ScalarValue) -> Result { +fn to_substrait_literal(value: &ScalarValue) -> Result { let (literal_type, type_variation_reference) = match value { ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b), DEFAULT_TYPE_REF), ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32), DEFAULT_TYPE_REF), @@ -1741,15 +1744,50 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { }), DECIMAL_128_TYPE_REF, ), + ScalarValue::List(l) if !value.is_null() => ( + convert_array_to_literal_list(l)?, + DEFAULT_CONTAINER_TYPE_REF, + ), + ScalarValue::LargeList(l) if !value.is_null() => { + (convert_array_to_literal_list(l)?, LARGE_CONTAINER_TYPE_REF) + } _ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF), }; + Ok(Literal { + nullable: true, + type_variation_reference, + literal_type: Some(literal_type), + }) +} + +fn convert_array_to_literal_list( + array: &GenericListArray, +) -> Result { + assert_eq!(array.len(), 1); + let nested_array = array.value(0); + + let values = (0..nested_array.len()) + .map(|i| to_substrait_literal(&ScalarValue::try_from_array(&nested_array, i)?)) + .collect::>>()?; + + if values.is_empty() { + let et = match to_substrait_type(array.data_type())? { + substrait::proto::Type { + kind: Some(r#type::Kind::List(lt)), + } => lt.as_ref().to_owned(), + _ => unreachable!(), + }; + Ok(LiteralType::EmptyList(et)) + } else { + Ok(LiteralType::List(List { values })) + } +} + +fn to_substrait_literal_expr(value: &ScalarValue) -> Result { + let literal = to_substrait_literal(value)?; Ok(Expression { - rex_type: Some(RexType::Literal(Literal { - nullable: true, - type_variation_reference, - literal_type: Some(literal_type), - })), + rex_type: Some(RexType::Literal(literal)), }) } @@ -1937,6 +1975,10 @@ fn try_to_substrait_null(v: &ScalarValue) -> Result { })), })) } + ScalarValue::List(l) => Ok(LiteralType::Null(to_substrait_type(l.data_type())?)), + ScalarValue::LargeList(l) => { + Ok(LiteralType::Null(to_substrait_type(l.data_type())?)) + } // TODO: Extend support for remaining data types _ => not_impl_err!("Unsupported literal: {v:?}"), } @@ -2016,7 +2058,9 @@ fn substrait_field_ref(index: usize) -> Result { #[cfg(test)] mod test { - use crate::logical_plan::consumer::from_substrait_literal; + use crate::logical_plan::consumer::{from_substrait_literal, from_substrait_type}; + use datafusion::arrow::array::GenericListArray; + use datafusion::arrow::datatypes::Field; use super::*; @@ -2054,22 +2098,87 @@ mod test { round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?; round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?; + round_trip_literal(ScalarValue::List(ScalarValue::new_list( + &[ScalarValue::Float32(Some(1.0))], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::List(ScalarValue::new_list( + &[], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::List(Arc::new(GenericListArray::new_null( + Field::new_list_field(DataType::Float32, true).into(), + 1, + ))))?; + round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( + &[ScalarValue::Float32(Some(1.0))], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( + &[], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::LargeList(Arc::new( + GenericListArray::new_null( + Field::new_list_field(DataType::Float32, true).into(), + 1, + ), + )))?; + Ok(()) } fn round_trip_literal(scalar: ScalarValue) -> Result<()> { println!("Checking round trip of {scalar:?}"); - let substrait = to_substrait_literal(&scalar)?; - let Expression { - rex_type: Some(RexType::Literal(substrait_literal)), - } = substrait - else { - panic!("Expected Literal expression, got {substrait:?}"); - }; - + let substrait_literal = to_substrait_literal(&scalar)?; let roundtrip_scalar = from_substrait_literal(&substrait_literal)?; assert_eq!(scalar, roundtrip_scalar); Ok(()) } + + #[test] + fn round_trip_types() -> Result<()> { + round_trip_type(DataType::Boolean)?; + round_trip_type(DataType::Int8)?; + round_trip_type(DataType::UInt8)?; + round_trip_type(DataType::Int16)?; + round_trip_type(DataType::UInt16)?; + round_trip_type(DataType::Int32)?; + round_trip_type(DataType::UInt32)?; + round_trip_type(DataType::Int64)?; + round_trip_type(DataType::UInt64)?; + round_trip_type(DataType::Float32)?; + round_trip_type(DataType::Float64)?; + round_trip_type(DataType::Timestamp(TimeUnit::Second, None))?; + round_trip_type(DataType::Timestamp(TimeUnit::Millisecond, None))?; + round_trip_type(DataType::Timestamp(TimeUnit::Microsecond, None))?; + round_trip_type(DataType::Timestamp(TimeUnit::Nanosecond, None))?; + round_trip_type(DataType::Date32)?; + round_trip_type(DataType::Date64)?; + round_trip_type(DataType::Binary)?; + round_trip_type(DataType::FixedSizeBinary(10))?; + round_trip_type(DataType::LargeBinary)?; + round_trip_type(DataType::Utf8)?; + round_trip_type(DataType::LargeUtf8)?; + round_trip_type(DataType::Decimal128(10, 2))?; + round_trip_type(DataType::Decimal256(30, 2))?; + round_trip_type(DataType::List( + Field::new_list_field(DataType::Int32, true).into(), + ))?; + round_trip_type(DataType::LargeList( + Field::new_list_field(DataType::Int32, true).into(), + ))?; + + Ok(()) + } + + fn round_trip_type(dt: DataType) -> Result<()> { + println!("Checking round trip of {dt:?}"); + + let substrait = to_substrait_type(&dt)?; + let roundtrip_dt = from_substrait_type(&substrait)?; + assert_eq!(dt, roundtrip_dt); + Ok(()) + } } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 4c7dc87145852..02371063ef131 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -665,6 +665,16 @@ async fn all_type_literal() -> Result<()> { .await } +#[tokio::test] +async fn roundtrip_literal_list() -> Result<()> { + assert_expected_plan( + "SELECT [[1,2,3], [], NULL, [NULL]] FROM data", + "Projection: List([[1, 2, 3], [], , []])\ + \n TableScan: data projection=[]", + ) + .await +} + /// Construct a plan that cast columns. Only those SQL types are supported for now. #[tokio::test] async fn new_test_grammar() -> Result<()> {