Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 0 additions & 29 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,24 +211,6 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod);
*/
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const IRModule& mod);

/*!
* \brief Collect the device mapping information of each expression.
*
* \param expr The expression.
*
* \return The device mapping.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);

/*!
* \brief Collect the device anntation operators.
*
* \param expr The expression.
*
* \return The annotated expression to device type mapping for annotation ops.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);

/*!
* \brief Finds cases that the given match expression does not catch, if any.
*
Expand Down Expand Up @@ -268,17 +250,6 @@ TVM_DLL IRModule GetCalibrateModule(IRModule mod);
*/
TVM_DLL Map<GlobalVar, Array<Integer>> GetCalibrateOutputMap(const IRModule& mod);

/*!
* \brief Analyze the device context of each IR node in a given relay module.
*
* \param mod The module for analysis.
* \param default_device The default device used by unassigned IR nodes.
*
* \return The mapping between an IR node and its associated device.
*/
TVM_DLL std::unordered_map<Expr, Device, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>
ContextAnalysis(const IRModule& mod, const Device& default_device);

} // namespace relay
} // namespace tvm

Expand Down
3 changes: 2 additions & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,12 @@ TVM_DLL Pass ToANormalForm();
/*!
* \brief ToANormalForm but on incomplete graph.
*
* \param maybe_mod optional module holding definitions for global vars in \p expr
* \param expr the graph.
*
* \return The transformed program.
*/
TVM_DLL Expr ToANormalForm(const Expr& expr);
TVM_DLL Expr ToANormalForm(const Optional<IRModule>& maybe_mod, const Expr& expr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The optional argument should probably be second to allow for the default behaviour to be just passing expr ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that I think about it again I don't need this argument -- if during conversion to ANF the pass asks for the device for a global var then it is only for the purposes of maybe wrapping the same global var with an "on_device" annotation, but that is always a no-op. I''ll remove it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably change this to be a real pass?


/*!
* \brief Turn an expression into continuation passing style(CPS).
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ struct VMFunction {
/*! \brief The size of the frame for this function */
Index register_file_size;
/*! \brief The device type of each parameter for this function. */
std::vector<Index> params_device_type;
std::vector<DLDeviceType> params_device_type;

VMFunction(const std::string& name, std::vector<std::string> params,
const std::vector<Instruction>& instructions, Index register_file_size,
const std::vector<Index> params_device_type = {})
const std::vector<DLDeviceType> params_device_type = {})
: name(name),
params(params),
instructions(instructions),
Expand Down
49 changes: 0 additions & 49 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,6 @@
from .feature import Feature


def context_analysis(mod, default_device):
"""Analyze the device context information of each IR node in a Relay
program.

Parameters
----------
mod : tvm.IRModule
The input module.

default_device : tvm.runtime.Device
The default context allocated to an IR node.
"""
return _ffi_api.ContextAnalysis(mod, default_device)


def post_order_visit(expr, fvisit):
"""Recursively visit the ir in post DFS order node,
apply fvisit. Each node is guaranteed to be visited
Expand Down Expand Up @@ -268,40 +253,6 @@ def all_dtypes(expr):
return set(_ffi_api.all_dtypes(expr))


def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.

Parameters
----------
expr : tvm.relay.Expr
The input expression.

Returns
-------
ret : Dict[tvm.relay.ir.expr, int]
A dictionary mapping tvm.relay.Expr to device type.
"""
return _ffi_api.CollectDeviceInfo(expr)


def collect_device_annotation_ops(expr):
"""Collect the device annotation ops for the given expression.

Parameters
----------
expr : tvm.relay.Expr
The input expression.

Returns
-------
ret : Dict[tvm.relay.Expr, int]
A dictionary mapping tvm.relay.Expr to device type where the keys are
annotation expressions.
"""
return _ffi_api.CollectDeviceAnnotationOps(expr)


def get_total_mac_number(expr):
"""
Count the number of MACs (multiply-accumulate) of a model
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@
register_broadcast_schedule("fast_erf")
# a fake on_device schedule.
# this will not be used in actual computation
# as on_device will be removed during DeviceAnnotation pass
register_injective_schedule("on_device")


Expand Down
21 changes: 0 additions & 21 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,27 +544,6 @@ def MergeCompilerRegions():
return _ffi_api.MergeCompilerRegions()


def RewriteAnnotatedOps(fallback_device):
"""Rewrite the annotated program where annotation operators, e.g.
`on_device`, mark which device an expression should be scheduled to.
This pass helps heterogeneous execution where different operators may need
to be allocated on various devices.

Parameters
----------
fallback_device : int
The fallback device type. It is also used as the default device for
operators with no annotated device.

Returns
-------
ret: tvm.transform.Pass
The registered pass that rewrites an expression with annotated
`on_device` operators.
"""
return _ffi_api.RewriteDeviceAnnotation(fallback_device)


def ToANormalForm():
"""Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
Expand Down
48 changes: 24 additions & 24 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ class Parser {
* Useful for matching optional tokens, effectively looksahead by one.
*/
bool WhenMatch(const TokenType& token_type) {
VLOG(1) << "Parser::WhenMatch: Peek() == " << Peek();
VLOG(9) << "Parser::WhenMatch: Peek() == " << Peek();
if (Peek()->token_type == token_type) {
Consume(token_type);
return true;
Expand Down Expand Up @@ -594,7 +594,7 @@ class Parser {
template <typename R>
R WithSpan(std::function<R()> parser) {
auto start_span = Peek()->span;
VLOG(0) << "WithSpan: start_span = " << start_span;
VLOG(9) << "WithSpan: start_span = " << start_span;
R ast = parser();
if (ast.defined()) {
// The token at the head of the stream is now 1 past where we parsed. So we find its start
Expand All @@ -608,7 +608,7 @@ class Parser {
span_pos--;
}
auto end_token = tokens.at(span_pos);
VLOG(0) << "WithSpan: end_span = " << end_token->span;
VLOG(9) << "WithSpan: end_span = " << end_token->span;
ast->span = start_span.Merge(end_token->span);
}
return ast;
Expand Down Expand Up @@ -668,7 +668,7 @@ class Parser {
template <typename T>
Array<T> ParseSequence(TokenType start, TokenType sep, TokenType stop, std::function<T()> parse,
std::function<bool()> before_stop = nullptr) {
VLOG(0) << "Parser::ParseSequence: start=" << ToString(start) << " sep=" << ToString(sep)
VLOG(9) << "Parser::ParseSequence: start=" << ToString(start) << " sep=" << ToString(sep)
<< " stop=" << ToString(stop);
Match(start);

Expand All @@ -686,7 +686,7 @@ class Parser {
if (WhenMatch(stop)) {
return Array<T>();
} else {
VLOG(0) << "Parser::ParseSequence: parse first";
VLOG(9) << "Parser::ParseSequence: parse first";
auto data = parse();
Array<T> elements = {data};

Expand All @@ -695,7 +695,7 @@ class Parser {
// parse '( expr ',' * ')'
} else if (WhenMatch(sep)) {
while (true) {
VLOG(0) << "Parser::ParseSequence: parse element";
VLOG(9) << "Parser::ParseSequence: parse element";
if (WhenMatch(stop)) {
break;
} else {
Expand Down Expand Up @@ -893,12 +893,12 @@ class Parser {

/*! \brief Parse a single Relay expression. */
Expr ParseExpr() {
VLOG(0) << "Parser::ParseExpr";
VLOG(9) << "Parser::ParseExpr";
return WithSpan<Expr>([this] {
std::vector<Expr> exprs;

while (true) {
VLOG(0) << "Parser::ParseExpr: parsing a single expression";
VLOG(9) << "Parser::ParseExpr: parsing a single expression";
auto next = Peek();
switch (next->token_type) {
// For graph or let, match first rhs, then invoke ParseBindingExpr
Expand Down Expand Up @@ -1011,7 +1011,7 @@ class Parser {
// This ensures for n sequential bindings
// the call depth will be the same before
// and after parsing the n bindings.
VLOG(0) << "Parser::ParseBindingExpr";
VLOG(9) << "Parser::ParseBindingExpr";
std::vector<std::tuple<Var, Expr, Span>> bindings;
int scopes = 0;

Expand Down Expand Up @@ -1085,7 +1085,7 @@ class Parser {
* Handles things of the form [T1, ..., TN](arg1: U1, ..., argN : UN) -> Ret { body }.
*/
Function ParseFunctionDef() {
VLOG(0) << "Parser::ParseFunctionDef";
VLOG(9) << "Parser::ParseFunctionDef";
return WithSpan<Function>([&]() {
PushScope();
PushTypeScope();
Expand Down Expand Up @@ -1147,7 +1147,7 @@ class Parser {
/*! \brief Parse an if-expression. */
Expr ParseIf() {
return WithSpan<Expr>([&]() {
VLOG(0) << "Parser::ParseIf";
VLOG(9) << "Parser::ParseIf";
Consume(TokenType::kIf);

auto guard = WithSpan<Expr>([&] { return Parens<Expr>([&] { return ParseExpr(); }); });
Expand Down Expand Up @@ -1186,7 +1186,7 @@ class Parser {
* This function recursively parses a pattern.
*/
Pattern ParsePattern() {
VLOG(0) << "Parser::ParsePattern";
VLOG(9) << "Parser::ParsePattern";
auto next = Peek();
switch (next->token_type) {
case TokenType::kUnderscore: {
Expand Down Expand Up @@ -1249,7 +1249,7 @@ class Parser {
}

Expr ParseExprBinOp() {
VLOG(0) << "Parser::ParseExprBinOp";
VLOG(9) << "Parser::ParseExprBinOp";
return WithSpan<Expr>([this] {
// We must parse at least one expression, the default
// case is that there is no operator and we will fall
Expand Down Expand Up @@ -1333,7 +1333,7 @@ class Parser {
}

ObjectRef ParseAttributeValue() {
VLOG(0) << "Parser::ParseAttributeValue";
VLOG(9) << "Parser::ParseAttributeValue";
auto next = Peek();
switch (next->token_type) {
case TokenType::kFloat:
Expand Down Expand Up @@ -1375,7 +1375,7 @@ class Parser {
}

Map<String, ObjectRef> ParseAttrs() {
VLOG(0) << "Parser::ParseAttrs";
VLOG(9) << "Parser::ParseAttrs";
Map<String, ObjectRef> kwargs;
while (Peek()->token_type == TokenType::kIdentifier) {
auto key = GetHierarchicalName(ParseHierarchicalName().data);
Expand All @@ -1387,14 +1387,14 @@ class Parser {
kwargs.Set(key, value);
WhenMatch(TokenType::kComma);
}
VLOG(0) << "Parser::ParseAttrs: kwargs=" << kwargs;
VLOG(9) << "Parser::ParseAttrs: kwargs=" << kwargs;
return kwargs;
}

Expr ParseCallArgs(Expr op) {
ICHECK(op.defined()) << "the operator must be defined";

VLOG(0) << "Parser::ParseCallArgs";
VLOG(9) << "Parser::ParseCallArgs";
Attrs attrs;
std::string op_key;
bool is_op = false;
Expand Down Expand Up @@ -1471,7 +1471,7 @@ class Parser {
}

Expr ParseCallExpr() {
VLOG(0) << "Parser::ParseCallExpr";
VLOG(9) << "Parser::ParseCallExpr";
return WithSpan<Expr>([this] {
Expr expr = ParseAtomicExpr();
// Parse as many call args as possible, building up expression
Expand Down Expand Up @@ -1500,7 +1500,7 @@ class Parser {
}

Expr GetOp(const std::string& op_name, const Span& span) {
VLOG(0) << "op_name=" << op_name << " span=" << span;
VLOG(9) << "op_name=" << op_name << " span=" << span;
try {
return Op::Get(op_name);
} catch (const Error& e) {
Expand All @@ -1513,7 +1513,7 @@ class Parser {
}

Expr ParseAtomicExpr() {
VLOG(0) << "Parser::ParseAtomicExpr";
VLOG(9) << "Parser::ParseAtomicExpr";
Expr expr = WithSpan<Expr>([this] {
auto next = Peek();
switch (next->token_type) {
Expand Down Expand Up @@ -1649,7 +1649,7 @@ class Parser {
auto token = Match(TokenType::kInteger);
auto index = token.ToNumber();
auto span = token->span.Merge(expr->span);
VLOG(0) << "Parser::ParseAtomicExpr: tuple get item";
VLOG(9) << "Parser::ParseAtomicExpr: tuple get item";
return relay::TupleGetItem(expr, index, span);
} else {
return expr;
Expand Down Expand Up @@ -1870,7 +1870,7 @@ class Parser {

Parser InitParser(const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
VLOG(0) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size();
VLOG(9) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size();
SourceName src_name = SourceName::Get(file_name);
Source source(src_name, file_content);

Expand Down Expand Up @@ -1909,7 +1909,7 @@ Parser InitParser(const std::string& file_name, const std::string& file_content,

IRModule ParseModule(const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
VLOG(0) << "ParseModule";
VLOG(9) << "ParseModule";
auto parser = InitParser(file_name, file_content, init_module, init_meta_table);
auto mod = parser.ParseModule();
ICHECK(mod.defined()) << "The parser must return a non-null module.";
Expand All @@ -1923,7 +1923,7 @@ IRModule ParseModule(const std::string& file_name, const std::string& file_conte
}

Expr ParseExpr(const std::string& file_name, const std::string& file_content) {
VLOG(0) << "ParseExpr";
VLOG(9) << "ParseExpr";
auto parser = InitParser(file_name, file_content, Optional<IRModule>(), MetaTable());
parser.ParseSemVer(false);
parser.PushScope();
Expand Down
4 changes: 2 additions & 2 deletions src/parser/tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ struct Tokenizer {
int line = this->line;
int col = this->col;
auto next = Peek();
VLOG(1) << "tvm::parser::TokenizeOnce: next=" << next;
VLOG(9) << "tvm::parser::TokenizeOnce: next=" << next;
if (next == '\n') {
auto token = NewToken(TokenType::kNewline);
Next();
Expand Down Expand Up @@ -550,7 +550,7 @@ struct Tokenizer {
}

void Tokenize() {
VLOG(0) << "tvm::parser::Tokenize";
VLOG(9) << "tvm::parser::Tokenize";
while (this->More()) {
auto token = TokenizeOnce();
ICHECK(token.defined());
Expand Down
Loading