Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion datafusion/src/field_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use arrow::datatypes::{DataType, Field};
use crate::error::{DataFusionError, Result};
use crate::scalar::ScalarValue;

/// Returns the field access indexed by `key` from a [`DataType::List`]
/// Returns the field access indexed by `key` from a [`DataType::List`] or [`DataType::Struct`]
/// # Error
/// Errors if
/// * the `data_type` is not a Struct or,
Expand All @@ -39,6 +39,25 @@ pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result<Fiel
Ok(Field::new(&i.to_string(), lt.data_type().clone(), false))
}
}
(DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => {
if s.is_empty() {
Err(DataFusionError::Plan(
"Struct based indexed access requires a non empty string".to_string(),
))
} else {
let field = fields.iter().find(|f| f.name() == s);
match field {
None => Err(DataFusionError::Plan(format!(
"Field {} not found in struct",
s
))),
Some(f) => Ok(f.clone()),
}
}
}
(DataType::Struct(_), _) => Err(DataFusionError::Plan(
"Only utf8 strings are valid as an indexed field in a struct".to_string(),
)),
(DataType::List(_), _) => Err(DataFusionError::Plan(
"Only ints are valid as an indexed field in a list".to_string(),
)),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ pub enum Expr {
IsNull(Box<Expr>),
/// arithmetic negation of an expression, the operand must be of a signed numeric data type
Negative(Box<Expr>),
/// Returns the field of a [`ListArray`] by key
/// Returns the field of a [`ListArray`] or [`StructArray`] by key
GetIndexedField {
/// the expression to take the field from
expr: Box<Expr>,
Expand Down
167 changes: 155 additions & 12 deletions datafusion/src/physical_plan/expressions/get_indexed_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use crate::{
field_util::get_indexed_field as get_data_type_field,
physical_plan::{ColumnarValue, PhysicalExpr},
};
use arrow::array::ListArray;
use arrow::array::{ListArray, StructArray};
use std::fmt::Debug;

/// expression to get a field of a struct array.
Expand Down Expand Up @@ -81,7 +81,7 @@ impl PhysicalExpr for GetIndexedFieldExpr {
let arg = self.arg.evaluate(batch)?;
match arg {
ColumnarValue::Array(array) => match (array.data_type(), &self.key) {
(DataType::List(_), _) if self.key.is_null() => {
(DataType::List(_) | DataType::Struct(_), _) if self.key.is_null() => {
let scalar_null: ScalarValue = array.data_type().try_into()?;
Ok(ColumnarValue::Scalar(scalar_null))
}
Expand All @@ -100,6 +100,13 @@ impl PhysicalExpr for GetIndexedFieldExpr {
let iter = concat(vec.as_slice()).unwrap();
Ok(ColumnarValue::Array(iter))
}
(DataType::Struct(_), ScalarValue::Utf8(Some(k))) => {
let as_struct_array = array.as_any().downcast_ref::<StructArray>().unwrap();
match as_struct_array.column_by_name(k) {
None => Err(DataFusionError::Execution(format!("get indexed field {} not found in struct", k))),
Some(col) => Ok(ColumnarValue::Array(col.clone()))
}
}
(dt, key) => Err(DataFusionError::NotImplemented(format!("get indexed field is only possible on lists with int64 indexes. Tried {} with {} index", dt, key))),
},
ColumnarValue::Scalar(_) => Err(DataFusionError::NotImplemented(
Expand All @@ -112,18 +119,16 @@ impl PhysicalExpr for GetIndexedFieldExpr {
#[cfg(test)]
mod tests {
use super::*;
use crate::arrow::array::GenericListArray;
use crate::error::Result;
use crate::physical_plan::expressions::{col, lit};
use arrow::array::{ListBuilder, StringBuilder};
use arrow::array::{
Int64Array, Int64Builder, ListBuilder, StringBuilder, StructArray, StructBuilder,
};
use arrow::{array::StringArray, datatypes::Field};

fn get_indexed_field_test(
list_of_lists: Vec<Vec<Option<&str>>>,
index: i64,
expected: Vec<Option<&str>>,
) -> Result<()> {
let schema = list_schema("l");
let builder = StringBuilder::new(3);
fn build_utf8_lists(list_of_lists: Vec<Vec<Option<&str>>>) -> GenericListArray<i32> {
let builder = StringBuilder::new(list_of_lists.len());
let mut lb = ListBuilder::new(builder);
for values in list_of_lists {
let builder = lb.values();
Expand All @@ -137,9 +142,18 @@ mod tests {
lb.append(true).unwrap();
}

let expr = col("l", &schema).unwrap();
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(lb.finish())])?;
lb.finish()
}

fn get_indexed_field_test(
list_of_lists: Vec<Vec<Option<&str>>>,
index: i64,
expected: Vec<Option<&str>>,
) -> Result<()> {
let schema = list_schema("l");
let list_col = build_utf8_lists(list_of_lists);
let expr = col("l", &schema).unwrap();
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(list_col)])?;
let key = ScalarValue::Int64(Some(index));
let expr = Arc::new(GetIndexedFieldExpr::new(expr, key));
let result = expr.evaluate(&batch)?.into_array(batch.num_rows());
Expand Down Expand Up @@ -222,4 +236,133 @@ mod tests {
let expr = col("l", &schema).unwrap();
get_indexed_field_test_failure(schema, expr, ScalarValue::Int8(Some(0)), "This feature is not implemented: get indexed field is only possible on lists with int64 indexes. Tried List(Field { name: \"item\", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: None }) with 0 index")
}

fn build_struct(
fields: Vec<Field>,
list_of_tuples: Vec<(Option<i64>, Vec<Option<&str>>)>,
) -> StructArray {
let foo_builder = Int64Array::builder(list_of_tuples.len());
let str_builder = StringBuilder::new(list_of_tuples.len());
let bar_builder = ListBuilder::new(str_builder);
let mut builder = StructBuilder::new(
fields,
vec![Box::new(foo_builder), Box::new(bar_builder)],
);
for (int_value, list_value) in list_of_tuples {
let fb = builder.field_builder::<Int64Builder>(0).unwrap();
match int_value {
None => fb.append_null(),
Some(v) => fb.append_value(v),
}
.unwrap();
builder.append(true).unwrap();
let lb = builder
.field_builder::<ListBuilder<StringBuilder>>(1)
.unwrap();
for str_value in list_value {
match str_value {
None => lb.values().append_null(),
Some(v) => lb.values().append_value(v),
}
.unwrap();
}
lb.append(true).unwrap();
}
builder.finish()
}

fn get_indexed_field_mixed_test(
list_of_tuples: Vec<(Option<i64>, Vec<Option<&str>>)>,
expected_strings: Vec<Vec<Option<&str>>>,
expected_ints: Vec<Option<i64>>,
) -> Result<()> {
let struct_col = "s";
let fields = vec![
Field::new("foo", DataType::Int64, true),
Field::new(
"bar",
DataType::List(Box::new(Field::new("item", DataType::Utf8, true))),
true,
),
];
let schema = Schema::new(vec![Field::new(
struct_col,
DataType::Struct(fields.clone()),
true,
)]);
let struct_col = build_struct(fields, list_of_tuples.clone());

let struct_col_expr = col("s", &schema).unwrap();
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_col)])?;

let int_field_key = ScalarValue::Utf8(Some("foo".to_string()));
let get_field_expr = Arc::new(GetIndexedFieldExpr::new(
struct_col_expr.clone(),
int_field_key,
));
let result = get_field_expr
.evaluate(&batch)?
.into_array(batch.num_rows());
let result = result
.as_any()
.downcast_ref::<Int64Array>()
.expect("failed to downcast to Int64Array");
let expected = &Int64Array::from(expected_ints);
assert_eq!(expected, result);

let list_field_key = ScalarValue::Utf8(Some("bar".to_string()));
let get_list_expr =
Arc::new(GetIndexedFieldExpr::new(struct_col_expr, list_field_key));
let result = get_list_expr.evaluate(&batch)?.into_array(batch.num_rows());
let result = result
.as_any()
.downcast_ref::<ListArray>()
.unwrap_or_else(|| panic!("failed to downcast to ListArray : {:?}", result));
let expected =
&build_utf8_lists(list_of_tuples.into_iter().map(|t| t.1).collect());
assert_eq!(expected, result);

for (i, expected) in expected_strings.into_iter().enumerate() {
let get_nested_str_expr = Arc::new(GetIndexedFieldExpr::new(
get_list_expr.clone(),
ScalarValue::Int64(Some(i as i64)),
));
let result = get_nested_str_expr
.evaluate(&batch)?
.into_array(batch.num_rows());
let result = result
.as_any()
.downcast_ref::<StringArray>()
.unwrap_or_else(|| {
panic!("failed to downcast to StringArray : {:?}", result)
});
let expected = &StringArray::from(expected);
assert_eq!(expected, result);
}
Ok(())
}

#[test]
fn get_indexed_field_struct() -> Result<()> {
let list_of_structs = vec![
(Some(10), vec![Some("a"), Some("b"), None]),
(Some(15), vec![None, Some("c"), Some("d")]),
(None, vec![Some("e"), None, Some("f")]),
];

let expected_list = vec![
vec![Some("a"), None, Some("e")],
vec![Some("b"), Some("c"), None],
vec![None, Some("d"), Some("f")],
];

let expected_ints = vec![Some(10), Some(15), None];

get_indexed_field_mixed_test(
list_of_structs.clone(),
expected_list,
expected_ints,
)?;
Ok(())
}
}
44 changes: 44 additions & 0 deletions datafusion/tests/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5385,3 +5385,47 @@ async fn query_nested_get_indexed_field() -> Result<()> {
assert_eq!(expected, actual);
Ok(())
}

#[tokio::test]
async fn query_nested_get_indexed_field_on_struct() -> Result<()> {
let mut ctx = ExecutionContext::new();
let nested_dt = DataType::List(Box::new(Field::new("item", DataType::Int64, true)));
// Nested schema of { "some_struct": { "bar": [i64] } }
let struct_fields = vec![Field::new("bar", nested_dt.clone(), true)];
let schema = Arc::new(Schema::new(vec![Field::new(
"some_struct",
DataType::Struct(struct_fields.clone()),
false,
)]));

let builder = PrimitiveBuilder::<Int64Type>::new(3);
let nested_lb = ListBuilder::new(builder);
let mut sb = StructBuilder::new(struct_fields, vec![Box::new(nested_lb)]);
for int_vec in vec![vec![0, 1, 2, 3], vec![4, 5, 6, 7], vec![8, 9, 10, 11]] {
let lb = sb.field_builder::<ListBuilder<Int64Builder>>(0).unwrap();
for int in int_vec {
lb.values().append_value(int).unwrap();
}
lb.append(true).unwrap();
}
let data = RecordBatch::try_new(schema.clone(), vec![Arc::new(sb.finish())])?;
let table = MemTable::try_new(schema, vec![vec![data]])?;
let table_a = Arc::new(table);

ctx.register_table("structs", table_a)?;

// Original column is micros, convert to millis and check timestamp
let sql = "SELECT some_struct[\"bar\"] as l0 FROM structs LIMIT 3";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is so cool!

let actual = execute(&mut ctx, sql).await;
let expected = vec![
vec!["[0, 1, 2, 3]"],
vec!["[4, 5, 6, 7]"],
vec!["[8, 9, 10, 11]"],
];
assert_eq!(expected, actual);
let sql = "SELECT some_struct[\"bar\"][0] as i0 FROM structs LIMIT 3";
let actual = execute(&mut ctx, sql).await;
let expected = vec![vec!["0"], vec!["4"], vec!["8"]];
assert_eq!(expected, actual);
Ok(())
}