diff --git a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala index d01a23712e96..eb4e069be22e 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/GlutenPlugin.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.util.JavaUtils import org.apache.spark.softaffinity.SoftAffinityListener import org.apache.spark.sql.execution.ui.{GlutenSQLAppStatusListener, GlutenUIUtils} -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{SparkConfigUtil, SQLConf} import org.apache.spark.sql.internal.StaticSQLConf.SPARK_SESSION_EXTENSIONS import org.apache.spark.task.TaskResources import org.apache.spark.util.SparkResourceUtil @@ -137,13 +137,13 @@ private[gluten] class GlutenDriverPlugin extends DriverPlugin with Logging { private def setPredefinedConfigs(conf: SparkConf): Unit = { // Spark SQL extensions - val extensions = if (conf.contains(SPARK_SESSION_EXTENSIONS.key)) { - s"${conf.get(SPARK_SESSION_EXTENSIONS.key)}," + - s"${GlutenSessionExtensions.GLUTEN_SESSION_EXTENSION_NAME}" - } else { - s"${GlutenSessionExtensions.GLUTEN_SESSION_EXTENSION_NAME}" + val extensionSeq = + SparkConfigUtil.getEntryValue(conf, SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty) + if (!extensionSeq.toSet.contains(GlutenSessionExtensions.GLUTEN_SESSION_EXTENSION_NAME)) { + conf.set( + SPARK_SESSION_EXTENSIONS.key, + (extensionSeq :+ GlutenSessionExtensions.GLUTEN_SESSION_EXTENSION_NAME).mkString(",")) } - conf.set(SPARK_SESSION_EXTENSIONS.key, extensions) // adaptive custom cost evaluator class val enableGlutenCostEvaluator = conf.getBoolean( diff --git a/shims/common/src/main/scala/org/apache/spark/sql/internal/SparkConfigUtil.scala b/shims/common/src/main/scala/org/apache/spark/sql/internal/SparkConfigUtil.scala new file mode 100644 index 000000000000..945174073e48 --- /dev/null +++ b/shims/common/src/main/scala/org/apache/spark/sql/internal/SparkConfigUtil.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal + +import org.apache.spark.SparkConf +import org.apache.spark.internal.config.ConfigEntry + +object SparkConfigUtil { + def getEntryValue[T](conf: SparkConf, entry: ConfigEntry[T]): T = { + conf.get(entry) + } +}