Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,11 @@ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
*/
constexpr const char* opengl_stage_scope = "opengl_stage_scope";

/*!
* \brief Mark that it is in the device scope.
*/
constexpr const char* device_scope = "device_scope";

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
9 changes: 9 additions & 0 deletions include/tvm/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,15 @@ Stmt RewriteUnsafeSelect(Stmt stmt);
*/
Stmt LowerStorageAccessInfo(Stmt stmt);

/*!
* \brief Decorate the stmt with a device scope, this is helpful for
* hardware accelerator without thread blocks.
*
* \param stmt The stmt to be trasnformed
* \return Transformed stmt.
*/
Stmt DecorateDeviceScope(Stmt stmt);

/*!
* \brief Make an user callable API LoweredFunc.
*
Expand Down
1 change: 1 addition & 0 deletions src/api/api_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,5 +154,6 @@ REGISTER_PASS1(LowerTVMBuiltin);
REGISTER_PASS1(CombineContextCall);
REGISTER_PASS2(VerifyMemory);
REGISTER_PASS2(VerifyGPUCode);
REGISTER_PASS1(DecorateDeviceScope);
} // namespace ir
} // namespace tvm
21 changes: 21 additions & 0 deletions src/pass/detect_device.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*!
* Copyright (c) 2018 by Contributors
* \file detect_device.cc
*/

#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include "../pass/ir_util.h"

namespace tvm {
namespace ir {
Stmt DecorateDeviceScope(Stmt stmt) {
Stmt body = AttrStmt::make(make_zero(Int(32)),
ir::attr::device_scope,
0,
stmt);
return body;
}

} // namespace ir
} // namespace tvm
3 changes: 2 additions & 1 deletion src/pass/split_host_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ class HostDeviceSplitter : public IRMutator {

Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->attr_key == attr::thread_extent ||
op->attr_key == attr::pipeline_exec_scope) {
op->attr_key == attr::pipeline_exec_scope ||
op->attr_key == attr::device_scope) {
return SplitDeviceFunc(s);
}
return IRMutator::Mutate_(op, s);
Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_pass_decorate_device_scope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import tvm

def test_decorate_device():
m = tvm.var('m')
l = tvm.var('l')
A = tvm.placeholder((m, l), name='A')

A1 = tvm.compute((m, l), lambda i, j: A[i, j], name='A1')
A2 = tvm.compute((m, l), lambda i, j: A1[i, j] + 3, name='A2')

s = tvm.create_schedule(A2.op)
xo, xi = s[A2].split(A2.op.axis[0], factor=8)
s[A1].compute_at(s[A2], xo)
s[A1].set_scope("shared")

bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)
stmt1 = tvm.ir_pass.Simplify(stmt)
stmt2 = tvm.ir_pass.DecorateDeviceScope(stmt1)
assert isinstance(stmt2, tvm.stmt.AttrStmt)
assert stmt2.attr_key == "device_scope"
assert stmt1 == stmt2.body

if __name__ == "__main__":
test_decorate_device()