Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/core/jni/JniCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ static inline jclass createGlobalClassReference(JNIEnv* env, const char* classNa
static inline jclass createGlobalClassReferenceOrError(JNIEnv* env, const char* className) {
jclass globalClass = createGlobalClassReference(env, className);
if (globalClass == nullptr) {
std::string errorMessage = "Unable to CreateGlobalClassReferenceOrError for" + std::string(className);
std::string errorMessage = "Unable to create global class reference for" + std::string(className);
throw gluten::GlutenException(errorMessage);
}
return globalClass;
Expand Down
6 changes: 1 addition & 5 deletions cpp/velox/benchmarks/PlanValidatorUtil.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,11 @@ int main(int argc, char** argv) {
std::unordered_map<std::string, std::string> conf;
conf.insert({kDebugModeEnabled, "true"});
initVeloxBackend(conf);
std::unordered_map<std::string, std::string> configs{{core::QueryConfig::kSparkPartitionId, "0"}};
auto queryCtx = core::QueryCtx::create(nullptr, core::QueryConfig(configs));
auto pool = defaultLeafVeloxMemoryPool().get();
core::ExecCtx execCtx(pool, queryCtx.get());
SubstraitToVeloxPlanValidator planValidator(pool);

::substrait::Plan subPlan;
parseProtobuf(reinterpret_cast<uint8_t*>(plan.data()), plan.size(), &subPlan);

SubstraitToVeloxPlanValidator planValidator(pool, &execCtx);
try {
if (!planValidator.validate(subPlan)) {
auto reason = planValidator.getValidateLog();
Expand Down
45 changes: 18 additions & 27 deletions cpp/velox/jni/VeloxJniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ using namespace gluten;
using namespace facebook;

namespace {
jclass infoCls;
jmethodID infoClsInitMethod;

jclass blockStripesClass;
jmethodID blockStripesConstructor;
} // namespace
Expand All @@ -61,6 +64,9 @@ jint JNI_OnLoad(JavaVM* vm, void*) {
initVeloxJniFileSystem(env);
initVeloxJniUDF(env);

infoCls = createGlobalClassReferenceOrError(env, "Lorg/apache/gluten/validate/NativePlanValidationInfo;");
infoClsInitMethod = env->GetMethodID(infoCls, "<init>", "(ILjava/lang/String;)V");

blockStripesClass =
createGlobalClassReferenceOrError(env, "Lorg/apache/spark/sql/execution/datasources/BlockStripes;");
blockStripesConstructor = env->GetMethodID(blockStripesClass, "<init>", "(J[J[II[B)V");
Expand Down Expand Up @@ -116,52 +122,37 @@ Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrapper_nativeValidateWithFail
jobject wrapper,
jbyteArray planArray) {
JNI_METHOD_START
auto ctx = getRuntime(env, wrapper);
auto safeArray = getByteArrayElementsSafe(env, planArray);
auto planData = safeArray.elems();
auto planSize = env->GetArrayLength(planArray);
auto runtime = dynamic_cast<VeloxRuntime*>(ctx);
const auto ctx = getRuntime(env, wrapper);
const auto safeArray = getByteArrayElementsSafe(env, planArray);
const auto planData = safeArray.elems();
const auto planSize = env->GetArrayLength(planArray);
const auto runtime = dynamic_cast<VeloxRuntime*>(ctx);
if (runtime->debugModeEnabled()) {
try {
auto jsonPlan = substraitFromPbToJson("Plan", planData, planSize, std::nullopt);
const auto jsonPlan = substraitFromPbToJson("Plan", planData, planSize, std::nullopt);
LOG(INFO) << std::string(50, '#') << " received substrait::Plan: for validation";
LOG(INFO) << jsonPlan;
} catch (const std::exception& e) {
LOG(WARNING) << "Error converting Substrait plan for validation to JSON: " << e.what();
}
}

const auto pool = defaultLeafVeloxMemoryPool().get();
SubstraitToVeloxPlanValidator planValidator(pool);
::substrait::Plan subPlan;
parseProtobuf(planData, planSize, &subPlan);

// A query context with dummy configs. Used for function validation.
std::unordered_map<std::string, std::string> configs{
{velox::core::QueryConfig::kSparkPartitionId, "0"}, {velox::core::QueryConfig::kSessionTimezone, "GMT"}};
auto queryCtx = velox::core::QueryCtx::create(nullptr, velox::core::QueryConfig(configs));
auto pool = defaultLeafVeloxMemoryPool().get();
// An execution context used for function validation.
velox::core::ExecCtx execCtx(pool, queryCtx.get());

SubstraitToVeloxPlanValidator planValidator(pool, &execCtx);
jclass infoCls = env->FindClass("Lorg/apache/gluten/validate/NativePlanValidationInfo;");
if (infoCls == nullptr) {
std::string errorMessage = "Unable to CreateGlobalClassReferenceOrError for NativePlanValidationInfo";
throw GlutenException(errorMessage);
}
jmethodID method = env->GetMethodID(infoCls, "<init>", "(ILjava/lang/String;)V");
try {
auto isSupported = planValidator.validate(subPlan);
auto logs = planValidator.getValidateLog();
const auto isSupported = planValidator.validate(subPlan);
const auto logs = planValidator.getValidateLog();
std::string concatLog;
for (int i = 0; i < logs.size(); i++) {
concatLog += logs[i] + "@";
}
return env->NewObject(infoCls, method, isSupported, env->NewStringUTF(concatLog.c_str()));
return env->NewObject(infoCls, infoClsInitMethod, isSupported, env->NewStringUTF(concatLog.c_str()));
} catch (std::invalid_argument& e) {
LOG(INFO) << "Failed to validate substrait plan because " << e.what();
// return false;
auto isSupported = false;
return env->NewObject(infoCls, method, isSupported, env->NewStringUTF(""));
return env->NewObject(infoCls, infoClsInitMethod, false, env->NewStringUTF(""));
}
JNI_METHOD_END(nullptr)
}
Expand Down
22 changes: 11 additions & 11 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::ExpandRel& expan
if (rowType) {
// Try to compile the expressions. If there is any unregistered
// function or mismatched type, exception will be thrown.
exec::ExprSet exprSet(std::move(expressions), execCtx_);
exec::ExprSet exprSet(std::move(expressions), execCtx_.get());
}
} else {
LOG_VALIDATION_MSG("Only SwitchingField is supported in ExpandRel.");
Expand Down Expand Up @@ -669,7 +669,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowRel& windo
}
// Try to compile the expressions. If there is any unregistred funciton or
// mismatched type, exception will be thrown.
exec::ExprSet exprSet(std::move(expressions), execCtx_);
exec::ExprSet exprSet(std::move(expressions), execCtx_.get());

// Validate Sort expression
const auto& sorts = windowRel.sorts();
Expand All @@ -692,7 +692,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowRel& windo
LOG_VALIDATION_MSG("in windowRel, the sorting key in Sort Operator only support field.");
return false;
}
exec::ExprSet exprSet1({std::move(expression)}, execCtx_);
exec::ExprSet exprSet1({std::move(expression)}, execCtx_.get());
}
}

Expand Down Expand Up @@ -740,7 +740,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowGroupLimit
}
// Try to compile the expressions. If there is any unregistered function or
// mismatched type, exception will be thrown.
exec::ExprSet exprSet(std::move(expressions), execCtx_);
exec::ExprSet exprSet(std::move(expressions), execCtx_.get());
// Validate Sort expression
const auto& sorts = windowGroupLimitRel.sorts();
for (const auto& sort : sorts) {
Expand All @@ -762,7 +762,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowGroupLimit
LOG_VALIDATION_MSG("in windowGroupLimitRel, the sorting key in Sort Operator only support field.");
return false;
}
exec::ExprSet exprSet1({std::move(expression)}, execCtx_);
exec::ExprSet exprSet1({std::move(expression)}, execCtx_.get());
}
}

Expand Down Expand Up @@ -864,7 +864,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SortRel& sortRel
LOG_VALIDATION_MSG("in SortRel, the sorting key in Sort Operator only support field.");
return false;
}
exec::ExprSet exprSet({std::move(expression)}, execCtx_);
exec::ExprSet exprSet({std::move(expression)}, execCtx_.get());
}
}

Expand Down Expand Up @@ -911,7 +911,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::ProjectRel& proj
}
// Try to compile the expressions. If there is any unregistered function or
// mismatched type, exception will be thrown.
exec::ExprSet exprSet(std::move(expressions), execCtx_);
exec::ExprSet exprSet(std::move(expressions), execCtx_.get());
return true;
}

Expand Down Expand Up @@ -950,7 +950,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FilterRel& filte
expressions.emplace_back(exprConverter_->toVeloxExpr(filterRel.condition(), rowType));
// Try to compile the expressions. If there is any unregistered function
// or mismatched type, exception will be thrown.
exec::ExprSet exprSet(std::move(expressions), execCtx_);
exec::ExprSet exprSet(std::move(expressions), execCtx_.get());
return true;
}

Expand Down Expand Up @@ -1024,7 +1024,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::JoinRel& joinRel

if (joinRel.has_post_join_filter()) {
auto expression = exprConverter_->toVeloxExpr(joinRel.post_join_filter(), rowType);
exec::ExprSet exprSet({std::move(expression)}, execCtx_);
exec::ExprSet exprSet({std::move(expression)}, execCtx_.get());
}
return true;
}
Expand Down Expand Up @@ -1073,7 +1073,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::CrossRel& crossR

if (crossRel.has_expression()) {
auto expression = exprConverter_->toVeloxExpr(crossRel.expression(), rowType);
exec::ExprSet exprSet({std::move(expression)}, execCtx_);
exec::ExprSet exprSet({std::move(expression)}, execCtx_.get());
}

return true;
Expand Down Expand Up @@ -1299,7 +1299,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::ReadRel& readRel
expressions.emplace_back(exprConverter_->toVeloxExpr(readRel.filter(), rowType));
// Try to compile the expressions. If there is any unregistered function
// or mismatched type, exception will be thrown.
exec::ExprSet exprSet(std::move(expressions), execCtx_);
exec::ExprSet exprSet(std::move(expressions), execCtx_.get());
}

return true;
Expand Down
19 changes: 11 additions & 8 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,21 @@
#include "SubstraitToVeloxPlan.h"
#include "velox/core/QueryCtx.h"

using namespace facebook;

namespace gluten {

/// This class is used to validate whether the computing of
/// a Substrait plan is supported in Velox.
class SubstraitToVeloxPlanValidator {
public:
SubstraitToVeloxPlanValidator(memory::MemoryPool* pool, core::ExecCtx* execCtx)
: pool_(pool), execCtx_(execCtx), planConverter_(pool_, confMap_, std::nullopt, true) {}
SubstraitToVeloxPlanValidator(memory::MemoryPool* pool) : planConverter_(pool, {}, std::nullopt, true) {
const std::unordered_map<std::string, std::string> configs{
{velox::core::QueryConfig::kSparkPartitionId, "0"}, {velox::core::QueryConfig::kSessionTimezone, "GMT"}};
queryCtx_ = velox::core::QueryCtx::create(nullptr, velox::core::QueryConfig(configs));
// An execution context used for function validation.
execCtx_ = std::make_unique<velox::core::ExecCtx>(pool, queryCtx_.get());
}

/// Used to validate whether the computing of this Plan is supported.
bool validate(const ::substrait::Plan& plan);
Expand Down Expand Up @@ -88,14 +95,10 @@ class SubstraitToVeloxPlanValidator {
/// Used to validate whether the computing of this RelRoot is supported.
bool validate(const ::substrait::RelRoot& relRoot);

/// A memory pool used for function validation.
memory::MemoryPool* pool_;
std::shared_ptr<velox::core::QueryCtx> queryCtx_;

/// An execution context used for function validation.
core::ExecCtx* execCtx_;

// Unused customized conf map.
std::unordered_map<std::string, std::string> confMap_ = {};
std::unique_ptr<core::ExecCtx> execCtx_;

/// A converter used to convert Substrait plan into Velox's plan node.
SubstraitToVeloxPlanConverter planConverter_;
Expand Down
7 changes: 1 addition & 6 deletions cpp/velox/tests/Substrait2VeloxPlanValidatorTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,7 @@ class Substrait2VeloxPlanValidatorTest : public exec::test::HiveConnectorTestBas
}

bool validatePlan(::substrait::Plan& plan) {
auto queryCtx = core::QueryCtx::create();

// An execution context used for function validation.
std::unique_ptr<core::ExecCtx> execCtx = std::make_unique<core::ExecCtx>(pool_.get(), queryCtx.get());

auto planValidator = std::make_shared<SubstraitToVeloxPlanValidator>(pool_.get(), execCtx.get());
auto planValidator = std::make_shared<SubstraitToVeloxPlanValidator>(pool_.get());
return planValidator->validate(plan);
}
};
Expand Down
Loading