Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions Detectors/TRD/pid/include/TRDPID/ML.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,14 @@ class ML : public PIDBase
[](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, const char* message) {
LOG(warn) << "Ort " << severity << ": [" << logid << "|" << category << "|" << code_location << "]: " << message << ((intptr_t)param == 3 ? " [valid]" : " [error]");
},
(void*)3}; ///< ONNX enviroment
const OrtApi& mApi{Ort::GetApi()}; ///< ONNX api
(void*)3}; ///< ONNX enviroment
const OrtApi& mApi{Ort::GetApi()}; ///< ONNX api
#if __has_include(<onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>)
std::unique_ptr<Ort::Experimental::Session> mSession; ///< ONNX session
Ort::SessionOptions mSessionOptions; ///< ONNX session options
#else
std::unique_ptr<Ort::Session> mSession; ///< ONNX session
#endif
Ort::SessionOptions mSessionOptions; ///< ONNX session options
Ort::AllocatorWithDefaultOptions mAllocator;

// Input/Output
Expand Down
11 changes: 11 additions & 0 deletions Detectors/TRD/pid/src/ML.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@ void ML::init(o2::framework::ProcessingContext& pc)
LOG(info) << "Set GraphOptimizationLevel to " << mParams.graphOptimizationLevel;

// create actual session
#if __has_include(<onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>)
mSession = std::make_unique<Ort::Experimental::Session>(mEnv, reinterpret_cast<void*>(model_data.data()), model_data.size(), mSessionOptions);
#else
mSession = std::make_unique<Ort::Session>(mEnv, reinterpret_cast<void*>(model_data.data()), model_data.size(), mSessionOptions);
#endif
LOG(info) << "ONNX runtime session created";

// print name/shape of inputs
Expand Down Expand Up @@ -104,8 +108,15 @@ float ML::process(const TrackTRD& trk, const o2::globaltracking::RecoContainer&
try {
auto input = prepareModelInput(trk, inputTracks);
// create memory mapping to vector above
#if __has_include(<onnxruntime/core/session/experimental_onnxruntime_cxx_api.h>)
auto inputTensor = Ort::Experimental::Value::CreateTensor<float>(input.data(), input.size(),
{static_cast<int64_t>(input.size()) / mInputShapes[0][1], mInputShapes[0][1]});
#else
Ort::MemoryInfo mem_info =
Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault);
auto inputTensor = Ort::Value::CreateTensor<float>(mem_info, input.data(), input.size(),
{static_cast<int64_t>(input.size()) / mInputShapes[0][1], mInputShapes[0][1]});
#endif
std::vector<Ort::Value> ortTensor;
ortTensor.push_back(std::move(inputTensor));
auto outTensor = mSession->Run(mInputNames, ortTensor, mOutputNames);
Expand Down
3 changes: 3 additions & 0 deletions dependencies/FindONNXRuntime.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,7 @@ endif()

if (NOT ONNXRuntime::ONNXRuntime_FOUND)
find_package(onnxruntime CONFIG)
if (onnxruntime_FOUND)
add_library(ONNXRuntime::ONNXRuntime ALIAS onnxruntime::onnxruntime)
endif()
endif()