diff --git a/rust/datafusion/src/physical_plan/expressions.rs b/rust/datafusion/src/physical_plan/expressions.rs index 084f8186c5e..e9bbe193ff3 100644 --- a/rust/datafusion/src/physical_plan/expressions.rs +++ b/rust/datafusion/src/physical_plan/expressions.rs @@ -1080,7 +1080,40 @@ impl fmt::Display for BinaryExpr { } } -// the type that both lhs and rhs can be casted to for the purpose of a string computation +/// Coercion rules for dictionary values (aka the type of the dictionary itself) +fn dictionary_value_coercion( + lhs_type: &DataType, + rhs_type: &DataType, +) -> Option { + numerical_coercion(lhs_type, rhs_type).or_else(|| string_coercion(lhs_type, rhs_type)) +} + +/// Coercion rules for Dictionaries: the type that both lhs and rhs +/// can be casted to for the purpose of a computation. +/// +/// It would likely be preferable to cast primitive values to +/// dictionaries, and thus avoid unpacking dictionary as well as doing +/// faster comparisons. However, the arrow compute kernels (e.g. eq) +/// don't have DictionaryArray support yet, so fall back to unpacking +/// the dictionaries +fn dictionary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { + match (lhs_type, rhs_type) { + ( + DataType::Dictionary(_lhs_index_type, lhs_value_type), + DataType::Dictionary(_rhs_index_type, rhs_value_type), + ) => dictionary_value_coercion(lhs_value_type, rhs_value_type), + (DataType::Dictionary(_index_type, value_type), _) => { + dictionary_value_coercion(value_type, rhs_type) + } + (_, DataType::Dictionary(_index_type, value_type)) => { + dictionary_value_coercion(lhs_type, value_type) + } + _ => None, + } +} + +/// Coercion rules for Strings: the type that both lhs and rhs can be +/// casted to for the purpose of a string computation fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; match (lhs_type, rhs_type) { @@ -1092,7 +1125,9 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option } } -/// coercion rule for numerical types +/// Coercion rule for numerical types: The type that both lhs and rhs +/// can be casted to for numerical calculation, while maintaining +/// maximum precision pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { use arrow::datatypes::DataType::*; @@ -1150,6 +1185,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option { return Some(lhs_type.clone()); } numerical_coercion(lhs_type, rhs_type) + .or_else(|| dictionary_coercion(lhs_type, rhs_type)) } // coercion rules that assume an ordered set, such as "less than". @@ -1160,16 +1196,13 @@ fn order_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option return Some(lhs_type.clone()); } - match numerical_coercion(lhs_type, rhs_type) { - None => { - // strings are naturally ordered, and thus ordering can be applied to them. - string_coercion(lhs_type, rhs_type) - } - t => t, - } + numerical_coercion(lhs_type, rhs_type) + .or_else(|| string_coercion(lhs_type, rhs_type)) + .or_else(|| dictionary_coercion(lhs_type, rhs_type)) } -/// coercion rules for all binary operators +/// Coercion rules for all binary operators. Returns the output type +/// of applying `op` to an argument of `lhs_type` and `rhs_type`. fn common_binary_type( lhs_type: &DataType, op: &Operator, @@ -1526,8 +1559,8 @@ impl PhysicalExpr for CastExpr { } } -/// Returns a physical cast operation that casts `expr` to `cast_type` -/// if casting is needed. +/// Return a PhysicalExpression representing `expr` casted to +/// `cast_type`, if any casting is needed. /// /// Note that such casts may lose type information pub fn cast( @@ -1665,11 +1698,14 @@ impl PhysicalSortExpr { mod tests { use super::*; use crate::error::Result; - use arrow::array::{ - LargeStringArray, PrimitiveArray, PrimitiveArrayOps, StringArray, - Time64NanosecondArray, - }; use arrow::datatypes::*; + use arrow::{ + array::{ + LargeStringArray, PrimitiveArray, PrimitiveArrayOps, PrimitiveBuilder, + StringArray, StringDictionaryBuilder, Time64NanosecondArray, + }, + util::display::array_value_to_string, + }; // Create a binary expression without coercion. Used here when we do not want to coerce the expressions // to valid types. Usage can result in an execution (after plan) error. @@ -1772,11 +1808,13 @@ mod tests { // runs an end-to-end test of physical type coercion: // 1. construct a record batch with two columns of type A and B + // (*_ARRAY is the Rust Arrow array type, and *_TYPE is the DataType of the elements) // 2. construct a physical expression of A OP B // 3. evaluate the expression // 4. verify that the resulting expression is of type C + // 5. verify that the results of evaluation are $VEC macro_rules! test_coercion { - ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr) => {{ + ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $C_ARRAY:ident, $C_TYPE:expr, $VEC:expr) => {{ let schema = Schema::new(vec![ Field::new("a", $A_TYPE, false), Field::new("b", $B_TYPE, false), @@ -1792,18 +1830,18 @@ mod tests { let expression = binary(col("a"), $OP, col("b"), &schema)?; // verify that the expression's type is correct - assert_eq!(expression.data_type(&schema)?, $TYPE); + assert_eq!(expression.data_type(&schema)?, $C_TYPE); // compute let result = expression.evaluate(&batch)?; // verify that the array's data_type is correct - assert_eq!(*result.data_type(), $TYPE); + assert_eq!(*result.data_type(), $C_TYPE); // verify that the data itself is downcastable let result = result .as_any() - .downcast_ref::<$TYPEARRAY>() + .downcast_ref::<$C_ARRAY>() .expect("failed to downcast"); // verify that the result itself is correct for (i, x) in $VEC.iter().enumerate() { @@ -1877,6 +1915,107 @@ mod tests { Ok(()) } + #[test] + fn test_dictionary_type_coersion() -> Result<()> { + use DataType::*; + + // TODO: In the future, this would ideally return Dictionary types and avoid unpacking + let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32)); + let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); + assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Int32)); + + let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); + assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), None); + + let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + let rhs_type = Utf8; + assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8)); + + let lhs_type = Utf8; + let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); + assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8)); + + Ok(()) + } + + // Note it would be nice to use the same test_coercion macro as + // above, but sadly the type of the values of the dictionary are + // not encoded in the rust type of the DictionaryArray. Thus there + // is no way at the time of this writing to create a dictionary + // array using the `From` trait + #[test] + fn test_dictionary_type_to_array_coersion() -> Result<()> { + // Test string a string dictionary + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let string_type = DataType::Utf8; + + // build dictionary + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = StringBuilder::new(10); + let mut dict_builder = StringDictionaryBuilder::new(keys_builder, values_builder); + + dict_builder.append("one")?; + dict_builder.append_null()?; + dict_builder.append("three")?; + dict_builder.append("four")?; + let dict_array = dict_builder.finish(); + + let str_array = + StringArray::from(vec![Some("not one"), Some("two"), None, Some("four")]); + + let schema = Arc::new(Schema::new(vec![ + Field::new("dict", dict_type.clone(), true), + Field::new("str", string_type.clone(), true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(dict_array), Arc::new(str_array)], + )?; + + let expected = "false\n\n\ntrue"; + + // Test 1: dict = str + + // verify that we can construct the expression + let expression = binary(col("dict"), Operator::Eq, col("str"), &schema)?; + assert_eq!(expression.data_type(&schema)?, DataType::Boolean); + + // evaluate and verify the result type matched + let result = expression.evaluate(&batch)?; + assert_eq!(result.data_type(), &DataType::Boolean); + + // verify that the result itself is correct + assert_eq!(expected, array_to_string(&result)?); + + // Test 2: now test the other direction + // str = dict + + // verify that we can construct the expression + let expression = binary(col("str"), Operator::Eq, col("dict"), &schema)?; + assert_eq!(expression.data_type(&schema)?, DataType::Boolean); + + // evaluate and verify the result type matched + let result = expression.evaluate(&batch)?; + assert_eq!(result.data_type(), &DataType::Boolean); + + // verify that the result itself is correct + assert_eq!(expected, array_to_string(&result)?); + + Ok(()) + } + + // Convert the array to a newline delimited string of pretty printed values + fn array_to_string(array: &ArrayRef) -> Result { + let s = (0..array.len()) + .map(|i| array_value_to_string(array, i)) + .collect::, arrow::error::ArrowError>>()? + .join("\n"); + Ok(s) + } + #[test] fn test_coersion_error() -> Result<()> { let expr = diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 52027a4080b..5bf45251716 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -21,8 +21,8 @@ use std::sync::Arc; extern crate arrow; extern crate datafusion; -use arrow::record_batch::RecordBatch; use arrow::{array::*, datatypes::TimeUnit}; +use arrow::{datatypes::Int32Type, record_batch::RecordBatch}; use arrow::{ datatypes::{DataType, Field, Schema, SchemaRef}, util::display::array_value_to_string, @@ -930,14 +930,20 @@ fn register_alltypes_parquet(ctx: &mut ExecutionContext) { /// Execute query and return result set as 2-d table of Vecs /// `result[row][column]` async fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec> { - let plan = ctx.create_logical_plan(&sql).unwrap(); + let msg = format!("Creating logical plan for '{}'", sql); + let plan = ctx.create_logical_plan(&sql).expect(&msg); let logical_schema = plan.schema(); - let plan = ctx.optimize(&plan).unwrap(); + + let msg = format!("Optimizing logical plan for '{}': {:?}", sql, plan); + let plan = ctx.optimize(&plan).expect(&msg); let optimized_logical_schema = plan.schema(); - let plan = ctx.create_physical_plan(&plan).unwrap(); + + let msg = format!("Creating physical plan for '{}': {:?}", sql, plan); + let plan = ctx.create_physical_plan(&plan).expect(&msg); let physical_schema = plan.schema(); - let results = ctx.collect(plan).await.unwrap(); + let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); + let results = ctx.collect(plan).await.expect(&msg); assert_eq!(logical_schema.as_ref(), optimized_logical_schema.as_ref()); assert_eq!(logical_schema.as_ref(), physical_schema.as_ref()); @@ -1238,3 +1244,59 @@ async fn query_count_distinct() -> Result<()> { assert_eq!(expected, actual); Ok(()) } + +#[tokio::test] +async fn query_on_string_dictionary() -> Result<()> { + // Test to ensure DataFusion can operate on dictionary types + // Use StringDictionary (32 bit indexes = keys) + let field_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); + + let keys_builder = PrimitiveBuilder::::new(10); + let values_builder = StringBuilder::new(10); + let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); + + builder.append("one")?; + builder.append_null()?; + builder.append("three")?; + let array = Arc::new(builder.finish()); + + let data = RecordBatch::try_new(schema.clone(), vec![array])?; + + let table = MemTable::new(schema, vec![vec![data]])?; + let mut ctx = ExecutionContext::new(); + ctx.register_table("test", Box::new(table)); + + // Basic SELECT + let sql = "SELECT * FROM test"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["one"], vec!["NULL"], vec!["three"]]; + assert_eq!(expected, actual); + + // basic filtering + let sql = "SELECT * FROM test WHERE d1 IS NOT NULL"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["one"], vec!["three"]]; + assert_eq!(expected, actual); + + // filtering with constant + let sql = "SELECT * FROM test WHERE d1 = 'three'"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["three"]]; + assert_eq!(expected, actual); + + // Expression evaluation + let sql = "SELECT concat(d1, '-foo') FROM test"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["one-foo"], vec!["NULL"], vec!["three-foo"]]; + assert_eq!(expected, actual); + + // aggregation + let sql = "SELECT COUNT(d1) FROM test"; + let actual = execute(&mut ctx, sql).await; + let expected = vec![vec!["2"]]; + assert_eq!(expected, actual); + + Ok(()) +}