-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Hexagon] Generalize builtin for Nd memory alloc with storage scope and add lowering for VTCM / Hexagon #10558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6bbe6cc
3cb0121
5365fab
8044dc0
82f61be
c1843ed
4cef769
f13cd4c
e9ef946
daac188
5ca8970
8cea1e1
05423ea
a96c062
0ecd017
2c1ee84
351b0af
0f37782
6678e14
1088c66
7de3ae0
794dbbf
c21b254
7b06e7c
bc372da
4ff6471
1c23651
7e43cd8
5132fd6
b28ff9c
3d28c59
47268c5
40b0dd5
10ee3a5
1645238
caed9f1
0f59317
520f517
b7d8dd0
c3a3b30
53ce1ee
412857c
38cd975
484cc7b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,7 +62,7 @@ void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, int ndim, const int64_t* sh | |
| CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type; | ||
|
|
||
| // Forcing contiguous allocation, for now | ||
| // TODO(Straw): Enable discontiguous allocation after RFC 39 lands | ||
| // TODO(Straw): Enable discontiguous allocation | ||
| size_t nallocs = 1; | ||
| size_t nbytes = 1; | ||
| for (int i = 0; i < ndim; ++i) { | ||
|
|
@@ -107,7 +107,7 @@ void* HexagonDeviceAPIv2::AllocWorkspace(Device dev, size_t size, DLDataType typ | |
| dmlc::ThreadLocalStore<HexagonWorkspacePool>::Get()->AllocWorkspace(dev, size)); | ||
|
|
||
| // Assumes a single contiguous allocation | ||
| // TODO(Straw): Enable discontiguous allocation after RFC 39 lands | ||
| // TODO(Straw): Enable discontiguous allocation | ||
| void* ptr = hexbuf->GetPointer()[0]; | ||
| workspace_allocations_.insert({ptr, hexbuf}); | ||
| return ptr; | ||
|
|
@@ -122,6 +122,20 @@ void HexagonDeviceAPIv2::FreeWorkspace(Device dev, void* data) { | |
| workspace_allocations_.erase(it); | ||
| } | ||
|
|
||
| void* HexagonDeviceAPIv2::AllocVtcmWorkspace(Device dev, int ndim, const int64_t* shape, | ||
| DLDataType dtype, Optional<String> mem_scope) { | ||
| CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type; | ||
| // Forcing contiguous allocation, for now | ||
| // TODO(Straw): Enable discontiguous allocation | ||
| CHECK_EQ(ndim, 1); | ||
| return AllocDataSpace(dev, ndim, shape, dtype, mem_scope); | ||
| } | ||
|
|
||
| void HexagonDeviceAPIv2::FreeVtcmWorkspace(Device dev, void* ptr) { | ||
| CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type; | ||
| FreeDataSpace(dev, ptr); | ||
| } | ||
|
|
||
| void HexagonDeviceAPIv2::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { | ||
| CHECK_EQ(from->byte_offset, 0); | ||
| CHECK_EQ(to->byte_offset, 0); | ||
|
|
@@ -166,6 +180,60 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVM | |
| *rv = static_cast<int32_t>(0); | ||
| }); | ||
|
|
||
| std::map<void*, HexagonBuffer*> vtcmallocs; | ||
|
|
||
| TVM_REGISTER_GLOBAL("device_api.hexagon.AllocNd").set_body([](TVMArgs args, TVMRetValue* rv) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of using a lambda for the body, can we move this to a separate function? It tends to make debugging easier later on, and becomes a template for a generalized function we can propose adding to
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This PR creates a separate function called Note that the lambda must construct |
||
| int32_t device_type = args[0]; | ||
| int32_t device_id = args[1]; | ||
| int32_t dtype_code_hint = args[2]; | ||
| int32_t dtype_bits_hint = args[3]; | ||
| std::string scope = args[4]; | ||
| CHECK(scope.find("global.vtcm") != std::string::npos); | ||
| int64_t ndim = args[5]; | ||
| // Forcing contiguous allocation, for now | ||
| // TODO(Straw): Enable discontiguous allocation | ||
| CHECK_EQ(ndim, 1); | ||
| int64_t* shape = static_cast<int64_t*>(static_cast<void*>(args[6])); | ||
|
|
||
| Device dev; | ||
| dev.device_type = static_cast<DLDeviceType>(device_type); | ||
| dev.device_id = device_id; | ||
|
|
||
| DLDataType type_hint; | ||
| type_hint.code = static_cast<decltype(type_hint.code)>(dtype_code_hint); | ||
| type_hint.bits = static_cast<decltype(type_hint.bits)>(dtype_bits_hint); | ||
| type_hint.lanes = 1; | ||
|
|
||
| HexagonDeviceAPIv2* hexapi = HexagonDeviceAPIv2::Global(); | ||
| HexagonBuffer* hexbuf = reinterpret_cast<HexagonBuffer*>( | ||
| hexapi->AllocVtcmWorkspace(dev, ndim, shape, type_hint, String(scope))); | ||
|
|
||
| // Assumes a single contiguous allocation | ||
| // TODO(Straw): Enable discontiguous allocation | ||
| void* ptr = hexbuf->GetPointer()[0]; | ||
| vtcmallocs[ptr] = hexbuf; | ||
| *rv = ptr; | ||
| }); | ||
|
|
||
| TVM_REGISTER_GLOBAL("device_api.hexagon.FreeNd").set_body([](TVMArgs args, TVMRetValue* rv) { | ||
adstraw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| int32_t device_type = args[0]; | ||
| int32_t device_id = args[1]; | ||
| std::string scope = args[2]; | ||
| CHECK(scope.find("vtcm") != std::string::npos); | ||
| void* ptr = args[3]; | ||
| CHECK(vtcmallocs.find(ptr) != vtcmallocs.end()); | ||
|
|
||
| HexagonBuffer* hexbuf = vtcmallocs[ptr]; | ||
|
|
||
| Device dev; | ||
| dev.device_type = static_cast<DLDeviceType>(device_type); | ||
| dev.device_id = device_id; | ||
|
|
||
| HexagonDeviceAPIv2* hexapi = HexagonDeviceAPIv2::Global(); | ||
| hexapi->FreeVtcmWorkspace(dev, hexbuf); | ||
| *rv = static_cast<int32_t>(0); | ||
| }); | ||
|
|
||
| TVM_REGISTER_GLOBAL("device_api.hexagon.v2").set_body([](TVMArgs args, TVMRetValue* rv) { | ||
| DeviceAPI* ptr = HexagonDeviceAPIv2::Global(); | ||
| *rv = static_cast<void*>(ptr); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
|
|
||
| #include <tvm/tir/builtin.h> | ||
| #include <tvm/tir/stmt.h> | ||
| #include <tvm/tir/transform.h> | ||
|
|
||
| #include "../../arith/ir_visitor_with_analyzer.h" | ||
|
|
||
| namespace tvm { | ||
| namespace tir { | ||
|
|
||
| inline bool IsVtcmStorage(std::string scope) { | ||
| return scope.find("global.vtcm") != std::string::npos; | ||
| } | ||
|
|
||
| class VtcmAllocator : public StmtExprMutator { | ||
| public: | ||
| using StmtExprMutator::VisitStmt_; | ||
| VtcmAllocator() {} | ||
|
|
||
| Stmt VisitStmt_(const AllocateNode* op) final { | ||
| std::string storage_scope = GetStorageScope(op->buffer_var); | ||
| if (IsVtcmStorage(storage_scope)) { | ||
| Stmt body = this->VisitStmt(op->body); | ||
adstraw marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Array<PrimExpr> args; | ||
| args.push_back(StringImm(storage_scope)); | ||
| args.push_back(IntImm(DataType::Int(64), op->extents.size())); | ||
| args.push_back(Call(DataType::Handle(), builtin::tvm_stack_make_shape(), op->extents)); | ||
| return LetStmt(op->buffer_var, | ||
| Call(op->buffer_var.dtype(), builtin::nd_mem_alloc_with_scope(), args), body); | ||
| } | ||
| return StmtExprMutator::VisitStmt_(op); | ||
| } | ||
|
|
||
| protected: | ||
| std::string GetStorageScope(const Var& var) { | ||
| auto* ptr = var->type_annotation.as<PointerTypeNode>(); | ||
| ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; | ||
| return ptr->storage_scope; | ||
| } | ||
| }; | ||
|
|
||
| PrimFunc LowerVtcmAlloc(PrimFunc func) { | ||
| auto fptr = func.CopyOnWrite(); | ||
| fptr->body = VtcmAllocator()(std::move(fptr->body)); | ||
| return func; | ||
| } | ||
|
|
||
| namespace transform { | ||
|
|
||
| Pass LowerVtcmAlloc() { | ||
| auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { | ||
| return LowerVtcmAlloc(std::move(f)); | ||
| }; | ||
| return CreatePrimFuncPass(pass_func, 0, "tir.LowerVtcmAlloc", {}); | ||
| } | ||
|
|
||
| TVM_REGISTER_GLOBAL("tir.transform.LowerVtcmAlloc").set_body_typed(LowerVtcmAlloc); | ||
|
|
||
| } // namespace transform | ||
|
|
||
| } // namespace tir | ||
| } // namespace tvm | ||
Uh oh!
There was an error while loading. Please reload this page.