diff --git a/datafusion/physical-plan/src/unnest.rs b/datafusion/physical-plan/src/unnest.rs index 06288a1f70419..0615e6738a1fc 100644 --- a/datafusion/physical-plan/src/unnest.rs +++ b/datafusion/physical-plan/src/unnest.rs @@ -36,7 +36,7 @@ use arrow::compute::kernels::zip::zip; use arrow::compute::{cast, is_not_null, kernels, sum}; use arrow::datatypes::{DataType, Int64Type, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; -use arrow_array::{Int64Array, Scalar, StructArray}; +use arrow_array::{new_null_array, Int64Array, Scalar, StructArray}; use arrow_ord::cmp::lt; use datafusion_common::{ exec_datafusion_err, exec_err, internal_err, HashMap, HashSet, Result, UnnestOptions, @@ -453,16 +453,36 @@ fn list_unnest_at_level( // Create the take indices array for other columns let take_indices = create_take_indicies(unnested_length, total_length); - - // Dimension of arrays in batch is untouched, but the values are repeated - // as the side effect of unnesting - let ret = repeat_arrs_from_indices(batch, &take_indices)?; unnested_temp_arrays .into_iter() .zip(list_unnest_specs.iter()) .for_each(|(flatten_arr, unnesting)| { temp_unnested_arrs.insert(*unnesting, flatten_arr); }); + + let repeat_mask: Vec = batch + .iter() + .enumerate() + .map(|(i, _)| { + // Check if the column is needed in future levels (levels below the current one) + let needed_in_future_levels = list_type_unnests.iter().any(|unnesting| { + unnesting.index_in_input_schema == i && unnesting.depth < level_to_unnest + }); + + // Check if the column is involved in unnesting at any level + let is_involved_in_unnesting = list_type_unnests + .iter() + .any(|unnesting| unnesting.index_in_input_schema == i); + + // Repeat columns needed in future levels or not unnested. + needed_in_future_levels || !is_involved_in_unnesting + }) + .collect(); + + // Dimension of arrays in batch is untouched, but the values are repeated + // as the side effect of unnesting + let ret = repeat_arrs_from_indices(batch, &take_indices, &repeat_mask)?; + Ok((ret, total_length)) } struct UnnestingResult { @@ -859,8 +879,11 @@ fn create_take_indicies( builder.finish() } -/// Create the batch given an arrays and a `indices` array -/// that is used by the take kernel to copy values. +/// Create a batch of arrays based on an input `batch` and a `indices` array. +/// The `indices` array is used by the take kernel to repeat values in the arrays +/// that are marked with `true` in the `repeat_mask`. Arrays marked with `false` +/// in the `repeat_mask` will be replaced with arrays filled with nulls of the +/// appropriate length. /// /// For example if we have the following batch: /// @@ -890,14 +913,35 @@ fn create_take_indicies( /// c2: 'a', 'b', 'c', 'c', 'c', null, 'd', 'd' /// ``` /// +/// The `repeat_mask` determines whether an array's values are repeated or replaced with nulls. +/// For example, if the `repeat_mask` is: +/// +/// ```ignore +/// [true, false] +/// ``` +/// +/// The final batch will look like: +/// +/// ```ignore +/// c1: 1, null, 2, 3, 4, null, 5, 6 // Repeated using `indices` +/// c2: null, null, null, null, null, null, null, null // Replaced with nulls +/// fn repeat_arrs_from_indices( batch: &[ArrayRef], indices: &PrimitiveArray, + repeat_mask: &[bool], ) -> Result>> { batch .iter() - .map(|arr| Ok(kernels::take::take(arr, indices, None)?)) - .collect::>() + .zip(repeat_mask.iter()) + .map(|(arr, &repeat)| { + if repeat { + Ok(kernels::take::take(arr, indices, None)?) + } else { + Ok(new_null_array(arr.data_type(), arr.len())) + } + }) + .collect() } #[cfg(test)] diff --git a/datafusion/sqllogictest/test_files/unnest.slt b/datafusion/sqllogictest/test_files/unnest.slt index 8ebed5b25ca92..2e1b8b87cc429 100644 --- a/datafusion/sqllogictest/test_files/unnest.slt +++ b/datafusion/sqllogictest/test_files/unnest.slt @@ -853,3 +853,9 @@ select unnest(u.column5), j.* except(column2, column3) from unnest_table u join 1 2 1 3 4 2 NULL NULL 4 + +## Issue: https://github.com/apache/datafusion/issues/13237 +query I +select count(*) from (select unnest(range(0, 100000)) id) t inner join (select unnest(range(0, 100000)) id) t1 on t.id = t1.id; +---- +100000