Skip to content

Commit 9c5d8cd

Browse files
YuchenJinZihengJiangMasterJH5574sungggjunrushao
committed
[Unity] Relax VM (#13878)
This PR implements a flexible register-based VM to execute relax programs with dynamic shape and control flow. Design: https://github.com/tlc-pack/relax/wiki/Relax-VM-Design. Co-Authored-by: Ziheng Jiang <ziheng@apache.org> Co-Authored-by: Ruihang Lai <ruihangl@cs.cmu.edu> Co-Authored-by: Sunghyun Park <49998730+sunggg@users.noreply.github.com> Co-Authored-by: Junru Shao <junrushao1994@gmail.com> Co-Authored-by: Prakalp Srivastava <prakalp@octoml.ai> Co-Authored-by: Yong Wu <yongcale@gmail.com> Co-Authored-by: Steven S. Lyubomirsky <slyubomirsky@octoml.ai> Co-Authored-by: Tianqi Chen <tianqi.tchen@gmail.com> Co-Authored-by: Hongyi Jin <3231950289@qq.com>
1 parent 1e988a4 commit 9c5d8cd

File tree

21 files changed

+4797
-0
lines changed

21 files changed

+4797
-0
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS
289289
src/driver/*.cc
290290
src/support/*.cc
291291
src/script/*.cc
292+
src/relax/backend/vm/*.cc
292293
)
293294

294295
tvm_file_glob(GLOB CODEGEN_SRCS
@@ -335,6 +336,7 @@ tvm_file_glob(GLOB RUNTIME_SRCS
335336
src/runtime/*.cc
336337
src/runtime/vm/*.cc
337338
src/runtime/minrpc/*.cc
339+
src/runtime/relax_vm/*.cc
338340
)
339341

340342
if(BUILD_FOR_HEXAGON)

include/tvm/relax/exec_builder.h

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/relax/exec_builder.h
22+
*/
23+
#ifndef TVM_RELAX_EXEC_BUILDER_H_
24+
#define TVM_RELAX_EXEC_BUILDER_H_
25+
26+
#include <tvm/ir/expr.h>
27+
#include <tvm/node/reflection.h>
28+
#include <tvm/node/repr_printer.h>
29+
#include <tvm/runtime/object.h>
30+
#include <tvm/runtime/registry.h>
31+
#include <tvm/runtime/relax_vm/bytecode.h>
32+
#include <tvm/runtime/relax_vm/executable.h>
33+
34+
#include <string>
35+
#include <unordered_map>
36+
#include <vector>
37+
38+
namespace tvm {
39+
namespace relax {
40+
41+
namespace vm = tvm::runtime::relax_vm;
42+
43+
class ExecBuilder;
44+
45+
/*!
46+
* \brief A builder provides api to build VM executable with instructions.
47+
*/
48+
class ExecBuilderNode : public Object {
49+
public:
50+
/*!
51+
* \brief Declare a function, it is OK to have multiple declarations.
52+
* \param func The function name.
53+
* \param kind The kind of the function.
54+
*/
55+
void DeclareFunction(const std::string& func, vm::VMFuncInfo::FuncKind kind);
56+
/*!
57+
* \brief To annotate the start of a vm function.
58+
* \param func The function name.
59+
* \param num_inputs The number of inputs.
60+
* \param param_names The function parameter names.
61+
* \param kind The kind of the function.
62+
* \param init_register_size Initial setting of register file size.
63+
*/
64+
void EmitFunction(const std::string& func, int64_t num_inputs,
65+
Optional<Array<String>> param_names,
66+
vm::VMFuncInfo::FuncKind kind = vm::VMFuncInfo::FuncKind::kVMFunc,
67+
int64_t init_register_size = 0);
68+
/*!
69+
* \brief Annotate the end of a vm function.
70+
* \param func The function name.
71+
*/
72+
void EndFunction(const std::string& func);
73+
/*!
74+
* \brief Emit a call instruction for a packed function.
75+
* \param func The packed function name.
76+
* \param args The arguments of the function.
77+
* \param ret The return register.
78+
*/
79+
void EmitCall(const std::string& func, std::vector<vm::Instruction::Arg> args, vm::RegName ret);
80+
/*!
81+
* \brief Emit a call instruction with func as argument.
82+
* \param func The packed function index.
83+
* \param args The arguments of the function.
84+
* \param ret The return register.
85+
*/
86+
void EmitCall(vm::Instruction::Arg func, std::vector<vm::Instruction::Arg> args, vm::RegName ret);
87+
/*!
88+
* \brief Emit a ret instruction.
89+
* \param result The return result.
90+
* \note result must be a register.
91+
*/
92+
void EmitRet(vm::Instruction::Arg result);
93+
/*!
94+
* \brief Emit a goto instruction.
95+
* \param pc_offset The program counter offset as the jump offset.
96+
*/
97+
void EmitGoto(vm::Index pc_offset);
98+
/*!
99+
* \brief Emit an If instruction.
100+
* \param cond The register containing the cond value.
101+
* \param false_offset The program counter offset for the false branch.
102+
* \note result must be a register.
103+
*/
104+
void EmitIf(vm::Instruction::Arg cond, vm::Index false_offset);
105+
/*!
106+
* \brief Get function index by its name.
107+
* \param name The name of the function.
108+
* \return The argument corresponding to the function index.
109+
*/
110+
vm::Instruction::Arg GetFunction(const std::string& name);
111+
/*!
112+
* \brief Convert a constant value something that exec builder can understand.
113+
*
114+
* This function may update the constant pool to include the obj value.
115+
*
116+
* \param value The input constant value
117+
* \return An Arg that represents the result of constant argument.
118+
*/
119+
template <typename T>
120+
vm::Instruction::Arg ConvertConstant(T value) {
121+
TVMRetValue rv;
122+
rv = value;
123+
return ConvertConstant_(rv);
124+
}
125+
/*!
126+
* \brief Raw access to underlying executable build in progress.
127+
*/
128+
vm::Executable* exec() const;
129+
/*!
130+
* \brief Finalize the build, run formalize and get the final result.
131+
* \note This function should not be called during construction.
132+
*/
133+
ObjectPtr<vm::Executable> Get();
134+
/*!
135+
* \brief Create an ExecBuilder.
136+
* \return The ExecBuilder.
137+
*/
138+
TVM_DLL static ExecBuilder Create();
139+
140+
void VisitAttrs(AttrVisitor* v) {}
141+
142+
static constexpr const uint32_t _type_index = TypeIndex::kDynamic;
143+
static constexpr const char* _type_key = "relax.ExecBuilder";
144+
TVM_DECLARE_FINAL_OBJECT_INFO(ExecBuilderNode, Object);
145+
146+
private:
147+
/*!
148+
* \brief Convert a constant value something that exec builder can understand.
149+
*
150+
* This function may update the constant pool to include the obj value.
151+
*
152+
* \param obj The constant value to be emitted
153+
* \return An Arg that represents the result of constant argument.
154+
*/
155+
vm::Instruction::Arg ConvertConstant_(TVMRetValue obj);
156+
157+
/*!
158+
* \brief A helper function to check if an executable is legal by checking if registers are used
159+
* properly
160+
*/
161+
void CheckExecutable();
162+
/*!
163+
* \brief Formalize the executable.
164+
*/
165+
void Formalize();
166+
167+
/*! \brief The mutable internal executable. */
168+
ObjectPtr<vm::Executable> exec_; // mutable
169+
/*! \brief internal dedup map when creating index for a new constant */
170+
std::unordered_map<ObjectRef, vm::Index, StructuralHash, StructuralEqual> const_dedup_map_;
171+
};
172+
173+
class ExecBuilder : public ObjectRef {
174+
public:
175+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ExecBuilder, ObjectRef, ExecBuilderNode);
176+
};
177+
178+
} // namespace relax
179+
} // namespace tvm
180+
181+
#endif // TVM_RELAX_EXEC_BUILDER_H_
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file tvm/runtime/relax_vm/builtin.h
22+
* \brief Builtin runtime APIs.
23+
*/
24+
#ifndef TVM_RUNTIME_RELAX_VM_BUILTIN_H_
25+
#define TVM_RUNTIME_RELAX_VM_BUILTIN_H_
26+
27+
namespace tvm {
28+
namespace runtime {
29+
namespace relax_vm {
30+
31+
/*!
32+
* \brief Op code used in built-in match-shape function.
33+
*
34+
* The function takes the following signature:
35+
36+
* MatchShape(input_shape, shape_heap, n, c[0], r[0], c[1], r[1], ... c[n], r[n], err_ctx)
37+
*
38+
* This function provides runtime shape population and checking support for match-cast.
39+
* When a shape variable appears in the first time, we should load the shape and
40+
* populate the variable. When a shape variable already appears, we should
41+
* assert that it already equals an existing shape value.
42+
*
43+
* NOTE: It is OK to pass nullptr shape_heap if all code are AssertEqualToImm.
44+
*/
45+
enum class MatchShapeCode : int {
46+
/*!
47+
* \brief Perform an assertion that shape equals immediate.
48+
*
49+
* assert input_shape[i] == r[i]
50+
*/
51+
kAssertEqualToImm = 0,
52+
/*!
53+
* \brief This is the first time we see a symbolic shape variable, store to heap.
54+
*
55+
* shape_heap[r[i]] = input_shape[i]
56+
*/
57+
kStoreToHeap = 1,
58+
/*!
59+
* \brief skip and do not do anything.
60+
*/
61+
kNoOp = 2,
62+
/*!
63+
* \brief Peform an assertion that the shape equals a loaded value.
64+
*
65+
* assert input_shape[i] == shape_heap[r[i]]
66+
*/
67+
kAssertEqualToLoad = 3,
68+
};
69+
70+
/*!
71+
* \brief Op code used in builtin function MakeShape.
72+
*
73+
* MakeShape(shape_heap, n, c[0], r[0], c[1], r[1], ... c[n], r[n]).
74+
*
75+
* \note It is OK to pass nullptr to shape_heap if all code are UseImm.
76+
*/
77+
enum class MakeShapeCode : int {
78+
/*! \brief Use the following r[i] as immediate shape value. */
79+
kUseImm = 0,
80+
/*!
81+
* \brief Load shape value from the shape_heap[[r[i]].
82+
*/
83+
kLoadShape = 1,
84+
};
85+
86+
} // namespace relax_vm
87+
} // namespace runtime
88+
} // namespace tvm
89+
#endif // TVM_RUNTIME_RELAX_VM_BUILTIN_H_

0 commit comments

Comments
 (0)