From 33640eca2e9d03fb4f2473ad35db8da9c713926c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 5 Jun 2015 00:24:40 +0800 Subject: [PATCH 1/8] auto alias expressions in analyzer --- .../apache/spark/sql/catalyst/SqlParser.scala | 11 +--- .../sql/catalyst/analysis/Analyzer.scala | 55 +++++++++++-------- .../sql/catalyst/analysis/CheckAnalysis.scala | 9 +-- .../sql/catalyst/analysis/unresolved.scala | 14 +++++ .../catalyst/expressions/ExtractValue.scala | 16 ++++-- .../catalyst/plans/logical/LogicalPlan.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 1 - .../org/apache/spark/sql/DataFrame.scala | 6 +- .../org/apache/spark/sql/GroupedData.scala | 36 +++++------- .../scala/org/apache/spark/sql/TestData.scala | 1 - .../org/apache/spark/sql/hive/HiveQl.scala | 9 +-- 11 files changed, 74 insertions(+), 86 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index da3a717f90058..79f526e823cd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -99,13 +99,6 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { protected val WHERE = Keyword("WHERE") protected val WITH = Keyword("WITH") - protected def assignAliases(exprs: Seq[Expression]): Seq[NamedExpression] = { - exprs.zipWithIndex.map { - case (ne: NamedExpression, _) => ne - case (e, i) => Alias(e, s"c$i")() - } - } - protected lazy val start: Parser[LogicalPlan] = start1 | insert | cte @@ -130,8 +123,8 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { val base = r.getOrElse(OneRowRelation) val withFilter = f.map(Filter(_, base)).getOrElse(base) val withProjection = g - .map(Aggregate(_, assignAliases(p), withFilter)) - .getOrElse(Project(assignAliases(p), withFilter)) + .map(Aggregate(_, p.map(UnresolvedAlias(_)), withFilter)) + .getOrElse(Project(p.map(UnresolvedAlias(_)), withFilter)) val withDistinct = d.map(_ => Distinct(withProjection)).getOrElse(withProjection) val withHaving = h.map(Filter(_, withDistinct)).getOrElse(withDistinct) val withOrder = o.map(_(withHaving)).getOrElse(withHaving) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 21b05760256b4..a87b9418a0dd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -70,6 +70,7 @@ class Analyzer( Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: + ResolveAliases :: ResolveGroupingAnalytics :: ResolveSortReferences :: ResolveGenerate :: @@ -77,7 +78,6 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: - TrimGroupingAliases :: typeCoercionRules ++ extendedResolutionRules : _*) ) @@ -131,13 +131,28 @@ class Analyzer( } } - /** - * Removes no-op Alias expressions from the plan. - */ - object TrimGroupingAliases extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Aggregate(groups, aggs, child) => - Aggregate(groups.map(_.transform { case Alias(c, _) => c }), aggs, child) + object ResolveAliases extends Rule[LogicalPlan] { + private def assignAliases(exprs: Seq[Expression]) = { + var i = -1 + exprs.map(_ transformDown { + case u @ UnresolvedAlias(child) => + child match { + case ne: NamedExpression => ne + case ev: ExtractValueWithStruct => Alias(ev, ev.field.name)() + case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil) + case e if !e.resolved => u + case other => + i += 1 + Alias(other, s"c$i")() + } + }).asInstanceOf[Seq[NamedExpression]] + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case Aggregate(groups, aggs, child) if child.resolved => + Aggregate(groups, assignAliases(aggs), child) + case Project(projectList, child) if child.resolved => + Project(assignAliases(projectList), child) } } @@ -228,7 +243,7 @@ class Analyzer( } def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i@InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => + case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _, _) => i.copy(table = EliminateSubQueries(getTable(u))) case u: UnresolvedRelation => getTable(u) @@ -352,8 +367,12 @@ class Analyzer( q.asInstanceOf[GroupingAnalytics].gid case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = - withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + val result = withPosition(u) { + q.resolveChildren(nameParts, resolver).map { + case UnresolvedAlias(child) => child + case other => other + }.getOrElse(u) + } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -586,19 +605,7 @@ class Analyzer( /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ private object AliasedGenerator { def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { - case Alias(g: Generator, name) - if g.resolved && - g.elementTypes.size > 1 && - java.util.regex.Pattern.matches("_c[0-9]+", name) => { - // Assume the default name given by parser is "_c[0-9]+", - // TODO in long term, move the naming logic from Parser to Analyzer. - // In projection, Parser gave default name for TGF as does for normal UDF, - // but the TGF probably have multiple output columns/names. - // e.g. SELECT explode(map(key, value)) FROM src; - // Let's simply ignore the default given name for this case. - Some((g, Nil)) - } - case Alias(g: Generator, name) if g.resolved && g.elementTypes.size > 1 => + case Alias(g: Generator, name) if g.elementTypes.size > 1 => // If not given the default names, and the TGF with multiple output columns failAnalysis( s"""Expect multiple names given for ${g.getClass.getName}, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7fabd2bfc80ab..c5a1437be6d05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -95,14 +95,7 @@ trait CheckAnalysis { case e => e.children.foreach(checkValidAggregateExpression) } - val cleaned = aggregateExprs.map(_.transform { - // Should trim aliases around `GetField`s. These aliases are introduced while - // resolving struct field accesses, because `GetField` is not a `NamedExpression`. - // (Should we just turn `GetField` into a `NamedExpression`?) - case Alias(g, _) => g - }) - - cleaned.foreach(checkValidAggregateExpression) + aggregateExprs.foreach(checkValidAggregateExpression) case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index c9d91425788a8..22f4845a499bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -206,3 +206,17 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) override def toString: String = s"$child[$extraction]" } + +case class UnresolvedAlias(child: Expression) extends NamedExpression with trees.UnaryNode[Expression] { + override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") + override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") + override def exprId: ExprId = throw new UnresolvedException(this, "exprId") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def name: String = throw new UnresolvedException(this, "name") + + override lazy val resolved = false + + override def eval(input: Row = null): Any = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index 4aaabff15b6ee..741b0c3870603 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -94,16 +94,22 @@ trait ExtractValue extends UnaryExpression { self: Product => } +abstract class ExtractValueWithStruct extends ExtractValue { + self: Product => + + def field: StructField + override def foldable: Boolean = child.foldable + override def toString: String = s"$child.${field.name}" +} + /** * Returns the value of fields in the Struct `child`. */ case class GetStructField(child: Expression, field: StructField, ordinal: Int) - extends ExtractValue { + extends ExtractValueWithStruct { override def dataType: DataType = field.dataType override def nullable: Boolean = child.nullable || field.nullable - override def foldable: Boolean = child.foldable - override def toString: String = s"$child.${field.name}" override def eval(input: InternalRow): Any = { val baseValue = child.eval(input).asInstanceOf[InternalRow] @@ -118,12 +124,10 @@ case class GetArrayStructFields( child: Expression, field: StructField, ordinal: Int, - containsNull: Boolean) extends ExtractValue { + containsNull: Boolean) extends ExtractValueWithStruct { override def dataType: DataType = ArrayType(field.dataType, containsNull) override def nullable: Boolean = child.nullable - override def foldable: Boolean = child.foldable - override def toString: String = s"$child.${field.name}" override def eval(input: InternalRow): Any = { val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index a853e27c1212d..73e8a0c8effb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, EliminateSubQueries, Resolver} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.trees.TreeNode diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index b4e008a6e8480..f201c8ea8a110 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -21,7 +21,6 @@ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental import org.apache.spark.Logging -import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.analysis._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 466258e76f9f6..d3c4489ec6215 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -629,11 +629,7 @@ class DataFrame private[sql]( @scala.annotation.varargs def select(cols: Column*): DataFrame = { val namedExpressions = cols.map { - case Column(expr: NamedExpression) => expr - // Leave an unaliased explode with an empty list of names since the analzyer will generate the - // correct defaults after the nested expression's type has been resolved. - case Column(explode: Explode) => MultiAlias(explode, Nil) - case Column(expr: Expression) => Alias(expr, expr.prettyString)() + case Column(expr: Expression) => UnresolvedAlias(expr) } // When user continuously call `select`, speed up analysis by collapsing `Project` import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 45b3e1bc627d5..7cc356f68c3a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConversions._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType @@ -70,27 +70,24 @@ class GroupedData protected[sql]( groupingExprs: Seq[Expression], private val groupType: GroupedData.GroupType) { - private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { - val retainedExprs = groupingExprs.map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - retainedExprs ++ aggExprs - } else { - aggExprs - } + groupingExprs ++ aggExprs + } else { + aggExprs + } + val aliasedAgg = aggregates.map(UnresolvedAlias(_)) groupType match { case GroupedData.GroupByType => DataFrame( - df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan)) + df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case GroupedData.RollupType => DataFrame( - df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates)) + df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg)) case GroupedData.CubeType => DataFrame( - df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates)) + df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) } } @@ -112,10 +109,7 @@ class GroupedData protected[sql]( namedExpr } } - toDF(columnExprs.map { c => - val a = f(c) - Alias(a, a.prettyString)() - }) + toDF(columnExprs.map(f)) } private[this] def strToExpr(expr: String): (Expression => Expression) = { @@ -169,8 +163,7 @@ class GroupedData protected[sql]( */ def agg(exprs: Map[String, String]): DataFrame = { toDF(exprs.map { case (colName, expr) => - val a = strToExpr(expr)(df(colName).expr) - Alias(a, a.prettyString)() + strToExpr(expr)(df(colName).expr) }.toSeq) } @@ -224,10 +217,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = { - toDF((expr +: exprs).map(_.expr).map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - }) + toDF((expr +: exprs).map(_.expr)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 520a862ea0838..207d7a352c7b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import java.sql.Timestamp -import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test._ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index ca4b80b51b23f..7c4620952ba4b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -415,13 +415,6 @@ private[hive] object HiveQl { throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ") } - protected def nameExpressions(exprs: Seq[Expression]): Seq[NamedExpression] = { - exprs.zipWithIndex.map { - case (ne: NamedExpression, _) => ne - case (e, i) => Alias(e, s"_c$i")() - } - } - protected def extractDbNameTableName(tableNameParts: Node): (Option[String], String) = { val (db, tableName) = tableNameParts.getChildren.map { case Token(part, Nil) => cleanIdentifier(part) } match { @@ -942,7 +935,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // (if there is a group by) or a script transformation. val withProject: LogicalPlan = transformation.getOrElse { val selectExpressions = - nameExpressions(select.getChildren.flatMap(selExprNodeToExpr).toSeq) + select.getChildren.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)).toSeq Seq( groupByClause.map(e => e match { case Token("TOK_GROUPBY", children) => From 39c1aef288e1cbd149dfd0d5359960952046992e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 5 Jun 2015 09:45:35 +0800 Subject: [PATCH 2/8] small fix --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 1 + .../org/apache/spark/sql/catalyst/analysis/unresolved.scala | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a87b9418a0dd9..28040509a40e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -137,6 +137,7 @@ class Analyzer( exprs.map(_ transformDown { case u @ UnresolvedAlias(child) => child match { + case _: UnresolvedAttribute => u case ne: NamedExpression => ne case ev: ExtractValueWithStruct => Alias(ev, ev.field.name)() case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 22f4845a499bb..341bf56b81b59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -207,7 +207,9 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) override def toString: String = s"$child[$extraction]" } -case class UnresolvedAlias(child: Expression) extends NamedExpression with trees.UnaryNode[Expression] { +case class UnresolvedAlias(child: Expression) extends NamedExpression + with trees.UnaryNode[Expression] { + override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") override def qualifiers: Seq[String] = throw new UnresolvedException(this, "qualifiers") override def exprId: ExprId = throw new UnresolvedException(this, "exprId") From 9f073597ff19fba455fec6cebb135cf8ad66f292 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 5 Jun 2015 17:48:36 +0800 Subject: [PATCH 3/8] address comments --- .../sql/catalyst/analysis/Analyzer.scala | 9 +++-- .../sql/catalyst/analysis/unresolved.scala | 3 ++ .../sql/catalyst/planning/patterns.scala | 11 +++--- .../org/apache/spark/sql/GroupedData.scala | 36 ++++++++++++------- 4 files changed, 37 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 28040509a40e1..19c6dc4d9a363 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -131,10 +131,15 @@ class Analyzer( } } + /** + * Replaces [[UnresolvedAlias]]s with concrete aliases. + */ object ResolveAliases extends Rule[LogicalPlan] { private def assignAliases(exprs: Seq[Expression]) = { var i = -1 - exprs.map(_ transformDown { + // The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need + // to transform down the whole tree. + exprs.map { case u @ UnresolvedAlias(child) => child match { case _: UnresolvedAttribute => u @@ -146,7 +151,7 @@ class Analyzer( i += 1 Alias(other, s"c$i")() } - }).asInstanceOf[Seq[NamedExpression]] + } } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 341bf56b81b59..aee2c04b04a38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -207,6 +207,9 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) override def toString: String = s"$child[$extraction]" } +/** + * Holds the expression that has yet to be aliased. + */ case class UnresolvedAlias(child: Expression) extends NamedExpression with trees.UnaryNode[Expression] { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 3b6f8bfd9ff9b..8d79585e65a36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -156,13 +156,10 @@ object PartialAggregation { partialEvaluations(new TreeNodeRef(e)).finalEvaluation case e: Expression => - // Should trim aliases around `GetField`s. These aliases are introduced while - // resolving struct field accesses, because `GetField` is not a `NamedExpression`. - // (Should we just turn `GetField` into a `NamedExpression`?) - val trimmed = e.transform { case Alias(g: ExtractValue, _) => g } - namedGroupingExpressions.collectFirst { - case (expr, ne) if expr semanticEquals trimmed => ne.toAttribute - }.getOrElse(e) + namedGroupingExpressions + .find { case (k, v) => k semanticEquals e } + .map(_._2.toAttribute) + .getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] val partialComputation = namedGroupingExpressions.map(_._2) ++ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 7cc356f68c3a0..45b3e1bc627d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConversions._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, Star} +import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType @@ -70,24 +70,27 @@ class GroupedData protected[sql]( groupingExprs: Seq[Expression], private val groupType: GroupedData.GroupType) { - private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { + private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { - groupingExprs ++ aggExprs - } else { - aggExprs - } + val retainedExprs = groupingExprs.map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } + retainedExprs ++ aggExprs + } else { + aggExprs + } - val aliasedAgg = aggregates.map(UnresolvedAlias(_)) groupType match { case GroupedData.GroupByType => DataFrame( - df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) + df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan)) case GroupedData.RollupType => DataFrame( - df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg)) + df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates)) case GroupedData.CubeType => DataFrame( - df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) + df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates)) } } @@ -109,7 +112,10 @@ class GroupedData protected[sql]( namedExpr } } - toDF(columnExprs.map(f)) + toDF(columnExprs.map { c => + val a = f(c) + Alias(a, a.prettyString)() + }) } private[this] def strToExpr(expr: String): (Expression => Expression) = { @@ -163,7 +169,8 @@ class GroupedData protected[sql]( */ def agg(exprs: Map[String, String]): DataFrame = { toDF(exprs.map { case (colName, expr) => - strToExpr(expr)(df(colName).expr) + val a = strToExpr(expr)(df(colName).expr) + Alias(a, a.prettyString)() }.toSeq) } @@ -217,7 +224,10 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = { - toDF((expr +: exprs).map(_.expr)) + toDF((expr +: exprs).map(_.expr).map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + }) } /** From d18f401d7cbcf27fa3447bc5f5fe8e01662c13bd Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sat, 6 Jun 2015 11:23:58 +0800 Subject: [PATCH 4/8] refine --- .../sql/catalyst/analysis/Analyzer.scala | 14 +++---- .../catalyst/expressions/ExtractValue.scala | 2 - .../sql/catalyst/planning/patterns.scala | 7 ++-- .../catalyst/plans/logical/LogicalPlan.scala | 9 ++--- .../org/apache/spark/sql/GroupedData.scala | 37 ++++++++----------- 5 files changed, 28 insertions(+), 41 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 19c6dc4d9a363..7b1d6f8528ca1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -135,22 +135,20 @@ class Analyzer( * Replaces [[UnresolvedAlias]]s with concrete aliases. */ object ResolveAliases extends Rule[LogicalPlan] { - private def assignAliases(exprs: Seq[Expression]) = { - var i = -1 + private def assignAliases(exprs: Seq[NamedExpression]) = { // The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need // to transform down the whole tree. - exprs.map { - case u @ UnresolvedAlias(child) => + exprs.zipWithIndex.map { + case (u @ UnresolvedAlias(child), i) => child match { case _: UnresolvedAttribute => u case ne: NamedExpression => ne case ev: ExtractValueWithStruct => Alias(ev, ev.field.name)() case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil) case e if !e.resolved => u - case other => - i += 1 - Alias(other, s"c$i")() + case other => Alias(other, s"_c$i")() } + case (other, _) => other } } @@ -611,7 +609,7 @@ class Analyzer( /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ private object AliasedGenerator { def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { - case Alias(g: Generator, name) if g.elementTypes.size > 1 => + case Alias(g: Generator, name) if g.resolved && g.elementTypes.size > 1 => // If not given the default names, and the TGF with multiple output columns failAnalysis( s"""Expect multiple names given for ${g.getClass.getName}, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index 741b0c3870603..f65a107924ec5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -98,7 +98,6 @@ abstract class ExtractValueWithStruct extends ExtractValue { self: Product => def field: StructField - override def foldable: Boolean = child.foldable override def toString: String = s"$child.${field.name}" } @@ -127,7 +126,6 @@ case class GetArrayStructFields( containsNull: Boolean) extends ExtractValueWithStruct { override def dataType: DataType = ArrayType(field.dataType, containsNull) - override def nullable: Boolean = child.nullable override def eval(input: InternalRow): Any = { val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 8d79585e65a36..179a348d5baac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -156,10 +156,9 @@ object PartialAggregation { partialEvaluations(new TreeNodeRef(e)).finalEvaluation case e: Expression => - namedGroupingExpressions - .find { case (k, v) => k semanticEquals e } - .map(_._2.toAttribute) - .getOrElse(e) + namedGroupingExpressions.collectFirst { + case (expr, ne) if expr semanticEquals e => ne.toAttribute + }.getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] val partialComputation = namedGroupingExpressions.map(_._2) ++ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 73e8a0c8effb7..b009a200b920f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -252,14 +252,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => // The foldLeft adds ExtractValues for every remaining parts of the identifier, - // and aliases it with the last part of the identifier. + // and wrap it with UnresolvedAlias which will be removed later. // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias - // the final expression as "c". + // Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as + // UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))). val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => ExtractValue(expr, Literal(fieldName), resolver)) - val aliasName = nestedFields.last - Some(Alias(fieldExprs, aliasName)()) + Some(UnresolvedAlias(fieldExprs)) // No matches. case Seq() => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 45b3e1bc627d5..859224d263ec2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -70,27 +70,27 @@ class GroupedData protected[sql]( groupingExprs: Seq[Expression], private val groupType: GroupedData.GroupType) { - private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { - val retainedExprs = groupingExprs.map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - retainedExprs ++ aggExprs - } else { - aggExprs - } + groupingExprs ++ aggExprs + } else { + aggExprs + } + val aliasedAgg = aggregates.map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } groupType match { case GroupedData.GroupByType => DataFrame( - df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan)) + df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case GroupedData.RollupType => DataFrame( - df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates)) + df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg)) case GroupedData.CubeType => DataFrame( - df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates)) + df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) } } @@ -112,10 +112,7 @@ class GroupedData protected[sql]( namedExpr } } - toDF(columnExprs.map { c => - val a = f(c) - Alias(a, a.prettyString)() - }) + toDF(columnExprs.map(f)) } private[this] def strToExpr(expr: String): (Expression => Expression) = { @@ -169,8 +166,7 @@ class GroupedData protected[sql]( */ def agg(exprs: Map[String, String]): DataFrame = { toDF(exprs.map { case (colName, expr) => - val a = strToExpr(expr)(df(colName).expr) - Alias(a, a.prettyString)() + strToExpr(expr)(df(colName).expr) }.toSeq) } @@ -224,10 +220,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = { - toDF((expr +: exprs).map(_.expr).map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - }) + toDF((expr +: exprs).map(_.expr)) } /** From 4cfd23cdbabef7fa90c7976824e1d1c7f9c074ce Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 18 Jun 2015 23:18:17 +0800 Subject: [PATCH 5/8] fix order by --- .../sql/catalyst/analysis/Analyzer.scala | 25 +++++++++++-------- .../sql/catalyst/analysis/unresolved.scala | 2 +- .../org/apache/spark/sql/DataFrame.scala | 2 +- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7b1d6f8528ca1..8b5554a155e89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import scala.collection.mutable.ArrayBuffer - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.catalyst.expressions._ @@ -153,9 +151,11 @@ class Analyzer( } def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case Aggregate(groups, aggs, child) if child.resolved => + case Aggregate(groups, aggs, child) + if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) => Aggregate(groups, assignAliases(aggs), child) - case Project(projectList, child) if child.resolved => + case Project(projectList, child) + if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) => Project(assignAliases(projectList), child) } } @@ -371,12 +371,10 @@ class Analyzer( q.asInstanceOf[GroupingAnalytics].gid case u @ UnresolvedAttribute(nameParts) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = withPosition(u) { - q.resolveChildren(nameParts, resolver).map { - case UnresolvedAlias(child) => child - case other => other - }.getOrElse(u) - } + val result = + withPosition(u) { + q.resolveChildren(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) + } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -402,6 +400,11 @@ class Analyzer( exprs.exists(_.collect { case _: Star => true }.nonEmpty) } + private def trimUnresolvedAlias(ne: NamedExpression) = ne match { + case UnresolvedAlias(child) => child + case other => other + } + private def resolveSortOrders(ordering: Seq[SortOrder], plan: LogicalPlan, throws: Boolean) = { ordering.map { order => // Resolve SortOrder in one round. @@ -411,7 +414,7 @@ class Analyzer( try { val newOrder = order transformUp { case u @ UnresolvedAttribute(nameParts) => - plan.resolve(nameParts, resolver).getOrElse(u) + plan.resolve(nameParts, resolver).map(trimUnresolvedAlias).getOrElse(u) case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index aee2c04b04a38..c51c4dc83f51a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -222,6 +222,6 @@ case class UnresolvedAlias(child: Expression) extends NamedExpression override lazy val resolved = false - override def eval(input: Row = null): Any = + override def eval(input: InternalRow = null): Any = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d3c4489ec6215..184d019923793 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -32,7 +32,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.{MultiAlias, ResolvedStar, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, ResolvedStar, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} From 73a90cb72cb036e5f44e0a64aa741da1f5778bc2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 22 Jun 2015 00:20:27 +0800 Subject: [PATCH 6/8] fix case-preserve of ExtractValue --- .../sql/catalyst/expressions/ExtractValue.scala | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index f65a107924ec5..9832207ee940c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map -import org.apache.spark.sql.{catalyst, AnalysisException} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.types._ @@ -41,12 +41,14 @@ object ExtractValue { resolver: Resolver): ExtractValue = { (child.dataType, extraction) match { - case (StructType(fields), Literal(fieldName, StringType)) => - val ordinal = findField(fields, fieldName.toString, resolver) - GetStructField(child, fields(ordinal), ordinal) - case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) => - val ordinal = findField(fields, fieldName.toString, resolver) - GetArrayStructFields(child, fields(ordinal), ordinal, containsNull) + case (StructType(fields), NonNullLiteral(v, StringType)) => + val fieldName = v.toString + val ordinal = findField(fields, fieldName, resolver) + GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal) + case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => + val fieldName = v.toString + val ordinal = findField(fields, fieldName, resolver) + GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull) case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => GetArrayItem(child, extraction) case (_: MapType, _) => From 5b5786d6dc2c672e450e511f5a7b6424f21d3377 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 22 Jun 2015 02:14:27 +0800 Subject: [PATCH 7/8] fix agg --- .../sql/catalyst/analysis/Analyzer.scala | 17 ++++++++++------ .../sql/catalyst/analysis/unresolved.scala | 1 - .../catalyst/expressions/ExtractValue.scala | 4 ++++ .../plans/logical/basicOperators.scala | 20 ++++++++++++++++--- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8b5554a155e89..39aa32b954043 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -154,6 +154,11 @@ class Analyzer( case Aggregate(groups, aggs, child) if child.resolved && aggs.exists(_.isInstanceOf[UnresolvedAlias]) => Aggregate(groups, assignAliases(aggs), child) + + case g: GroupingAnalytics + if g.child.resolved && g.aggregations.exists(_.isInstanceOf[UnresolvedAlias]) => + g.withNewAggs(assignAliases(g.aggregations)) + case Project(projectList, child) if child.resolved && projectList.exists(_.isInstanceOf[UnresolvedAlias]) => Project(assignAliases(projectList), child) @@ -267,24 +272,24 @@ class Analyzer( Project( projectList.flatMap { case s: Star => s.expand(child.output, resolver) - case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) => + case UnresolvedAlias(f @ UnresolvedFunction(_, args)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil } - Alias(child = f.copy(children = expandedArgs), name)() :: Nil - case Alias(c @ CreateArray(args), name) if containsStar(args) => + UnresolvedAlias(child = f.copy(children = expandedArgs)) :: Nil + case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil } - Alias(c.copy(children = expandedArgs), name)() :: Nil - case Alias(c @ CreateStruct(args), name) if containsStar(args) => + UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil + case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child.output, resolver) case o => o :: Nil } - Alias(c.copy(children = expandedArgs), name)() :: Nil + UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil case o => o :: Nil }, child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index c51c4dc83f51a..ae3adbab05108 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.{errors, trees} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index 9832207ee940c..013027b199e63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -45,14 +45,18 @@ object ExtractValue { val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal) + case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull) + case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => GetArrayItem(child, extraction) + case (_: MapType, _) => GetMapValue(child, extraction) + case (otherType, _) => val errorMsg = otherType match { case StructType(_) | ArrayType(StructType(_), _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 963c7820914f3..f8e5916d69f9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -242,6 +242,8 @@ trait GroupingAnalytics extends UnaryNode { def aggregations: Seq[NamedExpression] override def output: Seq[Attribute] = aggregations.map(_.toAttribute) + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics } /** @@ -266,7 +268,11 @@ case class GroupingSets( groupByExprs: Seq[Expression], child: LogicalPlan, aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics { + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = + this.copy(aggregations = aggs) +} /** * Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, @@ -284,7 +290,11 @@ case class Cube( groupByExprs: Seq[Expression], child: LogicalPlan, aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics { + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = + this.copy(aggregations = aggs) +} /** * Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, @@ -303,7 +313,11 @@ case class Rollup( groupByExprs: Seq[Expression], child: LogicalPlan, aggregations: Seq[NamedExpression], - gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics + gid: AttributeReference = VirtualColumn.newGroupingId) extends GroupingAnalytics { + + def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = + this.copy(aggregations = aggs) +} case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output From 552eba4d7921f4a863bcae5f37725ccd3793ad3d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 22 Jun 2015 10:01:43 +0800 Subject: [PATCH 8/8] fix python --- python/pyspark/sql/context.py | 9 +++++---- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../main/scala/org/apache/spark/sql/DataFrame.scala | 12 ++++++++++-- .../scala/org/apache/spark/sql/GroupedData.scala | 6 +++++- .../org/apache/spark/sql/execution/pythonUdfs.scala | 2 +- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 6 +++--- 6 files changed, 25 insertions(+), 12 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 599c9ac5794a2..dc239226e6d3c 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -86,7 +86,8 @@ def __init__(self, sparkContext, sqlContext=None): >>> df.registerTempTable("allTypes") >>> sqlContext.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() - [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0, time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] + [Row(_c0=2, _c1=2.0, _c2=False, _c3=2, _c4=0, \ + time=datetime.datetime(2014, 8, 1, 14, 1, 5), a=1)] >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, x.row.a, x.list)).collect() [(1, u'string', 1.0, 1, True, datetime.datetime(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ @@ -176,17 +177,17 @@ def registerFunction(self, name, f, returnType=StringType()): >>> sqlContext.registerFunction("stringLengthString", lambda x: len(x)) >>> sqlContext.sql("SELECT stringLengthString('test')").collect() - [Row(c0=u'4')] + [Row(_c0=u'4')] >>> from pyspark.sql.types import IntegerType >>> sqlContext.registerFunction("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(c0=4)] + [Row(_c0=4)] >>> from pyspark.sql.types import IntegerType >>> sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() - [Row(c0=4)] + [Row(_c0=4)] """ func = lambda _, it: map(lambda x: f(*x), it) ser = AutoBatchedSerializer(PickleSerializer()) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 39aa32b954043..6311784422a91 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -68,11 +68,11 @@ class Analyzer( Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: - ResolveAliases :: ResolveGroupingAnalytics :: ResolveSortReferences :: ResolveGenerate :: ResolveFunctions :: + ResolveAliases :: ExtractWindowExpressions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 184d019923793..492a3321bc0bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -32,7 +32,7 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, ResolvedStar, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, _} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -629,7 +629,15 @@ class DataFrame private[sql]( @scala.annotation.varargs def select(cols: Column*): DataFrame = { val namedExpressions = cols.map { - case Column(expr: Expression) => UnresolvedAlias(expr) + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case Column(u: UnresolvedAttribute) => UnresolvedAlias(u) + case Column(expr: NamedExpression) => expr + // Leave an unaliased explode with an empty list of names since the analzyer will generate the + // correct defaults after the nested expression's type has been resolved. + case Column(explode: Explode) => MultiAlias(explode, Nil) + case Column(expr: Expression) => Alias(expr, expr.prettyString)() } // When user continuously call `select`, speed up analysis by collapsing `Project` import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 859224d263ec2..99d557b03a033 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConversions._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.Star +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, Star} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate} import org.apache.spark.sql.types.NumericType @@ -78,6 +78,10 @@ class GroupedData protected[sql]( } val aliasedAgg = aggregates.map { + // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we + // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to + // make it a NamedExpression. + case u: UnresolvedAttribute => UnresolvedAlias(u) case expr: NamedExpression => expr case expr: Expression => Alias(expr, expr.prettyString)() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 1ce150ceaf5f9..c8c67ce334002 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -74,7 +74,7 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { // Skip EvaluatePython nodes. case plan: EvaluatePython => plan - case plan: LogicalPlan => + case plan: LogicalPlan if plan.resolved => // Extract any PythonUDFs from the current operator. val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) if (udfs.isEmpty) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 4441afd6bd811..73bc6c999164e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1367,9 +1367,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("SPARK-6145: special cases") { sqlContext.read.json(sqlContext.sparkContext.makeRDD( - """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") - checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) - checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) + """{"a": {"b": [1]}, "b": [{"a": 1}], "_c0": {"a": 1}}""" :: Nil)).registerTempTable("t") + checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY _c0.a"), Row(1)) + checkAnswer(sql("SELECT b[0].a FROM t ORDER BY _c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") {