Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions source/api_cc/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ namespace deepmd {
typedef double ENERGYTYPE;
enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown };

/**
* @brief Get the backend of the model.
* @param[in] model The model name.
* @return The backend of the model.
**/
DPBackend get_backend(const std::string& model);

struct NeighborListData {
/// Array stores the core region atom's index
std::vector<int> ilist;
Expand Down
3 changes: 1 addition & 2 deletions source/api_cc/src/DataModifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@ void DipoleChargeModifier::init(const std::string& model,
<< std::endl;
return;
}
// TODO: To implement detect_backend
DPBackend backend = deepmd::DPBackend::TensorFlow;
const DPBackend backend = get_backend(model);
if (deepmd::DPBackend::TensorFlow == backend) {
#ifdef BUILD_TENSORFLOW
dcm = std::make_shared<deepmd::DipoleChargeModifierTF>(model, gpu_rank,
Expand Down
12 changes: 1 addition & 11 deletions source/api_cc/src/DeepPot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,7 @@ void DeepPot::init(const std::string& model,
<< std::endl;
return;
}
DPBackend backend;
if (model.length() >= 4 && model.substr(model.length() - 4) == ".pth") {
backend = deepmd::DPBackend::PyTorch;
} else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") {
backend = deepmd::DPBackend::TensorFlow;
} else if (model.length() >= 11 &&
model.substr(model.length() - 11) == ".savedmodel") {
backend = deepmd::DPBackend::JAX;
} else {
throw deepmd::deepmd_exception("Unsupported model file format");
}
const DPBackend backend = get_backend(model);
if (deepmd::DPBackend::TensorFlow == backend) {
#ifdef BUILD_TENSORFLOW
dp = std::make_shared<deepmd::DeepPotTF>(model, gpu_rank, file_content);
Expand Down
9 changes: 1 addition & 8 deletions source/api_cc/src/DeepSpin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,7 @@ void DeepSpin::init(const std::string& model,
<< std::endl;
return;
}
DPBackend backend;
if (model.length() >= 4 && model.substr(model.length() - 4) == ".pth") {
backend = deepmd::DPBackend::PyTorch;
} else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") {
backend = deepmd::DPBackend::TensorFlow;
} else {
throw deepmd::deepmd_exception("Unsupported model file format");
}
const DPBackend backend = get_backend(model);
if (deepmd::DPBackend::TensorFlow == backend) {
#ifdef BUILD_TENSORFLOW
dp = std::make_shared<deepmd::DeepSpinTF>(model, gpu_rank, file_content);
Expand Down
3 changes: 1 addition & 2 deletions source/api_cc/src/DeepTensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ void DeepTensor::init(const std::string &model,
<< std::endl;
return;
}
// TODO: To implement detect_backend
DPBackend backend = deepmd::DPBackend::TensorFlow;
const DPBackend backend = get_backend(model);
if (deepmd::DPBackend::TensorFlow == backend) {
#ifdef BUILD_TENSORFLOW
dt = std::make_shared<deepmd::DeepTensorTF>(model, gpu_rank, name_scope_);
Expand Down
12 changes: 12 additions & 0 deletions source/api_cc/src/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1399,3 +1399,15 @@ void deepmd::print_summary(const std::string& pre) {
<< "set tf inter_op_parallelism_threads: " << num_inter_nthreads
<< std::endl;
}

deepmd::DPBackend deepmd::get_backend(const std::string& model) {
if (model.length() >= 4 && model.substr(model.length() - 4) == ".pth") {
return deepmd::DPBackend::PyTorch;
} else if (model.length() >= 3 && model.substr(model.length() - 3) == ".pb") {
return deepmd::DPBackend::TensorFlow;
} else if (model.length() >= 11 &&
model.substr(model.length() - 11) == ".savedmodel") {
return deepmd::DPBackend::JAX;
}
throw deepmd::deepmd_exception("Unsupported model file format");
}