From 8301e301e872941d1fa74ec7d727a3c35818ad6b Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Mon, 11 Mar 2024 18:42:42 -0500 Subject: [PATCH 01/16] feat: replace namedstruct with ScalarUDF --- datafusion/core/src/physical_planner.rs | 13 +- datafusion/optimizer/Cargo.toml | 1 + .../optimizer/src/analyzer/rewrite_expr.rs | 16 +- .../src/expressions/get_indexed_field.rs | 490 +++++++++--------- .../physical-expr/src/expressions/mod.rs | 1 - datafusion/physical-expr/src/planner.rs | 17 +- datafusion/proto/proto/datafusion.proto | 17 +- datafusion/proto/src/generated/pbjson.rs | 219 -------- datafusion/proto/src/generated/prost.rs | 31 +- .../proto/src/physical_plan/from_proto.rs | 22 - .../proto/src/physical_plan/to_proto.rs | 20 - .../tests/cases/roundtrip_physical_plan.rs | 35 +- 12 files changed, 278 insertions(+), 604 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 6d49287debb44..3d6b8054c83f2 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -207,10 +207,13 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { let expr = create_physical_name(expr, false)?; Ok(format!("{expr} IS NOT UNKNOWN")) } - Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let expr = create_physical_name(expr, false)?; - let name = match field { - GetFieldAccess::NamedStructField { name } => format!("{expr}[{name}]"), + Expr::GetIndexedField(GetIndexedField { expr: _, field }) => { + match field { + GetFieldAccess::NamedStructField { name: _ } => { + unreachable!( + "NamedStructField should have been rewritten in OperatorToFunction" + ) + } GetFieldAccess::ListIndex { key: _ } => { unreachable!( "ListIndex should have been rewritten in OperatorToFunction" @@ -226,8 +229,6 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { ) } }; - - Ok(name) } Expr::ScalarFunction(fun) => { // function should be resolved during `AnalyzerRule`s diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index f497f2ec86027..3034f28f8688b 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -45,6 +45,7 @@ async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true } datafusion-functions-array = { workspace = true, optional = true } datafusion-physical-expr = { workspace = true } hashbrown = { version = "0.14", features = ["raw"] } diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs index 99578e91183c0..04c8622cd33d3 100644 --- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -30,11 +30,11 @@ use datafusion_common::{DFSchema, Result}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr_rewriter::rewrite_preserving_name; use datafusion_expr::utils::merge_schema; -use datafusion_expr::BuiltinScalarFunction; use datafusion_expr::GetFieldAccess; use datafusion_expr::GetIndexedField; #[cfg(feature = "array_expressions")] use datafusion_expr::{BinaryExpr, Operator, ScalarFunctionDefinition}; +use datafusion_expr::{BuiltinScalarFunction, ScalarUDF}; use datafusion_expr::{Expr, LogicalPlan}; #[cfg(feature = "array_expressions")] use datafusion_functions_array::expr_fn::{array_append, array_concat, array_prepend}; @@ -137,6 +137,19 @@ impl TreeNodeRewriter for OperatorToFunctionRewriter { }) = expr { match field { + GetFieldAccess::NamedStructField { name, .. } => { + let expr = *expr.clone(); + let name = name.clone(); + let args = vec![expr, Expr::Literal(name)]; + return Ok(Transformed::yes(Expr::ScalarFunction( + ScalarFunction::new_udf( + Arc::new(ScalarUDF::new_from_impl( + datafusion_functions::core::r#struct::StructFunc::new(), + )), + args, + ), + ))); + } GetFieldAccess::ListIndex { ref key } => { let expr = *expr.clone(); let key = *key.clone(); @@ -159,7 +172,6 @@ impl TreeNodeRewriter for OperatorToFunctionRewriter { ScalarFunction::new(BuiltinScalarFunction::ArraySlice, args), ))); } - _ => {} } } diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs index 99b2279ba572f..6edfd41aeb45d 100644 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ b/datafusion/physical-expr/src/expressions/get_indexed_field.rs @@ -1,245 +1,245 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! get field of a `ListArray` - -use crate::PhysicalExpr; -use datafusion_common::exec_err; - -use crate::physical_expr::down_cast_any_ref; -use arrow::{ - array::{Array, Scalar, StringArray}, - datatypes::{DataType, Schema}, - record_batch::RecordBatch, -}; -use datafusion_common::{ - cast::{as_map_array, as_struct_array}, - Result, ScalarValue, -}; -use datafusion_expr::{field_util::GetFieldAccessSchema, ColumnarValue}; -use std::fmt::Debug; -use std::hash::{Hash, Hasher}; -use std::{any::Any, sync::Arc}; - -/// Access a sub field of a nested type, such as `Field` or `List` -#[derive(Clone, Hash, Debug)] -pub enum GetFieldAccessExpr { - /// Named field, For example `struct["name"]` - NamedStructField { name: ScalarValue }, -} - -impl std::fmt::Display for GetFieldAccessExpr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - match self { - GetFieldAccessExpr::NamedStructField { name } => write!(f, "[{}]", name), - } - } -} - -impl PartialEq for GetFieldAccessExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| match (self, x) { - ( - GetFieldAccessExpr::NamedStructField { name: lhs }, - GetFieldAccessExpr::NamedStructField { name: rhs }, - ) => lhs.eq(rhs), - }) - .unwrap_or(false) - } -} - -/// Expression to get a field of a struct array. -#[derive(Debug, Hash)] -pub struct GetIndexedFieldExpr { - /// The expression to find - arg: Arc, - /// The key statement - field: GetFieldAccessExpr, -} - -impl GetIndexedFieldExpr { - /// Create new [`GetIndexedFieldExpr`] - pub fn new(arg: Arc, field: GetFieldAccessExpr) -> Self { - Self { arg, field } - } - - /// Create a new [`GetIndexedFieldExpr`] for accessing the named field - pub fn new_field(arg: Arc, name: impl Into) -> Self { - Self::new( - arg, - GetFieldAccessExpr::NamedStructField { - name: ScalarValue::from(name.into()), - }, - ) - } - - /// Get the description of what field should be accessed - pub fn field(&self) -> &GetFieldAccessExpr { - &self.field - } - - /// Get the input expression - pub fn arg(&self) -> &Arc { - &self.arg - } - - fn schema_access(&self, _input_schema: &Schema) -> Result { - Ok(match &self.field { - GetFieldAccessExpr::NamedStructField { name } => { - GetFieldAccessSchema::NamedStructField { name: name.clone() } - } - }) - } -} - -impl std::fmt::Display for GetIndexedFieldExpr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "({}).{}", self.arg, self.field) - } -} - -impl PhysicalExpr for GetIndexedFieldExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, input_schema: &Schema) -> Result { - let arg_dt = self.arg.data_type(input_schema)?; - self.schema_access(input_schema)? - .get_accessed_field(&arg_dt) - .map(|f| f.data_type().clone()) - } - - fn nullable(&self, input_schema: &Schema) -> Result { - let arg_dt = self.arg.data_type(input_schema)?; - self.schema_access(input_schema)? - .get_accessed_field(&arg_dt) - .map(|f| f.is_nullable()) - } - - fn evaluate(&self, batch: &RecordBatch) -> Result { - let array = self.arg.evaluate(batch)?.into_array(batch.num_rows())?; - match &self.field { - GetFieldAccessExpr::NamedStructField{name} => match (array.data_type(), name) { - (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { - let map_array = as_map_array(array.as_ref())?; - let key_scalar = Scalar::new(StringArray::from(vec![k.clone()])); - let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; - let entries = arrow::compute::filter(map_array.entries(), &keys)?; - let entries_struct_array = as_struct_array(entries.as_ref())?; - Ok(ColumnarValue::Array(entries_struct_array.column(1).clone())) - } - (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { - let as_struct_array = as_struct_array(&array)?; - match as_struct_array.column_by_name(k) { - None => exec_err!( - "get indexed field {k} not found in struct"), - Some(col) => Ok(ColumnarValue::Array(col.clone())) - } - } - (DataType::Struct(_), name) => exec_err!( - "get indexed field is only possible on struct with utf8 indexes. \ - Tried with {name:?} index"), - (dt, name) => exec_err!( - "get indexed field is only possible on lists with int64 indexes or struct \ - with utf8 indexes. Tried {dt:?} with {name:?} index"), - }, - } - } - - fn children(&self) -> Vec> { - vec![self.arg.clone()] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - Ok(Arc::new(GetIndexedFieldExpr::new( - children[0].clone(), - self.field.clone(), - ))) - } - - fn dyn_hash(&self, state: &mut dyn Hasher) { - let mut s = state; - self.hash(&mut s); - } -} - -impl PartialEq for GetIndexedFieldExpr { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| self.arg.eq(&x.arg) && self.field.eq(&x.field)) - .unwrap_or(false) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::expressions::col; - use arrow::array::ArrayRef; - use arrow::array::{BooleanArray, Int64Array, StructArray}; - use arrow::datatypes::Field; - use arrow::datatypes::Fields; - use datafusion_common::cast::as_boolean_array; - use datafusion_common::Result; - - #[test] - fn get_indexed_field_named_struct_field() -> Result<()> { - let schema = struct_schema(); - let boolean = BooleanArray::from(vec![false, false, true, true]); - let int = Int64Array::from(vec![42, 28, 19, 31]); - let struct_array = StructArray::from(vec![ - ( - Arc::new(Field::new("a", DataType::Boolean, true)), - Arc::new(boolean.clone()) as ArrayRef, - ), - ( - Arc::new(Field::new("b", DataType::Int64, true)), - Arc::new(int) as ArrayRef, - ), - ]); - let expr = col("str", &schema).unwrap(); - // only one row should be processed - let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?; - let expr = Arc::new(GetIndexedFieldExpr::new_field(expr, "a")); - let result = expr - .evaluate(&batch)? - .into_array(1) - .expect("Failed to convert to array"); - let result = - as_boolean_array(&result).expect("failed to downcast to BooleanArray"); - assert_eq!(boolean, result.clone()); - Ok(()) - } - - fn struct_schema() -> Schema { - Schema::new(vec![Field::new_struct( - "str", - Fields::from(vec![ - Field::new("a", DataType::Boolean, true), - Field::new("b", DataType::Int64, true), - ]), - true, - )]) - } -} +// // Licensed to the Apache Software Foundation (ASF) under one +// // or more contributor license agreements. See the NOTICE file +// // distributed with this work for additional information +// // regarding copyright ownership. The ASF licenses this file +// // to you under the Apache License, Version 2.0 (the +// // "License"); you may not use this file except in compliance +// // with the License. You may obtain a copy of the License at +// // +// // http://www.apache.org/licenses/LICENSE-2.0 +// // +// // Unless required by applicable law or agreed to in writing, +// // software distributed under the License is distributed on an +// // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// // KIND, either express or implied. See the License for the +// // specific language governing permissions and limitations +// // under the License. + +// //! get field of a `ListArray` + +// use crate::PhysicalExpr; +// use datafusion_common::exec_err; + +// use crate::physical_expr::down_cast_any_ref; +// use arrow::{ +// array::{Array, Scalar, StringArray}, +// datatypes::{DataType, Schema}, +// record_batch::RecordBatch, +// }; +// use datafusion_common::{ +// cast::{as_map_array, as_struct_array}, +// Result, ScalarValue, +// }; +// use datafusion_expr::{field_util::GetFieldAccessSchema, ColumnarValue}; +// use std::fmt::Debug; +// use std::hash::{Hash, Hasher}; +// use std::{any::Any, sync::Arc}; + +// /// Access a sub field of a nested type, such as `Field` or `List` +// #[derive(Clone, Hash, Debug)] +// pub enum GetFieldAccessExpr { +// /// Named field, For example `struct["name"]` +// NamedStructField { name: ScalarValue }, +// } + +// impl std::fmt::Display for GetFieldAccessExpr { +// fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +// match self { +// GetFieldAccessExpr::NamedStructField { name } => write!(f, "[{}]", name), +// } +// } +// } + +// impl PartialEq for GetFieldAccessExpr { +// fn eq(&self, other: &dyn Any) -> bool { +// down_cast_any_ref(other) +// .downcast_ref::() +// .map(|x| match (self, x) { +// ( +// GetFieldAccessExpr::NamedStructField { name: lhs }, +// GetFieldAccessExpr::NamedStructField { name: rhs }, +// ) => lhs.eq(rhs), +// }) +// .unwrap_or(false) +// } +// } + +// /// Expression to get a field of a struct array. +// #[derive(Debug, Hash)] +// pub struct GetIndexedFieldExpr { +// /// The expression to find +// arg: Arc, +// /// The key statement +// field: GetFieldAccessExpr, +// } + +// impl GetIndexedFieldExpr { +// /// Create new [`GetIndexedFieldExpr`] +// pub fn new(arg: Arc, field: GetFieldAccessExpr) -> Self { +// Self { arg, field } +// } + +// /// Create a new [`GetIndexedFieldExpr`] for accessing the named field +// pub fn new_field(arg: Arc, name: impl Into) -> Self { +// Self::new( +// arg, +// GetFieldAccessExpr::NamedStructField { +// name: ScalarValue::from(name.into()), +// }, +// ) +// } + +// /// Get the description of what field should be accessed +// pub fn field(&self) -> &GetFieldAccessExpr { +// &self.field +// } + +// /// Get the input expression +// pub fn arg(&self) -> &Arc { +// &self.arg +// } + +// fn schema_access(&self, _input_schema: &Schema) -> Result { +// Ok(match &self.field { +// GetFieldAccessExpr::NamedStructField { name } => { +// GetFieldAccessSchema::NamedStructField { name: name.clone() } +// } +// }) +// } +// } + +// impl std::fmt::Display for GetIndexedFieldExpr { +// fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { +// write!(f, "({}).{}", self.arg, self.field) +// } +// } + +// impl PhysicalExpr for GetIndexedFieldExpr { +// fn as_any(&self) -> &dyn Any { +// self +// } + +// fn data_type(&self, input_schema: &Schema) -> Result { +// let arg_dt = self.arg.data_type(input_schema)?; +// self.schema_access(input_schema)? +// .get_accessed_field(&arg_dt) +// .map(|f| f.data_type().clone()) +// } + +// fn nullable(&self, input_schema: &Schema) -> Result { +// let arg_dt = self.arg.data_type(input_schema)?; +// self.schema_access(input_schema)? +// .get_accessed_field(&arg_dt) +// .map(|f| f.is_nullable()) +// } + +// fn evaluate(&self, batch: &RecordBatch) -> Result { +// let array = self.arg.evaluate(batch)?.into_array(batch.num_rows())?; +// match &self.field { +// GetFieldAccessExpr::NamedStructField{name} => match (array.data_type(), name) { +// (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { +// let map_array = as_map_array(array.as_ref())?; +// let key_scalar = Scalar::new(StringArray::from(vec![k.clone()])); +// let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; +// let entries = arrow::compute::filter(map_array.entries(), &keys)?; +// let entries_struct_array = as_struct_array(entries.as_ref())?; +// Ok(ColumnarValue::Array(entries_struct_array.column(1).clone())) +// } +// (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { +// let as_struct_array = as_struct_array(&array)?; +// match as_struct_array.column_by_name(k) { +// None => exec_err!( +// "get indexed field {k} not found in struct"), +// Some(col) => Ok(ColumnarValue::Array(col.clone())) +// } +// } +// (DataType::Struct(_), name) => exec_err!( +// "get indexed field is only possible on struct with utf8 indexes. \ +// Tried with {name:?} index"), +// (dt, name) => exec_err!( +// "get indexed field is only possible on lists with int64 indexes or struct \ +// with utf8 indexes. Tried {dt:?} with {name:?} index"), +// }, +// } +// } + +// fn children(&self) -> Vec> { +// vec![self.arg.clone()] +// } + +// fn with_new_children( +// self: Arc, +// children: Vec>, +// ) -> Result> { +// Ok(Arc::new(GetIndexedFieldExpr::new( +// children[0].clone(), +// self.field.clone(), +// ))) +// } + +// fn dyn_hash(&self, state: &mut dyn Hasher) { +// let mut s = state; +// self.hash(&mut s); +// } +// } + +// impl PartialEq for GetIndexedFieldExpr { +// fn eq(&self, other: &dyn Any) -> bool { +// down_cast_any_ref(other) +// .downcast_ref::() +// .map(|x| self.arg.eq(&x.arg) && self.field.eq(&x.field)) +// .unwrap_or(false) +// } +// } + +// #[cfg(test)] +// mod tests { +// use super::*; +// use crate::expressions::col; +// use arrow::array::ArrayRef; +// use arrow::array::{BooleanArray, Int64Array, StructArray}; +// use arrow::datatypes::Field; +// use arrow::datatypes::Fields; +// use datafusion_common::cast::as_boolean_array; +// use datafusion_common::Result; + +// #[test] +// fn get_indexed_field_named_struct_field() -> Result<()> { +// let schema = struct_schema(); +// let boolean = BooleanArray::from(vec![false, false, true, true]); +// let int = Int64Array::from(vec![42, 28, 19, 31]); +// let struct_array = StructArray::from(vec![ +// ( +// Arc::new(Field::new("a", DataType::Boolean, true)), +// Arc::new(boolean.clone()) as ArrayRef, +// ), +// ( +// Arc::new(Field::new("b", DataType::Int64, true)), +// Arc::new(int) as ArrayRef, +// ), +// ]); +// let expr = col("str", &schema).unwrap(); +// // only one row should be processed +// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?; +// let expr = Arc::new(GetIndexedFieldExpr::new_field(expr, "a")); +// let result = expr +// .evaluate(&batch)? +// .into_array(1) +// .expect("Failed to convert to array"); +// let result = +// as_boolean_array(&result).expect("failed to downcast to BooleanArray"); +// assert_eq!(boolean, result.clone()); +// Ok(()) +// } + +// fn struct_schema() -> Schema { +// Schema::new(vec![Field::new_struct( +// "str", +// Fields::from(vec![ +// Field::new("a", DataType::Boolean, true), +// Field::new("b", DataType::Int64, true), +// ]), +// true, +// )]) +// } +// } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 26d649f572011..ea6c015345595 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -82,7 +82,6 @@ pub use binary::{binary, BinaryExpr}; pub use case::{case, CaseExpr}; pub use cast::{cast, cast_with_options, CastExpr}; pub use column::{col, Column, UnKnownColumn}; -pub use get_indexed_field::{GetFieldAccessExpr, GetIndexedFieldExpr}; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index e6022d383e467..241f01a4170ad 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -15,9 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::expressions::GetFieldAccessExpr; use crate::{ - expressions::{self, binary, like, Column, GetIndexedFieldExpr, Literal}, + expressions::{self, binary, like, Column, Literal}, functions, udf, PhysicalExpr, }; use arrow::datatypes::Schema; @@ -228,10 +227,12 @@ pub fn create_physical_expr( input_dfschema, execution_props, )?), - Expr::GetIndexedField(GetIndexedField { expr, field }) => { - let field = match field { - GetFieldAccess::NamedStructField { name } => { - GetFieldAccessExpr::NamedStructField { name: name.clone() } + Expr::GetIndexedField(GetIndexedField { expr: _, field }) => { + match field { + GetFieldAccess::NamedStructField { name: _ } => { + unreachable!( + "NamedStructField should be rewritten in OperatorToFunction" + ) } GetFieldAccess::ListIndex { key: _ } => { unreachable!("ListIndex should be rewritten in OperatorToFunction") @@ -244,10 +245,6 @@ pub fn create_physical_expr( unreachable!("ListRange should be rewritten in OperatorToFunction") } }; - Ok(Arc::new(GetIndexedFieldExpr::new( - create_physical_expr(expr, input_dfschema, execution_props)?, - field, - ))) } Expr::ScalarFunction(ScalarFunction { func_def, args }) => { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 24ba8b0102e7c..1b4de1ed08496 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1353,8 +1353,6 @@ message PhysicalExprNode { PhysicalScalarUdfNode scalar_udf = 16; PhysicalLikeExprNode like_expr = 18; - - PhysicalGetIndexedFieldExprNode get_indexed_field_expr = 19; } } @@ -1769,17 +1767,4 @@ message ColumnStats { Precision max_value = 2; Precision null_count = 3; Precision distinct_count = 4; -} - -message NamedStructFieldExpr { - ScalarValue name = 1; -} - -message PhysicalGetIndexedFieldExprNode { - PhysicalExprNode arg = 1; - oneof field { - NamedStructFieldExpr named_struct_field_expr = 2; - // 3 was list_index_expr - // 4 was list_range_expr - } -} +} \ No newline at end of file diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index bb8d40c63bdb7..aaef23c184e71 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -14671,97 +14671,6 @@ impl<'de> serde::Deserialize<'de> for NamedStructField { deserializer.deserialize_struct("datafusion.NamedStructField", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for NamedStructFieldExpr { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.name.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.NamedStructFieldExpr", len)?; - if let Some(v) = self.name.as_ref() { - struct_ser.serialize_field("name", v)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for NamedStructFieldExpr { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "name", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Name, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "name" => Ok(GeneratedField::Name), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = NamedStructFieldExpr; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.NamedStructFieldExpr") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut name__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Name => { - if name__.is_some() { - return Err(serde::de::Error::duplicate_field("name")); - } - name__ = map_.next_value()?; - } - } - } - Ok(NamedStructFieldExpr { - name: name__, - }) - } - } - deserializer.deserialize_struct("datafusion.NamedStructFieldExpr", FIELDS, GeneratedVisitor) - } -} impl serde::Serialize for NegativeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -17541,9 +17450,6 @@ impl serde::Serialize for PhysicalExprNode { physical_expr_node::ExprType::LikeExpr(v) => { struct_ser.serialize_field("likeExpr", v)?; } - physical_expr_node::ExprType::GetIndexedFieldExpr(v) => { - struct_ser.serialize_field("getIndexedFieldExpr", v)?; - } } } struct_ser.end() @@ -17585,8 +17491,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "scalarUdf", "like_expr", "likeExpr", - "get_indexed_field_expr", - "getIndexedFieldExpr", ]; #[allow(clippy::enum_variant_names)] @@ -17608,7 +17512,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { WindowExpr, ScalarUdf, LikeExpr, - GetIndexedFieldExpr, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -17647,7 +17550,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { "windowExpr" | "window_expr" => Ok(GeneratedField::WindowExpr), "scalarUdf" | "scalar_udf" => Ok(GeneratedField::ScalarUdf), "likeExpr" | "like_expr" => Ok(GeneratedField::LikeExpr), - "getIndexedFieldExpr" | "get_indexed_field_expr" => Ok(GeneratedField::GetIndexedFieldExpr), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -17787,13 +17689,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalExprNode { return Err(serde::de::Error::duplicate_field("likeExpr")); } expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::LikeExpr) -; - } - GeneratedField::GetIndexedFieldExpr => { - if expr_type__.is_some() { - return Err(serde::de::Error::duplicate_field("getIndexedFieldExpr")); - } - expr_type__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_expr_node::ExprType::GetIndexedFieldExpr) ; } } @@ -17917,120 +17812,6 @@ impl<'de> serde::Deserialize<'de> for PhysicalExtensionNode { deserializer.deserialize_struct("datafusion.PhysicalExtensionNode", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PhysicalGetIndexedFieldExprNode { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.arg.is_some() { - len += 1; - } - if self.field.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("datafusion.PhysicalGetIndexedFieldExprNode", len)?; - if let Some(v) = self.arg.as_ref() { - struct_ser.serialize_field("arg", v)?; - } - if let Some(v) = self.field.as_ref() { - match v { - physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr(v) => { - struct_ser.serialize_field("namedStructFieldExpr", v)?; - } - } - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for PhysicalGetIndexedFieldExprNode { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "arg", - "named_struct_field_expr", - "namedStructFieldExpr", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Arg, - NamedStructFieldExpr, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "arg" => Ok(GeneratedField::Arg), - "namedStructFieldExpr" | "named_struct_field_expr" => Ok(GeneratedField::NamedStructFieldExpr), - _ => Err(serde::de::Error::unknown_field(value, FIELDS)), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PhysicalGetIndexedFieldExprNode; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct datafusion.PhysicalGetIndexedFieldExprNode") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut arg__ = None; - let mut field__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Arg => { - if arg__.is_some() { - return Err(serde::de::Error::duplicate_field("arg")); - } - arg__ = map_.next_value()?; - } - GeneratedField::NamedStructFieldExpr => { - if field__.is_some() { - return Err(serde::de::Error::duplicate_field("namedStructFieldExpr")); - } - field__ = map_.next_value::<::std::option::Option<_>>()?.map(physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr) -; - } - } - } - Ok(PhysicalGetIndexedFieldExprNode { - arg: arg__, - field: field__, - }) - } - } - deserializer.deserialize_struct("datafusion.PhysicalGetIndexedFieldExprNode", FIELDS, GeneratedVisitor) - } -} impl serde::Serialize for PhysicalHashRepartition { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 9742c55474b9f..bf99bfd312baf 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1813,7 +1813,7 @@ pub struct PhysicalExtensionNode { pub struct PhysicalExprNode { #[prost( oneof = "physical_expr_node::ExprType", - tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 19" + tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18" )] pub expr_type: ::core::option::Option, } @@ -1861,10 +1861,6 @@ pub mod physical_expr_node { ScalarUdf(super::PhysicalScalarUdfNode), #[prost(message, tag = "18")] LikeExpr(::prost::alloc::boxed::Box), - #[prost(message, tag = "19")] - GetIndexedFieldExpr( - ::prost::alloc::boxed::Box, - ), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -2521,31 +2517,6 @@ pub struct ColumnStats { #[prost(message, optional, tag = "4")] pub distinct_count: ::core::option::Option, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct NamedStructFieldExpr { - #[prost(message, optional, tag = "1")] - pub name: ::core::option::Option, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PhysicalGetIndexedFieldExprNode { - #[prost(message, optional, boxed, tag = "1")] - pub arg: ::core::option::Option<::prost::alloc::boxed::Box>, - #[prost(oneof = "physical_get_indexed_field_expr_node::Field", tags = "2")] - pub field: ::core::option::Option, -} -/// Nested message and enum types in `PhysicalGetIndexedFieldExprNode`. -pub mod physical_get_indexed_field_expr_node { - #[allow(clippy::derive_partial_eq_without_eq)] - #[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Field { - /// 3 was list_index_expr - /// 4 was list_range_expr - #[prost(message, tag = "2")] - NamedStructFieldExpr(super::NamedStructFieldExpr), - } -} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum JoinType { diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index d3b41f114fbac..162dd9da279e2 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -37,7 +37,6 @@ use datafusion::physical_plan::expressions::{ in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, Literal, NegativeExpr, NotExpr, TryCastExpr, }; -use datafusion::physical_plan::expressions::{GetFieldAccessExpr, GetIndexedFieldExpr}; use datafusion::physical_plan::windows::create_window_expr; use datafusion::physical_plan::{ functions, ColumnStatistics, Partitioning, PhysicalExpr, Statistics, WindowExpr, @@ -384,27 +383,6 @@ pub fn parse_physical_expr( input_schema, )?, )), - ExprType::GetIndexedFieldExpr(get_indexed_field_expr) => { - let field = match &get_indexed_field_expr.field { - Some(protobuf::physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr(named_struct_field_expr)) => GetFieldAccessExpr::NamedStructField{ - name: convert_required!(named_struct_field_expr.name)?, - }, - None => - return Err(proto_error( - "Field must not be None", - )), - }; - - Arc::new(GetIndexedFieldExpr::new( - parse_required_physical_expr( - get_indexed_field_expr.arg.as_deref(), - registry, - "arg", - input_schema, - )?, - field, - )) - } }; Ok(pexpr) diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index da4e87b7a8535..d3b91572b0177 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -40,7 +40,6 @@ use datafusion::datasource::{ physical_plan::FileSinkConfig, }; use datafusion::logical_expr::BuiltinScalarFunction; -use datafusion::physical_expr::expressions::{GetFieldAccessExpr, GetIndexedFieldExpr}; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ @@ -552,25 +551,6 @@ impl TryFrom> for protobuf::PhysicalExprNode { }), )), }) - } else if let Some(expr) = expr.downcast_ref::() { - let field = match expr.field() { - GetFieldAccessExpr::NamedStructField{name} => Some( - protobuf::physical_get_indexed_field_expr_node::Field::NamedStructFieldExpr(protobuf::NamedStructFieldExpr { - name: Some(ScalarValue::try_from(name)?) - }) - ), - }; - - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::GetIndexedFieldExpr( - Box::new(protobuf::PhysicalGetIndexedFieldExprNode { - arg: Some(Box::new(expr.arg().to_owned().try_into()?)), - field, - }), - ), - ), - }) } else { internal_err!("physical_plan::to_proto() unsupported expression {value:?}") } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index a3c0b3eccd3c1..9ea1c1762a274 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -18,7 +18,7 @@ use arrow::csv::WriterBuilder; use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; -use datafusion::arrow::datatypes::{DataType, Field, Fields, IntervalUnit, Schema}; +use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; use datafusion::datasource::file_format::csv::CsvSink; use datafusion::datasource::file_format::json::JsonSink; use datafusion::datasource::file_format::parquet::ParquetSink; @@ -43,8 +43,7 @@ use datafusion::physical_plan::analyze::AnalyzeExec; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::expressions::{ binary, cast, col, in_list, like, lit, Avg, BinaryExpr, Column, DistinctCount, - GetFieldAccessExpr, GetIndexedFieldExpr, NotExpr, NthValue, PhysicalSortExpr, - StringAgg, Sum, + NotExpr, NthValue, PhysicalSortExpr, StringAgg, Sum, }; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::functions; @@ -705,36 +704,6 @@ fn roundtrip_like() -> Result<()> { roundtrip_test(plan) } -#[test] -fn roundtrip_get_indexed_field_named_struct_field() -> Result<()> { - let fields = vec![ - Field::new("id", DataType::Int64, true), - Field::new_struct( - "arg", - Fields::from(vec![Field::new("name", DataType::Float64, true)]), - true, - ), - ]; - - let schema = Schema::new(fields); - let input = Arc::new(EmptyExec::new(Arc::new(schema.clone()))); - - let col_arg = col("arg", &schema)?; - let get_indexed_field_expr = Arc::new(GetIndexedFieldExpr::new( - col_arg, - GetFieldAccessExpr::NamedStructField { - name: ScalarValue::from("name"), - }, - )); - - let plan = Arc::new(ProjectionExec::try_new( - vec![(get_indexed_field_expr, "result".to_string())], - input, - )?); - - roundtrip_test(plan) -} - #[test] fn roundtrip_analyze() -> Result<()> { let field_a = Field::new("plan_type", DataType::Utf8, false); From dfb50275a0ec9d384007cb7ba77b18a6ca01436d Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Mon, 11 Mar 2024 18:44:34 -0500 Subject: [PATCH 02/16] fix typo --- datafusion/core/src/physical_planner.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 3d6b8054c83f2..6b746c737e6a2 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -225,7 +225,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result { stride: _, } => { unreachable!( - "ListIndex should have been rewritten in OperatorToFunction" + "ListRange should have been rewritten in OperatorToFunction" ) } }; From 2543d709647161009c68bed095636a1157fa2987 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Mon, 11 Mar 2024 18:45:38 -0500 Subject: [PATCH 03/16] delete indexed_field file --- .../src/expressions/get_indexed_field.rs | 245 ------------------ 1 file changed, 245 deletions(-) delete mode 100644 datafusion/physical-expr/src/expressions/get_indexed_field.rs diff --git a/datafusion/physical-expr/src/expressions/get_indexed_field.rs b/datafusion/physical-expr/src/expressions/get_indexed_field.rs deleted file mode 100644 index 6edfd41aeb45d..0000000000000 --- a/datafusion/physical-expr/src/expressions/get_indexed_field.rs +++ /dev/null @@ -1,245 +0,0 @@ -// // Licensed to the Apache Software Foundation (ASF) under one -// // or more contributor license agreements. See the NOTICE file -// // distributed with this work for additional information -// // regarding copyright ownership. The ASF licenses this file -// // to you under the Apache License, Version 2.0 (the -// // "License"); you may not use this file except in compliance -// // with the License. You may obtain a copy of the License at -// // -// // http://www.apache.org/licenses/LICENSE-2.0 -// // -// // Unless required by applicable law or agreed to in writing, -// // software distributed under the License is distributed on an -// // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// // KIND, either express or implied. See the License for the -// // specific language governing permissions and limitations -// // under the License. - -// //! get field of a `ListArray` - -// use crate::PhysicalExpr; -// use datafusion_common::exec_err; - -// use crate::physical_expr::down_cast_any_ref; -// use arrow::{ -// array::{Array, Scalar, StringArray}, -// datatypes::{DataType, Schema}, -// record_batch::RecordBatch, -// }; -// use datafusion_common::{ -// cast::{as_map_array, as_struct_array}, -// Result, ScalarValue, -// }; -// use datafusion_expr::{field_util::GetFieldAccessSchema, ColumnarValue}; -// use std::fmt::Debug; -// use std::hash::{Hash, Hasher}; -// use std::{any::Any, sync::Arc}; - -// /// Access a sub field of a nested type, such as `Field` or `List` -// #[derive(Clone, Hash, Debug)] -// pub enum GetFieldAccessExpr { -// /// Named field, For example `struct["name"]` -// NamedStructField { name: ScalarValue }, -// } - -// impl std::fmt::Display for GetFieldAccessExpr { -// fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { -// match self { -// GetFieldAccessExpr::NamedStructField { name } => write!(f, "[{}]", name), -// } -// } -// } - -// impl PartialEq for GetFieldAccessExpr { -// fn eq(&self, other: &dyn Any) -> bool { -// down_cast_any_ref(other) -// .downcast_ref::() -// .map(|x| match (self, x) { -// ( -// GetFieldAccessExpr::NamedStructField { name: lhs }, -// GetFieldAccessExpr::NamedStructField { name: rhs }, -// ) => lhs.eq(rhs), -// }) -// .unwrap_or(false) -// } -// } - -// /// Expression to get a field of a struct array. -// #[derive(Debug, Hash)] -// pub struct GetIndexedFieldExpr { -// /// The expression to find -// arg: Arc, -// /// The key statement -// field: GetFieldAccessExpr, -// } - -// impl GetIndexedFieldExpr { -// /// Create new [`GetIndexedFieldExpr`] -// pub fn new(arg: Arc, field: GetFieldAccessExpr) -> Self { -// Self { arg, field } -// } - -// /// Create a new [`GetIndexedFieldExpr`] for accessing the named field -// pub fn new_field(arg: Arc, name: impl Into) -> Self { -// Self::new( -// arg, -// GetFieldAccessExpr::NamedStructField { -// name: ScalarValue::from(name.into()), -// }, -// ) -// } - -// /// Get the description of what field should be accessed -// pub fn field(&self) -> &GetFieldAccessExpr { -// &self.field -// } - -// /// Get the input expression -// pub fn arg(&self) -> &Arc { -// &self.arg -// } - -// fn schema_access(&self, _input_schema: &Schema) -> Result { -// Ok(match &self.field { -// GetFieldAccessExpr::NamedStructField { name } => { -// GetFieldAccessSchema::NamedStructField { name: name.clone() } -// } -// }) -// } -// } - -// impl std::fmt::Display for GetIndexedFieldExpr { -// fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { -// write!(f, "({}).{}", self.arg, self.field) -// } -// } - -// impl PhysicalExpr for GetIndexedFieldExpr { -// fn as_any(&self) -> &dyn Any { -// self -// } - -// fn data_type(&self, input_schema: &Schema) -> Result { -// let arg_dt = self.arg.data_type(input_schema)?; -// self.schema_access(input_schema)? -// .get_accessed_field(&arg_dt) -// .map(|f| f.data_type().clone()) -// } - -// fn nullable(&self, input_schema: &Schema) -> Result { -// let arg_dt = self.arg.data_type(input_schema)?; -// self.schema_access(input_schema)? -// .get_accessed_field(&arg_dt) -// .map(|f| f.is_nullable()) -// } - -// fn evaluate(&self, batch: &RecordBatch) -> Result { -// let array = self.arg.evaluate(batch)?.into_array(batch.num_rows())?; -// match &self.field { -// GetFieldAccessExpr::NamedStructField{name} => match (array.data_type(), name) { -// (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { -// let map_array = as_map_array(array.as_ref())?; -// let key_scalar = Scalar::new(StringArray::from(vec![k.clone()])); -// let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; -// let entries = arrow::compute::filter(map_array.entries(), &keys)?; -// let entries_struct_array = as_struct_array(entries.as_ref())?; -// Ok(ColumnarValue::Array(entries_struct_array.column(1).clone())) -// } -// (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { -// let as_struct_array = as_struct_array(&array)?; -// match as_struct_array.column_by_name(k) { -// None => exec_err!( -// "get indexed field {k} not found in struct"), -// Some(col) => Ok(ColumnarValue::Array(col.clone())) -// } -// } -// (DataType::Struct(_), name) => exec_err!( -// "get indexed field is only possible on struct with utf8 indexes. \ -// Tried with {name:?} index"), -// (dt, name) => exec_err!( -// "get indexed field is only possible on lists with int64 indexes or struct \ -// with utf8 indexes. Tried {dt:?} with {name:?} index"), -// }, -// } -// } - -// fn children(&self) -> Vec> { -// vec![self.arg.clone()] -// } - -// fn with_new_children( -// self: Arc, -// children: Vec>, -// ) -> Result> { -// Ok(Arc::new(GetIndexedFieldExpr::new( -// children[0].clone(), -// self.field.clone(), -// ))) -// } - -// fn dyn_hash(&self, state: &mut dyn Hasher) { -// let mut s = state; -// self.hash(&mut s); -// } -// } - -// impl PartialEq for GetIndexedFieldExpr { -// fn eq(&self, other: &dyn Any) -> bool { -// down_cast_any_ref(other) -// .downcast_ref::() -// .map(|x| self.arg.eq(&x.arg) && self.field.eq(&x.field)) -// .unwrap_or(false) -// } -// } - -// #[cfg(test)] -// mod tests { -// use super::*; -// use crate::expressions::col; -// use arrow::array::ArrayRef; -// use arrow::array::{BooleanArray, Int64Array, StructArray}; -// use arrow::datatypes::Field; -// use arrow::datatypes::Fields; -// use datafusion_common::cast::as_boolean_array; -// use datafusion_common::Result; - -// #[test] -// fn get_indexed_field_named_struct_field() -> Result<()> { -// let schema = struct_schema(); -// let boolean = BooleanArray::from(vec![false, false, true, true]); -// let int = Int64Array::from(vec![42, 28, 19, 31]); -// let struct_array = StructArray::from(vec![ -// ( -// Arc::new(Field::new("a", DataType::Boolean, true)), -// Arc::new(boolean.clone()) as ArrayRef, -// ), -// ( -// Arc::new(Field::new("b", DataType::Int64, true)), -// Arc::new(int) as ArrayRef, -// ), -// ]); -// let expr = col("str", &schema).unwrap(); -// // only one row should be processed -// let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?; -// let expr = Arc::new(GetIndexedFieldExpr::new_field(expr, "a")); -// let result = expr -// .evaluate(&batch)? -// .into_array(1) -// .expect("Failed to convert to array"); -// let result = -// as_boolean_array(&result).expect("failed to downcast to BooleanArray"); -// assert_eq!(boolean, result.clone()); -// Ok(()) -// } - -// fn struct_schema() -> Schema { -// Schema::new(vec![Field::new_struct( -// "str", -// Fields::from(vec![ -// Field::new("a", DataType::Boolean, true), -// Field::new("b", DataType::Int64, true), -// ]), -// true, -// )]) -// } -// } From 69a014ee754ccabaad9f4589240dfb61b87bb199 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Mon, 11 Mar 2024 18:54:44 -0500 Subject: [PATCH 04/16] fix cargo check --- datafusion/physical-expr/src/expressions/mod.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index ea6c015345595..7c4ea07dfbcb2 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -23,7 +23,6 @@ mod case; mod cast; mod column; mod datum; -mod get_indexed_field; mod in_list; mod is_not_null; mod is_null; From d2873c78f0a05e855b9cc33663faf08ac52a9893 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Mon, 11 Mar 2024 19:07:46 -0500 Subject: [PATCH 05/16] fix cargo check --- datafusion/optimizer/src/analyzer/rewrite_expr.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs index 04c8622cd33d3..c9a357e1926d2 100644 --- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -17,7 +17,6 @@ //! Analyzer rule for to replace operators with function calls (e.g `||` to array_concat`) -#[cfg(feature = "array_expressions")] use std::sync::Arc; use super::AnalyzerRule; From 3decd0023754483a87075d3a1968279211608a4f Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Mon, 11 Mar 2024 19:29:55 -0500 Subject: [PATCH 06/16] cargo update in CLI --- datafusion-cli/Cargo.lock | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 9afd78d1cc882..9ff66444f1f93 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1285,6 +1285,7 @@ dependencies = [ "chrono", "datafusion-common", "datafusion-expr", + "datafusion-functions", "datafusion-physical-expr", "hashbrown 0.14.3", "itertools", From 9b001714164a79b1c7408f56f7a64cf8e5c394e0 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Mon, 11 Mar 2024 23:00:18 -0500 Subject: [PATCH 07/16] feat: add getfield func --- datafusion/functions/src/core/getfield.rs | 165 ++++++++++++++++++ datafusion/functions/src/core/mod.rs | 7 +- datafusion/functions/src/core/struct.rs | 8 +- .../optimizer/src/analyzer/rewrite_expr.rs | 2 +- 4 files changed, 172 insertions(+), 10 deletions(-) create mode 100644 datafusion/functions/src/core/getfield.rs diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs new file mode 100644 index 0000000000000..11b4d5367b983 --- /dev/null +++ b/datafusion/functions/src/core/getfield.rs @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; +use datafusion_common::{exec_err, ExprSchema, Result, ScalarValue}; +use datafusion_expr::{ColumnarValue, Expr}; +use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; + +#[derive(Debug)] +pub struct GetFieldFunc { + signature: Signature, +} + +impl GetFieldFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } +} +impl Default for GetFieldFunc { + fn default() -> Self { + Self::new() + } +} + +// get_field(struct_array, field_name) +impl ScalarUDFImpl for GetFieldFunc { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "get_field" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> Result { + todo!() + } + + fn return_type_from_exprs( + &self, + args: &[Expr], + schema: &dyn ExprSchema, + _arg_types: &[DataType], + ) -> Result { + if args.len() != 2 { + return exec_err!( + "get_field function requires 2 arguments, got {}", + args.len() + ); + } + + match &args[0] { + Expr::Column(name) => { + let data_type = schema.data_type(name)?; + match data_type { + DataType::Struct(fields) => { + let field_name = match &args[1] { + Expr::Literal(ScalarValue::Utf8(name)) => { + name.as_ref().map(|x| x.as_str()) + } + _ => { + return exec_err!( + "get_field function requires the argument field_name to be a string" + ); + } + }; + let field = + fields.iter().find(|f| f.name() == field_name.unwrap()); + match field { + Some(field) => Ok(field.data_type().clone()), + None => { + exec_err!( + "get_field function can't find the field {} in the struct", field_name.unwrap() + ) + } + } + } + _ => { + exec_err!( + "get_field function requires the column to have struct type" + ) + } + } + } + _ => { + exec_err!( + "get_field function requires the first argument to be struct array" + ) + } + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + if args.len() != 2 { + return exec_err!( + "get_field function requires 2 arguments, got {}", + args.len() + ); + } + + match &args[0] { + ColumnarValue::Array(array) => { + let struct_array = match array + .as_any() + .downcast_ref::() + { + Some(struct_array) => struct_array, + None => { + return exec_err!( + "get_field function requires the first argument to be struct array" + ); + } + }; + match &args[1] { + ColumnarValue::Scalar(scalar) => { + let column_name = match scalar { + ScalarValue::Utf8(name) => name.as_ref().map(|x| x.as_str()), + _ => { + return exec_err!( + "get_field function requires the argument field_name to be a string" + ); + } + }; + Ok(ColumnarValue::Array( + struct_array + .column_by_name(column_name.unwrap()) + .unwrap() + .clone(), + )) + } + _ => { + exec_err!( + "get_field function requires the argument field_name to be a string" + ) + } + } + } + _ => { + exec_err!( + "get_field function requires the first argument to be struct array" + ) + } + } + } +} diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 3f13067a4a074..c866513de05af 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -18,10 +18,11 @@ //! "core" DataFusion functions mod arrowtypeof; +pub mod getfield; mod nullif; mod nvl; mod nvl2; -pub mod r#struct; +mod r#struct; // create UDFs make_udf_function!(nullif::NullIfFunc, NULLIF, nullif); @@ -29,6 +30,7 @@ make_udf_function!(nvl::NVLFunc, NVL, nvl); make_udf_function!(nvl2::NVL2Func, NVL2, nvl2); make_udf_function!(arrowtypeof::ArrowTypeOfFunc, ARROWTYPEOF, arrow_typeof); make_udf_function!(r#struct::StructFunc, STRUCT, r#struct); +make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field); // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( @@ -36,5 +38,6 @@ export_functions!( (nvl, arg_1 arg_2, "returns value2 if value1 is NULL; otherwise it returns value1"), (nvl2, arg_1 arg_2 arg_3, "Returns value2 if value1 is not NULL; otherwise, it returns value3."), (arrow_typeof, arg_1, "Returns the Arrow type of the input expression."), - (r#struct, args, "Returns a struct with the given arguments") + (r#struct, args, "Returns a struct with the given arguments"), + (get_field, arg_1 arg_2, "Returns the value of the field with the given name from the struct") ); diff --git a/datafusion/functions/src/core/struct.rs b/datafusion/functions/src/core/struct.rs index 6236f98794bba..406e402ccd850 100644 --- a/datafusion/functions/src/core/struct.rs +++ b/datafusion/functions/src/core/struct.rs @@ -61,7 +61,7 @@ fn struct_expr(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?)) } #[derive(Debug)] -pub struct StructFunc { +pub(super) struct StructFunc { signature: Signature, } @@ -73,12 +73,6 @@ impl StructFunc { } } -impl Default for StructFunc { - fn default() -> Self { - Self::new() - } -} - impl ScalarUDFImpl for StructFunc { fn as_any(&self) -> &dyn Any { self diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs index c9a357e1926d2..4a34039f2e108 100644 --- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs +++ b/datafusion/optimizer/src/analyzer/rewrite_expr.rs @@ -143,7 +143,7 @@ impl TreeNodeRewriter for OperatorToFunctionRewriter { return Ok(Transformed::yes(Expr::ScalarFunction( ScalarFunction::new_udf( Arc::new(ScalarUDF::new_from_impl( - datafusion_functions::core::r#struct::StructFunc::new(), + datafusion_functions::core::getfield::GetFieldFunc::new(), )), args, ), From 0cf493652a5e5949286e56ec1bc71934a01d69f7 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Tue, 12 Mar 2024 00:06:27 -0500 Subject: [PATCH 08/16] fix struct fun --- datafusion/functions/src/core/getfield.rs | 102 ++++++++++++++++++++-- 1 file changed, 95 insertions(+), 7 deletions(-) diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 11b4d5367b983..c454473552571 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -102,9 +102,72 @@ impl ScalarUDFImpl for GetFieldFunc { } } } + Expr::ScalarFunction(fun) => { + let index = match &args[1] { + Expr::Literal(ScalarValue::Utf8(name)) => { + name.as_ref().map(|x| x.as_str()).unwrap()[1..] + .parse::() + .unwrap() + } + _ => { + return exec_err!( + "get_field function requires the argument field_name to be a string" + ); + } + }; + if let Some(expr) = fun.args.get(index) { + match expr { + Expr::Literal(scalar) => { + println!("{:?}", scalar); + Ok(scalar.data_type().clone()) + } + _ => { + exec_err!( + "get_field function requires the first argument to be struct array 1" + ) + } + } + } else { + exec_err!( + "get_field function requires the first argument to be struct array 2" + ) + } + } + Expr::Literal(scalar) => match scalar { + ScalarValue::Struct(struct_array) => { + let field_name = match &args[1] { + Expr::Literal(ScalarValue::Utf8(name)) => { + name.as_ref().map(|x| x.as_str()) + } + _ => { + return exec_err!( + "get_field function requires the argument field_name to be a string" + ); + } + }; + println!("{}", field_name.unwrap()); + let field = struct_array + .fields() + .iter() + .find(|f| f.name() == field_name.unwrap()); + match field { + Some(field) => Ok(field.data_type().clone()), + None => { + exec_err!( + "get_field function can't find the field {} in the struct", field_name.unwrap() + ) + } + } + } + _ => { + exec_err!( + "get_field function requires the first argument to be struct array 3" + ) + } + }, _ => { exec_err!( - "get_field function requires the first argument to be struct array" + "get_field function requires the first argument to be struct array 4" ) } } @@ -127,7 +190,7 @@ impl ScalarUDFImpl for GetFieldFunc { Some(struct_array) => struct_array, None => { return exec_err!( - "get_field function requires the first argument to be struct array" + "get_field function requires the first argument to be struct array, 5" ); } }; @@ -155,11 +218,36 @@ impl ScalarUDFImpl for GetFieldFunc { } } } - _ => { - exec_err!( - "get_field function requires the first argument to be struct array" - ) - } + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Struct(struct_array) => { + let column_name = match &args[1] { + ColumnarValue::Scalar(scalar) => match scalar { + ScalarValue::Utf8(name) => name.as_ref().map(|x| x.as_str()), + _ => { + return exec_err!( + "get_field function requires the argument field_name to be a string" + ); + } + }, + _ => { + return exec_err!( + "get_field function requires the argument field_name to be a string" + ); + } + }; + Ok(ColumnarValue::Array( + struct_array + .column_by_name(column_name.unwrap()) + .unwrap() + .clone(), + )) + } + _ => { + exec_err!( + "get_field function requires the first argument to be struct array" + ) + } + }, } } } From 3050a2e5fdedce62f585e9e50872e7f22a48c0e6 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Tue, 12 Mar 2024 10:00:19 -0500 Subject: [PATCH 09/16] stage commit --- datafusion/core/src/dataframe/mod.rs | 23 ++++++++++++++ datafusion/functions/src/core/getfield.rs | 37 ++++++++++++++++++++++- 2 files changed, 59 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 3bdf2af4552de..e22a662e2bdf1 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1591,6 +1591,29 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_my() -> Result<()> { + let ctx = SessionContext::new(); + + let create_table_query = r#" + CREATE TABLE values( + a INT, + b FLOAT, + c VARCHAR + ) AS VALUES + (1, 1.1, 'a'), + (2, 2.2, 'b'), + (3, 3.3, 'c') + "#; + ctx.sql(create_table_query).await?; + + let query = r#"select struct(a, b, c)['c1'] from values"#; + + let result = ctx.sql(query).await?; + result.show().await?; + Ok(()) + } + #[tokio::test] async fn test_array_agg_schema() -> Result<()> { let ctx = SessionContext::new(); diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index c454473552571..34ddad825c68e 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -95,9 +95,40 @@ impl ScalarUDFImpl for GetFieldFunc { } } } + DataType::Map(a , b) =>{ + match a.data_type() { + DataType::Struct(fields) => { + let field_name = match &args[1] { + Expr::Literal(ScalarValue::Utf8(name)) => { + name.as_ref().map(|x| x.as_str()) + } + _ => { + return exec_err!( + "get_field function requires the argument field_name to be a string" + ); + } + }; + let field = + fields.iter().find(|f| f.name() == field_name.unwrap()); + match field { + Some(field) => Ok(field.data_type().clone()), + None => { + exec_err!( + "get_field function can't find the field {} in the struct", field_name.unwrap() + ) + } + } + } + _ => { + exec_err!( + "get_field function requires the first argument to be struct array" + ) + } + } + } _ => { exec_err!( - "get_field function requires the column to have struct type" + "get_field function requires the first argument to be struct array" ) } } @@ -121,6 +152,10 @@ impl ScalarUDFImpl for GetFieldFunc { println!("{:?}", scalar); Ok(scalar.data_type().clone()) } + Expr::Column(name) => { + let data_type = schema.data_type(name)?; + Ok(data_type.clone()) + } _ => { exec_err!( "get_field function requires the first argument to be struct array 1" From c8345975050dacba2f8d2ab41b9864bdee034c12 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Tue, 12 Mar 2024 10:33:35 -0500 Subject: [PATCH 10/16] fix test --- datafusion/functions/src/core/getfield.rs | 253 +++++----------------- 1 file changed, 52 insertions(+), 201 deletions(-) diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 34ddad825c68e..ccd6f5a0a5ee8 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -16,8 +16,11 @@ // under the License. use arrow::datatypes::DataType; +use arrow_array::{Scalar, StringArray}; +use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{exec_err, ExprSchema, Result, ScalarValue}; -use datafusion_expr::{ColumnarValue, Expr}; +use datafusion_expr::field_util::GetFieldAccessSchema; +use datafusion_expr::{ColumnarValue, Expr, ExprSchemable}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; @@ -69,143 +72,19 @@ impl ScalarUDFImpl for GetFieldFunc { ); } - match &args[0] { - Expr::Column(name) => { - let data_type = schema.data_type(name)?; - match data_type { - DataType::Struct(fields) => { - let field_name = match &args[1] { - Expr::Literal(ScalarValue::Utf8(name)) => { - name.as_ref().map(|x| x.as_str()) - } - _ => { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); - } - }; - let field = - fields.iter().find(|f| f.name() == field_name.unwrap()); - match field { - Some(field) => Ok(field.data_type().clone()), - None => { - exec_err!( - "get_field function can't find the field {} in the struct", field_name.unwrap() - ) - } - } - } - DataType::Map(a , b) =>{ - match a.data_type() { - DataType::Struct(fields) => { - let field_name = match &args[1] { - Expr::Literal(ScalarValue::Utf8(name)) => { - name.as_ref().map(|x| x.as_str()) - } - _ => { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); - } - }; - let field = - fields.iter().find(|f| f.name() == field_name.unwrap()); - match field { - Some(field) => Ok(field.data_type().clone()), - None => { - exec_err!( - "get_field function can't find the field {} in the struct", field_name.unwrap() - ) - } - } - } - _ => { - exec_err!( - "get_field function requires the first argument to be struct array" - ) - } - } - } - _ => { - exec_err!( - "get_field function requires the first argument to be struct array" - ) - } - } - } - Expr::ScalarFunction(fun) => { - let index = match &args[1] { - Expr::Literal(ScalarValue::Utf8(name)) => { - name.as_ref().map(|x| x.as_str()).unwrap()[1..] - .parse::() - .unwrap() - } - _ => { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); - } - }; - if let Some(expr) = fun.args.get(index) { - match expr { - Expr::Literal(scalar) => { - println!("{:?}", scalar); - Ok(scalar.data_type().clone()) - } - Expr::Column(name) => { - let data_type = schema.data_type(name)?; - Ok(data_type.clone()) - } - _ => { - exec_err!( - "get_field function requires the first argument to be struct array 1" - ) - } - } - } else { - exec_err!( - "get_field function requires the first argument to be struct array 2" - ) - } - } - Expr::Literal(scalar) => match scalar { - ScalarValue::Struct(struct_array) => { - let field_name = match &args[1] { - Expr::Literal(ScalarValue::Utf8(name)) => { - name.as_ref().map(|x| x.as_str()) - } - _ => { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); - } - }; - println!("{}", field_name.unwrap()); - let field = struct_array - .fields() - .iter() - .find(|f| f.name() == field_name.unwrap()); - match field { - Some(field) => Ok(field.data_type().clone()), - None => { - exec_err!( - "get_field function can't find the field {} in the struct", field_name.unwrap() - ) - } - } - } - _ => { - exec_err!( - "get_field function requires the first argument to be struct array 3" - ) - } - }, + let name = match &args[1] { + Expr::Literal(name) => name, _ => { - exec_err!( - "get_field function requires the first argument to be struct array 4" - ) + return exec_err!( + "get_field function requires the argument field_name to be a string" + ); } - } + }; + let access_schema = GetFieldAccessSchema::NamedStructField { name: name.clone() }; + let arg_dt = args[0].get_type(schema)?; + access_schema + .get_accessed_field(&arg_dt) + .map(|f| f.data_type().clone()) } fn invoke(&self, args: &[ColumnarValue]) -> Result { @@ -216,73 +95,45 @@ impl ScalarUDFImpl for GetFieldFunc { ); } - match &args[0] { - ColumnarValue::Array(array) => { - let struct_array = match array - .as_any() - .downcast_ref::() - { - Some(struct_array) => struct_array, - None => { - return exec_err!( - "get_field function requires the first argument to be struct array, 5" - ); - } - }; - match &args[1] { - ColumnarValue::Scalar(scalar) => { - let column_name = match scalar { - ScalarValue::Utf8(name) => name.as_ref().map(|x| x.as_str()), - _ => { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); - } - }; - Ok(ColumnarValue::Array( - struct_array - .column_by_name(column_name.unwrap()) - .unwrap() - .clone(), - )) - } - _ => { - exec_err!( - "get_field function requires the argument field_name to be a string" - ) - } - } + let arr; + let array = match &args[0] { + ColumnarValue::Array(array) => array, + ColumnarValue::Scalar(scalar) => { + arr = scalar.clone().to_array()?; + &arr } - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Struct(struct_array) => { - let column_name = match &args[1] { - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(name) => name.as_ref().map(|x| x.as_str()), - _ => { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); - } - }, - _ => { - return exec_err!( - "get_field function requires the argument field_name to be a string" - ); - } - }; - Ok(ColumnarValue::Array( - struct_array - .column_by_name(column_name.unwrap()) - .unwrap() - .clone(), - )) + }; + let name = match &args[1] { + ColumnarValue::Scalar(name) => name, + _ => { + return exec_err!( + "get_field function requires the argument field_name to be a string" + ); + } + }; + match (array.data_type(), name) { + (DataType::Map(_, _), ScalarValue::Utf8(Some(k))) => { + let map_array = as_map_array(array.as_ref())?; + let key_scalar = Scalar::new(StringArray::from(vec![k.clone()])); + let keys = arrow::compute::kernels::cmp::eq(&key_scalar, map_array.keys())?; + let entries = arrow::compute::filter(map_array.entries(), &keys)?; + let entries_struct_array = as_struct_array(entries.as_ref())?; + Ok(ColumnarValue::Array(entries_struct_array.column(1).clone())) } - _ => { - exec_err!( - "get_field function requires the first argument to be struct array" - ) + (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { + let as_struct_array = as_struct_array(&array)?; + match as_struct_array.column_by_name(&k) { + None => exec_err!( + "get indexed field {k} not found in struct"), + Some(col) => Ok(ColumnarValue::Array(col.clone())) + } } - }, - } + (DataType::Struct(_), name) => exec_err!( + "get indexed field is only possible on struct with utf8 indexes. \ + Tried with {name:?} index"), + (dt, name) => exec_err!( + "get indexed field is only possible on lists with int64 indexes or struct \ + with utf8 indexes. Tried {dt:?} with {name:?} index"), + } } } From f21f41f5f8b81c385a7272799d73c869830d1075 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Tue, 12 Mar 2024 10:42:51 -0500 Subject: [PATCH 11/16] refresh CI --- datafusion/core/src/dataframe/mod.rs | 23 ----------------------- 1 file changed, 23 deletions(-) diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index e22a662e2bdf1..3bdf2af4552de 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -1591,29 +1591,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn test_my() -> Result<()> { - let ctx = SessionContext::new(); - - let create_table_query = r#" - CREATE TABLE values( - a INT, - b FLOAT, - c VARCHAR - ) AS VALUES - (1, 1.1, 'a'), - (2, 2.2, 'b'), - (3, 3.3, 'c') - "#; - ctx.sql(create_table_query).await?; - - let query = r#"select struct(a, b, c)['c1'] from values"#; - - let result = ctx.sql(query).await?; - result.show().await?; - Ok(()) - } - #[tokio::test] async fn test_array_agg_schema() -> Result<()> { let ctx = SessionContext::new(); From 3a39ec756e05a5320c2e2230f2878e6f1c68c5a6 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Tue, 12 Mar 2024 10:51:04 -0500 Subject: [PATCH 12/16] resolve strange bug --- datafusion/proto/src/physical_plan/from_proto.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index dad950596907e..184c048c1bdd9 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -41,9 +41,8 @@ use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::WindowFunctionDefinition; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - in_list, BinaryExpr, CaseExpr, CastExpr, Column, GetFieldAccessExpr, - GetIndexedFieldExpr, IsNotNullExpr, IsNullExpr, LikeExpr, Literal, NegativeExpr, - NotExpr, TryCastExpr, + in_list, BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, LikeExpr, + Literal, NegativeExpr, NotExpr, TryCastExpr, }; use datafusion::physical_plan::windows::create_window_expr; use datafusion::physical_plan::{ From 25c75e85c890f7a7a4356748b55e1cd4c9aa9df4 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Tue, 12 Mar 2024 11:07:41 -0500 Subject: [PATCH 13/16] fix clippy --- datafusion/functions/src/core/getfield.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index ccd6f5a0a5ee8..9b63609d7c5a8 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -122,7 +122,7 @@ impl ScalarUDFImpl for GetFieldFunc { } (DataType::Struct(_), ScalarValue::Utf8(Some(k))) => { let as_struct_array = as_struct_array(&array)?; - match as_struct_array.column_by_name(&k) { + match as_struct_array.column_by_name(k) { None => exec_err!( "get indexed field {k} not found in struct"), Some(col) => Ok(ColumnarValue::Array(col.clone())) From d4fe597ea8eca2158fe0e500575b42f9bba8d234 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Wed, 13 Mar 2024 10:32:26 -0500 Subject: [PATCH 14/16] use values_to_arrays --- datafusion/functions/src/core/getfield.rs | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 9b63609d7c5a8..f8f116b99e5be 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -95,14 +95,9 @@ impl ScalarUDFImpl for GetFieldFunc { ); } - let arr; - let array = match &args[0] { - ColumnarValue::Array(array) => array, - ColumnarValue::Scalar(scalar) => { - arr = scalar.clone().to_array()?; - &arr - } - }; + let arrays = ColumnarValue::values_to_arrays(args)?; + let array = arrays[0].clone(); + let name = match &args[1] { ColumnarValue::Scalar(name) => name, _ => { From 0680d2c19688f903b4bf8529efcfed772d0d7a65 Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Wed, 13 Mar 2024 10:33:27 -0500 Subject: [PATCH 15/16] delete for merge --- .../optimizer/src/analyzer/rewrite_expr.rs | 364 ------------------ 1 file changed, 364 deletions(-) delete mode 100644 datafusion/optimizer/src/analyzer/rewrite_expr.rs diff --git a/datafusion/optimizer/src/analyzer/rewrite_expr.rs b/datafusion/optimizer/src/analyzer/rewrite_expr.rs deleted file mode 100644 index 4a34039f2e108..0000000000000 --- a/datafusion/optimizer/src/analyzer/rewrite_expr.rs +++ /dev/null @@ -1,364 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Analyzer rule for to replace operators with function calls (e.g `||` to array_concat`) - -use std::sync::Arc; - -use super::AnalyzerRule; - -use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNodeRewriter}; -#[cfg(feature = "array_expressions")] -use datafusion_common::{utils::list_ndims, DFSchemaRef}; -use datafusion_common::{DFSchema, Result}; -use datafusion_expr::expr::ScalarFunction; -use datafusion_expr::expr_rewriter::rewrite_preserving_name; -use datafusion_expr::utils::merge_schema; -use datafusion_expr::GetFieldAccess; -use datafusion_expr::GetIndexedField; -#[cfg(feature = "array_expressions")] -use datafusion_expr::{BinaryExpr, Operator, ScalarFunctionDefinition}; -use datafusion_expr::{BuiltinScalarFunction, ScalarUDF}; -use datafusion_expr::{Expr, LogicalPlan}; -#[cfg(feature = "array_expressions")] -use datafusion_functions_array::expr_fn::{array_append, array_concat, array_prepend}; - -#[derive(Default)] -pub struct OperatorToFunction {} - -impl OperatorToFunction { - pub fn new() -> Self { - Self {} - } -} - -impl AnalyzerRule for OperatorToFunction { - fn name(&self) -> &str { - "operator_to_function" - } - - fn analyze(&self, plan: LogicalPlan, _: &ConfigOptions) -> Result { - analyze_internal(&plan) - } -} - -fn analyze_internal(plan: &LogicalPlan) -> Result { - // optimize child plans first - let new_inputs = plan - .inputs() - .iter() - .map(|p| analyze_internal(p)) - .collect::>>()?; - - // get schema representing all available input fields. This is used for data type - // resolution only, so order does not matter here - let mut schema = merge_schema(new_inputs.iter().collect()); - - if let LogicalPlan::TableScan(ts) = plan { - let source_schema = - DFSchema::try_from_qualified_schema(&ts.table_name, &ts.source.schema())?; - schema.merge(&source_schema); - } - - let mut expr_rewrite = OperatorToFunctionRewriter { - #[cfg(feature = "array_expressions")] - schema: Arc::new(schema), - }; - - let new_expr = plan - .expressions() - .into_iter() - .map(|expr| { - // ensure names don't change: - // https://github.com/apache/arrow-datafusion/issues/3555 - rewrite_preserving_name(expr, &mut expr_rewrite) - }) - .collect::>>()?; - - plan.with_new_exprs(new_expr, new_inputs) -} - -pub(crate) struct OperatorToFunctionRewriter { - #[cfg(feature = "array_expressions")] - pub(crate) schema: DFSchemaRef, -} - -impl TreeNodeRewriter for OperatorToFunctionRewriter { - type Node = Expr; - - fn f_up(&mut self, expr: Expr) -> Result> { - #[cfg(feature = "array_expressions")] - if let Expr::BinaryExpr(BinaryExpr { - ref left, - op, - ref right, - }) = expr - { - if let Some(expr) = rewrite_array_concat_operator_to_func_for_column( - left.as_ref(), - op, - right.as_ref(), - self.schema.as_ref(), - )? - .or_else(|| { - rewrite_array_concat_operator_to_func(left.as_ref(), op, right.as_ref()) - }) { - // Convert &Box -> Expr - return Ok(Transformed::yes(expr)); - } - - // TODO: change OperatorToFunction to OperatoToArrayFunction and configure it with array_expressions feature - // after other array functions are udf-based - #[cfg(feature = "array_expressions")] - if let Some(expr) = rewrite_array_has_all_operator_to_func(left, op, right) { - return Ok(Transformed::yes(expr)); - } - } - - if let Expr::GetIndexedField(GetIndexedField { - ref expr, - ref field, - }) = expr - { - match field { - GetFieldAccess::NamedStructField { name, .. } => { - let expr = *expr.clone(); - let name = name.clone(); - let args = vec![expr, Expr::Literal(name)]; - return Ok(Transformed::yes(Expr::ScalarFunction( - ScalarFunction::new_udf( - Arc::new(ScalarUDF::new_from_impl( - datafusion_functions::core::getfield::GetFieldFunc::new(), - )), - args, - ), - ))); - } - GetFieldAccess::ListIndex { ref key } => { - let expr = *expr.clone(); - let key = *key.clone(); - let args = vec![expr, key]; - return Ok(Transformed::yes(Expr::ScalarFunction( - ScalarFunction::new(BuiltinScalarFunction::ArrayElement, args), - ))); - } - GetFieldAccess::ListRange { - start, - stop, - stride, - } => { - let expr = *expr.clone(); - let start = *start.clone(); - let stop = *stop.clone(); - let stride = *stride.clone(); - let args = vec![expr, start, stop, stride]; - return Ok(Transformed::yes(Expr::ScalarFunction( - ScalarFunction::new(BuiltinScalarFunction::ArraySlice, args), - ))); - } - } - } - - Ok(Transformed::no(expr)) - } -} - -// Note This rewrite is only done if the built in DataFusion `array_expressions` feature is enabled. -// Even if users implement their own array functions, those functions are not equal to the DataFusion -// udf based array functions, so this rewrite is not corrrect -#[cfg(feature = "array_expressions")] -fn rewrite_array_has_all_operator_to_func( - left: &Expr, - op: Operator, - right: &Expr, -) -> Option { - use super::array_has_all; - - if op != Operator::AtArrow && op != Operator::ArrowAt { - return None; - } - - match (left, right) { - // array1 @> array2 -> array_has_all(array1, array2) - // array1 <@ array2 -> array_has_all(array2, array1) - ( - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(left_fun), - args: _left_args, - }), - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(right_fun), - args: _right_args, - }), - ) if left_fun.name() == "make_array" && right_fun.name() == "make_array" => { - let left = left.clone(); - let right = right.clone(); - - let expr = if let Operator::ArrowAt = op { - array_has_all(right, left) - } else { - array_has_all(left, right) - }; - Some(expr) - } - _ => None, - } -} - -/// Summary of the logic below: -/// -/// 1) array || array -> array concat -/// -/// 2) array || scalar -> array append -/// -/// 3) scalar || array -> array prepend -/// -/// 4) (arry concat, array append, array prepend) || array -> array concat -/// -/// 5) (arry concat, array append, array prepend) || scalar -> array append -#[cfg(feature = "array_expressions")] -fn rewrite_array_concat_operator_to_func( - left: &Expr, - op: Operator, - right: &Expr, -) -> Option { - // Convert `Array StringConcat Array` to ScalarFunction::ArrayConcat - - if op != Operator::StringConcat { - return None; - } - - match (left, right) { - // Chain concat operator (a || b) || array, - // (arry concat, array append, array prepend) || array -> array concat - ( - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(left_fun), - args: _left_args, - }), - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(right_fun), - args: _right_args, - }), - ) if ["array_append", "array_prepend", "array_concat"] - .contains(&left_fun.name()) - && right_fun.name() == "make_array" => - { - Some(array_concat(vec![left.clone(), right.clone()])) - } - // Chain concat operator (a || b) || scalar, - // (arry concat, array append, array prepend) || scalar -> array append - ( - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(left_fun), - args: _left_args, - }), - _scalar, - ) if ["array_append", "array_prepend", "array_concat"] - .contains(&left_fun.name()) => - { - Some(array_append(left.clone(), right.clone())) - } - // array || array -> array concat - ( - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(left_fun), - args: _left_args, - }), - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(right_fun), - args: _right_args, - }), - ) if left_fun.name() == "make_array" && right_fun.name() == "make_array" => { - Some(array_concat(vec![left.clone(), right.clone()])) - } - // array || scalar -> array append - ( - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(left_fun), - args: _left_args, - }), - _right_scalar, - ) if left_fun.name() == "make_array" => { - Some(array_append(left.clone(), right.clone())) - } - // scalar || array -> array prepend - ( - _left_scalar, - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(right_fun), - args: _right_args, - }), - ) if right_fun.name() == "make_array" => { - Some(array_prepend(left.clone(), right.clone())) - } - - _ => None, - } -} - -/// Summary of the logic below: -/// -/// 1) (arry concat, array append, array prepend) || column -> (array append, array concat) -/// -/// 2) column1 || column2 -> (array prepend, array append, array concat) -#[cfg(feature = "array_expressions")] -fn rewrite_array_concat_operator_to_func_for_column( - left: &Expr, - op: Operator, - right: &Expr, - schema: &DFSchema, -) -> Result> { - if op != Operator::StringConcat { - return Ok(None); - } - - match (left, right) { - // Column cases: - // 1) array_prepend/append/concat || column - ( - Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::UDF(left_fun), - args: _left_args, - }), - Expr::Column(c), - ) if ["array_append", "array_prepend", "array_concat"] - .contains(&left_fun.name()) => - { - let d = schema.field_from_column(c)?.data_type(); - let ndim = list_ndims(d); - match ndim { - 0 => Ok(Some(array_append(left.clone(), right.clone()))), - _ => Ok(Some(array_concat(vec![left.clone(), right.clone()]))), - } - } - // 2) select column1 || column2 - (Expr::Column(c1), Expr::Column(c2)) => { - let d1 = schema.field_from_column(c1)?.data_type(); - let d2 = schema.field_from_column(c2)?.data_type(); - let ndim1 = list_ndims(d1); - let ndim2 = list_ndims(d2); - match (ndim1, ndim2) { - (0, _) => Ok(Some(array_prepend(left.clone(), right.clone()))), - (_, 0) => Ok(Some(array_append(left.clone(), right.clone()))), - _ => Ok(Some(array_concat(vec![left.clone(), right.clone()]))), - } - } - _ => Ok(None), - } -} From 61186fb3af27c4d8ab86dfdbd617041dee261afa Mon Sep 17 00:00:00 2001 From: Junhao Liu Date: Wed, 13 Mar 2024 11:20:01 -0500 Subject: [PATCH 16/16] use function_rewrite feature --- datafusion-cli/Cargo.lock | 32 +++++++++++------------ datafusion/functions-array/Cargo.toml | 1 + datafusion/functions-array/src/rewrite.rs | 10 +++++++ datafusion/functions/src/core/getfield.rs | 7 +---- datafusion/functions/src/core/mod.rs | 2 +- datafusion/optimizer/Cargo.toml | 1 - 6 files changed, 29 insertions(+), 24 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 9ff66444f1f93..5bc242af74e4b 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -754,9 +754,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0231f06152bf547e9c2b5194f247cd97aacf6dcd8b15d8e5ec0663f64580da87" +checksum = "30cca6d3674597c30ddf2c587bf8d9d65c9a84d2326d941cc79c9842dfe0ef52" dependencies = [ "arrayref", "arrayvec", @@ -1271,6 +1271,7 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", + "datafusion-functions", "itertools", "log", "paste", @@ -1285,7 +1286,6 @@ dependencies = [ "chrono", "datafusion-common", "datafusion-expr", - "datafusion-functions", "datafusion-physical-expr", "hashbrown 0.14.3", "itertools", @@ -2645,9 +2645,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.78" +version = "1.0.79" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" +checksum = "e835ff2298f5721608eb1a980ecaee1aef2c132bf95ecc026a11b7bf3c01c02e" dependencies = [ "unicode-ident", ] @@ -2774,9 +2774,9 @@ checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "reqwest" -version = "0.11.25" +version = "0.11.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eea5a9eb898d3783f17c6407670e3592fd174cb81a10e51d4c37f49450b9946" +checksum = "78bf93c4af7a8bb7d879d51cebe797356ff10ae8516ace542b5182d9dcac10b2" dependencies = [ "base64 0.21.7", "bytes", @@ -3330,20 +3330,20 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "system-configuration" -version = "0.6.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "658bc6ee10a9b4fcf576e9b0819d95ec16f4d2c02d39fd83ac1c8789785c4a42" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ - "bitflags 2.4.2", + "bitflags 1.3.2", "core-foundation", "system-configuration-sys", ] [[package]] name = "system-configuration-sys" -version = "0.6.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" dependencies = [ "core-foundation-sys", "libc", @@ -3384,18 +3384,18 @@ checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" [[package]] name = "thiserror" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b" +checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.57" +version = "1.0.58" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" +checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", diff --git a/datafusion/functions-array/Cargo.toml b/datafusion/functions-array/Cargo.toml index ba7d9e26ecaf5..99239ffb3bdc3 100644 --- a/datafusion/functions-array/Cargo.toml +++ b/datafusion/functions-array/Cargo.toml @@ -44,6 +44,7 @@ arrow-schema = { workspace = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-functions = { workspace = true } itertools = { version = "0.12", features = ["use_std"] } log = { workspace = true } paste = "1.0.14" diff --git a/datafusion/functions-array/src/rewrite.rs b/datafusion/functions-array/src/rewrite.rs index 368fad41af29b..a9e79f54a52d8 100644 --- a/datafusion/functions-array/src/rewrite.rs +++ b/datafusion/functions-array/src/rewrite.rs @@ -28,6 +28,7 @@ use datafusion_expr::expr_rewriter::FunctionRewrite; use datafusion_expr::{ BinaryExpr, BuiltinScalarFunction, Expr, GetFieldAccess, GetIndexedField, Operator, }; +use datafusion_functions::expr_fn::get_field; /// Rewrites expressions into function calls to array functions pub(crate) struct ArrayFunctionRewriter {} @@ -147,6 +148,15 @@ impl FunctionRewrite for ArrayFunctionRewriter { Transformed::yes(array_prepend(*left, *right)) } + Expr::GetIndexedField(GetIndexedField { + expr, + field: GetFieldAccess::NamedStructField { name }, + }) => { + let expr = *expr.clone(); + let name = Expr::Literal(name); + Transformed::yes(get_field(expr, name.clone())) + } + // expr[idx] ==> array_element(expr, idx) Expr::GetIndexedField(GetIndexedField { expr, diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index f8f116b99e5be..0a99cccf9e1c4 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -25,7 +25,7 @@ use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use std::any::Any; #[derive(Debug)] -pub struct GetFieldFunc { +pub(super) struct GetFieldFunc { signature: Signature, } @@ -36,11 +36,6 @@ impl GetFieldFunc { } } } -impl Default for GetFieldFunc { - fn default() -> Self { - Self::new() - } -} // get_field(struct_array, field_name) impl ScalarUDFImpl for GetFieldFunc { diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index c866513de05af..73cc4d18bf9f7 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -18,7 +18,7 @@ //! "core" DataFusion functions mod arrowtypeof; -pub mod getfield; +mod getfield; mod nullif; mod nvl; mod nvl2; diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 3034f28f8688b..f497f2ec86027 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -45,7 +45,6 @@ async-trait = { workspace = true } chrono = { workspace = true } datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } -datafusion-functions = { workspace = true } datafusion-functions-array = { workspace = true, optional = true } datafusion-physical-expr = { workspace = true } hashbrown = { version = "0.14", features = ["raw"] }