Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql

import scala.util.control.NonFatal

import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.HiveResult.hiveResultString
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeCommandBase}
import org.apache.spark.sql.types.StructType

trait CometSQLQueryTestHelper {
Copy link
Copy Markdown
Member Author

@viirya viirya Jul 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This copies from Spark SQLQueryTestHelper with a little changes to accommodate difference between Spark 3.3/3.4/3.5/4.0.


private val notIncludedMsg = "[not included in comparison]"
private val clsName = this.getClass.getCanonicalName
protected val emptySchema: String = StructType(Seq.empty).catalogString

protected def replaceNotIncludedMsg(line: String): String = {
line
.replaceAll("#\\d+", "#x")
.replaceAll("plan_id=\\d+", "plan_id=x")
.replaceAll(s"Location.*$clsName/", s"Location $notIncludedMsg/{warehouse_dir}/")
.replaceAll(s"file:[^\\s,]*$clsName", s"file:$notIncludedMsg/{warehouse_dir}")
.replaceAll("Created By.*", s"Created By $notIncludedMsg")
.replaceAll("Created Time.*", s"Created Time $notIncludedMsg")
.replaceAll("Last Access.*", s"Last Access $notIncludedMsg")
.replaceAll("Partition Statistics\t\\d+", s"Partition Statistics\t$notIncludedMsg")
.replaceAll("\\*\\(\\d+\\) ", "*") // remove the WholeStageCodegen codegenStageIds
}

/** Executes a query and returns the result as (schema of the output, normalized output). */
protected def getNormalizedResult(session: SparkSession, sql: String): (String, Seq[String]) = {
// Returns true if the plan is supposed to be sorted.
def isSorted(plan: LogicalPlan): Boolean = plan match {
case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false
case _: DescribeCommandBase | _: DescribeColumnCommand | _: DescribeRelation |
_: DescribeColumn =>
true
case PhysicalOperation(_, _, Sort(_, true, _)) => true
case _ => plan.children.iterator.exists(isSorted)
}

val df = session.sql(sql)
val schema = df.schema.catalogString
// Get answer, but also get rid of the #1234 expression ids that show up in explain plans
val answer = SQLExecution.withNewExecutionId(df.queryExecution, Some(sql)) {
hiveResultString(df.queryExecution.executedPlan).map(replaceNotIncludedMsg)
}

// If the output is not pre-sorted, sort it.
if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted)
}

/**
* This method handles exceptions occurred during query execution as they may need special care
* to become comparable to the expected output.
*
* @param result
* a function that returns a pair of schema and output
*/
protected def handleExceptions(result: => (String, Seq[String])): (String, Seq[String]) = {
try {
result
} catch {
case e: SparkThrowable with Throwable if e.getErrorClass != null =>
(emptySchema, Seq(e.getClass.getName, e.getMessage))
case a: AnalysisException =>
// Do not output the logical plan tree which contains expression IDs.
// Also implement a crude way of masking expression IDs in the error message
// with a generic pattern "###".
val msg = a.getMessage
(emptySchema, Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x")))
case s: SparkException if s.getCause != null =>
// For a runtime exception, it is hard to match because its message contains
// information of stage, task ID, etc.
// To make result matching simpler, here we match the cause of the exception if it exists.
s.getCause match {
case e: SparkThrowable with Throwable if e.getErrorClass != null =>
(emptySchema, Seq(e.getClass.getName, e.getMessage))
case cause =>
(emptySchema, Seq(cause.getClass.getName, cause.getMessage))
}
case NonFatal(e) =>
// If there is an exception, put the exception class followed by the message.
(emptySchema, Seq(e.getClass.getName, e.getMessage))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ class CometTPCDSQuerySuite
override val tpcdsQueries: Seq[String] =
tpcdsAllQueries.filterNot(excludedTpcdsQueries.contains)
}
with TPCDSQueryTestSuite
with CometTPCDSQueryTestSuite
with ShimCometTPCDSQuerySuite {
override def sparkConf: SparkConf = {
val conf = super.sparkConf
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql

import java.io.File
import java.nio.file.{Files, Paths}

import scala.collection.JavaConverters._

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.catalyst.util.{fileToString, resourceToString, stringToFile}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.TestSparkSession

/**
* Because we need to modify some methods of Spark `TPCDSQueryTestSuite` but they are private, we
* copy Spark `TPCDSQueryTestSuite`.
*/
class CometTPCDSQueryTestSuite extends QueryTest with TPCDSBase with CometSQLQueryTestHelper {

private val tpcdsDataPath = sys.env.get("SPARK_TPCDS_DATA")

private val regenGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1"

// To make output results deterministic
override protected def sparkConf: SparkConf = super.sparkConf
.set(SQLConf.SHUFFLE_PARTITIONS.key, "1")

protected override def createSparkSession: TestSparkSession = {
new TestSparkSession(new SparkContext("local[1]", this.getClass.getSimpleName, sparkConf))
}

// We use SF=1 table data here, so we cannot use SF=100 stats
protected override val injectStats: Boolean = false

if (tpcdsDataPath.nonEmpty) {
val nonExistentTables = tableNames.filterNot { tableName =>
Files.exists(Paths.get(s"${tpcdsDataPath.get}/$tableName"))
}
if (nonExistentTables.nonEmpty) {
fail(
s"Non-existent TPCDS table paths found in ${tpcdsDataPath.get}: " +
nonExistentTables.mkString(", "))
}
}

protected val baseResourcePath: String = {
// use the same way as `SQLQueryTestSuite` to get the resource path
getWorkspaceFilePath(
"sql",
"core",
"src",
"test",
"resources",
"tpcds-query-results").toFile.getAbsolutePath
}

override def createTable(
spark: SparkSession,
tableName: String,
format: String = "parquet",
options: scala.Seq[String]): Unit = {
spark.sql(s"""
|CREATE TABLE `$tableName` (${tableColumns(tableName)})
|USING $format
|LOCATION '${tpcdsDataPath.get}/$tableName'
|${options.mkString("\n")}
""".stripMargin)
}

private def runQuery(query: String, goldenFile: File, conf: Map[String, String]): Unit = {
// This is `sortMergeJoinConf != conf` in Spark, i.e., it sorts results for other joins
// than sort merge join. But in some queries DataFusion sort returns correct results
// in terms of required sorting columns, but the results are not same as Spark in terms of
// order of irrelevant columns. So, we need to sort the results for all joins.
val shouldSortResults = true
withSQLConf(conf.toSeq: _*) {
try {
val (schema, output) = handleExceptions(getNormalizedResult(spark, query))
val queryString = query.trim
val outputString = output.mkString("\n").replaceAll("\\s+$", "")
if (regenGoldenFiles) {
val goldenOutput = {
s"-- Automatically generated by ${getClass.getSimpleName}\n\n" +
"-- !query schema\n" +
schema + "\n" +
"-- !query output\n" +
outputString +
"\n"
}
val parent = goldenFile.getParentFile
if (!parent.exists()) {
assert(parent.mkdirs(), "Could not create directory: " + parent)
}
stringToFile(goldenFile, goldenOutput)
}

// Read back the golden file.
val (expectedSchema, expectedOutput) = {
val goldenOutput = fileToString(goldenFile)
val segments = goldenOutput.split("-- !query.*\n")

// query has 3 segments, plus the header
assert(
segments.size == 3,
s"Expected 3 blocks in result file but got ${segments.size}. " +
"Try regenerate the result files.")

(segments(1).trim, segments(2).replaceAll("\\s+$", ""))
}

val notMatchedSchemaOutput = if (schema == emptySchema) {
// There might be exception. See `handleExceptions`.
s"Schema did not match\n$queryString\nOutput/Exception: $outputString"
} else {
s"Schema did not match\n$queryString"
}

assertResult(expectedSchema, notMatchedSchemaOutput) {
schema
}
if (shouldSortResults) {
val expectSorted = expectedOutput
.split("\n")
.sorted
.map(_.trim)
.mkString("\n")
.replaceAll("\\s+$", "")
val outputSorted = output.sorted.map(_.trim).mkString("\n").replaceAll("\\s+$", "")
assertResult(expectSorted, s"Result did not match\n$queryString") {
outputSorted
}
} else {
assertResult(expectedOutput, s"Result did not match\n$queryString") {
outputString
}
}
} catch {
case e: Throwable =>
val configs = conf.map { case (k, v) =>
s"$k=$v"
}
throw new Exception(s"${e.getMessage}\nError using configs:\n${configs.mkString("\n")}")
}
}
}

val sortMergeJoinConf: Map[String, String] = Map(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
SQLConf.PREFER_SORTMERGEJOIN.key -> "true")

val broadcastHashJoinConf: Map[String, String] = Map(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10485760")

val shuffledHashJoinConf: Map[String, String] = Map(
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
"spark.sql.join.forceApplyShuffledHashJoin" -> "true")

val allJoinConfCombinations: Seq[Map[String, String]] =
Seq(sortMergeJoinConf, broadcastHashJoinConf, shuffledHashJoinConf)

val joinConfs: Seq[Map[String, String]] = if (regenGoldenFiles) {
require(
!sys.env.contains("SPARK_TPCDS_JOIN_CONF"),
"'SPARK_TPCDS_JOIN_CONF' cannot be set together with 'SPARK_GENERATE_GOLDEN_FILES'")
Seq(sortMergeJoinConf)
} else {
sys.env
.get("SPARK_TPCDS_JOIN_CONF")
.map { s =>
val p = new java.util.Properties()
p.load(new java.io.StringReader(s))
Seq(p.asScala.toMap)
}
.getOrElse(allJoinConfCombinations)
}

assert(joinConfs.nonEmpty)
joinConfs.foreach(conf =>
require(
allJoinConfCombinations.contains(conf),
s"Join configurations [$conf] should be one of $allJoinConfCombinations"))

if (tpcdsDataPath.nonEmpty) {
tpcdsQueries.foreach { name =>
val queryString = resourceToString(
s"tpcds/$name.sql",
classLoader = Thread.currentThread().getContextClassLoader)
test(name) {
val goldenFile = new File(s"$baseResourcePath/v1_4", s"$name.sql.out")
joinConfs.foreach { conf =>
System.gc() // Workaround for GitHub Actions memory limitation, see also SPARK-37368
runQuery(queryString, goldenFile, conf)
}
}
}

tpcdsQueriesV2_7_0.foreach { name =>
val queryString = resourceToString(
s"tpcds-v2.7.0/$name.sql",
classLoader = Thread.currentThread().getContextClassLoader)
test(s"$name-v2.7") {
val goldenFile = new File(s"$baseResourcePath/v2_7", s"$name.sql.out")
joinConfs.foreach { conf =>
System.gc() // SPARK-37368
runQuery(queryString, goldenFile, conf)
}
}
}
} else {
ignore("skipped because env `SPARK_TPCDS_DATA` is not set") {}
}
}