From 48565f7cc1e6b87c93face137d79bf21dcdc3e85 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Wed, 22 May 2024 12:26:41 +0200 Subject: [PATCH 1/6] Add support for Substrait List/EmptyList literals Adds support for converting from DataFusion List/LargeList ScalarValues into Substrait List/EmptyList Literals and back --- .../substrait/src/logical_plan/consumer.rs | 57 +++++++++++++- .../substrait/src/logical_plan/producer.rs | 74 ++++++++++++++++--- .../tests/cases/roundtrip_logical_plan.rs | 12 +++ 3 files changed, 132 insertions(+), 11 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index e16479110671d..9620c90c6f04d 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -61,6 +61,7 @@ use substrait::proto::{ }; use substrait::proto::{FunctionArgument, SortField}; +use datafusion::arrow::array::GenericListArray; use datafusion::common::plan_err; use datafusion::logical_expr::expr::{InList, InSubquery, Sort}; use std::collections::HashMap; @@ -1138,7 +1139,7 @@ fn from_substrait_type(dt: &substrait::proto::Type) -> Result { from_substrait_type(list.r#type.as_ref().ok_or_else(|| { substrait_datafusion_err!("List type must have inner type") })?)?; - let field = Arc::new(Field::new("list_item", inner_type, true)); + let field = Arc::new(Field::new_list_field(inner_type, true)); match list.type_variation_reference { DEFAULT_CONTAINER_TYPE_REF => Ok(DataType::List(field)), LARGE_CONTAINER_TYPE_REF => Ok(DataType::LargeList(field)), @@ -1278,6 +1279,40 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { s, ) } + Some(LiteralType::List(l)) => { + let elements = l + .values + .iter() + .map(|el| from_substrait_literal(el)) + .collect::>>()?; + let element_type = elements[0].data_type(); + match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => ScalarValue::List(ScalarValue::new_list( + elements.as_slice(), + &element_type, + )), + LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeList( + ScalarValue::new_large_list(elements.as_slice(), &element_type), + ), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + } + } + Some(LiteralType::EmptyList(l)) => { + let element_type = from_substrait_type(l.r#type.clone().unwrap().as_ref())?; + match lit.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => { + ScalarValue::List(ScalarValue::new_list(&[], &element_type)) + } + LARGE_CONTAINER_TYPE_REF => ScalarValue::LargeList( + ScalarValue::new_large_list(&[], &element_type), + ), + others => { + return substrait_err!("Unknown type variation reference {others}"); + } + } + } Some(LiteralType::Null(ntype)) => from_substrait_null(ntype)?, _ => return not_impl_err!("Unsupported literal_type: {:?}", lit.literal_type), }; @@ -1361,7 +1396,25 @@ fn from_substrait_null(null_type: &Type) -> Result { d.precision as u8, d.scale as i8, )), - _ => not_impl_err!("Unsupported Substrait type: {kind:?}"), + r#type::Kind::List(l) => { + // let field = Field::new_list_field( + // from_substrait_type(l.r#type.clone().unwrap().as_ref())?, + // true, + // ); + // Ok(ScalarValue::List(Arc::new(GenericListArray::new_null( + // field.into(), + // 1, + // )))) + let field = Field::new_list_field( + from_substrait_type(l.r#type.clone().unwrap().as_ref())?, + true, + ); + Ok(ScalarValue::List(Arc::new(GenericListArray::new_null( + field.into(), + 1, + )))) + } + _ => not_impl_err!("Unsupported Substrait type for null: {kind:?}"), } } else { not_impl_err!("Null type without kind is not supported") diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 6f0738c38df5d..31a979b4b30ee 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -30,6 +30,7 @@ use datafusion::{ scalar::ScalarValue, }; +use datafusion::arrow::array::{Array, AsArray, OffsetSizeTrait}; use datafusion::common::{exec_err, internal_err, not_impl_err}; use datafusion::common::{substrait_err, DFSchemaRef}; #[allow(unused_imports)] @@ -42,6 +43,7 @@ use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Opera use datafusion::prelude::Expr; use prost_types::Any as ProtoAny; use substrait::proto::exchange_rel::{ExchangeKind, RoundRobin, ScatterFields}; +use substrait::proto::expression::literal::List; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; use substrait::proto::{CrossRel, ExchangeRel}; @@ -1100,7 +1102,7 @@ pub fn to_substrait_rex( ))), }) } - Expr::Literal(value) => to_substrait_literal(value), + Expr::Literal(value) => to_substrait_literal_expr(value), Expr::Alias(Alias { expr, .. }) => { to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info) } @@ -1526,8 +1528,9 @@ fn make_substrait_like_expr( }; let expr = to_substrait_rex(ctx, expr, schema, col_ref_offset, extension_info)?; let pattern = to_substrait_rex(ctx, pattern, schema, col_ref_offset, extension_info)?; - let escape_char = - to_substrait_literal(&ScalarValue::Utf8(escape_char.map(|c| c.to_string())))?; + let escape_char = to_substrait_literal_expr(&ScalarValue::Utf8( + escape_char.map(|c| c.to_string()), + ))?; let arguments = vec![ FunctionArgument { arg_type: Some(ArgType::Value(expr)), @@ -1683,7 +1686,7 @@ fn to_substrait_bounds(window_frame: &WindowFrame) -> Result<(Bound, Bound)> { )) } -fn to_substrait_literal(value: &ScalarValue) -> Result { +fn to_substrait_literal(value: &ScalarValue) -> Result { let (literal_type, type_variation_reference) = match value { ScalarValue::Boolean(Some(b)) => (LiteralType::Boolean(*b), DEFAULT_TYPE_REF), ScalarValue::Int8(Some(n)) => (LiteralType::I8(*n as i32), DEFAULT_TYPE_REF), @@ -1741,15 +1744,64 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { }), DECIMAL_128_TYPE_REF, ), + ScalarValue::List(l) if !value.is_null() => { + let values = + convert_array_to_literal_vec::(&(l.to_owned() as Arc))?; + let list = if values.is_empty() { + LiteralType::EmptyList(match to_substrait_type(&l.data_type())? { + substrait::proto::Type { + kind: Some(r#type::Kind::List(l)), + } => l.as_ref().to_owned(), + _ => unreachable!(), + }) + } else { + LiteralType::List(List { values }) + }; + (list, DEFAULT_CONTAINER_TYPE_REF) + } + ScalarValue::LargeList(l) if !value.is_null() => { + let values = + convert_array_to_literal_vec::(&(l.to_owned() as Arc))?; + let list = if values.is_empty() { + LiteralType::EmptyList(match to_substrait_type(&l.data_type())? { + substrait::proto::Type { + kind: Some(r#type::Kind::List(l)), + } => l.as_ref().to_owned(), + _ => unreachable!(), + }) + } else { + LiteralType::List(List { values }) + }; + (list, LARGE_CONTAINER_TYPE_REF) + } _ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF), }; + Ok(Literal { + nullable: true, + type_variation_reference, + literal_type: Some(literal_type), + }) +} + +fn convert_array_to_literal_vec( + array: &dyn Array, +) -> Result> { + // Adapted from ScalarValue::convert_array_to_scalar_vec to support both i32 and i64 + assert_eq!(array.len(), 1); + let nested_array = array.as_list::().value(0); + + let scalars = (0..nested_array.len()) + .map(|i| to_substrait_literal(&ScalarValue::try_from_array(&nested_array, i)?)) + .collect::>>()?; + + Ok(scalars) +} + +fn to_substrait_literal_expr(value: &ScalarValue) -> Result { + let literal = to_substrait_literal(value)?; Ok(Expression { - rex_type: Some(RexType::Literal(Literal { - nullable: true, - type_variation_reference, - literal_type: Some(literal_type), - })), + rex_type: Some(RexType::Literal(literal)), }) } @@ -1937,6 +1989,10 @@ fn try_to_substrait_null(v: &ScalarValue) -> Result { })), })) } + ScalarValue::List(l) => Ok(LiteralType::Null(to_substrait_type(l.data_type())?)), + ScalarValue::LargeList(l) => { + Ok(LiteralType::Null(to_substrait_type(l.data_type())?)) + } // TODO: Extend support for remaining data types _ => not_impl_err!("Unsupported literal: {v:?}"), } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 4c7dc87145852..255208f592f14 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -665,6 +665,16 @@ async fn all_type_literal() -> Result<()> { .await } +#[tokio::test] +async fn roundtrip_literal_list() -> Result<()> { + assert_expected_plan( + "SELECT [[1,2,3], [], NULL, [NULL]] FROM data", + "Projection: List([[1, 2, 3], [], , []])\ + \n TableScan: data projection=[]", + ) + .await +} + /// Construct a plan that cast columns. Only those SQL types are supported for now. #[tokio::test] async fn new_test_grammar() -> Result<()> { @@ -885,10 +895,12 @@ async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan, &ctx)?; + println!("{proto:#?}"); let plan2 = from_substrait_plan(&ctx, &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; let plan2str = format!("{plan2:?}"); assert_eq!(expected_plan_str, &plan2str); + ctx.execute_logical_plan(plan2).await?.show().await?; Ok(()) } From c71e5771c396083cb5a34e398d780af6adbd04bc Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Wed, 22 May 2024 12:31:56 +0200 Subject: [PATCH 2/6] cleanup --- datafusion/substrait/src/logical_plan/consumer.rs | 8 -------- .../substrait/tests/cases/roundtrip_logical_plan.rs | 2 -- 2 files changed, 10 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 9620c90c6f04d..713fd536ae4e7 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1397,14 +1397,6 @@ fn from_substrait_null(null_type: &Type) -> Result { d.scale as i8, )), r#type::Kind::List(l) => { - // let field = Field::new_list_field( - // from_substrait_type(l.r#type.clone().unwrap().as_ref())?, - // true, - // ); - // Ok(ScalarValue::List(Arc::new(GenericListArray::new_null( - // field.into(), - // 1, - // )))) let field = Field::new_list_field( from_substrait_type(l.r#type.clone().unwrap().as_ref())?, true, diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index 255208f592f14..02371063ef131 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -895,12 +895,10 @@ async fn assert_expected_plan(sql: &str, expected_plan_str: &str) -> Result<()> let df = ctx.sql(sql).await?; let plan = df.into_optimized_plan()?; let proto = to_substrait_plan(&plan, &ctx)?; - println!("{proto:#?}"); let plan2 = from_substrait_plan(&ctx, &proto).await?; let plan2 = ctx.state().optimize(&plan2)?; let plan2str = format!("{plan2:?}"); assert_eq!(expected_plan_str, &plan2str); - ctx.execute_logical_plan(plan2).await?.show().await?; Ok(()) } From 799d672c23b39b37533197105c744cb2fe68e7c9 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Wed, 22 May 2024 14:23:07 +0200 Subject: [PATCH 3/6] fix test, add literal roundtrip tests for lists, and fix creating null large lists --- .../substrait/src/logical_plan/consumer.rs | 15 ++- .../substrait/src/logical_plan/producer.rs | 96 ++++++++++--------- 2 files changed, 63 insertions(+), 48 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 713fd536ae4e7..d8ecb6da5d6bc 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1401,10 +1401,17 @@ fn from_substrait_null(null_type: &Type) -> Result { from_substrait_type(l.r#type.clone().unwrap().as_ref())?, true, ); - Ok(ScalarValue::List(Arc::new(GenericListArray::new_null( - field.into(), - 1, - )))) + match l.type_variation_reference { + DEFAULT_CONTAINER_TYPE_REF => Ok(ScalarValue::List(Arc::new( + GenericListArray::new_null(field.into(), 1), + ))), + LARGE_CONTAINER_TYPE_REF => Ok(ScalarValue::LargeList(Arc::new( + GenericListArray::new_null(field.into(), 1), + ))), + v => not_impl_err!( + "Unsupported Substrait type variation {v} of type {kind:?}" + ), + } } _ => not_impl_err!("Unsupported Substrait type for null: {kind:?}"), } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 31a979b4b30ee..eec4ef2e7f151 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -30,7 +30,7 @@ use datafusion::{ scalar::ScalarValue, }; -use datafusion::arrow::array::{Array, AsArray, OffsetSizeTrait}; +use datafusion::arrow::array::{Array, GenericListArray, OffsetSizeTrait}; use datafusion::common::{exec_err, internal_err, not_impl_err}; use datafusion::common::{substrait_err, DFSchemaRef}; #[allow(unused_imports)] @@ -1744,35 +1744,12 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { }), DECIMAL_128_TYPE_REF, ), - ScalarValue::List(l) if !value.is_null() => { - let values = - convert_array_to_literal_vec::(&(l.to_owned() as Arc))?; - let list = if values.is_empty() { - LiteralType::EmptyList(match to_substrait_type(&l.data_type())? { - substrait::proto::Type { - kind: Some(r#type::Kind::List(l)), - } => l.as_ref().to_owned(), - _ => unreachable!(), - }) - } else { - LiteralType::List(List { values }) - }; - (list, DEFAULT_CONTAINER_TYPE_REF) - } + ScalarValue::List(l) if !value.is_null() => ( + convert_array_to_literal_list(l)?, + DEFAULT_CONTAINER_TYPE_REF, + ), ScalarValue::LargeList(l) if !value.is_null() => { - let values = - convert_array_to_literal_vec::(&(l.to_owned() as Arc))?; - let list = if values.is_empty() { - LiteralType::EmptyList(match to_substrait_type(&l.data_type())? { - substrait::proto::Type { - kind: Some(r#type::Kind::List(l)), - } => l.as_ref().to_owned(), - _ => unreachable!(), - }) - } else { - LiteralType::List(List { values }) - }; - (list, LARGE_CONTAINER_TYPE_REF) + (convert_array_to_literal_list(l)?, LARGE_CONTAINER_TYPE_REF) } _ => (try_to_substrait_null(value)?, DEFAULT_TYPE_REF), }; @@ -1784,18 +1761,27 @@ fn to_substrait_literal(value: &ScalarValue) -> Result { }) } -fn convert_array_to_literal_vec( - array: &dyn Array, -) -> Result> { - // Adapted from ScalarValue::convert_array_to_scalar_vec to support both i32 and i64 +fn convert_array_to_literal_list( + array: &GenericListArray, +) -> Result { assert_eq!(array.len(), 1); - let nested_array = array.as_list::().value(0); + let nested_array = array.value(0); - let scalars = (0..nested_array.len()) + let values = (0..nested_array.len()) .map(|i| to_substrait_literal(&ScalarValue::try_from_array(&nested_array, i)?)) .collect::>>()?; - Ok(scalars) + if values.is_empty() { + let et = match to_substrait_type(&array.data_type())? { + substrait::proto::Type { + kind: Some(r#type::Kind::List(lt)), + } => lt.as_ref().to_owned(), + _ => unreachable!(), + }; + Ok(LiteralType::EmptyList(et)) + } else { + Ok(LiteralType::List(List { values })) + } } fn to_substrait_literal_expr(value: &ScalarValue) -> Result { @@ -2073,6 +2059,8 @@ fn substrait_field_ref(index: usize) -> Result { #[cfg(test)] mod test { use crate::logical_plan::consumer::from_substrait_literal; + use datafusion::arrow::array::GenericListArray; + use datafusion::arrow::datatypes::Field; use super::*; @@ -2110,20 +2098,40 @@ mod test { round_trip_literal(ScalarValue::UInt64(Some(u64::MIN)))?; round_trip_literal(ScalarValue::UInt64(Some(u64::MAX)))?; + round_trip_literal(ScalarValue::List(ScalarValue::new_list( + &[ScalarValue::Float32(Some(1.0))], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::List(ScalarValue::new_list( + &[], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::List(Arc::new(GenericListArray::new_null( + Field::new_list_field(DataType::Float32, true).into(), + 1, + ))))?; + round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( + &[ScalarValue::Float32(Some(1.0))], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::LargeList(ScalarValue::new_large_list( + &[], + &DataType::Float32, + )))?; + round_trip_literal(ScalarValue::LargeList(Arc::new( + GenericListArray::new_null( + Field::new_list_field(DataType::Float32, true).into(), + 1, + ), + )))?; + Ok(()) } fn round_trip_literal(scalar: ScalarValue) -> Result<()> { println!("Checking round trip of {scalar:?}"); - let substrait = to_substrait_literal(&scalar)?; - let Expression { - rex_type: Some(RexType::Literal(substrait_literal)), - } = substrait - else { - panic!("Expected Literal expression, got {substrait:?}"); - }; - + let substrait_literal = to_substrait_literal(&scalar)?; let roundtrip_scalar = from_substrait_literal(&substrait_literal)?; assert_eq!(scalar, roundtrip_scalar); Ok(()) From 047d64349da81586dd31e01d1fa2b4628d147d10 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Wed, 22 May 2024 16:33:39 +0200 Subject: [PATCH 4/6] add unit testing for type roundtrips --- .../substrait/src/logical_plan/consumer.rs | 2 +- .../substrait/src/logical_plan/producer.rs | 47 ++++++++++++++++++- 2 files changed, 47 insertions(+), 2 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index d8ecb6da5d6bc..f3bbb950f4211 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1059,7 +1059,7 @@ pub async fn from_substrait_rex( } } -fn from_substrait_type(dt: &substrait::proto::Type) -> Result { +pub(crate) fn from_substrait_type(dt: &substrait::proto::Type) -> Result { match &dt.kind { Some(s_kind) => match s_kind { r#type::Kind::Bool(_) => Ok(DataType::Boolean), diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index eec4ef2e7f151..d9a89cd668134 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -2058,7 +2058,7 @@ fn substrait_field_ref(index: usize) -> Result { #[cfg(test)] mod test { - use crate::logical_plan::consumer::from_substrait_literal; + use crate::logical_plan::consumer::{from_substrait_literal, from_substrait_type}; use datafusion::arrow::array::GenericListArray; use datafusion::arrow::datatypes::Field; @@ -2136,4 +2136,49 @@ mod test { assert_eq!(scalar, roundtrip_scalar); Ok(()) } + + #[test] + fn round_trip_types() -> Result<()> { + round_trip_type(DataType::Boolean)?; + round_trip_type(DataType::Int8)?; + round_trip_type(DataType::UInt8)?; + round_trip_type(DataType::Int16)?; + round_trip_type(DataType::UInt16)?; + round_trip_type(DataType::Int32)?; + round_trip_type(DataType::UInt32)?; + round_trip_type(DataType::Int64)?; + round_trip_type(DataType::UInt64)?; + round_trip_type(DataType::Float32)?; + round_trip_type(DataType::Float64)?; + round_trip_type(DataType::Timestamp(TimeUnit::Second, None))?; + round_trip_type(DataType::Timestamp(TimeUnit::Millisecond, None))?; + round_trip_type(DataType::Timestamp(TimeUnit::Microsecond, None))?; + round_trip_type(DataType::Timestamp(TimeUnit::Nanosecond, None))?; + round_trip_type(DataType::Date32)?; + round_trip_type(DataType::Date64)?; + round_trip_type(DataType::Binary)?; + round_trip_type(DataType::FixedSizeBinary(10))?; + round_trip_type(DataType::LargeBinary)?; + round_trip_type(DataType::Utf8)?; + round_trip_type(DataType::LargeUtf8)?; + round_trip_type(DataType::Decimal128(10, 2))?; + round_trip_type(DataType::Decimal256(30, 2))?; + round_trip_type(DataType::List( + Field::new_list_field(DataType::Int32, true).into(), + ))?; + round_trip_type(DataType::LargeList( + Field::new_list_field(DataType::Int32, true).into(), + ))?; + + Ok(()) + } + + fn round_trip_type(dt: DataType) -> Result<()> { + println!("Checking round trip of {dt:?}"); + + let substrait = to_substrait_type(&dt)?; + let roundtrip_dt = from_substrait_type(&substrait)?; + assert_eq!(dt, roundtrip_dt); + Ok(()) + } } From 5e185a6e606c491499700d21ad6c48b2d24ce3e1 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Wed, 22 May 2024 16:38:56 +0200 Subject: [PATCH 5/6] fix clippy --- datafusion/substrait/src/logical_plan/consumer.rs | 2 +- datafusion/substrait/src/logical_plan/producer.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index f3bbb950f4211..75b8c6c14de5d 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1283,7 +1283,7 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { let elements = l .values .iter() - .map(|el| from_substrait_literal(el)) + .map(from_substrait_literal) .collect::>>()?; let element_type = elements[0].data_type(); match lit.type_variation_reference { diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index d9a89cd668134..bfdffdc3a260f 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -1772,7 +1772,7 @@ fn convert_array_to_literal_list( .collect::>>()?; if values.is_empty() { - let et = match to_substrait_type(&array.data_type())? { + let et = match to_substrait_type(array.data_type())? { substrait::proto::Type { kind: Some(r#type::Kind::List(lt)), } => lt.as_ref().to_owned(), From cdc525cef141e87c74521c146c54e69863b35084 Mon Sep 17 00:00:00 2001 From: Arttu Voutilainen Date: Wed, 22 May 2024 16:42:18 +0200 Subject: [PATCH 6/6] better error if a substrait literal list is empty --- datafusion/substrait/src/logical_plan/consumer.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index 75b8c6c14de5d..5a71ab91db1a3 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -1285,6 +1285,11 @@ pub(crate) fn from_substrait_literal(lit: &Literal) -> Result { .iter() .map(from_substrait_literal) .collect::>>()?; + if elements.is_empty() { + return substrait_err!( + "Empty list must be encoded as EmptyList literal type, not List" + ); + } let element_type = elements[0].data_type(); match lit.type_variation_reference { DEFAULT_CONTAINER_TYPE_REF => ScalarValue::List(ScalarValue::new_list(