Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,13 @@ public Count(FieldReference column, boolean isDistinct) {
public boolean isDistinct() { return isDistinct; }

@Override
public String toString() { return "Count(" + column.describe() + "," + isDistinct + ")"; }
public String toString() {
if (isDistinct) {
return "COUNT(DISTINCT " + column.describe() + ")";
} else {
return "COUNT(" + column.describe() + ")";
}
}

@Override
public String describe() { return this.toString(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public CountStar() {
}

@Override
public String toString() { return "CountStar()"; }
public String toString() { return "COUNT(*)"; }

@Override
public String describe() { return this.toString(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public final class Max implements AggregateFunc {
public FieldReference column() { return column; }

@Override
public String toString() { return "Max(" + column.describe() + ")"; }
public String toString() { return "MAX(" + column.describe() + ")"; }

@Override
public String describe() { return this.toString(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public final class Min implements AggregateFunc {
public FieldReference column() { return column; }

@Override
public String toString() { return "Min(" + column.describe() + ")"; }
public String toString() { return "MIN(" + column.describe() + ")"; }

@Override
public String describe() { return this.toString(); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark.sql.connector.expressions;

import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.types.DataType;

/**
* An aggregate function that returns the summation of all the values in a group.
Expand All @@ -28,22 +27,23 @@
@Evolving
public final class Sum implements AggregateFunc {
private final FieldReference column;
private final DataType dataType;
private final boolean isDistinct;

public Sum(FieldReference column, DataType dataType, boolean isDistinct) {
public Sum(FieldReference column, boolean isDistinct) {
this.column = column;
this.dataType = dataType;
this.isDistinct = isDistinct;
}

public FieldReference column() { return column; }
public DataType dataType() { return dataType; }
public boolean isDistinct() { return isDistinct; }

@Override
public String toString() {
return "Sum(" + column.describe() + "," + dataType + "," + isDistinct + ")";
if (isDistinct) {
return "SUM(DISTINCT " + column.describe() + ")";
} else {
return "SUM(" + column.describe() + ")";
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,7 @@ object DataSourceStrategy
case _ => None
}
case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
Some(new Sum(FieldReference(name).asInstanceOf[FieldReference],
sum.dataType, aggregates.isDistinct))
Some(new Sum(FieldReference(name).asInstanceOf[FieldReference], aggregates.isDistinct))
case _ => None
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, FieldReference, Max, Min, Sum}
import org.apache.spark.sql.connector.expressions.{AggregateFunc, Count, CountStar, Max, Min, Sum}
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -54,9 +54,14 @@ object JDBCRDD extends Logging {
val url = options.url
val table = options.tableOrQuery
val dialect = JdbcDialects.get(url)
getQueryOutputSchema(dialect.getSchemaQuery(table), options, dialect)
}

def getQueryOutputSchema(
query: String, options: JDBCOptions, dialect: JdbcDialect): StructType = {
val conn: Connection = JdbcUtils.createConnectionFactory(options)()
try {
val statement = conn.prepareStatement(dialect.getSchemaQuery(table))
val statement = conn.prepareStatement(query)
try {
statement.setQueryTimeout(options.queryTimeout)
val rs = statement.executeQuery()
Expand Down Expand Up @@ -136,30 +141,30 @@ object JDBCRDD extends Logging {

def compileAggregates(
aggregates: Seq[AggregateFunc],
dialect: JdbcDialect): Seq[String] = {
dialect: JdbcDialect): Option[Seq[String]] = {
def quote(colName: String): String = dialect.quoteIdentifier(colName)

aggregates.map {
Some(aggregates.map {
case min: Min =>
assert(min.column.fieldNames.length == 1)
if (min.column.fieldNames.length != 1) return None
s"MIN(${quote(min.column.fieldNames.head)})"
case max: Max =>
assert(max.column.fieldNames.length == 1)
if (max.column.fieldNames.length != 1) return None
s"MAX(${quote(max.column.fieldNames.head)})"
case count: Count =>
assert(count.column.fieldNames.length == 1)
val distinct = if (count.isDistinct) "DISTINCT" else ""
if (count.column.fieldNames.length != 1) return None
val distinct = if (count.isDistinct) "DISTINCT " else ""
val column = quote(count.column.fieldNames.head)
s"COUNT($distinct $column)"
s"COUNT($distinct$column)"
case sum: Sum =>
assert(sum.column.fieldNames.length == 1)
val distinct = if (sum.isDistinct) "DISTINCT" else ""
if (sum.column.fieldNames.length != 1) return None
val distinct = if (sum.isDistinct) "DISTINCT " else ""
val column = quote(sum.column.fieldNames.head)
s"SUM($distinct $column)"
s"SUM($distinct$column)"
case _: CountStar =>
s"COUNT(1)"
case _ => ""
}
s"COUNT(*)"
case _ => return None
})
}

/**
Expand All @@ -185,7 +190,7 @@ object JDBCRDD extends Logging {
parts: Array[Partition],
options: JDBCOptions,
outputSchema: Option[StructType] = None,
groupByColumns: Option[Array[FieldReference]] = None): RDD[InternalRow] = {
groupByColumns: Option[Array[String]] = None): RDD[InternalRow] = {
val url = options.url
val dialect = JdbcDialects.get(url)
val quotedColumns = if (groupByColumns.isEmpty) {
Expand Down Expand Up @@ -221,7 +226,7 @@ private[jdbc] class JDBCRDD(
partitions: Array[Partition],
url: String,
options: JDBCOptions,
groupByColumns: Option[Array[FieldReference]])
groupByColumns: Option[Array[String]])
extends RDD[InternalRow](sc, Nil) {

/**
Expand Down Expand Up @@ -266,10 +271,8 @@ private[jdbc] class JDBCRDD(
*/
private def getGroupByClause: String = {
if (groupByColumns.nonEmpty && groupByColumns.get.nonEmpty) {
assert(groupByColumns.get.forall(_.fieldNames.length == 1))
val dialect = JdbcDialects.get(url)
val quotedColumns = groupByColumns.get.map(c => dialect.quoteIdentifier(c.fieldNames.head))
s"GROUP BY ${quotedColumns.mkString(", ")}"
// The GROUP BY columns should already be quoted by the caller side.
s"GROUP BY ${groupByColumns.get.mkString(", ")}"
} else {
""
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.jdbc.JdbcDialects
Expand Down Expand Up @@ -291,9 +290,9 @@ private[sql] case class JDBCRelation(

def buildScan(
requiredColumns: Array[String],
requireSchema: Option[StructType],
finalSchema: StructType,
filters: Array[Filter],
groupByColumns: Option[Array[FieldReference]]): RDD[Row] = {
groupByColumns: Option[Array[String]]): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sparkSession.sparkContext,
Expand All @@ -302,7 +301,7 @@ private[sql] case class JDBCRelation(
filters,
parts,
jdbcOptions,
requireSchema,
Some(finalSchema),
groupByColumns).asInstanceOf[RDD[Row]]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ object PushDownUtils extends PredicateHelper {
}

scanBuilder match {
case r: SupportsPushDownAggregates =>
case r: SupportsPushDownAggregates if aggregates.nonEmpty =>
val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate)
val translatedGroupBys = groupBy.flatMap(columnAsString)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.execution.datasources.v2

import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
Expand Down Expand Up @@ -76,9 +78,18 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) =>
sHolder.builder match {
case _: SupportsPushDownAggregates =>
val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
var ordinal = 0
val aggregates = resultExpressions.flatMap { expr =>
expr.collect {
case agg: AggregateExpression => agg
// Do not push down duplicated aggregate expressions. For example,
// `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
// `max(a)` to the data source.
case agg: AggregateExpression
if !aggExprToOutputOrdinal.contains(agg.canonicalized) =>
aggExprToOutputOrdinal(agg.canonicalized) = ordinal
ordinal += 1
agg
}
}
val pushedAggregates = PushDownUtils
Expand Down Expand Up @@ -144,19 +155,18 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
// Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
// +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
// scalastyle:on
var i = 0
val aggOutput = output.drop(groupAttrs.length)
plan.transformExpressions {
case agg: AggregateExpression =>
val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
val aggFunction: aggregate.AggregateFunction =
agg.aggregateFunction match {
case max: aggregate.Max => max.copy(child = aggOutput(i))
case min: aggregate.Min => min.copy(child = aggOutput(i))
case sum: aggregate.Sum => sum.copy(child = aggOutput(i))
case _: aggregate.Count => aggregate.Sum(aggOutput(i))
case max: aggregate.Max => max.copy(child = aggOutput(ordinal))
case min: aggregate.Min => min.copy(child = aggOutput(ordinal))
case sum: aggregate.Sum => sum.copy(child = aggOutput(ordinal))
case _: aggregate.Count => aggregate.Sum(aggOutput(ordinal))
case other => other
}
i += 1
agg.copy(aggregateFunction = aggFunction)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.connector.expressions.FieldReference
import org.apache.spark.sql.connector.read.V1Scan
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan}
Expand All @@ -29,7 +28,7 @@ case class JDBCScan(
prunedSchema: StructType,
pushedFilters: Array[Filter],
pushedAggregateColumn: Array[String] = Array(),
groupByColumns: Option[Array[FieldReference]]) extends V1Scan {
groupByColumns: Option[Array[String]]) extends V1Scan {

override def readSchema(): StructType = prunedSchema

Expand All @@ -39,13 +38,12 @@ case class JDBCScan(
override def schema: StructType = prunedSchema
override def needConversion: Boolean = relation.needConversion
override def buildScan(): RDD[Row] = {
if (groupByColumns.isEmpty) {
relation.buildScan(
prunedSchema.map(_.name).toArray, Some(prunedSchema), pushedFilters, groupByColumns)
val columnList = if (groupByColumns.isEmpty) {
prunedSchema.map(_.name).toArray
} else {
relation.buildScan(
pushedAggregateColumn, Some(prunedSchema), pushedFilters, groupByColumns)
pushedAggregateColumn
}
relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns)
}
}.asInstanceOf[T]
}
Expand Down
Loading