From 851921a0138feaa5e5b3a8c26add3b4eda2aedb3 Mon Sep 17 00:00:00 2001 From: lgbo-ustc Date: Fri, 3 Jan 2025 15:11:41 +0800 Subject: [PATCH] replace from_json with get_json_object --- .../backendsapi/clickhouse/CHBackend.scala | 7 ++ .../backendsapi/clickhouse/CHRuleApi.scala | 1 + .../BasicExpressionRewriteRule.scala | 77 +++++++++++++++++++ .../GlutenFunctionValidateSuite.scala | 21 +++++ 4 files changed, 106 insertions(+) create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/extension/BasicExpressionRewriteRule.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index 962698759320..316c205d5851 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -393,6 +393,13 @@ object CHBackendSettings extends BackendSettingsApi with Logging { ) } + def enableReplaceFromJsonWithGetJsonObject(): Boolean = { + SparkEnv.get.conf.getBoolean( + CHConf.runtimeConfig("enable_replace_from_json_with_get_json_object"), + defaultValue = true + ) + } + override def enableNativeWriteFiles(): Boolean = { GlutenConfig.get.enableNativeWriter.getOrElse(false) } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index c79931fa4e93..40344e96e768 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -63,6 +63,7 @@ object CHRuleApi { injector.injectResolutionRule(spark => new RewriteToDateExpresstionRule(spark)) injector.injectResolutionRule(spark => new RewriteDateTimestampComparisonRule(spark)) injector.injectResolutionRule(spark => new CollapseGetJsonObjectExpressionRule(spark)) + injector.injectResolutionRule(spark => new RepalceFromJsonWithGetJsonObject(spark)) injector.injectOptimizerRule(spark => new CommonSubexpressionEliminateRule(spark)) injector.injectOptimizerRule(spark => new ExtendedColumnPruning(spark)) injector.injectOptimizerRule(spark => CHAggregateFunctionRewriteRule(spark)) diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/BasicExpressionRewriteRule.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/BasicExpressionRewriteRule.scala new file mode 100644 index 000000000000..5943582b1ee5 --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/BasicExpressionRewriteRule.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.gluten.backendsapi.clickhouse.CHBackendSettings + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/* + * This file includes some rules to repace expressions in more efficient way. + */ + +// Try to replace `from_json` with `get_json_object` if possible. +class RepalceFromJsonWithGetJsonObject(spark: SparkSession) extends Rule[LogicalPlan] with Logging { + override def apply(plan: LogicalPlan): LogicalPlan = { + if (!CHBackendSettings.enableReplaceFromJsonWithGetJsonObject || !plan.resolved) { + plan + } else { + visitPlan(plan) + } + } + + def visitPlan(plan: LogicalPlan): LogicalPlan = { + val newPlan = plan match { + case project: Project => + val newProjectList = + project.projectList.map(expr => visitExpression(expr).asInstanceOf[NamedExpression]) + project.copy(projectList = newProjectList, child = visitPlan(project.child)) + case filter: Filter => + val newCondition = visitExpression(filter.condition) + Filter(newCondition, visitPlan(filter.child)) + case other => + other.withNewChildren(other.children.map(visitPlan)) + } + // Some plan nodes have tags, we need to copy the tags to the new ones. + newPlan.copyTagsFrom(plan) + newPlan + } + + def visitExpression(expr: Expression): Expression = { + expr match { + case getMapValue: GetMapValue + if getMapValue.child.isInstanceOf[JsonToStructs] && + getMapValue.child.dataType.isInstanceOf[MapType] && + getMapValue.child.dataType.asInstanceOf[MapType].valueType.isInstanceOf[StringType] && + getMapValue.key.isInstanceOf[Literal] && + getMapValue.key.dataType.isInstanceOf[StringType] => + val child = visitExpression(getMapValue.child.asInstanceOf[JsonToStructs].child) + val key = UTF8String.fromString(s"$$.${getMapValue.key.asInstanceOf[Literal].value}") + GetJsonObject(child, Literal(key, StringType)) + case literal: Literal => literal + case attr: Attribute => attr + case other => + other.withNewChildren(other.children.map(visitExpression)) + } + } +} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala index f84557e6e97c..84c92d1e047b 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenFunctionValidateSuite.scala @@ -995,4 +995,25 @@ class GlutenFunctionValidateSuite extends GlutenClickHouseWholeStageTransformerS } compareResultsAgainstVanillaSpark(sql, true, checkProjects, false) } + + test("GLUTEN-8406 replace from_json with get_json_object") { + withTable("test_8406") { + spark.sql("create table test_8406(x string) using parquet") + val insert_sql = + """ + |insert into test_8406 values + |('{"a":1}'), + |('{"a":2'), + |('{"b":3}'), + |('{"a":"5"}'), + |('{"a":{"x":1}}') + |""".stripMargin + spark.sql(insert_sql) + val sql = + """ + |select from_json(x, 'Map')['a'] from test_8406 + |""".stripMargin + compareResultsAgainstVanillaSpark(sql, true, { _ => }) + } + } }