diff --git a/cpp/core/jni/JniCommon.h b/cpp/core/jni/JniCommon.h index 0c3096d9d693..ecbac66a497f 100644 --- a/cpp/core/jni/JniCommon.h +++ b/cpp/core/jni/JniCommon.h @@ -70,7 +70,7 @@ static inline jclass createGlobalClassReference(JNIEnv* env, const char* classNa static inline jclass createGlobalClassReferenceOrError(JNIEnv* env, const char* className) { jclass globalClass = createGlobalClassReference(env, className); if (globalClass == nullptr) { - std::string errorMessage = "Unable to CreateGlobalClassReferenceOrError for" + std::string(className); + std::string errorMessage = "Unable to create global class reference for" + std::string(className); throw gluten::GlutenException(errorMessage); } return globalClass; diff --git a/cpp/velox/benchmarks/PlanValidatorUtil.cc b/cpp/velox/benchmarks/PlanValidatorUtil.cc index 46f2733f29ea..20d02db6c49e 100644 --- a/cpp/velox/benchmarks/PlanValidatorUtil.cc +++ b/cpp/velox/benchmarks/PlanValidatorUtil.cc @@ -44,15 +44,11 @@ int main(int argc, char** argv) { std::unordered_map conf; conf.insert({kDebugModeEnabled, "true"}); initVeloxBackend(conf); - std::unordered_map configs{{core::QueryConfig::kSparkPartitionId, "0"}}; - auto queryCtx = core::QueryCtx::create(nullptr, core::QueryConfig(configs)); auto pool = defaultLeafVeloxMemoryPool().get(); - core::ExecCtx execCtx(pool, queryCtx.get()); + SubstraitToVeloxPlanValidator planValidator(pool); ::substrait::Plan subPlan; parseProtobuf(reinterpret_cast(plan.data()), plan.size(), &subPlan); - - SubstraitToVeloxPlanValidator planValidator(pool, &execCtx); try { if (!planValidator.validate(subPlan)) { auto reason = planValidator.getValidateLog(); diff --git a/cpp/velox/jni/VeloxJniWrapper.cc b/cpp/velox/jni/VeloxJniWrapper.cc index d29f5b0f7bb0..cd751f5aedff 100644 --- a/cpp/velox/jni/VeloxJniWrapper.cc +++ b/cpp/velox/jni/VeloxJniWrapper.cc @@ -42,6 +42,9 @@ using namespace gluten; using namespace facebook; namespace { +jclass infoCls; +jmethodID infoClsInitMethod; + jclass blockStripesClass; jmethodID blockStripesConstructor; } // namespace @@ -61,6 +64,9 @@ jint JNI_OnLoad(JavaVM* vm, void*) { initVeloxJniFileSystem(env); initVeloxJniUDF(env); + infoCls = createGlobalClassReferenceOrError(env, "Lorg/apache/gluten/validate/NativePlanValidationInfo;"); + infoClsInitMethod = env->GetMethodID(infoCls, "", "(ILjava/lang/String;)V"); + blockStripesClass = createGlobalClassReferenceOrError(env, "Lorg/apache/spark/sql/execution/datasources/BlockStripes;"); blockStripesConstructor = env->GetMethodID(blockStripesClass, "", "(J[J[II[B)V"); @@ -116,14 +122,14 @@ Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrapper_nativeValidateWithFail jobject wrapper, jbyteArray planArray) { JNI_METHOD_START - auto ctx = getRuntime(env, wrapper); - auto safeArray = getByteArrayElementsSafe(env, planArray); - auto planData = safeArray.elems(); - auto planSize = env->GetArrayLength(planArray); - auto runtime = dynamic_cast(ctx); + const auto ctx = getRuntime(env, wrapper); + const auto safeArray = getByteArrayElementsSafe(env, planArray); + const auto planData = safeArray.elems(); + const auto planSize = env->GetArrayLength(planArray); + const auto runtime = dynamic_cast(ctx); if (runtime->debugModeEnabled()) { try { - auto jsonPlan = substraitFromPbToJson("Plan", planData, planSize, std::nullopt); + const auto jsonPlan = substraitFromPbToJson("Plan", planData, planSize, std::nullopt); LOG(INFO) << std::string(50, '#') << " received substrait::Plan: for validation"; LOG(INFO) << jsonPlan; } catch (const std::exception& e) { @@ -131,37 +137,22 @@ Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrapper_nativeValidateWithFail } } + const auto pool = defaultLeafVeloxMemoryPool().get(); + SubstraitToVeloxPlanValidator planValidator(pool); ::substrait::Plan subPlan; parseProtobuf(planData, planSize, &subPlan); - // A query context with dummy configs. Used for function validation. - std::unordered_map configs{ - {velox::core::QueryConfig::kSparkPartitionId, "0"}, {velox::core::QueryConfig::kSessionTimezone, "GMT"}}; - auto queryCtx = velox::core::QueryCtx::create(nullptr, velox::core::QueryConfig(configs)); - auto pool = defaultLeafVeloxMemoryPool().get(); - // An execution context used for function validation. - velox::core::ExecCtx execCtx(pool, queryCtx.get()); - - SubstraitToVeloxPlanValidator planValidator(pool, &execCtx); - jclass infoCls = env->FindClass("Lorg/apache/gluten/validate/NativePlanValidationInfo;"); - if (infoCls == nullptr) { - std::string errorMessage = "Unable to CreateGlobalClassReferenceOrError for NativePlanValidationInfo"; - throw GlutenException(errorMessage); - } - jmethodID method = env->GetMethodID(infoCls, "", "(ILjava/lang/String;)V"); try { - auto isSupported = planValidator.validate(subPlan); - auto logs = planValidator.getValidateLog(); + const auto isSupported = planValidator.validate(subPlan); + const auto logs = planValidator.getValidateLog(); std::string concatLog; for (int i = 0; i < logs.size(); i++) { concatLog += logs[i] + "@"; } - return env->NewObject(infoCls, method, isSupported, env->NewStringUTF(concatLog.c_str())); + return env->NewObject(infoCls, infoClsInitMethod, isSupported, env->NewStringUTF(concatLog.c_str())); } catch (std::invalid_argument& e) { LOG(INFO) << "Failed to validate substrait plan because " << e.what(); - // return false; - auto isSupported = false; - return env->NewObject(infoCls, method, isSupported, env->NewStringUTF("")); + return env->NewObject(infoCls, infoClsInitMethod, false, env->NewStringUTF("")); } JNI_METHOD_END(nullptr) } diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc index 899a16d0d942..bac419883b5b 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc @@ -555,7 +555,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::ExpandRel& expan if (rowType) { // Try to compile the expressions. If there is any unregistered // function or mismatched type, exception will be thrown. - exec::ExprSet exprSet(std::move(expressions), execCtx_); + exec::ExprSet exprSet(std::move(expressions), execCtx_.get()); } } else { LOG_VALIDATION_MSG("Only SwitchingField is supported in ExpandRel."); @@ -669,7 +669,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowRel& windo } // Try to compile the expressions. If there is any unregistred funciton or // mismatched type, exception will be thrown. - exec::ExprSet exprSet(std::move(expressions), execCtx_); + exec::ExprSet exprSet(std::move(expressions), execCtx_.get()); // Validate Sort expression const auto& sorts = windowRel.sorts(); @@ -692,7 +692,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowRel& windo LOG_VALIDATION_MSG("in windowRel, the sorting key in Sort Operator only support field."); return false; } - exec::ExprSet exprSet1({std::move(expression)}, execCtx_); + exec::ExprSet exprSet1({std::move(expression)}, execCtx_.get()); } } @@ -740,7 +740,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowGroupLimit } // Try to compile the expressions. If there is any unregistered function or // mismatched type, exception will be thrown. - exec::ExprSet exprSet(std::move(expressions), execCtx_); + exec::ExprSet exprSet(std::move(expressions), execCtx_.get()); // Validate Sort expression const auto& sorts = windowGroupLimitRel.sorts(); for (const auto& sort : sorts) { @@ -762,7 +762,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowGroupLimit LOG_VALIDATION_MSG("in windowGroupLimitRel, the sorting key in Sort Operator only support field."); return false; } - exec::ExprSet exprSet1({std::move(expression)}, execCtx_); + exec::ExprSet exprSet1({std::move(expression)}, execCtx_.get()); } } @@ -864,7 +864,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SortRel& sortRel LOG_VALIDATION_MSG("in SortRel, the sorting key in Sort Operator only support field."); return false; } - exec::ExprSet exprSet({std::move(expression)}, execCtx_); + exec::ExprSet exprSet({std::move(expression)}, execCtx_.get()); } } @@ -911,7 +911,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::ProjectRel& proj } // Try to compile the expressions. If there is any unregistered function or // mismatched type, exception will be thrown. - exec::ExprSet exprSet(std::move(expressions), execCtx_); + exec::ExprSet exprSet(std::move(expressions), execCtx_.get()); return true; } @@ -950,7 +950,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FilterRel& filte expressions.emplace_back(exprConverter_->toVeloxExpr(filterRel.condition(), rowType)); // Try to compile the expressions. If there is any unregistered function // or mismatched type, exception will be thrown. - exec::ExprSet exprSet(std::move(expressions), execCtx_); + exec::ExprSet exprSet(std::move(expressions), execCtx_.get()); return true; } @@ -1024,7 +1024,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::JoinRel& joinRel if (joinRel.has_post_join_filter()) { auto expression = exprConverter_->toVeloxExpr(joinRel.post_join_filter(), rowType); - exec::ExprSet exprSet({std::move(expression)}, execCtx_); + exec::ExprSet exprSet({std::move(expression)}, execCtx_.get()); } return true; } @@ -1073,7 +1073,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::CrossRel& crossR if (crossRel.has_expression()) { auto expression = exprConverter_->toVeloxExpr(crossRel.expression(), rowType); - exec::ExprSet exprSet({std::move(expression)}, execCtx_); + exec::ExprSet exprSet({std::move(expression)}, execCtx_.get()); } return true; @@ -1299,7 +1299,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::ReadRel& readRel expressions.emplace_back(exprConverter_->toVeloxExpr(readRel.filter(), rowType)); // Try to compile the expressions. If there is any unregistered function // or mismatched type, exception will be thrown. - exec::ExprSet exprSet(std::move(expressions), execCtx_); + exec::ExprSet exprSet(std::move(expressions), execCtx_.get()); } return true; diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h index 881a0e514809..28d82f9cce6f 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h +++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h @@ -20,14 +20,21 @@ #include "SubstraitToVeloxPlan.h" #include "velox/core/QueryCtx.h" +using namespace facebook; + namespace gluten { /// This class is used to validate whether the computing of /// a Substrait plan is supported in Velox. class SubstraitToVeloxPlanValidator { public: - SubstraitToVeloxPlanValidator(memory::MemoryPool* pool, core::ExecCtx* execCtx) - : pool_(pool), execCtx_(execCtx), planConverter_(pool_, confMap_, std::nullopt, true) {} + SubstraitToVeloxPlanValidator(memory::MemoryPool* pool) : planConverter_(pool, {}, std::nullopt, true) { + const std::unordered_map configs{ + {velox::core::QueryConfig::kSparkPartitionId, "0"}, {velox::core::QueryConfig::kSessionTimezone, "GMT"}}; + queryCtx_ = velox::core::QueryCtx::create(nullptr, velox::core::QueryConfig(configs)); + // An execution context used for function validation. + execCtx_ = std::make_unique(pool, queryCtx_.get()); + } /// Used to validate whether the computing of this Plan is supported. bool validate(const ::substrait::Plan& plan); @@ -88,14 +95,10 @@ class SubstraitToVeloxPlanValidator { /// Used to validate whether the computing of this RelRoot is supported. bool validate(const ::substrait::RelRoot& relRoot); - /// A memory pool used for function validation. - memory::MemoryPool* pool_; + std::shared_ptr queryCtx_; /// An execution context used for function validation. - core::ExecCtx* execCtx_; - - // Unused customized conf map. - std::unordered_map confMap_ = {}; + std::unique_ptr execCtx_; /// A converter used to convert Substrait plan into Velox's plan node. SubstraitToVeloxPlanConverter planConverter_; diff --git a/cpp/velox/tests/Substrait2VeloxPlanValidatorTest.cc b/cpp/velox/tests/Substrait2VeloxPlanValidatorTest.cc index 3f90c865df16..2476e2a2f810 100644 --- a/cpp/velox/tests/Substrait2VeloxPlanValidatorTest.cc +++ b/cpp/velox/tests/Substrait2VeloxPlanValidatorTest.cc @@ -47,12 +47,7 @@ class Substrait2VeloxPlanValidatorTest : public exec::test::HiveConnectorTestBas } bool validatePlan(::substrait::Plan& plan) { - auto queryCtx = core::QueryCtx::create(); - - // An execution context used for function validation. - std::unique_ptr execCtx = std::make_unique(pool_.get(), queryCtx.get()); - - auto planValidator = std::make_shared(pool_.get(), execCtx.get()); + auto planValidator = std::make_shared(pool_.get()); return planValidator->validate(plan); } };