diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 8136488b6afa..0ec44a207a43 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -32,6 +32,7 @@ import org.apache.gluten.vectorized.{BlockOutputStream, CHColumnarBatchSerialize import org.apache.spark.ShuffleDependency import org.apache.spark.internal.Logging +import org.apache.spark.memory.SparkMemoryUtil import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper, HashPartitioningWrapper} @@ -49,7 +50,7 @@ import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec import org.apache.spark.sql.execution.datasources.{FileFormat, HadoopFsRelation} import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.clickhouse.source.DeltaMergeTreeFileFormat -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildSideRelation, HashedRelationBroadcastMode} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.utils.{CHExecUtil, PushDownUtil} @@ -559,9 +560,11 @@ class CHSparkPlanExecApi extends SparkPlanExecApi with Logging { val batches = countsAndBytes.map(_._2) val totalBatchesSize = batches.map(_.length).sum val rawSize = dataSize.value - if (rawSize >= BroadcastExchangeExec.MAX_BROADCAST_TABLE_BYTES) { + if (rawSize >= GlutenConfig.get.maxBroadcastTableSize) { throw new GlutenException( - s"Cannot broadcast the table that is larger than 8GB: $rawSize bytes") + "Cannot broadcast the table that is larger than " + + s"${SparkMemoryUtil.bytesToString(GlutenConfig.get.maxBroadcastTableSize)}: " + + s"${SparkMemoryUtil.bytesToString(rawSize)}") } if ((rawSize == 0 && totalBatchesSize != 0) || totalBatchesSize < 0) { throw new GlutenException( diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 77cb2d68bed1..0dc83e98d48f 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -28,6 +28,7 @@ import org.apache.gluten.vectorized.{ColumnarBatchSerializer, ColumnarBatchSeria import org.apache.spark.{ShuffleDependency, SparkException} import org.apache.spark.api.python.{ColumnarArrowEvalPythonExec, PullOutArrowEvalPythonPreProjectHelper} +import org.apache.spark.memory.SparkMemoryUtil import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper} @@ -41,7 +42,7 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.FileFormat -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBroadcastMode} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ArrowEvalPythonExec @@ -649,9 +650,11 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { .filter(_.getNumRows != 0) .collect val rawSize = serialized.map(_.getSerialized.length).sum - if (rawSize >= BroadcastExchangeExec.MAX_BROADCAST_TABLE_BYTES) { + if (rawSize >= GlutenConfig.get.maxBroadcastTableSize) { throw new SparkException( - s"Cannot broadcast the table that is larger than 8GB: ${rawSize >> 30} GB") + "Cannot broadcast the table that is larger than " + + s"${SparkMemoryUtil.bytesToString(GlutenConfig.get.maxBroadcastTableSize)}: " + + s"${SparkMemoryUtil.bytesToString(rawSize)}") } numOutputRows += serialized.map(_.getNumRows).sum dataSize += rawSize diff --git a/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala index 449db38e24d5..22d269cc25b3 100644 --- a/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/config/GlutenConfig.scala @@ -17,7 +17,7 @@ package org.apache.gluten.config import org.apache.spark.internal.Logging -import org.apache.spark.network.util.ByteUnit +import org.apache.spark.network.util.{ByteUnit, JavaUtils} import org.apache.spark.sql.internal.{GlutenConfigUtil, SQLConf, SQLConfProvider} import com.google.common.collect.ImmutableList @@ -364,6 +364,9 @@ class GlutenConfig(conf: SQLConf) extends Logging { def enableColumnarRange: Boolean = getConf(COLUMNAR_RANGE_ENABLED) def enableColumnarCollectLimit: Boolean = getConf(COLUMNAR_COLLECT_LIMIT_ENABLED) def getSupportedFlattenedExpressions: String = getConf(GLUTEN_SUPPORTED_FLATTENED_FUNCTIONS) + + def maxBroadcastTableSize: Long = + JavaUtils.byteStringAsBytes(conf.getConfString(SPARK_MAX_BROADCAST_TABLE_SIZE, "8GB")) } object GlutenConfig { @@ -442,6 +445,7 @@ object GlutenConfig { val SPARK_SHUFFLE_SPILL_DISK_WRITE_BUFFER_SIZE = "spark.shuffle.spill.diskWriteBufferSize" val SPARK_SHUFFLE_SPILL_COMPRESS = "spark.shuffle.spill.compress" val SPARK_SHUFFLE_SPILL_COMPRESS_DEFAULT: Boolean = true + val SPARK_MAX_BROADCAST_TABLE_SIZE = "spark.sql.maxBroadcastTableSize" def get: GlutenConfig = { new GlutenConfig(SQLConf.get)