diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index e6ee3f79ee32..016664ed97e1 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -61,6 +61,7 @@ object CHRuleApi { (spark, parserInterface) => new GlutenClickhouseSqlParser(spark, parserInterface)) injector.injectResolutionRule(spark => new RewriteToDateExpresstionRule(spark)) injector.injectResolutionRule(spark => new RewriteDateTimestampComparisonRule(spark)) + injector.injectResolutionRule(spark => new CollapseGetJsonObjectExpressionRule(spark)) injector.injectOptimizerRule(spark => new CommonSubexpressionEliminateRule(spark)) injector.injectOptimizerRule(spark => new ExtendedColumnPruning(spark)) injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark)) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index e1287c8b6d86..d900bc000c05 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -21,7 +21,9 @@ import org.apache.gluten.utils.UTSystemParameters import org.apache.spark.SparkConf import org.apache.spark.sql.{DataFrame, Row, TestUtils} +import org.apache.spark.sql.catalyst.expressions.{Expression, GetJsonObject, Literal} import org.apache.spark.sql.catalyst.optimizer.{ConstantFolding, NullPropagation} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} import org.apache.spark.sql.execution.datasources.v2.clickhouse.ClickHouseConfig import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -90,7 +92,9 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS Row(1.011, 5, "{\"a\":\"b\", \"x\":{\"i\":1}}"), Row(1.011, 5, "{\"a\":\"b\", \"x\":{\"i\":2}}"), Row(1.011, 5, "{\"a\":1, \"x\":{\"i\":2}}"), - Row(1.0, 5, "{\"a\":\"{\\\"x\\\":5}\"}") + Row(1.0, 5, "{\"a\":\"{\\\"x\\\":5}\"}"), + Row(1.0, 6, "{\"a\":{\"y\": 5, \"z\": {\"m\":1, \"n\": {\"p\": \"k\"}}}"), + Row(1.0, 7, "{\"a\":[{\"y\": 5}, {\"z\":[{\"m\":1, \"n\":{\"p\":\"k\"}}]}]}") )) val dfParquet = spark.createDataFrame(data, schema) dfParquet @@ -268,6 +272,85 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS } } + test("GLUTEN-8304: Optimize nested get_json_object") { + def checkExpression(expr: Expression, path: String): Boolean = { + expr match { + case g: GetJsonObject + if g.path.isInstanceOf[Literal] && g.path.dataType.isInstanceOf[StringType] => + g.path.asInstanceOf[Literal].value.toString.equals(path) || g.children.exists( + c => checkExpression(c, path)) + case _ => + if (expr.children.isEmpty) { + false + } else { + expr.children.exists(c => checkExpression(c, path)) + } + } + } + def checkPlan(plan: LogicalPlan, path: String): Boolean = plan match { + case p: Project => + p.projectList.exists(x => checkExpression(x, path)) || checkPlan(p.child, path) + case f: Filter => + checkExpression(f.condition, path) || checkPlan(f.child, path) + case _ => + if (plan.children.isEmpty) { + false + } else { + plan.children.exists(c => checkPlan(c, path)) + } + } + def checkGetJsonObjectPath(df: DataFrame, path: String): Boolean = { + checkPlan(df.queryExecution.analyzed, path) + } + withSQLConf(("spark.gluten.sql.collapseGetJsonObject.enabled", "true")) { + runQueryAndCompare( + "select get_json_object(get_json_object(string_field1, '$.a'), '$.y') " + + " from json_test where int_field1 = 6") { + x => assert(checkGetJsonObjectPath(x, "$.a.y")) + } + runQueryAndCompare( + "select get_json_object(get_json_object(string_field1, '$[a]'), '$[y]') " + + " from json_test where int_field1 = 6") { + x => assert(checkGetJsonObjectPath(x, "$[a][y]")) + } + runQueryAndCompare( + "select get_json_object(get_json_object(get_json_object(string_field1, " + + "'$.a'), '$.y'), '$.z') from json_test where int_field1 = 6") { + x => assert(checkGetJsonObjectPath(x, "$.a.y.z")) + } + runQueryAndCompare( + "select get_json_object(get_json_object(get_json_object(string_field1, '$.a')," + + " string_field1), '$.z') from json_test where int_field1 = 6", + noFallBack = false + )(x => assert(checkGetJsonObjectPath(x, "$.a") && checkGetJsonObjectPath(x, "$.z"))) + runQueryAndCompare( + "select get_json_object(get_json_object(get_json_object(string_field1, " + + " string_field1), '$.a'), '$.z') from json_test where int_field1 = 6", + noFallBack = false + )(x => assert(checkGetJsonObjectPath(x, "$.a.z"))) + runQueryAndCompare( + "select get_json_object(get_json_object(get_json_object(" + + " substring(string_field1, 10), '$.a'), '$.z'), string_field1) " + + " from json_test where int_field1 = 6", + noFallBack = false + )(x => assert(checkGetJsonObjectPath(x, "$.a.z"))) + runQueryAndCompare( + "select get_json_object(get_json_object(string_field1, '$.a[0]'), '$.y') " + + " from json_test where int_field1 = 7") { + x => assert(checkGetJsonObjectPath(x, "$.a[0].y")) + } + runQueryAndCompare( + "select get_json_object(get_json_object(get_json_object(string_field1, " + + " '$.a[1]'), '$.z[1]'), '$.n') from json_test where int_field1 = 7") { + x => assert(checkGetJsonObjectPath(x, "$.a[1].z[1].n")) + } + runQueryAndCompare( + "select * from json_test where " + + " get_json_object(get_json_object(get_json_object(string_field1, '$.a'), " + + "'$.y'), '$.z') != null")(x => assert(checkGetJsonObjectPath(x, "$.a.y.z"))) + } + } + test("Test get_json_object 10") { runQueryAndCompare("SELECT get_json_object(string_field1, '$.12345') from json_test") { _ => } runQueryAndCompare("SELECT get_json_object(string_field1, '$.123.abc') from json_test") { _ => } diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index 72d769c999e8..d36cb6f55305 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -58,6 +58,7 @@ object VeloxRuleApi { // Inject the regular Spark rules directly. injector.injectOptimizerRule(CollectRewriteRule.apply) injector.injectOptimizerRule(HLLRewriteRule.apply) + injector.injectOptimizerRule(CollapseGetJsonObjectExpressionRule.apply) injector.injectPostHocResolutionRule(ArrowConvertorRule.apply) } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/CollapseGetJsonObjectExpressionRule.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/CollapseGetJsonObjectExpressionRule.scala new file mode 100644 index 000000000000..4c84f4214904 --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/CollapseGetJsonObjectExpressionRule.scala @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.columnar + +import org.apache.gluten.config.GlutenConfig + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * The rule is aimed to collapse nested `get_json_object` functions as one for optimization, e.g. + * get_json_object(get_json_object(d, '$.a'), '$.b') => get_json_object(d, '$.a.b'). And we should + * notice that some case can not be applied to this rule: + * - get_json_object(get_json_object({"a":"{\\\"x\\\":5}"}', '$.a'), '$.x'), the json string has + * backslashes to escape quotes ; + * - get_json_object(get_json_object('{"a.b": 0}', '$.a), '$.b'), the json key contains dot + * character(.) and it's same as the collapsed json path; + */ +case class CollapseGetJsonObjectExpressionRule(spark: SparkSession) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = { + if ( + plan.resolved + && GlutenConfig.get.enableCollapseNestedGetJsonObject + ) { + visitPlan(plan) + } else { + plan + } + } + + private def visitPlan(plan: LogicalPlan): LogicalPlan = plan match { + case p: Project => + var newProjectList = Seq.empty[NamedExpression] + p.projectList.foreach { + case a: Alias if a.child.isInstanceOf[GetJsonObject] => + newProjectList :+= optimizeNestedFunctions(a).asInstanceOf[NamedExpression] + case p => + newProjectList :+= p + } + val newChild = visitPlan(p.child) + Project(newProjectList, newChild) + case f: Filter => + val newCond = optimizeNestedFunctions(f.condition) + val newChild = visitPlan(f.child) + Filter(newCond, newChild) + case other => + val children = other.children.map(visitPlan) + plan.withNewChildren(children) + } + + private def optimizeNestedFunctions( + expr: Expression, + path: String = "", + isNested: Boolean = false): Expression = { + + def getPathLiteral(path: Expression): Option[String] = path match { + case l: Literal => + Option.apply(l.value.toString) + case _ => + Option.empty + } + + expr match { + case g: GetJsonObject => + val gPath = getPathLiteral(g.path).orNull + var newPath = "" + if (gPath != null) { + newPath = gPath.replace("$", "") + path + } + val res = optimizeNestedFunctions(g.json, newPath, isNested = true) + if (gPath != null) { + res + } else { + var newChildren = Seq.empty[Expression] + newChildren :+= res + newChildren :+= g.path + val newExpr = g.withNewChildren(newChildren) + if (path.nonEmpty) { + GetJsonObject(newExpr, Literal.apply("$" + path)) + } else { + newExpr + } + } + case _ => + val newChildren = expr.children.map(x => optimizeNestedFunctions(x, path)) + val newExpr = expr.withNewChildren(newChildren) + if (isNested && path.nonEmpty) { + val pathExpr = Literal.apply("$" + path) + GetJsonObject(newExpr, pathExpr) + } else { + newExpr + } + } + } +} diff --git a/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala index b1337c92eec1..4513364648eb 100644 --- a/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala @@ -105,6 +105,9 @@ class GlutenConfig(conf: SQLConf) extends Logging { def enableRewriteDateTimestampComparison: Boolean = conf.getConf(ENABLE_REWRITE_DATE_TIMESTAMP_COMPARISON) + def enableCollapseNestedGetJsonObject: Boolean = + conf.getConf(ENABLE_COLLAPSE_GET_JSON_OBJECT) + def enableCHRewriteDateConversion: Boolean = conf.getConf(ENABLE_CH_REWRITE_DATE_CONVERSION) @@ -1966,6 +1969,13 @@ object GlutenConfig { .booleanConf .createWithDefault(true) + val ENABLE_COLLAPSE_GET_JSON_OBJECT = + buildConf("spark.gluten.sql.collapseGetJsonObject.enabled") + .internal() + .doc("Collapse nested get_json_object functions as one for optimization.") + .booleanConf + .createWithDefault(false) + val ENABLE_CH_REWRITE_DATE_CONVERSION = buildConf("spark.gluten.sql.columnar.backend.ch.rewrite.dateConversion") .internal()