diff --git a/cpp/velox/operators/functions/RegistrationAllFunctions.cc b/cpp/velox/operators/functions/RegistrationAllFunctions.cc index 5d46dbdcdd0e..dd1be7805c75 100644 --- a/cpp/velox/operators/functions/RegistrationAllFunctions.cc +++ b/cpp/velox/operators/functions/RegistrationAllFunctions.cc @@ -21,6 +21,7 @@ #include "operators/functions/RowFunctionWithNull.h" #include "velox/expression/SpecialFormRegistry.h" #include "velox/expression/VectorFunction.h" +#include "velox/functions/iceberg/Register.h" #include "velox/functions/lib/CheckedArithmetic.h" #include "velox/functions/lib/RegistrationHelpers.h" #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" @@ -91,6 +92,8 @@ void registerAllFunctions() { // Using function overwrite to handle function names mismatch between Spark // and Velox. registerFunctionOverwrite(); + + velox::functions::iceberg::registerFunctions(); } } // namespace gluten diff --git a/gluten-substrait/pom.xml b/gluten-substrait/pom.xml index 06aee5d59c07..8abfe77b4dbc 100644 --- a/gluten-substrait/pom.xml +++ b/gluten-substrait/pom.xml @@ -262,7 +262,7 @@ com.google.protobuf:protoc:${protobuf.version}:exe:${os.detected.classifier} src/main/resources/substrait/proto - true + false diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 4d62ae18048f..4c101e638bc3 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -144,6 +144,31 @@ object ExpressionConverter extends SQLConfHelper with Logging { DecimalArithmeticExpressionTransformer(substraitName, leftChild, rightChild, resultType, b) } + private def replaceIcebergStaticInvoke( + s: StaticInvoke, + attributeSeq: Seq[Attribute], + expressionsMap: Map[Class[_], String]): ExpressionTransformer = { + val invokeMap = Map( + "BucketFunction" -> ExpressionNames.BUCKET, + "TruncateFunction" -> ExpressionNames.TRUNCATE, + "YearsFunction" -> ExpressionNames.YEARS, + "MonthsFunction" -> ExpressionNames.MONTHS, + "DaysFunction" -> ExpressionNames.DAYS, + "HoursFunction" -> ExpressionNames.HOURS + ) + val objName = s.staticObject.getName + val transformer = invokeMap.find { + case (func, _) => objName.startsWith("org.apache.iceberg.spark.functions." + func) + } + if (transformer.isEmpty) { + throw new GlutenNotSupportException(s"Not supported staticInvoke call object: $objName") + } + GenericExpressionTransformer( + transformer.get._2, + s.arguments.map(replaceWithExpressionTransformer0(_, attributeSeq, expressionsMap)), + s) + } + private def replaceWithExpressionTransformer0( expr: Expression, attributeSeq: Seq[Attribute], @@ -180,6 +205,10 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformer0(i.arguments.head, attributeSeq, expressionsMap), i ) + case i: StaticInvoke + if i.functionName == "invoke" && i.staticObject.getName.startsWith( + "org.apache.iceberg.spark.functions.") => + return replaceIcebergStaticInvoke(i, attributeSeq, expressionsMap) case i: StaticInvoke => throw new GlutenNotSupportException( s"Not supported to transform StaticInvoke with object: ${i.staticObject.getName}, " + diff --git a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala index 021d02a13749..8fcdc01e5e86 100644 --- a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala @@ -371,4 +371,12 @@ object ExpressionNames { // A placeholder for native UDF functions final val UDF_PLACEHOLDER = "udf_placeholder" final val UDAF_PLACEHOLDER = "udaf_placeholder" + + // Iceberg function names + final val YEARS = "years" + final val MONTHS = "months" + final val DAYS = "days" + final val HOURS = "hours" + final val BUCKET = "bucket" + final val TRUNCATE = "truncate" }