Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/script/printer/doc.h
Original file line number Diff line number Diff line change
Expand Up @@ -1067,7 +1067,7 @@ class FunctionDocNode : public StmtDocNode {
/*! \brief Decorators of function. */
Array<ExprDoc> decorators;
/*! \brief The return type of function. */
ExprDoc return_type{nullptr};
Optional<ExprDoc> return_type{NullOpt};
/*! \brief The body of function. */
Array<StmtDoc> body;

Expand Down Expand Up @@ -1100,7 +1100,7 @@ class FunctionDoc : public StmtDoc {
* \param body The body of function.
*/
explicit FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
ExprDoc return_type, Array<StmtDoc> body);
Optional<ExprDoc> return_type, Array<StmtDoc> body);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FunctionDoc, StmtDoc, FunctionDocNode);
};

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/printer/doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,15 +439,15 @@ class FunctionDoc(StmtDoc):
name: IdDoc
args: Sequence[AssignDoc]
decorators: Sequence[ExprDoc]
return_type: ExprDoc
return_type: Optional[ExprDoc]
body: Sequence[StmtDoc]

def __init__(
self,
name: IdDoc,
args: List[AssignDoc],
decorators: List[ExprDoc],
return_type: ExprDoc,
return_type: Optional[ExprDoc],
body: List[StmtDoc],
):
self.__init_handle_by_constructor__(
Expand Down
4 changes: 2 additions & 2 deletions src/script/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ ReturnDoc::ReturnDoc(ExprDoc value) {
}

FunctionDoc::FunctionDoc(IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
ExprDoc return_type, Array<StmtDoc> body) {
Optional<ExprDoc> return_type, Array<StmtDoc> body) {
ObjectPtr<FunctionDocNode> n = make_object<FunctionDocNode>();
n->name = name;
n->args = args;
Expand Down Expand Up @@ -345,7 +345,7 @@ TVM_REGISTER_GLOBAL("script.printer.ReturnDoc").set_body_typed([](ExprDoc value)
TVM_REGISTER_NODE_TYPE(FunctionDocNode);
TVM_REGISTER_GLOBAL("script.printer.FunctionDoc")
.set_body_typed([](IdDoc name, Array<AssignDoc> args, Array<ExprDoc> decorators,
ExprDoc return_type, Array<StmtDoc> body) {
Optional<ExprDoc> return_type, Array<StmtDoc> body) {
return FunctionDoc(name, args, decorators, return_type, body);
});

Expand Down
176 changes: 164 additions & 12 deletions src/script/printer/python_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,111 @@ namespace tvm {
namespace script {
namespace printer {

/*!
* \brief Operator precedence
*
* This is based on
* https://docs.python.org/3/reference/expressions.html#operator-precedence
*/
enum class ExprPrecedence : int32_t {
/*! \brief Unknown precedence */
kUnkown = 0,
/*! \brief Lambda Expression */
kLambda = 1,
/*! \brief Conditional Expression */
kIfThenElse = 2,
/*! \brief Boolean OR */
kBooleanOr = 3,
/*! \brief Boolean AND */
kBooleanAnd = 4,
/*! \brief Boolean NOT */
kBooleanNot = 5,
/*! \brief Comparisons */
kComparison = 6,
/*! \brief Bitwise OR */
kBitwiseOr = 7,
/*! \brief Bitwise XOR */
kBitwiseXor = 8,
/*! \brief Bitwise AND */
kBitwiseAnd = 9,
/*! \brief Shift Operators */
kShift = 10,
/*! \brief Addition and subtraction */
kAdd = 11,
/*! \brief Multiplication, division, floor division, remainder */
kMult = 12,
/*! \brief Positive negative and bitwise NOT */
kUnary = 13,
/*! \brief Exponentiation */
kExp = 14,
/*! \brief Index access, attribute access, call and atom expression */
kIdentity = 15,
};

ExprPrecedence GetExprPrecedence(const ExprDoc& doc) {
// Key is the value of OperationDocNode::Kind
static const std::vector<ExprPrecedence> op_kind_precedence = []() {
using OpKind = OperationDocNode::Kind;
std::map<OpKind, ExprPrecedence> raw_table = {
{OpKind::kUSub, ExprPrecedence::kUnary},
{OpKind::kInvert, ExprPrecedence::kUnary},
{OpKind::kAdd, ExprPrecedence::kAdd},
{OpKind::kSub, ExprPrecedence::kAdd},
{OpKind::kMult, ExprPrecedence::kMult},
{OpKind::kDiv, ExprPrecedence::kMult},
{OpKind::kFloorDiv, ExprPrecedence::kMult},
{OpKind::kMod, ExprPrecedence::kMult},
{OpKind::kPow, ExprPrecedence::kExp},
{OpKind::kLShift, ExprPrecedence::kShift},
{OpKind::kRShift, ExprPrecedence::kShift},
{OpKind::kBitAnd, ExprPrecedence::kBitwiseAnd},
{OpKind::kBitOr, ExprPrecedence::kBitwiseOr},
{OpKind::kBitXor, ExprPrecedence::kBitwiseXor},
{OpKind::kLt, ExprPrecedence::kComparison},
{OpKind::kLtE, ExprPrecedence::kComparison},
{OpKind::kEq, ExprPrecedence::kComparison},
{OpKind::kNotEq, ExprPrecedence::kComparison},
{OpKind::kGt, ExprPrecedence::kComparison},
{OpKind::kGtE, ExprPrecedence::kComparison},
{OpKind::kIfThenElse, ExprPrecedence::kIfThenElse},
};
int n = static_cast<int>(OpKind::kSpecialEnd);
std::vector<ExprPrecedence> table(n + 1, ExprPrecedence::kUnkown);
for (const auto& kv : raw_table) {
table[static_cast<int>(kv.first)] = kv.second;
}
return table;
}();

// Key is the type index of Doc
static const std::unordered_map<uint32_t, ExprPrecedence> doc_type_precedence = {
{LiteralDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
{IdDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
{AttrAccessDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
{IndexDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
{CallDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
{LambdaDocNode::RuntimeTypeIndex(), ExprPrecedence::kLambda},
{TupleDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
{ListDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
{DictDocNode::RuntimeTypeIndex(), ExprPrecedence::kIdentity},
};

if (const auto* op_doc = doc.as<OperationDocNode>()) {
size_t kind = static_cast<int>(op_doc->kind);
ICHECK_LT(kind, op_kind_precedence.size()) << "ValueError: Invalid operation: " << kind;
ExprPrecedence precedence = op_kind_precedence[kind];
ICHECK(precedence != ExprPrecedence::kUnkown)
<< "Precedence for operator " << static_cast<int>(op_doc->kind) << " is unknown";
return precedence;
}
auto it = doc_type_precedence.find(doc->type_index());
if (it != doc_type_precedence.end()) {
return it->second;
}
ICHECK(false) << "Precedence for doc type " << doc->GetTypeKey() << " is unknown";
throw;
}

class PythonDocPrinter : public DocPrinter {
public:
explicit PythonDocPrinter(int indent_spaces = 4) : DocPrinter(indent_spaces) {}
Expand Down Expand Up @@ -98,6 +203,42 @@ class PythonDocPrinter : public DocPrinter {
}
}

/*!
* \brief Print expression and add parenthesis if needed.
*/
void PrintChildExpr(const ExprDoc& doc, ExprPrecedence parent_precedence,
bool parenthesis_for_same_precedence = false) {
ExprPrecedence doc_precedence = GetExprPrecedence(doc);
if (doc_precedence < parent_precedence ||
(parenthesis_for_same_precedence && doc_precedence == parent_precedence)) {
output_ << "(";
PrintDoc(doc);
output_ << ")";
} else {
PrintDoc(doc);
}
}

/*!
* \brief Print expression and add parenthesis if doc has lower precedence than parent.
*/
void PrintChildExpr(const ExprDoc& doc, const ExprDoc& parent,
bool parenthesis_for_same_precedence = false) {
ExprPrecedence parent_precedence = GetExprPrecedence(parent);
return PrintChildExpr(doc, parent_precedence, parenthesis_for_same_precedence);
}

/*!
* \brief Print expression and add parenthesis if doc doesn't have higher precedence than parent.
*
* This function should be used to print an child expression that needs to be wrapped
* by parenthesis even if it has the same precedence as its parent, e.g., the `b` in `a + b`
* and the `b` and `c` in `a if b else c`.
*/
void PrintChildExprConservatively(const ExprDoc& doc, const ExprDoc& parent) {
PrintChildExpr(doc, parent, /*parenthesis_for_same_precedence=*/true);
}

void MaybePrintCommentInline(const StmtDoc& stmt) {
if (stmt->comment.defined()) {
const std::string& comment = stmt->comment.value();
Expand Down Expand Up @@ -161,12 +302,12 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) {
void PythonDocPrinter::PrintTypedDoc(const IdDoc& doc) { output_ << doc->name; }

void PythonDocPrinter::PrintTypedDoc(const AttrAccessDoc& doc) {
PrintDoc(doc->value);
PrintChildExpr(doc->value, doc);
output_ << "." << doc->name;
}

void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) {
PrintDoc(doc->value);
PrintChildExpr(doc->value, doc);
if (doc->indices.size() == 0) {
output_ << "[()]";
} else {
Expand Down Expand Up @@ -226,29 +367,38 @@ void PythonDocPrinter::PrintTypedDoc(const OperationDoc& doc) {
// Unary Operators
ICHECK_EQ(doc->operands.size(), 1);
output_ << OperatorToString(doc->kind);
PrintDoc(doc->operands[0]);
PrintChildExpr(doc->operands[0], doc);
} else if (doc->kind == OpKind::kPow) {
// Power operator is different than other binary operators
// It's right-associative and binds less tightly than unary operator on its right.
// https://docs.python.org/3/reference/expressions.html#the-power-operator
// https://docs.python.org/3/reference/expressions.html#operator-precedence
ICHECK_EQ(doc->operands.size(), 2);
PrintChildExprConservatively(doc->operands[0], doc);
output_ << " ** ";
PrintChildExpr(doc->operands[1], ExprPrecedence::kUnary);
} else if (doc->kind < OpKind::kBinaryEnd) {
// Binary Operator
ICHECK_EQ(doc->operands.size(), 2);
PrintDoc(doc->operands[0]);
PrintChildExpr(doc->operands[0], doc);
output_ << " " << OperatorToString(doc->kind) << " ";
PrintDoc(doc->operands[1]);
PrintChildExprConservatively(doc->operands[1], doc);
} else if (doc->kind == OpKind::kIfThenElse) {
ICHECK_EQ(doc->operands.size(), 3)
<< "ValueError: IfThenElse requires 3 operands, but got " << doc->operands.size();
PrintDoc(doc->operands[1]);
PrintChildExpr(doc->operands[1], doc);
output_ << " if ";
PrintDoc(doc->operands[0]);
PrintChildExprConservatively(doc->operands[0], doc);
output_ << " else ";
PrintDoc(doc->operands[2]);
PrintChildExprConservatively(doc->operands[2], doc);
} else {
LOG(FATAL) << "Unknown OperationDocNode::Kind " << static_cast<int>(doc->kind);
throw;
}
}

void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) {
PrintDoc(doc->callee);
PrintChildExpr(doc->callee, doc);

output_ << "(";

Expand Down Expand Up @@ -285,7 +435,7 @@ void PythonDocPrinter::PrintTypedDoc(const LambdaDoc& doc) {
output_ << "lambda ";
PrintJoinedDocs(doc->args, ", ");
output_ << ": ";
PrintDoc(doc->body);
PrintChildExpr(doc->body, doc);
}

void PythonDocPrinter::PrintTypedDoc(const ListDoc& doc) {
Expand Down Expand Up @@ -444,8 +594,10 @@ void PythonDocPrinter::PrintTypedDoc(const FunctionDoc& doc) {
PrintJoinedDocs(doc->args, ", ");
output_ << ")";

output_ << " -> ";
PrintDoc(doc->return_type);
if (doc->return_type.defined()) {
output_ << " -> ";
PrintDoc(doc->return_type.value());
}

output_ << ":";

Expand Down
47 changes: 29 additions & 18 deletions tests/python/unittest/test_tvmscript_printer_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,31 @@

import pytest

import tvm
from tvm.script.printer.doc import (
LiteralDoc,
IdDoc,
AssertDoc,
AssignDoc,
AttrAccessDoc,
IndexDoc,
CallDoc,
OperationKind,
OperationDoc,
ClassDoc,
DictDoc,
ExprStmtDoc,
ForDoc,
FunctionDoc,
IdDoc,
IfDoc,
IndexDoc,
LambdaDoc,
TupleDoc,
ListDoc,
DictDoc,
LiteralDoc,
OperationDoc,
OperationKind,
ReturnDoc,
ScopeDoc,
SliceDoc,
StmtBlockDoc,
AssignDoc,
IfDoc,
TupleDoc,
WhileDoc,
ForDoc,
ScopeDoc,
ExprStmtDoc,
AssertDoc,
ReturnDoc,
FunctionDoc,
ClassDoc,
)


Expand Down Expand Up @@ -450,6 +451,13 @@ def test_return_doc():
[IdDoc("test"), IdDoc("test2")],
],
)
@pytest.mark.parametrize(
"return_type",
[
None,
LiteralDoc(None),
],
)
@pytest.mark.parametrize(
"body",
[
Expand All @@ -458,9 +466,8 @@ def test_return_doc():
[ExprStmtDoc(IdDoc("x")), ExprStmtDoc(IdDoc("y"))],
],
)
def test_function_doc(args, decorators, body):
def test_function_doc(args, decorators, return_type, body):
name = IdDoc("name")
return_type = LiteralDoc(None)

doc = FunctionDoc(name, args, decorators, return_type, body)

Expand Down Expand Up @@ -504,3 +511,7 @@ def test_stmt_doc_comment():
comment = "test comment"
doc.comment = comment
assert doc.comment == comment


if __name__ == "__main__":
tvm.testing.main()
Loading