diff --git a/datafusion/expr-common/src/signature.rs b/datafusion/expr-common/src/signature.rs index be69d3b809c09..3846fae5de5dc 100644 --- a/datafusion/expr-common/src/signature.rs +++ b/datafusion/expr-common/src/signature.rs @@ -18,8 +18,10 @@ //! Signature module contains foundational types that are used to represent signatures, types, //! and return types of functions in DataFusion. +use crate::type_coercion::aggregates::{NUMERICS, STRINGS}; use arrow::datatypes::DataType; -use datafusion_common::types::LogicalTypeRef; +use datafusion_common::types::{LogicalTypeRef, NativeType}; +use itertools::Itertools; /// Constant that is used as a placeholder for any valid timezone. /// This is used where a function can accept a timestamp type with any @@ -258,17 +260,66 @@ impl TypeSignature { .iter() .flat_map(|type_sig| type_sig.get_possible_types()) .collect(), + TypeSignature::Uniform(arg_count, types) => types + .iter() + .cloned() + .map(|data_type| vec![data_type; *arg_count]) + .collect(), + TypeSignature::Coercible(types) => types + .iter() + .map(|logical_type| get_data_types(logical_type.native())) + .multi_cartesian_product() + .collect(), + TypeSignature::Variadic(types) => types + .iter() + .cloned() + .map(|data_type| vec![data_type]) + .collect(), + TypeSignature::Numeric(arg_count) => NUMERICS + .iter() + .cloned() + .map(|numeric_type| vec![numeric_type; *arg_count]) + .collect(), + TypeSignature::String(arg_count) => STRINGS + .iter() + .cloned() + .map(|string_type| vec![string_type; *arg_count]) + .collect(), // TODO: Implement for other types - TypeSignature::Uniform(_, _) - | TypeSignature::Coercible(_) - | TypeSignature::Any(_) - | TypeSignature::Variadic(_) + TypeSignature::Any(_) | TypeSignature::VariadicAny - | TypeSignature::UserDefined | TypeSignature::ArraySignature(_) - | TypeSignature::Numeric(_) - | TypeSignature::String(_) => vec![], + | TypeSignature::UserDefined => vec![], + } + } +} + +fn get_data_types(native_type: &NativeType) -> Vec { + match native_type { + NativeType::Null => vec![DataType::Null], + NativeType::Boolean => vec![DataType::Boolean], + NativeType::Int8 => vec![DataType::Int8], + NativeType::Int16 => vec![DataType::Int16], + NativeType::Int32 => vec![DataType::Int32], + NativeType::Int64 => vec![DataType::Int64], + NativeType::UInt8 => vec![DataType::UInt8], + NativeType::UInt16 => vec![DataType::UInt16], + NativeType::UInt32 => vec![DataType::UInt32], + NativeType::UInt64 => vec![DataType::UInt64], + NativeType::Float16 => vec![DataType::Float16], + NativeType::Float32 => vec![DataType::Float32], + NativeType::Float64 => vec![DataType::Float64], + NativeType::Date => vec![DataType::Date32, DataType::Date64], + NativeType::Binary => vec![ + DataType::Binary, + DataType::LargeBinary, + DataType::BinaryView, + ], + NativeType::String => { + vec![DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View] } + // TODO: support other native types + _ => vec![], } } @@ -417,6 +468,8 @@ impl Signature { #[cfg(test)] mod tests { + use datafusion_common::types::{logical_int64, logical_string}; + use super::*; #[test] @@ -515,5 +568,65 @@ mod tests { vec![DataType::Utf8] ] ); + + let type_signature = + TypeSignature::Uniform(2, vec![DataType::Float32, DataType::Int64]); + let possible_types = type_signature.get_possible_types(); + assert_eq!( + possible_types, + vec![ + vec![DataType::Float32, DataType::Float32], + vec![DataType::Int64, DataType::Int64] + ] + ); + + let type_signature = + TypeSignature::Coercible(vec![logical_string(), logical_int64()]); + let possible_types = type_signature.get_possible_types(); + assert_eq!( + possible_types, + vec![ + vec![DataType::Utf8, DataType::Int64], + vec![DataType::LargeUtf8, DataType::Int64], + vec![DataType::Utf8View, DataType::Int64] + ] + ); + + let type_signature = + TypeSignature::Variadic(vec![DataType::Int32, DataType::Int64]); + let possible_types = type_signature.get_possible_types(); + assert_eq!( + possible_types, + vec![vec![DataType::Int32], vec![DataType::Int64]] + ); + + let type_signature = TypeSignature::Numeric(2); + let possible_types = type_signature.get_possible_types(); + assert_eq!( + possible_types, + vec![ + vec![DataType::Int8, DataType::Int8], + vec![DataType::Int16, DataType::Int16], + vec![DataType::Int32, DataType::Int32], + vec![DataType::Int64, DataType::Int64], + vec![DataType::UInt8, DataType::UInt8], + vec![DataType::UInt16, DataType::UInt16], + vec![DataType::UInt32, DataType::UInt32], + vec![DataType::UInt64, DataType::UInt64], + vec![DataType::Float32, DataType::Float32], + vec![DataType::Float64, DataType::Float64] + ] + ); + + let type_signature = TypeSignature::String(2); + let possible_types = type_signature.get_possible_types(); + assert_eq!( + possible_types, + vec![ + vec![DataType::Utf8, DataType::Utf8], + vec![DataType::LargeUtf8, DataType::LargeUtf8], + vec![DataType::Utf8View, DataType::Utf8View] + ] + ); } } diff --git a/datafusion/expr-common/src/type_coercion/aggregates.rs b/datafusion/expr-common/src/type_coercion/aggregates.rs index fee75f9e45959..384d688cc27ed 100644 --- a/datafusion/expr-common/src/type_coercion/aggregates.rs +++ b/datafusion/expr-common/src/type_coercion/aggregates.rs @@ -23,7 +23,8 @@ use arrow::datatypes::{ use datafusion_common::{internal_err, plan_err, Result}; -pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8]; +pub static STRINGS: &[DataType] = + &[DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View]; pub static SIGNED_INTEGERS: &[DataType] = &[ DataType::Int8,