From 05875ba0f65ffec90299bee395a8b3771e353a2d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 14 Apr 2016 06:54:04 +0000 Subject: [PATCH 1/2] Skip calling encoder.shift many times in update and do not need to do shift when delta is zero. --- .../sql/catalyst/encoders/ExpressionEncoder.scala | 10 +++++++--- .../aggregate/TypedAggregateExpression.scala | 11 +++++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 56d29cfbe1f66..63d8a3726d4dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -339,9 +339,13 @@ case class ExpressionEncoder[T]( * Returns a new encoder with input columns shifted by `delta` ordinals */ def shift(delta: Int): ExpressionEncoder[T] = { - copy(deserializer = deserializer transform { - case r: BoundReference => r.copy(ordinal = r.ordinal + delta) - }) + if (delta == 0) { + this + } else { + copy(deserializer = deserializer transform { + case r: BoundReference => r.copy(ordinal = r.ordinal + delta) + }) + } } protected val attrs = serializer.flatMap(_.collect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 9abae5357973f..10f2a4160e5c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -85,6 +85,9 @@ case class TypedAggregateExpression( .resolve(aggBufferAttributes, OuterScopes.outerScopes) .bind(aggBufferAttributes) + val bEncoderForMutableAggBuffer = bEncoder.shift(mutableAggBufferOffset) + val bEncoderForInputAggBuffer = bEncoder.shift(inputAggBufferOffset) + // Note: although this simply copies aggBufferAttributes, this common code can not be placed // in the superclass because that will lead to initialization ordering issues. override val inputAggBufferAttributes: Seq[AttributeReference] = @@ -118,7 +121,7 @@ case class TypedAggregateExpression( override def update(buffer: MutableRow, input: InternalRow): Unit = { val inputA = boundA.fromRow(input) - val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) + val currentB = bEncoderForMutableAggBuffer.fromRow(buffer) val merged = aggregator.reduce(currentB, inputA) val returned = bEncoder.toRow(merged) @@ -126,8 +129,8 @@ case class TypedAggregateExpression( } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1) - val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2) + val b1 = bEncoderForMutableAggBuffer.fromRow(buffer1) + val b2 = bEncoderForInputAggBuffer.fromRow(buffer2) val merged = aggregator.merge(b1, b2) val returned = bEncoder.toRow(merged) @@ -135,7 +138,7 @@ case class TypedAggregateExpression( } override def eval(buffer: InternalRow): Any = { - val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) + val b = bEncoderForMutableAggBuffer.fromRow(buffer) val result = cEncoder.toRow(aggregator.finish(b)) dataType match { case _: StructType => result From faf1ddf87c454414624877fbc56f675fa03bb97f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 14 Apr 2016 08:58:42 +0000 Subject: [PATCH 2/2] Revert it since it will be revamped soon. --- .../aggregate/TypedAggregateExpression.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 10f2a4160e5c4..9abae5357973f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -85,9 +85,6 @@ case class TypedAggregateExpression( .resolve(aggBufferAttributes, OuterScopes.outerScopes) .bind(aggBufferAttributes) - val bEncoderForMutableAggBuffer = bEncoder.shift(mutableAggBufferOffset) - val bEncoderForInputAggBuffer = bEncoder.shift(inputAggBufferOffset) - // Note: although this simply copies aggBufferAttributes, this common code can not be placed // in the superclass because that will lead to initialization ordering issues. override val inputAggBufferAttributes: Seq[AttributeReference] = @@ -121,7 +118,7 @@ case class TypedAggregateExpression( override def update(buffer: MutableRow, input: InternalRow): Unit = { val inputA = boundA.fromRow(input) - val currentB = bEncoderForMutableAggBuffer.fromRow(buffer) + val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) val merged = aggregator.reduce(currentB, inputA) val returned = bEncoder.toRow(merged) @@ -129,8 +126,8 @@ case class TypedAggregateExpression( } override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val b1 = bEncoderForMutableAggBuffer.fromRow(buffer1) - val b2 = bEncoderForInputAggBuffer.fromRow(buffer2) + val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1) + val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2) val merged = aggregator.merge(b1, b2) val returned = bEncoder.toRow(merged) @@ -138,7 +135,7 @@ case class TypedAggregateExpression( } override def eval(buffer: InternalRow): Any = { - val b = bEncoderForMutableAggBuffer.fromRow(buffer) + val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) val result = cEncoder.toRow(aggregator.finish(b)) dataType match { case _: StructType => result