diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala index 3c9b5b2e9f..79e5a6dd3c 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala @@ -1415,6 +1415,10 @@ object Utils extends Logging { } (Seq(countUpdateExpr), Seq(count)) } + case PartialMerge => { + val countUpdateExpr = Add(count, c.inputAggBufferAttributes(0)) + (Seq(countUpdateExpr), Seq(count)) + } case Final => { val countUpdateExpr = Add(count, c.inputAggBufferAttributes(0)) (Seq(countUpdateExpr), Seq(count)) @@ -1423,7 +1427,7 @@ object Utils extends Logging { val countUpdateExpr = Add(count, Literal(1L)) (Seq(countUpdateExpr), Seq(count)) } - case _ => + case _ => } tuix.AggregateExpr.createAggregateExpr( @@ -1594,6 +1598,11 @@ object Utils extends Logging { val sumUpdateExpr = If(IsNull(partialSum), sum, partialSum) (Seq(sumUpdateExpr), Seq(sum)) } + case PartialMerge => { + val partialSum = Add(If(IsNull(sum), Literal.default(sumDataType), sum), s.inputAggBufferAttributes(0)) + val sumUpdateExpr = If(IsNull(partialSum), sum, partialSum) + (Seq(sumUpdateExpr), Seq(sum)) + } case Final => { val partialSum = Add(If(IsNull(sum), Literal.default(sumDataType), sum), s.inputAggBufferAttributes(0)) val sumUpdateExpr = If(IsNull(partialSum), sum, partialSum) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala index 07da3b7d80..5f269595de 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/benchmark/TPCHBenchmark.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.SQLContext object TPCHBenchmark { // Add query numbers here once they are supported - val supportedQueries = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 17, 18, 19, 20, 21, 22) + val supportedQueries = Seq(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 16, 17, 18, 19, 20, 21, 22) def query(queryNumber: Int, tpch: TPCH, sqlContext: SQLContext, numPartitions: Int) = { val sqlStr = tpch.getQuery(queryNumber) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala index e2d62cde51..6d7855f46a 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala @@ -233,43 +233,34 @@ case class EncryptedFilterExec(condition: Expression, child: SparkPlan) case class EncryptedAggregateExec( groupingExpressions: Seq[NamedExpression], - aggExpressions: Seq[AggregateExpression], - mode: AggregateMode, + aggregateExpressions: Seq[AggregateExpression], child: SparkPlan) extends UnaryExecNode with OpaqueOperatorExec { override def producedAttributes: AttributeSet = - AttributeSet(aggExpressions) -- AttributeSet(groupingExpressions) - - override def output: Seq[Attribute] = mode match { - case Partial => groupingExpressions.map(_.toAttribute) ++ aggExpressions.map(_.copy(mode = Partial)).flatMap(_.aggregateFunction.inputAggBufferAttributes) - case Final => groupingExpressions.map(_.toAttribute) ++ aggExpressions.map(_.resultAttribute) - case Complete => groupingExpressions.map(_.toAttribute) ++ aggExpressions.map(_.resultAttribute) - } + AttributeSet(aggregateExpressions) -- AttributeSet(groupingExpressions) + + override def output: Seq[Attribute] = groupingExpressions.map(_.toAttribute) ++ + aggregateExpressions.flatMap(expr => { + expr.mode match { + case Partial | PartialMerge => + expr.aggregateFunction.inputAggBufferAttributes + case _ => + Seq(expr.resultAttribute) + } + }) override def executeBlocked(): RDD[Block] = { - val (groupingExprs, aggExprs) = mode match { - case Partial => { - val partialAggExpressions = aggExpressions.map(_.copy(mode = Partial)) - (groupingExpressions, partialAggExpressions) - } - case Final => { - val finalGroupingExpressions = groupingExpressions.map(_.toAttribute) - val finalAggExpressions = aggExpressions.map(_.copy(mode = Final)) - (finalGroupingExpressions, finalAggExpressions) - } - case Complete => { - (groupingExpressions, aggExpressions.map(_.copy(mode = Complete))) - } - } + val aggExprSer = Utils.serializeAggOp(groupingExpressions, aggregateExpressions, child.output) + val isPartial = aggregateExpressions.map(expr => expr.mode) + .exists(mode => mode == Partial || mode == PartialMerge) - val aggExprSer = Utils.serializeAggOp(groupingExprs, aggExprs, child.output) timeOperator(child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), "EncryptedPartialAggregateExec") { childRDD => childRDD.map { block => val (enclave, eid) = Utils.initEnclave() - Block(enclave.NonObliviousAggregate(eid, aggExprSer, block.bytes, (mode == Partial))) + Block(enclave.NonObliviousAggregate(eid, aggExprSer, block.bytes, isPartial)) } } } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala index 9f7e325131..d36d3c6d01 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala @@ -132,25 +132,90 @@ object OpaqueOperators extends Strategy { if (isEncrypted(child) && aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression])) => val aggregateExpressions = aggExpressions.map(expr => expr.asInstanceOf[AggregateExpression]) - - if (groupingExpressions.size == 0) { - // Global aggregation - val partialAggregate = EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Partial, planLater(child)) - val partialOutput = partialAggregate.output - val (projSchema, tag) = tagForGlobalAggregate(partialOutput) - - EncryptedProjectExec(resultExpressions, - EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Final, - EncryptedProjectExec(partialOutput, - EncryptedSortExec(Seq(SortOrder(tag, Ascending)), true, - EncryptedProjectExec(projSchema, partialAggregate))))) :: Nil - } else { - // Grouping aggregation - EncryptedProjectExec(resultExpressions, - EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Final, - EncryptedSortExec(groupingExpressions.map(_.toAttribute).map(e => SortOrder(e, Ascending)), true, - EncryptedAggregateExec(groupingExpressions, aggregateExpressions, Partial, - EncryptedSortExec(groupingExpressions.map(e => SortOrder(e, Ascending)), false, planLater(child)))))) :: Nil + val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) + + functionsWithDistinct.size match { + case 0 => // No distinct aggregate operations + if (groupingExpressions.size == 0) { + // Global aggregation + val partialAggregate = EncryptedAggregateExec(groupingExpressions, + aggregateExpressions.map(_.copy(mode = Partial)), planLater(child)) + val partialOutput = partialAggregate.output + val (projSchema, tag) = tagForGlobalAggregate(partialOutput) + + EncryptedProjectExec(resultExpressions, + EncryptedAggregateExec(groupingExpressions, aggregateExpressions.map(_.copy(mode = Final)), + EncryptedProjectExec(partialOutput, + EncryptedSortExec(Seq(SortOrder(tag, Ascending)), true, + EncryptedProjectExec(projSchema, partialAggregate))))) :: Nil + } else { + // Grouping aggregation + EncryptedProjectExec(resultExpressions, + EncryptedAggregateExec(groupingExpressions, aggregateExpressions.map(_.copy(mode = Final)), + EncryptedSortExec(groupingExpressions.map(_.toAttribute).map(e => SortOrder(e, Ascending)), true, + EncryptedAggregateExec(groupingExpressions, aggregateExpressions.map(_.copy(mode = Partial)), + EncryptedSortExec(groupingExpressions.map(e => SortOrder(e, Ascending)), false, planLater(child)))))) :: Nil + } + case size if size == 1 => // One distinct aggregate operation + // Because we are also grouping on the columns used in the distinct expressions, + // we do not need separate cases for global and grouping aggregation. + + // We need to extract named expressions from the children of the distinct aggregate functions + // in order to group by those columns. + val namedDistinctExpressions = functionsWithDistinct.head.aggregateFunction.children.flatMap{ e => + e match { + case ne: NamedExpression => + Seq(ne) + case _ => + e.children.filter(child => child.isInstanceOf[NamedExpression]) + .map(child => child.asInstanceOf[NamedExpression]) + } + } + val combinedGroupingExpressions = groupingExpressions ++ namedDistinctExpressions + + // 1. Create an Aggregate operator for partial aggregations. + val partialAggregate = { + val sorted = EncryptedSortExec(combinedGroupingExpressions.map(e => SortOrder(e, Ascending)), false, + planLater(child)) + EncryptedAggregateExec(combinedGroupingExpressions, functionsWithoutDistinct.map(_.copy(mode = Partial)), sorted) + } + + // 2. Create an Aggregate operator for partial merge aggregations. + val partialMergeAggregate = { + // Partition based on the final grouping expressions. + val partitionOrder = groupingExpressions.map(e => SortOrder(e, Ascending)) + val partitioned = EncryptedRangePartitionExec(partitionOrder, partialAggregate) + + // Local sort on the combined grouping expressions. + val sortOrder = combinedGroupingExpressions.map(e => SortOrder(e, Ascending)) + val sorted = EncryptedSortExec(sortOrder, false, partitioned) + + EncryptedAggregateExec(combinedGroupingExpressions, + functionsWithoutDistinct.map(_.copy(mode = PartialMerge)), sorted) + } + + // 3. Create an Aggregate operator for partial aggregation of distinct aggregate expressions. + val partialDistinctAggregate = { + // Indistinct functions operate on aggregation buffers since partial aggregation was already called, + // but distinct functions operate on the original input to the aggregation. + EncryptedAggregateExec(groupingExpressions, + functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) ++ + functionsWithDistinct.map(_.copy(mode = Partial)), partialMergeAggregate) + } + + // 4. Create an Aggregate operator for the final aggregation. + val finalAggregate = { + val sorted = EncryptedSortExec(groupingExpressions.map(e => SortOrder(e, Ascending)), + true, partialDistinctAggregate) + EncryptedAggregateExec(groupingExpressions, + (functionsWithoutDistinct ++ functionsWithDistinct).map(_.copy(mode = Final)), sorted) + } + + EncryptedProjectExec(resultExpressions, finalAggregate) :: Nil + + case _ => { // More than one distinct operations + throw new UnsupportedOperationException("Aggregate operations with more than one distinct expressions are not yet supported.") + } } case p @ Union(Seq(left, right)) if isEncrypted(p) => diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index ebc2b09dce..ed59f7cba1 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -479,6 +479,30 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => .collect.sortBy { case Row(category: String, _) => category } } + testAgainstSpark("aggregate count distinct and indistinct") { securityLevel => + val data = (0 until 64).map{ i => + if (i % 6 == 0) + (abc(i), null.asInstanceOf[Int], i % 8) + else + (abc(i), i % 4, i % 8) + }.toSeq + val words = makeDF(data, securityLevel, "category", "id", "price") + words.groupBy("category").agg(countDistinct("id").as("num_unique_ids"), + count("price").as("num_prices")).collect.toSet + } + + testAgainstSpark("aggregate count distinct") { securityLevel => + val data = (0 until 64).map{ i => + if (i % 6 == 0) + (abc(i), null.asInstanceOf[Int]) + else + (abc(i), i % 8) + }.toSeq + val words = makeDF(data, securityLevel, "category", "price") + words.groupBy("category").agg(countDistinct("price").as("num_unique_prices")) + .collect.sortBy { case Row(category: String, _) => category } + } + testAgainstSpark("aggregate first") { securityLevel => val data = for (i <- 0 until 256) yield (i, abc(i), 1) val words = makeDF(data, securityLevel, "id", "category", "price") @@ -526,6 +550,30 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => .collect.sortBy { case Row(word: String, _) => word } } + testAgainstSpark("aggregate sum distinct and indistinct") { securityLevel => + val data = (0 until 64).map{ i => + if (i % 6 == 0) + (abc(i), null.asInstanceOf[Int], i % 8) + else + (abc(i), i % 4, i % 8) + }.toSeq + val words = makeDF(data, securityLevel, "category", "id", "price") + words.groupBy("category").agg(sumDistinct("id").as("sum_unique_ids"), + sum("price").as("sum_prices")).collect.toSet + } + + testAgainstSpark("aggregate sum distinct") { securityLevel => + val data = (0 until 64).map{ i => + if (i % 6 == 0) + (abc(i), null.asInstanceOf[Int]) + else + (abc(i), i % 8) + }.toSeq + val words = makeDF(data, securityLevel, "category", "price") + words.groupBy("category").agg(sumDistinct("price").as("sum_unique_prices")) + .collect.sortBy { case Row(category: String, _) => category } + } + testAgainstSpark("aggregate on multiple columns") { securityLevel => val data = for (i <- 0 until 256) yield (abc(i), 1, 1.0f) val words = makeDF(data, securityLevel, "str", "x", "y") @@ -557,6 +605,12 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => words.agg(sum("count").as("totalCount")).collect } + testAgainstSpark("global aggregate count distinct") { securityLevel => + val data = for (i <- 0 until 256) yield (i, abc(i), i % 64) + val words = makeDF(data, securityLevel, "id", "word", "price") + words.agg(countDistinct("price").as("num_unique_prices")).collect + } + testAgainstSpark("global aggregate with 0 rows") { securityLevel => val data = for (i <- 0 until 256) yield (i, abc(i), 1) val words = makeDF(data, securityLevel, "id", "word", "count")