diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/runtime/hexagon/hexagon_device_api.cc index 84232a614428..06254fba4585 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/runtime/hexagon/hexagon_device_api.cc @@ -217,7 +217,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVM int ret = DMA_RETRY; do { - ret = HexagonUserDMA::Get().Copy(dst, src, size); + ret = HexagonDeviceAPI::Global()->UserDMA()->Copy(dst, src, size); } while (ret == DMA_RETRY); *rv = static_cast(ret); }); @@ -227,7 +227,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVM ICHECK(queue_id == 0 && "Hexagon supports just a single asynchronous queue for DMA"); int inflight = args[1]; ICHECK(inflight >= 0); - HexagonUserDMA::Get().Wait(inflight); + HexagonDeviceAPI::Global()->UserDMA()->Wait(inflight); *rv = static_cast(0); }); diff --git a/src/runtime/hexagon/hexagon_device_api.h b/src/runtime/hexagon/hexagon_device_api.h index 4f544faffba1..555ca0fa51a8 100644 --- a/src/runtime/hexagon/hexagon_device_api.h +++ b/src/runtime/hexagon/hexagon_device_api.h @@ -32,6 +32,7 @@ #include "hexagon_buffer.h" #include "hexagon_buffer_manager.h" #include "hexagon_thread_manager.h" +#include "hexagon_user_dma.h" namespace tvm { namespace runtime { @@ -61,10 +62,18 @@ class HexagonDeviceAPI final : public DeviceAPI { CHECK_EQ(runtime_threads, nullptr); runtime_threads = std::make_unique(threads, stack_size, pipe_size); DLOG(INFO) << "runtime_threads created"; + + CHECK_EQ(runtime_dma, nullptr); + runtime_dma = std::make_unique(); + DLOG(INFO) << "runtime_dma created"; } //! \brief Ensures all runtime resources are freed void ReleaseResources() { + CHECK(runtime_dma) << "runtime_dma was not created in AcquireResources"; + runtime_dma.reset(); + DLOG(INFO) << "runtime_dma reset"; + CHECK(runtime_threads) << "runtime_threads was not created in AcquireResources"; runtime_threads.reset(); DLOG(INFO) << "runtime_threads reset"; @@ -150,7 +159,13 @@ class HexagonDeviceAPI final : public DeviceAPI { void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) final; HexagonThreadManager* ThreadManager() { - return runtime_threads ? runtime_threads.get() : nullptr; + CHECK(runtime_threads) << "runtime_threads has not been created"; + return runtime_threads.get(); + } + + HexagonUserDMA* UserDMA() { + CHECK(runtime_dma) << "runtime_dma has not been created"; + return runtime_dma.get(); } protected: @@ -184,6 +199,9 @@ class HexagonDeviceAPI final : public DeviceAPI { const unsigned threads{6}; const unsigned pipe_size{1000}; const unsigned stack_size{0x4000}; // 16KB + + //! \brief User DMA manager + std::unique_ptr runtime_dma; }; } // namespace hexagon } // namespace runtime diff --git a/src/runtime/hexagon/hexagon_user_dma.cc b/src/runtime/hexagon/hexagon_user_dma.cc index 8d45b7590bc4..ab464c150319 100644 --- a/src/runtime/hexagon/hexagon_user_dma.cc +++ b/src/runtime/hexagon/hexagon_user_dma.cc @@ -21,6 +21,8 @@ #include +#include "hexagon_device_api.h" + namespace tvm { namespace runtime { namespace hexagon { @@ -116,13 +118,15 @@ HexagonUserDMA::~HexagonUserDMA() { } int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length) { + HexagonUserDMA* user_dma = HexagonDeviceAPI::Global()->UserDMA(); + // One DMA transfer can copy at most DESC_LENGTH_MASK bytes. // Make the common case quick. if (length <= DESC_LENGTH_MASK) { // sync DMA -> `Copy` and then `Wait(0)` - int ret_val = HexagonUserDMA::Get().Copy(dst, src, length); + int ret_val = user_dma->Copy(dst, src, length); if (ret_val != DMA_SUCCESS) return ret_val; - HexagonUserDMA::Get().Wait(0); + user_dma->Wait(0); return DMA_SUCCESS; } @@ -133,9 +137,9 @@ int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length) { // Ensure there is no overflow while updating i uint32_t cur_len = std::min(length - i, DESC_LENGTH_MASK); // sync DMA -> `Copy` and then `Wait(0)` - int ret_val = HexagonUserDMA::Get().Copy(&cast_dst[i], &cast_src[i], cur_len); + int ret_val = user_dma->Copy(&cast_dst[i], &cast_src[i], cur_len); if (ret_val != DMA_SUCCESS) return ret_val; - HexagonUserDMA::Get().Wait(0); + user_dma->Wait(0); // 2 cases for new val for i: // 1. length - i <= DESC_LENGTH_MASK (<= MAX_UINT) // new_i = i + (length - i) = length, no more iter diff --git a/src/runtime/hexagon/hexagon_user_dma.h b/src/runtime/hexagon/hexagon_user_dma.h index aa00df79c4d0..f8838ee2dcc9 100644 --- a/src/runtime/hexagon/hexagon_user_dma.h +++ b/src/runtime/hexagon/hexagon_user_dma.h @@ -37,6 +37,13 @@ namespace hexagon { class HexagonUserDMA { public: + HexagonUserDMA(); + ~HexagonUserDMA(); + HexagonUserDMA(const HexagonUserDMA&) = delete; + HexagonUserDMA& operator=(const HexagonUserDMA&) = delete; + HexagonUserDMA(HexagonUserDMA&&) = delete; + HexagonUserDMA& operator=(HexagonUserDMA&&) = delete; + /*! * \brief Initiate DMA to copy memory from source to destination address * \param dst Destination address @@ -59,21 +66,7 @@ class HexagonUserDMA { */ uint32_t Poll(); - //! \brief HexagonUserDMA uses the singleton pattern - static HexagonUserDMA& Get() { - static HexagonUserDMA* hud = new HexagonUserDMA(); - return *hud; - } - private: - // HexagonUserDMA uses the singleton pattern - HexagonUserDMA(); - ~HexagonUserDMA(); - HexagonUserDMA(const HexagonUserDMA&) = delete; - HexagonUserDMA& operator=(const HexagonUserDMA&) = delete; - HexagonUserDMA(HexagonUserDMA&&) = delete; - HexagonUserDMA& operator=(HexagonUserDMA&&) = delete; - //! \brief Initializes the Hexagon User DMA engine unsigned int Init(); diff --git a/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc b/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc index b54e40e87958..d0f962cfcee5 100644 --- a/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc @@ -171,7 +171,16 @@ TEST_F(HexagonDeviceAPITest, thread_manager) { HexagonThreadManager* threads = hexapi->ThreadManager(); CHECK(threads != nullptr); hexapi->ReleaseResources(); - threads = hexapi->ThreadManager(); - CHECK(threads == nullptr); + EXPECT_THROW(hexapi->ThreadManager(), InternalError); + hexapi->AcquireResources(); +} + +// Ensure thread manager is properly configured and destroyed +// in Acquire/Release +TEST_F(HexagonDeviceAPITest, user_dma) { + HexagonUserDMA* user_dma = hexapi->UserDMA(); + CHECK(user_dma != nullptr); + hexapi->ReleaseResources(); + EXPECT_THROW(hexapi->UserDMA(), InternalError); hexapi->AcquireResources(); } diff --git a/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc b/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc index bf7a23712d7d..fb46cb3fd976 100644 --- a/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc +++ b/tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc @@ -19,13 +19,14 @@ #include -#include "../src/runtime/hexagon/hexagon_user_dma.h" +#include "../src/runtime/hexagon/hexagon_device_api.h" using namespace tvm::runtime; using namespace tvm::runtime::hexagon; class HexagonUserDMATest : public ::testing::Test { void SetUp() override { + user_dma = HexagonDeviceAPI::Global()->UserDMA(); src = malloc(length); dst = malloc(length); ASSERT_NE(src, nullptr); @@ -44,6 +45,7 @@ class HexagonUserDMATest : public ::testing::Test { } public: + HexagonUserDMA* user_dma; int ret{0}; void* src{nullptr}; void* dst{nullptr}; @@ -53,29 +55,29 @@ class HexagonUserDMATest : public ::testing::Test { }; TEST_F(HexagonUserDMATest, wait) { - HexagonUserDMA::Get().Wait(0); - HexagonUserDMA::Get().Wait(10); + user_dma->Wait(0); + user_dma->Wait(10); } -TEST_F(HexagonUserDMATest, poll) { ASSERT_EQ(HexagonUserDMA::Get().Poll(), 0); } +TEST_F(HexagonUserDMATest, poll) { ASSERT_EQ(user_dma->Poll(), 0); } TEST_F(HexagonUserDMATest, bad_copy) { uint64_t bigaddr = 0x100000000; void* src64 = reinterpret_cast(bigaddr); void* dst64 = reinterpret_cast(bigaddr); uint32_t biglength = 0x1000000; - ASSERT_NE(HexagonUserDMA::Get().Copy(dst64, src, length), DMA_SUCCESS); - ASSERT_NE(HexagonUserDMA::Get().Copy(dst, src64, length), DMA_SUCCESS); - ASSERT_NE(HexagonUserDMA::Get().Copy(dst, src, biglength), DMA_SUCCESS); + ASSERT_NE(user_dma->Copy(dst64, src, length), DMA_SUCCESS); + ASSERT_NE(user_dma->Copy(dst, src64, length), DMA_SUCCESS); + ASSERT_NE(user_dma->Copy(dst, src, biglength), DMA_SUCCESS); } TEST_F(HexagonUserDMATest, sync_dma) { // kick off 1 DMA - ret = HexagonUserDMA::Get().Copy(dst, src, length); + ret = user_dma->Copy(dst, src, length); ASSERT_EQ(ret, DMA_SUCCESS); // wait for DMA to complete - HexagonUserDMA::Get().Wait(0); + user_dma->Wait(0); // verify for (uint32_t i = 0; i < length; ++i) { @@ -86,12 +88,12 @@ TEST_F(HexagonUserDMATest, sync_dma) { TEST_F(HexagonUserDMATest, async_dma_wait) { // kick off 10x duplicate DMAs for (uint32_t i = 0; i < 10; ++i) { - ret = HexagonUserDMA::Get().Copy(dst, src, length); + ret = user_dma->Copy(dst, src, length); ASSERT_EQ(ret, DMA_SUCCESS); } // wait for at least 1 DMA to complete - HexagonUserDMA::Get().Wait(9); + user_dma->Wait(9); // verify for (uint32_t i = 0; i < length; ++i) { @@ -99,18 +101,18 @@ TEST_F(HexagonUserDMATest, async_dma_wait) { } // empty the DMA queue - HexagonUserDMA::Get().Wait(0); + user_dma->Wait(0); } TEST_F(HexagonUserDMATest, async_dma_poll) { // kick off 10x duplicate DMAs for (uint32_t i = 0; i < 10; ++i) { - ret = HexagonUserDMA::Get().Copy(dst, src, length); + ret = user_dma->Copy(dst, src, length); ASSERT_EQ(ret, DMA_SUCCESS); } // poll until at least 1 DMA is complete - while (HexagonUserDMA::Get().Poll() == 10) { + while (user_dma->Poll() == 10) { }; // verify @@ -119,7 +121,7 @@ TEST_F(HexagonUserDMATest, async_dma_poll) { } // empty the DMA queue - HexagonUserDMA::Get().Wait(0); + user_dma->Wait(0); } // TODO: Run non-pipelined case with sync DMA and execution time vs. pipelined case @@ -128,26 +130,26 @@ TEST_F(HexagonUserDMATest, pipeline) { uint32_t pipeline_length = length / pipeline_depth; for (uint32_t i = 0; i < pipeline_depth; ++i) { - ret |= HexagonUserDMA::Get().Copy(dst_char + i * pipeline_length, - src_char + i * pipeline_length, pipeline_length); + ret |= user_dma->Copy(dst_char + i * pipeline_length, src_char + i * pipeline_length, + pipeline_length); } - HexagonUserDMA::Get().Wait(3); + user_dma->Wait(3); for (uint32_t i = 0; i < pipeline_length; ++i) { dst_char[i]++; } - HexagonUserDMA::Get().Wait(2); + user_dma->Wait(2); for (uint32_t i = pipeline_length; i < 2 * pipeline_length; ++i) { dst_char[i]++; } - HexagonUserDMA::Get().Wait(1); + user_dma->Wait(1); for (uint32_t i = 2 * pipeline_length; i < 3 * pipeline_length; ++i) { dst_char[i]++; } - HexagonUserDMA::Get().Wait(0); + user_dma->Wait(0); for (uint32_t i = 3 * pipeline_length; i < 4 * pipeline_length; ++i) { dst_char[i]++; } @@ -165,8 +167,8 @@ TEST_F(HexagonUserDMATest, overflow_ring_buffer) { for (uint32_t i = 0; i < number_of_dmas; ++i) { do { - ret = HexagonUserDMA::Get().Copy(dst_char + i * length_of_each_dma, - src_char + i * length_of_each_dma, length_of_each_dma); + ret = user_dma->Copy(dst_char + i * length_of_each_dma, src_char + i * length_of_each_dma, + length_of_each_dma); } while (ret == DMA_RETRY); ASSERT_EQ(ret, DMA_SUCCESS); }