-
Notifications
You must be signed in to change notification settings - Fork 4k
ARROW-10163: [Rust] [DataFusion] Add DictionaryArray coercion support #8463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1080,7 +1080,40 @@ impl fmt::Display for BinaryExpr { | |
| } | ||
| } | ||
|
|
||
| // the type that both lhs and rhs can be casted to for the purpose of a string computation | ||
| /// Coercion rules for dictionary values (aka the type of the dictionary itself) | ||
| fn dictionary_value_coercion( | ||
| lhs_type: &DataType, | ||
| rhs_type: &DataType, | ||
| ) -> Option<DataType> { | ||
| numerical_coercion(lhs_type, rhs_type).or_else(|| string_coercion(lhs_type, rhs_type)) | ||
| } | ||
|
|
||
| /// Coercion rules for Dictionaries: the type that both lhs and rhs | ||
| /// can be casted to for the purpose of a computation. | ||
| /// | ||
| /// It would likely be preferable to cast primitive values to | ||
| /// dictionaries, and thus avoid unpacking dictionary as well as doing | ||
| /// faster comparisons. However, the arrow compute kernels (e.g. eq) | ||
| /// don't have DictionaryArray support yet, so fall back to unpacking | ||
| /// the dictionaries | ||
| fn dictionary_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { | ||
| match (lhs_type, rhs_type) { | ||
| ( | ||
| DataType::Dictionary(_lhs_index_type, lhs_value_type), | ||
| DataType::Dictionary(_rhs_index_type, rhs_value_type), | ||
| ) => dictionary_value_coercion(lhs_value_type, rhs_value_type), | ||
| (DataType::Dictionary(_index_type, value_type), _) => { | ||
| dictionary_value_coercion(value_type, rhs_type) | ||
| } | ||
| (_, DataType::Dictionary(_index_type, value_type)) => { | ||
| dictionary_value_coercion(lhs_type, value_type) | ||
| } | ||
| _ => None, | ||
| } | ||
| } | ||
|
|
||
| /// Coercion rules for Strings: the type that both lhs and rhs can be | ||
| /// casted to for the purpose of a string computation | ||
| fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { | ||
| use arrow::datatypes::DataType::*; | ||
| match (lhs_type, rhs_type) { | ||
|
|
@@ -1092,7 +1125,9 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> | |
| } | ||
| } | ||
|
|
||
| /// coercion rule for numerical types | ||
| /// Coercion rule for numerical types: The type that both lhs and rhs | ||
| /// can be casted to for numerical calculation, while maintaining | ||
| /// maximum precision | ||
| pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { | ||
| use arrow::datatypes::DataType::*; | ||
|
|
||
|
|
@@ -1150,6 +1185,7 @@ fn eq_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> { | |
| return Some(lhs_type.clone()); | ||
| } | ||
| numerical_coercion(lhs_type, rhs_type) | ||
| .or_else(|| dictionary_coercion(lhs_type, rhs_type)) | ||
| } | ||
|
|
||
| // coercion rules that assume an ordered set, such as "less than". | ||
|
|
@@ -1160,16 +1196,13 @@ fn order_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> | |
| return Some(lhs_type.clone()); | ||
| } | ||
|
|
||
| match numerical_coercion(lhs_type, rhs_type) { | ||
| None => { | ||
| // strings are naturally ordered, and thus ordering can be applied to them. | ||
| string_coercion(lhs_type, rhs_type) | ||
| } | ||
| t => t, | ||
| } | ||
| numerical_coercion(lhs_type, rhs_type) | ||
| .or_else(|| string_coercion(lhs_type, rhs_type)) | ||
| .or_else(|| dictionary_coercion(lhs_type, rhs_type)) | ||
| } | ||
|
|
||
| /// coercion rules for all binary operators | ||
| /// Coercion rules for all binary operators. Returns the output type | ||
| /// of applying `op` to an argument of `lhs_type` and `rhs_type`. | ||
| fn common_binary_type( | ||
| lhs_type: &DataType, | ||
| op: &Operator, | ||
|
|
@@ -1526,8 +1559,8 @@ impl PhysicalExpr for CastExpr { | |
| } | ||
| } | ||
|
|
||
| /// Returns a physical cast operation that casts `expr` to `cast_type` | ||
| /// if casting is needed. | ||
| /// Return a PhysicalExpression representing `expr` casted to | ||
| /// `cast_type`, if any casting is needed. | ||
| /// | ||
| /// Note that such casts may lose type information | ||
| pub fn cast( | ||
|
|
@@ -1665,11 +1698,14 @@ impl PhysicalSortExpr { | |
| mod tests { | ||
| use super::*; | ||
| use crate::error::Result; | ||
| use arrow::array::{ | ||
| LargeStringArray, PrimitiveArray, PrimitiveArrayOps, StringArray, | ||
| Time64NanosecondArray, | ||
| }; | ||
| use arrow::datatypes::*; | ||
| use arrow::{ | ||
| array::{ | ||
| LargeStringArray, PrimitiveArray, PrimitiveArrayOps, PrimitiveBuilder, | ||
| StringArray, StringDictionaryBuilder, Time64NanosecondArray, | ||
| }, | ||
| util::display::array_value_to_string, | ||
| }; | ||
|
|
||
| // Create a binary expression without coercion. Used here when we do not want to coerce the expressions | ||
| // to valid types. Usage can result in an execution (after plan) error. | ||
|
|
@@ -1772,11 +1808,13 @@ mod tests { | |
|
|
||
| // runs an end-to-end test of physical type coercion: | ||
| // 1. construct a record batch with two columns of type A and B | ||
| // (*_ARRAY is the Rust Arrow array type, and *_TYPE is the DataType of the elements) | ||
| // 2. construct a physical expression of A OP B | ||
| // 3. evaluate the expression | ||
| // 4. verify that the resulting expression is of type C | ||
| // 5. verify that the results of evaluation are $VEC | ||
| macro_rules! test_coercion { | ||
| ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $TYPEARRAY:ident, $TYPE:expr, $VEC:expr) => {{ | ||
|
||
| ($A_ARRAY:ident, $A_TYPE:expr, $A_VEC:expr, $B_ARRAY:ident, $B_TYPE:expr, $B_VEC:expr, $OP:expr, $C_ARRAY:ident, $C_TYPE:expr, $VEC:expr) => {{ | ||
| let schema = Schema::new(vec![ | ||
| Field::new("a", $A_TYPE, false), | ||
| Field::new("b", $B_TYPE, false), | ||
|
|
@@ -1792,18 +1830,18 @@ mod tests { | |
| let expression = binary(col("a"), $OP, col("b"), &schema)?; | ||
|
|
||
| // verify that the expression's type is correct | ||
| assert_eq!(expression.data_type(&schema)?, $TYPE); | ||
| assert_eq!(expression.data_type(&schema)?, $C_TYPE); | ||
|
|
||
| // compute | ||
| let result = expression.evaluate(&batch)?; | ||
|
|
||
| // verify that the array's data_type is correct | ||
| assert_eq!(*result.data_type(), $TYPE); | ||
| assert_eq!(*result.data_type(), $C_TYPE); | ||
|
|
||
| // verify that the data itself is downcastable | ||
| let result = result | ||
| .as_any() | ||
| .downcast_ref::<$TYPEARRAY>() | ||
| .downcast_ref::<$C_ARRAY>() | ||
| .expect("failed to downcast"); | ||
| // verify that the result itself is correct | ||
| for (i, x) in $VEC.iter().enumerate() { | ||
|
|
@@ -1877,6 +1915,107 @@ mod tests { | |
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_dictionary_type_coersion() -> Result<()> { | ||
| use DataType::*; | ||
|
|
||
| // TODO: In the future, this would ideally return Dictionary types and avoid unpacking | ||
| let lhs_type = Dictionary(Box::new(Int8), Box::new(Int32)); | ||
| let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); | ||
| assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Int32)); | ||
|
|
||
| let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); | ||
| let rhs_type = Dictionary(Box::new(Int8), Box::new(Int16)); | ||
| assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), None); | ||
|
|
||
| let lhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); | ||
| let rhs_type = Utf8; | ||
| assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8)); | ||
|
|
||
| let lhs_type = Utf8; | ||
| let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8)); | ||
| assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8)); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| // Note it would be nice to use the same test_coercion macro as | ||
| // above, but sadly the type of the values of the dictionary are | ||
| // not encoded in the rust type of the DictionaryArray. Thus there | ||
| // is no way at the time of this writing to create a dictionary | ||
| // array using the `From` trait | ||
| #[test] | ||
| fn test_dictionary_type_to_array_coersion() -> Result<()> { | ||
| // Test string a string dictionary | ||
| let dict_type = | ||
| DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); | ||
| let string_type = DataType::Utf8; | ||
|
|
||
| // build dictionary | ||
| let keys_builder = PrimitiveBuilder::<Int32Type>::new(10); | ||
| let values_builder = StringBuilder::new(10); | ||
| let mut dict_builder = StringDictionaryBuilder::new(keys_builder, values_builder); | ||
|
|
||
| dict_builder.append("one")?; | ||
| dict_builder.append_null()?; | ||
| dict_builder.append("three")?; | ||
| dict_builder.append("four")?; | ||
| let dict_array = dict_builder.finish(); | ||
|
|
||
| let str_array = | ||
| StringArray::from(vec![Some("not one"), Some("two"), None, Some("four")]); | ||
|
|
||
| let schema = Arc::new(Schema::new(vec![ | ||
| Field::new("dict", dict_type.clone(), true), | ||
| Field::new("str", string_type.clone(), true), | ||
| ])); | ||
|
|
||
| let batch = RecordBatch::try_new( | ||
| schema.clone(), | ||
| vec![Arc::new(dict_array), Arc::new(str_array)], | ||
| )?; | ||
|
|
||
| let expected = "false\n\n\ntrue"; | ||
|
|
||
| // Test 1: dict = str | ||
|
|
||
| // verify that we can construct the expression | ||
| let expression = binary(col("dict"), Operator::Eq, col("str"), &schema)?; | ||
| assert_eq!(expression.data_type(&schema)?, DataType::Boolean); | ||
|
|
||
| // evaluate and verify the result type matched | ||
| let result = expression.evaluate(&batch)?; | ||
| assert_eq!(result.data_type(), &DataType::Boolean); | ||
|
|
||
| // verify that the result itself is correct | ||
| assert_eq!(expected, array_to_string(&result)?); | ||
|
|
||
| // Test 2: now test the other direction | ||
| // str = dict | ||
|
|
||
| // verify that we can construct the expression | ||
| let expression = binary(col("str"), Operator::Eq, col("dict"), &schema)?; | ||
| assert_eq!(expression.data_type(&schema)?, DataType::Boolean); | ||
|
|
||
| // evaluate and verify the result type matched | ||
| let result = expression.evaluate(&batch)?; | ||
| assert_eq!(result.data_type(), &DataType::Boolean); | ||
|
|
||
| // verify that the result itself is correct | ||
| assert_eq!(expected, array_to_string(&result)?); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| // Convert the array to a newline delimited string of pretty printed values | ||
| fn array_to_string(array: &ArrayRef) -> Result<String> { | ||
| let s = (0..array.len()) | ||
| .map(|i| array_value_to_string(array, i)) | ||
| .collect::<std::result::Result<Vec<_>, arrow::error::ArrowError>>()? | ||
| .join("\n"); | ||
| Ok(s) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_coersion_error() -> Result<()> { | ||
| let expr = | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,8 +21,8 @@ use std::sync::Arc; | |
| extern crate arrow; | ||
| extern crate datafusion; | ||
|
|
||
| use arrow::record_batch::RecordBatch; | ||
| use arrow::{array::*, datatypes::TimeUnit}; | ||
| use arrow::{datatypes::Int32Type, record_batch::RecordBatch}; | ||
| use arrow::{ | ||
| datatypes::{DataType, Field, Schema, SchemaRef}, | ||
| util::display::array_value_to_string, | ||
|
|
@@ -930,14 +930,20 @@ fn register_alltypes_parquet(ctx: &mut ExecutionContext) { | |
| /// Execute query and return result set as 2-d table of Vecs | ||
| /// `result[row][column]` | ||
| async fn execute(ctx: &mut ExecutionContext, sql: &str) -> Vec<Vec<String>> { | ||
| let plan = ctx.create_logical_plan(&sql).unwrap(); | ||
| let msg = format!("Creating logical plan for '{}'", sql); | ||
| let plan = ctx.create_logical_plan(&sql).expect(&msg); | ||
jorgecarleitao marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| let logical_schema = plan.schema(); | ||
| let plan = ctx.optimize(&plan).unwrap(); | ||
|
|
||
| let msg = format!("Optimizing logical plan for '{}': {:?}", sql, plan); | ||
| let plan = ctx.optimize(&plan).expect(&msg); | ||
| let optimized_logical_schema = plan.schema(); | ||
| let plan = ctx.create_physical_plan(&plan).unwrap(); | ||
|
|
||
| let msg = format!("Creating physical plan for '{}': {:?}", sql, plan); | ||
| let plan = ctx.create_physical_plan(&plan).expect(&msg); | ||
| let physical_schema = plan.schema(); | ||
|
|
||
| let results = ctx.collect(plan).await.unwrap(); | ||
| let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); | ||
| let results = ctx.collect(plan).await.expect(&msg); | ||
|
|
||
| assert_eq!(logical_schema.as_ref(), optimized_logical_schema.as_ref()); | ||
| assert_eq!(logical_schema.as_ref(), physical_schema.as_ref()); | ||
|
|
@@ -1238,3 +1244,59 @@ async fn query_count_distinct() -> Result<()> { | |
| assert_eq!(expected, actual); | ||
| Ok(()) | ||
| } | ||
|
|
||
|
||
| #[tokio::test] | ||
| async fn query_on_string_dictionary() -> Result<()> { | ||
| // Test to ensure DataFusion can operate on dictionary types | ||
| // Use StringDictionary (32 bit indexes = keys) | ||
| let field_type = | ||
| DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); | ||
| let schema = Arc::new(Schema::new(vec![Field::new("d1", field_type, true)])); | ||
|
|
||
| let keys_builder = PrimitiveBuilder::<Int32Type>::new(10); | ||
| let values_builder = StringBuilder::new(10); | ||
| let mut builder = StringDictionaryBuilder::new(keys_builder, values_builder); | ||
|
|
||
| builder.append("one")?; | ||
| builder.append_null()?; | ||
| builder.append("three")?; | ||
| let array = Arc::new(builder.finish()); | ||
|
|
||
| let data = RecordBatch::try_new(schema.clone(), vec![array])?; | ||
|
|
||
| let table = MemTable::new(schema, vec![vec![data]])?; | ||
| let mut ctx = ExecutionContext::new(); | ||
| ctx.register_table("test", Box::new(table)); | ||
|
|
||
| // Basic SELECT | ||
| let sql = "SELECT * FROM test"; | ||
| let actual = execute(&mut ctx, sql).await; | ||
| let expected = vec![vec!["one"], vec!["NULL"], vec!["three"]]; | ||
| assert_eq!(expected, actual); | ||
|
|
||
| // basic filtering | ||
| let sql = "SELECT * FROM test WHERE d1 IS NOT NULL"; | ||
| let actual = execute(&mut ctx, sql).await; | ||
| let expected = vec![vec!["one"], vec!["three"]]; | ||
| assert_eq!(expected, actual); | ||
|
|
||
| // filtering with constant | ||
| let sql = "SELECT * FROM test WHERE d1 = 'three'"; | ||
| let actual = execute(&mut ctx, sql).await; | ||
| let expected = vec![vec!["three"]]; | ||
| assert_eq!(expected, actual); | ||
|
|
||
| // Expression evaluation | ||
| let sql = "SELECT concat(d1, '-foo') FROM test"; | ||
| let actual = execute(&mut ctx, sql).await; | ||
| let expected = vec![vec!["one-foo"], vec!["NULL"], vec!["three-foo"]]; | ||
| assert_eq!(expected, actual); | ||
|
|
||
| // aggregation | ||
| let sql = "SELECT COUNT(d1) FROM test"; | ||
| let actual = execute(&mut ctx, sql).await; | ||
| let expected = vec![vec!["2"]]; | ||
| assert_eq!(expected, actual); | ||
|
|
||
| Ok(()) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is just a refactor, it is not meant to change the semantics