Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,7 @@ pub fn can_hash(data_type: &DataType) -> bool {
DataType::List(_) => true,
DataType::LargeList(_) => true,
DataType::FixedSizeList(_, _) => true,
DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
_ => false,
}
}
Expand Down
111 changes: 108 additions & 3 deletions datafusion/physical-plan/src/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1212,11 +1212,16 @@ fn eq_dyn_null(
right: &dyn Array,
null_equals_null: bool,
) -> Result<BooleanArray, ArrowError> {
// Nested datatypes cannot use the underlying not_distinct function and must use a special
// Nested datatypes cannot use the underlying not_distinct/eq function and must use a special
// implementation
// <https://github.com/apache/datafusion/issues/10749>
if left.data_type().is_nested() && null_equals_null {
return Ok(compare_op_for_nested(&Operator::Eq, &left, &right)?);
if left.data_type().is_nested() {
let op = if null_equals_null {
Operator::IsNotDistinctFrom
} else {
Operator::Eq
};
return Ok(compare_op_for_nested(&op, &left, &right)?);
}
match (left.data_type(), right.data_type()) {
_ if null_equals_null => not_distinct(&left, &right),
Expand Down Expand Up @@ -1546,6 +1551,8 @@ mod tests {

use arrow::array::{Date32Array, Int32Array, UInt32Builder, UInt64Builder};
use arrow::datatypes::{DataType, Field};
use arrow_array::StructArray;
use arrow_buffer::NullBuffer;
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err,
ScalarValue,
Expand Down Expand Up @@ -3844,6 +3851,104 @@ mod tests {
Ok(())
}

fn build_table_struct(
struct_name: &str,
field_name_and_values: (&str, &Vec<Option<i32>>),
nulls: Option<NullBuffer>,
) -> Arc<dyn ExecutionPlan> {
let (field_name, values) = field_name_and_values;
let inner_fields = vec![Field::new(field_name, DataType::Int32, true)];
let schema = Schema::new(vec![Field::new(
struct_name,
DataType::Struct(inner_fields.clone().into()),
nulls.is_some(),
)]);

let batch = RecordBatch::try_new(
Arc::new(schema),
vec![Arc::new(StructArray::new(
inner_fields.into(),
vec![Arc::new(Int32Array::from(values.clone()))],
nulls,
))],
)
.unwrap();
let schema_ref = batch.schema();
Arc::new(MemoryExec::try_new(&[vec![batch]], schema_ref, None).unwrap())
}

#[tokio::test]
async fn join_on_struct() -> Result<()> {
Comment thread
eejbyfeldt marked this conversation as resolved.
let task_ctx = Arc::new(TaskContext::default());
let left =
build_table_struct("n1", ("a", &vec![None, Some(1), Some(2), Some(3)]), None);
let right =
build_table_struct("n2", ("a", &vec![None, Some(1), Some(2), Some(4)]), None);
let on = vec![(
Arc::new(Column::new_with_schema("n1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("n2", &right.schema())?) as _,
)];

let (columns, batches) =
join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?;

assert_eq!(columns, vec!["n1", "n2"]);

let expected = [
"+--------+--------+",
"| n1 | n2 |",
"+--------+--------+",
"| {a: } | {a: } |",
"| {a: 1} | {a: 1} |",
"| {a: 2} | {a: 2} |",
"+--------+--------+",
];
assert_batches_eq!(expected, &batches);

Ok(())
}

#[tokio::test]
async fn join_on_struct_with_nulls() -> Result<()> {
Comment thread
eejbyfeldt marked this conversation as resolved.
let task_ctx = Arc::new(TaskContext::default());
let left =
build_table_struct("n1", ("a", &vec![None]), Some(NullBuffer::new_null(1)));
let right =
build_table_struct("n2", ("a", &vec![None]), Some(NullBuffer::new_null(1)));
let on = vec![(
Arc::new(Column::new_with_schema("n1", &left.schema())?) as _,
Arc::new(Column::new_with_schema("n2", &right.schema())?) as _,
)];

let (_, batches_null_eq) = join_collect(
left.clone(),
right.clone(),
on.clone(),
&JoinType::Inner,
true,
task_ctx.clone(),
)
.await?;

let expected_null_eq = [
"+----+----+",
"| n1 | n2 |",
"+----+----+",
"| | |",
"+----+----+",
];
assert_batches_eq!(expected_null_eq, &batches_null_eq);

let (_, batches_null_neq) =
join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?;

let expected_null_neq =
["+----+----+", "| n1 | n2 |", "+----+----+", "+----+----+"];
assert_batches_eq!(expected_null_neq, &batches_null_neq);

Ok(())
}

/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
Expand Down
52 changes: 52 additions & 0 deletions datafusion/sqllogictest/test_files/joins.slt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ AS VALUES
(44, 'x', 3),
(55, 'w', 3);

statement ok
CREATE TABLE join_t3(s3 struct<id INT>)
AS VALUES
(NULL),
(struct(1)),
(struct(2));

statement ok
CREATE TABLE join_t4(s4 struct<id INT>)
AS VALUES
(NULL),
(struct(2)),
(struct(3));

# Left semi anti join

statement ok
Expand Down Expand Up @@ -1336,6 +1350,44 @@ physical_plan
10)----------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
11)------------MemoryExec: partitions=1, partition_sizes=[1]

# Join on struct
query TT
explain select join_t3.s3, join_t4.s4
from join_t3
inner join join_t4 on join_t3.s3 = join_t4.s4
----
logical_plan
01)Inner Join: join_t3.s3 = join_t4.s4
02)--TableScan: join_t3 projection=[s3]
03)--TableScan: join_t4 projection=[s4]
physical_plan
01)CoalesceBatchesExec: target_batch_size=2
02)--HashJoinExec: mode=Partitioned, join_type=Inner, on=[(s3@0, s4@0)]
03)----CoalesceBatchesExec: target_batch_size=2
04)------RepartitionExec: partitioning=Hash([s3@0], 2), input_partitions=2
05)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
06)----------MemoryExec: partitions=1, partition_sizes=[1]
07)----CoalesceBatchesExec: target_batch_size=2
08)------RepartitionExec: partitioning=Hash([s4@0], 2), input_partitions=2
09)--------RepartitionExec: partitioning=RoundRobinBatch(2), input_partitions=1
10)----------MemoryExec: partitions=1, partition_sizes=[1]

query ??
select join_t3.s3, join_t4.s4
from join_t3
inner join join_t4 on join_t3.s3 = join_t4.s4
----
{id: 2} {id: 2}

# join with struct key and nulls
Comment thread
eejbyfeldt marked this conversation as resolved.
# Note that intersect or except applies `null_equals_null` as true for Join.
query ?
SELECT * FROM join_t3
EXCEPT
SELECT * FROM join_t4
----
{id: 1}

query TT
EXPLAIN
select count(*)
Expand Down