diff --git a/.github/workflows/pr_build_linux.yml b/.github/workflows/pr_build_linux.yml index 7df0aa0697..8376afbc6e 100644 --- a/.github/workflows/pr_build_linux.yml +++ b/.github/workflows/pr_build_linux.yml @@ -250,6 +250,7 @@ jobs: - name: "expressions" value: | org.apache.comet.CometExpressionSuite + org.apache.comet.CometSqlFileTestSuite org.apache.comet.CometExpressionCoverageSuite org.apache.comet.CometHashExpressionSuite org.apache.comet.CometTemporalExpressionSuite diff --git a/.github/workflows/pr_build_macos.yml b/.github/workflows/pr_build_macos.yml index acf052bf13..3a64c0051f 100644 --- a/.github/workflows/pr_build_macos.yml +++ b/.github/workflows/pr_build_macos.yml @@ -193,6 +193,7 @@ jobs: - name: "expressions" value: | org.apache.comet.CometExpressionSuite + org.apache.comet.CometSqlFileTestSuite org.apache.comet.CometExpressionCoverageSuite org.apache.comet.CometHashExpressionSuite org.apache.comet.CometTemporalExpressionSuite diff --git a/spark/src/test/resources/sql-tests/expressions/arithmetic.sql b/spark/src/test/resources/sql-tests/expressions/arithmetic.sql new file mode 100644 index 0000000000..16ac7b2429 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/arithmetic.sql @@ -0,0 +1,38 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +-- negative +statement +CREATE TABLE test_neg(col1 int) USING parquet + +statement +INSERT INTO test_neg VALUES(1), (2), (3), (3) + +query +SELECT negative(col1), -(col1) FROM test_neg + +-- integral division overflow +statement +CREATE TABLE test_div(c1 long, c2 short) USING parquet + +statement +INSERT INTO test_div VALUES(-9223372036854775808, -1) + +query +SELECT c1 div c2 FROM test_div ORDER BY c1 diff --git a/spark/src/test/resources/sql-tests/expressions/bitwise.sql b/spark/src/test/resources/sql-tests/expressions/bitwise.sql new file mode 100644 index 0000000000..067836a63b --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/bitwise.sql @@ -0,0 +1,50 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +-- Setup +statement +CREATE TABLE test(col1 int, col2 int) USING parquet + +statement +INSERT INTO test VALUES(1111, 2) + +statement +INSERT INTO test VALUES(1111, 2) + +statement +INSERT INTO test VALUES(3333, 4) + +statement +INSERT INTO test VALUES(5555, 6) + +-- Queries +query +SELECT col1 & col2, col1 | col2, col1 ^ col2 FROM test + +query +SELECT col1 & 1234, col1 | 1234, col1 ^ 1234 FROM test + +query +SELECT shiftright(col1, 2), shiftright(col1, col2) FROM test + +query +SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM test + +query +SELECT ~(11), ~col1, ~col2 FROM test diff --git a/spark/src/test/resources/sql-tests/expressions/boolean.sql b/spark/src/test/resources/sql-tests/expressions/boolean.sql new file mode 100644 index 0000000000..00688f5fa8 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/boolean.sql @@ -0,0 +1,41 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +-- compare true/false to negative zero +statement +CREATE TABLE test(col1 boolean, col2 float) USING parquet + +statement +INSERT INTO test VALUES(true, -0.0) + +statement +INSERT INTO test VALUES(false, -0.0) + +query +SELECT col1, negative(col2), cast(col1 as float), col1 = negative(col2) FROM test + +-- not +statement +CREATE TABLE test_not(col1 int, col2 boolean) USING parquet + +statement +INSERT INTO test_not VALUES(1, false), (2, true), (3, true), (3, false) + +query +SELECT col1, col2, NOT(col2), !(col2) FROM test_not diff --git a/spark/src/test/resources/sql-tests/expressions/datetime.sql b/spark/src/test/resources/sql-tests/expressions/datetime.sql new file mode 100644 index 0000000000..93c615d5b3 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/datetime.sql @@ -0,0 +1,28 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +-- DatePart functions +statement +CREATE TABLE test_dt(col timestamp) USING parquet + +statement +INSERT INTO test_dt VALUES (timestamp('2024-06-15 10:30:00')), (timestamp('1900-01-01')), (null) + +query +SELECT col, year(col), month(col), day(col), weekday(col), dayofweek(col), dayofyear(col), weekofyear(col), quarter(col) FROM test_dt diff --git a/spark/src/test/resources/sql-tests/expressions/hash.sql b/spark/src/test/resources/sql-tests/expressions/hash.sql new file mode 100644 index 0000000000..0629bc0a77 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/hash.sql @@ -0,0 +1,28 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: parquet.enable.dictionary=false,true + +-- hash functions +statement +CREATE TABLE test(col string, a int, b float) USING parquet + +statement +INSERT INTO test VALUES ('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999), ('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999) + +query +SELECT md5(col), md5(cast(a as string)), md5(cast(b as string)), hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1), sha1(col), sha1(cast(a as string)), sha1(cast(b as string)) FROM test diff --git a/spark/src/test/resources/sql-tests/expressions/in_set.sql b/spark/src/test/resources/sql-tests/expressions/in_set.sql new file mode 100644 index 0000000000..1cb0f248a6 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/in_set.sql @@ -0,0 +1,38 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- ConfigMatrix: spark.sql.optimizer.inSetConversionThreshold=100,0 +-- ConfigMatrix: parquet.enable.dictionary=false,true + +-- test in(set)/not in(set) +statement +CREATE TABLE names(id int, name varchar(20)) USING parquet + +statement +INSERT INTO names VALUES(1, 'James'), (1, 'Jones'), (2, 'Smith'), (3, 'Smith'), (NULL, 'Jones'), (4, NULL) + +query +SELECT * FROM names WHERE id in (1, 2, 4, NULL) + +query +SELECT * FROM names WHERE name in ('Smith', 'Brown', NULL) + +query +SELECT * FROM names WHERE id not in (1) + +query spark_answer_only +SELECT * FROM names WHERE name not in ('Smith', 'Brown', NULL) diff --git a/spark/src/test/resources/sql-tests/expressions/parquet_default_values.sql b/spark/src/test/resources/sql-tests/expressions/parquet_default_values.sql new file mode 100644 index 0000000000..bb1b003c79 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/parquet_default_values.sql @@ -0,0 +1,29 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- parquet default values +statement +CREATE TABLE t1(col1 boolean) USING parquet + +statement +INSERT INTO t1 VALUES(true) + +statement +ALTER TABLE t1 ADD COLUMN col2 string DEFAULT 'hello' + +query +SELECT * FROM t1 diff --git a/spark/src/test/resources/sql-tests/expressions/string.sql b/spark/src/test/resources/sql-tests/expressions/string.sql new file mode 100644 index 0000000000..15c0107fae --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/string.sql @@ -0,0 +1,49 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- substring with start < 1 +statement +CREATE TABLE t(col string) USING parquet + +statement +INSERT INTO t VALUES('123456') + +query +SELECT substring(col, 0) FROM t + +query +SELECT substring(col, -1) FROM t + +-- md5 +statement +CREATE TABLE test_md5(col String) USING parquet + +statement +INSERT INTO test_md5 VALUES ('test1'), ('test1'), ('test2'), ('test2'), (NULL), ('') + +query +SELECT md5(col) FROM test_md5 + +-- unhex +statement +CREATE TABLE unhex_table(col string) USING parquet + +statement +INSERT INTO unhex_table VALUES ('537061726B2053514C'), ('737472696E67'), ('\0'), (''), ('###'), ('G123'), ('hello'), ('A1B'), ('0A1B') + +query +SELECT unhex(col) FROM unhex_table diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index fe5ea77a89..576af73745 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -131,31 +131,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("compare true/false to negative zero") { - Seq(false, true).foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { - val table = "test" - withTable(table) { - sql(s"create table $table(col1 boolean, col2 float) using parquet") - sql(s"insert into $table values(true, -0.0)") - sql(s"insert into $table values(false, -0.0)") - - checkSparkAnswerAndOperator( - s"SELECT col1, negative(col2), cast(col1 as float), col1 = negative(col2) FROM $table") - } - } - } - } - - test("parquet default values") { - withTable("t1") { - sql("create table t1(col1 boolean) using parquet") - sql("insert into t1 values(true)") - sql("alter table t1 add column col2 string default 'hello'") - checkSparkAnswerAndOperator("select * from t1") - } - } - test("decimals divide by zero") { Seq(true, false).foreach { dictionary => withSQLConf( @@ -176,16 +151,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("Integral Division Overflow Handling Matches Spark Behavior") { - withTable("t1") { - val value = Long.MinValue - sql("create table t1(c1 long, c2 short) using parquet") - sql(s"insert into t1 values($value, -1)") - val res = sql("select c1 div c2 from t1 order by c1") - checkSparkAnswerAndOperator(res) - } - } - test("basic data type support") { // this test requires native_comet scan due to unsigned u8/u16 issue withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_NATIVE_COMET) { @@ -502,17 +467,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("substring with start < 1") { - withTempPath { _ => - withTable("t") { - sql("create table t (col string) using parquet") - sql("insert into t values('123456')") - checkSparkAnswerAndOperator(sql("select substring(col, 0) from t")) - checkSparkAnswerAndOperator(sql("select substring(col, -1) from t")) - } - } - } - test("substring with dictionary") { val data = (0 until 1000) .map(_ % 5) // reduce value space to trigger dictionary encoding @@ -1557,20 +1511,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("md5") { - Seq(false, true).foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { - val table = "test" - withTable(table) { - sql(s"create table $table(col String) using parquet") - sql( - s"insert into $table values ('test1'), ('test1'), ('test2'), ('test2'), (NULL), ('')") - checkSparkAnswerAndOperator(s"select md5(col) FROM $table") - } - } - } - } - test("hex") { // https://github.com/apache/datafusion-comet/issues/1441 assume(!usingDataSourceExec) @@ -1589,26 +1529,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("unhex") { - val table = "unhex_table" - withTable(table) { - sql(s"create table $table(col string) using parquet") - - sql(s"""INSERT INTO $table VALUES - |('537061726B2053514C'), - |('737472696E67'), - |('\\0'), - |(''), - |('###'), - |('G123'), - |('hello'), - |('A1B'), - |('0A1B')""".stripMargin) - - checkSparkAnswerAndOperator(s"SELECT unhex(col) FROM $table") - } - } - test("EqualNullSafe should preserve comet filter") { Seq("true", "false").foreach(b => withParquetTable( @@ -1633,58 +1553,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { }) } - test("test in(set)/not in(set)") { - Seq("100", "0").foreach { inSetThreshold => - Seq(false, true).foreach { dictionary => - withSQLConf( - SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> inSetThreshold, - "parquet.enable.dictionary" -> dictionary.toString) { - val table = "names" - withTable(table) { - sql(s"create table $table(id int, name varchar(20)) using parquet") - sql( - s"insert into $table values(1, 'James'), (1, 'Jones'), (2, 'Smith'), (3, 'Smith')," + - "(NULL, 'Jones'), (4, NULL)") - - checkSparkAnswerAndOperator(s"SELECT * FROM $table WHERE id in (1, 2, 4, NULL)") - checkSparkAnswerAndOperator( - s"SELECT * FROM $table WHERE name in ('Smith', 'Brown', NULL)") - - // TODO: why with not in, the plan is only `LocalTableScan`? - checkSparkAnswerAndOperator(s"SELECT * FROM $table WHERE id not in (1)") - checkSparkAnswer(s"SELECT * FROM $table WHERE name not in ('Smith', 'Brown', NULL)") - } - } - } - } - } - - test("not") { - Seq(false, true).foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { - val table = "test" - withTable(table) { - sql(s"create table $table(col1 int, col2 boolean) using parquet") - sql(s"insert into $table values(1, false), (2, true), (3, true), (3, false)") - checkSparkAnswerAndOperator(s"SELECT col1, col2, NOT(col2), !(col2) FROM $table") - } - } - } - } - - test("negative") { - Seq(false, true).foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { - val table = "test" - withTable(table) { - sql(s"create table $table(col1 int) using parquet") - sql(s"insert into $table values(1), (2), (3), (3)") - checkSparkAnswerAndOperator(s"SELECT negative(col1), -(col1) FROM $table") - } - } - } - } - test("basic arithmetic") { withSQLConf("parquet.enable.dictionary" -> "false") { withParquetTable((1 until 10).map(i => (i, i + 1)), "tbl", false) { @@ -1719,21 +1587,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("DatePart functions: Year/Month/DayOfMonth/DayOfWeek/DayOfYear/WeekOfYear/Quarter") { - Seq(false, true).foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { - val table = "test" - withTable(table) { - sql(s"create table $table(col timestamp) using parquet") - sql(s"insert into $table values (now()), (timestamp('1900-01-01')), (null)") - checkSparkAnswerAndOperator( - "SELECT col, year(col), month(col), day(col), weekday(col), " + - s" dayofweek(col), dayofyear(col), weekofyear(col), quarter(col) FROM $table") - } - } - } - } - test("from_unixtime") { Seq(false, true).foreach { dictionary => withSQLConf( @@ -1985,31 +1838,6 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("hash functions") { - Seq(true, false).foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { - val table = "test" - withTable(table) { - sql(s"create table $table(col string, a int, b float) using parquet") - sql(s""" - |insert into $table values - |('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999) - |, ('Spark SQL ', 10, 1.2), (NULL, NULL, NULL), ('', 0, 0.0), ('苹果手机', NULL, 3.999999) - |""".stripMargin) - checkSparkAnswerAndOperator(""" - |select - |md5(col), md5(cast(a as string)), md5(cast(b as string)), - |hash(col), hash(col, 1), hash(col, 0), hash(col, a, b), hash(b, a, col), - |xxhash64(col), xxhash64(col, 1), xxhash64(col, 0), xxhash64(col, a, b), xxhash64(b, a, col), - |sha2(col, 0), sha2(col, 256), sha2(col, 224), sha2(col, 384), sha2(col, 512), sha2(col, 128), sha2(col, -1), - |sha1(col), sha1(cast(a as string)), sha1(cast(b as string)) - |from test - |""".stripMargin) - } - } - } - } - test("remainder function") { def withAnsiMode(enabled: Boolean)(f: => Unit): Unit = { withSQLConf( diff --git a/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala b/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala new file mode 100644 index 0000000000..8d382f5173 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/CometSqlFileTestSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import java.io.File + +import org.scalactic.source.Position +import org.scalatest.Tag + +import org.apache.spark.sql.CometTestBase +import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper + +class CometSqlFileTestSuite extends CometTestBase with AdaptiveSparkPlanHelper { + + override protected def test(testName: String, testTags: Tag*)(testFun: => Any)(implicit + pos: Position): Unit = { + super.test(testName, testTags: _*) { + withSQLConf(CometConf.COMET_NATIVE_SCAN_IMPL.key -> CometConf.SCAN_AUTO) { + testFun + } + } + } + + private val testResourceDir = { + val url = getClass.getClassLoader.getResource("sql-tests") + assert(url != null, "Could not find sql-tests resource directory") + new File(url.toURI) + } + + private def discoverTestFiles(dir: File): Seq[File] = { + if (!dir.exists()) return Seq.empty + val files = dir.listFiles().toSeq + val sqlFiles = files.filter(f => f.isFile && f.getName.endsWith(".sql")) + val subDirFiles = files.filter(_.isDirectory).flatMap(discoverTestFiles) + sqlFiles ++ subDirFiles + } + + /** Generate all config combinations from a ConfigMatrix specification. */ + private def configCombinations( + matrix: Seq[(String, Seq[String])]): Seq[Seq[(String, String)]] = { + if (matrix.isEmpty) return Seq(Seq.empty) + val (key, values) = matrix.head + val rest = configCombinations(matrix.tail) + for { + value <- values + combo <- rest + } yield (key, value) +: combo + } + + private def runTestFile(file: SqlTestFile): Unit = { + val allConfigs = file.configs + withSQLConf(allConfigs: _*) { + withTable(file.tables: _*) { + file.records.foreach { + case SqlStatement(sql) => + spark.sql(sql) + case SqlQuery(sql, mode) => + mode match { + case CheckOperator => + checkSparkAnswerAndOperator(sql) + case SparkAnswerOnly => + checkSparkAnswer(sql) + case WithTolerance(tol) => + checkSparkAnswerWithTolerance(sql, tol) + } + } + } + } + } + + // Discover and register all .sql test files + discoverTestFiles(testResourceDir).foreach { file => + val relativePath = testResourceDir.toURI.relativize(file.toURI).getPath + val parsed = SqlFileTestParser.parse(file) + val combinations = configCombinations(parsed.configMatrix) + + if (combinations.size <= 1) { + // No matrix or single combination + test(s"sql-file: $relativePath") { + val effectiveConfigs = parsed.configs ++ combinations.headOption.getOrElse(Seq.empty) + runTestFile(parsed.copy(configs = effectiveConfigs)) + } + } else { + // Multiple combinations: generate one test per combination + combinations.foreach { matrixConfigs => + val label = matrixConfigs.map { case (k, v) => s"$k=$v" }.mkString(", ") + test(s"sql-file: $relativePath [$label]") { + runTestFile(parsed.copy(configs = parsed.configs ++ matrixConfigs)) + } + } + } + } +} diff --git a/spark/src/test/scala/org/apache/comet/SqlFileTestParser.scala b/spark/src/test/scala/org/apache/comet/SqlFileTestParser.scala new file mode 100644 index 0000000000..f2cbf56544 --- /dev/null +++ b/spark/src/test/scala/org/apache/comet/SqlFileTestParser.scala @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet + +import java.io.File + +import scala.io.Source + +/** A record in a SQL test file: either a statement (DDL/DML) or a query (SELECT). */ +sealed trait SqlTestRecord + +/** A SQL statement to execute (CREATE TABLE, INSERT, etc.). */ +case class SqlStatement(sql: String) extends SqlTestRecord + +/** A SQL query whose results are compared between Spark and Comet. */ +case class SqlQuery(sql: String, mode: QueryMode = CheckOperator) extends SqlTestRecord + +sealed trait QueryMode +case object CheckOperator extends QueryMode +case object SparkAnswerOnly extends QueryMode +case class WithTolerance(tol: Double) extends QueryMode + +/** + * Parsed representation of a .sql test file. + * + * @param configs + * Spark SQL configs to set for this test file. + * @param configMatrix + * Map of config key to list of values. The test will run once per combination. + * @param records + * Ordered list of statements and queries. + * @param tables + * Table names extracted from CREATE TABLE statements (for cleanup). + */ +case class SqlTestFile( + configs: Seq[(String, String)], + configMatrix: Seq[(String, Seq[String])], + records: Seq[SqlTestRecord], + tables: Seq[String]) + +object SqlFileTestParser { + + private val ConfigPattern = """--\s*Config:\s*(.+)=(.+)""".r + private val ConfigMatrixPattern = """--\s*ConfigMatrix:\s*(.+)=(.+)""".r + private val CreateTablePattern = """(?i)CREATE\s+TABLE\s+(\w+)""".r.unanchored + + def parse(file: File): SqlTestFile = { + val source = Source.fromFile(file) + try { + parse(source.getLines().toSeq) + } finally { + source.close() + } + } + + def parse(lines: Seq[String]): SqlTestFile = { + var configs = Seq.empty[(String, String)] + var configMatrix = Seq.empty[(String, Seq[String])] + val records = Seq.newBuilder[SqlTestRecord] + val tables = Seq.newBuilder[String] + + var i = 0 + while (i < lines.length) { + val line = lines(i).trim + + line match { + case ConfigPattern(key, value) => + configs :+= (key.trim -> value.trim) + i += 1 + + case ConfigMatrixPattern(key, values) => + configMatrix :+= (key.trim -> values.split(",").map(_.trim).toSeq) + i += 1 + + case "statement" => + i += 1 + val (sql, nextIdx) = collectSql(lines, i) + // Extract table names for cleanup + CreateTablePattern.findFirstMatchIn(sql).foreach(m => tables += m.group(1)) + records += SqlStatement(sql) + i = nextIdx + + case s if s.startsWith("query") => + val mode = parseQueryMode(s) + i += 1 + val (sql, nextIdx) = collectSql(lines, i) + records += SqlQuery(sql, mode) + i = nextIdx + + case _ => + // Skip blank lines and comments + i += 1 + } + } + + SqlTestFile(configs, configMatrix, records.result(), tables.result()) + } + + private def parseQueryMode(directive: String): QueryMode = { + val parts = directive.split("\\s+") + if (parts.length == 1) return CheckOperator + parts(1) match { + case "spark_answer_only" => SparkAnswerOnly + case s if s.startsWith("tolerance=") => + WithTolerance(s.stripPrefix("tolerance=").toDouble) + case _ => CheckOperator + } + } + + /** Collect SQL lines until a blank line or end of file. */ + private def collectSql(lines: Seq[String], start: Int): (String, Int) = { + val sb = new StringBuilder + var i = start + while (i < lines.length && lines(i).trim.nonEmpty) { + if (sb.nonEmpty) sb.append("\n") + sb.append(lines(i)) + i += 1 + } + (sb.toString, i) + } +}