From d7ff8f4e346dd5a42061af67a7606283ed62a40a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 15 May 2015 13:48:05 +0800 Subject: [PATCH 1/3] fix 7269 --- .../spark/sql/catalyst/analysis/Analyzer.scala | 15 +++++++-------- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 ++-- .../sql/catalyst/expressions/Expression.scala | 6 ++++++ .../expressions/namedExpressions.scala | 5 +++++ .../spark/sql/catalyst/planning/patterns.scala | 5 +++-- .../sql/hive/execution/SQLQuerySuite.scala | 18 ++++++++++++++++++ 6 files changed, 41 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0b6e1d44b9c4d..394e7282da378 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.OpenHashSet /** * A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing @@ -148,17 +147,17 @@ class Analyzer( * @param exprs the attributes in sequence * @return the attributes of non selected specified via bitmask (with the bit set to 1) */ - private def buildNonSelectExprSet(bitmask: Int, exprs: Seq[Expression]) - : OpenHashSet[Expression] = { - val set = new OpenHashSet[Expression](2) + private def buildNonSelectExprs(bitmask: Int, exprs: Seq[Expression]) + : Seq[Expression] = { + val buffer = ArrayBuffer.empty[Expression] var bit = exprs.length - 1 while (bit >= 0) { - if (((bitmask >> bit) & 1) == 0) set.add(exprs(bit)) + if (((bitmask >> bit) & 1) == 0) buffer += exprs(bit) bit -= 1 } - set + buffer } /* @@ -197,10 +196,10 @@ class Analyzer( g.bitmasks.foreach { bitmask => // get the non selected grouping attributes according to the bit mask - val nonSelectedGroupExprSet = buildNonSelectExprSet(bitmask, g.groupByExprs) + val nonSelectedGroupExprs = buildNonSelectExprs(bitmask, g.groupByExprs) val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown { - case x: Expression if nonSelectedGroupExprSet.contains(x) => + case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined => // if the input attribute in the Invalid Grouping Expression set of for this group // replace it with constant null Literal.create(null, expr.dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index f104e742c90fe..06a0504359f6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -86,12 +86,12 @@ trait CheckAnalysis { case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case _: AggregateExpression => // OK - case e: Attribute if !groupingExprs.contains(e) => + case e: Attribute if groupingExprs.find(_ semanticEquals e).isEmpty => failAnalysis( s"expression '${e.prettyString}' is neither present in the group by, " + s"nor is it an aggregate function. " + "Add to group by or wrap in first() if you don't care which value you get.") - case e if groupingExprs.contains(e) => // OK + case e if groupingExprs.find(_ semanticEquals e).isDefined => // OK case e if e.references.isEmpty => // OK case e => e.children.foreach(checkValidAggregateExpression) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 0837a3179d897..04c209f674064 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -76,6 +76,12 @@ abstract class Expression extends TreeNode[Expression] { case u: UnresolvedAttribute => PrettyAttribute(u.name) }.toString } + + def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && + this.productIterator.zip(other.asInstanceOf[Product].productIterator).forall { + case (e1: Expression, e2: Expression) => e1 semanticEquals e2 + case (i1, i2) => i1 == i2 + } } abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index a9170589f8c6c..8609f441d7661 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -181,6 +181,11 @@ case class AttributeReference( case _ => false } + override def semanticEquals(other: Expression): Boolean = other match { + case ar: AttributeReference => exprId == ar.exprId + case _ => false + } + override def hashCode: Int = { // See http://stackoverflow.com/questions/113511/hash-code-implementation var h = 17 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index cd54d04814ea4..1dd75a8846303 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -159,9 +159,10 @@ object PartialAggregation { // Should trim aliases around `GetField`s. These aliases are introduced while // resolving struct field accesses, because `GetField` is not a `NamedExpression`. // (Should we just turn `GetField` into a `NamedExpression`?) + val trimmed = e.transform { case Alias(g: ExtractValue, _) => g } namedGroupingExpressions - .get(e.transform { case Alias(g: ExtractValue, _) => g }) - .map(_.toAttribute) + .find { case (k, v) => k semanticEquals trimmed } + .map(_._2.toAttribute) .getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 5c7152e2140db..b53b07fb5d7fb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -773,4 +773,22 @@ class SQLQuerySuite extends QueryTest { | select * from v2 order by key limit 1 """.stripMargin), Row(0, 3)) } + + test("SPARK-7269 Check analysis failed in case in-sensitive") { + Seq(1, 2, 3).map { i => + (i.toString, i.toString) + }.toDF("key", "value").registerTempTable("df_analysis") + sql("SELECT kEy from df_analysis group by key").collect() + sql("SELECT kEy+3 from df_analysis group by key+3").collect() + sql("SELECT kEy+3, a.kEy, A.kEy from df_analysis A group by key").collect() + sql("SELECT cast(kEy+1 as Int) from df_analysis A group by cast(key+1 as int)").collect() + sql("SELECT cast(kEy+1 as Int) from df_analysis A group by key+1").collect() + sql("SELECT 2 from df_analysis A group by key+1").collect() + intercept[AnalysisException] { + sql("SELECT kEy+1 from df_analysis group by key+3") + } + intercept[AnalysisException] { + sql("SELECT cast(key+2 as Int) from df_analysis A group by cast(key+1 as int)") + } + } } From cc0204564edc4b4ba59a88e744dac8d4166e7da2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 15 May 2015 14:40:16 +0800 Subject: [PATCH 2/3] consider elements length equal --- .../apache/spark/sql/catalyst/expressions/Expression.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 04c209f674064..b981adba82882 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -77,11 +77,14 @@ abstract class Expression extends TreeNode[Expression] { }.toString } - def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && - this.productIterator.zip(other.asInstanceOf[Product].productIterator).forall { + def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && { + val elements1 = this.productIterator.toSeq + val elements2 = other.asInstanceOf[Product].productIterator.toSeq + elements1.length == elements2.length && elements1.zip(elements2).forall { case (e1: Expression, e2: Expression) => e1 semanticEquals e2 case (i1, i2) => i1 == i2 } + } } abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] { From e4a3cc7cdaf7e971734c0dcf01fe6316f8df51d9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 16 May 2015 12:20:23 +0800 Subject: [PATCH 3/3] address comments --- .../sql/catalyst/analysis/Analyzer.scala | 26 +++++-------------- .../sql/catalyst/expressions/Expression.scala | 4 +++ .../expressions/namedExpressions.scala | 2 +- 3 files changed, 11 insertions(+), 21 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 394e7282da378..dfa4215f2efe5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -141,25 +141,6 @@ class Analyzer( } object ResolveGroupingAnalytics extends Rule[LogicalPlan] { - /** - * Extract attribute set according to the grouping id - * @param bitmask bitmask to represent the selected of the attribute sequence - * @param exprs the attributes in sequence - * @return the attributes of non selected specified via bitmask (with the bit set to 1) - */ - private def buildNonSelectExprs(bitmask: Int, exprs: Seq[Expression]) - : Seq[Expression] = { - val buffer = ArrayBuffer.empty[Expression] - - var bit = exprs.length - 1 - while (bit >= 0) { - if (((bitmask >> bit) & 1) == 0) buffer += exprs(bit) - bit -= 1 - } - - buffer - } - /* * GROUP BY a, b, c WITH ROLLUP * is equivalent to @@ -196,7 +177,12 @@ class Analyzer( g.bitmasks.foreach { bitmask => // get the non selected grouping attributes according to the bit mask - val nonSelectedGroupExprs = buildNonSelectExprs(bitmask, g.groupByExprs) + val nonSelectedGroupExprs = ArrayBuffer.empty[Expression] + var bit = g.groupByExprs.length - 1 + while (bit >= 0) { + if (((bitmask >> bit) & 1) == 0) nonSelectedGroupExprs += g.groupByExprs(bit) + bit -= 1 + } val substitution = (g.child.output :+ g.gid).map(expr => expr transformDown { case x: Expression if nonSelectedGroupExprs.find(_ semanticEquals x).isDefined => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b981adba82882..8995002c677f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -77,6 +77,10 @@ abstract class Expression extends TreeNode[Expression] { }.toString } + /** + * Returns true if 2 expressions are equal in semantic, which is similar to equals method + * but has different definition on some leaf expressions like AttributeReference. + */ def semanticEquals(other: Expression): Boolean = this.getClass == other.getClass && { val elements1 = this.productIterator.toSeq val elements2 = other.asInstanceOf[Product].productIterator.toSeq diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 8609f441d7661..50be26d0b08b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -182,7 +182,7 @@ case class AttributeReference( } override def semanticEquals(other: Expression): Boolean = other match { - case ar: AttributeReference => exprId == ar.exprId + case ar: AttributeReference => sameRef(ar) case _ => false }