diff --git a/rust/datafusion/src/physical_plan/hash_join.rs b/rust/datafusion/src/physical_plan/hash_join.rs index 9ac7447a8ab..d6b3b937004 100644 --- a/rust/datafusion/src/physical_plan/hash_join.rs +++ b/rust/datafusion/src/physical_plan/hash_join.rs @@ -81,6 +81,14 @@ pub struct HashJoinExec { build_side: Arc>>, } +/// Information about the index and placement (left or right) of the columns +struct ColumnIndex { + /// Index of the column + index: usize, + /// Whether the column is at the left or right side + is_left: bool, +} + impl HashJoinExec { /// Tries to create a new [HashJoinExec]. /// # Error @@ -116,6 +124,36 @@ impl HashJoinExec { build_side: Arc::new(Mutex::new(None)), }) } + + /// Calculates column indices and left/right placement on input / output schemas and jointype + fn column_indices_from_schema(&self) -> ArrowResult> { + let (primary_is_left, primary_schema, secondary_schema) = match self.join_type { + JoinType::Inner | JoinType::Left => { + (true, self.left.schema(), self.right.schema()) + } + JoinType::Right => (false, self.right.schema(), self.left.schema()), + }; + let mut column_indices = Vec::with_capacity(self.schema.fields().len()); + for field in self.schema.fields() { + let (is_primary, index) = match primary_schema.index_of(field.name()) { + Ok(i) => Ok((true, i)), + Err(_) => { + match secondary_schema.index_of(field.name()) { + Ok(i) => Ok((false, i)), + _ => Err(DataFusionError::Internal( + format!("During execution, the column {} was not found in neither the left or right side of the join", field.name()).to_string() + )) + } + } + }.map_err(DataFusionError::into_arrow_external_error)?; + + let is_left = + is_primary && primary_is_left || !is_primary && !primary_is_left; + column_indices.push(ColumnIndex { index, is_left }); + } + + Ok(column_indices) + } } #[async_trait] @@ -202,12 +240,15 @@ impl ExecutionPlan for HashJoinExec { .iter() .map(|on| on.1.clone()) .collect::>(); + + let column_indices = self.column_indices_from_schema()?; Ok(Box::pin(HashJoinStream { schema: self.schema.clone(), on_right, join_type: self.join_type, left_data, right: stream, + column_indices, })) } } @@ -252,6 +293,8 @@ struct HashJoinStream { left_data: JoinLeftData, /// right right: SendableRecordBatchStream, + /// Information of index and left / right placement of columns + column_indices: Vec, } impl RecordBatchStream for HashJoinStream { @@ -269,18 +312,13 @@ fn build_batch_from_indices( schema: &Schema, left: &Vec, right: &RecordBatch, - join_type: &JoinType, indices: &[(JoinIndex, RightIndex)], + column_indices: &Vec, ) -> ArrowResult { if left.is_empty() { todo!("Create empty record batch"); } - let (primary_is_left, primary_schema, secondary_schema) = match join_type { - JoinType::Inner | JoinType::Left => (true, left[0].schema(), right.schema()), - JoinType::Right => (false, right.schema(), left[0].schema()), - }; - // build the columns of the new [RecordBatch]: // 1. pick whether the column is from the left or right // 2. based on the pick, `take` items from the different recordBatches @@ -288,28 +326,12 @@ fn build_batch_from_indices( let right_indices = indices.iter().map(|(_, join_index)| join_index).collect(); - for field in schema.fields() { - // pick the column (left or right) based on the field name. - let (is_primary, column_index) = match primary_schema.index_of(field.name()) { - Ok(i) => Ok((true, i)), - Err(_) => { - match secondary_schema.index_of(field.name()) { - Ok(i) => Ok((false, i)), - _ => Err(DataFusionError::Internal( - format!("During execution, the column {} was not found in neither the left or right side of the join", field.name()).to_string() - )) - } - } - }.map_err(DataFusionError::into_arrow_external_error)?; - - let is_left = - (is_primary && primary_is_left) || (!is_primary && !primary_is_left); - - let array = if is_left { + for column_index in column_indices { + let array = if column_index.is_left { // Note that we take `.data_ref()` to gather the [ArrayData] of each array. let arrays = left .iter() - .map(|batch| batch.column(column_index).data_ref().as_ref()) + .map(|batch| batch.column(column_index.index).data_ref().as_ref()) .collect::>(); let mut mutable = MutableArrayData::new(arrays, true, indices.len()); @@ -323,7 +345,7 @@ fn build_batch_from_indices( make_array(Arc::new(mutable.freeze())) } else { // use the right indices - let array = right.column(column_index); + let array = right.column(column_index.index); compute::take(array.as_ref(), &right_indices, None)? }; columns.push(array); @@ -396,12 +418,13 @@ fn build_batch( batch: &RecordBatch, left_data: &JoinLeftData, on_right: &HashSet, - join_type: &JoinType, + join_type: JoinType, schema: &Schema, + column_indices: &Vec, ) -> ArrowResult { let indices = build_join_indexes(&left_data.0, &batch, join_type, on_right).unwrap(); - build_batch_from_indices(schema, &left_data.1, batch, join_type, &indices) + build_batch_from_indices(schema, &left_data.1, batch, &indices, column_indices) } /// returns a vector with (index from left, index from right). @@ -434,7 +457,7 @@ fn build_batch( fn build_join_indexes( left: &JoinHashMap, right: &RecordBatch, - join_type: &JoinType, + join_type: JoinType, right_on: &HashSet, ) -> Result> { let keys_values = right_on @@ -531,8 +554,9 @@ impl Stream for HashJoinStream { &batch, &self.left_data, &self.on_right, - &self.join_type, + self.join_type, &self.schema, + &self.column_indices, )), other => other, })