-
Notifications
You must be signed in to change notification settings - Fork 305
fix: Supported nested types in HashJoin #735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1e93be4
9b334e0
bd3e92d
c1ab4f4
c4f2eee
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,6 +25,7 @@ import org.scalatest.Tag | |
| import org.apache.spark.sql.CometTestBase | ||
| import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometBroadcastHashJoinExec} | ||
| import org.apache.spark.sql.internal.SQLConf | ||
| import org.apache.spark.sql.types.Decimal | ||
|
|
||
| import org.apache.comet.CometConf | ||
| import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus | ||
|
|
@@ -197,6 +198,66 @@ class CometJoinSuite extends CometTestBase { | |
| } | ||
| } | ||
|
|
||
| test("HashJoin struct key") { | ||
| withSQLConf( | ||
| "spark.sql.join.forceApplyShuffledHashJoin" -> "true", | ||
| SQLConf.PREFER_SORTMERGEJOIN.key -> "false", | ||
| SQLConf.ADAPTIVE_AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1", | ||
| SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { | ||
|
|
||
| def manyTypes(idx: Int, v: Int) = | ||
| ( | ||
| idx, | ||
| v, | ||
| v.toLong, | ||
| v.toFloat, | ||
| v.toDouble, | ||
| v.toString, | ||
| v % 2 == 0, | ||
| v.toString().getBytes, | ||
| Decimal(v)) | ||
|
|
||
| withParquetTable((0 until 10).map(i => manyTypes(i, i % 5)), "tbl_a") { | ||
| withParquetTable((0 until 10).map(i => manyTypes(i, i % 10)), "tbl_b") { | ||
| // Full join: struct key | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for adding the test. I wonder if we should make this more comprehensive to cover structs containing different types, nulls, and nested structs? Also, what happens with structs containing unsupported types such as array and map? Do we still fall back for those? It would be good to have a test for this case as well.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good call will add more test cases. Created #797 to make it easier to create nulls of struct type.
Do you mean unsupported by comet here? I think the answer for
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point about map not being supported by Spark. I think we should fall back for array for now because we don't really support array yet.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We do fallback currently. Is there someway or even desired to add a test for that?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we could improve our test framework to make it easier to test for fallback but it is possible with code like this (must be used after calling collect on a DataFrame). I think we can improve the tests in a future PR. |
||
| val df1 = | ||
| sql( | ||
| "SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a FULL JOIN tbl_b " + | ||
| "ON named_struct('1', tbl_a._2) = named_struct('1', tbl_b._1)") | ||
| checkSparkAnswerAndOperator(df1) | ||
|
|
||
| // Full join: struct key with nulls | ||
| val df2 = | ||
| sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a FULL JOIN tbl_b " + | ||
| "ON IF(tbl_a._1 > 5, named_struct('2', tbl_a._2), NULL) = IF(tbl_b._2 > 5, named_struct('2', tbl_b._1), NULL)") | ||
| checkSparkAnswerAndOperator(df2) | ||
|
|
||
| // Full join: struct key with nulls in the struct | ||
| val df3 = | ||
| sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a FULL JOIN tbl_b " + | ||
| "ON named_struct('2', IF(tbl_a._1 > 5, tbl_a._2, NULL)) = named_struct('2', IF(tbl_b._2 > 5, tbl_b._1, NULL))") | ||
| checkSparkAnswerAndOperator(df3) | ||
|
|
||
| // Full join: nested structs | ||
| val df4 = | ||
| sql("SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a FULL JOIN tbl_b " + | ||
| "ON named_struct('1', named_struct('2', tbl_a._2)) = named_struct('1', named_struct('2', tbl_b._1))") | ||
| checkSparkAnswerAndOperator(df4) | ||
|
|
||
| val columnCount = manyTypes(0, 0).productArity | ||
| def key(tbl: String) = | ||
| (1 to columnCount).map(i => s"${tbl}._$i").mkString("struct(", ", ", ")") | ||
| // Using several different types in the struct key | ||
| val df5 = | ||
| sql( | ||
| "SELECT /*+ SHUFFLE_HASH(tbl_b) */ * FROM tbl_a FULL JOIN tbl_b " + | ||
| s"ON ${key("tbl_a")} = ${key("tbl_b")}") | ||
| checkSparkAnswerAndOperator(df5) | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| test("HashJoin with join filter") { | ||
| withSQLConf( | ||
| SQLConf.PREFER_SORTMERGEJOIN.key -> "false", | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.