Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/runtime/hexagon/hexagon_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVM

int ret = DMA_RETRY;
do {
ret = HexagonUserDMA::Get().Copy(dst, src, size);
ret = HexagonDeviceAPI::Global()->UserDMA()->Copy(dst, src, size);
} while (ret == DMA_RETRY);
*rv = static_cast<int32_t>(ret);
});
Expand All @@ -227,7 +227,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.dma_wait").set_body([](TVMArgs args, TVM
ICHECK(queue_id == 0 && "Hexagon supports just a single asynchronous queue for DMA");
int inflight = args[1];
ICHECK(inflight >= 0);
HexagonUserDMA::Get().Wait(inflight);
HexagonDeviceAPI::Global()->UserDMA()->Wait(inflight);
*rv = static_cast<int32_t>(0);
});

Expand Down
20 changes: 19 additions & 1 deletion src/runtime/hexagon/hexagon_device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "hexagon_buffer.h"
#include "hexagon_buffer_manager.h"
#include "hexagon_thread_manager.h"
#include "hexagon_user_dma.h"

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -61,10 +62,18 @@ class HexagonDeviceAPI final : public DeviceAPI {
CHECK_EQ(runtime_threads, nullptr);
runtime_threads = std::make_unique<HexagonThreadManager>(threads, stack_size, pipe_size);
DLOG(INFO) << "runtime_threads created";

CHECK_EQ(runtime_dma, nullptr);
runtime_dma = std::make_unique<HexagonUserDMA>();
DLOG(INFO) << "runtime_dma created";
}

//! \brief Ensures all runtime resources are freed
void ReleaseResources() {
CHECK(runtime_dma) << "runtime_dma was not created in AcquireResources";
runtime_dma.reset();
DLOG(INFO) << "runtime_dma reset";

CHECK(runtime_threads) << "runtime_threads was not created in AcquireResources";
runtime_threads.reset();
DLOG(INFO) << "runtime_threads reset";
Expand Down Expand Up @@ -150,7 +159,13 @@ class HexagonDeviceAPI final : public DeviceAPI {
void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) final;

HexagonThreadManager* ThreadManager() {
return runtime_threads ? runtime_threads.get() : nullptr;
CHECK(runtime_threads) << "runtime_threads has not been created";
return runtime_threads.get();
}

HexagonUserDMA* UserDMA() {
CHECK(runtime_dma) << "runtime_dma has not been created";
return runtime_dma.get();
}

protected:
Expand Down Expand Up @@ -184,6 +199,9 @@ class HexagonDeviceAPI final : public DeviceAPI {
const unsigned threads{6};
const unsigned pipe_size{1000};
const unsigned stack_size{0x4000}; // 16KB

//! \brief User DMA manager
std::unique_ptr<HexagonUserDMA> runtime_dma;
};
} // namespace hexagon
} // namespace runtime
Expand Down
12 changes: 8 additions & 4 deletions src/runtime/hexagon/hexagon_user_dma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

#include <algorithm>

#include "hexagon_device_api.h"

namespace tvm {
namespace runtime {
namespace hexagon {
Expand Down Expand Up @@ -116,13 +118,15 @@ HexagonUserDMA::~HexagonUserDMA() {
}

int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length) {
HexagonUserDMA* user_dma = HexagonDeviceAPI::Global()->UserDMA();

// One DMA transfer can copy at most DESC_LENGTH_MASK bytes.
// Make the common case quick.
if (length <= DESC_LENGTH_MASK) {
// sync DMA -> `Copy` and then `Wait(0)`
int ret_val = HexagonUserDMA::Get().Copy(dst, src, length);
int ret_val = user_dma->Copy(dst, src, length);
if (ret_val != DMA_SUCCESS) return ret_val;
HexagonUserDMA::Get().Wait(0);
user_dma->Wait(0);
return DMA_SUCCESS;
}

Expand All @@ -133,9 +137,9 @@ int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length) {
// Ensure there is no overflow while updating i
uint32_t cur_len = std::min<uint32_t>(length - i, DESC_LENGTH_MASK);
// sync DMA -> `Copy` and then `Wait(0)`
int ret_val = HexagonUserDMA::Get().Copy(&cast_dst[i], &cast_src[i], cur_len);
int ret_val = user_dma->Copy(&cast_dst[i], &cast_src[i], cur_len);
if (ret_val != DMA_SUCCESS) return ret_val;
HexagonUserDMA::Get().Wait(0);
user_dma->Wait(0);
// 2 cases for new val for i:
// 1. length - i <= DESC_LENGTH_MASK (<= MAX_UINT)
// new_i = i + (length - i) = length, no more iter
Expand Down
21 changes: 7 additions & 14 deletions src/runtime/hexagon/hexagon_user_dma.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ namespace hexagon {

class HexagonUserDMA {
public:
HexagonUserDMA();
~HexagonUserDMA();
HexagonUserDMA(const HexagonUserDMA&) = delete;
HexagonUserDMA& operator=(const HexagonUserDMA&) = delete;
HexagonUserDMA(HexagonUserDMA&&) = delete;
HexagonUserDMA& operator=(HexagonUserDMA&&) = delete;

/*!
* \brief Initiate DMA to copy memory from source to destination address
* \param dst Destination address
Expand All @@ -59,21 +66,7 @@ class HexagonUserDMA {
*/
uint32_t Poll();

//! \brief HexagonUserDMA uses the singleton pattern
static HexagonUserDMA& Get() {
static HexagonUserDMA* hud = new HexagonUserDMA();
return *hud;
}

private:
// HexagonUserDMA uses the singleton pattern
HexagonUserDMA();
~HexagonUserDMA();
HexagonUserDMA(const HexagonUserDMA&) = delete;
HexagonUserDMA& operator=(const HexagonUserDMA&) = delete;
HexagonUserDMA(HexagonUserDMA&&) = delete;
HexagonUserDMA& operator=(HexagonUserDMA&&) = delete;

//! \brief Initializes the Hexagon User DMA engine
unsigned int Init();

Expand Down
13 changes: 11 additions & 2 deletions tests/cpp-runtime/hexagon/hexagon_device_api_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,16 @@ TEST_F(HexagonDeviceAPITest, thread_manager) {
HexagonThreadManager* threads = hexapi->ThreadManager();
CHECK(threads != nullptr);
hexapi->ReleaseResources();
threads = hexapi->ThreadManager();
CHECK(threads == nullptr);
EXPECT_THROW(hexapi->ThreadManager(), InternalError);
hexapi->AcquireResources();
}

// Ensure thread manager is properly configured and destroyed
// in Acquire/Release
TEST_F(HexagonDeviceAPITest, user_dma) {
HexagonUserDMA* user_dma = hexapi->UserDMA();
CHECK(user_dma != nullptr);
hexapi->ReleaseResources();
EXPECT_THROW(hexapi->UserDMA(), InternalError);
hexapi->AcquireResources();
}
48 changes: 25 additions & 23 deletions tests/cpp-runtime/hexagon/hexagon_user_dma_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@

#include <gtest/gtest.h>

#include "../src/runtime/hexagon/hexagon_user_dma.h"
#include "../src/runtime/hexagon/hexagon_device_api.h"

using namespace tvm::runtime;
using namespace tvm::runtime::hexagon;

class HexagonUserDMATest : public ::testing::Test {
void SetUp() override {
user_dma = HexagonDeviceAPI::Global()->UserDMA();
src = malloc(length);
dst = malloc(length);
ASSERT_NE(src, nullptr);
Expand All @@ -44,6 +45,7 @@ class HexagonUserDMATest : public ::testing::Test {
}

public:
HexagonUserDMA* user_dma;
int ret{0};
void* src{nullptr};
void* dst{nullptr};
Expand All @@ -53,29 +55,29 @@ class HexagonUserDMATest : public ::testing::Test {
};

TEST_F(HexagonUserDMATest, wait) {
HexagonUserDMA::Get().Wait(0);
HexagonUserDMA::Get().Wait(10);
user_dma->Wait(0);
user_dma->Wait(10);
}

TEST_F(HexagonUserDMATest, poll) { ASSERT_EQ(HexagonUserDMA::Get().Poll(), 0); }
TEST_F(HexagonUserDMATest, poll) { ASSERT_EQ(user_dma->Poll(), 0); }

TEST_F(HexagonUserDMATest, bad_copy) {
uint64_t bigaddr = 0x100000000;
void* src64 = reinterpret_cast<void*>(bigaddr);
void* dst64 = reinterpret_cast<void*>(bigaddr);
uint32_t biglength = 0x1000000;
ASSERT_NE(HexagonUserDMA::Get().Copy(dst64, src, length), DMA_SUCCESS);
ASSERT_NE(HexagonUserDMA::Get().Copy(dst, src64, length), DMA_SUCCESS);
ASSERT_NE(HexagonUserDMA::Get().Copy(dst, src, biglength), DMA_SUCCESS);
ASSERT_NE(user_dma->Copy(dst64, src, length), DMA_SUCCESS);
ASSERT_NE(user_dma->Copy(dst, src64, length), DMA_SUCCESS);
ASSERT_NE(user_dma->Copy(dst, src, biglength), DMA_SUCCESS);
}

TEST_F(HexagonUserDMATest, sync_dma) {
// kick off 1 DMA
ret = HexagonUserDMA::Get().Copy(dst, src, length);
ret = user_dma->Copy(dst, src, length);
ASSERT_EQ(ret, DMA_SUCCESS);

// wait for DMA to complete
HexagonUserDMA::Get().Wait(0);
user_dma->Wait(0);

// verify
for (uint32_t i = 0; i < length; ++i) {
Expand All @@ -86,31 +88,31 @@ TEST_F(HexagonUserDMATest, sync_dma) {
TEST_F(HexagonUserDMATest, async_dma_wait) {
// kick off 10x duplicate DMAs
for (uint32_t i = 0; i < 10; ++i) {
ret = HexagonUserDMA::Get().Copy(dst, src, length);
ret = user_dma->Copy(dst, src, length);
ASSERT_EQ(ret, DMA_SUCCESS);
}

// wait for at least 1 DMA to complete
HexagonUserDMA::Get().Wait(9);
user_dma->Wait(9);

// verify
for (uint32_t i = 0; i < length; ++i) {
ASSERT_EQ(src_char[i], dst_char[i]);
}

// empty the DMA queue
HexagonUserDMA::Get().Wait(0);
user_dma->Wait(0);
}

TEST_F(HexagonUserDMATest, async_dma_poll) {
// kick off 10x duplicate DMAs
for (uint32_t i = 0; i < 10; ++i) {
ret = HexagonUserDMA::Get().Copy(dst, src, length);
ret = user_dma->Copy(dst, src, length);
ASSERT_EQ(ret, DMA_SUCCESS);
}

// poll until at least 1 DMA is complete
while (HexagonUserDMA::Get().Poll() == 10) {
while (user_dma->Poll() == 10) {
};

// verify
Expand All @@ -119,7 +121,7 @@ TEST_F(HexagonUserDMATest, async_dma_poll) {
}

// empty the DMA queue
HexagonUserDMA::Get().Wait(0);
user_dma->Wait(0);
}

// TODO: Run non-pipelined case with sync DMA and execution time vs. pipelined case
Expand All @@ -128,26 +130,26 @@ TEST_F(HexagonUserDMATest, pipeline) {
uint32_t pipeline_length = length / pipeline_depth;

for (uint32_t i = 0; i < pipeline_depth; ++i) {
ret |= HexagonUserDMA::Get().Copy(dst_char + i * pipeline_length,
src_char + i * pipeline_length, pipeline_length);
ret |= user_dma->Copy(dst_char + i * pipeline_length, src_char + i * pipeline_length,
pipeline_length);
}

HexagonUserDMA::Get().Wait(3);
user_dma->Wait(3);
for (uint32_t i = 0; i < pipeline_length; ++i) {
dst_char[i]++;
}

HexagonUserDMA::Get().Wait(2);
user_dma->Wait(2);
for (uint32_t i = pipeline_length; i < 2 * pipeline_length; ++i) {
dst_char[i]++;
}

HexagonUserDMA::Get().Wait(1);
user_dma->Wait(1);
for (uint32_t i = 2 * pipeline_length; i < 3 * pipeline_length; ++i) {
dst_char[i]++;
}

HexagonUserDMA::Get().Wait(0);
user_dma->Wait(0);
for (uint32_t i = 3 * pipeline_length; i < 4 * pipeline_length; ++i) {
dst_char[i]++;
}
Expand All @@ -165,8 +167,8 @@ TEST_F(HexagonUserDMATest, overflow_ring_buffer) {

for (uint32_t i = 0; i < number_of_dmas; ++i) {
do {
ret = HexagonUserDMA::Get().Copy(dst_char + i * length_of_each_dma,
src_char + i * length_of_each_dma, length_of_each_dma);
ret = user_dma->Copy(dst_char + i * length_of_each_dma, src_char + i * length_of_each_dma,
length_of_each_dma);
} while (ret == DMA_RETRY);
ASSERT_EQ(ret, DMA_SUCCESS);
}
Expand Down