Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -713,13 +713,6 @@ TVM_DLL const Op& texture2d_store();
*/
TVM_DLL const Op& texture2d_load();

/*!
* \brief Copy 1d memory from source to destination
* Same semantics as memcpy(destination, source, size)
* Allows for device specific implementations e.g. direct memory access (DMA)
*/
TVM_DLL const Op& mem_copy();

/*!
* \brief Initiate a non-blocking DMA copy from source to destination
*/
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.dma_bypass_cache", Bool);

using tvm::Array;
using tvm::transform::Pass;
Expand Down
18 changes: 14 additions & 4 deletions src/runtime/hexagon/hexagon_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,12 @@

#include "hexagon_common.h"
#include "hexagon_device_api.h"
#include "qurt_memory.h"

namespace tvm {
namespace runtime {
namespace hexagon {

int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length);

struct Allocation {
Allocation(size_t allocation_nbytes, size_t alignment)
: allocation_nbytes_(allocation_nbytes), alignment_(alignment) {}
Expand Down Expand Up @@ -237,8 +236,19 @@ void hexagon_buffer_copy_across_regions(const BufferSet& dest, const BufferSet&

// Finally, do the memory copies.
for (const auto& copy : macro_copies) {
int error_code = hexagon_user_dma_1d_sync(copy.dest, copy.src, copy.num_bytes);
CHECK_EQ(error_code, 0);
// clean Hexagon cache before / after memcpy to ensure clean cache state to enable usage of DMA
// bypass mode for increased DMA bandwidth
// TODO(HWE): Switch to ION Buffer to avoid need for memcpy and potentially lighten or alleviate
// the burden of cache invalidation in this code
qurt_mem_cache_clean(reinterpret_cast<qurt_addr_t>(copy.dest), copy.num_bytes,
QURT_MEM_CACHE_INVALIDATE, QURT_MEM_DCACHE);
qurt_mem_cache_clean(reinterpret_cast<qurt_addr_t>(copy.src), copy.num_bytes,
QURT_MEM_CACHE_INVALIDATE, QURT_MEM_DCACHE);
memcpy(copy.dest, copy.src, copy.num_bytes);
qurt_mem_cache_clean(reinterpret_cast<qurt_addr_t>(copy.dest), copy.num_bytes,
QURT_MEM_CACHE_INVALIDATE, QURT_MEM_DCACHE);
qurt_mem_cache_clean(reinterpret_cast<qurt_addr_t>(copy.src), copy.num_bytes,
QURT_MEM_CACHE_INVALIDATE, QURT_MEM_DCACHE);
}
}

Expand Down
29 changes: 13 additions & 16 deletions src/runtime/hexagon/hexagon_device_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ namespace tvm {
namespace runtime {
namespace hexagon {

int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length);

HexagonDeviceAPI* HexagonDeviceAPI::Global() {
static auto* inst = new HexagonDeviceAPI();
return inst;
Expand Down Expand Up @@ -206,39 +204,38 @@ void HexagonDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void
memcpy(static_cast<char*>(to) + to_offset, static_cast<const char*>(from) + from_offset, size);
}

TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy_DLTensor")
TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy_dltensor")
.set_body([](TVMArgs args, TVMRetValue* rv) {
DLTensor* dst = args[0];
DLTensor* src = args[1];
int size = args[2];
ICHECK(size > 0);
bool bypass_cache = args[3];

hexagon_user_dma_1d_sync(dst->data, src->data, size);
int ret = DMA_RETRY;
do {
ret = HexagonDeviceAPI::Global()->UserDMA()->Copy(SYNC_DMA_QUEUE, dst->data, src->data,
size, bypass_cache);
} while (ret == DMA_RETRY);
CHECK(ret == DMA_SUCCESS);
HexagonDeviceAPI::Global()->UserDMA()->Wait(SYNC_DMA_QUEUE, 0);

*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVMRetValue* rv) {
void* dst = args[0];
void* src = args[1];
int size = args[2];

int error_code = hexagon_user_dma_1d_sync(dst, src, size);
CHECK_EQ(error_code, 0);

*rv = static_cast<int32_t>(0);
});

TVM_REGISTER_GLOBAL("device_api.hexagon.dma_copy").set_body([](TVMArgs args, TVMRetValue* rv) {
int queue_id = args[0];
void* dst = args[1];
void* src = args[2];
int size = args[3];
ICHECK(size > 0);
bool bypass_cache = args[3];

int ret = DMA_RETRY;
do {
ret = HexagonDeviceAPI::Global()->UserDMA()->Copy(queue_id, dst, src, size);
ret = HexagonDeviceAPI::Global()->UserDMA()->Copy(queue_id, dst, src, size, bypass_cache);
} while (ret == DMA_RETRY);
CHECK(ret == DMA_SUCCESS);
*rv = static_cast<int32_t>(ret);
});

Expand Down
61 changes: 19 additions & 42 deletions src/runtime/hexagon/hexagon_user_dma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ unsigned int HexagonUserDMA::Init() {
return status;
}

int HexagonUserDMA::Copy(int queue_id, void* dst, void* src, uint32_t length) {
int HexagonUserDMA::Copy(int queue_id, void* dst, void* src, uint32_t length, bool bypass_cache) {
// length limited to 24 bits
if (length > DESC_LENGTH_MASK) {
return DMA_FAILURE;
Expand Down Expand Up @@ -66,8 +66,24 @@ int HexagonUserDMA::Copy(int queue_id, void* dst, void* src, uint32_t length) {
dma_desc_set_desctype(dma_desc, DESC_DESCTYPE_1D);
dma_desc_set_dstcomp(dma_desc, DESC_COMP_NONE);
dma_desc_set_srccomp(dma_desc, DESC_COMP_NONE);
dma_desc_set_bypassdst(dma_desc, DESC_BYPASS_OFF);
dma_desc_set_bypasssrc(dma_desc, DESC_BYPASS_OFF);

bool dst_is_ddr = !HexagonDeviceAPI::Global()->VtcmPool()->IsVtcm(dst, length);
bool src_is_ddr = !HexagonDeviceAPI::Global()->VtcmPool()->IsVtcm(src, length);

// VTCM -> DDR with bypass enabled
if (dst_is_ddr && !src_is_ddr && bypass_cache) {
dma_desc_set_bypassdst(dma_desc, DESC_BYPASS_ON);
} else {
dma_desc_set_bypassdst(dma_desc, DESC_BYPASS_OFF);
}

// DDR -> VTCM with bypass enabled
if (src_is_ddr && !dst_is_ddr && bypass_cache) {
dma_desc_set_bypasssrc(dma_desc, DESC_BYPASS_ON);
} else {
dma_desc_set_bypasssrc(dma_desc, DESC_BYPASS_OFF);
}

dma_desc_set_order(dma_desc, DESC_ORDER_ORDER);
dma_desc_set_done(dma_desc, DESC_DONE_INCOMPLETE);
dma_desc_set_src(dma_desc, src32);
Expand Down Expand Up @@ -117,45 +133,6 @@ HexagonUserDMA::~HexagonUserDMA() {
delete descriptors_;
}

int hexagon_user_dma_1d_sync(void* dst, void* src, uint32_t length) {
HexagonUserDMA* user_dma = HexagonDeviceAPI::Global()->UserDMA();

// One DMA transfer can copy at most DESC_LENGTH_MASK bytes.
// Make the common case quick.
if (length <= DESC_LENGTH_MASK) {
// sync DMA -> `Copy` and then `Wait(0)`
int ret_val = user_dma->Copy(SYNC_DMA_QUEUE, dst, src, length);
if (ret_val != DMA_SUCCESS) return ret_val;
user_dma->Wait(SYNC_DMA_QUEUE, 0);
return DMA_SUCCESS;
}

// Split big transfers into smaller transfers.
char* cast_src = static_cast<char*>(src);
char* cast_dst = static_cast<char*>(dst);
for (uint32_t i = 0; i < length;) {
// Ensure there is no overflow while updating i
uint32_t cur_len = std::min<uint32_t>(length - i, DESC_LENGTH_MASK);
// sync DMA -> `Copy` and then `Wait(0)`
int ret_val = user_dma->Copy(SYNC_DMA_QUEUE, &cast_dst[i], &cast_src[i], cur_len);
if (ret_val != DMA_SUCCESS) return ret_val;
user_dma->Wait(SYNC_DMA_QUEUE, 0);
// 2 cases for new val for i:
// 1. length - i <= DESC_LENGTH_MASK (<= MAX_UINT)
// new_i = i + (length - i) = length, no more iter
// and no overflow (since (length - i) <= (MAX_UINT - i))
// 2. length - i > DESC_LENGTH_MASK
// length > (i + DESC_LENGTH_MASK)
// new_i = (i + DESC_LENGTH_MASK)
// length > new_i for next iter, we're done
// length - i > DESC_LENGTH_MASK
// and length <= MAX_UINT,
// so MAX_UINT >= length > DESC_LEN_MASK + i
// MAX_UINT > (DESC_LEN_MASK + i), so no overflow
i += cur_len;
}
return DMA_SUCCESS;
}
} // namespace hexagon
} // namespace runtime
} // namespace tvm
2 changes: 1 addition & 1 deletion src/runtime/hexagon/hexagon_user_dma.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class HexagonUserDMA {
* \param length Length in bytes to copy
* \returns Status: DMA_SUCCESS or DMA_FAILURE
*/
int Copy(int queue_id, void* dst, void* src, uint32_t length);
int Copy(int queue_id, void* dst, void* src, uint32_t length, bool bypass_cache);

/*!
* \brief Wait until the number of DMAs in flight is less than or equal to some maximum
Expand Down
12 changes: 12 additions & 0 deletions src/runtime/hexagon/hexagon_vtcm_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ class HexagonVtcmPool {
//! \brief Returns the total number of bytes in this pool
size_t TotalBytes() { return reinterpret_cast<size_t>(vtcm_size_); }

bool IsVtcm(void* ptr, unsigned size) {
auto char_ptr = static_cast<char*>(ptr);
CHECK(char_ptr != nullptr);
auto char_vtcm = static_cast<char*>(vtcm_data_);
CHECK(vtcm_data_ != nullptr);

if (char_ptr >= char_vtcm && (char_ptr + size) <= (char_vtcm + vtcm_size_)) {
return true;
}
return false;
}

private:
//! \brief Total size of VTCM pool
unsigned int vtcm_size_;
Expand Down
3 changes: 0 additions & 3 deletions src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,6 @@ TIR_DEFINE_BUILTIN_FUNC(texture2d_load)
.set_attr<TVectorizable>("TVectorizable", true)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(mem_copy).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

TIR_DEFINE_BUILTIN_FUNC(dma_copy).set_attr<TCallEffectKind>("TCallEffectKind",
Integer(CallEffectKind::kOpaque));

Expand Down
14 changes: 8 additions & 6 deletions src/tir/transforms/lower_async_dma.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace tir {

class AsyncDMALowerer : public StmtExprMutator {
public:
AsyncDMALowerer() {}
explicit AsyncDMALowerer(bool dma_bypass_cache) : dma_bypass_cache_(dma_bypass_cache) {}

Stmt VisitStmt_(const AttrStmtNode* op) final {
// Convert this, for example:
Expand All @@ -52,7 +52,7 @@ class AsyncDMALowerer : public StmtExprMutator {
int queue_id = queue_id_node->value;

// abort if we have not seen this queue ID in `copy` transform
if (queue_ids.find(queue_id) == queue_ids.end()) {
if (queue_ids_.find(queue_id) == queue_ids_.end()) {
DLOG(INFO) << "AsyncDMALowerer exiting because the queue ID observed in the "
"`async_wait_queue_scope` transform has not been previously observed in the "
"`async_commit_queue_scope` transform";
Expand Down Expand Up @@ -160,29 +160,31 @@ class AsyncDMALowerer : public StmtExprMutator {

// now that we are about to perform the `copy` transform
// save queue ID for inspection in `wait` transform
queue_ids.insert(queue_id);
queue_ids_.insert(queue_id);

return Evaluate(Call(DataType::Int(32), builtin::dma_copy(),
{queue_id,
Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(bufferstorenode->buffer, store_index)}),
Call(DataType::Handle(), builtin::address_of(),
{BufferLoad(bufferloadnode->buffer, load_index)}),
for_loop->extent * bufferloadnode->dtype.bytes()}));
for_loop->extent * bufferloadnode->dtype.bytes(), dma_bypass_cache_}));
}
return StmtExprMutator::VisitStmt_(op);
}

private:
std::set<int> queue_ids;
std::set<int> queue_ids_;
bool dma_bypass_cache_;
};

namespace transform {

Pass LowerAsyncDMA() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto fptr = f.CopyOnWrite();
fptr->body = AsyncDMALowerer()(std::move(fptr->body));
bool dma_bypass_cache = ctx->GetConfig<Bool>("tir.dma_bypass_cache", Bool(false)).value();
fptr->body = AsyncDMALowerer(dma_bypass_cache)(std::move(fptr->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerAsyncDMA", {});
Expand Down
15 changes: 0 additions & 15 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,6 @@ class BuiltinLower : public StmtExprMutator {
return MakeArray(op);
} else if (op->op.same_as(builtin::tvm_context_id())) {
return make_zero(op->dtype);
} else if (op->op.same_as(builtin::mem_copy())) {
return MakeMemCopy(op);
} else if (op->op.same_as(builtin::dma_copy())) {
return MakeDMACopy(op);
} else if (op->op.same_as(builtin::dma_wait())) {
Expand All @@ -326,19 +324,6 @@ class BuiltinLower : public StmtExprMutator {
}
}

PrimExpr MakeMemCopy(const CallNode* op) {
PrimExpr dst = op->args[0];
PrimExpr src = op->args[1];
PrimExpr size = op->args[2];

std::string fdevapi_prefix =
"device_api." + std::string(runtime::DeviceName(device_type_.as<IntImmNode>()->value));

Call call_packed = Call(DataType::Int(32), builtin::tvm_call_packed(),
{StringImm(fdevapi_prefix + ".mem_copy"), dst, src, size});
return VisitExpr(call_packed);
}

PrimExpr MakeDMACopy(const CallNode* op) {
PrimExpr queue_id = op->args[0];
PrimExpr dst = op->args[1];
Expand Down
Loading