From 358d7a4c43239458c18e3b7c1875917c1f86ae49 Mon Sep 17 00:00:00 2001 From: Octavian Sima Date: Wed, 10 Feb 2021 23:26:42 +0000 Subject: [PATCH] matching in strategies.scala set up class thing cleanup added test cases for non-equi left anti join rename to serializeEquiJoinExpression added isEncrypted condition set up keys JoinExpr now has condition rename serialization does not throw compile error for BNLJ split up added condition in ExpressionEvaluation.h zipPartitions cpp put in place typo added func to header two loops in place update tests condition fixed scala loop interchange rows added tags ensure cached == match working comparison decoupling in ExpressionEvalulation save compiles and condition works is printing fix swap outer/inner o_i_match show() has the same result tests pass test cleanup added test cases for different condition BuildLeft works optional keys in scala started C++ passes the operator tests comments, cleanup attemping to do it the ~right~ way comments to distinguish between primary/secondary, operator tests pass cleanup comments, about to begin implementation for distinct agg ops is_distinct added test case serializing with isDistinct is_distinct in ExpressionEvaluation.h removed unused code from join implementation remove RowWriter/Reader in condition evaluation (join) easier test serialization done correct checking in Scala set is set up spaghetti but it finally works function for clearing values condition_eval isntead of condition goto comment started impl of multiple partitions fix added rangepartitionexec that runs partitioning cleanup serialization properly comments, generalization for > 1 distinct function comments about to refactor into logical.Aggregation the new case has distinct in result expressions need to match on distinct removed new case (doesn't make difference?) works remove traces of distinct more cleanup address comments rename equi join split Join.cpp into two files Update App.cpp fixed swap issues one more swap stream/broadcast concatEncryptedBlocks, remove import iostream comment for for loop added comments explaining constraints with broadcast side comments left semi done, existence serializes remove existence serialization fixed --- src/enclave/App/App.cpp | 44 ++++++++ src/enclave/App/SGXEnclave.h | 4 + .../Enclave/BroadcastNestedLoopJoin.cpp | 54 ++++++++++ src/enclave/Enclave/BroadcastNestedLoopJoin.h | 8 ++ src/enclave/Enclave/CMakeLists.txt | 3 +- src/enclave/Enclave/Enclave.cpp | 22 +++- src/enclave/Enclave/Enclave.edl | 6 ++ src/enclave/Enclave/ExpressionEvaluation.h | 100 +++++++++++++----- ...Join.cpp => NonObliviousSortMergeJoin.cpp} | 11 +- .../{Join.h => NonObliviousSortMergeJoin.h} | 5 - src/flatbuffers/operators.fbs | 9 +- .../edu/berkeley/cs/rise/opaque/Utils.scala | 39 ++++--- .../cs/rise/opaque/execution/SGXEnclave.scala | 3 + .../cs/rise/opaque/execution/operators.scala | 72 ++++++++++++- .../berkeley/cs/rise/opaque/strategies.scala | 47 +++++++- .../cs/rise/opaque/OpaqueOperatorTests.scala | 54 ++++++++++ 16 files changed, 418 insertions(+), 63 deletions(-) create mode 100644 src/enclave/Enclave/BroadcastNestedLoopJoin.cpp create mode 100644 src/enclave/Enclave/BroadcastNestedLoopJoin.h rename src/enclave/Enclave/{Join.cpp => NonObliviousSortMergeJoin.cpp} (88%) rename src/enclave/Enclave/{Join.h => NonObliviousSortMergeJoin.h} (85%) diff --git a/src/enclave/App/App.cpp b/src/enclave/App/App.cpp index 64013d2ab7..596e593d52 100644 --- a/src/enclave/App/App.cpp +++ b/src/enclave/App/App.cpp @@ -555,6 +555,50 @@ Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin( return ret; } +JNIEXPORT jbyteArray JNICALL +Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_BroadcastNestedLoopJoin( + JNIEnv *env, jobject obj, jlong eid, jbyteArray join_expr, jbyteArray outer_rows, jbyteArray inner_rows) { + (void)obj; + + jboolean if_copy; + + uint32_t join_expr_length = (uint32_t) env->GetArrayLength(join_expr); + uint8_t *join_expr_ptr = (uint8_t *) env->GetByteArrayElements(join_expr, &if_copy); + + uint32_t outer_rows_length = (uint32_t) env->GetArrayLength(outer_rows); + uint8_t *outer_rows_ptr = (uint8_t *) env->GetByteArrayElements(outer_rows, &if_copy); + + uint32_t inner_rows_length = (uint32_t) env->GetArrayLength(inner_rows); + uint8_t *inner_rows_ptr = (uint8_t *) env->GetByteArrayElements(inner_rows, &if_copy); + + uint8_t *output_rows = nullptr; + size_t output_rows_length = 0; + + if (outer_rows_ptr == nullptr) { + ocall_throw("BroadcastNestedLoopJoin: JNI failed to get inner byte array."); + } else if (inner_rows_ptr == nullptr) { + ocall_throw("BroadcastNestedLoopJoin: JNI failed to get outer byte array."); + } else { + oe_check_and_time("Broadcast Nested Loop Join", + ecall_broadcast_nested_loop_join( + (oe_enclave_t*)eid, + join_expr_ptr, join_expr_length, + outer_rows_ptr, outer_rows_length, + inner_rows_ptr, inner_rows_length, + &output_rows, &output_rows_length)); + } + + jbyteArray ret = env->NewByteArray(output_rows_length); + env->SetByteArrayRegion(ret, 0, output_rows_length, (jbyte *) output_rows); + free(output_rows); + + env->ReleaseByteArrayElements(join_expr, (jbyte *) join_expr_ptr, 0); + env->ReleaseByteArrayElements(outer_rows, (jbyte *) outer_rows_ptr, 0); + env->ReleaseByteArrayElements(inner_rows, (jbyte *) inner_rows_ptr, 0); + + return ret; +} + JNIEXPORT jobject JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate( JNIEnv *env, jobject obj, jlong eid, jbyteArray agg_op, jbyteArray input_rows, jboolean isPartial) { diff --git a/src/enclave/App/SGXEnclave.h b/src/enclave/App/SGXEnclave.h index 2b74c42763..1ddd0d8497 100644 --- a/src/enclave/App/SGXEnclave.h +++ b/src/enclave/App/SGXEnclave.h @@ -41,6 +41,10 @@ extern "C" { Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousSortMergeJoin( JNIEnv *, jobject, jlong, jbyteArray, jbyteArray); + JNIEXPORT jbyteArray JNICALL + Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_BroadcastNestedLoopJoin( + JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jbyteArray); + JNIEXPORT jobject JNICALL Java_edu_berkeley_cs_rise_opaque_execution_SGXEnclave_NonObliviousAggregate( JNIEnv *, jobject, jlong, jbyteArray, jbyteArray, jboolean); diff --git a/src/enclave/Enclave/BroadcastNestedLoopJoin.cpp b/src/enclave/Enclave/BroadcastNestedLoopJoin.cpp new file mode 100644 index 0000000000..c99297ebf5 --- /dev/null +++ b/src/enclave/Enclave/BroadcastNestedLoopJoin.cpp @@ -0,0 +1,54 @@ +#include "BroadcastNestedLoopJoin.h" + +#include "ExpressionEvaluation.h" +#include "FlatbuffersReaders.h" +#include "FlatbuffersWriters.h" +#include "common.h" + +/** C++ implementation of a broadcast nested loop join. + * Assumes outer_rows is streamed and inner_rows is broadcast. + * DOES NOT rely on rows to be tagged primary or secondary, and that + * assumption will break the implementation. + */ +void broadcast_nested_loop_join( + uint8_t *join_expr, size_t join_expr_length, + uint8_t *outer_rows, size_t outer_rows_length, + uint8_t *inner_rows, size_t inner_rows_length, + uint8_t **output_rows, size_t *output_rows_length) { + + FlatbuffersJoinExprEvaluator join_expr_eval(join_expr, join_expr_length); + const tuix::JoinType join_type = join_expr_eval.get_join_type(); + + RowReader outer_r(BufferRefView(outer_rows, outer_rows_length)); + RowWriter w; + + while (outer_r.has_next()) { + const tuix::Row *outer = outer_r.next(); + bool o_i_match = false; + + RowReader inner_r(BufferRefView(inner_rows, inner_rows_length)); + const tuix::Row *inner; + while (inner_r.has_next()) { + inner = inner_r.next(); + o_i_match |= join_expr_eval.eval_condition(outer, inner); + } + + switch(join_type) { + case tuix::JoinType_LeftAnti: + if (!o_i_match) { + w.append(outer); + } + break; + case tuix::JoinType_LeftSemi: + if (o_i_match) { + w.append(outer); + } + break; + default: + throw std::runtime_error( + std::string("Join type not supported: ") + + std::string(to_string(join_type))); + } + } + w.output_buffer(output_rows, output_rows_length); +} diff --git a/src/enclave/Enclave/BroadcastNestedLoopJoin.h b/src/enclave/Enclave/BroadcastNestedLoopJoin.h new file mode 100644 index 0000000000..55c934067b --- /dev/null +++ b/src/enclave/Enclave/BroadcastNestedLoopJoin.h @@ -0,0 +1,8 @@ +#include +#include + +void broadcast_nested_loop_join( + uint8_t *join_expr, size_t join_expr_length, + uint8_t *outer_rows, size_t outer_rows_length, + uint8_t *inner_rows, size_t inner_rows_length, + uint8_t **output_rows, size_t *output_rows_length); diff --git a/src/enclave/Enclave/CMakeLists.txt b/src/enclave/Enclave/CMakeLists.txt index 6a72e76dfd..07e6130d80 100644 --- a/src/enclave/Enclave/CMakeLists.txt +++ b/src/enclave/Enclave/CMakeLists.txt @@ -10,7 +10,8 @@ set(SOURCES Flatbuffers.cpp FlatbuffersReaders.cpp FlatbuffersWriters.cpp - Join.cpp + NonObliviousSortMergeJoin.cpp + BroadcastNestedLoopJoin.cpp Limit.cpp Project.cpp Sort.cpp diff --git a/src/enclave/Enclave/Enclave.cpp b/src/enclave/Enclave/Enclave.cpp index e9342875b2..fde1806a97 100644 --- a/src/enclave/Enclave/Enclave.cpp +++ b/src/enclave/Enclave/Enclave.cpp @@ -6,7 +6,8 @@ #include "Aggregate.h" #include "Crypto.h" #include "Filter.h" -#include "Join.h" +#include "NonObliviousSortMergeJoin.h" +#include "BroadcastNestedLoopJoin.h" #include "Limit.h" #include "Project.h" #include "Sort.h" @@ -161,6 +162,25 @@ void ecall_non_oblivious_sort_merge_join(uint8_t *join_expr, size_t join_expr_le } } +void ecall_broadcast_nested_loop_join(uint8_t *join_expr, size_t join_expr_length, + uint8_t *outer_rows, size_t outer_rows_length, + uint8_t *inner_rows, size_t inner_rows_length, + uint8_t **output_rows, size_t *output_rows_length) { + // Guard against operating on arbitrary enclave memory + assert(oe_is_outside_enclave(outer_rows, outer_rows_length) == 1); + assert(oe_is_outside_enclave(inner_rows, inner_rows_length) == 1); + __builtin_ia32_lfence(); + + try { + broadcast_nested_loop_join(join_expr, join_expr_length, + outer_rows, outer_rows_length, + inner_rows, inner_rows_length, + output_rows, output_rows_length); + } catch (const std::runtime_error &e) { + ocall_throw(e.what()); + } +} + void ecall_non_oblivious_aggregate( uint8_t *agg_op, size_t agg_op_length, uint8_t *input_rows, size_t input_rows_length, diff --git a/src/enclave/Enclave/Enclave.edl b/src/enclave/Enclave/Enclave.edl index 44eccc7a76..1789ff2b64 100644 --- a/src/enclave/Enclave/Enclave.edl +++ b/src/enclave/Enclave/Enclave.edl @@ -51,6 +51,12 @@ enclave { [user_check] uint8_t *input_rows, size_t input_rows_length, [out] uint8_t **output_rows, [out] size_t *output_rows_length); + public void ecall_broadcast_nested_loop_join( + [in, count=join_expr_length] uint8_t *join_expr, size_t join_expr_length, + [user_check] uint8_t *outer_rows, size_t outer_rows_length, + [user_check] uint8_t *inner_rows, size_t inner_rows_length, + [out] uint8_t **output_rows, [out] size_t *output_rows_length); + public void ecall_non_oblivious_aggregate( [in, count=agg_op_length] uint8_t *agg_op, size_t agg_op_length, [user_check] uint8_t *input_rows, size_t input_rows_length, diff --git a/src/enclave/Enclave/ExpressionEvaluation.h b/src/enclave/Enclave/ExpressionEvaluation.h index 0f48c56d48..06693d84fa 100644 --- a/src/enclave/Enclave/ExpressionEvaluation.h +++ b/src/enclave/Enclave/ExpressionEvaluation.h @@ -1725,60 +1725,104 @@ class FlatbuffersJoinExprEvaluator { } const tuix::JoinExpr* join_expr = flatbuffers::GetRoot(buf); - join_type = join_expr->join_type(); - if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) { - throw std::runtime_error("Mismatched join key lengths"); - } - for (auto key_it = join_expr->left_keys()->begin(); - key_it != join_expr->left_keys()->end(); ++key_it) { - left_key_evaluators.emplace_back( - std::unique_ptr( - new FlatbuffersExpressionEvaluator(*key_it))); + join_type = join_expr->join_type(); + if (join_expr->condition() != NULL) { + condition_eval = std::unique_ptr( + new FlatbuffersExpressionEvaluator(join_expr->condition())); } - for (auto key_it = join_expr->right_keys()->begin(); - key_it != join_expr->right_keys()->end(); ++key_it) { - right_key_evaluators.emplace_back( - std::unique_ptr( - new FlatbuffersExpressionEvaluator(*key_it))); + is_equi_join = false; + + if (join_expr->left_keys() != NULL && join_expr->right_keys() != NULL) { + is_equi_join = true; + if (join_expr->condition() != NULL) { + throw std::runtime_error("Equi join cannot have condition"); + } + if (join_expr->left_keys()->size() != join_expr->right_keys()->size()) { + throw std::runtime_error("Mismatched join key lengths"); + } + for (auto key_it = join_expr->left_keys()->begin(); + key_it != join_expr->left_keys()->end(); ++key_it) { + left_key_evaluators.emplace_back( + std::unique_ptr( + new FlatbuffersExpressionEvaluator(*key_it))); + } + for (auto key_it = join_expr->right_keys()->begin(); + key_it != join_expr->right_keys()->end(); ++key_it) { + right_key_evaluators.emplace_back( + std::unique_ptr( + new FlatbuffersExpressionEvaluator(*key_it))); + } } } - /** - * Return true if the given row is from the primary table, indicated by its first field, which - * must be an IntegerField. + /** Return true if the given row is from the primary table, indicated by its first field, which + * must be an IntegerField. + * Rows MUST have been tagged in Scala. */ bool is_primary(const tuix::Row *row) { return static_cast( row->field_values()->Get(0)->value())->value() == 0; } - /** Return true if the two rows are from the same join group. */ - bool is_same_group(const tuix::Row *row1, const tuix::Row *row2) { - auto &row1_evaluators = is_primary(row1) ? left_key_evaluators : right_key_evaluators; - auto &row2_evaluators = is_primary(row2) ? left_key_evaluators : right_key_evaluators; + /** Returns the row evaluator corresponding to the primary row + * Rows MUST have been tagged in Scala. + */ + const tuix::Row *get_primary_row( + const tuix::Row *row1, const tuix::Row *row2) { + return is_primary(row1) ? row1 : row2; + } + /** Return true if the two rows satisfy the join condition. */ + bool eval_condition(const tuix::Row *row1, const tuix::Row *row2) { builder.Clear(); + bool row1_equals_row2; + + /** Check equality for equi joins. If it is a non-equi join, + * the key evaluators will be empty, so the code never enters the for loop. + */ + auto &row1_evaluators = is_primary(row1) ? left_key_evaluators : right_key_evaluators; + auto &row2_evaluators = is_primary(row2) ? left_key_evaluators : right_key_evaluators; for (uint32_t i = 0; i < row1_evaluators.size(); i++) { const tuix::Field *row1_eval_tmp = row1_evaluators[i]->eval(row1); auto row1_eval_offset = flatbuffers_copy(row1_eval_tmp, builder); + auto row1_field = flatbuffers::GetTemporaryPointer(builder, row1_eval_offset); + const tuix::Field *row2_eval_tmp = row2_evaluators[i]->eval(row2); auto row2_eval_offset = flatbuffers_copy(row2_eval_tmp, builder); + auto row2_field = flatbuffers::GetTemporaryPointer(builder, row2_eval_offset); - bool row1_equals_row2 = + flatbuffers::Offset comparison = eval_binary_comparison( + builder, + row1_field, + row2_field); + row1_equals_row2 = static_cast( flatbuffers::GetTemporaryPointer( builder, - eval_binary_comparison( - builder, - flatbuffers::GetTemporaryPointer(builder, row1_eval_offset), - flatbuffers::GetTemporaryPointer(builder, row2_eval_offset))) - ->value())->value(); + comparison)->value())->value(); if (!row1_equals_row2) { return false; } } + + /* Check condition for non-equi joins */ + if (!is_equi_join) { + std::vector> concat_fields; + for (auto field : *row1->field_values()) { + concat_fields.push_back(flatbuffers_copy(field, builder)); + } + for (auto field : *row2->field_values()) { + concat_fields.push_back(flatbuffers_copy(field, builder)); + } + flatbuffers::Offset concat = tuix::CreateRowDirect(builder, &concat_fields); + const tuix::Row *concat_ptr = flatbuffers::GetTemporaryPointer(builder, concat); + + const tuix::Field *condition_result = condition_eval->eval(concat_ptr); + + return static_cast(condition_result->value())->value(); + } return true; } @@ -1791,6 +1835,8 @@ class FlatbuffersJoinExprEvaluator { tuix::JoinType join_type; std::vector> left_key_evaluators; std::vector> right_key_evaluators; + bool is_equi_join; + std::unique_ptr condition_eval; }; class AggregateExpressionEvaluator { diff --git a/src/enclave/Enclave/Join.cpp b/src/enclave/Enclave/NonObliviousSortMergeJoin.cpp similarity index 88% rename from src/enclave/Enclave/Join.cpp rename to src/enclave/Enclave/NonObliviousSortMergeJoin.cpp index 828c963d40..67bc546c0f 100644 --- a/src/enclave/Enclave/Join.cpp +++ b/src/enclave/Enclave/NonObliviousSortMergeJoin.cpp @@ -1,10 +1,13 @@ -#include "Join.h" +#include "NonObliviousSortMergeJoin.h" #include "ExpressionEvaluation.h" #include "FlatbuffersReaders.h" #include "FlatbuffersWriters.h" #include "common.h" +/** C++ implementation of a non-oblivious sort merge join. + * Rows MUST be tagged primary or secondary for this to work. + */ void non_oblivious_sort_merge_join( uint8_t *join_expr, size_t join_expr_length, uint8_t *input_rows, size_t input_rows_length, @@ -25,7 +28,7 @@ void non_oblivious_sort_merge_join( if (join_expr_eval.is_primary(current)) { if (last_primary_of_group.get() - && join_expr_eval.is_same_group(last_primary_of_group.get(), current)) { + && join_expr_eval.eval_condition(last_primary_of_group.get(), current)) { // Add this primary row to the current group primary_group.append(current); last_primary_of_group.set(current); @@ -50,13 +53,13 @@ void non_oblivious_sort_merge_join( } else { // Output the joined rows resulting from this foreign row if (last_primary_of_group.get() - && join_expr_eval.is_same_group(last_primary_of_group.get(), current)) { + && join_expr_eval.eval_condition(last_primary_of_group.get(), current)) { auto primary_group_buffer = primary_group.output_buffer(); RowReader primary_group_reader(primary_group_buffer.view()); while (primary_group_reader.has_next()) { const tuix::Row *primary = primary_group_reader.next(); - if (!join_expr_eval.is_same_group(primary, current)) { + if (!join_expr_eval.eval_condition(primary, current)) { throw std::runtime_error( std::string("Invariant violation: rows of primary_group " "are not of the same group: ") diff --git a/src/enclave/Enclave/Join.h b/src/enclave/Enclave/NonObliviousSortMergeJoin.h similarity index 85% rename from src/enclave/Enclave/Join.h rename to src/enclave/Enclave/NonObliviousSortMergeJoin.h index b380909027..ef60c38437 100644 --- a/src/enclave/Enclave/Join.h +++ b/src/enclave/Enclave/NonObliviousSortMergeJoin.h @@ -1,12 +1,7 @@ #include #include -#ifndef JOIN_H -#define JOIN_H - void non_oblivious_sort_merge_join( uint8_t *join_expr, size_t join_expr_length, uint8_t *input_rows, size_t input_rows_length, uint8_t **output_rows, size_t *output_rows_length); - -#endif diff --git a/src/flatbuffers/operators.fbs b/src/flatbuffers/operators.fbs index 1ebd06c971..9fa82b6cab 100644 --- a/src/flatbuffers/operators.fbs +++ b/src/flatbuffers/operators.fbs @@ -54,10 +54,11 @@ enum JoinType : ubyte { } table JoinExpr { join_type:JoinType; - // Currently only cross joins and equijoins are supported, so we store - // parallel arrays of key expressions and the join outputs pairs of rows - // where each expression from the left is equal to the matching expression - // from the right. + // In the case of equi joins, we store parallel arrays of key expressions and have the join output + // pairs of rows where each expression from the left is equal to the matching expression from the right. left_keys:[Expr]; right_keys:[Expr]; + // In the case of non-equi joins, we pass in a condition as an expression and evaluate that on each pair of rows. + // TODO: have equi joins use this condition rather than an additional filter operation. + condition:Expr; } diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala index 4c6970e489..5af2806c1e 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala @@ -1246,8 +1246,9 @@ object Utils extends Logging { } def serializeJoinExpression( - joinType: JoinType, leftKeys: Seq[Expression], rightKeys: Seq[Expression], - leftSchema: Seq[Attribute], rightSchema: Seq[Attribute]): Array[Byte] = { + joinType: JoinType, leftKeys: Option[Seq[Expression]], rightKeys: Option[Seq[Expression]], + leftSchema: Seq[Attribute], rightSchema: Seq[Attribute], + condition: Option[Expression] = None): Array[Byte] = { val builder = new FlatBufferBuilder builder.finish( tuix.JoinExpr.createJoinExpr( @@ -1266,12 +1267,28 @@ object Utils extends Logging { case UsingJoin(_, _) => ??? // scalastyle:on }, - tuix.JoinExpr.createLeftKeysVector( - builder, - leftKeys.map(e => flatbuffersSerializeExpression(builder, e, leftSchema)).toArray), - tuix.JoinExpr.createRightKeysVector( - builder, - rightKeys.map(e => flatbuffersSerializeExpression(builder, e, rightSchema)).toArray))) + // Non-zero when equi join + leftKeys match { + case Some(leftKeys) => + tuix.JoinExpr.createLeftKeysVector( + builder, + leftKeys.map(e => flatbuffersSerializeExpression(builder, e, leftSchema)).toArray) + case None => 0 + }, + // Non-zero when equi join + rightKeys match { + case Some(rightKeys) => + tuix.JoinExpr.createRightKeysVector( + builder, + rightKeys.map(e => flatbuffersSerializeExpression(builder, e, rightSchema)).toArray) + case None => 0 + }, + // Non-zero when non-equi join + condition match { + case Some(condition) => + flatbuffersSerializeExpression(builder, condition, leftSchema ++ rightSchema) + case _ => 0 + })) builder.sizedByteArray() } @@ -1371,8 +1388,7 @@ object Utils extends Logging { updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray), tuix.AggregateExpr.createEvaluateExprsVector( builder, - evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray) - ) + evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray)) case c @ Count(children) => val count = c.aggBufferAttributes(0) @@ -1410,8 +1426,7 @@ object Utils extends Logging { updateExprs.map(e => flatbuffersSerializeExpression(builder, e, concatSchema)).toArray), tuix.AggregateExpr.createEvaluateExprsVector( builder, - evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray) - ) + evaluateExprs.map(e => flatbuffersSerializeExpression(builder, e, aggSchema)).toArray)) case f @ First(child, false) => val first = f.aggBufferAttributes(0) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala index b49090ced1..e1f1d31261 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/SGXEnclave.scala @@ -42,6 +42,9 @@ class SGXEnclave extends java.io.Serializable { @native def NonObliviousSortMergeJoin( eid: Long, joinExpr: Array[Byte], input: Array[Byte]): Array[Byte] + @native def BroadcastNestedLoopJoin( + eid: Long, joinExpr: Array[Byte], outerBlock: Array[Byte], innerBlock: Array[Byte]): Array[Byte] + @native def NonObliviousAggregate( eid: Long, aggOp: Array[Byte], inputRows: Array[Byte], isPartial: Boolean): (Array[Byte]) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala index 4eb941157e..6983df047b 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/execution/operators.scala @@ -26,12 +26,11 @@ import org.apache.spark.sql.catalyst.expressions.AttributeSet import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.LeftAnti -import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.SparkPlan +import edu.berkeley.cs.rise.opaque.OpaqueException trait LeafExecNode extends SparkPlan { override final def children: Seq[SparkPlan] = Nil @@ -294,7 +293,7 @@ case class EncryptedSortMergeJoinExec( override def executeBlocked(): RDD[Block] = { val joinExprSer = Utils.serializeJoinExpression( - joinType, leftKeys, rightKeys, leftSchema, rightSchema) + joinType, Some(leftKeys), Some(rightKeys), leftSchema, rightSchema) timeOperator( child.asInstanceOf[OpaqueOperatorExec].executeBlocked(), @@ -308,6 +307,69 @@ case class EncryptedSortMergeJoinExec( } } +case class EncryptedBroadcastNestedLoopJoinExec( + left: SparkPlan, + right: SparkPlan, + buildSide: BuildSide, + joinType: JoinType, + condition: Option[Expression]) + extends BinaryExecNode with OpaqueOperatorExec { + + override def output: Seq[Attribute] = { + joinType match { + case _: InnerLike => + left.output ++ right.output + case LeftOuter => + left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => + left.output.map(_.withNullability(true)) ++ right.output + case FullOuter => + left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) + case j: ExistenceJoin => + left.output :+ j.exists + case LeftExistence(_) => + left.output + case x => + throw new IllegalArgumentException( + s"BroadcastNestedLoopJoin should not take $x as the JoinType") + } + } + + override def executeBlocked(): RDD[Block] = { + val joinExprSer = Utils.serializeJoinExpression( + joinType, None, None, left.output, right.output, condition) + + val leftRDD = left.asInstanceOf[OpaqueOperatorExec].executeBlocked() + val rightRDD = right.asInstanceOf[OpaqueOperatorExec].executeBlocked() + + joinType match { + case LeftExistence(_) => { + join(leftRDD, rightRDD, joinExprSer) + } + case _ => + throw new OpaqueException(s"$joinType JoinType is not yet supported") + } + } + + def join(leftRDD: RDD[Block], rightRDD: RDD[Block], + joinExprSer: Array[Byte]): RDD[Block] = { + // We pick which side to broadcast/stream according to buildSide. + // BuildRight means the right relation <=> the broadcast relation. + // NOTE: outer_rows and inner_rows in C++ correspond to stream and broadcast side respectively. + var (streamRDD, broadcastRDD) = buildSide match { + case BuildRight => + (leftRDD, rightRDD) + case BuildLeft => + (rightRDD, leftRDD) + } + val broadcast = Utils.concatEncryptedBlocks(broadcastRDD.collect) + streamRDD.map { block => + val (enclave, eid) = Utils.initEnclave() + Block(enclave.BroadcastNestedLoopJoin(eid, joinExprSer, block.bytes, broadcast.bytes)) + } + } +} + case class EncryptedUnionExec( left: SparkPlan, right: SparkPlan) diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala index 0c8f188369..dd104d2ad2 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/strategies.scala @@ -32,13 +32,19 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.planning.PhysicalAggregation import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.FullOuter import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.InnerLike import org.apache.spark.sql.catalyst.plans.LeftAnti import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.SparkPlan import edu.berkeley.cs.rise.opaque.execution._ import edu.berkeley.cs.rise.opaque.logical._ +import org.apache.spark.sql.catalyst.plans.LeftExistence object OpaqueOperators extends Strategy { @@ -73,6 +79,7 @@ object OpaqueOperators extends Strategy { case Sort(sortExprs, global, child) if isEncrypted(child) => EncryptedSortExec(sortExprs, global, planLater(child)) :: Nil + // Used to match equi joins case p @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _) if isEncrypted(p) => val (leftProjSchema, leftKeysProj, tag) = tagForJoin(leftKeys, left.output, true) val (rightProjSchema, rightKeysProj, _) = tagForJoin(rightKeys, right.output, false) @@ -105,6 +112,26 @@ object OpaqueOperators extends Strategy { filtered :: Nil + // Used to match non-equi joins + case Join(left, right, joinType, condition, hint) if isEncrypted(left) && isEncrypted(right) => + // How to pick broadcast side: if left join, broadcast right. If right join, broadcast left. + // This is the simplest and most performant method, but may be worth revisting if one side is + // significantly smaller than the other. Otherwise, pick the smallest side to broadcast. + // NOTE: the current implementation of BNLJ only works under the assumption that + // left join <==> broadcast right AND right join <==> broadcast left. + val desiredBuildSide = if (joinType.isInstanceOf[InnerLike] || joinType == FullOuter) + getSmallerSide(left, right) else + getBroadcastSideBNLJ(joinType) + + val joined = EncryptedBroadcastNestedLoopJoinExec( + planLater(left), + planLater(right), + desiredBuildSide, + joinType, + condition) + + joined :: Nil + case a @ PhysicalAggregation(groupingExpressions, aggExpressions, resultExpressions, child) if (isEncrypted(child) && aggExpressions.forall(expr => expr.isInstanceOf[AggregateExpression])) => @@ -183,17 +210,29 @@ object OpaqueOperators extends Strategy { (Seq(tag) ++ keysProj ++ input, keysProj.map(_.toAttribute), tag.toAttribute) } - private def sortForJoin( - leftKeys: Seq[Expression], tag: Expression, input: Seq[Attribute]): Seq[SortOrder] = - leftKeys.map(k => SortOrder(k, Ascending)) :+ SortOrder(tag, Ascending) - private def dropTags( leftOutput: Seq[Attribute], rightOutput: Seq[Attribute]): Seq[NamedExpression] = leftOutput ++ rightOutput + private def sortForJoin( + leftKeys: Seq[Expression], tag: Expression, input: Seq[Attribute]): Seq[SortOrder] = + leftKeys.map(k => SortOrder(k, Ascending)) :+ SortOrder(tag, Ascending) + private def tagForGlobalAggregate(input: Seq[Attribute]) : (Seq[NamedExpression], NamedExpression) = { val tag = Alias(Literal(0), "_tag")() (Seq(tag) ++ input, tag.toAttribute) } + + private def getBroadcastSideBNLJ(joinType: JoinType): BuildSide = { + joinType match { + case LeftExistence(_) => BuildRight + case _ => BuildLeft + } + } + + // Everything below is a private method in SparkStrategies.scala + private def getSmallerSide(left: LogicalPlan, right: LogicalPlan): BuildSide = { + if (right.stats.sizeInBytes <= left.stats.sizeInBytes) BuildRight else BuildLeft + } } diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index a69894d13c..ff1865f343 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -326,6 +326,24 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => df.collect } + testAgainstSpark("non-equi left semi join") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10) + val p = makeDF(p_data, securityLevel, "id1", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "x") + val df = p.join(f, $"join_col_1" >= $"join_col_2", "left_semi").sort($"join_col_1", $"id1") + df.collect + } + + testAgainstSpark("non-equi left semi join negated") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, (i % 8).toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, (i % 8).toString, i * 10) + val p = makeDF(p_data, securityLevel, "id1", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id2", "join_col_2", "x") + val df = p.join(f, $"join_col_1" < $"join_col_2", "left_semi").sort($"join_col_1", $"id1") + df.collect + } + testAgainstSpark("left anti join 1") { securityLevel => val p_data = for (i <- 1 to 128) yield (i, (i % 16).toString, i * 10) val f_data = for (i <- 1 to 256 if (i % 3) + 1 == 0 || (i % 3) + 5 == 0) yield (i, i.toString, i * 10) @@ -335,6 +353,24 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => df.collect } + testAgainstSpark("non-equi left anti join 1") { securityLevel => + val p_data = for (i <- 1 to 128) yield (i, (i % 16).toString, i * 10) + val f_data = for (i <- 1 to 256 if (i % 3) + 1 == 0 || (i % 3) + 5 == 0) yield (i, i.toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" >= $"join_col_2", "left_anti").sort($"join_col_1", $"id") + df.collect + } + + testAgainstSpark("non-equi left anti join 1 negated") { securityLevel => + val p_data = for (i <- 1 to 128) yield (i, (i % 16).toString, i * 10) + val f_data = for (i <- 1 to 256 if (i % 3) + 1 == 0 || (i % 3) + 5 == 0) yield (i, i.toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" < $"join_col_2", "left_anti").sort($"join_col_1", $"id") + df.collect + } + testAgainstSpark("left anti join 2") { securityLevel => val p_data = for (i <- 1 to 16) yield (i, (i % 4).toString, i * 10) val f_data = for (i <- 1 to 32) yield (i, i.toString, i * 10) @@ -344,6 +380,24 @@ trait OpaqueOperatorTests extends OpaqueTestsBase { self => df.collect } + testAgainstSpark("non-equi left anti join 2") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, (i % 4).toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, i.toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" >= $"join_col_2", "left_anti").sort($"join_col_1", $"id") + df.collect + } + + testAgainstSpark("non-equi left anti join 2 negated") { securityLevel => + val p_data = for (i <- 1 to 16) yield (i, (i % 4).toString, i * 10) + val f_data = for (i <- 1 to 32) yield (i, i.toString, i * 10) + val p = makeDF(p_data, securityLevel, "id", "join_col_1", "x") + val f = makeDF(f_data, securityLevel, "id", "join_col_2", "x") + val df = p.join(f, $"join_col_1" < $"join_col_2", "left_anti").sort($"join_col_1", $"id") + df.collect + } + def abc(i: Int): String = (i % 3) match { case 0 => "A" case 1 => "B"