From 774acef9908bb30f5144b9769aad40d4ac642789 Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Tue, 18 Jun 2024 18:06:28 +0530 Subject: [PATCH 01/18] adding the test case to reproduce the bug --- .../org/apache/comet/CometExpressionSuite.scala | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index ec45c984b9..49920ab316 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -856,6 +856,21 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("remainder") { + withTempDir { dir => + // Create a DataFrame with null values + val df = Seq((-21840, -0.0)).toDF("c90", "c1") + + // Write the DataFrame to a Parquet file + val path = new Path(dir.toURI.toString, "remainder_test.parquet").toString + df.write.mode("overwrite").parquet(path) + + withParquetTable(path, "t") { + checkSparkAnswerAndOperator("SELECT c90, c1, c90 % c1 FROM t") + } + } + } + test("abs Overflow ansi mode") { def testAbsAnsiOverflow[T <: Product: ClassTag: TypeTag](data: Seq[T]): Unit = { From 45b876be511fd2c64502c0f4f3807c8edeba3a19 Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Tue, 18 Jun 2024 18:12:13 +0530 Subject: [PATCH 02/18] Remove comments --- .../apache/comet/CometExpressionSuite.scala | 1613 ++++++++--------- 1 file changed, 805 insertions(+), 808 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 49920ab316..645f5f8480 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -50,818 +50,815 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("decimals divide by zero") { - // TODO: enable Spark 3.2 & 3.3 tests after supporting decimal divide operation - assume(isSpark34Plus) - - Seq(true, false).foreach { dictionary => - withSQLConf( - SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false", - "parquet.enable.dictionary" -> dictionary.toString) { - withTempPath { dir => - val data = makeDecimalRDD(10, DecimalType(18, 10), dictionary) - data.write.parquet(dir.getCanonicalPath) - readParquetFile(dir.getCanonicalPath) { df => - { - val decimalLiteral = Decimal(0.00) - val cometDf = df.select($"dec" / decimalLiteral, $"dec" % decimalLiteral) - checkSparkAnswerAndOperator(cometDf) - } - } - } - } - } - } - - test("bitwise shift with different left/right types") { - Seq(false, true).foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { - val table = "test" - withTable(table) { - sql(s"create table $table(col1 long, col2 int) using parquet") - sql(s"insert into $table values(1111, 2)") - sql(s"insert into $table values(1111, 2)") - sql(s"insert into $table values(3333, 4)") - sql(s"insert into $table values(5555, 6)") - - checkSparkAnswerAndOperator( - s"SELECT shiftright(col1, 2), shiftright(col1, col2) FROM $table") - checkSparkAnswerAndOperator( - s"SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM $table") - } - } - } - } - - test("basic data type support") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) - withParquetTable(path.toString, "tbl") { - // TODO: enable test for unsigned ints - checkSparkAnswerAndOperator( - "select _1, _2, _3, _4, _5, _6, _7, _8, _13, _14, _15, _16, _17, " + - "_18, _19, _20 FROM tbl WHERE _2 > 100") - } - } - } - } - - test("null literals") { - val batchSize = 1000 - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, batchSize) - withParquetTable(path.toString, "tbl") { - val sqlString = "SELECT _4 + null, _15 - null, _16 * null FROM tbl" - val df2 = sql(sqlString) - val rows = df2.collect() - assert(rows.length == batchSize) - assert(rows.forall(_ == Row(null, null, null))) - - checkSparkAnswerAndOperator(sqlString) - } - } - } - } - - test("date and timestamp type literals") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) - withParquetTable(path.toString, "tbl") { - checkSparkAnswerAndOperator( - "SELECT _4 FROM tbl WHERE " + - "_20 > CAST('2020-01-01' AS DATE) AND _18 < CAST('2020-01-01' AS TIMESTAMP)") - } - } - } - } - - test("dictionary arithmetic") { - // TODO: test ANSI mode - withSQLConf(SQLConf.ANSI_ENABLED.key -> "false", "parquet.enable.dictionary" -> "true") { - withParquetTable((0 until 10).map(i => (i % 5, i % 3)), "tbl") { - checkSparkAnswerAndOperator("SELECT _1 + _2, _1 - _2, _1 * _2, _1 / _2, _1 % _2 FROM tbl") - } - } - } - - test("dictionary arithmetic with scalar") { - withSQLConf("parquet.enable.dictionary" -> "true") { - withParquetTable((0 until 10).map(i => (i % 5, i % 3)), "tbl") { - checkSparkAnswerAndOperator("SELECT _1 + 1, _1 - 1, _1 * 2, _1 / 2, _1 % 2 FROM tbl") - } - } - } - - test("string type and substring") { - withParquetTable((0 until 5).map(i => (i.toString, (i + 100).toString)), "tbl") { - checkSparkAnswerAndOperator("SELECT _1, substring(_2, 2, 2) FROM tbl") - checkSparkAnswerAndOperator("SELECT _1, substring(_2, 2, -2) FROM tbl") - checkSparkAnswerAndOperator("SELECT _1, substring(_2, -2, 2) FROM tbl") - checkSparkAnswerAndOperator("SELECT _1, substring(_2, -2, -2) FROM tbl") - checkSparkAnswerAndOperator("SELECT _1, substring(_2, -2, 10) FROM tbl") - checkSparkAnswerAndOperator("SELECT _1, substring(_2, 0, 0) FROM tbl") - checkSparkAnswerAndOperator("SELECT _1, substring(_2, 1, 0) FROM tbl") - } - } - - test("substring with start < 1") { - withTempPath { _ => - withTable("t") { - sql("create table t (col string) using parquet") - sql("insert into t values('123456')") - checkSparkAnswerAndOperator(sql("select substring(col, 0) from t")) - checkSparkAnswerAndOperator(sql("select substring(col, -1) from t")) - } - } - } - - test("string with coalesce") { - withParquetTable( - (0 until 10).map(i => (i.toString, if (i > 5) None else Some((i + 100).toString))), - "tbl") { - checkSparkAnswerAndOperator( - "SELECT coalesce(_1), coalesce(_1, 1), coalesce(null, _1), coalesce(null, 1), coalesce(_2, _1), coalesce(null) FROM tbl") - } - } - - test("substring with dictionary") { - val data = (0 until 1000) - .map(_ % 5) // reduce value space to trigger dictionary encoding - .map(i => (i.toString, (i + 100).toString)) - withParquetTable(data, "tbl") { - checkSparkAnswerAndOperator("SELECT _1, substring(_2, 2, 2) FROM tbl") - } - } - - test("string_space") { - withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { - checkSparkAnswerAndOperator("SELECT space(_1), space(_2) FROM tbl") - } - } - - test("string_space with dictionary") { - val data = (0 until 1000).map(i => Tuple1(i % 5)) - - withSQLConf("parquet.enable.dictionary" -> "true") { - withParquetTable(data, "tbl") { - checkSparkAnswerAndOperator("SELECT space(_1) FROM tbl") - } - } - } - - test("hour, minute, second") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "part-r-0.parquet") - val expected = makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) - readParquetFile(path.toString) { df => - val query = df.select(expr("hour(_1)"), expr("minute(_1)"), expr("second(_1)")) - - checkAnswer( - query, - expected.map { - case None => - Row(null, null, null) - case Some(i) => - val timestamp = new java.sql.Timestamp(i).toLocalDateTime - val hour = timestamp.getHour - val minute = timestamp.getMinute - val second = timestamp.getSecond - - Row(hour, minute, second) - }) - } - } - } - } - - test("hour on int96 timestamp column") { - import testImplicits._ - - val N = 100 - val ts = "2020-01-01 01:02:03.123456" - Seq(true, false).foreach { dictionaryEnabled => - Seq(false, true).foreach { conversionEnabled => - withSQLConf( - SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> "INT96", - SQLConf.PARQUET_INT96_TIMESTAMP_CONVERSION.key -> conversionEnabled.toString) { - withTempPath { path => - Seq - .tabulate(N)(_ => ts) - .toDF("ts1") - .select($"ts1".cast("timestamp").as("ts")) - .repartition(1) - .write - .option("parquet.enable.dictionary", dictionaryEnabled) - .parquet(path.getCanonicalPath) - - checkAnswer( - spark.read.parquet(path.getCanonicalPath).select(expr("hour(ts)")), - Seq.tabulate(N)(_ => Row(1))) - } - } - } - } - } - - test("cast timestamp and timestamp_ntz") { - withSQLConf( - SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", - CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") - makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) - withParquetTable(path.toString, "timetbl") { - checkSparkAnswerAndOperator( - "SELECT " + - "cast(_2 as timestamp) tz_millis, " + - "cast(_3 as timestamp) ntz_millis, " + - "cast(_4 as timestamp) tz_micros, " + - "cast(_5 as timestamp) ntz_micros " + - " from timetbl") - } - } - } - } - } - - test("cast timestamp and timestamp_ntz to string") { - // TODO: make the test pass for Spark 3.2 & 3.3 - assume(isSpark34Plus) - - withSQLConf( - SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", - CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") - makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 2001) - withParquetTable(path.toString, "timetbl") { - checkSparkAnswerAndOperator( - "SELECT " + - "cast(_2 as string) tz_millis, " + - "cast(_3 as string) ntz_millis, " + - "cast(_4 as string) tz_micros, " + - "cast(_5 as string) ntz_micros " + - " from timetbl") - } - } - } - } - } - - test("cast timestamp and timestamp_ntz to long, date") { - // TODO: make the test pass for Spark 3.2 & 3.3 - assume(isSpark34Plus) - - withSQLConf( - SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", - CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") - makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) - withParquetTable(path.toString, "timetbl") { - checkSparkAnswerAndOperator( - "SELECT " + - "cast(_2 as long) tz_millis, " + - "cast(_4 as long) tz_micros, " + - "cast(_2 as date) tz_millis_to_date, " + - "cast(_3 as date) ntz_millis_to_date, " + - "cast(_4 as date) tz_micros_to_date, " + - "cast(_5 as date) ntz_micros_to_date " + - " from timetbl") - } - } - } - } - } - - test("trunc") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "date_trunc.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) - withParquetTable(path.toString, "tbl") { - Seq("YEAR", "YYYY", "YY", "QUARTER", "MON", "MONTH", "MM", "WEEK").foreach { format => - checkSparkAnswerAndOperator(s"SELECT trunc(_20, '$format') from tbl") - } - } - } - } - } - - test("trunc with format array") { - val numRows = 1000 - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "date_trunc_with_format.parquet") - makeDateTimeWithFormatTable(path, dictionaryEnabled = dictionaryEnabled, numRows) - withParquetTable(path.toString, "dateformattbl") { - checkSparkAnswerAndOperator( - "SELECT " + - "dateformat, _7, " + - "trunc(_7, dateformat) " + - " from dateformattbl ") - } - } - } - } - - test("date_trunc") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") - makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) - withParquetTable(path.toString, "timetbl") { - Seq( - "YEAR", - "YYYY", - "YY", - "MON", - "MONTH", - "MM", - "QUARTER", - "WEEK", - "DAY", - "DD", - "HOUR", - "MINUTE", - "SECOND", - "MILLISECOND", - "MICROSECOND").foreach { format => - checkSparkAnswerAndOperator( - "SELECT " + - s"date_trunc('$format', _0), " + - s"date_trunc('$format', _1), " + - s"date_trunc('$format', _2), " + - s"date_trunc('$format', _4) " + - " from timetbl") - } - } - } - } - } - - test("date_trunc with timestamp_ntz") { - assume(!isSpark32, "timestamp functions for timestamp_ntz have incorrect behavior in 3.2") - withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") - makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) - withParquetTable(path.toString, "timetbl") { - Seq( - "YEAR", - "YYYY", - "YY", - "MON", - "MONTH", - "MM", - "QUARTER", - "WEEK", - "DAY", - "DD", - "HOUR", - "MINUTE", - "SECOND", - "MILLISECOND", - "MICROSECOND").foreach { format => - checkSparkAnswerAndOperator( - "SELECT " + - s"date_trunc('$format', _3), " + - s"date_trunc('$format', _5) " + - " from timetbl") - } - } - } - } - } - } - - test("date_trunc with format array") { - assume(isSpark33Plus, "TimestampNTZ is supported in Spark 3.3+, See SPARK-36182") - withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { - val numRows = 1000 - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "timestamp_trunc_with_format.parquet") - makeDateTimeWithFormatTable(path, dictionaryEnabled = dictionaryEnabled, numRows) - withParquetTable(path.toString, "timeformattbl") { - checkSparkAnswerAndOperator( - "SELECT " + - "format, _0, _1, _2, _3, _4, _5, " + - "date_trunc(format, _0), " + - "date_trunc(format, _1), " + - "date_trunc(format, _2), " + - "date_trunc(format, _3), " + - "date_trunc(format, _4), " + - "date_trunc(format, _5) " + - " from timeformattbl ") - } - } - } - } - } - - test("date_trunc on int96 timestamp column") { - import testImplicits._ - - val N = 100 - val ts = "2020-01-01 01:02:03.123456" - Seq(true, false).foreach { dictionaryEnabled => - Seq(false, true).foreach { conversionEnabled => - withSQLConf( - SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> "INT96", - SQLConf.PARQUET_INT96_TIMESTAMP_CONVERSION.key -> conversionEnabled.toString) { - withTempPath { path => - Seq - .tabulate(N)(_ => ts) - .toDF("ts1") - .select($"ts1".cast("timestamp").as("ts")) - .repartition(1) - .write - .option("parquet.enable.dictionary", dictionaryEnabled) - .parquet(path.getCanonicalPath) - - withParquetTable(path.toString, "int96timetbl") { - Seq( - "YEAR", - "YYYY", - "YY", - "MON", - "MONTH", - "MM", - "QUARTER", - "WEEK", - "DAY", - "DD", - "HOUR", - "MINUTE", - "SECOND", - "MILLISECOND", - "MICROSECOND").foreach { format => - checkSparkAnswer( - "SELECT " + - s"date_trunc('$format', ts )" + - " from int96timetbl") - } - } - } - } - } - } - } - - test("charvarchar") { - Seq(false, true).foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { - val table = "char_tbl4" - withTable(table) { - val view = "str_view" - withView(view) { - sql(s"""create temporary view $view as select c, v from values - | (null, null), (null, null), - | (null, 'S'), (null, 'S'), - | ('N', 'N '), ('N', 'N '), - | ('Ne', 'Sp'), ('Ne', 'Sp'), - | ('Net ', 'Spa '), ('Net ', 'Spa '), - | ('NetE', 'Spar'), ('NetE', 'Spar'), - | ('NetEa ', 'Spark '), ('NetEa ', 'Spark '), - | ('NetEas ', 'Spark'), ('NetEas ', 'Spark'), - | ('NetEase', 'Spark-'), ('NetEase', 'Spark-') t(c, v);""".stripMargin) - sql( - s"create table $table(c7 char(7), c8 char(8), v varchar(6), s string) using parquet;") - sql(s"insert into $table select c, c, v, c from $view;") - val df = sql(s"""select substring(c7, 2), substring(c8, 2), - | substring(v, 3), substring(s, 2) from $table;""".stripMargin) - - val expected = Row(" ", " ", "", "") :: - Row(null, null, "", null) :: Row(null, null, null, null) :: - Row("e ", "e ", "", "e") :: Row("et ", "et ", "a ", "et ") :: - Row("etE ", "etE ", "ar", "etE") :: - Row("etEa ", "etEa ", "ark ", "etEa ") :: - Row("etEas ", "etEas ", "ark", "etEas ") :: - Row("etEase", "etEase ", "ark-", "etEase") :: Nil - checkAnswer(df, expected ::: expected) - } - } - } - } - } - - test("char varchar over length values") { - Seq("char", "varchar").foreach { typ => - withTempPath { dir => - withTable("t") { - sql("select '123456' as col").write.format("parquet").save(dir.toString) - sql(s"create table t (col $typ(2)) using parquet location '$dir'") - sql("insert into t values('1')") - checkSparkAnswerAndOperator(sql("select substring(col, 1) from t")) - checkSparkAnswerAndOperator(sql("select substring(col, 0) from t")) - checkSparkAnswerAndOperator(sql("select substring(col, -1) from t")) - } - } - } - } - - test("like (LikeSimplification enabled)") { - val table = "names" - withTable(table) { - sql(s"create table $table(id int, name varchar(20)) using parquet") - sql(s"insert into $table values(1,'James Smith')") - sql(s"insert into $table values(2,'Michael Rose')") - sql(s"insert into $table values(3,'Robert Williams')") - sql(s"insert into $table values(4,'Rames Rose')") - sql(s"insert into $table values(5,'Rames rose')") - - // Filter column having values 'Rames _ose', where any character matches for '_' - val query = sql(s"select id from $table where name like 'Rames _ose'") - checkAnswer(query, Row(4) :: Row(5) :: Nil) - - // Filter rows that contains 'rose' in 'name' column - val queryContains = sql(s"select id from $table where name like '%rose%'") - checkAnswer(queryContains, Row(5) :: Nil) - - // Filter rows that starts with 'R' following by any characters - val queryStartsWith = sql(s"select id from $table where name like 'R%'") - checkAnswer(queryStartsWith, Row(3) :: Row(4) :: Row(5) :: Nil) - - // Filter rows that ends with 's' following by any characters - val queryEndsWith = sql(s"select id from $table where name like '%s'") - checkAnswer(queryEndsWith, Row(3) :: Nil) - } - } - - test("like with custom escape") { - val table = "names" - withTable(table) { - sql(s"create table $table(id int, name varchar(20)) using parquet") - sql(s"insert into $table values(1,'James Smith')") - sql(s"insert into $table values(2,'Michael_Rose')") - sql(s"insert into $table values(3,'Robert_R_Williams')") - - // Filter column having values that include underscores - val queryDefaultEscape = sql("select id from names where name like '%\\_%'") - checkSparkAnswerAndOperator(queryDefaultEscape) - - val queryCustomEscape = sql("select id from names where name like '%$_%' escape '$'") - checkAnswer(queryCustomEscape, Row(2) :: Row(3) :: Nil) - - } - } - - test("contains") { - assume(!isSpark32) - - val table = "names" - withTable(table) { - sql(s"create table $table(id int, name varchar(20)) using parquet") - sql(s"insert into $table values(1,'James Smith')") - sql(s"insert into $table values(2,'Michael Rose')") - sql(s"insert into $table values(3,'Robert Williams')") - sql(s"insert into $table values(4,'Rames Rose')") - sql(s"insert into $table values(5,'Rames rose')") - - // Filter rows that contains 'rose' in 'name' column - val queryContains = sql(s"select id from $table where contains (name, 'rose')") - checkAnswer(queryContains, Row(5) :: Nil) - } - } - - test("startswith") { - assume(!isSpark32) - - val table = "names" - withTable(table) { - sql(s"create table $table(id int, name varchar(20)) using parquet") - sql(s"insert into $table values(1,'James Smith')") - sql(s"insert into $table values(2,'Michael Rose')") - sql(s"insert into $table values(3,'Robert Williams')") - sql(s"insert into $table values(4,'Rames Rose')") - sql(s"insert into $table values(5,'Rames rose')") - - // Filter rows that starts with 'R' following by any characters - val queryStartsWith = sql(s"select id from $table where startswith (name, 'R')") - checkAnswer(queryStartsWith, Row(3) :: Row(4) :: Row(5) :: Nil) - } - } - - test("endswith") { - assume(!isSpark32) - - val table = "names" - withTable(table) { - sql(s"create table $table(id int, name varchar(20)) using parquet") - sql(s"insert into $table values(1,'James Smith')") - sql(s"insert into $table values(2,'Michael Rose')") - sql(s"insert into $table values(3,'Robert Williams')") - sql(s"insert into $table values(4,'Rames Rose')") - sql(s"insert into $table values(5,'Rames rose')") - - // Filter rows that ends with 's' following by any characters - val queryEndsWith = sql(s"select id from $table where endswith (name, 's')") - checkAnswer(queryEndsWith, Row(3) :: Nil) - } - } - - test("add overflow (ANSI disable)") { - // Enabling ANSI will cause native engine failure, but as we cannot catch - // native error now, we cannot test it here. - withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { - withParquetTable(Seq((Int.MaxValue, 1)), "tbl") { - checkSparkAnswerAndOperator("SELECT _1 + _2 FROM tbl") - } - } - } - - test("divide by zero (ANSI disable)") { - // Enabling ANSI will cause native engine failure, but as we cannot catch - // native error now, we cannot test it here. - withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { - withParquetTable(Seq((1, 0, 1.0, 0.0)), "tbl") { - checkSparkAnswerAndOperator("SELECT _1 / _2, _3 / _4 FROM tbl") - } - } - } - - test("decimals arithmetic and comparison") { - // TODO: enable Spark 3.2 & 3.3 tests after supporting decimal reminder operation - assume(isSpark34Plus) - - def makeDecimalRDD(num: Int, decimal: DecimalType, useDictionary: Boolean): DataFrame = { - val div = if (useDictionary) 5 else num // narrow the space to make it dictionary encoded - spark - .range(num) - .map(_ % div) - // Parquet doesn't allow column names with spaces, have to add an alias here. - // Minus 500 here so that negative decimals are also tested. - .select( - (($"value" - 500) / 100.0) cast decimal as Symbol("dec1"), - (($"value" - 600) / 100.0) cast decimal as Symbol("dec2")) - .coalesce(1) - } - - Seq(true, false).foreach { dictionary => - Seq(16, 1024).foreach { batchSize => - withSQLConf( - CometConf.COMET_BATCH_SIZE.key -> batchSize.toString, - SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false", - "parquet.enable.dictionary" -> dictionary.toString) { - var combinations = Seq((5, 2), (1, 0), (18, 10), (18, 17), (19, 0), (38, 37)) - // If ANSI mode is on, the combination (1, 1) will cause a runtime error. Otherwise, the - // decimal RDD contains all null values and should be able to read back from Parquet. - - if (!SQLConf.get.ansiEnabled) { - combinations = combinations ++ Seq((1, 1)) - } - - for ((precision, scale) <- combinations) { - withTempPath { dir => - val data = makeDecimalRDD(10, DecimalType(precision, scale), dictionary) - data.write.parquet(dir.getCanonicalPath) - readParquetFile(dir.getCanonicalPath) { df => - { - val decimalLiteral1 = Decimal(1.00) - val decimalLiteral2 = Decimal(123.456789) - val cometDf = df.select( - $"dec1" + $"dec2", - $"dec1" - $"dec2", - $"dec1" % $"dec2", - $"dec1" >= $"dec1", - $"dec1" === "1.0", - $"dec1" + decimalLiteral1, - $"dec1" - decimalLiteral1, - $"dec1" + decimalLiteral2, - $"dec1" - decimalLiteral2) - - checkAnswer( - cometDf, - data - .select( - $"dec1" + $"dec2", - $"dec1" - $"dec2", - $"dec1" % $"dec2", - $"dec1" >= $"dec1", - $"dec1" === "1.0", - $"dec1" + decimalLiteral1, - $"dec1" - decimalLiteral1, - $"dec1" + decimalLiteral2, - $"dec1" - decimalLiteral2) - .collect() - .toSeq) - } - } - } - } - } - } - } - } - - test("scalar decimal arithmetic operations") { - assume(isSpark34Plus) - withTable("tbl") { - withSQLConf(CometConf.COMET_ENABLED.key -> "true") { - sql("CREATE TABLE tbl (a INT) USING PARQUET") - sql("INSERT INTO tbl VALUES (0)") - - val combinations = Seq((7, 3), (18, 10), (38, 4)) - for ((precision, scale) <- combinations) { - for (op <- Seq("+", "-", "*", "/", "%")) { - val left = s"CAST(1.00 AS DECIMAL($precision, $scale))" - val right = s"CAST(123.45 AS DECIMAL($precision, $scale))" - - withSQLConf( - "spark.sql.optimizer.excludedRules" -> - "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { - - checkSparkAnswerAndOperator(s"SELECT $left $op $right FROM tbl") - } - } - } - } - } - } - - test("cast decimals to int") { - Seq(16, 1024).foreach { batchSize => - withSQLConf( - CometConf.COMET_BATCH_SIZE.key -> batchSize.toString, - SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false") { - var combinations = Seq((5, 2), (1, 0), (18, 10), (18, 17), (19, 0), (38, 37)) - // If ANSI mode is on, the combination (1, 1) will cause a runtime error. Otherwise, the - // decimal RDD contains all null values and should be able to read back from Parquet. - - if (!SQLConf.get.ansiEnabled) { - combinations = combinations ++ Seq((1, 1)) - } - - for ((precision, scale) <- combinations; useDictionary <- Seq(false)) { - withTempPath { dir => - val data = makeDecimalRDD(10, DecimalType(precision, scale), useDictionary) - data.write.parquet(dir.getCanonicalPath) - readParquetFile(dir.getCanonicalPath) { df => - { - val cometDf = df.select($"dec".cast("int")) - - // `data` is not read from Parquet, so it doesn't go Comet exec. - checkAnswer(cometDf, data.select($"dec".cast("int")).collect().toSeq) - } - } - } - } - } - } - } - - test("various math scalar functions") { - Seq("true", "false").foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary) { - withParquetTable( - (0 until 5).map(i => (i.toDouble + 0.3, i.toDouble + 0.8)), - "tbl", - withDictionary = dictionary.toBoolean) { - checkSparkAnswerWithTol( - "SELECT abs(_1), acos(_2), asin(_1), atan(_2), atan2(_1, _2), cos(_1) FROM tbl") - checkSparkAnswerWithTol( - "SELECT exp(_1), ln(_2), log10(_1), log2(_1), pow(_1, _2) FROM tbl") - // TODO: comment in the round tests once supported - // checkSparkAnswerWithTol("SELECT round(_1), round(_2) FROM tbl") - checkSparkAnswerWithTol("SELECT signum(_1), sin(_1), sqrt(_1) FROM tbl") - checkSparkAnswerWithTol("SELECT tan(_1) FROM tbl") - } - } - } - } - - test("abs") { - Seq(true, false).foreach { dictionaryEnabled => - withTempDir { dir => - val path = new Path(dir.toURI.toString, "test.parquet") - makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 100) - withParquetTable(path.toString, "tbl") { - Seq(2, 3, 4, 5, 6, 7, 15, 16, 17).foreach { col => - checkSparkAnswerAndOperator(s"SELECT abs(_${col}) FROM tbl") - } - } - } - } - } + // test("decimals divide by zero") { + // // TODO: enable Spark 3.2 & 3.3 tests after supporting decimal divide operation + // assume(isSpark34Plus) + + // Seq(true, false).foreach { dictionary => + // withSQLConf( + // SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false", + // "parquet.enable.dictionary" -> dictionary.toString) { + // withTempPath { dir => + // val data = makeDecimalRDD(10, DecimalType(18, 10), dictionary) + // data.write.parquet(dir.getCanonicalPath) + // readParquetFile(dir.getCanonicalPath) { df => + // { + // val decimalLiteral = Decimal(0.00) + // val cometDf = df.select($"dec" / decimalLiteral, $"dec" % decimalLiteral) + // checkSparkAnswerAndOperator(cometDf) + // } + // } + // } + // } + // } + // } + + // test("bitwise shift with different left/right types") { + // Seq(false, true).foreach { dictionary => + // withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + // val table = "test" + // withTable(table) { + // sql(s"create table $table(col1 long, col2 int) using parquet") + // sql(s"insert into $table values(1111, 2)") + // sql(s"insert into $table values(1111, 2)") + // sql(s"insert into $table values(3333, 4)") + // sql(s"insert into $table values(5555, 6)") + + // checkSparkAnswerAndOperator( + // s"SELECT shiftright(col1, 2), shiftright(col1, col2) FROM $table") + // checkSparkAnswerAndOperator( + // s"SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM $table") + // } + // } + // } + // } + + // test("basic data type support") { + // Seq(true, false).foreach { dictionaryEnabled => + // withTempDir { dir => + // val path = new Path(dir.toURI.toString, "test.parquet") + // makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + // withParquetTable(path.toString, "tbl") { + // // TODO: enable test for unsigned ints + // checkSparkAnswerAndOperator( + // "select _1, _2, _3, _4, _5, _6, _7, _8, _13, _14, _15, _16, _17, " + + // "_18, _19, _20 FROM tbl WHERE _2 > 100") + // } + // } + // } + // } + + // test("null literals") { + // val batchSize = 1000 + // Seq(true, false).foreach { dictionaryEnabled => + // withTempDir { dir => + // val path = new Path(dir.toURI.toString, "test.parquet") + // makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, batchSize) + // withParquetTable(path.toString, "tbl") { + // val sqlString = "SELECT _4 + null, _15 - null, _16 * null FROM tbl" + // val df2 = sql(sqlString) + // val rows = df2.collect() + // assert(rows.length == batchSize) + // assert(rows.forall(_ == Row(null, null, null))) + + // checkSparkAnswerAndOperator(sqlString) + // } + // } + // } + // } + + // test("date and timestamp type literals") { + // Seq(true, false).foreach { dictionaryEnabled => + // withTempDir { dir => + // val path = new Path(dir.toURI.toString, "test.parquet") + // makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + // withParquetTable(path.toString, "tbl") { + // checkSparkAnswerAndOperator( + // "SELECT _4 FROM tbl WHERE " + + // "_20 > CAST('2020-01-01' AS DATE) AND _18 < CAST('2020-01-01' AS TIMESTAMP)") + // } + // } + // } + // } + + // test("dictionary arithmetic") { + // // TODO: test ANSI mode + // withSQLConf(SQLConf.ANSI_ENABLED.key -> "false", "parquet.enable.dictionary" -> "true") { + // withParquetTable((0 until 10).map(i => (i % 5, i % 3)), "tbl") { + // checkSparkAnswerAndOperator("SELECT _1 + _2, _1 - _2, _1 * _2, _1 / _2, _1 % _2 FROM tbl") + // } + // } + // } + + // test("dictionary arithmetic with scalar") { + // withSQLConf("parquet.enable.dictionary" -> "true") { + // withParquetTable((0 until 10).map(i => (i % 5, i % 3)), "tbl") { + // checkSparkAnswerAndOperator("SELECT _1 + 1, _1 - 1, _1 * 2, _1 / 2, _1 % 2 FROM tbl") + // } + // } + // } + + // test("string type and substring") { + // withParquetTable((0 until 5).map(i => (i.toString, (i + 100).toString)), "tbl") { + // checkSparkAnswerAndOperator("SELECT _1, substring(_2, 2, 2) FROM tbl") + // checkSparkAnswerAndOperator("SELECT _1, substring(_2, 2, -2) FROM tbl") + // checkSparkAnswerAndOperator("SELECT _1, substring(_2, -2, 2) FROM tbl") + // checkSparkAnswerAndOperator("SELECT _1, substring(_2, -2, -2) FROM tbl") + // checkSparkAnswerAndOperator("SELECT _1, substring(_2, -2, 10) FROM tbl") + // checkSparkAnswerAndOperator("SELECT _1, substring(_2, 0, 0) FROM tbl") + // checkSparkAnswerAndOperator("SELECT _1, substring(_2, 1, 0) FROM tbl") + // } + // } + + // test("substring with start < 1") { + // withTempPath { _ => + // withTable("t") { + // sql("create table t (col string) using parquet") + // sql("insert into t values('123456')") + // checkSparkAnswerAndOperator(sql("select substring(col, 0) from t")) + // checkSparkAnswerAndOperator(sql("select substring(col, -1) from t")) + // } + // } + // } + + // test("string with coalesce") { + // withParquetTable( + // (0 until 10).map(i => (i.toString, if (i > 5) None else Some((i + 100).toString))), + // "tbl") { + // checkSparkAnswerAndOperator( + // "SELECT coalesce(_1), coalesce(_1, 1), coalesce(null, _1), coalesce(null, 1), coalesce(_2, _1), coalesce(null) FROM tbl") + // } + // } + + // test("substring with dictionary") { + // val data = (0 until 1000) + // .map(_ % 5) // reduce value space to trigger dictionary encoding + // .map(i => (i.toString, (i + 100).toString)) + // withParquetTable(data, "tbl") { + // checkSparkAnswerAndOperator("SELECT _1, substring(_2, 2, 2) FROM tbl") + // } + // } + + // test("string_space") { + // withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { + // checkSparkAnswerAndOperator("SELECT space(_1), space(_2) FROM tbl") + // } + // } + + // test("string_space with dictionary") { + // val data = (0 until 1000).map(i => Tuple1(i % 5)) + + // withSQLConf("parquet.enable.dictionary" -> "true") { + // withParquetTable(data, "tbl") { + // checkSparkAnswerAndOperator("SELECT space(_1) FROM tbl") + // } + // } + // } + + // test("hour, minute, second") { + // Seq(true, false).foreach { dictionaryEnabled => + // withTempDir { dir => + // val path = new Path(dir.toURI.toString, "part-r-0.parquet") + // val expected = makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) + // readParquetFile(path.toString) { df => + // val query = df.select(expr("hour(_1)"), expr("minute(_1)"), expr("second(_1)")) + + // checkAnswer( + // query, + // expected.map { + // case None => + // Row(null, null, null) + // case Some(i) => + // val timestamp = new java.sql.Timestamp(i).toLocalDateTime + // val hour = timestamp.getHour + // val minute = timestamp.getMinute + // val second = timestamp.getSecond + + // Row(hour, minute, second) + // }) + // } + // } + // } + // } + + // test("hour on int96 timestamp column") { + // import testImplicits._ + + // val N = 100 + // val ts = "2020-01-01 01:02:03.123456" + // Seq(true, false).foreach { dictionaryEnabled => + // Seq(false, true).foreach { conversionEnabled => + // withSQLConf( + // SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> "INT96", + // SQLConf.PARQUET_INT96_TIMESTAMP_CONVERSION.key -> conversionEnabled.toString) { + // withTempPath { path => + // Seq + // .tabulate(N)(_ => ts) + // .toDF("ts1") + // .select($"ts1".cast("timestamp").as("ts")) + // .repartition(1) + // .write + // .option("parquet.enable.dictionary", dictionaryEnabled) + // .parquet(path.getCanonicalPath) + + // checkAnswer( + // spark.read.parquet(path.getCanonicalPath).select(expr("hour(ts)")), + // Seq.tabulate(N)(_ => Row(1))) + // } + // } + // } + // } + // } + + // test("cast timestamp and timestamp_ntz") { + // withSQLConf( + // SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", + // CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + // Seq(true, false).foreach { dictionaryEnabled => + // withTempDir { dir => + // val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") + // makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) + // withParquetTable(path.toString, "timetbl") { + // checkSparkAnswerAndOperator( + // "SELECT " + + // "cast(_2 as timestamp) tz_millis, " + + // "cast(_3 as timestamp) ntz_millis, " + + // "cast(_4 as timestamp) tz_micros, " + + // "cast(_5 as timestamp) ntz_micros " + + // " from timetbl") + // } + // } + // } + // } + // } + + // test("cast timestamp and timestamp_ntz to string") { + // // TODO: make the test pass for Spark 3.2 & 3.3 + // assume(isSpark34Plus) + + // withSQLConf( + // SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", + // CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + // Seq(true, false).foreach { dictionaryEnabled => + // withTempDir { dir => + // val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") + // makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 2001) + // withParquetTable(path.toString, "timetbl") { + // checkSparkAnswerAndOperator( + // "SELECT " + + // "cast(_2 as string) tz_millis, " + + // "cast(_3 as string) ntz_millis, " + + // "cast(_4 as string) tz_micros, " + + // "cast(_5 as string) ntz_micros " + + // " from timetbl") + // } + // } + // } + // } + // } + + // test("cast timestamp and timestamp_ntz to long, date") { + // // TODO: make the test pass for Spark 3.2 & 3.3 + // assume(isSpark34Plus) + + // withSQLConf( + // SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", + // CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + // Seq(true, false).foreach { dictionaryEnabled => + // withTempDir { dir => + // val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") + // makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) + // withParquetTable(path.toString, "timetbl") { + // checkSparkAnswerAndOperator( + // "SELECT " + + // "cast(_2 as long) tz_millis, " + + // "cast(_4 as long) tz_micros, " + + // "cast(_2 as date) tz_millis_to_date, " + + // "cast(_3 as date) ntz_millis_to_date, " + + // "cast(_4 as date) tz_micros_to_date, " + + // "cast(_5 as date) ntz_micros_to_date " + + // " from timetbl") + // } + // } + // } + // } + // } + + // test("trunc") { + // Seq(true, false).foreach { dictionaryEnabled => + // withTempDir { dir => + // val path = new Path(dir.toURI.toString, "date_trunc.parquet") + // makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + // withParquetTable(path.toString, "tbl") { + // Seq("YEAR", "YYYY", "YY", "QUARTER", "MON", "MONTH", "MM", "WEEK").foreach { format => + // checkSparkAnswerAndOperator(s"SELECT trunc(_20, '$format') from tbl") + // } + // } + // } + // } + // } + + // test("trunc with format array") { + // val numRows = 1000 + // Seq(true, false).foreach { dictionaryEnabled => + // withTempDir { dir => + // val path = new Path(dir.toURI.toString, "date_trunc_with_format.parquet") + // makeDateTimeWithFormatTable(path, dictionaryEnabled = dictionaryEnabled, numRows) + // withParquetTable(path.toString, "dateformattbl") { + // checkSparkAnswerAndOperator( + // "SELECT " + + // "dateformat, _7, " + + // "trunc(_7, dateformat) " + + // " from dateformattbl ") + // } + // } + // } + // } + + // test("date_trunc") { + // Seq(true, false).foreach { dictionaryEnabled => + // withTempDir { dir => + // val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") + // makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) + // withParquetTable(path.toString, "timetbl") { + // Seq( + // "YEAR", + // "YYYY", + // "YY", + // "MON", + // "MONTH", + // "MM", + // "QUARTER", + // "WEEK", + // "DAY", + // "DD", + // "HOUR", + // "MINUTE", + // "SECOND", + // "MILLISECOND", + // "MICROSECOND").foreach { format => + // checkSparkAnswerAndOperator( + // "SELECT " + + // s"date_trunc('$format', _0), " + + // s"date_trunc('$format', _1), " + + // s"date_trunc('$format', _2), " + + // s"date_trunc('$format', _4) " + + // " from timetbl") + // } + // } + // } + // } + // } + + // test("date_trunc with timestamp_ntz") { + // assume(!isSpark32, "timestamp functions for timestamp_ntz have incorrect behavior in 3.2") + // withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + // Seq(true, false).foreach { dictionaryEnabled => + // withTempDir { dir => + // val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") + // makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) + // withParquetTable(path.toString, "timetbl") { + // Seq( + // "YEAR", + // "YYYY", + // "YY", + // "MON", + // "MONTH", + // "MM", + // "QUARTER", + // "WEEK", + // "DAY", + // "DD", + // "HOUR", + // "MINUTE", + // "SECOND", + // "MILLISECOND", + // "MICROSECOND").foreach { format => + // checkSparkAnswerAndOperator( + // "SELECT " + + // s"date_trunc('$format', _3), " + + // s"date_trunc('$format', _5) " + + // " from timetbl") + // } + // } + // } + // } + // } + // } + + // test("date_trunc with format array") { + // assume(isSpark33Plus, "TimestampNTZ is supported in Spark 3.3+, See SPARK-36182") + // withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + // val numRows = 1000 + // Seq(true, false).foreach { dictionaryEnabled => + // withTempDir { dir => + // val path = new Path(dir.toURI.toString, "timestamp_trunc_with_format.parquet") + // makeDateTimeWithFormatTable(path, dictionaryEnabled = dictionaryEnabled, numRows) + // withParquetTable(path.toString, "timeformattbl") { + // checkSparkAnswerAndOperator( + // "SELECT " + + // "format, _0, _1, _2, _3, _4, _5, " + + // "date_trunc(format, _0), " + + // "date_trunc(format, _1), " + + // "date_trunc(format, _2), " + + // "date_trunc(format, _3), " + + // "date_trunc(format, _4), " + + // "date_trunc(format, _5) " + + // " from timeformattbl ") + // } + // } + // } + // } + // } + + // test("date_trunc on int96 timestamp column") { + // import testImplicits._ + + // val N = 100 + // val ts = "2020-01-01 01:02:03.123456" + // Seq(true, false).foreach { dictionaryEnabled => + // Seq(false, true).foreach { conversionEnabled => + // withSQLConf( + // SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> "INT96", + // SQLConf.PARQUET_INT96_TIMESTAMP_CONVERSION.key -> conversionEnabled.toString) { + // withTempPath { path => + // Seq + // .tabulate(N)(_ => ts) + // .toDF("ts1") + // .select($"ts1".cast("timestamp").as("ts")) + // .repartition(1) + // .write + // .option("parquet.enable.dictionary", dictionaryEnabled) + // .parquet(path.getCanonicalPath) + + // withParquetTable(path.toString, "int96timetbl") { + // Seq( + // "YEAR", + // "YYYY", + // "YY", + // "MON", + // "MONTH", + // "MM", + // "QUARTER", + // "WEEK", + // "DAY", + // "DD", + // "HOUR", + // "MINUTE", + // "SECOND", + // "MILLISECOND", + // "MICROSECOND").foreach { format => + // checkSparkAnswer( + // "SELECT " + + // s"date_trunc('$format', ts )" + + // " from int96timetbl") + // } + // } + // } + // } + // } + // } + // } + + // test("charvarchar") { + // Seq(false, true).foreach { dictionary => + // withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + // val table = "char_tbl4" + // withTable(table) { + // val view = "str_view" + // withView(view) { + // sql(s"""create temporary view $view as select c, v from values + // | (null, null), (null, null), + // | (null, 'S'), (null, 'S'), + // | ('N', 'N '), ('N', 'N '), + // | ('Ne', 'Sp'), ('Ne', 'Sp'), + // | ('Net ', 'Spa '), ('Net ', 'Spa '), + // | ('NetE', 'Spar'), ('NetE', 'Spar'), + // | ('NetEa ', 'Spark '), ('NetEa ', 'Spark '), + // | ('NetEas ', 'Spark'), ('NetEas ', 'Spark'), + // | ('NetEase', 'Spark-'), ('NetEase', 'Spark-') t(c, v);""".stripMargin) + // sql( + // s"create table $table(c7 char(7), c8 char(8), v varchar(6), s string) using parquet;") + // sql(s"insert into $table select c, c, v, c from $view;") + // val df = sql(s"""select substring(c7, 2), substring(c8, 2), + // | substring(v, 3), substring(s, 2) from $table;""".stripMargin) + + // val expected = Row(" ", " ", "", "") :: + // Row(null, null, "", null) :: Row(null, null, null, null) :: + // Row("e ", "e ", "", "e") :: Row("et ", "et ", "a ", "et ") :: + // Row("etE ", "etE ", "ar", "etE") :: + // Row("etEa ", "etEa ", "ark ", "etEa ") :: + // Row("etEas ", "etEas ", "ark", "etEas ") :: + // Row("etEase", "etEase ", "ark-", "etEase") :: Nil + // checkAnswer(df, expected ::: expected) + // } + // } + // } + // } + // } + + // test("char varchar over length values") { + // Seq("char", "varchar").foreach { typ => + // withTempPath { dir => + // withTable("t") { + // sql("select '123456' as col").write.format("parquet").save(dir.toString) + // sql(s"create table t (col $typ(2)) using parquet location '$dir'") + // sql("insert into t values('1')") + // checkSparkAnswerAndOperator(sql("select substring(col, 1) from t")) + // checkSparkAnswerAndOperator(sql("select substring(col, 0) from t")) + // checkSparkAnswerAndOperator(sql("select substring(col, -1) from t")) + // } + // } + // } + // } + + // test("like (LikeSimplification enabled)") { + // val table = "names" + // withTable(table) { + // sql(s"create table $table(id int, name varchar(20)) using parquet") + // sql(s"insert into $table values(1,'James Smith')") + // sql(s"insert into $table values(2,'Michael Rose')") + // sql(s"insert into $table values(3,'Robert Williams')") + // sql(s"insert into $table values(4,'Rames Rose')") + // sql(s"insert into $table values(5,'Rames rose')") + + // // Filter column having values 'Rames _ose', where any character matches for '_' + // val query = sql(s"select id from $table where name like 'Rames _ose'") + // checkAnswer(query, Row(4) :: Row(5) :: Nil) + + // // Filter rows that contains 'rose' in 'name' column + // val queryContains = sql(s"select id from $table where name like '%rose%'") + // checkAnswer(queryContains, Row(5) :: Nil) + + // // Filter rows that starts with 'R' following by any characters + // val queryStartsWith = sql(s"select id from $table where name like 'R%'") + // checkAnswer(queryStartsWith, Row(3) :: Row(4) :: Row(5) :: Nil) + + // // Filter rows that ends with 's' following by any characters + // val queryEndsWith = sql(s"select id from $table where name like '%s'") + // checkAnswer(queryEndsWith, Row(3) :: Nil) + // } + // } + + // test("like with custom escape") { + // val table = "names" + // withTable(table) { + // sql(s"create table $table(id int, name varchar(20)) using parquet") + // sql(s"insert into $table values(1,'James Smith')") + // sql(s"insert into $table values(2,'Michael_Rose')") + // sql(s"insert into $table values(3,'Robert_R_Williams')") + + // // Filter column having values that include underscores + // val queryDefaultEscape = sql("select id from names where name like '%\\_%'") + // checkSparkAnswerAndOperator(queryDefaultEscape) + + // val queryCustomEscape = sql("select id from names where name like '%$_%' escape '$'") + // checkAnswer(queryCustomEscape, Row(2) :: Row(3) :: Nil) + + // } + // } + + // test("contains") { + // assume(!isSpark32) + + // val table = "names" + // withTable(table) { + // sql(s"create table $table(id int, name varchar(20)) using parquet") + // sql(s"insert into $table values(1,'James Smith')") + // sql(s"insert into $table values(2,'Michael Rose')") + // sql(s"insert into $table values(3,'Robert Williams')") + // sql(s"insert into $table values(4,'Rames Rose')") + // sql(s"insert into $table values(5,'Rames rose')") + + // // Filter rows that contains 'rose' in 'name' column + // val queryContains = sql(s"select id from $table where contains (name, 'rose')") + // checkAnswer(queryContains, Row(5) :: Nil) + // } + // } + + // test("startswith") { + // assume(!isSpark32) + + // val table = "names" + // withTable(table) { + // sql(s"create table $table(id int, name varchar(20)) using parquet") + // sql(s"insert into $table values(1,'James Smith')") + // sql(s"insert into $table values(2,'Michael Rose')") + // sql(s"insert into $table values(3,'Robert Williams')") + // sql(s"insert into $table values(4,'Rames Rose')") + // sql(s"insert into $table values(5,'Rames rose')") + + // // Filter rows that starts with 'R' following by any characters + // val queryStartsWith = sql(s"select id from $table where startswith (name, 'R')") + // checkAnswer(queryStartsWith, Row(3) :: Row(4) :: Row(5) :: Nil) + // } + // } + + // test("endswith") { + // assume(!isSpark32) + + // val table = "names" + // withTable(table) { + // sql(s"create table $table(id int, name varchar(20)) using parquet") + // sql(s"insert into $table values(1,'James Smith')") + // sql(s"insert into $table values(2,'Michael Rose')") + // sql(s"insert into $table values(3,'Robert Williams')") + // sql(s"insert into $table values(4,'Rames Rose')") + // sql(s"insert into $table values(5,'Rames rose')") + + // // Filter rows that ends with 's' following by any characters + // val queryEndsWith = sql(s"select id from $table where endswith (name, 's')") + // checkAnswer(queryEndsWith, Row(3) :: Nil) + // } + // } + + // test("add overflow (ANSI disable)") { + // // Enabling ANSI will cause native engine failure, but as we cannot catch + // // native error now, we cannot test it here. + // withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + // withParquetTable(Seq((Int.MaxValue, 1)), "tbl") { + // checkSparkAnswerAndOperator("SELECT _1 + _2 FROM tbl") + // } + // } + // } + + // test("divide by zero (ANSI disable)") { + // // Enabling ANSI will cause native engine failure, but as we cannot catch + // // native error now, we cannot test it here. + // withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + // withParquetTable(Seq((1, 0, 1.0, 0.0)), "tbl") { + // checkSparkAnswerAndOperator("SELECT _1 / _2, _3 / _4 FROM tbl") + // } + // } + // } + + // test("decimals arithmetic and comparison") { + // // TODO: enable Spark 3.2 & 3.3 tests after supporting decimal reminder operation + // assume(isSpark34Plus) + + // def makeDecimalRDD(num: Int, decimal: DecimalType, useDictionary: Boolean): DataFrame = { + // val div = if (useDictionary) 5 else num // narrow the space to make it dictionary encoded + // spark + // .range(num) + // .map(_ % div) + // // Parquet doesn't allow column names with spaces, have to add an alias here. + // // Minus 500 here so that negative decimals are also tested. + // .select( + // (($"value" - 500) / 100.0) cast decimal as Symbol("dec1"), + // (($"value" - 600) / 100.0) cast decimal as Symbol("dec2")) + // .coalesce(1) + // } + + // Seq(true, false).foreach { dictionary => + // Seq(16, 1024).foreach { batchSize => + // withSQLConf( + // CometConf.COMET_BATCH_SIZE.key -> batchSize.toString, + // SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false", + // "parquet.enable.dictionary" -> dictionary.toString) { + // var combinations = Seq((5, 2), (1, 0), (18, 10), (18, 17), (19, 0), (38, 37)) + // // If ANSI mode is on, the combination (1, 1) will cause a runtime error. Otherwise, the + // // decimal RDD contains all null values and should be able to read back from Parquet. + + // if (!SQLConf.get.ansiEnabled) { + // combinations = combinations ++ Seq((1, 1)) + // } + + // for ((precision, scale) <- combinations) { + // withTempPath { dir => + // val data = makeDecimalRDD(10, DecimalType(precision, scale), dictionary) + // data.write.parquet(dir.getCanonicalPath) + // readParquetFile(dir.getCanonicalPath) { df => + // { + // val decimalLiteral1 = Decimal(1.00) + // val decimalLiteral2 = Decimal(123.456789) + // val cometDf = df.select( + // $"dec1" + $"dec2", + // $"dec1" - $"dec2", + // $"dec1" % $"dec2", + // $"dec1" >= $"dec1", + // $"dec1" === "1.0", + // $"dec1" + decimalLiteral1, + // $"dec1" - decimalLiteral1, + // $"dec1" + decimalLiteral2, + // $"dec1" - decimalLiteral2) + + // checkAnswer( + // cometDf, + // data + // .select( + // $"dec1" + $"dec2", + // $"dec1" - $"dec2", + // $"dec1" % $"dec2", + // $"dec1" >= $"dec1", + // $"dec1" === "1.0", + // $"dec1" + decimalLiteral1, + // $"dec1" - decimalLiteral1, + // $"dec1" + decimalLiteral2, + // $"dec1" - decimalLiteral2) + // .collect() + // .toSeq) + // } + // } + // } + // } + // } + // } + // } + // } + + // test("scalar decimal arithmetic operations") { + // assume(isSpark34Plus) + // withTable("tbl") { + // withSQLConf(CometConf.COMET_ENABLED.key -> "true") { + // sql("CREATE TABLE tbl (a INT) USING PARQUET") + // sql("INSERT INTO tbl VALUES (0)") + + // val combinations = Seq((7, 3), (18, 10), (38, 4)) + // for ((precision, scale) <- combinations) { + // for (op <- Seq("+", "-", "*", "/", "%")) { + // val left = s"CAST(1.00 AS DECIMAL($precision, $scale))" + // val right = s"CAST(123.45 AS DECIMAL($precision, $scale))" + + // withSQLConf( + // "spark.sql.optimizer.excludedRules" -> + // "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + + // checkSparkAnswerAndOperator(s"SELECT $left $op $right FROM tbl") + // } + // } + // } + // } + // } + // } + + // test("cast decimals to int") { + // Seq(16, 1024).foreach { batchSize => + // withSQLConf( + // CometConf.COMET_BATCH_SIZE.key -> batchSize.toString, + // SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false") { + // var combinations = Seq((5, 2), (1, 0), (18, 10), (18, 17), (19, 0), (38, 37)) + // // If ANSI mode is on, the combination (1, 1) will cause a runtime error. Otherwise, the + // // decimal RDD contains all null values and should be able to read back from Parquet. + + // if (!SQLConf.get.ansiEnabled) { + // combinations = combinations ++ Seq((1, 1)) + // } + + // for ((precision, scale) <- combinations; useDictionary <- Seq(false)) { + // withTempPath { dir => + // val data = makeDecimalRDD(10, DecimalType(precision, scale), useDictionary) + // data.write.parquet(dir.getCanonicalPath) + // readParquetFile(dir.getCanonicalPath) { df => + // { + // val cometDf = df.select($"dec".cast("int")) + + // // `data` is not read from Parquet, so it doesn't go Comet exec. + // checkAnswer(cometDf, data.select($"dec".cast("int")).collect().toSeq) + // } + // } + // } + // } + // } + // } + // } + + // test("various math scalar functions") { + // Seq("true", "false").foreach { dictionary => + // withSQLConf("parquet.enable.dictionary" -> dictionary) { + // withParquetTable( + // (0 until 5).map(i => (i.toDouble + 0.3, i.toDouble + 0.8)), + // "tbl", + // withDictionary = dictionary.toBoolean) { + // checkSparkAnswerWithTol( + // "SELECT abs(_1), acos(_2), asin(_1), atan(_2), atan2(_1, _2), cos(_1) FROM tbl") + // checkSparkAnswerWithTol( + // "SELECT exp(_1), ln(_2), log10(_1), log2(_1), pow(_1, _2) FROM tbl") + // // TODO: comment in the round tests once supported + // // checkSparkAnswerWithTol("SELECT round(_1), round(_2) FROM tbl") + // checkSparkAnswerWithTol("SELECT signum(_1), sin(_1), sqrt(_1) FROM tbl") + // checkSparkAnswerWithTol("SELECT tan(_1) FROM tbl") + // } + // } + // } + // } + + // test("abs") { + // Seq(true, false).foreach { dictionaryEnabled => + // withTempDir { dir => + // val path = new Path(dir.toURI.toString, "test.parquet") + // makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 100) + // withParquetTable(path.toString, "tbl") { + // Seq(2, 3, 4, 5, 6, 7, 15, 16, 17).foreach { col => + // checkSparkAnswerAndOperator(s"SELECT abs(_${col}) FROM tbl") + // } + // } + // } + // } + // } test("remainder") { withTempDir { dir => - // Create a DataFrame with null values val df = Seq((-21840, -0.0)).toDF("c90", "c1") - - // Write the DataFrame to a Parquet file val path = new Path(dir.toURI.toString, "remainder_test.parquet").toString df.write.mode("overwrite").parquet(path) From 90494da86edf70628e2863f1bbe675764f631e4d Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Tue, 18 Jun 2024 18:12:56 +0530 Subject: [PATCH 03/18] bug fix --- .../apache/comet/CometExpressionSuite.scala | 1610 ++++++++--------- 1 file changed, 805 insertions(+), 805 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 645f5f8480..2b7f8cd9dc 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -50,811 +50,811 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - // test("decimals divide by zero") { - // // TODO: enable Spark 3.2 & 3.3 tests after supporting decimal divide operation - // assume(isSpark34Plus) - - // Seq(true, false).foreach { dictionary => - // withSQLConf( - // SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false", - // "parquet.enable.dictionary" -> dictionary.toString) { - // withTempPath { dir => - // val data = makeDecimalRDD(10, DecimalType(18, 10), dictionary) - // data.write.parquet(dir.getCanonicalPath) - // readParquetFile(dir.getCanonicalPath) { df => - // { - // val decimalLiteral = Decimal(0.00) - // val cometDf = df.select($"dec" / decimalLiteral, $"dec" % decimalLiteral) - // checkSparkAnswerAndOperator(cometDf) - // } - // } - // } - // } - // } - // } - - // test("bitwise shift with different left/right types") { - // Seq(false, true).foreach { dictionary => - // withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { - // val table = "test" - // withTable(table) { - // sql(s"create table $table(col1 long, col2 int) using parquet") - // sql(s"insert into $table values(1111, 2)") - // sql(s"insert into $table values(1111, 2)") - // sql(s"insert into $table values(3333, 4)") - // sql(s"insert into $table values(5555, 6)") - - // checkSparkAnswerAndOperator( - // s"SELECT shiftright(col1, 2), shiftright(col1, col2) FROM $table") - // checkSparkAnswerAndOperator( - // s"SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM $table") - // } - // } - // } - // } - - // test("basic data type support") { - // Seq(true, false).foreach { dictionaryEnabled => - // withTempDir { dir => - // val path = new Path(dir.toURI.toString, "test.parquet") - // makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) - // withParquetTable(path.toString, "tbl") { - // // TODO: enable test for unsigned ints - // checkSparkAnswerAndOperator( - // "select _1, _2, _3, _4, _5, _6, _7, _8, _13, _14, _15, _16, _17, " + - // "_18, _19, _20 FROM tbl WHERE _2 > 100") - // } - // } - // } - // } - - // test("null literals") { - // val batchSize = 1000 - // Seq(true, false).foreach { dictionaryEnabled => - // withTempDir { dir => - // val path = new Path(dir.toURI.toString, "test.parquet") - // makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, batchSize) - // withParquetTable(path.toString, "tbl") { - // val sqlString = "SELECT _4 + null, _15 - null, _16 * null FROM tbl" - // val df2 = sql(sqlString) - // val rows = df2.collect() - // assert(rows.length == batchSize) - // assert(rows.forall(_ == Row(null, null, null))) - - // checkSparkAnswerAndOperator(sqlString) - // } - // } - // } - // } - - // test("date and timestamp type literals") { - // Seq(true, false).foreach { dictionaryEnabled => - // withTempDir { dir => - // val path = new Path(dir.toURI.toString, "test.parquet") - // makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) - // withParquetTable(path.toString, "tbl") { - // checkSparkAnswerAndOperator( - // "SELECT _4 FROM tbl WHERE " + - // "_20 > CAST('2020-01-01' AS DATE) AND _18 < CAST('2020-01-01' AS TIMESTAMP)") - // } - // } - // } - // } - - // test("dictionary arithmetic") { - // // TODO: test ANSI mode - // withSQLConf(SQLConf.ANSI_ENABLED.key -> "false", "parquet.enable.dictionary" -> "true") { - // withParquetTable((0 until 10).map(i => (i % 5, i % 3)), "tbl") { - // checkSparkAnswerAndOperator("SELECT _1 + _2, _1 - _2, _1 * _2, _1 / _2, _1 % _2 FROM tbl") - // } - // } - // } - - // test("dictionary arithmetic with scalar") { - // withSQLConf("parquet.enable.dictionary" -> "true") { - // withParquetTable((0 until 10).map(i => (i % 5, i % 3)), "tbl") { - // checkSparkAnswerAndOperator("SELECT _1 + 1, _1 - 1, _1 * 2, _1 / 2, _1 % 2 FROM tbl") - // } - // } - // } - - // test("string type and substring") { - // withParquetTable((0 until 5).map(i => (i.toString, (i + 100).toString)), "tbl") { - // checkSparkAnswerAndOperator("SELECT _1, substring(_2, 2, 2) FROM tbl") - // checkSparkAnswerAndOperator("SELECT _1, substring(_2, 2, -2) FROM tbl") - // checkSparkAnswerAndOperator("SELECT _1, substring(_2, -2, 2) FROM tbl") - // checkSparkAnswerAndOperator("SELECT _1, substring(_2, -2, -2) FROM tbl") - // checkSparkAnswerAndOperator("SELECT _1, substring(_2, -2, 10) FROM tbl") - // checkSparkAnswerAndOperator("SELECT _1, substring(_2, 0, 0) FROM tbl") - // checkSparkAnswerAndOperator("SELECT _1, substring(_2, 1, 0) FROM tbl") - // } - // } - - // test("substring with start < 1") { - // withTempPath { _ => - // withTable("t") { - // sql("create table t (col string) using parquet") - // sql("insert into t values('123456')") - // checkSparkAnswerAndOperator(sql("select substring(col, 0) from t")) - // checkSparkAnswerAndOperator(sql("select substring(col, -1) from t")) - // } - // } - // } - - // test("string with coalesce") { - // withParquetTable( - // (0 until 10).map(i => (i.toString, if (i > 5) None else Some((i + 100).toString))), - // "tbl") { - // checkSparkAnswerAndOperator( - // "SELECT coalesce(_1), coalesce(_1, 1), coalesce(null, _1), coalesce(null, 1), coalesce(_2, _1), coalesce(null) FROM tbl") - // } - // } - - // test("substring with dictionary") { - // val data = (0 until 1000) - // .map(_ % 5) // reduce value space to trigger dictionary encoding - // .map(i => (i.toString, (i + 100).toString)) - // withParquetTable(data, "tbl") { - // checkSparkAnswerAndOperator("SELECT _1, substring(_2, 2, 2) FROM tbl") - // } - // } - - // test("string_space") { - // withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { - // checkSparkAnswerAndOperator("SELECT space(_1), space(_2) FROM tbl") - // } - // } - - // test("string_space with dictionary") { - // val data = (0 until 1000).map(i => Tuple1(i % 5)) - - // withSQLConf("parquet.enable.dictionary" -> "true") { - // withParquetTable(data, "tbl") { - // checkSparkAnswerAndOperator("SELECT space(_1) FROM tbl") - // } - // } - // } - - // test("hour, minute, second") { - // Seq(true, false).foreach { dictionaryEnabled => - // withTempDir { dir => - // val path = new Path(dir.toURI.toString, "part-r-0.parquet") - // val expected = makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) - // readParquetFile(path.toString) { df => - // val query = df.select(expr("hour(_1)"), expr("minute(_1)"), expr("second(_1)")) - - // checkAnswer( - // query, - // expected.map { - // case None => - // Row(null, null, null) - // case Some(i) => - // val timestamp = new java.sql.Timestamp(i).toLocalDateTime - // val hour = timestamp.getHour - // val minute = timestamp.getMinute - // val second = timestamp.getSecond - - // Row(hour, minute, second) - // }) - // } - // } - // } - // } - - // test("hour on int96 timestamp column") { - // import testImplicits._ - - // val N = 100 - // val ts = "2020-01-01 01:02:03.123456" - // Seq(true, false).foreach { dictionaryEnabled => - // Seq(false, true).foreach { conversionEnabled => - // withSQLConf( - // SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> "INT96", - // SQLConf.PARQUET_INT96_TIMESTAMP_CONVERSION.key -> conversionEnabled.toString) { - // withTempPath { path => - // Seq - // .tabulate(N)(_ => ts) - // .toDF("ts1") - // .select($"ts1".cast("timestamp").as("ts")) - // .repartition(1) - // .write - // .option("parquet.enable.dictionary", dictionaryEnabled) - // .parquet(path.getCanonicalPath) - - // checkAnswer( - // spark.read.parquet(path.getCanonicalPath).select(expr("hour(ts)")), - // Seq.tabulate(N)(_ => Row(1))) - // } - // } - // } - // } - // } - - // test("cast timestamp and timestamp_ntz") { - // withSQLConf( - // SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", - // CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { - // Seq(true, false).foreach { dictionaryEnabled => - // withTempDir { dir => - // val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") - // makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) - // withParquetTable(path.toString, "timetbl") { - // checkSparkAnswerAndOperator( - // "SELECT " + - // "cast(_2 as timestamp) tz_millis, " + - // "cast(_3 as timestamp) ntz_millis, " + - // "cast(_4 as timestamp) tz_micros, " + - // "cast(_5 as timestamp) ntz_micros " + - // " from timetbl") - // } - // } - // } - // } - // } - - // test("cast timestamp and timestamp_ntz to string") { - // // TODO: make the test pass for Spark 3.2 & 3.3 - // assume(isSpark34Plus) - - // withSQLConf( - // SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", - // CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { - // Seq(true, false).foreach { dictionaryEnabled => - // withTempDir { dir => - // val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") - // makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 2001) - // withParquetTable(path.toString, "timetbl") { - // checkSparkAnswerAndOperator( - // "SELECT " + - // "cast(_2 as string) tz_millis, " + - // "cast(_3 as string) ntz_millis, " + - // "cast(_4 as string) tz_micros, " + - // "cast(_5 as string) ntz_micros " + - // " from timetbl") - // } - // } - // } - // } - // } - - // test("cast timestamp and timestamp_ntz to long, date") { - // // TODO: make the test pass for Spark 3.2 & 3.3 - // assume(isSpark34Plus) - - // withSQLConf( - // SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", - // CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { - // Seq(true, false).foreach { dictionaryEnabled => - // withTempDir { dir => - // val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") - // makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) - // withParquetTable(path.toString, "timetbl") { - // checkSparkAnswerAndOperator( - // "SELECT " + - // "cast(_2 as long) tz_millis, " + - // "cast(_4 as long) tz_micros, " + - // "cast(_2 as date) tz_millis_to_date, " + - // "cast(_3 as date) ntz_millis_to_date, " + - // "cast(_4 as date) tz_micros_to_date, " + - // "cast(_5 as date) ntz_micros_to_date " + - // " from timetbl") - // } - // } - // } - // } - // } - - // test("trunc") { - // Seq(true, false).foreach { dictionaryEnabled => - // withTempDir { dir => - // val path = new Path(dir.toURI.toString, "date_trunc.parquet") - // makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) - // withParquetTable(path.toString, "tbl") { - // Seq("YEAR", "YYYY", "YY", "QUARTER", "MON", "MONTH", "MM", "WEEK").foreach { format => - // checkSparkAnswerAndOperator(s"SELECT trunc(_20, '$format') from tbl") - // } - // } - // } - // } - // } - - // test("trunc with format array") { - // val numRows = 1000 - // Seq(true, false).foreach { dictionaryEnabled => - // withTempDir { dir => - // val path = new Path(dir.toURI.toString, "date_trunc_with_format.parquet") - // makeDateTimeWithFormatTable(path, dictionaryEnabled = dictionaryEnabled, numRows) - // withParquetTable(path.toString, "dateformattbl") { - // checkSparkAnswerAndOperator( - // "SELECT " + - // "dateformat, _7, " + - // "trunc(_7, dateformat) " + - // " from dateformattbl ") - // } - // } - // } - // } - - // test("date_trunc") { - // Seq(true, false).foreach { dictionaryEnabled => - // withTempDir { dir => - // val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") - // makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) - // withParquetTable(path.toString, "timetbl") { - // Seq( - // "YEAR", - // "YYYY", - // "YY", - // "MON", - // "MONTH", - // "MM", - // "QUARTER", - // "WEEK", - // "DAY", - // "DD", - // "HOUR", - // "MINUTE", - // "SECOND", - // "MILLISECOND", - // "MICROSECOND").foreach { format => - // checkSparkAnswerAndOperator( - // "SELECT " + - // s"date_trunc('$format', _0), " + - // s"date_trunc('$format', _1), " + - // s"date_trunc('$format', _2), " + - // s"date_trunc('$format', _4) " + - // " from timetbl") - // } - // } - // } - // } - // } - - // test("date_trunc with timestamp_ntz") { - // assume(!isSpark32, "timestamp functions for timestamp_ntz have incorrect behavior in 3.2") - // withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { - // Seq(true, false).foreach { dictionaryEnabled => - // withTempDir { dir => - // val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") - // makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) - // withParquetTable(path.toString, "timetbl") { - // Seq( - // "YEAR", - // "YYYY", - // "YY", - // "MON", - // "MONTH", - // "MM", - // "QUARTER", - // "WEEK", - // "DAY", - // "DD", - // "HOUR", - // "MINUTE", - // "SECOND", - // "MILLISECOND", - // "MICROSECOND").foreach { format => - // checkSparkAnswerAndOperator( - // "SELECT " + - // s"date_trunc('$format', _3), " + - // s"date_trunc('$format', _5) " + - // " from timetbl") - // } - // } - // } - // } - // } - // } - - // test("date_trunc with format array") { - // assume(isSpark33Plus, "TimestampNTZ is supported in Spark 3.3+, See SPARK-36182") - // withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { - // val numRows = 1000 - // Seq(true, false).foreach { dictionaryEnabled => - // withTempDir { dir => - // val path = new Path(dir.toURI.toString, "timestamp_trunc_with_format.parquet") - // makeDateTimeWithFormatTable(path, dictionaryEnabled = dictionaryEnabled, numRows) - // withParquetTable(path.toString, "timeformattbl") { - // checkSparkAnswerAndOperator( - // "SELECT " + - // "format, _0, _1, _2, _3, _4, _5, " + - // "date_trunc(format, _0), " + - // "date_trunc(format, _1), " + - // "date_trunc(format, _2), " + - // "date_trunc(format, _3), " + - // "date_trunc(format, _4), " + - // "date_trunc(format, _5) " + - // " from timeformattbl ") - // } - // } - // } - // } - // } - - // test("date_trunc on int96 timestamp column") { - // import testImplicits._ - - // val N = 100 - // val ts = "2020-01-01 01:02:03.123456" - // Seq(true, false).foreach { dictionaryEnabled => - // Seq(false, true).foreach { conversionEnabled => - // withSQLConf( - // SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> "INT96", - // SQLConf.PARQUET_INT96_TIMESTAMP_CONVERSION.key -> conversionEnabled.toString) { - // withTempPath { path => - // Seq - // .tabulate(N)(_ => ts) - // .toDF("ts1") - // .select($"ts1".cast("timestamp").as("ts")) - // .repartition(1) - // .write - // .option("parquet.enable.dictionary", dictionaryEnabled) - // .parquet(path.getCanonicalPath) - - // withParquetTable(path.toString, "int96timetbl") { - // Seq( - // "YEAR", - // "YYYY", - // "YY", - // "MON", - // "MONTH", - // "MM", - // "QUARTER", - // "WEEK", - // "DAY", - // "DD", - // "HOUR", - // "MINUTE", - // "SECOND", - // "MILLISECOND", - // "MICROSECOND").foreach { format => - // checkSparkAnswer( - // "SELECT " + - // s"date_trunc('$format', ts )" + - // " from int96timetbl") - // } - // } - // } - // } - // } - // } - // } - - // test("charvarchar") { - // Seq(false, true).foreach { dictionary => - // withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { - // val table = "char_tbl4" - // withTable(table) { - // val view = "str_view" - // withView(view) { - // sql(s"""create temporary view $view as select c, v from values - // | (null, null), (null, null), - // | (null, 'S'), (null, 'S'), - // | ('N', 'N '), ('N', 'N '), - // | ('Ne', 'Sp'), ('Ne', 'Sp'), - // | ('Net ', 'Spa '), ('Net ', 'Spa '), - // | ('NetE', 'Spar'), ('NetE', 'Spar'), - // | ('NetEa ', 'Spark '), ('NetEa ', 'Spark '), - // | ('NetEas ', 'Spark'), ('NetEas ', 'Spark'), - // | ('NetEase', 'Spark-'), ('NetEase', 'Spark-') t(c, v);""".stripMargin) - // sql( - // s"create table $table(c7 char(7), c8 char(8), v varchar(6), s string) using parquet;") - // sql(s"insert into $table select c, c, v, c from $view;") - // val df = sql(s"""select substring(c7, 2), substring(c8, 2), - // | substring(v, 3), substring(s, 2) from $table;""".stripMargin) - - // val expected = Row(" ", " ", "", "") :: - // Row(null, null, "", null) :: Row(null, null, null, null) :: - // Row("e ", "e ", "", "e") :: Row("et ", "et ", "a ", "et ") :: - // Row("etE ", "etE ", "ar", "etE") :: - // Row("etEa ", "etEa ", "ark ", "etEa ") :: - // Row("etEas ", "etEas ", "ark", "etEas ") :: - // Row("etEase", "etEase ", "ark-", "etEase") :: Nil - // checkAnswer(df, expected ::: expected) - // } - // } - // } - // } - // } - - // test("char varchar over length values") { - // Seq("char", "varchar").foreach { typ => - // withTempPath { dir => - // withTable("t") { - // sql("select '123456' as col").write.format("parquet").save(dir.toString) - // sql(s"create table t (col $typ(2)) using parquet location '$dir'") - // sql("insert into t values('1')") - // checkSparkAnswerAndOperator(sql("select substring(col, 1) from t")) - // checkSparkAnswerAndOperator(sql("select substring(col, 0) from t")) - // checkSparkAnswerAndOperator(sql("select substring(col, -1) from t")) - // } - // } - // } - // } - - // test("like (LikeSimplification enabled)") { - // val table = "names" - // withTable(table) { - // sql(s"create table $table(id int, name varchar(20)) using parquet") - // sql(s"insert into $table values(1,'James Smith')") - // sql(s"insert into $table values(2,'Michael Rose')") - // sql(s"insert into $table values(3,'Robert Williams')") - // sql(s"insert into $table values(4,'Rames Rose')") - // sql(s"insert into $table values(5,'Rames rose')") - - // // Filter column having values 'Rames _ose', where any character matches for '_' - // val query = sql(s"select id from $table where name like 'Rames _ose'") - // checkAnswer(query, Row(4) :: Row(5) :: Nil) - - // // Filter rows that contains 'rose' in 'name' column - // val queryContains = sql(s"select id from $table where name like '%rose%'") - // checkAnswer(queryContains, Row(5) :: Nil) - - // // Filter rows that starts with 'R' following by any characters - // val queryStartsWith = sql(s"select id from $table where name like 'R%'") - // checkAnswer(queryStartsWith, Row(3) :: Row(4) :: Row(5) :: Nil) - - // // Filter rows that ends with 's' following by any characters - // val queryEndsWith = sql(s"select id from $table where name like '%s'") - // checkAnswer(queryEndsWith, Row(3) :: Nil) - // } - // } - - // test("like with custom escape") { - // val table = "names" - // withTable(table) { - // sql(s"create table $table(id int, name varchar(20)) using parquet") - // sql(s"insert into $table values(1,'James Smith')") - // sql(s"insert into $table values(2,'Michael_Rose')") - // sql(s"insert into $table values(3,'Robert_R_Williams')") - - // // Filter column having values that include underscores - // val queryDefaultEscape = sql("select id from names where name like '%\\_%'") - // checkSparkAnswerAndOperator(queryDefaultEscape) - - // val queryCustomEscape = sql("select id from names where name like '%$_%' escape '$'") - // checkAnswer(queryCustomEscape, Row(2) :: Row(3) :: Nil) - - // } - // } - - // test("contains") { - // assume(!isSpark32) - - // val table = "names" - // withTable(table) { - // sql(s"create table $table(id int, name varchar(20)) using parquet") - // sql(s"insert into $table values(1,'James Smith')") - // sql(s"insert into $table values(2,'Michael Rose')") - // sql(s"insert into $table values(3,'Robert Williams')") - // sql(s"insert into $table values(4,'Rames Rose')") - // sql(s"insert into $table values(5,'Rames rose')") - - // // Filter rows that contains 'rose' in 'name' column - // val queryContains = sql(s"select id from $table where contains (name, 'rose')") - // checkAnswer(queryContains, Row(5) :: Nil) - // } - // } - - // test("startswith") { - // assume(!isSpark32) - - // val table = "names" - // withTable(table) { - // sql(s"create table $table(id int, name varchar(20)) using parquet") - // sql(s"insert into $table values(1,'James Smith')") - // sql(s"insert into $table values(2,'Michael Rose')") - // sql(s"insert into $table values(3,'Robert Williams')") - // sql(s"insert into $table values(4,'Rames Rose')") - // sql(s"insert into $table values(5,'Rames rose')") - - // // Filter rows that starts with 'R' following by any characters - // val queryStartsWith = sql(s"select id from $table where startswith (name, 'R')") - // checkAnswer(queryStartsWith, Row(3) :: Row(4) :: Row(5) :: Nil) - // } - // } - - // test("endswith") { - // assume(!isSpark32) - - // val table = "names" - // withTable(table) { - // sql(s"create table $table(id int, name varchar(20)) using parquet") - // sql(s"insert into $table values(1,'James Smith')") - // sql(s"insert into $table values(2,'Michael Rose')") - // sql(s"insert into $table values(3,'Robert Williams')") - // sql(s"insert into $table values(4,'Rames Rose')") - // sql(s"insert into $table values(5,'Rames rose')") - - // // Filter rows that ends with 's' following by any characters - // val queryEndsWith = sql(s"select id from $table where endswith (name, 's')") - // checkAnswer(queryEndsWith, Row(3) :: Nil) - // } - // } - - // test("add overflow (ANSI disable)") { - // // Enabling ANSI will cause native engine failure, but as we cannot catch - // // native error now, we cannot test it here. - // withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { - // withParquetTable(Seq((Int.MaxValue, 1)), "tbl") { - // checkSparkAnswerAndOperator("SELECT _1 + _2 FROM tbl") - // } - // } - // } - - // test("divide by zero (ANSI disable)") { - // // Enabling ANSI will cause native engine failure, but as we cannot catch - // // native error now, we cannot test it here. - // withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { - // withParquetTable(Seq((1, 0, 1.0, 0.0)), "tbl") { - // checkSparkAnswerAndOperator("SELECT _1 / _2, _3 / _4 FROM tbl") - // } - // } - // } - - // test("decimals arithmetic and comparison") { - // // TODO: enable Spark 3.2 & 3.3 tests after supporting decimal reminder operation - // assume(isSpark34Plus) - - // def makeDecimalRDD(num: Int, decimal: DecimalType, useDictionary: Boolean): DataFrame = { - // val div = if (useDictionary) 5 else num // narrow the space to make it dictionary encoded - // spark - // .range(num) - // .map(_ % div) - // // Parquet doesn't allow column names with spaces, have to add an alias here. - // // Minus 500 here so that negative decimals are also tested. - // .select( - // (($"value" - 500) / 100.0) cast decimal as Symbol("dec1"), - // (($"value" - 600) / 100.0) cast decimal as Symbol("dec2")) - // .coalesce(1) - // } - - // Seq(true, false).foreach { dictionary => - // Seq(16, 1024).foreach { batchSize => - // withSQLConf( - // CometConf.COMET_BATCH_SIZE.key -> batchSize.toString, - // SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false", - // "parquet.enable.dictionary" -> dictionary.toString) { - // var combinations = Seq((5, 2), (1, 0), (18, 10), (18, 17), (19, 0), (38, 37)) - // // If ANSI mode is on, the combination (1, 1) will cause a runtime error. Otherwise, the - // // decimal RDD contains all null values and should be able to read back from Parquet. - - // if (!SQLConf.get.ansiEnabled) { - // combinations = combinations ++ Seq((1, 1)) - // } - - // for ((precision, scale) <- combinations) { - // withTempPath { dir => - // val data = makeDecimalRDD(10, DecimalType(precision, scale), dictionary) - // data.write.parquet(dir.getCanonicalPath) - // readParquetFile(dir.getCanonicalPath) { df => - // { - // val decimalLiteral1 = Decimal(1.00) - // val decimalLiteral2 = Decimal(123.456789) - // val cometDf = df.select( - // $"dec1" + $"dec2", - // $"dec1" - $"dec2", - // $"dec1" % $"dec2", - // $"dec1" >= $"dec1", - // $"dec1" === "1.0", - // $"dec1" + decimalLiteral1, - // $"dec1" - decimalLiteral1, - // $"dec1" + decimalLiteral2, - // $"dec1" - decimalLiteral2) - - // checkAnswer( - // cometDf, - // data - // .select( - // $"dec1" + $"dec2", - // $"dec1" - $"dec2", - // $"dec1" % $"dec2", - // $"dec1" >= $"dec1", - // $"dec1" === "1.0", - // $"dec1" + decimalLiteral1, - // $"dec1" - decimalLiteral1, - // $"dec1" + decimalLiteral2, - // $"dec1" - decimalLiteral2) - // .collect() - // .toSeq) - // } - // } - // } - // } - // } - // } - // } - // } - - // test("scalar decimal arithmetic operations") { - // assume(isSpark34Plus) - // withTable("tbl") { - // withSQLConf(CometConf.COMET_ENABLED.key -> "true") { - // sql("CREATE TABLE tbl (a INT) USING PARQUET") - // sql("INSERT INTO tbl VALUES (0)") - - // val combinations = Seq((7, 3), (18, 10), (38, 4)) - // for ((precision, scale) <- combinations) { - // for (op <- Seq("+", "-", "*", "/", "%")) { - // val left = s"CAST(1.00 AS DECIMAL($precision, $scale))" - // val right = s"CAST(123.45 AS DECIMAL($precision, $scale))" - - // withSQLConf( - // "spark.sql.optimizer.excludedRules" -> - // "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { - - // checkSparkAnswerAndOperator(s"SELECT $left $op $right FROM tbl") - // } - // } - // } - // } - // } - // } - - // test("cast decimals to int") { - // Seq(16, 1024).foreach { batchSize => - // withSQLConf( - // CometConf.COMET_BATCH_SIZE.key -> batchSize.toString, - // SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false") { - // var combinations = Seq((5, 2), (1, 0), (18, 10), (18, 17), (19, 0), (38, 37)) - // // If ANSI mode is on, the combination (1, 1) will cause a runtime error. Otherwise, the - // // decimal RDD contains all null values and should be able to read back from Parquet. - - // if (!SQLConf.get.ansiEnabled) { - // combinations = combinations ++ Seq((1, 1)) - // } - - // for ((precision, scale) <- combinations; useDictionary <- Seq(false)) { - // withTempPath { dir => - // val data = makeDecimalRDD(10, DecimalType(precision, scale), useDictionary) - // data.write.parquet(dir.getCanonicalPath) - // readParquetFile(dir.getCanonicalPath) { df => - // { - // val cometDf = df.select($"dec".cast("int")) - - // // `data` is not read from Parquet, so it doesn't go Comet exec. - // checkAnswer(cometDf, data.select($"dec".cast("int")).collect().toSeq) - // } - // } - // } - // } - // } - // } - // } - - // test("various math scalar functions") { - // Seq("true", "false").foreach { dictionary => - // withSQLConf("parquet.enable.dictionary" -> dictionary) { - // withParquetTable( - // (0 until 5).map(i => (i.toDouble + 0.3, i.toDouble + 0.8)), - // "tbl", - // withDictionary = dictionary.toBoolean) { - // checkSparkAnswerWithTol( - // "SELECT abs(_1), acos(_2), asin(_1), atan(_2), atan2(_1, _2), cos(_1) FROM tbl") - // checkSparkAnswerWithTol( - // "SELECT exp(_1), ln(_2), log10(_1), log2(_1), pow(_1, _2) FROM tbl") - // // TODO: comment in the round tests once supported - // // checkSparkAnswerWithTol("SELECT round(_1), round(_2) FROM tbl") - // checkSparkAnswerWithTol("SELECT signum(_1), sin(_1), sqrt(_1) FROM tbl") - // checkSparkAnswerWithTol("SELECT tan(_1) FROM tbl") - // } - // } - // } - // } - - // test("abs") { - // Seq(true, false).foreach { dictionaryEnabled => - // withTempDir { dir => - // val path = new Path(dir.toURI.toString, "test.parquet") - // makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 100) - // withParquetTable(path.toString, "tbl") { - // Seq(2, 3, 4, 5, 6, 7, 15, 16, 17).foreach { col => - // checkSparkAnswerAndOperator(s"SELECT abs(_${col}) FROM tbl") - // } - // } - // } - // } - // } + test("decimals divide by zero") { + // TODO: enable Spark 3.2 & 3.3 tests after supporting decimal divide operation + assume(isSpark34Plus) + + Seq(true, false).foreach { dictionary => + withSQLConf( + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false", + "parquet.enable.dictionary" -> dictionary.toString) { + withTempPath { dir => + val data = makeDecimalRDD(10, DecimalType(18, 10), dictionary) + data.write.parquet(dir.getCanonicalPath) + readParquetFile(dir.getCanonicalPath) { df => + { + val decimalLiteral = Decimal(0.00) + val cometDf = df.select($"dec" / decimalLiteral, $"dec" % decimalLiteral) + checkSparkAnswerAndOperator(cometDf) + } + } + } + } + } + } + + test("bitwise shift with different left/right types") { + Seq(false, true).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val table = "test" + withTable(table) { + sql(s"create table $table(col1 long, col2 int) using parquet") + sql(s"insert into $table values(1111, 2)") + sql(s"insert into $table values(1111, 2)") + sql(s"insert into $table values(3333, 4)") + sql(s"insert into $table values(5555, 6)") + + checkSparkAnswerAndOperator( + s"SELECT shiftright(col1, 2), shiftright(col1, col2) FROM $table") + checkSparkAnswerAndOperator( + s"SELECT shiftleft(col1, 2), shiftleft(col1, col2) FROM $table") + } + } + } + } + + test("basic data type support") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + withParquetTable(path.toString, "tbl") { + // TODO: enable test for unsigned ints + checkSparkAnswerAndOperator( + "select _1, _2, _3, _4, _5, _6, _7, _8, _13, _14, _15, _16, _17, " + + "_18, _19, _20 FROM tbl WHERE _2 > 100") + } + } + } + } + + test("null literals") { + val batchSize = 1000 + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, batchSize) + withParquetTable(path.toString, "tbl") { + val sqlString = "SELECT _4 + null, _15 - null, _16 * null FROM tbl" + val df2 = sql(sqlString) + val rows = df2.collect() + assert(rows.length == batchSize) + assert(rows.forall(_ == Row(null, null, null))) + + checkSparkAnswerAndOperator(sqlString) + } + } + } + } + + test("date and timestamp type literals") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + withParquetTable(path.toString, "tbl") { + checkSparkAnswerAndOperator( + "SELECT _4 FROM tbl WHERE " + + "_20 > CAST('2020-01-01' AS DATE) AND _18 < CAST('2020-01-01' AS TIMESTAMP)") + } + } + } + } + + test("dictionary arithmetic") { + // TODO: test ANSI mode + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false", "parquet.enable.dictionary" -> "true") { + withParquetTable((0 until 10).map(i => (i % 5, i % 3)), "tbl") { + checkSparkAnswerAndOperator("SELECT _1 + _2, _1 - _2, _1 * _2, _1 / _2, _1 % _2 FROM tbl") + } + } + } + + test("dictionary arithmetic with scalar") { + withSQLConf("parquet.enable.dictionary" -> "true") { + withParquetTable((0 until 10).map(i => (i % 5, i % 3)), "tbl") { + checkSparkAnswerAndOperator("SELECT _1 + 1, _1 - 1, _1 * 2, _1 / 2, _1 % 2 FROM tbl") + } + } + } + + test("string type and substring") { + withParquetTable((0 until 5).map(i => (i.toString, (i + 100).toString)), "tbl") { + checkSparkAnswerAndOperator("SELECT _1, substring(_2, 2, 2) FROM tbl") + checkSparkAnswerAndOperator("SELECT _1, substring(_2, 2, -2) FROM tbl") + checkSparkAnswerAndOperator("SELECT _1, substring(_2, -2, 2) FROM tbl") + checkSparkAnswerAndOperator("SELECT _1, substring(_2, -2, -2) FROM tbl") + checkSparkAnswerAndOperator("SELECT _1, substring(_2, -2, 10) FROM tbl") + checkSparkAnswerAndOperator("SELECT _1, substring(_2, 0, 0) FROM tbl") + checkSparkAnswerAndOperator("SELECT _1, substring(_2, 1, 0) FROM tbl") + } + } + + test("substring with start < 1") { + withTempPath { _ => + withTable("t") { + sql("create table t (col string) using parquet") + sql("insert into t values('123456')") + checkSparkAnswerAndOperator(sql("select substring(col, 0) from t")) + checkSparkAnswerAndOperator(sql("select substring(col, -1) from t")) + } + } + } + + test("string with coalesce") { + withParquetTable( + (0 until 10).map(i => (i.toString, if (i > 5) None else Some((i + 100).toString))), + "tbl") { + checkSparkAnswerAndOperator( + "SELECT coalesce(_1), coalesce(_1, 1), coalesce(null, _1), coalesce(null, 1), coalesce(_2, _1), coalesce(null) FROM tbl") + } + } + + test("substring with dictionary") { + val data = (0 until 1000) + .map(_ % 5) // reduce value space to trigger dictionary encoding + .map(i => (i.toString, (i + 100).toString)) + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator("SELECT _1, substring(_2, 2, 2) FROM tbl") + } + } + + test("string_space") { + withParquetTable((0 until 5).map(i => (i, i + 1)), "tbl") { + checkSparkAnswerAndOperator("SELECT space(_1), space(_2) FROM tbl") + } + } + + test("string_space with dictionary") { + val data = (0 until 1000).map(i => Tuple1(i % 5)) + + withSQLConf("parquet.enable.dictionary" -> "true") { + withParquetTable(data, "tbl") { + checkSparkAnswerAndOperator("SELECT space(_1) FROM tbl") + } + } + } + + test("hour, minute, second") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "part-r-0.parquet") + val expected = makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) + readParquetFile(path.toString) { df => + val query = df.select(expr("hour(_1)"), expr("minute(_1)"), expr("second(_1)")) + + checkAnswer( + query, + expected.map { + case None => + Row(null, null, null) + case Some(i) => + val timestamp = new java.sql.Timestamp(i).toLocalDateTime + val hour = timestamp.getHour + val minute = timestamp.getMinute + val second = timestamp.getSecond + + Row(hour, minute, second) + }) + } + } + } + } + + test("hour on int96 timestamp column") { + import testImplicits._ + + val N = 100 + val ts = "2020-01-01 01:02:03.123456" + Seq(true, false).foreach { dictionaryEnabled => + Seq(false, true).foreach { conversionEnabled => + withSQLConf( + SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> "INT96", + SQLConf.PARQUET_INT96_TIMESTAMP_CONVERSION.key -> conversionEnabled.toString) { + withTempPath { path => + Seq + .tabulate(N)(_ => ts) + .toDF("ts1") + .select($"ts1".cast("timestamp").as("ts")) + .repartition(1) + .write + .option("parquet.enable.dictionary", dictionaryEnabled) + .parquet(path.getCanonicalPath) + + checkAnswer( + spark.read.parquet(path.getCanonicalPath).select(expr("hour(ts)")), + Seq.tabulate(N)(_ => Row(1))) + } + } + } + } + } + + test("cast timestamp and timestamp_ntz") { + withSQLConf( + SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") + makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) + withParquetTable(path.toString, "timetbl") { + checkSparkAnswerAndOperator( + "SELECT " + + "cast(_2 as timestamp) tz_millis, " + + "cast(_3 as timestamp) ntz_millis, " + + "cast(_4 as timestamp) tz_micros, " + + "cast(_5 as timestamp) ntz_micros " + + " from timetbl") + } + } + } + } + } + + test("cast timestamp and timestamp_ntz to string") { + // TODO: make the test pass for Spark 3.2 & 3.3 + assume(isSpark34Plus) + + withSQLConf( + SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") + makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 2001) + withParquetTable(path.toString, "timetbl") { + checkSparkAnswerAndOperator( + "SELECT " + + "cast(_2 as string) tz_millis, " + + "cast(_3 as string) ntz_millis, " + + "cast(_4 as string) tz_micros, " + + "cast(_5 as string) ntz_micros " + + " from timetbl") + } + } + } + } + } + + test("cast timestamp and timestamp_ntz to long, date") { + // TODO: make the test pass for Spark 3.2 & 3.3 + assume(isSpark34Plus) + + withSQLConf( + SESSION_LOCAL_TIMEZONE.key -> "Asia/Kathmandu", + CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") + makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) + withParquetTable(path.toString, "timetbl") { + checkSparkAnswerAndOperator( + "SELECT " + + "cast(_2 as long) tz_millis, " + + "cast(_4 as long) tz_micros, " + + "cast(_2 as date) tz_millis_to_date, " + + "cast(_3 as date) ntz_millis_to_date, " + + "cast(_4 as date) tz_micros_to_date, " + + "cast(_5 as date) ntz_micros_to_date " + + " from timetbl") + } + } + } + } + } + + test("trunc") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "date_trunc.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 10000) + withParquetTable(path.toString, "tbl") { + Seq("YEAR", "YYYY", "YY", "QUARTER", "MON", "MONTH", "MM", "WEEK").foreach { format => + checkSparkAnswerAndOperator(s"SELECT trunc(_20, '$format') from tbl") + } + } + } + } + } + + test("trunc with format array") { + val numRows = 1000 + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "date_trunc_with_format.parquet") + makeDateTimeWithFormatTable(path, dictionaryEnabled = dictionaryEnabled, numRows) + withParquetTable(path.toString, "dateformattbl") { + checkSparkAnswerAndOperator( + "SELECT " + + "dateformat, _7, " + + "trunc(_7, dateformat) " + + " from dateformattbl ") + } + } + } + } + + test("date_trunc") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") + makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) + withParquetTable(path.toString, "timetbl") { + Seq( + "YEAR", + "YYYY", + "YY", + "MON", + "MONTH", + "MM", + "QUARTER", + "WEEK", + "DAY", + "DD", + "HOUR", + "MINUTE", + "SECOND", + "MILLISECOND", + "MICROSECOND").foreach { format => + checkSparkAnswerAndOperator( + "SELECT " + + s"date_trunc('$format', _0), " + + s"date_trunc('$format', _1), " + + s"date_trunc('$format', _2), " + + s"date_trunc('$format', _4) " + + " from timetbl") + } + } + } + } + } + + test("date_trunc with timestamp_ntz") { + assume(!isSpark32, "timestamp functions for timestamp_ntz have incorrect behavior in 3.2") + withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "timestamp_trunc.parquet") + makeRawTimeParquetFile(path, dictionaryEnabled = dictionaryEnabled, 10000) + withParquetTable(path.toString, "timetbl") { + Seq( + "YEAR", + "YYYY", + "YY", + "MON", + "MONTH", + "MM", + "QUARTER", + "WEEK", + "DAY", + "DD", + "HOUR", + "MINUTE", + "SECOND", + "MILLISECOND", + "MICROSECOND").foreach { format => + checkSparkAnswerAndOperator( + "SELECT " + + s"date_trunc('$format', _3), " + + s"date_trunc('$format', _5) " + + " from timetbl") + } + } + } + } + } + } + + test("date_trunc with format array") { + assume(isSpark33Plus, "TimestampNTZ is supported in Spark 3.3+, See SPARK-36182") + withSQLConf(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key -> "true") { + val numRows = 1000 + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "timestamp_trunc_with_format.parquet") + makeDateTimeWithFormatTable(path, dictionaryEnabled = dictionaryEnabled, numRows) + withParquetTable(path.toString, "timeformattbl") { + checkSparkAnswerAndOperator( + "SELECT " + + "format, _0, _1, _2, _3, _4, _5, " + + "date_trunc(format, _0), " + + "date_trunc(format, _1), " + + "date_trunc(format, _2), " + + "date_trunc(format, _3), " + + "date_trunc(format, _4), " + + "date_trunc(format, _5) " + + " from timeformattbl ") + } + } + } + } + } + + test("date_trunc on int96 timestamp column") { + import testImplicits._ + + val N = 100 + val ts = "2020-01-01 01:02:03.123456" + Seq(true, false).foreach { dictionaryEnabled => + Seq(false, true).foreach { conversionEnabled => + withSQLConf( + SQLConf.PARQUET_OUTPUT_TIMESTAMP_TYPE.key -> "INT96", + SQLConf.PARQUET_INT96_TIMESTAMP_CONVERSION.key -> conversionEnabled.toString) { + withTempPath { path => + Seq + .tabulate(N)(_ => ts) + .toDF("ts1") + .select($"ts1".cast("timestamp").as("ts")) + .repartition(1) + .write + .option("parquet.enable.dictionary", dictionaryEnabled) + .parquet(path.getCanonicalPath) + + withParquetTable(path.toString, "int96timetbl") { + Seq( + "YEAR", + "YYYY", + "YY", + "MON", + "MONTH", + "MM", + "QUARTER", + "WEEK", + "DAY", + "DD", + "HOUR", + "MINUTE", + "SECOND", + "MILLISECOND", + "MICROSECOND").foreach { format => + checkSparkAnswer( + "SELECT " + + s"date_trunc('$format', ts )" + + " from int96timetbl") + } + } + } + } + } + } + } + + test("charvarchar") { + Seq(false, true).foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + val table = "char_tbl4" + withTable(table) { + val view = "str_view" + withView(view) { + sql(s"""create temporary view $view as select c, v from values + | (null, null), (null, null), + | (null, 'S'), (null, 'S'), + | ('N', 'N '), ('N', 'N '), + | ('Ne', 'Sp'), ('Ne', 'Sp'), + | ('Net ', 'Spa '), ('Net ', 'Spa '), + | ('NetE', 'Spar'), ('NetE', 'Spar'), + | ('NetEa ', 'Spark '), ('NetEa ', 'Spark '), + | ('NetEas ', 'Spark'), ('NetEas ', 'Spark'), + | ('NetEase', 'Spark-'), ('NetEase', 'Spark-') t(c, v);""".stripMargin) + sql( + s"create table $table(c7 char(7), c8 char(8), v varchar(6), s string) using parquet;") + sql(s"insert into $table select c, c, v, c from $view;") + val df = sql(s"""select substring(c7, 2), substring(c8, 2), + | substring(v, 3), substring(s, 2) from $table;""".stripMargin) + + val expected = Row(" ", " ", "", "") :: + Row(null, null, "", null) :: Row(null, null, null, null) :: + Row("e ", "e ", "", "e") :: Row("et ", "et ", "a ", "et ") :: + Row("etE ", "etE ", "ar", "etE") :: + Row("etEa ", "etEa ", "ark ", "etEa ") :: + Row("etEas ", "etEas ", "ark", "etEas ") :: + Row("etEase", "etEase ", "ark-", "etEase") :: Nil + checkAnswer(df, expected ::: expected) + } + } + } + } + } + + test("char varchar over length values") { + Seq("char", "varchar").foreach { typ => + withTempPath { dir => + withTable("t") { + sql("select '123456' as col").write.format("parquet").save(dir.toString) + sql(s"create table t (col $typ(2)) using parquet location '$dir'") + sql("insert into t values('1')") + checkSparkAnswerAndOperator(sql("select substring(col, 1) from t")) + checkSparkAnswerAndOperator(sql("select substring(col, 0) from t")) + checkSparkAnswerAndOperator(sql("select substring(col, -1) from t")) + } + } + } + } + + test("like (LikeSimplification enabled)") { + val table = "names" + withTable(table) { + sql(s"create table $table(id int, name varchar(20)) using parquet") + sql(s"insert into $table values(1,'James Smith')") + sql(s"insert into $table values(2,'Michael Rose')") + sql(s"insert into $table values(3,'Robert Williams')") + sql(s"insert into $table values(4,'Rames Rose')") + sql(s"insert into $table values(5,'Rames rose')") + + // Filter column having values 'Rames _ose', where any character matches for '_' + val query = sql(s"select id from $table where name like 'Rames _ose'") + checkAnswer(query, Row(4) :: Row(5) :: Nil) + + // Filter rows that contains 'rose' in 'name' column + val queryContains = sql(s"select id from $table where name like '%rose%'") + checkAnswer(queryContains, Row(5) :: Nil) + + // Filter rows that starts with 'R' following by any characters + val queryStartsWith = sql(s"select id from $table where name like 'R%'") + checkAnswer(queryStartsWith, Row(3) :: Row(4) :: Row(5) :: Nil) + + // Filter rows that ends with 's' following by any characters + val queryEndsWith = sql(s"select id from $table where name like '%s'") + checkAnswer(queryEndsWith, Row(3) :: Nil) + } + } + + test("like with custom escape") { + val table = "names" + withTable(table) { + sql(s"create table $table(id int, name varchar(20)) using parquet") + sql(s"insert into $table values(1,'James Smith')") + sql(s"insert into $table values(2,'Michael_Rose')") + sql(s"insert into $table values(3,'Robert_R_Williams')") + + // Filter column having values that include underscores + val queryDefaultEscape = sql("select id from names where name like '%\\_%'") + checkSparkAnswerAndOperator(queryDefaultEscape) + + val queryCustomEscape = sql("select id from names where name like '%$_%' escape '$'") + checkAnswer(queryCustomEscape, Row(2) :: Row(3) :: Nil) + + } + } + + test("contains") { + assume(!isSpark32) + + val table = "names" + withTable(table) { + sql(s"create table $table(id int, name varchar(20)) using parquet") + sql(s"insert into $table values(1,'James Smith')") + sql(s"insert into $table values(2,'Michael Rose')") + sql(s"insert into $table values(3,'Robert Williams')") + sql(s"insert into $table values(4,'Rames Rose')") + sql(s"insert into $table values(5,'Rames rose')") + + // Filter rows that contains 'rose' in 'name' column + val queryContains = sql(s"select id from $table where contains (name, 'rose')") + checkAnswer(queryContains, Row(5) :: Nil) + } + } + + test("startswith") { + assume(!isSpark32) + + val table = "names" + withTable(table) { + sql(s"create table $table(id int, name varchar(20)) using parquet") + sql(s"insert into $table values(1,'James Smith')") + sql(s"insert into $table values(2,'Michael Rose')") + sql(s"insert into $table values(3,'Robert Williams')") + sql(s"insert into $table values(4,'Rames Rose')") + sql(s"insert into $table values(5,'Rames rose')") + + // Filter rows that starts with 'R' following by any characters + val queryStartsWith = sql(s"select id from $table where startswith (name, 'R')") + checkAnswer(queryStartsWith, Row(3) :: Row(4) :: Row(5) :: Nil) + } + } + + test("endswith") { + assume(!isSpark32) + + val table = "names" + withTable(table) { + sql(s"create table $table(id int, name varchar(20)) using parquet") + sql(s"insert into $table values(1,'James Smith')") + sql(s"insert into $table values(2,'Michael Rose')") + sql(s"insert into $table values(3,'Robert Williams')") + sql(s"insert into $table values(4,'Rames Rose')") + sql(s"insert into $table values(5,'Rames rose')") + + // Filter rows that ends with 's' following by any characters + val queryEndsWith = sql(s"select id from $table where endswith (name, 's')") + checkAnswer(queryEndsWith, Row(3) :: Nil) + } + } + + test("add overflow (ANSI disable)") { + // Enabling ANSI will cause native engine failure, but as we cannot catch + // native error now, we cannot test it here. + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + withParquetTable(Seq((Int.MaxValue, 1)), "tbl") { + checkSparkAnswerAndOperator("SELECT _1 + _2 FROM tbl") + } + } + } + + test("divide by zero (ANSI disable)") { + // Enabling ANSI will cause native engine failure, but as we cannot catch + // native error now, we cannot test it here. + withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { + withParquetTable(Seq((1, 0, 1.0, 0.0)), "tbl") { + checkSparkAnswerAndOperator("SELECT _1 / _2, _3 / _4 FROM tbl") + } + } + } + + test("decimals arithmetic and comparison") { + // TODO: enable Spark 3.2 & 3.3 tests after supporting decimal reminder operation + assume(isSpark34Plus) + + def makeDecimalRDD(num: Int, decimal: DecimalType, useDictionary: Boolean): DataFrame = { + val div = if (useDictionary) 5 else num // narrow the space to make it dictionary encoded + spark + .range(num) + .map(_ % div) + // Parquet doesn't allow column names with spaces, have to add an alias here. + // Minus 500 here so that negative decimals are also tested. + .select( + (($"value" - 500) / 100.0) cast decimal as Symbol("dec1"), + (($"value" - 600) / 100.0) cast decimal as Symbol("dec2")) + .coalesce(1) + } + + Seq(true, false).foreach { dictionary => + Seq(16, 1024).foreach { batchSize => + withSQLConf( + CometConf.COMET_BATCH_SIZE.key -> batchSize.toString, + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false", + "parquet.enable.dictionary" -> dictionary.toString) { + var combinations = Seq((5, 2), (1, 0), (18, 10), (18, 17), (19, 0), (38, 37)) + // If ANSI mode is on, the combination (1, 1) will cause a runtime error. Otherwise, the + // decimal RDD contains all null values and should be able to read back from Parquet. + + if (!SQLConf.get.ansiEnabled) { + combinations = combinations ++ Seq((1, 1)) + } + + for ((precision, scale) <- combinations) { + withTempPath { dir => + val data = makeDecimalRDD(10, DecimalType(precision, scale), dictionary) + data.write.parquet(dir.getCanonicalPath) + readParquetFile(dir.getCanonicalPath) { df => + { + val decimalLiteral1 = Decimal(1.00) + val decimalLiteral2 = Decimal(123.456789) + val cometDf = df.select( + $"dec1" + $"dec2", + $"dec1" - $"dec2", + $"dec1" % $"dec2", + $"dec1" >= $"dec1", + $"dec1" === "1.0", + $"dec1" + decimalLiteral1, + $"dec1" - decimalLiteral1, + $"dec1" + decimalLiteral2, + $"dec1" - decimalLiteral2) + + checkAnswer( + cometDf, + data + .select( + $"dec1" + $"dec2", + $"dec1" - $"dec2", + $"dec1" % $"dec2", + $"dec1" >= $"dec1", + $"dec1" === "1.0", + $"dec1" + decimalLiteral1, + $"dec1" - decimalLiteral1, + $"dec1" + decimalLiteral2, + $"dec1" - decimalLiteral2) + .collect() + .toSeq) + } + } + } + } + } + } + } + } + + test("scalar decimal arithmetic operations") { + assume(isSpark34Plus) + withTable("tbl") { + withSQLConf(CometConf.COMET_ENABLED.key -> "true") { + sql("CREATE TABLE tbl (a INT) USING PARQUET") + sql("INSERT INTO tbl VALUES (0)") + + val combinations = Seq((7, 3), (18, 10), (38, 4)) + for ((precision, scale) <- combinations) { + for (op <- Seq("+", "-", "*", "/", "%")) { + val left = s"CAST(1.00 AS DECIMAL($precision, $scale))" + val right = s"CAST(123.45 AS DECIMAL($precision, $scale))" + + withSQLConf( + "spark.sql.optimizer.excludedRules" -> + "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { + + checkSparkAnswerAndOperator(s"SELECT $left $op $right FROM tbl") + } + } + } + } + } + } + + test("cast decimals to int") { + Seq(16, 1024).foreach { batchSize => + withSQLConf( + CometConf.COMET_BATCH_SIZE.key -> batchSize.toString, + SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "false") { + var combinations = Seq((5, 2), (1, 0), (18, 10), (18, 17), (19, 0), (38, 37)) + // If ANSI mode is on, the combination (1, 1) will cause a runtime error. Otherwise, the + // decimal RDD contains all null values and should be able to read back from Parquet. + + if (!SQLConf.get.ansiEnabled) { + combinations = combinations ++ Seq((1, 1)) + } + + for ((precision, scale) <- combinations; useDictionary <- Seq(false)) { + withTempPath { dir => + val data = makeDecimalRDD(10, DecimalType(precision, scale), useDictionary) + data.write.parquet(dir.getCanonicalPath) + readParquetFile(dir.getCanonicalPath) { df => + { + val cometDf = df.select($"dec".cast("int")) + + // `data` is not read from Parquet, so it doesn't go Comet exec. + checkAnswer(cometDf, data.select($"dec".cast("int")).collect().toSeq) + } + } + } + } + } + } + } + + test("various math scalar functions") { + Seq("true", "false").foreach { dictionary => + withSQLConf("parquet.enable.dictionary" -> dictionary) { + withParquetTable( + (0 until 5).map(i => (i.toDouble + 0.3, i.toDouble + 0.8)), + "tbl", + withDictionary = dictionary.toBoolean) { + checkSparkAnswerWithTol( + "SELECT abs(_1), acos(_2), asin(_1), atan(_2), atan2(_1, _2), cos(_1) FROM tbl") + checkSparkAnswerWithTol( + "SELECT exp(_1), ln(_2), log10(_1), log2(_1), pow(_1, _2) FROM tbl") + // TODO: comment in the round tests once supported + // checkSparkAnswerWithTol("SELECT round(_1), round(_2) FROM tbl") + checkSparkAnswerWithTol("SELECT signum(_1), sin(_1), sqrt(_1) FROM tbl") + checkSparkAnswerWithTol("SELECT tan(_1) FROM tbl") + } + } + } + } + + test("abs") { + Seq(true, false).foreach { dictionaryEnabled => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "test.parquet") + makeParquetFileAllTypes(path, dictionaryEnabled = dictionaryEnabled, 100) + withParquetTable(path.toString, "tbl") { + Seq(2, 3, 4, 5, 6, 7, 15, 16, 17).foreach { col => + checkSparkAnswerAndOperator(s"SELECT abs(_${col}) FROM tbl") + } + } + } + } + } test("remainder") { withTempDir { dir => From d82ffbf84b40d09d61e879c995f622a1130c64ce Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Fri, 21 Jun 2024 13:10:21 +0530 Subject: [PATCH 04/18] modulo expression wrapper to improve spark's compatibility --- .../execution/datafusion/expressions/mod.rs | 1 + .../datafusion/expressions/modulo.rs | 191 ++++++++++++++++++ core/src/execution/datafusion/planner.rs | 10 +- .../apache/comet/CometExpressionSuite.scala | 2 +- 4 files changed, 201 insertions(+), 3 deletions(-) create mode 100644 core/src/execution/datafusion/expressions/modulo.rs diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index 5d5f58e0c2..a19f80af4e 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -33,6 +33,7 @@ pub mod avg_decimal; pub mod bloom_filter_might_contain; pub mod correlation; pub mod covariance; +pub mod modulo; pub mod negative; pub mod stats; pub mod stddev; diff --git a/core/src/execution/datafusion/expressions/modulo.rs b/core/src/execution/datafusion/expressions/modulo.rs new file mode 100644 index 0000000000..27724124c4 --- /dev/null +++ b/core/src/execution/datafusion/expressions/modulo.rs @@ -0,0 +1,191 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; +use std::fmt; +use std::sync::Arc; + +use arrow::datatypes::Schema; +use datafusion::arrow::record_batch::RecordBatch; +use datafusion::error::Result; +use datafusion::physical_plan::expressions::BinaryExpr; +use datafusion::physical_plan::{ColumnarValue, PhysicalExpr}; +use datafusion_common::ScalarValue; +use datafusion_expr::Operator; + +#[derive(Debug, Hash, Clone)] +pub struct ModuloExpr { + left: Arc, + op: Operator, + right: Arc, +} + +impl ModuloExpr { + pub fn new(left: Arc, right: Arc) -> Self { + Self { + left, + op: Operator::Modulo, + right, + } + } +} + +impl fmt::Display for ModuloExpr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{} % {}", self.left, self.right) + } +} + +impl PartialEq for ModuloExpr { + fn eq(&self, other: &(dyn Any + 'static)) -> bool { + if let Some(other) = other.downcast_ref::() { + self.left.eq(&other.left) && self.op == other.op && self.right.eq(&other.right) + } else { + false + } + } +} + +impl PhysicalExpr for ModuloExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let rhs = self.right.evaluate(batch)?; + + // Following to match spark's behavior for modulo with -0.0 + match rhs { + ColumnarValue::Scalar(ScalarValue::Float64(Some(val))) => { + if val == -0.0 { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); + } + } + ColumnarValue::Array(arr) => { + if arr.data_type() == &arrow::datatypes::DataType::Float64 { + let float64_array = arr + .as_any() + .downcast_ref::() + .unwrap(); + + for i in 0..float64_array.len() { + if float64_array.value(i) == -0.0 { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); + } + } + } + } + _ => {} + } + // Otherwise, use the BinaryExpr implementation for modulo + let binary_expr = BinaryExpr::new(Arc::clone(&self.left), self.op, Arc::clone(&self.right)); + binary_expr.evaluate(batch) + } + + fn evaluate_selection( + &self, + batch: &RecordBatch, + selection: &arrow_array::BooleanArray, + ) -> Result { + let rhs = self.right.evaluate_selection(batch, selection)?; + + // Check if the right operand is a ScalarValue of Float64 and equal to -0.0 + match rhs { + ColumnarValue::Scalar(ScalarValue::Float64(Some(val))) => { + if val == -0.0 { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); + } + } + ColumnarValue::Array(arr) => { + if arr.data_type() == &arrow::datatypes::DataType::Float64 { + let float64_array = arr + .as_any() + .downcast_ref::() + .unwrap(); + + for i in 0..float64_array.len() { + if float64_array.value(i) == -0.0 { + return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); + } + } + } + } + _ => {} + } + + // Otherwise, use the BinaryExpr implementation for modulo + let binary_expr = BinaryExpr::new(Arc::clone(&self.left), self.op, Arc::clone(&self.right)); + binary_expr.evaluate_selection(batch, selection) + } + + fn evaluate_bounds( + &self, + _children: &[&datafusion_expr::interval_arithmetic::Interval], + ) -> Result { + datafusion_common::not_impl_err!("Not implemented for {self}") + } + + fn propagate_constraints( + &self, + _interval: &datafusion_expr::interval_arithmetic::Interval, + _children: &[&datafusion_expr::interval_arithmetic::Interval], + ) -> Result>> { + Ok(Some(vec![])) + } + + fn get_properties( + &self, + _children: &[datafusion_expr::sort_properties::ExprProperties], + ) -> Result { + Ok(datafusion_expr::sort_properties::ExprProperties::new_unknown()) + } + + fn data_type(&self, input_schema: &Schema) -> Result { + let binary_expr = BinaryExpr::new(Arc::clone(&self.left), self.op, Arc::clone(&self.right)); + binary_expr.data_type(input_schema) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + let binary_expr = BinaryExpr::new(Arc::clone(&self.left), self.op, Arc::clone(&self.right)); + binary_expr.nullable(input_schema) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + if children.len() == 2 { + Ok(Arc::new(ModuloExpr::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + ))) + } else { + Err(datafusion::error::DataFusionError::Internal( + "Invalid number of children".to_string(), + )) + } + } + + fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) { + let binary_expr = BinaryExpr::new(Arc::clone(&self.left), self.op, Arc::clone(&self.right)); + binary_expr.dyn_hash(state) + } +} diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index cd9822d669..954cafdd1b 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -97,7 +97,7 @@ use crate::{ }, }; -use super::expressions::{abs::CometAbsFunc, EvalMode}; +use super::expressions::{abs::CometAbsFunc, modulo::ModuloExpr, EvalMode}; // For clippy error on type_complexity. type ExecResult = Result; @@ -681,7 +681,13 @@ impl PhysicalPlanner { data_type, ))) } - _ => Ok(Arc::new(BinaryExpr::new(left, op, right))), + _ => { + // Improves compatibility with Spark + if op == DataFusionOperator::Modulo { + return Ok(Arc::new(ModuloExpr::new(left, right))); + } + Ok(Arc::new(BinaryExpr::new(left, op, right))) + } } } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 2b7f8cd9dc..940bda6e8f 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -858,7 +858,7 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("remainder") { withTempDir { dir => - val df = Seq((-21840, -0.0)).toDF("c90", "c1") + val df = Seq((21840, -0.0)).toDF("c90", "c1") val path = new Path(dir.toURI.toString, "remainder_test.parquet").toString df.write.mode("overwrite").parquet(path) From b658eac452e3b866c01efe1bbbb4304f1e7eef8f Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Sat, 22 Jun 2024 08:12:48 +0530 Subject: [PATCH 05/18] changed the logic, modified nullifwhenprimitive function --- .../execution/datafusion/expressions/mod.rs | 1 - .../datafusion/expressions/modulo.rs | 191 ------------------ core/src/execution/datafusion/planner.rs | 10 +- .../apache/comet/serde/QueryPlanSerde.scala | 10 +- .../apache/comet/CometExpressionSuite.scala | 4 +- 5 files changed, 14 insertions(+), 202 deletions(-) delete mode 100644 core/src/execution/datafusion/expressions/modulo.rs diff --git a/core/src/execution/datafusion/expressions/mod.rs b/core/src/execution/datafusion/expressions/mod.rs index a19f80af4e..5d5f58e0c2 100644 --- a/core/src/execution/datafusion/expressions/mod.rs +++ b/core/src/execution/datafusion/expressions/mod.rs @@ -33,7 +33,6 @@ pub mod avg_decimal; pub mod bloom_filter_might_contain; pub mod correlation; pub mod covariance; -pub mod modulo; pub mod negative; pub mod stats; pub mod stddev; diff --git a/core/src/execution/datafusion/expressions/modulo.rs b/core/src/execution/datafusion/expressions/modulo.rs deleted file mode 100644 index 27724124c4..0000000000 --- a/core/src/execution/datafusion/expressions/modulo.rs +++ /dev/null @@ -1,191 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use std::any::Any; -use std::fmt; -use std::sync::Arc; - -use arrow::datatypes::Schema; -use datafusion::arrow::record_batch::RecordBatch; -use datafusion::error::Result; -use datafusion::physical_plan::expressions::BinaryExpr; -use datafusion::physical_plan::{ColumnarValue, PhysicalExpr}; -use datafusion_common::ScalarValue; -use datafusion_expr::Operator; - -#[derive(Debug, Hash, Clone)] -pub struct ModuloExpr { - left: Arc, - op: Operator, - right: Arc, -} - -impl ModuloExpr { - pub fn new(left: Arc, right: Arc) -> Self { - Self { - left, - op: Operator::Modulo, - right, - } - } -} - -impl fmt::Display for ModuloExpr { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{} % {}", self.left, self.right) - } -} - -impl PartialEq for ModuloExpr { - fn eq(&self, other: &(dyn Any + 'static)) -> bool { - if let Some(other) = other.downcast_ref::() { - self.left.eq(&other.left) && self.op == other.op && self.right.eq(&other.right) - } else { - false - } - } -} - -impl PhysicalExpr for ModuloExpr { - fn as_any(&self) -> &dyn Any { - self - } - - fn evaluate(&self, batch: &RecordBatch) -> Result { - let rhs = self.right.evaluate(batch)?; - - // Following to match spark's behavior for modulo with -0.0 - match rhs { - ColumnarValue::Scalar(ScalarValue::Float64(Some(val))) => { - if val == -0.0 { - return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); - } - } - ColumnarValue::Array(arr) => { - if arr.data_type() == &arrow::datatypes::DataType::Float64 { - let float64_array = arr - .as_any() - .downcast_ref::() - .unwrap(); - - for i in 0..float64_array.len() { - if float64_array.value(i) == -0.0 { - return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); - } - } - } - } - _ => {} - } - // Otherwise, use the BinaryExpr implementation for modulo - let binary_expr = BinaryExpr::new(Arc::clone(&self.left), self.op, Arc::clone(&self.right)); - binary_expr.evaluate(batch) - } - - fn evaluate_selection( - &self, - batch: &RecordBatch, - selection: &arrow_array::BooleanArray, - ) -> Result { - let rhs = self.right.evaluate_selection(batch, selection)?; - - // Check if the right operand is a ScalarValue of Float64 and equal to -0.0 - match rhs { - ColumnarValue::Scalar(ScalarValue::Float64(Some(val))) => { - if val == -0.0 { - return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); - } - } - ColumnarValue::Array(arr) => { - if arr.data_type() == &arrow::datatypes::DataType::Float64 { - let float64_array = arr - .as_any() - .downcast_ref::() - .unwrap(); - - for i in 0..float64_array.len() { - if float64_array.value(i) == -0.0 { - return Ok(ColumnarValue::Scalar(ScalarValue::Float64(None))); - } - } - } - } - _ => {} - } - - // Otherwise, use the BinaryExpr implementation for modulo - let binary_expr = BinaryExpr::new(Arc::clone(&self.left), self.op, Arc::clone(&self.right)); - binary_expr.evaluate_selection(batch, selection) - } - - fn evaluate_bounds( - &self, - _children: &[&datafusion_expr::interval_arithmetic::Interval], - ) -> Result { - datafusion_common::not_impl_err!("Not implemented for {self}") - } - - fn propagate_constraints( - &self, - _interval: &datafusion_expr::interval_arithmetic::Interval, - _children: &[&datafusion_expr::interval_arithmetic::Interval], - ) -> Result>> { - Ok(Some(vec![])) - } - - fn get_properties( - &self, - _children: &[datafusion_expr::sort_properties::ExprProperties], - ) -> Result { - Ok(datafusion_expr::sort_properties::ExprProperties::new_unknown()) - } - - fn data_type(&self, input_schema: &Schema) -> Result { - let binary_expr = BinaryExpr::new(Arc::clone(&self.left), self.op, Arc::clone(&self.right)); - binary_expr.data_type(input_schema) - } - - fn nullable(&self, input_schema: &Schema) -> Result { - let binary_expr = BinaryExpr::new(Arc::clone(&self.left), self.op, Arc::clone(&self.right)); - binary_expr.nullable(input_schema) - } - - fn children(&self) -> Vec<&Arc> { - vec![&self.left, &self.right] - } - - fn with_new_children( - self: Arc, - children: Vec>, - ) -> Result> { - if children.len() == 2 { - Ok(Arc::new(ModuloExpr::new( - Arc::clone(&children[0]), - Arc::clone(&children[1]), - ))) - } else { - Err(datafusion::error::DataFusionError::Internal( - "Invalid number of children".to_string(), - )) - } - } - - fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) { - let binary_expr = BinaryExpr::new(Arc::clone(&self.left), self.op, Arc::clone(&self.right)); - binary_expr.dyn_hash(state) - } -} diff --git a/core/src/execution/datafusion/planner.rs b/core/src/execution/datafusion/planner.rs index 954cafdd1b..cd9822d669 100644 --- a/core/src/execution/datafusion/planner.rs +++ b/core/src/execution/datafusion/planner.rs @@ -97,7 +97,7 @@ use crate::{ }, }; -use super::expressions::{abs::CometAbsFunc, modulo::ModuloExpr, EvalMode}; +use super::expressions::{abs::CometAbsFunc, EvalMode}; // For clippy error on type_complexity. type ExecResult = Result; @@ -681,13 +681,7 @@ impl PhysicalPlanner { data_type, ))) } - _ => { - // Improves compatibility with Spark - if op == DataFusionOperator::Modulo { - return Ok(Arc::new(ModuloExpr::new(left, right))); - } - Ok(Arc::new(BinaryExpr::new(left, op, right))) - } + _ => Ok(Arc::new(BinaryExpr::new(left, op, right))), } } diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index c1c8b5c56b..21b42bba51 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2240,7 +2240,15 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim def nullIfWhenPrimitive(expression: Expression): Expression = if (isPrimitive(expression)) { val zero = Literal.default(expression.dataType) - If(EqualTo(expression, zero), Literal.create(null, expression.dataType), expression) + val negZero = Literal.create(-0.0, DoubleType) + if (expression.dataType == DoubleType) { + If( + Or(EqualTo(expression, zero), EqualTo(expression, negZero)), + Literal.create(null, expression.dataType), + expression) + } else { + If(EqualTo(expression, zero), Literal.create(null, expression.dataType), expression) + } } else { expression } diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 940bda6e8f..31cb883fe5 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -858,7 +858,9 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { test("remainder") { withTempDir { dir => - val df = Seq((21840, -0.0)).toDF("c90", "c1") + val df = + Seq((21840, -0.0), (21840, 5.0)) + .toDF("c90", "c1") val path = new Path(dir.toURI.toString, "remainder_test.parquet").toString df.write.mode("overwrite").parquet(path) From 3585f14c15b3deda3b0a185cac71f4889a469282 Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Tue, 25 Jun 2024 11:00:05 +0530 Subject: [PATCH 06/18] adding tests for float and decimal, code refactor --- .../org/apache/comet/CometExpressionSuite.scala | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 31cb883fe5..95ec36b757 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -857,15 +857,14 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("remainder") { - withTempDir { dir => - val df = - Seq((21840, -0.0), (21840, 5.0)) - .toDF("c90", "c1") - val path = new Path(dir.toURI.toString, "remainder_test.parquet").toString - df.write.mode("overwrite").parquet(path) - - withParquetTable(path, "t") { - checkSparkAnswerAndOperator("SELECT c90, c1, c90 % c1 FROM t") + val testCases = Seq( + (Seq((21840, -0.0), (21840, 5.0)), "t"), + (Seq((Decimal(21840, 10, 0), Decimal(-0.0, 10, 0))), "t"), + (Seq((21840.0f, -0.0f), (21840.0f, 5.0f)), "t")) + + testCases.foreach { case (data, tableName) => + withParquetTable(data, tableName) { + checkSparkAnswerAndOperator("SELECT _1, _2, _1 % _2 FROM t") } } } From 06febb369001781ec488c87f6dfe0e0c1c0fbcdf Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Tue, 25 Jun 2024 13:55:54 +0530 Subject: [PATCH 07/18] adding more tests and code refactor --- .../apache/comet/serde/QueryPlanSerde.scala | 4 ++-- .../apache/comet/CometExpressionSuite.scala | 18 ++++++++++-------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 21b42bba51..5020000e05 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2240,8 +2240,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim def nullIfWhenPrimitive(expression: Expression): Expression = if (isPrimitive(expression)) { val zero = Literal.default(expression.dataType) - val negZero = Literal.create(-0.0, DoubleType) - if (expression.dataType == DoubleType) { + val negZero = UnaryMinus(Literal.default(expression.dataType)) + if (expression.dataType == DoubleType || expression.dataType == FloatType) { If( Or(EqualTo(expression, zero), EqualTo(expression, negZero)), Literal.create(null, expression.dataType), diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 95ec36b757..6d2243d2dd 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -857,15 +857,17 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } test("remainder") { - val testCases = Seq( - (Seq((21840, -0.0), (21840, 5.0)), "t"), - (Seq((Decimal(21840, 10, 0), Decimal(-0.0, 10, 0))), "t"), - (Seq((21840.0f, -0.0f), (21840.0f, 5.0f)), "t")) + val query = "SELECT _1, _2, _1 % _2 FROM t" + withParquetTable(Seq((21840, -0.0), (21840, 5.0)), "t") { + checkSparkAnswerAndOperator(query) + } - testCases.foreach { case (data, tableName) => - withParquetTable(data, tableName) { - checkSparkAnswerAndOperator("SELECT _1, _2, _1 % _2 FROM t") - } + withParquetTable(Seq((Decimal(21840, 10, 0), Decimal(-0.0, 10, 0))), "t") { + checkSparkAnswerAndOperator(query) + } + + withParquetTable(Seq((21840.0f, -0.0f), (21840.0f, 5.0f)), "t") { + checkSparkAnswerAndOperator(query) } } From 447e3272370732186db47bb0fbddd2f99b03aa00 Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Tue, 25 Jun 2024 14:36:28 +0530 Subject: [PATCH 08/18] code refactor --- .../src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 5020000e05..d8d161c65d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2240,7 +2240,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim def nullIfWhenPrimitive(expression: Expression): Expression = if (isPrimitive(expression)) { val zero = Literal.default(expression.dataType) - val negZero = UnaryMinus(Literal.default(expression.dataType)) + val negZero = UnaryMinus(zero) if (expression.dataType == DoubleType || expression.dataType == FloatType) { If( Or(EqualTo(expression, zero), EqualTo(expression, negZero)), From 20ac2ba60bedc56b6187372dd2ce9b91515e5592 Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Thu, 27 Jun 2024 21:56:34 +0530 Subject: [PATCH 09/18] handle neg zero in equalto --- .../apache/comet/serde/QueryPlanSerde.scala | 55 +++++++++++++++---- 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index d8d161c65d..7287272351 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -855,9 +855,50 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None case EqualTo(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) + val zero = Literal.default(left.dataType) + val negZero = UnaryMinus(zero) + if (left.dataType == DoubleType || left.dataType == FloatType) { + if (right == negZero) { + return Some( + ExprOuterClass.Expr + .newBuilder() + .setEq( + ExprOuterClass.Equal + .newBuilder() + .setLeft(exprToProtoInternal(left, inputs).get) + .setRight(exprToProtoInternal(Abs(right).child, inputs).get)) + .build()) + } else if (left == negZero) { + return Some( + ExprOuterClass.Expr + .newBuilder() + .setEq( + ExprOuterClass.Equal + .newBuilder() + .setLeft(exprToProtoInternal(Abs(left).child, inputs).get) + .setRight(exprToProtoInternal(right, inputs).get)) + .build()) + } else { + Some( + ExprOuterClass.Expr + .newBuilder() + .setEq( + ExprOuterClass.Equal + .newBuilder() + .setLeft(exprToProtoInternal(left, inputs).get) + .setRight(exprToProtoInternal(right, inputs).get)) + .build()) + } + } + var leftExpr, rightExpr: Option[Expr] = None + if (left.dataType == DoubleType || left.dataType == FloatType) { + leftExpr = exprToProtoInternal(If(EqualTo(left, negZero), zero, left), inputs) + rightExpr = exprToProtoInternal(If(EqualTo(right, negZero), zero, right), inputs) + } else { + leftExpr = exprToProtoInternal(left, inputs) + rightExpr = exprToProtoInternal(right, inputs) + } if (leftExpr.isDefined && rightExpr.isDefined) { val builder = ExprOuterClass.Equal.newBuilder() builder.setLeft(leftExpr.get) @@ -2240,15 +2281,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim def nullIfWhenPrimitive(expression: Expression): Expression = if (isPrimitive(expression)) { val zero = Literal.default(expression.dataType) - val negZero = UnaryMinus(zero) - if (expression.dataType == DoubleType || expression.dataType == FloatType) { - If( - Or(EqualTo(expression, zero), EqualTo(expression, negZero)), - Literal.create(null, expression.dataType), - expression) - } else { - If(EqualTo(expression, zero), Literal.create(null, expression.dataType), expression) - } + If(EqualTo(expression, zero), Literal.create(null, expression.dataType), expression) } else { expression } From 00e44ed10ca037f34b810ac961fa3c8a385393a9 Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Thu, 27 Jun 2024 22:01:59 +0530 Subject: [PATCH 10/18] code refactor and adding test case for equalto --- .../apache/comet/serde/QueryPlanSerde.scala | 89 +++++++++---------- .../apache/comet/CometExpressionSuite.scala | 24 +++++ 2 files changed, 64 insertions(+), 49 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7287272351..26eefef2ea 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -857,61 +857,52 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case EqualTo(left, right) => val zero = Literal.default(left.dataType) val negZero = UnaryMinus(zero) - if (left.dataType == DoubleType || left.dataType == FloatType) { - if (right == negZero) { - return Some( - ExprOuterClass.Expr - .newBuilder() - .setEq( - ExprOuterClass.Equal - .newBuilder() - .setLeft(exprToProtoInternal(left, inputs).get) - .setRight(exprToProtoInternal(Abs(right).child, inputs).get)) - .build()) - } else if (left == negZero) { - return Some( - ExprOuterClass.Expr - .newBuilder() - .setEq( - ExprOuterClass.Equal - .newBuilder() - .setLeft(exprToProtoInternal(Abs(left).child, inputs).get) - .setRight(exprToProtoInternal(right, inputs).get)) - .build()) - } else { - Some( - ExprOuterClass.Expr + + def buildEqualExpr(leftExpr: Expr, rightExpr: Expr): ExprOuterClass.Expr = { + ExprOuterClass.Expr + .newBuilder() + .setEq( + ExprOuterClass.Equal .newBuilder() - .setEq( - ExprOuterClass.Equal - .newBuilder() - .setLeft(exprToProtoInternal(left, inputs).get) - .setRight(exprToProtoInternal(right, inputs).get)) - .build()) - } + .setLeft(leftExpr) + .setRight(rightExpr)) + .build() } - var leftExpr, rightExpr: Option[Expr] = None if (left.dataType == DoubleType || left.dataType == FloatType) { - leftExpr = exprToProtoInternal(If(EqualTo(left, negZero), zero, left), inputs) - rightExpr = exprToProtoInternal(If(EqualTo(right, negZero), zero, right), inputs) - } else { - leftExpr = exprToProtoInternal(left, inputs) - rightExpr = exprToProtoInternal(right, inputs) + (left, right) match { + case (`negZero`, _) => + return Some( + buildEqualExpr( + exprToProtoInternal(Abs(left).child, inputs).get, + exprToProtoInternal(right, inputs).get)) + case (_, `negZero`) => + return Some( + buildEqualExpr( + exprToProtoInternal(left, inputs).get, + exprToProtoInternal(Abs(right).child, inputs).get)) + case _ => + Some( + buildEqualExpr( + exprToProtoInternal(left, inputs).get, + exprToProtoInternal(right, inputs).get)) + } } - if (leftExpr.isDefined && rightExpr.isDefined) { - val builder = ExprOuterClass.Equal.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setEq(builder) - .build()) - } else { - withInfo(expr, left, right) - None + val (leftExpr, rightExpr) = + if (left.dataType == DoubleType || left.dataType == FloatType) { + ( + exprToProtoInternal(If(EqualTo(left, negZero), zero, left), inputs), + exprToProtoInternal(If(EqualTo(right, negZero), zero, right), inputs)) + } else { + (exprToProtoInternal(left, inputs), exprToProtoInternal(right, inputs)) + } + + (leftExpr, rightExpr) match { + case (Some(l), Some(r)) => Some(buildEqualExpr(l, r)) + case _ => + withInfo(expr, left, right) + None } case Not(EqualTo(left, right)) => diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 6d2243d2dd..b6b6a393e3 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -856,6 +856,30 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } + test("zero equality") { + withParquetTable( + Seq( + (-0.0, 0.0), + (0.0, -0.0), + (-0.0, -0.0), + (0.0, 0.0), + (1.0, 2.0), + (1.0, 1.0), + (1.0, 0.0), + (0.0, 1.0), + (-0.0, 1.0), + (1.0, -0.0), + (1.0, -1.0), + (-1.0, 1.0), + (-1.0, -0.0), + (-1.0, -1.0), + (-1.0, 0.0), + (0.0, -1.0)), + "t") { + checkSparkAnswerAndOperator("SELECT _1 == _2 FROM t") + } + } + test("remainder") { val query = "SELECT _1, _2, _1 % _2 FROM t" withParquetTable(Seq((21840, -0.0), (21840, 5.0)), "t") { From cd532843ef991f18917416f1a8c7879e036e6cac Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Fri, 28 Jun 2024 01:50:54 +0530 Subject: [PATCH 11/18] bug fix --- .../apache/comet/serde/QueryPlanSerde.scala | 28 +++++++++++-------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 26eefef2ea..2bfbd9112a 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -870,22 +870,28 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } if (left.dataType == DoubleType || left.dataType == FloatType) { + // make sure left or right is not null (left, right) match { case (`negZero`, _) => return Some( buildEqualExpr( - exprToProtoInternal(Abs(left).child, inputs).get, - exprToProtoInternal(right, inputs).get)) + exprToProtoInternal(Abs(left).child, inputs).getOrElse(return None), + exprToProtoInternal(right, inputs).getOrElse(return None))) case (_, `negZero`) => return Some( buildEqualExpr( - exprToProtoInternal(left, inputs).get, - exprToProtoInternal(Abs(right).child, inputs).get)) + exprToProtoInternal(left, inputs).getOrElse(return None), + exprToProtoInternal(Abs(right).child, inputs).getOrElse(return None))) case _ => + if ((left.nullable && !right.nullable) && + (left != zero && right != zero)) { + withInfo(expr, left, right) + return None + } Some( buildEqualExpr( - exprToProtoInternal(left, inputs).get, - exprToProtoInternal(right, inputs).get)) + exprToProtoInternal(left, inputs).getOrElse(return None), + exprToProtoInternal(right, inputs).getOrElse(return None))) } } @@ -898,11 +904,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim (exprToProtoInternal(left, inputs), exprToProtoInternal(right, inputs)) } - (leftExpr, rightExpr) match { - case (Some(l), Some(r)) => Some(buildEqualExpr(l, r)) - case _ => - withInfo(expr, left, right) - None + if (leftExpr.isDefined && rightExpr.isDefined) { + Some(buildEqualExpr(leftExpr.get, rightExpr.get)) + } else { + withInfo(expr, left, right) + None } case Not(EqualTo(left, right)) => From 81f4e62a1a84b3db4a122572a9786d0cdf68877d Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Fri, 28 Jun 2024 10:07:27 +0530 Subject: [PATCH 12/18] resolving pr comments --- .../apache/comet/serde/QueryPlanSerde.scala | 81 +++++++++++-------- 1 file changed, 46 insertions(+), 35 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 2bfbd9112a..7ae25c8bcf 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -855,62 +855,73 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None case EqualTo(left, right) => + // this is a workaround for handling -0.0 in double and float + // untill https://github.com/apache/datafusion/issues/11108 is fixed val zero = Literal.default(left.dataType) val negZero = UnaryMinus(zero) - def buildEqualExpr(leftExpr: Expr, rightExpr: Expr): ExprOuterClass.Expr = { - ExprOuterClass.Expr - .newBuilder() - .setEq( - ExprOuterClass.Equal + def buildEqualExpr( + leftExpr: Option[Expr], + rightExpr: Option[Expr]): Option[ExprOuterClass.Expr] = { + + if (leftExpr.isDefined && rightExpr.isDefined) { + Some( + ExprOuterClass.Expr .newBuilder() - .setLeft(leftExpr) - .setRight(rightExpr)) - .build() + .setEq( + ExprOuterClass.Equal + .newBuilder() + .setLeft(leftExpr.get) + .setRight(rightExpr.get)) + .build()) + } else { + withInfo(expr, left, right) + None + } } - if (left.dataType == DoubleType || left.dataType == FloatType) { - // make sure left or right is not null + if (left.dataType == DoubleType || + left.dataType == FloatType || + right.dataType == DoubleType || + right.dataType == FloatType) { (left, right) match { + case (`negZero`, `negZero`) => + return buildEqualExpr( + exprToProtoInternal(Abs(left).child, inputs), + exprToProtoInternal(Abs(right).child, inputs)) case (`negZero`, _) => - return Some( - buildEqualExpr( - exprToProtoInternal(Abs(left).child, inputs).getOrElse(return None), - exprToProtoInternal(right, inputs).getOrElse(return None))) + return buildEqualExpr( + exprToProtoInternal(Abs(left).child, inputs), + exprToProtoInternal(right, inputs)) case (_, `negZero`) => - return Some( - buildEqualExpr( - exprToProtoInternal(left, inputs).getOrElse(return None), - exprToProtoInternal(Abs(right).child, inputs).getOrElse(return None))) + return buildEqualExpr( + exprToProtoInternal(left, inputs), + exprToProtoInternal(Abs(right).child, inputs)) case _ => if ((left.nullable && !right.nullable) && (left != zero && right != zero)) { withInfo(expr, left, right) return None } - Some( - buildEqualExpr( - exprToProtoInternal(left, inputs).getOrElse(return None), - exprToProtoInternal(right, inputs).getOrElse(return None))) + buildEqualExpr( + exprToProtoInternal(left, inputs), + exprToProtoInternal(right, inputs)) } } - val (leftExpr, rightExpr) = - if (left.dataType == DoubleType || left.dataType == FloatType) { - ( - exprToProtoInternal(If(EqualTo(left, negZero), zero, left), inputs), - exprToProtoInternal(If(EqualTo(right, negZero), zero, right), inputs)) - } else { - (exprToProtoInternal(left, inputs), exprToProtoInternal(right, inputs)) - } - - if (leftExpr.isDefined && rightExpr.isDefined) { - Some(buildEqualExpr(leftExpr.get, rightExpr.get)) + val leftExpr = if (left.dataType == DoubleType || left.dataType == FloatType) { + exprToProtoInternal(If(EqualTo(left, negZero), zero, left), inputs) } else { - withInfo(expr, left, right) - None + exprToProtoInternal(left, inputs) + } + val rightExpr = if (right.dataType == DoubleType || right.dataType == FloatType) { + exprToProtoInternal(If(EqualTo(right, negZero), zero, right), inputs) + } else { + exprToProtoInternal(right, inputs) } + buildEqualExpr(leftExpr, rightExpr) + case Not(EqualTo(left, right)) => val leftExpr = exprToProtoInternal(left, inputs) val rightExpr = exprToProtoInternal(right, inputs) From 9696aaf63f29fc2b71410545238159ed444bad94 Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Fri, 28 Jun 2024 10:34:11 +0530 Subject: [PATCH 13/18] handling left and right zero and negzero separately --- .../apache/comet/serde/QueryPlanSerde.scala | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 7ae25c8bcf..68e679e15b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -857,8 +857,11 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim case EqualTo(left, right) => // this is a workaround for handling -0.0 in double and float // untill https://github.com/apache/datafusion/issues/11108 is fixed - val zero = Literal.default(left.dataType) - val negZero = UnaryMinus(zero) + val leftZero = Literal.default(left.dataType) + val rightZero = Literal.default(right.dataType) + // create negzero based on double or float + val negZeroLeft = UnaryMinus(leftZero) + val negZeroRight = UnaryMinus(rightZero) def buildEqualExpr( leftExpr: Option[Expr], @@ -885,21 +888,21 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim right.dataType == DoubleType || right.dataType == FloatType) { (left, right) match { - case (`negZero`, `negZero`) => + case (`negZeroLeft`, `negZeroRight`) => return buildEqualExpr( exprToProtoInternal(Abs(left).child, inputs), exprToProtoInternal(Abs(right).child, inputs)) - case (`negZero`, _) => + case (`negZeroLeft`, _) => return buildEqualExpr( exprToProtoInternal(Abs(left).child, inputs), exprToProtoInternal(right, inputs)) - case (_, `negZero`) => + case (_, `negZeroRight`) => return buildEqualExpr( exprToProtoInternal(left, inputs), exprToProtoInternal(Abs(right).child, inputs)) case _ => if ((left.nullable && !right.nullable) && - (left != zero && right != zero)) { + (left != leftZero && right != rightZero)) { withInfo(expr, left, right) return None } @@ -910,12 +913,12 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } val leftExpr = if (left.dataType == DoubleType || left.dataType == FloatType) { - exprToProtoInternal(If(EqualTo(left, negZero), zero, left), inputs) + exprToProtoInternal(If(EqualTo(left, negZeroLeft), leftZero, left), inputs) } else { exprToProtoInternal(left, inputs) } val rightExpr = if (right.dataType == DoubleType || right.dataType == FloatType) { - exprToProtoInternal(If(EqualTo(right, negZero), zero, right), inputs) + exprToProtoInternal(If(EqualTo(right, negZeroRight), rightZero, right), inputs) } else { exprToProtoInternal(right, inputs) } From 14f493aa25e2c820d5600494501b3d01b384a7ea Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Fri, 28 Jun 2024 10:38:40 +0530 Subject: [PATCH 14/18] removing unnecessary comments --- spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 68e679e15b..4d592edc1d 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -859,7 +859,6 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim // untill https://github.com/apache/datafusion/issues/11108 is fixed val leftZero = Literal.default(left.dataType) val rightZero = Literal.default(right.dataType) - // create negzero based on double or float val negZeroLeft = UnaryMinus(leftZero) val negZeroRight = UnaryMinus(rightZero) From 2878d7175052f48405e22629c861fa22309dd668 Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Sat, 29 Jun 2024 12:01:37 +0530 Subject: [PATCH 15/18] bug fixes --- .../apache/comet/serde/QueryPlanSerde.scala | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 4d592edc1d..fa6542f73f 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -865,7 +865,6 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim def buildEqualExpr( leftExpr: Option[Expr], rightExpr: Option[Expr]): Option[ExprOuterClass.Expr] = { - if (leftExpr.isDefined && rightExpr.isDefined) { Some( ExprOuterClass.Expr @@ -881,11 +880,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } } - - if (left.dataType == DoubleType || - left.dataType == FloatType || - right.dataType == DoubleType || - right.dataType == FloatType) { + if ((left.dataType == DoubleType && + right.dataType == DoubleType) || + (left.dataType == FloatType && + right.dataType == FloatType)) { (left, right) match { case (`negZeroLeft`, `negZeroRight`) => return buildEqualExpr( @@ -900,8 +898,14 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim exprToProtoInternal(left, inputs), exprToProtoInternal(Abs(right).child, inputs)) case _ => + val doubleNan = Literal(Double.NaN, DoubleType) + val floatNan = Literal(Float.NaN, FloatType) + if ((left.nullable && !right.nullable) && - (left != leftZero && right != rightZero)) { + (left != negZeroLeft && right != negZeroRight) && + (left != leftZero && right != rightZero) && + (left != doubleNan && right != doubleNan) && + (left != floatNan && right != floatNan)) { withInfo(expr, left, right) return None } From 79b0a4b13a309a4ced0990e7fa489a2c63ff44ba Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Sun, 30 Jun 2024 21:30:21 +0530 Subject: [PATCH 16/18] bug fixes --- .../apache/comet/serde/QueryPlanSerde.scala | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index fa6542f73f..f2870c9d3e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -822,8 +822,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } None - case rem @ Remainder(left, right, _) - if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => + case rem @ Remainder(left, right, _) if supportedDataType(left.dataType) => val leftExpr = exprToProtoInternal(left, inputs) val rightExpr = exprToProtoInternal(nullIfWhenPrimitive(right), inputs) @@ -905,7 +904,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim (left != negZeroLeft && right != negZeroRight) && (left != leftZero && right != rightZero) && (left != doubleNan && right != doubleNan) && - (left != floatNan && right != floatNan)) { + (left != floatNan && right != floatNan) && isSpark34Plus) { withInfo(expr, left, right) return None } @@ -915,16 +914,18 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } } - val leftExpr = if (left.dataType == DoubleType || left.dataType == FloatType) { - exprToProtoInternal(If(EqualTo(left, negZeroLeft), leftZero, left), inputs) - } else { - exprToProtoInternal(left, inputs) - } - val rightExpr = if (right.dataType == DoubleType || right.dataType == FloatType) { - exprToProtoInternal(If(EqualTo(right, negZeroRight), rightZero, right), inputs) - } else { - exprToProtoInternal(right, inputs) - } + val leftExpr = + if (left.dataType == DoubleType || left.dataType == FloatType) { + exprToProtoInternal(If(EqualTo(left, negZeroLeft), leftZero, left), inputs) + } else { + exprToProtoInternal(left, inputs) + } + val rightExpr = + if (right.dataType == DoubleType || right.dataType == FloatType) { + exprToProtoInternal(If(EqualTo(right, negZeroRight), rightZero, right), inputs) + } else { + exprToProtoInternal(right, inputs) + } buildEqualExpr(leftExpr, rightExpr) From 6e843da54bd7bd553aee4ca10d26e35c43ac244b Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Sun, 30 Jun 2024 21:37:09 +0530 Subject: [PATCH 17/18] bug fix --- .../src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index f2870c9d3e..65c535f475 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -904,7 +904,7 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim (left != negZeroLeft && right != negZeroRight) && (left != leftZero && right != rightZero) && (left != doubleNan && right != doubleNan) && - (left != floatNan && right != floatNan) && isSpark34Plus) { + (left != floatNan && right != floatNan)) { withInfo(expr, left, right) return None } From fd70aa7f604ff8c01a0ab36a80cdc20125e0915e Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Wed, 24 Jul 2024 16:18:38 +0530 Subject: [PATCH 18/18] adding comments to explain if statements --- .../src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala | 2 ++ .../src/test/scala/org/apache/comet/CometExpressionSuite.scala | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index aec686b8fc..e9e4623edb 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -1030,6 +1030,8 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim val doubleNan = Literal(Double.NaN, DoubleType) val floatNan = Literal(Float.NaN, FloatType) + // Ensure neither left nor right is -0.0 or 0.0 or NaN + // also return none if one side is nullable and the other is not if ((left.nullable && !right.nullable) && (left != negZeroLeft && right != negZeroRight) && (left != leftZero && right != rightZero) && diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index 09c1d81c21..0402477d02 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1794,4 +1794,4 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } } -} \ No newline at end of file +}