diff --git a/docs/velox-backend-aggregate-function-support.md b/docs/velox-backend-aggregate-function-support.md new file mode 100644 index 000000000000..ee6fa131cddf --- /dev/null +++ b/docs/velox-backend-aggregate-function-support.md @@ -0,0 +1,71 @@ +# Aggregate Functions Support Status + +**Out of 62 aggregate functions in Spark 3.5, Gluten currently fully supports 54 functions and partially supports 1 function.** + +## Aggregate Functions + +| Spark Functions | Spark Expressions | Status | Restrictions | +|-----------------------|------------------------------------|----------|----------------| +| any | BoolOr | S | | +| any_value | AnyValue | S | | +| approx_count_distinct | HyperLogLogPlusPlus | S | | +| approx_percentile | ApproximatePercentile | S | | +| array_agg | CollectList | S | | +| avg | Average | S | | +| bit_and | BitAndAgg | S | | +| bit_or | BitOrAgg | S | | +| bit_xor | BitXorAgg | S | | +| bitmap_construct_agg | BitmapConstructAgg | | | +| bitmap_or_agg | BitmapOrAgg | | | +| bool_and | BoolAnd | S | | +| bool_or | BoolOr | S | | +| collect_list | CollectList | S | | +| collect_set | CollectSet | S | | +| corr | Corr | S | | +| count | Count | S | | +| count_if | CountIf | S | | +| count_min_sketch | CountMinSketchAggExpressionBuilder | | | +| covar_pop | CovPopulation | S | | +| covar_samp | CovSample | S | | +| every | BoolAnd | S | | +| first | First | S | | +| first_value | First | S | | +| grouping | Grouping | S | | +| grouping_id | GroupingID | S | | +| histogram_numeric | HistogramNumeric | | | +| hll_sketch_agg | HllSketchAgg | | | +| hll_union_agg | HllUnionAgg | | | +| kurtosis | Kurtosis | S | | +| last | Last | S | | +| last_value | Last | S | | +| max | Max | S | | +| max_by | MaxBy | S | | +| mean | Average | S | | +| median | Median | S | | +| min | Min | S | | +| min_by | MinBy | S | | +| mode | Mode | | | +| percentile | Percentile | S | | +| percentile_approx | ApproximatePercentile | S | | +| regr_avgx | RegrAvgX | S | | +| regr_avgy | RegrAvgY | S | | +| regr_count | RegrCount | S | | +| regr_intercept | RegrIntercept | S | | +| regr_r2 | RegrR2 | S | | +| regr_slope | RegrSlope | S | | +| regr_sxx | RegrSXX | S | | +| regr_sxy | RegrSXY | S | | +| regr_syy | RegrSYY | S | | +| skewness | Skewness | S | | +| some | BoolOr | S | | +| std | StddevSamp | S | | +| stddev | StddevSamp | S | | +| stddev_pop | StddevPop | S | | +| stddev_samp | StddevSamp | S | | +| sum | Sum | S | | +| try_avg | TryAverageExpressionBuilder | S | | +| try_sum | TrySumExpressionBuilder | PS | | +| var_pop | VariancePop | S | | +| var_samp | VarianceSamp | S | | +| variance | VarianceSamp | S | | + diff --git a/docs/velox-backend-generator-function-support.md b/docs/velox-backend-generator-function-support.md new file mode 100644 index 000000000000..ec57535a93e6 --- /dev/null +++ b/docs/velox-backend-generator-function-support.md @@ -0,0 +1,16 @@ +# Generator Functions Support Status + +**Out of 7 generator functions in Spark 3.5, Gluten currently fully supports 4 functions.** + +## Generator Functions + +| Spark Functions | Spark Expressions | Status | Restrictions | +|-------------------|--------------------------|----------|----------------| +| explode | ExplodeExpressionBuilder | S | | +| explode_outer | ExplodeExpressionBuilder | | | +| inline | Inline | S | | +| inline_outer | Inline | | | +| posexplode | PosExplode | S | | +| posexplode_outer | PosExplode | | | +| stack | Stack | S | | + diff --git a/docs/velox-backend-scalar-function-support.md b/docs/velox-backend-scalar-function-support.md index ea8c58cac9cd..60eefb2ce329 100644 --- a/docs/velox-backend-scalar-function-support.md +++ b/docs/velox-backend-scalar-function-support.md @@ -1,13 +1,13 @@ # Scalar Functions Support Status -**Out of 357 scalar functions in Spark 3.5, Gluten currently fully supports 229 functions and partially supports 13 functions.** +**Out of 357 scalar functions in Spark 3.5, Gluten currently fully supports 220 functions and partially supports 20 functions.** ## Array Functions | Spark Functions | Spark Expressions | Status | Restrictions | |-------------------|---------------------|----------|----------------| | array | CreateArray | S | | -| array_append | ArrayAppend | | | +| array_append | ArrayAppend | S | | | array_compact | ArrayCompact | | | | array_contains | ArrayContains | S | | | array_distinct | ArrayDistinct | S | | @@ -21,12 +21,12 @@ | array_prepend | ArrayPrepend | | | | array_remove | ArrayRemove | S | | | array_repeat | ArrayRepeat | S | | -| array_union | ArrayUnion | | | +| array_union | ArrayUnion | S | | | arrays_overlap | ArraysOverlap | S | | | arrays_zip | ArraysZip | S | | | flatten | Flatten | S | | | get | Get | | | -| sequence | Sequence | S | | +| sequence | Sequence | | | | shuffle | Shuffle | S | | | slice | Slice | S | | | sort_array | SortArray | S | | @@ -41,7 +41,7 @@ | bit_get | BitwiseGet | S | | | getbit | BitwiseGet | S | | | shiftright | ShiftRight | S | | -| shiftrightunsigned | ShiftRightUnsigned | S | | +| shiftrightunsigned | ShiftRightUnsigned | | | | | | BitwiseOr | S | | | ~ | BitwiseNot | S | | @@ -51,7 +51,7 @@ |-------------------|---------------------|----------|----------------| | array_size | ArraySize | | | | cardinality | Size | S | | -| concat | Concat | S | | +| concat | Concat | PS | | | reverse | Reverse | S | | | size | Size | S | | @@ -110,7 +110,7 @@ | date_from_unix_date | DateFromUnixDate | S | | | date_part | DatePartExpressionBuilder | | | | date_sub | DateSub | S | | -| date_trunc | TruncTimestamp | S | | +| date_trunc | TruncTimestamp | | | | dateadd | DateAdd | S | | | datediff | DateDiff | S | | | datepart | DatePartExpressionBuilder | | | @@ -133,7 +133,7 @@ | make_ym_interval | MakeYMInterval | S | | | minute | Minute | S | | | month | Month | S | | -| months_between | MonthsBetween | S | | +| months_between | MonthsBetween | | | | next_day | NextDay | S | | | now | Now | | | | quarter | Quarter | S | | @@ -146,7 +146,7 @@ | to_timestamp | ParseToTimestamp | | | | to_timestamp_ltz | ParseToTimestampLTZExpressionBuilder | | | | to_timestamp_ntz | ParseToTimestampNTZExpressionBuilder | | | -| to_unix_timestamp | ToUnixTimestamp | S | | +| to_unix_timestamp | ToUnixTimestamp | PS | | | to_utc_timestamp | ToUTCTimestamp | S | | | trunc | TruncDate | | | | try_to_timestamp | TryToTimestampExpressionBuilder | | | @@ -177,7 +177,7 @@ | Spark Functions | Spark Expressions | Status | Restrictions | |-------------------|---------------------|----------|----------------| -| from_json | JsonToStructs | | | +| from_json | JsonToStructs | S | | | get_json_object | GetJsonObject | S | | | json_array_length | LengthOfJsonArray | S | | | json_object_keys | JsonObjectKeys | | | @@ -194,7 +194,7 @@ | exists | ArrayExists | S | | | filter | ArrayFilter | S | | | forall | ArrayForAll | S | | -| map_filter | MapFilter | | | +| map_filter | MapFilter | S | | | map_zip_with | MapZipWith | S | | | reduce | ArrayAggregate | S | | | transform | ArrayTransform | S | | @@ -208,10 +208,10 @@ |-------------------|---------------------|----------|----------------| | element_at | ElementAt | S | | | map | CreateMap | PS | | -| map_concat | MapConcat | | | +| map_concat | MapConcat | PS | | | map_contains_key | MapContainsKey | | | | map_entries | MapEntries | S | | -| map_from_arrays | MapFromArrays | S | | +| map_from_arrays | MapFromArrays | | | | map_from_entries | MapFromEntries | | | | map_keys | MapKeys | S | | | map_values | MapValues | S | | @@ -237,7 +237,7 @@ | atanh | Atanh | S | | | bin | Bin | S | | | bround | BRound | | | -| cbrt | Cbrt | S | | +| cbrt | Cbrt | | | | ceil | CeilExpressionBuilder | PS | | | ceiling | CeilExpressionBuilder | PS | | | conv | Conv | S | | @@ -250,13 +250,13 @@ | e | EulerNumber | S | | | exp | Exp | S | | | expm1 | Expm1 | S | | -| factorial | Factorial | S | | +| factorial | Factorial | | | | floor | FloorExpressionBuilder | PS | | | greatest | Greatest | S | | | hex | Hex | S | | | hypot | Hypot | S | | | least | Least | S | | -| ln | Log | S | | +| ln | Log | | | | log | Logarithm | S | | | log10 | Log10 | S | | | log1p | Log1p | S | | @@ -268,7 +268,7 @@ | positive | UnaryPositive | S | | | pow | Pow | S | | | power | Pow | S | | -| radians | ToRadians | S | | +| radians | ToRadians | | | | rand | Rand | S | | | randn | Randn | | | | random | Rand | S | | @@ -278,11 +278,11 @@ | shiftleft | ShiftLeft | S | | | sign | Signum | S | | | signum | Signum | S | | -| sin | Sin | S | | +| sin | Sin | | | | sinh | Sinh | S | | -| sqrt | Sqrt | S | | -| tan | Tan | S | | -| tanh | Tanh | S | | +| sqrt | Sqrt | | | +| tan | Tan | | | +| tanh | Tanh | | | | try_add | TryAdd | PS | | | try_divide | TryDivide | | | | try_multiply | TryMultiply | | | @@ -319,25 +319,25 @@ | user | CurrentUser | | | | uuid | Uuid | S | | | version | SparkVersion | S | | -| || | | | | +| || | | S | | ## Predicate Functions | Spark Functions | Spark Expressions | Status | Restrictions | |-------------------|---------------------|----------|------------------------| | ! | Not | S | | -| != | | | | +| != | | S | | | < | LessThan | S | | | <= | LessThanOrEqual | S | | | <=> | EqualNullSafe | S | | -| <> | | | | +| <> | | S | | | = | EqualTo | S | | | == | EqualTo | S | | | > | GreaterThan | S | | | >= | GreaterThanOrEqual | S | | | and | And | S | | -| between | | | | -| case | | | | +| between | | S | | +| case | | S | | | ilike | ILike | | | | in | In | PS | | | isnan | IsNaN | S | | @@ -357,21 +357,21 @@ | ascii | Ascii | S | | | base64 | Base64 | S | | | bit_length | BitLength | S | | -| btrim | StringTrimBoth | | | +| btrim | StringTrimBoth | S | | | char | Chr | S | | | char_length | Length | S | | | character_length | Length | S | | | chr | Chr | S | | | concat_ws | ConcatWs | S | | -| contains | ContainsExpressionBuilder | S | | +| contains | ContainsExpressionBuilder | PS | BinaryType unsupported | | decode | Decode | | | | elt | Elt | | | | encode | Encode | | | -| endswith | EndsWithExpressionBuilder | | | +| endswith | EndsWithExpressionBuilder | PS | BinaryType unsupported | | find_in_set | FindInSet | S | | | format_number | FormatNumber | | | -| format_string | FormatString | S | | -| initcap | InitCap | S | | +| format_string | FormatString | | | +| initcap | InitCap | | | | instr | StringInstr | S | | | lcase | Lower | S | | | left | Left | S | | @@ -380,14 +380,14 @@ | levenshtein | Levenshtein | S | | | locate | StringLocate | S | | | lower | Lower | S | | -| lpad | LPadExpressionBuilder | S | | +| lpad | LPadExpressionBuilder | PS | BinaryType unsupported | | ltrim | StringTrimLeft | S | | | luhn_check | Luhncheck | | | | mask | MaskExpressionBuilder | S | | | octet_length | OctetLength | | | | overlay | Overlay | S | | | position | StringLocate | S | | -| printf | FormatString | S | | +| printf | FormatString | | | | regexp_count | RegExpCount | | | | regexp_extract | RegExpExtract | PS | Lookaround unsupported | | regexp_extract_all | RegExpExtractAll | PS | Lookaround unsupported | @@ -397,16 +397,16 @@ | repeat | StringRepeat | S | | | replace | StringReplace | S | | | right | Right | | | -| rpad | RPadExpressionBuilder | S | | +| rpad | RPadExpressionBuilder | PS | BinaryType unsupported | | rtrim | StringTrimRight | S | | | sentences | Sentences | | | | soundex | SoundEx | S | | -| space | StringSpace | S | | +| space | StringSpace | | | | split | StringSplit | S | | | split_part | SplitPart | S | | -| startswith | StartsWithExpressionBuilder | | | -| substr | Substring | S | | -| substring | Substring | S | | +| startswith | StartsWithExpressionBuilder | PS | BinaryType unsupported | +| substr | Substring | PS | | +| substring | Substring | PS | | | substring_index | SubstringIndex | S | | | to_binary | ToBinary | | | | to_char | ToCharacter | | | @@ -432,8 +432,8 @@ | Spark Functions | Spark Expressions | Status | Restrictions | |-------------------|---------------------|----------|----------------| | parse_url | ParseUrl | | | -| url_decode | UrlDecode | PS | | -| url_encode | UrlEncode | PS | | +| url_decode | UrlDecode | S | | +| url_encode | UrlEncode | S | | ## XML Functions diff --git a/docs/velox-backend-support-progress.md b/docs/velox-backend-support-progress.md index ab2ca76dd8e8..38159b74ffe9 100644 --- a/docs/velox-backend-support-progress.md +++ b/docs/velox-backend-support-progress.md @@ -103,74 +103,8 @@ Please check the links below for the detailed support status of each category: [Scalar Functions Support Status](./velox-backend-scalar-function-support.md) -### Other Functions Support Status (To be updated) - -| Spark Functions | Velox/Presto Functions | Velox/Spark functions | Gluten | Restrictions | BOOLEAN | BYTE | SHORT | INT | LONG | FLOAT | DOUBLE | DATE | TIMESTAMP | STRING | DECIMAL | NULL | BINARY | CALENDAR | ARRAY | MAP | STRUCT | UDT | -|-----------------------|------------------------|-----------------------|--------|--------------|---------|------|-------|-----|------|-------|--------|------|-----------|--------|---------|------|--------|----------|-------|-----|--------|-----| -| bit_and | bitwise_and_agg | | S | | | S | S | S | S | S | | | | | | | | | | | | | -| bit_or | | | S | | | | | | | | | | | | | | | | | | | | -| bit_xor | | bit_xor | S | | | | | | | | | | | | | | | | | | | | -| explode | | | | | | | | | | | | | | | | | | | | | | | -| explode_outer | | | | | | | | | | | | | | | | | | | | | | | -| get_map_value | | element_at | S | | | | | | | | | | | | | | | | | S | | | -| posexplode_outer | | | | | | | | | | | | | | | | | | | | | | | -| any | | | | | | | | | | | | | | | | | | | | | | | -| approx_count_distinct | approx_distinct | | S | | S | S | S | S | S | S | S | S | | S | | | | | | | | | -| approx_percentile | | | | | | | | | | | | | | | | | | | | | | | -| avg | avg | | S | ANSI OFF | | S | S | S | S | S | | | | | | | | | | | | | -| bool_and | | | | | | | | | | | | | | | | | | | | | | | -| bool_or | | | | | | | | | | | | | | | | | | | | | | | -| collect_list | | | S | | | | | | | | | | | | | | | | | | | | -| collect_set | | | S | | | | | | | | | | | | | | | | | | | | -| corr | corr | | S | | | | S | S | S | S | S | | | | | | | | | | | | -| count | count | | S | | | | S | S | S | S | S | | | | | | | | | | | | -| count_if | count_if | | | | | S | S | S | S | S | | | | | | | | | | | | | -| count_min_sketch | | | | | | | | | | | | | | | | | | | | | | | -| covar_pop | covar_pop | | S | | | S | S | S | S | S | | | | | | | | | | | | | -| covar_samp | covar_samp | | S | | | S | S | S | S | S | | | | | | | | | | | | | -| every | | | | | | | | | | | | | | | | | | | | | | | -| first | | first | S | | | | | | | | | | | | | | | | | | | | -| first_value | | first_value | S | | | | | | | | | | | | | | | | | | | | -| grouping | | | | | | | | | | | | | | | | | | | | | | | -| grouping_id | | | | | | | | | | | | | | | | | | | | | | | -| kurtosis | kurtosis | kurtosis | S | | | | S | S | S | S | S | | | | | | | | | | | | -| last | | last | S | | | | | | | | | | | | | | | | | | | | -| last_value | | last_value | S | | | | | | | | | | | | | | | | | | | | -| max | max | | S | | | | S | S | S | S | S | | | | | | | | | | | | -| max_by | | | S | | | | | | | | | | | | | | | | | | | | -| mean | avg | | S | ANSI OFF | | | | | | | | | | | | | | | | | | | -| min | min | | S | | | | S | S | S | S | S | | | | | | | | | | | | -| min_by | | | S | | | | | | | | | | | | | | | | | | | | -| regr_avgx | regr_avgx | regr_avgx | S | | | | S | S | S | S | S | | | | | | | | | | | | -| regr_avgy | regr_avgy | regr_avgy | S | | | | S | S | S | S | S | | | | | | | | | | | | -| regr_count | regr_count | regr_count | S | | | | S | S | S | S | S | | | | | | | | | | | | -| regr_r2 | regr_r2 | regr_r2 | S | | | | S | S | S | S | S | | | | | | | | | | | | -| regr_intercept | regr_intercept | regr_intercept | S | | | | S | S | S | S | S | | | | | | | | | | | | -| regr_slope | regr_slope | regr_slope | S | | | | S | S | S | S | S | | | | | | | | | | | | -| regr_sxy | regr_sxy | regr_sxy | S | | | | S | S | S | S | S | | | | | | | | | | | | -| regr_sxx | regr_sxx | regr_sxx | S | | | | S | S | S | S | S | | | | | | | | | | | | -| regr_syy | regr_syy | regr_syy | S | | | | S | S | S | S | S | | | | | | | | | | | | -| skewness | skewness | skewness | S | | | | S | S | S | S | S | | | | | | | | | | | | -| some | | | | | | | | | | | | | | | | | | | | | | | -| std | stddev | | S | | | | S | S | S | S | S | | | | | | | | | | | | -| stddev | stddev | | S | | | | S | S | S | S | S | | | | | | | | | | | | -| stddev_pop | stddev_pop | | S | | | S | S | S | S | S | | | | | | | | | | | | | -| stddev_samp | stddev_samp | | S | | | | S | S | S | S | S | | | | | | | | | | | | -| sum | sum | | S | ANSI OFF | | S | S | S | S | S | | | | | | | | | | | | | -| var_pop | var_pop | | S | | | S | S | S | S | S | | | | | | | | | | | | | -| var_samp | var_samp | | S | | | S | S | S | S | S | | | | | | | | | | | | | -| variance | variance | | S | | | S | S | S | S | S | | | | | | | | | | | | | -| cume_dist | cume_dist | | S | | | | | | | | | | | | | | | | | | | | -| dense_rank | dense_rank | | S | | | | | | | | | | | | | | | | | | | | -| lag | | | S | | | | | | | | | | | | | | | | | | | | -| lead | | | S | | | | | | | | | | | | | | | | | | | | -| nth_value | nth_value | nth_value | PS | | | | | | | | | | | | | | | | | | | | -| ntile | ntile | ntile | S | | | | | | | | | | | | | | | | | | | | -| percent_rank | percent_rank | | S | | | | | | | | | | | | | | | | | | | | -| rank | rank | | S | | | | | | | | | | | | | | | | | | | | -| row_number | row_number | | S | | | | S | S | S | | | | | | | | | | | | | | -| inline | | | | | | | | | | | | | | | | | | | | | | | -| inline_outer | | | | | | | | | | | | | | | | | | | | | | | -| raise_error | | raise_error | S | | | | | | | | | | | | | | | | | | | | -| stack | | | S | | S | S | S | S | S | S | S | S | S | S | S | S | S | S | S | S | S | S | -| try_substract | | | S | | | | | | | | | | | | | | | | | | | | \ No newline at end of file +[Aggregate Functions Support Status](./velox-backend-aggregate-function-support.md) + +[Window Functions Support Status](./velox-backend-window-function-support.md) + +[Generator Functions Support Status](./velox-backend-generator-function-support.md) diff --git a/docs/velox-backend-window-function-support.md b/docs/velox-backend-window-function-support.md new file mode 100644 index 000000000000..85e6a7cbdfda --- /dev/null +++ b/docs/velox-backend-window-function-support.md @@ -0,0 +1,18 @@ +# Window Functions Support Status + +**Out of 9 window functions in Spark 3.5, Gluten currently fully supports 9 functions.** + +## Window Functions + +| Spark Functions | Spark Expressions | Status | Restrictions | +|-------------------|---------------------|----------|----------------| +| cume_dist | CumeDist | S | | +| dense_rank | DenseRank | S | | +| lag | Lag | S | | +| lead | Lead | S | | +| nth_value | NthValue | S | | +| ntile | NTile | S | | +| percent_rank | PercentRank | S | | +| rank | Rank | S | | +| row_number | RowNumber | S | | + diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 04a4e47e0b4b..8cdb41e4a805 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -161,23 +161,16 @@ object ExpressionConverter extends SQLConfHelper with Logging { return BackendsApiManager.getSparkPlanExecApiInstance.genHiveUDFTransformer( expr, attributeSeq) - case i: StaticInvoke => - val objectName = i.staticObject.getName.stripSuffix("$") - if (objectName.endsWith("UrlCodec")) { - val child = i.arguments.head - i.functionName match { - case "decode" => - return GenericExpressionTransformer( - ExpressionNames.URL_DECODE, - child.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), - i) - case "encode" => - return GenericExpressionTransformer( - ExpressionNames.URL_ENCODE, - child.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), - i) - } - } + case i @ StaticInvoke(_, _, "encode" | "decode", Seq(_, _), _, _, _, _) + if i.objectName.endsWith("UrlCodec") => + return GenericExpressionTransformer( + "url_" + i.functionName, + replaceWithExpressionTransformer0(i.arguments.head, attributeSeq, expressionsMap), + i) + case StaticInvoke(clz, _, functionName, _, _, _, _, _) => + throw new GlutenNotSupportException( + s"Not supported to transform StaticInvoke with object: ${clz.getName}, " + + s"function: $functionName") case _ => } diff --git a/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenTestUtils.scala b/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenTestUtils.scala index 35fe9518cee3..ef97ca3e02b3 100644 --- a/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenTestUtils.scala +++ b/gluten-substrait/src/test/scala/org/apache/spark/sql/GlutenTestUtils.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql import org.apache.gluten.exception.GlutenException -import org.apache.spark.SparkContext -import org.apache.spark.TestUtils +import org.apache.spark.{SparkContext, TestUtils} import org.apache.spark.scheduler.SparkListener import org.apache.spark.sql.test.SQLTestUtils diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxSQLQueryTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxSQLQueryTestSettings.scala index cbd39be2a9e8..180009d96df0 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxSQLQueryTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxSQLQueryTestSettings.scala @@ -180,23 +180,6 @@ object VeloxSQLQueryTestSettings extends SQLQueryTestSettings { "postgreSQL/window_part2.sql", "postgreSQL/with.sql", "datetime-special.sql", - "arrayJoin.sql", - "binaryComparison.sql", - "booleanEquality.sql", - "caseWhenCoercion.sql", - "concat.sql", - "dateTimeOperations.sql", - "decimalPrecision.sql", - "division.sql", - "elt.sql", - "ifCoercion.sql", - "implicitTypeCasts.sql", - "inConversion.sql", - "mapZipWith.sql", - "promoteStrings.sql", - "stringCastAndExpressions.sql", - "widenSetOperationTypes.sql", - "windowFrameCoercion.sql", "timestamp-ltz.sql", "timestamp-ntz.sql", "timezone.sql", diff --git a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala index 9a5abfe386f7..cc431e3188a5 100644 --- a/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala +++ b/shims/spark34/src/main/scala/org/apache/gluten/sql/shims/spark34/Spark34Shims.scala @@ -87,7 +87,9 @@ class Spark34Shims extends SparkShims { Sig[Mask](ExpressionNames.MASK), Sig[ArrayInsert](ExpressionNames.ARRAY_INSERT), Sig[CheckOverflowInTableInsert](ExpressionNames.CHECK_OVERFLOW_IN_TABLE_INSERT), - Sig[ArrayAppend](ExpressionNames.ARRAY_APPEND) + Sig[ArrayAppend](ExpressionNames.ARRAY_APPEND), + Sig[UrlEncode](ExpressionNames.URL_ENCODE), + Sig[UrlDecode](ExpressionNames.URL_DECODE) ) } diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index 3734bbf77807..0cd86e583e00 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -91,7 +91,9 @@ class Spark35Shims extends SparkShims { Sig[RoundCeil](ExpressionNames.CEIL), Sig[ArrayInsert](ExpressionNames.ARRAY_INSERT), Sig[CheckOverflowInTableInsert](ExpressionNames.CHECK_OVERFLOW_IN_TABLE_INSERT), - Sig[ArrayAppend](ExpressionNames.ARRAY_APPEND) + Sig[ArrayAppend](ExpressionNames.ARRAY_APPEND), + Sig[UrlEncode](ExpressionNames.URL_ENCODE), + Sig[UrlDecode](ExpressionNames.URL_DECODE) ) } diff --git a/tools/scripts/gen-function-support-docs.py b/tools/scripts/gen-function-support-docs.py index aaf7bc940655..c06f200ef92b 100644 --- a/tools/scripts/gen-function-support-docs.py +++ b/tools/scripts/gen-function-support-docs.py @@ -502,15 +502,38 @@ expression[StructsToCsv]("to_csv") ''' +FUNCTION_CATEGORIES = ['scalar', 'aggregate', 'window', 'generator'] + +STATIC_INVOKES = { + "luhn_check": ("org.apache.spark.sql.catalyst.expressions.ExpressionImplUtils", "isLuhnNumber"), + "base64": ("org.apache.spark.sql.catalyst.expressions.Base64", "encode"), + "contains": ("org.apache.spark.unsafe.array.ByteArrayMethods", "contains"), + "startsWith": ("org.apache.spark.unsafe.array.ByteArrayMethods", "startsWith"), + "endsWith": ("org.apache.spark.unsafe.array.ByteArrayMethods", "endsWith"), + "lpad": ("org.apache.spark.unsafe.array.ByteArrayMethods", "lpad"), + "rpad": ("org.apache.spark.unsafe.array.ByteArrayMethods", "rpad"), +} + # Known Restrictions in Gluten. LOOKAROUND_UNSUPPORTED = 'Lookaround unsupported' +BINARY_TYPE_UNSUPPORTED = 'BinaryType unsupported' GLUTEN_RESTRICTIONS = { - 'regexp': LOOKAROUND_UNSUPPORTED, - 'regexp_like': LOOKAROUND_UNSUPPORTED, - 'rlike': LOOKAROUND_UNSUPPORTED, - 'regexp_extract': LOOKAROUND_UNSUPPORTED, - 'regexp_extract_all': LOOKAROUND_UNSUPPORTED, - 'regexp_replace': LOOKAROUND_UNSUPPORTED + 'scalar': { + 'regexp': LOOKAROUND_UNSUPPORTED, + 'regexp_like': LOOKAROUND_UNSUPPORTED, + 'rlike': LOOKAROUND_UNSUPPORTED, + 'regexp_extract': LOOKAROUND_UNSUPPORTED, + 'regexp_extract_all': LOOKAROUND_UNSUPPORTED, + 'regexp_replace': LOOKAROUND_UNSUPPORTED, + 'contains': BINARY_TYPE_UNSUPPORTED, + 'startswith': BINARY_TYPE_UNSUPPORTED, + 'endswith': BINARY_TYPE_UNSUPPORTED, + 'lpad': BINARY_TYPE_UNSUPPORTED, + 'rpad': BINARY_TYPE_UNSUPPORTED + }, + 'aggregate': {}, + 'window': {}, + 'generator': {} } SPARK_FUNCTION_GROUPS = { @@ -553,15 +576,30 @@ 'xml_funcs': "XML Functions"} FUNCTION_GROUPS = {'scalar': SCALAR_FUNCTION_GROUPS, - 'agg': {'agg_funcs': 'Aggregate Functions'}, + 'aggregate': {'agg_funcs': 'Aggregate Functions'}, 'window': {'window_funcs': 'Window Functions'}, 'generator': {'generator_funcs': "Generator Functions"}} +FUNCTION_SUITE_PACKAGE = 'org.apache.spark.sql.' +FUNCTION_SUITES = { + 'scalar': {'GlutenSQLQueryTestSuite', 'GlutenDataFrameSessionWindowingSuite', 'GlutenDataFrameTimeWindowingSuite', + 'GlutenMiscFunctionsSuite', 'GlutenDateFunctionsSuite', 'GlutenDataFrameFunctionsSuite', + 'GlutenBitmapExpressionsQuerySuite', 'GlutenMathFunctionsSuite', 'GlutenColumnExpressionSuite', + 'GlutenStringFunctionsSuite', 'GlutenXPathFunctionsSuite', 'GlutenSQLQuerySuite'}, + 'aggregate': {'GlutenSQLQueryTestSuite', 'GlutenApproxCountDistinctForIntervalsQuerySuite', + 'GlutenBitmapExpressionsQuerySuite', + 'GlutenDataFrameAggregateSuite'}, + # All window functions are supported. + 'window': {}, + 'generator': {'GlutenGeneratorFunctionSuite'} +} + def create_spark_function_map(): exprs = list(map(lambda x: x if x[-1] != ',' else x[:-1], map(lambda x: x.strip(), - filter(lambda x: 'expression' in x, SPARK35_EXPRESSION_MAPPINGS.split('\n'))))) + filter(lambda x: 'expression' in x, + SPARK35_EXPRESSION_MAPPINGS.split('\n'))))) func_map = {} expression_pattern = 'expression[GeneratorOuter]*\[([\w0-9]+)\]\("([^\s]+)".*' @@ -589,7 +627,8 @@ def create_spark_function_map(): def generate_function_list(): jinfos = jvm.org.apache.spark.sql.api.python.PythonSQLUtils.listBuiltinFunctionInfos() - infos = [["!=", '', 'predicate_funcs'], ["<>", "", "predicate_funcs"], ['between', '', 'predicate_funcs'], + infos = [["!=", '', 'predicate_funcs'], ["<>", "", "predicate_funcs"], + ['between', '', 'predicate_funcs'], ['case', '', 'predicate_funcs'], ["||", '', 'misc_funcs']] for jinfo in filter(lambda x: x.getGroup() in SPARK_FUNCTION_GROUPS, jinfos): infos.append([jinfo.getName(), jinfo.getClassName().split('.')[-1], jinfo.getGroup()]) @@ -608,61 +647,72 @@ def generate_function_list(): group_functions[groupname].append(name) if groupname in SCALAR_FUNCTION_GROUPS: - scalar_functions.append(name) + functions['scalar'].add(name) elif groupname == 'agg_funcs': - agg_functions.append(name) + functions['aggregate'].add(name) elif groupname == 'window_funcs': - window_functions.append(name) + functions['window'].add(name) elif groupname == 'generator_funcs': - generator_functions.append(name) + functions['generator'].add(name) else: - logging.log(logging.WARNING, f"No matching group name for function {name}: " + groupname) + logging.log(logging.WARNING, + f"No matching group name for function {name}: " + groupname) def parse_logs(log_file): - generator_functions = ['explode', 'explode_outer', 'inline', 'inline_outer', 'posexplode', 'posexplode_outer', - 'stack'] + # "<>", "!=", "between", "case", and "||" are hard coded in spark and there's no corresponding functions. + builtin_functions = ['<>', '!=', 'between', 'case', '||'] + function_names = all_function_names.copy() + for f in builtin_functions: + function_names.remove(f) + + print(function_names) + + generator_functions = ['explode', 'explode_outer', 'inline', 'inline_outer', 'posexplode', + 'posexplode_outer', 'stack'] + + # unknown functions are not in the all_function_names list. Perhaps spark implemented this function but did not + # expose it to the user for current version. + support_list = {'scalar': {'partial': set(), 'unsupported': set(), 'unsupported_expr': set(), 'unknown': set()}, + 'aggregate': {'partial': set(), 'unsupported': set(), 'unsupported_expr': set(), 'unknown': set()}, + 'generator': {'partial': set(), 'unsupported': set(), 'unsupported_expr': set(), 'unknown': set()}, + 'window': {'partial': set(), 'unsupported': set(), 'unsupported_expr': set(), 'unknown': set()}} - scalar_support_list = {'partial': set(), 'unsupported': set()} - agg_support_list = {'partial': set(), 'unsupported': set()} - window_support_list = {'partial': set(), 'unsupported': set()} - generator_support_list = {'partial': set(), 'unsupported': set()} - try_to_binary_funcs = set(['unhex', 'encode', 'unbase64']) + try_to_binary_funcs = {'unhex', 'encode', 'unbase64'} unresolved = [] def filter_fallback_reasons(): - f = open(log_file, 'r') - lines = f.readlines() - lines - ll = [] + with open(log_file, 'r') as f: + lines = f.readlines() + + validation_logs = [] # Filter validation logs. for l in lines: - if ( - 'Validation failed for plan:' in l or 'Validation failed due to' in l or 'Validation failed at file' in l or l.startswith( - ' - ') or l.startswith(' |- ')) and 'Native validation failed:' not in l: - ll.append(l) + if l.startswith(' - ') and 'Native validation failed:' not in l or l.startswith(' |- '): + validation_logs.append(l) # Extract fallback reasons. - al = [] - for l in ll: + fallback_reasons = set() + for l in validation_logs: if 'due to:' in l: - al.append(l.split('due to:')[-1].strip()) + fallback_reasons.add(l.split('due to:')[-1].strip()) elif 'reason:' in l: - al.append(l.split('reason:')[-1].strip()) + fallback_reasons.add(l.split('reason:')[-1].strip()) else: - al.append(l) - al = sorted(set(al)) + fallback_reasons.add(l) + fallback_reasons = sorted(fallback_reasons) # Remove udf. - return list(filter(lambda x: 'Not supported python udf' not in x and 'Not supported scala udf' not in x, al)) + return list(filter(lambda x: 'Not supported python udf' not in x and 'Not supported scala udf' not in x, + fallback_reasons)) def function_name_tuple(function_name): return ( function_name, None if function_name not in function_to_classname else function_to_classname[function_name]) - def notFound(r): + def function_not_found(r): logging.log(logging.WARNING, f"No function name or class name found in: {r}") unresolved.append(r) @@ -672,14 +722,22 @@ def notFound(r): for item in jexpression_mappings: gluten_expressions[item._1()] = item._2() - for f in scalar_functions: - if f not in gluten_expressions.values() and function_to_classname[f] not in gluten_expressions.keys(): - scalar_support_list['unsupported'].add(function_name_tuple(f)) + for category in FUNCTION_CATEGORIES: + if category == 'scalar': + for f in functions[category]: + # TODO: Remove this filter as it may exclude supported expressions, such as + # RuntimeReplaceable and Builder. + if f not in builtin_functions and f not in gluten_expressions.values() and function_to_classname[ + f] not in gluten_expressions.keys(): + logging.log(logging.WARNING, f"Function not found in gluten expressions: {f}") + support_list[category]['unsupported'].add(function_name_tuple(f)) - for f in GLUTEN_RESTRICTIONS.keys(): - scalar_support_list['partial'].add(function_name_tuple(f)) + for f in GLUTEN_RESTRICTIONS[category].keys(): + support_list[category]['partial'].add(function_name_tuple(f)) for r in filter_fallback_reasons(): + ############## Scalar functions ############## + # Not supported: Expression not in ExpressionMappings. if 'Not supported to map spark function name to substrait function name' in r: pattern = r"class name: ([\w0-9]+)." @@ -689,41 +747,73 @@ def notFound(r): if match: class_name = match.group(1) if class_name in classname_to_function: - scalar_support_list['unsupported'].add((classname_to_function[class_name], class_name)) + function_name = classname_to_function[class_name] + if function_name in function_names: + support_list['scalar']['unsupported'].add((function_name, class_name)) + else: + support_list['scalar']['unknown'].add((function_name, class_name)) else: - logging.log(logging.INFO, f"No function name for class: {class_name}. Adding class name") - scalar_support_list['unsupported'].add((None, class_name)) + logging.log(logging.INFO, + f"No function name for class: {class_name}. Adding class name") + support_list['scalar']['unsupported_expr'].add(class_name) else: - notFound(r) + function_not_found(r) - elif 'Not support expression' in r: - pattern = r"Not support expression ([\w0-9]+)" + # Not supported: Function not registered in Velox. + elif 'Scalar function name not registered:' in r: + pattern = r"Scalar function name not registered:\s+([\w0-9]+)" - # Extract class name + # Extract the function name match = re.search(pattern, r) if match: - class_name = match.group(1) - if class_name in classname_to_function: - scalar_support_list['unsupported'].add((classname_to_function[class_name], class_name)) + function_name = match.group(1) + if function_name in function_names: + support_list['scalar']['unsupported'].add(function_name_tuple(function_name)) else: - logging.log(logging.INFO, f"No function name for class: {class_name}. Adding class name") - scalar_support_list['unsupported'].add((None, class_name)) + support_list['scalar']['unknown'].add(function_name_tuple(function_name)) else: - notFound(r) + function_not_found(r) - elif 'Scalar function name not registered:' in r: - pattern = r"Scalar function name not registered:\s+([\w0-9]+)" + # Partially supported: Function registered in Velox but not registered with specific arguments. + elif 'not registered with arguments:' in r: + pattern = r"Scalar function ([\w0-9]+) not registered with arguments:" # Extract the function name match = re.search(pattern, r) if match: function_name = match.group(1) - scalar_support_list['unsupported'].add(function_name_tuple(function_name)) + if function_name in function_names: + support_list['scalar']['partial'].add(function_name_tuple(function_name)) + else: + support_list['scalar']['unknown'].add(function_name_tuple(function_name)) else: - notFound(r) + function_not_found(r) + + # Not supported: Special case for unsupported expressions. + elif 'Not support expression' in r: + pattern = r"Not support expression ([\w0-9]+)" + # Extract class name + match = re.search(pattern, r) + + if match: + class_name = match.group(1) + if class_name in classname_to_function: + function_name = classname_to_function[class_name] + if function_name in function_names: + support_list['scalar']['unsupported'].add((function_name, class_name)) + else: + support_list['scalar']['unknown'].add((function_name, class_name)) + else: + logging.log(logging.INFO, + f"No function name for class: {class_name}. Adding class name") + support_list['scalar']['unsupported_expr'].add(class_name) + else: + function_not_found(r) + + # Not supported: Special case for unsupported functions. elif 'Function is not supported:' in r: pattern = r"Function is not supported:\s+([\w0-9]+)" @@ -732,34 +822,45 @@ def notFound(r): if match: function_name = match.group(1) - scalar_support_list['unsupported'].add(function_name_tuple(function_name)) + if function_name in function_names: + support_list['scalar']['unsupported'].add(function_name_tuple(function_name)) + else: + support_list['scalar']['unknown'].add(function_name_tuple(function_name)) else: - notFound(r) + function_not_found(r) - elif 'not registered with arguments:' in r: - pattern = r"Scalar function ([\w0-9]+) not registered with arguments:" + ############## Aggregate functions ############## + elif 'Could not find a valid substrait mapping' in r: + pattern = r"Could not find a valid substrait mapping name for ([\w0-9]+)\(" # Extract the function name match = re.search(pattern, r) if match: function_name = match.group(1) - scalar_support_list['partial'].add(function_name_tuple(function_name)) + if function_name in function_names: + support_list['aggregate']['unsupported'].add(function_name_tuple(function_name)) + else: + support_list['aggregate']['unknown'].add(function_name_tuple(function_name)) else: - notFound(r) + function_not_found(r) - elif 'Could not find a valid substrait mapping' in r: - pattern = r"Could not find a valid substrait mapping name for ([\w0-9]+)\(" + elif 'Unsupported aggregate mode' in r: + pattern = r"Unsupported aggregate mode: [\w]+ for ([\w0-9]+)" # Extract the function name match = re.search(pattern, r) if match: function_name = match.group(1) - agg_support_list['unsupported'].add(function_name_tuple(function_name)) + if function_name in function_names: + support_list['aggregate']['partial'].add(function_name_tuple(function_name)) + else: + support_list['aggregate']['unknown'].add(function_name_tuple(function_name)) else: - notFound(r) + function_not_found(r) + ############## Generator functions ############## elif 'Velox backend does not support this generator:' in r: pattern = r"Velox backend does not support this generator:\s+([\w0-9]+)" @@ -770,14 +871,15 @@ def notFound(r): class_name = match.group(1) function_name = class_name.lower() if function_name not in generator_functions: - generator_support_list['unsupported'].add((None, class_name)) + support_list['generator']['unknown'].add((None, class_name)) elif 'outer: true' in r: - generator_support_list['unsupported'].add((function_name + '_outer', None)) + support_list['generator']['unsupported'].add((function_name + '_outer', None)) else: - generator_support_list['unsupported'].add(function_name_tuple(function_name)) + support_list['generator']['unsupported'].add(function_name_tuple(function_name)) else: - notFound(r) + function_not_found(r) + ############## Special judgements ############## elif 'try_eval' in r and ' is not supported' in r: pattern = r"try_eval\((\w+)\) is not supported" match = re.search(pattern, r) @@ -790,48 +892,56 @@ def notFound(r): function_name = 'try_to_binary' p = function_name_tuple(function_name) if len(try_to_binary_funcs) == 0: - if p in scalar_support_list['partial']: - scalar_support_list['partial'].remove(p) - scalar_support_list['unsupported'].add(p) + if p in support_list['scalar']['partial']: + support_list['scalar']['partial'].remove(p) + support_list['scalar']['unsupported'].add(p) elif 'add' in function_name: function_name = 'try_add' - scalar_support_list['partial'].add(function_name_tuple(function_name)) + support_list['scalar']['partial'].add(function_name_tuple(function_name)) else: - notFound(r) + function_not_found(r) elif 'Pattern is not string literal for regexp_extract' == r: function_name = 'regexp_extract' - scalar_support_list['partial'].add(function_name_tuple(function_name)) + support_list['scalar']['partial'].add(function_name_tuple(function_name)) elif 'Pattern is not string literal for regexp_extract_all' == r: function_name = 'regexp_extract_all' - scalar_support_list['partial'].add(function_name_tuple(function_name)) + support_list['scalar']['partial'].add(function_name_tuple(function_name)) else: unresolved.append(r) - return scalar_support_list, agg_support_list, window_support_list, generator_support_list, unresolved + return support_list, unresolved -def generate_function_doc(category, function_support_list, output): - num_unsupported = len(list(filter(lambda x: x[0] is not None, function_support_list['unsupported']))) - num_unsupported_expression = len( - list(filter(lambda x: x[0] is None and x[1] is not None, function_support_list['unsupported']))) - num_partially_supported = len(list(filter(lambda x: x[0] is not None, function_support_list['partial']))) - num_supported = len(scalar_functions) - num_unsupported - num_partially_supported +def generate_function_doc(category, output): + def support_str(num_functions): + return f"{num_functions} functions" if num_functions > 1 else f"{num_functions} function" - logging.log(logging.WARNING, f'Number of {category} functions: {len(scalar_functions)}') - logging.log(logging.WARNING, f'Number of unsupported {category} functions: {num_unsupported}') - logging.log(logging.WARNING, f'Number of unsupported {category} expressions: {num_unsupported_expression}') - logging.log(logging.WARNING, f'Number of partially supported {category} function: {num_partially_supported}') + num_unsupported = len(list(filter(lambda x: x[0] is not None, support_list[category]['unsupported']))) + num_unsupported_expression = len(support_list[category]['unsupported_expr']) + num_unknown_function = len(support_list[category]['unknown']) + num_partially_supported = len(list(filter(lambda x: x[0] is not None, support_list[category]['partial']))) + num_supported = len(functions[category]) - num_unsupported - num_partially_supported + + logging.log(logging.WARNING, f'Number of {category} functions: {len(functions[category])}') logging.log(logging.WARNING, f'Number of fully supported {category} function: {num_supported}') + logging.log(logging.WARNING, f'Number of unsupported {category} functions: {num_unsupported}') + logging.log(logging.WARNING, + f'Number of partially supported {category} function: {num_partially_supported}') + logging.log(logging.WARNING, + f'Number of unsupported {category} expressions: {num_unsupported_expression}') + logging.log(logging.WARNING, + f'Number of unknown {category} function: {num_unknown_function}. List: {support_list[category]["unknown"]}') headers = ['Spark Functions', 'Spark Expressions', 'Status', 'Restrictions'] + partially_supports = '.' if not num_partially_supported else f' and partially supports {support_str(num_partially_supported)}.' lines = f'''# {category.capitalize()} Functions Support Status -**Out of {len(scalar_functions)} {category} functions in Spark 3.5, Gluten currently fully supports {num_supported} functions and partially supports {num_partially_supported} functions.** +**Out of {len(functions[category])} {category} functions in Spark 3.5, Gluten currently fully supports {support_str(num_supported)}{partially_supports}** ''' @@ -842,12 +952,12 @@ def generate_function_doc(category, function_support_list, output): for f in sorted(group_functions[g]): classname = '' if f not in spark_function_map else spark_function_map[f] support = None - for item in function_support_list['partial']: + for item in support_list[category]['partial']: if item[0] and item[0] == f or item[1] and item[1] == classname: support = 'PS' break if support is None: - for item in function_support_list['unsupported']: + for item in support_list[category]['unsupported']: if item[0] and item[0] == f or item[1] and item[1] == classname: support = '' break @@ -857,7 +967,8 @@ def generate_function_doc(category, function_support_list, output): f = '|' elif f == '||': f = '||' - data.append([f, classname, support, '' if f not in GLUTEN_RESTRICTIONS else GLUTEN_RESTRICTIONS[f]]) + data.append([f, classname, support, + '' if f not in GLUTEN_RESTRICTIONS[category] else GLUTEN_RESTRICTIONS[category][f]]) table = tabulate.tabulate(data, headers, tablefmt="github") lines += table + '\n\n' @@ -865,14 +976,25 @@ def generate_function_doc(category, function_support_list, output): fd.write(lines) -def run_GlutenSQLQueryTestSuite(): +def run_test_suites(categories): log4j_properties_file = os.path.abspath( os.path.join(os.path.dirname(os.path.abspath(__file__)), 'log4j2.properties')) + + suite_list = [] + for category in categories: + if FUNCTION_SUITES[category]: + suite_list.append(','.join([FUNCTION_SUITE_PACKAGE + name for name in FUNCTION_SUITES[category]])) + suites = ','.join(suite_list) + + if not suites: + logging.log(logging.WARNING, "No test suites to run.") + return + command = [ "mvn", "test", "-Pspark-3.5", "-Pspark-ut", "-Pbackends-velox", f"-DargLine=-Dspark.test.home={spark_home} -Dlog4j2.configurationFile=file:{log4j_properties_file}", - "-DwildcardSuites=org.apache.spark.sql.GlutenSQLQueryTestSuite", + f"-DwildcardSuites={suites}", "-Dtest=none", "-Dsurefire.failIfNoSpecifiedTests=false" ] @@ -880,23 +1002,41 @@ def run_GlutenSQLQueryTestSuite(): subprocess.Popen(command, cwd=gluten_home).wait() +def get_maven_project_version(): + result = subprocess.run( + ['mvn', 'help:evaluate', '-Dexpression=project.version', '-q', '-DforceStdout'], + capture_output=True, + text=True, + cwd=gluten_home + ) + if result.returncode == 0: + version = result.stdout.strip() + return version + else: + raise RuntimeError(f"Error running Maven command: {result.stderr}") + + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--spark_home", type=str, required=True, help="Directory to spark source code for the newest supported spark version in Gluten. " "It's required the spark project has been built from source.") - parser.add_argument("--skip_run_test_suite", action='store_true', + parser.add_argument("--skip_test_suite", action='store_true', help="Whether to run test suite. Set to False to skip running the test suite.") + parser.add_argument("--categories", type=str, default=','.join(FUNCTION_CATEGORIES), + help="Use comma-separated string to specify the function categories to generate the docs. " + "Default is all categories.") args = parser.parse_args() spark_home = args.spark_home findspark.init(spark_home) gluten_home = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), '../../')) - if not args.skip_run_test_suite: - run_GlutenSQLQueryTestSuite() + if not args.skip_test_suite: + run_test_suites(args.categories.split(',')) - gluten_jar = os.path.join(gluten_home, 'package', 'target', 'gluten-package-1.5.0-SNAPSHOT.jar') + gluten_version = get_maven_project_version() + gluten_jar = os.path.join(gluten_home, 'package', 'target', f'gluten-package-{gluten_version}.jar') if not os.path.exists(gluten_jar): raise Exception(f"Gluten jar not found at {gluten_jar}") @@ -910,10 +1050,7 @@ def run_GlutenSQLQueryTestSuite(): # Generate the function list to the global variables. all_function_names = [] - scalar_functions = [] - agg_functions = [] - window_functions = [] - generator_functions = [] + functions = {'scalar': set(), 'aggregate': set(), 'window': set(), 'generator': set()} classname_to_function = {} function_to_classname = {} group_functions = {} @@ -921,8 +1058,10 @@ def run_GlutenSQLQueryTestSuite(): spark_function_map = create_spark_function_map() - scalar_support_list, agg_support_list, window_support_list, generator_support_list, unresolved = parse_logs( - os.path.join(gluten_home, 'gluten-ut', 'spark35', 'target', 'gen-function-support-docs-tests.log')) + # support_list, unresolved = parse_logs( + # os.path.join(gluten_home, 'gluten-ut', 'spark35', 'target', 'gen-function-support-docs-tests.log')) + support_list, unresolved = parse_logs('/Users/rong/workspace/log/tmp5.log') - generate_function_doc('scalar', scalar_support_list, - os.path.join(gluten_home, 'docs', 'velox-backend-scalar-function-support.md')) + for category in args.categories.split(','): + generate_function_doc(category, + os.path.join(gluten_home, 'docs', f'velox-backend-{category}-function-support.md'))