From e6a3ccefcc4ed0fbb02358628b1c438055dadcee Mon Sep 17 00:00:00 2001 From: Felix Loesing Date: Thu, 28 Aug 2025 18:47:04 -0700 Subject: [PATCH] [GLUTEN-9671][VL] Fix broadcast exchange stackoverflow due to Kryo serialization (#10541) This pull request introduces a safer and more robust approach for handling Spark's BroadcastMode during serialization. The main improvement is the introduction of a new SafeBroadcastMode abstraction and related utilities, which help avoid serialization issues that caused a Stackoverflow exception during broadcast exchanges. BroadcastMode was introduced in this PR that caused the issue we observed. HashedRelationBroadcastMode embeds Catalyst expression trees, which are not safe to Kryo-serialize when running with spark.kryo.referenceTracking=false (default internally). With this change, the broadcast payload now contains only primitives and byte arrays (no Catalyst trees). For bound keys, we serialize just column ordinals (+ null-aware flag) and for computed keys (e.g., upper(col)), we serialize the key expressions once as Java bytes and deserialize only where needed to build projections. (cherry picked from commit 91c52e15f16593747e918145258ebe1408cb8ea2) --- .../velox/VeloxSparkPlanExecApi.scala | 2 +- .../sql/execution/BroadcastModeUtils.scala | 134 ++++++++++++++++++ .../spark/sql/execution/BroadcastUtils.scala | 4 +- .../execution/ColumnarBuildSideRelation.scala | 57 ++++++-- .../UnsafeColumnarBuildSideRelation.scala | 91 ++++++++++-- 5 files changed, 261 insertions(+), 27 deletions(-) create mode 100644 backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastModeUtils.scala diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index aaa6836f145f..a3b20062ac11 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -678,7 +678,7 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { dataSize += rawSize if (useOffheapBroadcastBuildRelation) { TaskResources.runUnsafe { - new UnsafeColumnarBuildSideRelation(child.output, serialized.flatMap(_.getSerialized), mode) + UnsafeColumnarBuildSideRelation(child.output, serialized.flatMap(_.getSerialized), mode) } } else { ColumnarBuildSideRelation(child.output, serialized.flatMap(_.getSerialized), mode) diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastModeUtils.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastModeUtils.scala new file mode 100644 index 000000000000..d0ae9a6832a8 --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastModeUtils.scala @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression} +import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, IdentityBroadcastMode} +import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, IOException, ObjectInputStream, ObjectOutputStream} + +/** + * Provides serialization-safe representations of BroadcastMode to avoid issues with circular + * references in complex expression trees during Kryo serialization. + */ +sealed trait SafeBroadcastMode extends Serializable + +/** Safe representation of IdentityBroadcastMode */ +case object IdentitySafeBroadcastMode extends SafeBroadcastMode + +/** + * Safe wrapper for HashedRelationBroadcastMode. Stores only column ordinals instead of full + * BoundReference expressions. + */ +final case class HashSafeBroadcastMode(ordinals: Array[Int], isNullAware: Boolean) + extends SafeBroadcastMode + +/** + * Safe wrapper for HashedRelationBroadcastMode when keys are not simple BoundReferences. Stores key + * expressions as serialized Java bytes. + */ +final case class HashExprSafeBroadcastMode(exprBytes: Array[Byte], isNullAware: Boolean) + extends SafeBroadcastMode + +object BroadcastModeUtils extends Logging { + + /** + * Converts a BroadcastMode to its SafeBroadcastMode equivalent. Uses ordinals for simple + * BoundReferences, otherwise serializes the expressions. + */ + private[execution] def toSafe(mode: BroadcastMode): SafeBroadcastMode = mode match { + case IdentityBroadcastMode => + IdentitySafeBroadcastMode + case HashedRelationBroadcastMode(keys, isNullAware) => + // Fast path: all keys are already BoundReference(i, ..,..). + val ords = keys.collect { case BoundReference(ord, _, _) => ord } + if (ords.size == keys.size) { + HashSafeBroadcastMode(ords.toArray, isNullAware) + } else { + // Fallback: store the key expressions as Java-serialized bytes. + HashExprSafeBroadcastMode(serializeExpressions(keys), isNullAware) + } + + case other => + throw new IllegalArgumentException(s"Unsupported BroadcastMode: $other") + } + + /** Converts a SafeBroadcastMode to its BroadcastMode equivalent. */ + private[execution] def fromSafe(safe: SafeBroadcastMode, output: Seq[Attribute]): BroadcastMode = + safe match { + case IdentitySafeBroadcastMode => + IdentityBroadcastMode + + case HashSafeBroadcastMode(ords, isNullAware) => + val bound = ords.map(i => BoundReference(i, output(i).dataType, output(i).nullable)).toSeq + HashedRelationBroadcastMode(bound, isNullAware) + + case HashExprSafeBroadcastMode(bytes, isNullAware) => + HashedRelationBroadcastMode(deserializeExpressions(bytes), isNullAware) + } + + // Helpers for expression serialization (used in HashExprSafeBroadcastMode) + private[execution] def serializeExpressions(keys: Seq[Expression]): Array[Byte] = { + val bos = new ByteArrayOutputStream() + var oos: ObjectOutputStream = null + try { + oos = new ObjectOutputStream(bos) + oos.writeObject(keys) + oos.flush() + bos.toByteArray + } catch { + case e @ (_: IOException | _: ClassNotFoundException | _: ClassCastException) => + logError( + s"Failed to serialize expressions for BroadcastMode. Expression count: ${keys.length}", + e) + throw new RuntimeException("Failed to serialize expressions for BroadcastMode", e) + case e: Exception => + logError( + s"Unexpected error during expression serialization. Expression count: ${keys.length}", + e) + throw e + } finally { + if (oos != null) oos.close() + bos.close() + } + } + + private[execution] def deserializeExpressions(bytes: Array[Byte]): Seq[Expression] = { + val bis = new ByteArrayInputStream(bytes) + var ois: ObjectInputStream = null + try { + ois = new ObjectInputStream(bis) + ois.readObject().asInstanceOf[Seq[Expression]] + } catch { + case e @ (_: IOException | _: ClassNotFoundException | _: ClassCastException) => + logError( + s"Failed to deserialize expressions for BroadcastMode. Data size: ${bytes.length} bytes", + e) + throw new RuntimeException("Failed to deserialize expressions for BroadcastMode", e) + case e: Exception => + logError( + s"Unexpected error during expression deserialization. Data size: ${bytes.length} bytes", + e) + throw e + } finally { + if (ois != null) ois.close() + bis.close() + } + } +} diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala index 155227db4145..342c9694f0f7 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/BroadcastUtils.scala @@ -108,7 +108,7 @@ object BroadcastUtils { result.getSerialized } if (useOffheapBuildRelation) { - new UnsafeColumnarBuildSideRelation( + UnsafeColumnarBuildSideRelation( SparkShimLoader.getSparkShims.attributesFromStruct(schema), serialized, mode) @@ -134,7 +134,7 @@ object BroadcastUtils { result.getSerialized } if (useOffheapBuildRelation) { - new UnsafeColumnarBuildSideRelation( + UnsafeColumnarBuildSideRelation( SparkShimLoader.getSparkShims.attributesFromStruct(schema), serialized, mode) diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala index cd49ed30ea62..59a9cb2b00fc 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala @@ -26,11 +26,9 @@ import org.apache.gluten.utils.ArrowAbiUtil import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSeq, BindReferences, BoundReference, Expression, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode -import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode -import org.apache.spark.sql.execution.joins.BuildSideRelation -import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode +import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBroadcastMode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.utils.SparkArrowUtil import org.apache.spark.sql.vectorized.ColumnarBatch @@ -40,17 +38,56 @@ import org.apache.arrow.c.ArrowSchema import scala.collection.JavaConverters.asScalaIteratorConverter +object ColumnarBuildSideRelation { + // Keep constructor with BroadcastMode for compatibility + def apply( + output: Seq[Attribute], + batches: Array[Array[Byte]], + mode: BroadcastMode): ColumnarBuildSideRelation = { + val boundMode = mode match { + case HashedRelationBroadcastMode(keys, isNullAware) => + // Bind each key to the build-side output so simple cols become BoundReference + val boundKeys: Seq[Expression] = + keys.map(k => BindReferences.bindReference(k, AttributeSeq(output))) + HashedRelationBroadcastMode(boundKeys, isNullAware) + case m => + m // IdentityBroadcastMode, etc. + } + new ColumnarBuildSideRelation(output, batches, BroadcastModeUtils.toSafe(boundMode)) + } +} + case class ColumnarBuildSideRelation( output: Seq[Attribute], batches: Array[Array[Byte]], - mode: BroadcastMode) + safeBroadcastMode: SafeBroadcastMode) extends BuildSideRelation { - private def transformProjection: UnsafeProjection = { - mode match { - case HashedRelationBroadcastMode(k, _) => UnsafeProjection.create(k) - case IdentityBroadcastMode => UnsafeProjection.create(output, output) - } + // Rebuild the real BroadcastMode on demand; never serialize it. + @transient override lazy val mode: BroadcastMode = + BroadcastModeUtils.fromSafe(safeBroadcastMode, output) + + // If we stored expression bytes, deserialize once and cache locally (not serialized). + @transient private lazy val exprKeysFromBytes: Option[Seq[Expression]] = safeBroadcastMode match { + case HashExprSafeBroadcastMode(bytes, _) => + Some(BroadcastModeUtils.deserializeExpressions(bytes)) + case _ => None + } + + private def transformProjection: UnsafeProjection = safeBroadcastMode match { + case IdentitySafeBroadcastMode => + UnsafeProjection.create(output, output) + case HashSafeBroadcastMode(ords, _) => + val bound = ords.map(i => BoundReference(i, output(i).dataType, output(i).nullable)) + UnsafeProjection.create(bound) + case HashExprSafeBroadcastMode(_, _) => + exprKeysFromBytes match { + case Some(keys) => UnsafeProjection.create(keys) + case None => + throw new IllegalStateException( + "Failed to deserialize expressions for HashExprSafeBroadcastMode" + ) + } } override def deserialized: Iterator[ColumnarBatch] = { diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala index c0ef884e73a3..80e92a1537e8 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/unsafe/UnsafeColumnarBuildSideRelation.scala @@ -28,8 +28,9 @@ import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeCo import org.apache.spark.annotation.Experimental import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, IdentityBroadcastMode} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSeq, BindReferences, BoundReference, Expression, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode +import org.apache.spark.sql.execution.{BroadcastModeUtils, HashExprSafeBroadcastMode, HashSafeBroadcastMode, IdentitySafeBroadcastMode, SafeBroadcastMode} import org.apache.spark.sql.execution.joins.{BuildSideRelation, HashedRelationBroadcastMode} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.utils.SparkArrowUtil @@ -45,6 +46,44 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import scala.collection.JavaConverters.asScalaIteratorConverter +object UnsafeColumnarBuildSideRelation { + // Keep constructors with BroadcastMode for compatibility + def apply( + output: Seq[Attribute], + batches: UnsafeBytesBufferArray, + mode: BroadcastMode): UnsafeColumnarBuildSideRelation = { + val boundMode = mode match { + case HashedRelationBroadcastMode(keys, isNullAware) => + // Bind each key to the build-side output so simple cols become BoundReference + val boundKeys: Seq[Expression] = + keys.map(k => BindReferences.bindReference(k, AttributeSeq(output))) + HashedRelationBroadcastMode(boundKeys, isNullAware) + case m => + m // IdentityBroadcastMode, etc. + } + new UnsafeColumnarBuildSideRelation(output, batches, BroadcastModeUtils.toSafe(boundMode)) + } + def apply( + output: Seq[Attribute], + bytesBufferArray: Array[Array[Byte]], + mode: BroadcastMode): UnsafeColumnarBuildSideRelation = { + val boundMode = mode match { + case HashedRelationBroadcastMode(keys, isNullAware) => + // Bind each key to the build-side output so simple cols become BoundReference + val boundKeys: Seq[Expression] = + keys.map(k => BindReferences.bindReference(k, AttributeSeq(output))) + HashedRelationBroadcastMode(boundKeys, isNullAware) + case m => + m // IdentityBroadcastMode, etc. + } + new UnsafeColumnarBuildSideRelation( + output, + bytesBufferArray, + BroadcastModeUtils.toSafe(boundMode) + ) + } +} + /** * A broadcast relation that is built using off-heap memory. It will avoid the on-heap memory OOM. * @@ -59,18 +98,33 @@ import scala.collection.JavaConverters.asScalaIteratorConverter case class UnsafeColumnarBuildSideRelation( private var output: Seq[Attribute], private var batches: UnsafeBytesBufferArray, - var mode: BroadcastMode) + var safeBroadcastMode: SafeBroadcastMode) extends BuildSideRelation with Externalizable with Logging with KryoSerializable { + // Rebuild the real BroadcastMode on demand; never serialize it. + @transient override lazy val mode: BroadcastMode = + BroadcastModeUtils.fromSafe(safeBroadcastMode, output) + + // If we stored expression bytes, deserialize once and cache locally (not serialized). + @transient private lazy val exprKeysFromBytes: Option[Seq[Expression]] = safeBroadcastMode match { + case HashExprSafeBroadcastMode(bytes, _) => + Some(BroadcastModeUtils.deserializeExpressions(bytes)) + case _ => None + } + /** needed for serialization. */ def this() = { this(null, null.asInstanceOf[UnsafeBytesBufferArray], null) } - def this(output: Seq[Attribute], bytesBufferArray: Array[Array[Byte]], mode: BroadcastMode) = { + def this( + output: Seq[Attribute], + bytesBufferArray: Array[Array[Byte]], + safeMode: SafeBroadcastMode + ) = { this( output, UnsafeBytesBufferArray( @@ -78,7 +132,7 @@ case class UnsafeColumnarBuildSideRelation( bytesBufferArray.map(_.length), bytesBufferArray.map(_.length.toLong).sum ), - mode + safeMode ) val batchesSize = bytesBufferArray.length for (i <- 0 until batchesSize) { @@ -89,7 +143,7 @@ case class UnsafeColumnarBuildSideRelation( override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeObject(output) - out.writeObject(mode) + out.writeObject(safeBroadcastMode) out.writeInt(batches.arraySize) out.writeObject(batches.bytesBufferLengths) out.writeLong(batches.totalBytes) @@ -101,7 +155,7 @@ case class UnsafeColumnarBuildSideRelation( override def write(kryo: Kryo, out: Output): Unit = Utils.tryOrIOException { kryo.writeObject(out, output.toList) - kryo.writeClassAndObject(out, mode) + kryo.writeClassAndObject(out, safeBroadcastMode) out.writeInt(batches.arraySize) kryo.writeObject(out, batches.bytesBufferLengths) out.writeLong(batches.totalBytes) @@ -113,7 +167,7 @@ case class UnsafeColumnarBuildSideRelation( override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { output = in.readObject().asInstanceOf[Seq[Attribute]] - mode = in.readObject().asInstanceOf[BroadcastMode] + safeBroadcastMode = in.readObject().asInstanceOf[SafeBroadcastMode] val totalArraySize = in.readInt() val bytesBufferLengths = in.readObject().asInstanceOf[Array[Int]] val totalBytes = in.readLong() @@ -137,7 +191,7 @@ case class UnsafeColumnarBuildSideRelation( override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException { output = kryo.readObject(in, classOf[List[_]]).asInstanceOf[Seq[Attribute]] - mode = kryo.readClassAndObject(in).asInstanceOf[BroadcastMode] + safeBroadcastMode = kryo.readClassAndObject(in).asInstanceOf[SafeBroadcastMode] val totalArraySize = in.readInt() val bytesBufferLengths = kryo.readObject(in, classOf[Array[Int]]) val totalBytes = in.readLong() @@ -152,11 +206,20 @@ case class UnsafeColumnarBuildSideRelation( } } - private def transformProjection: UnsafeProjection = { - mode match { - case HashedRelationBroadcastMode(k, _) => UnsafeProjection.create(k) - case IdentityBroadcastMode => UnsafeProjection.create(output, output) - } + private def transformProjection: UnsafeProjection = safeBroadcastMode match { + case IdentitySafeBroadcastMode => + UnsafeProjection.create(output, output) + case HashSafeBroadcastMode(ords, _) => + val bound = ords.map(i => BoundReference(i, output(i).dataType, output(i).nullable)) + UnsafeProjection.create(bound) + case HashExprSafeBroadcastMode(_, _) => + exprKeysFromBytes match { + case Some(keys) => UnsafeProjection.create(keys) + case None => + throw new IllegalStateException( + "Failed to deserialize expressions for HashExprSafeBroadcastMode" + ) + } } override def deserialized: Iterator[ColumnarBatch] = {