Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,9 @@ if(GTEST_FOUND)
if(DEFINED LLVM_LIBS)
target_link_libraries(cpptest PRIVATE ${LLVM_LIBS})
endif()
if(DEFINED ETHOSN_RUNTIME_LIBRARY)
target_link_libraries(cpptest PRIVATE ${ETHOSN_RUNTIME_LIBRARY})
endif()
set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_ALL 1)
set_target_properties(cpptest PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1)
if(USE_RELAY_DEBUG)
Expand Down
70 changes: 44 additions & 26 deletions src/runtime/contrib/ethosn/ethosn_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

#include <algorithm>
#include <memory>
#include <string>

#include "ethosn_driver_library/Buffer.hpp"
#include "ethosn_runtime.h"
Expand All @@ -48,7 +49,7 @@ namespace ethosn {

namespace dl = ::ethosn::driver_library;

bool WaitForInference(dl::Inference* inference, int timeout) {
InferenceWaitStatus WaitForInference(dl::Inference* inference, int timeout) {
// Wait for inference to complete
int fd = inference->GetFileDescriptor();
struct pollfd fds;
Expand All @@ -58,20 +59,32 @@ bool WaitForInference(dl::Inference* inference, int timeout) {

const int ms_per_seconds = 1000;
int poll_result = poll(&fds, 1, timeout * ms_per_seconds);
if (poll_result > 0) {
dl::InferenceResult result;
if (read(fd, &result, sizeof(result)) != sizeof(result)) {
return false;
}
if (result != dl::InferenceResult::Completed) {
return false;
}
int poll_error_code = errno;

if (poll_result < 0) {
return InferenceWaitStatus(InferenceWaitErrorCode::kError,
"Error while waiting for the inference to complete (" +
std::string(strerror(poll_error_code)) + ")");
} else if (poll_result == 0) {
return false;
} else {
return false;
return InferenceWaitStatus(InferenceWaitErrorCode::kTimeout,
"Timed out while waiting for the inference to complete.");
}
return true;

// poll_result > 0
dl::InferenceResult npu_result;
if (read(fd, &npu_result, sizeof(npu_result)) != static_cast<ssize_t>(sizeof(npu_result))) {
return InferenceWaitStatus(
InferenceWaitErrorCode::kError,
"Failed to read inference result status (" + std::string(strerror(poll_error_code)) + ")");
}

if (npu_result != dl::InferenceResult::Completed) {
return InferenceWaitStatus(
InferenceWaitErrorCode::kError,
"Inference failed with status " + std::to_string(static_cast<uint32_t>(npu_result)));
}

return InferenceWaitStatus(InferenceWaitErrorCode::kSuccess);
}

void CreateBuffers(std::vector<std::shared_ptr<dl::Buffer>>* fm,
Expand Down Expand Up @@ -123,21 +136,26 @@ bool Inference(tvm::runtime::TVMArgs args, dl::Network* npu,
}

// Execute the inference.
std::unique_ptr<dl::Inference> result(
std::unique_ptr<dl::Inference> inference(
npu->ScheduleInference(ifm_raw, n_inputs, ofm_raw, n_outputs));
bool inferenceCompleted = WaitForInference(result.get(), 60);
if (inferenceCompleted) {
for (size_t i = 0; i < n_outputs; i++) {
DLTensor* tensor = outputs[i];
dl::Buffer* source_buffer = ofm_raw[i];
uint8_t* dest_buffer = static_cast<uint8_t*>(tensor->data);
size_t size = source_buffer->GetSize();
uint8_t* source_buffer_data = source_buffer->Map();
std::copy(source_buffer_data, source_buffer_data + size, dest_buffer);
source_buffer->Unmap();
}
InferenceWaitStatus result = WaitForInference(inference.get(), 60);

if (result.GetErrorCode() != InferenceWaitErrorCode::kSuccess) {
LOG(FATAL) << "An error has occured waiting for the inference of a sub-graph on the NPU: "
<< result.GetErrorDescription();
}

for (size_t i = 0; i < n_outputs; i++) {
DLTensor* tensor = outputs[i];
dl::Buffer* source_buffer = ofm_raw[i];
uint8_t* dest_buffer = static_cast<uint8_t*>(tensor->data);
size_t size = source_buffer->GetSize();
uint8_t* source_buffer_data = source_buffer->Map();
std::copy(source_buffer_data, source_buffer_data + size, dest_buffer);
source_buffer->Unmap();
}
return inferenceCompleted;

return true;
}

} // namespace ethosn
Expand Down
33 changes: 33 additions & 0 deletions src/runtime/contrib/ethosn/ethosn_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,39 @@ class EthosnModule : public ModuleNode {
std::map<std::string, OrderedCompiledNetwork> network_map_;
};

/*!
* \brief Error codes for evaluating the result of inference on the NPU.
*/
enum class InferenceWaitErrorCode { kSuccess = 0, kTimeout = 1, kError = 2 };

/*!
* \brief A helper class holding the status of inference on the NPU and
* associated error message(s) if any occurred.
*
* Similar to the implementation of 'WaitStatus' in the driver stack:
* https://github.com/ARM-software/ethos-n-driver-stack/blob/22.08/armnn-ethos-n-backend/workloads/EthosNPreCompiledWorkload.cpp#L48
*/
class InferenceWaitStatus {
public:
InferenceWaitStatus() : error_code_(InferenceWaitErrorCode::kSuccess), error_description_("") {}

explicit InferenceWaitStatus(InferenceWaitErrorCode errorCode, std::string errorDescription = "")
: error_code_(errorCode), error_description_(errorDescription) {}

InferenceWaitStatus(const InferenceWaitStatus&) = default;
InferenceWaitStatus(InferenceWaitStatus&&) = default;
InferenceWaitStatus& operator=(const InferenceWaitStatus&) = default;
InferenceWaitStatus& operator=(InferenceWaitStatus&&) = default;

explicit operator bool() const { return error_code_ == InferenceWaitErrorCode::kSuccess; }
InferenceWaitErrorCode GetErrorCode() const { return error_code_; }
std::string GetErrorDescription() const { return error_description_; }

private:
InferenceWaitErrorCode error_code_;
std::string error_description_;
};

} // namespace ethosn
} // namespace runtime
} // namespace tvm
Expand Down
74 changes: 74 additions & 0 deletions tests/cpp/runtime/contrib/ethosn/inference_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tests/cpp/runtime/contrib/ethosn/inference_test.cc
* \brief Tests to check Arm(R) Ethos(TM)-N runtime components used during inference.
*/

#ifdef ETHOSN_HW

#include <gtest/gtest.h>

#include "../../../../../src/runtime/contrib/ethosn/ethosn_device.cc"

namespace tvm {
namespace runtime {
namespace ethosn {

TEST(WaitForInference, InferenceScheduled) {
const int inference_result = 0 /* Scheduled */;
const int timeout = 0;

dl::Inference inference = dl::Inference(inference_result);
InferenceWaitStatus result = WaitForInference(&inference, timeout);

ASSERT_EQ(result.GetErrorCode(), InferenceWaitErrorCode::kTimeout);
ICHECK_EQ(result.GetErrorDescription(), "Timed out while waiting for the inference to complete.");
}

TEST(WaitForInference, InferenceRunning) {
const int inference_result = 1 /* Running */;
const int timeout = 0;

dl::Inference inference = dl::Inference(inference_result);
InferenceWaitStatus result = WaitForInference(&inference, timeout);

ASSERT_EQ(result.GetErrorCode(), InferenceWaitErrorCode::kTimeout);
std::cout << result.GetErrorDescription() << std::endl;
ICHECK_EQ(result.GetErrorDescription(), "Timed out while waiting for the inference to complete.");
}

TEST(WaitForInference, InferenceError) {
const int inference_result = 3 /* Error */;
const int timeout = 0;

dl::Inference inference = dl::Inference(inference_result);
InferenceWaitStatus result = WaitForInference(&inference, timeout);

ASSERT_EQ(result.GetErrorCode(), InferenceWaitErrorCode::kError);
ICHECK_EQ(result.GetErrorDescription(),
"Failed to read inference result status (No such file or directory)");
}

} // namespace ethosn
} // namespace runtime
} // namespace tvm

#endif