From dc0732d73713da6b97750fe2121463c324c516d0 Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Mon, 17 Jun 2024 11:03:37 +0800 Subject: [PATCH 1/2] fixup r remove sort --- .../enumerated/EnumeratedTransform.scala | 1 + .../columnar/enumerated/RemoveSort.scala | 61 +++++++++++++++++++ .../org/apache/gluten/ras/path/Pattern.scala | 36 +++++++++-- .../apache/gluten/ras/rule/PatternSuite.scala | 30 ++++++++- 4 files changed, 120 insertions(+), 8 deletions(-) create mode 100644 gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveSort.scala diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala index 0b9dcc663246..9a54a101453f 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/EnumeratedTransform.scala @@ -43,6 +43,7 @@ case class EnumeratedTransform(session: SparkSession, outputsColumnar: Boolean) private val rules = List( new PushFilterToScan(RasOffload.validator), + RemoveSort, RemoveFilter ) diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveSort.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveSort.scala new file mode 100644 index 000000000000..5b5d5e541eb7 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveSort.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.columnar.enumerated + +import org.apache.gluten.execution.{HashAggregateExecBaseTransformer, ShuffledHashJoinExecTransformerBase, SortExecTransformer} +import org.apache.gluten.extension.GlutenPlan +import org.apache.gluten.ras.path.Pattern._ +import org.apache.gluten.ras.path.Pattern.Matchers._ +import org.apache.gluten.ras.rule.{RasRule, Shape} +import org.apache.gluten.ras.rule.Shapes._ + +import org.apache.spark.sql.catalyst.expressions.SortOrder +import org.apache.spark.sql.execution.SparkPlan + +/** + * Removes unnecessary sort if its parent doesn't require for sorted input. + * + * TODO: Sort's removal could be made much simpler once output ordering is added as a physical + * property in RAS planer. + */ +object RemoveSort extends RasRule[SparkPlan] { + private val appliedTypes: Seq[Class[_ <: GlutenPlan]] = + List(classOf[HashAggregateExecBaseTransformer], classOf[ShuffledHashJoinExecTransformerBase]) + + override def shift(node: SparkPlan): Iterable[SparkPlan] = { + assert(node.isInstanceOf[GlutenPlan]) + val newChildren = node.requiredChildOrdering.zip(node.children).map { + case (Nil, sort: SortExecTransformer) => + // Parent doesn't ask for sorted input from this child but a sort op was somehow added. + // Remove it. + sort.child + case (req, child) => + // Parent asks for sorted input from this child. Do nothing but an assertion. + assert(SortOrder.orderingSatisfies(child.outputOrdering, req)) + child + } + val out = List(node.withNewChildren(newChildren)) + out + } + override def shape(): Shape[SparkPlan] = pattern( + branch2[SparkPlan]( + or(appliedTypes.map(clazz[SparkPlan](_)): _*), + _ >= 1, + _ => node(clazz(classOf[GlutenPlan])) + ).build() + ) +} diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala index e60a94717654..be8d2c935b5c 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala @@ -87,14 +87,34 @@ object Pattern { override def children(count: Int): Seq[Node[T]] = (0 until count).map(_ => ignore[T]) } - private case class Branch[T <: AnyRef](matcher: Matcher[T], children: Seq[Node[T]]) + private case class Branch[T <: AnyRef](matcher: Matcher[T], children: Branch.ChildrenFactory[T]) extends Node[T] { override def skip(): Boolean = false - override def abort(node: CanonicalNode[T]): Boolean = node.childrenCount != children.size + override def abort(node: CanonicalNode[T]): Boolean = + !children.acceptsChildrenCount(node.childrenCount) override def matches(node: CanonicalNode[T]): Boolean = matcher(node.self()) override def children(count: Int): Seq[Node[T]] = { - assert(count == children.size) - children + assert(children.acceptsChildrenCount(count)) + (0 until count).map(children.child) + } + } + + private object Branch { + trait ChildrenFactory[T <: AnyRef] { + def child(index: Int): Node[T] + def acceptsChildrenCount(count: Int): Boolean + } + + object ChildrenFactory { + case class Plain[T <: AnyRef](nodes: Seq[Node[T]]) extends ChildrenFactory[T] { + override def child(index: Int): Node[T] = nodes(index) + override def acceptsChildrenCount(count: Int): Boolean = nodes.size == count + } + + case class Func[T <: AnyRef](arity: Int => Boolean, func: Int => Node[T]) extends ChildrenFactory[T] { + override def child(index: Int): Node[T] = func(index) + override def acceptsChildrenCount(count: Int): Boolean = arity(count) + } } } @@ -102,8 +122,12 @@ object Pattern { def ignore[T <: AnyRef]: Node[T] = Ignore.INSTANCE.asInstanceOf[Node[T]] def node[T <: AnyRef](matcher: Matcher[T]): Node[T] = Single(matcher) def branch[T <: AnyRef](matcher: Matcher[T], children: Node[T]*): Node[T] = - Branch(matcher, children.toSeq) - def leaf[T <: AnyRef](matcher: Matcher[T]): Node[T] = Branch(matcher, List.empty) + Branch(matcher, Branch.ChildrenFactory.Plain(children.toSeq)) + // Similar to #branch, but with unknown arity. + def branch2[T <: AnyRef](matcher: Matcher[T], arity: Int => Boolean, children: Int => Node[T]): Node[T] = + Branch(matcher, Branch.ChildrenFactory.Func(arity, children)) + def leaf[T <: AnyRef](matcher: Matcher[T]): Node[T] = + Branch(matcher, Branch.ChildrenFactory.Plain(List.empty)) implicit class NodeImplicits[T <: AnyRef](node: Node[T]) { def build(): Pattern[T] = { diff --git a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala index 64b66bbaffae..dc7f5e883022 100644 --- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala +++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/rule/PatternSuite.scala @@ -59,6 +59,29 @@ class PatternSuite extends AnyFunSuite { assert(pattern.matches(path, 1)) } + test("Match branch") { + val ras = + Ras[TestNode]( + PlanModelImpl, + CostModelImpl, + MetadataModelImpl, + PropertyModelImpl, + ExplainImpl, + RasRule.Factory.none()) + + val path1 = MockRasPath.mock(ras, Branch("n1", List())) + val path2 = MockRasPath.mock(ras, Branch("n1", List(Leaf("n2", 1)))) + val path3 = MockRasPath.mock(ras, Branch("n1", List(Leaf("n2", 1), Leaf("n3", 1)))) + + val pattern = + Pattern.branch2[TestNode](n => n.isInstanceOf[Branch], _ >= 1, _ => Pattern.any).build() + assert(!pattern.matches(path1, 1)) + assert(pattern.matches(path2, 1)) + assert(pattern.matches(path2, 2)) + assert(pattern.matches(path3, 1)) + assert(pattern.matches(path3, 2)) + } + test("Match unary") { val ras = Ras[TestNode]( @@ -231,17 +254,20 @@ object PatternSuite { case class Unary(name: String, child: TestNode) extends UnaryLike { override def selfCost(): Long = 1 - override def withNewChildren(child: TestNode): UnaryLike = copy(child = child) } case class Binary(name: String, left: TestNode, right: TestNode) extends BinaryLike { override def selfCost(): Long = 1 - override def withNewChildren(left: TestNode, right: TestNode): BinaryLike = copy(left = left, right = right) } + case class Branch(name: String, children: Seq[TestNode]) extends TestNode { + override def selfCost(): Long = 1 + override def withNewChildren(children: Seq[TestNode]): TestNode = copy(children = children) + } + case class DummyGroup() extends LeafLike { override def makeCopy(): LeafLike = throw new UnsupportedOperationException() override def selfCost(): Long = throw new UnsupportedOperationException() From 91baa6250ea86dacb5e544ed14767467ee1073ee Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Mon, 17 Jun 2024 13:56:40 +0800 Subject: [PATCH 2/2] fixup --- .../main/scala/org/apache/gluten/ras/path/Pattern.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala index be8d2c935b5c..f54b031b0aef 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/Pattern.scala @@ -111,7 +111,8 @@ object Pattern { override def acceptsChildrenCount(count: Int): Boolean = nodes.size == count } - case class Func[T <: AnyRef](arity: Int => Boolean, func: Int => Node[T]) extends ChildrenFactory[T] { + case class Func[T <: AnyRef](arity: Int => Boolean, func: Int => Node[T]) + extends ChildrenFactory[T] { override def child(index: Int): Node[T] = func(index) override def acceptsChildrenCount(count: Int): Boolean = arity(count) } @@ -124,7 +125,10 @@ object Pattern { def branch[T <: AnyRef](matcher: Matcher[T], children: Node[T]*): Node[T] = Branch(matcher, Branch.ChildrenFactory.Plain(children.toSeq)) // Similar to #branch, but with unknown arity. - def branch2[T <: AnyRef](matcher: Matcher[T], arity: Int => Boolean, children: Int => Node[T]): Node[T] = + def branch2[T <: AnyRef]( + matcher: Matcher[T], + arity: Int => Boolean, + children: Int => Node[T]): Node[T] = Branch(matcher, Branch.ChildrenFactory.Func(arity, children)) def leaf[T <: AnyRef](matcher: Matcher[T]): Node[T] = Branch(matcher, Branch.ChildrenFactory.Plain(List.empty))