diff --git a/src/enclave/Enclave/ExpressionEvaluation.h b/src/enclave/Enclave/ExpressionEvaluation.h index 299088568c..1c91d2e3f4 100644 --- a/src/enclave/Enclave/ExpressionEvaluation.h +++ b/src/enclave/Enclave/ExpressionEvaluation.h @@ -265,6 +265,25 @@ class FlatbuffersExpressionEvaluator { case tuix::ExprUnion_Literal: { + auto * literal = static_cast(expr->expr()); + const tuix::Field *value = literal->value(); + + // If type is CalendarInterval, manually return a calendar interval field. + // Otherwise 'days' disappears in conversion. + if (value->value_type() == tuix::FieldUnion_CalendarIntervalField) { + + auto *interval = value->value_as_CalendarIntervalField(); + uint32_t months = interval->months(); + uint32_t days = interval->days(); + uint64_t ms = interval->microseconds(); + + return tuix::CreateField( + builder, + tuix::FieldUnion_CalendarIntervalField, + tuix::CreateCalendarIntervalField(builder, months, days, ms).Union(), + false); + } + return flatbuffers_copy( static_cast(expr->expr())->value(), builder); } @@ -403,6 +422,7 @@ class FlatbuffersExpressionEvaluator { auto add = static_cast(expr->expr()); auto left_offset = eval_helper(row, add->left()); auto right_offset = eval_helper(row, add->right()); + return eval_binary_arithmetic_op( builder, flatbuffers::GetTemporaryPointer(builder, left_offset), @@ -1041,6 +1061,102 @@ class FlatbuffersExpressionEvaluator { false); } + // Time expressions + case tuix::ExprUnion_DateAdd: + { + auto c = static_cast(expr->expr()); + auto left_offset = eval_helper(row, c->left()); + auto right_offset = eval_helper(row, c->right()); + + // Note: These temporary pointers will be invalidated when we next write to builder + const tuix::Field *left = flatbuffers::GetTemporaryPointer(builder, left_offset); + const tuix::Field *right = flatbuffers::GetTemporaryPointer(builder, right_offset); + + if (left->value_type() != tuix::FieldUnion_DateField + || right->value_type() != tuix::FieldUnion_IntegerField) { + throw std::runtime_error( + std::string("tuix::DateAdd requires date Date, increment Integer, not ") + + std::string("date ") + + std::string(tuix::EnumNameFieldUnion(left->value_type())) + + std::string(", increment ") + + std::string(tuix::EnumNameFieldUnion(right->value_type()))); + } + + bool result_is_null = left->is_null() || right->is_null(); + + if (!result_is_null) { + auto left_field = static_cast(left->value()); + auto right_field = static_cast(right->value()); + + uint32_t result = left_field->value() + right_field->value(); + + return tuix::CreateField( + builder, + tuix::FieldUnion_DateField, + tuix::CreateDateField(builder, result).Union(), + result_is_null); + } else { + uint32_t result = 0; + return tuix::CreateField( + builder, + tuix::FieldUnion_DateField, + tuix::CreateDateField(builder, result).Union(), + result_is_null); + } + } + + case tuix::ExprUnion_DateAddInterval: + { + auto c = static_cast(expr->expr()); + auto left_offset = eval_helper(row, c->left()); + auto right_offset = eval_helper(row, c->right()); + + // Note: These temporary pointers will be invalidated when we next write to builder + const tuix::Field *left = flatbuffers::GetTemporaryPointer(builder, left_offset); + const tuix::Field *right = flatbuffers::GetTemporaryPointer(builder, right_offset); + + if (left->value_type() != tuix::FieldUnion_DateField + || right->value_type() != tuix::FieldUnion_CalendarIntervalField) { + throw std::runtime_error( + std::string("tuix::DateAddInterval requires date Date, interval CalendarIntervalField, not ") + + std::string("date ") + + std::string(tuix::EnumNameFieldUnion(left->value_type())) + + std::string(", interval ") + + std::string(tuix::EnumNameFieldUnion(right->value_type()))); + } + + bool result_is_null = left->is_null() || right->is_null(); + uint32_t result = 0; + + if (!result_is_null) { + + auto left_field = static_cast(left->value()); + auto right_field = static_cast(right->value()); + + //This is an approximation + //TODO take into account leap seconds + uint64_t date = 86400L*left_field->value(); + struct tm tm; + secs_to_tm(date, &tm); + tm.tm_mon += right_field->months(); + tm.tm_mday += right_field->days(); + time_t time = std::mktime(&tm); + uint32_t result = (time + (right_field->microseconds() / 1000)) / 86400L; + + return tuix::CreateField( + builder, + tuix::FieldUnion_DateField, + tuix::CreateDateField(builder, result).Union(), + result_is_null); + } else { + return tuix::CreateField( + builder, + tuix::FieldUnion_DateField, + tuix::CreateDateField(builder, result).Union(), + result_is_null); + } + } + case tuix::ExprUnion_Year: { auto e = static_cast(expr->expr()); diff --git a/src/flatbuffers/Expr.fbs b/src/flatbuffers/Expr.fbs index 6e29cf2c95..d09441942c 100644 --- a/src/flatbuffers/Expr.fbs +++ b/src/flatbuffers/Expr.fbs @@ -36,7 +36,9 @@ union ExprUnion { Exp, ClosestPoint, CreateArray, - Upper + Upper, + DateAdd, + DateAddInterval } table Expr { @@ -165,6 +167,16 @@ table Year { child:Expr; } +table DateAdd { + left:Expr; + right:Expr; +} + +table DateAddInterval { + left:Expr; + right:Expr; +} + // Math expressions table Exp { child:Expr; diff --git a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala index 7da1a4e21a..e3da1eafda 100644 --- a/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala +++ b/src/main/scala/edu/berkeley/cs/rise/opaque/Utils.scala @@ -44,6 +44,8 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.expressions.Contains +import org.apache.spark.sql.catalyst.expressions.DateAdd +import org.apache.spark.sql.catalyst.expressions.DateAddInterval import org.apache.spark.sql.catalyst.expressions.Descending import org.apache.spark.sql.catalyst.expressions.Divide import org.apache.spark.sql.catalyst.expressions.EndsWith @@ -69,6 +71,7 @@ import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.expressions.StartsWith import org.apache.spark.sql.catalyst.expressions.Substring import org.apache.spark.sql.catalyst.expressions.Subtract +import org.apache.spark.sql.catalyst.expressions.TimeAdd import org.apache.spark.sql.catalyst.expressions.UnaryMinus import org.apache.spark.sql.catalyst.expressions.Upper import org.apache.spark.sql.catalyst.expressions.Year @@ -1000,6 +1003,7 @@ object Utils extends Logging { tuix.Contains.createContains( builder, leftOffset, rightOffset)) + // Time expressions case (Year(child), Seq(childOffset)) => tuix.Expr.createExpr( builder, @@ -1007,6 +1011,20 @@ object Utils extends Logging { tuix.Year.createYear( builder, childOffset)) + case (DateAdd(left, right), Seq(leftOffset, rightOffset)) => + tuix.Expr.createExpr( + builder, + tuix.ExprUnion.DateAdd, + tuix.DateAdd.createDateAdd( + builder, leftOffset, rightOffset)) + + case (DateAddInterval(left, right, _, _), Seq(leftOffset, rightOffset)) => + tuix.Expr.createExpr( + builder, + tuix.ExprUnion.DateAddInterval, + tuix.DateAddInterval.createDateAddInterval( + builder, leftOffset, rightOffset)) + // Math expressions case (Exp(child), Seq(childOffset)) => tuix.Expr.createExpr( diff --git a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala index d0a2e2ffe9..219a39c54e 100644 --- a/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala +++ b/src/test/scala/edu/berkeley/cs/rise/opaque/OpaqueOperatorTests.scala @@ -122,6 +122,45 @@ trait OpaqueOperatorTests extends FunSuite with BeforeAndAfterAll { self => } } + testAgainstSpark("Interval SQL") { securityLevel => + val data = Seq(Tuple2(1, new java.sql.Date(new java.util.Date().getTime()))) + val df = makeDF(data, securityLevel, "index", "time") + df.createTempView("Interval") + try { + spark.sql("SELECT time + INTERVAL 7 DAY FROM Interval").collect + } finally { + spark.catalog.dropTempView("Interval") + } + } + + testAgainstSpark("Interval Week SQL") { securityLevel => + val data = Seq(Tuple2(1, new java.sql.Date(new java.util.Date().getTime()))) + val df = makeDF(data, securityLevel, "index", "time") + df.createTempView("Interval") + try { + spark.sql("SELECT time + INTERVAL 7 WEEK FROM Interval").collect + } finally { + spark.catalog.dropTempView("Interval") + } + } + + testAgainstSpark("Interval Month SQL") { securityLevel => + val data = Seq(Tuple2(1, new java.sql.Date(new java.util.Date().getTime()))) + val df = makeDF(data, securityLevel, "index", "time") + df.createTempView("Interval") + try { + spark.sql("SELECT time + INTERVAL 6 MONTH FROM Interval").collect + } finally { + spark.catalog.dropTempView("Interval") + } + } + + testAgainstSpark("Date Add") { securityLevel => + val data = Seq(Tuple2(1, new java.sql.Date(new java.util.Date().getTime()))) + val df = makeDF(data, securityLevel, "index", "time") + df.select(date_add($"time", 3)).collect + } + testAgainstSpark("create DataFrame from sequence") { securityLevel => val data = for (i <- 0 until 5) yield ("foo", i) makeDF(data, securityLevel, "word", "count").collect