From 696bdaf5054f33fb7101d920ed003567fb1177da Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 8 Jul 2022 21:52:40 -0400 Subject: [PATCH 01/23] Add expr doc --- include/tvm/script/printer/doc.h | 477 ++++++++++++++++++ python/tvm/script/printer/doc.py | 220 +++++++- src/script/printer/doc.cc | 145 ++++++ .../unittest/test_tvmscript_printer_doc.py | 218 +++++++- 4 files changed, 1057 insertions(+), 3 deletions(-) diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 67c27bd45a1d..1f80fc905db0 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -63,6 +63,8 @@ class Doc : public ObjectRef { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Doc, ObjectRef, DocNode); }; +class ExprDoc; + /*! * \brief The base class of expression doc. * @@ -70,6 +72,34 @@ class Doc : public ObjectRef { */ class ExprDocNode : public DocNode { public: + /*! + * \brief Create a doc representing attribute access on the current ExprDoc + * \param attr The attribute to access. + */ + ExprDoc Attr(String attr) const; + + /*! + * \brief Create a doc representing index access on the current ExprDoc + * \param indices The indices to access. + */ + ExprDoc Index(Array indices) const; + + /*! + * \brief Create a doc representing calling the current ExprDoc + * \param args The positional arguments of the function call. + */ + ExprDoc Call(Array args) const; + + /*! + * \brief Create a doc representing attribute access on the current ExprDoc + * \param args The positional arguments of the function call. + * \param kwargs_keys Keys of keywords arguments of the function call. + * \param kwargs_values Values of keywords arguments of the function call. + */ + ExprDoc Call(Array args, // + Array kwargs_keys, // + Array kwargs_values) const; + void VisitAttrs(AttrVisitor* v) { DocNode::VisitAttrs(v); } static constexpr const char* _type_key = "script.printer.ExprDoc"; @@ -158,6 +188,453 @@ class LiteralDoc : public ExprDoc { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode); }; +/*! + * \brief Doc that represents identifier. + * + * \sa IdDoc + */ +class IdDocNode : public ExprDocNode { + public: + /*! \brief The name of the identifier */ + String name; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("name", &name); + } + + static constexpr const char* _type_key = "script.printer.IdDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(IdDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of IdDocNode. + * + * \sa IdDocNode + */ +class IdDoc : public ExprDoc { + public: + /*! + * \brief Constructor of IdDoc. + * \param name The name of identifier. + */ + explicit IdDoc(String name); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IdDoc, ExprDoc, IdDocNode); +}; + +/*! + * \brief Doc that represents attribute access on another expression. + * + * \sa AttrAccessDoc + */ +class AttrAccessDocNode : public ExprDocNode { + public: + /*! \brief The target expression to be accessed */ + ExprDoc value{nullptr}; + /*! \brief The attribute to be accessed */ + String attr; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("value", &value); + v->Visit("attr", &attr); + } + + static constexpr const char* _type_key = "script.printer.AttrAccessDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttrAccessDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of AttrAccessDocNode. + * + * \sa AttrAccessDocNode + */ +class AttrAccessDoc : public ExprDoc { + public: + /*! + * \brief Constructor of AttrAccessDoc + * \param value The target expression of attribute access. + * \param attr The name of attribute to access. + */ + explicit AttrAccessDoc(ExprDoc value, String attr); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AttrAccessDoc, ExprDoc, AttrAccessDocNode); +}; + +/*! + * \brief Doc that represents index access on another expression. + * + * \sa IndexDoc + */ +class IndexDocNode : public ExprDocNode { + public: + /*! \brief The container value to be accessed */ + ExprDoc value{nullptr}; + /*! + * \brief The indices to access + * + * Possible actual types: + * - ExprDoc (single point access like a[1, 2]) + * - SliceDoc (slice access like a[1:5, 2]) + */ + Array indices; // Each element is union of: Slice / ExprDoc + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("value", &value); + v->Visit("indices", &indices); + } + + static constexpr const char* _type_key = "script.printer.IndexDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(IndexDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of IndexDocNode. + * + * \sa IndexDocNode + */ +class IndexDoc : public ExprDoc { + public: + /*! + * \brief Constructor of IndexDoc + * \param value The target expression of index access. + * \param indices The indices to access. + */ + explicit IndexDoc(ExprDoc value, Array indices); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IndexDoc, ExprDoc, IndexDocNode); +}; + +/*! + * \brief Doc that represents function call. + * + * \sa CallDoc + */ +class CallDocNode : public ExprDocNode { + public: + /*! \brief The callee of this function call */ + ExprDoc callee{nullptr}; + /*! \brief The positional arguments */ + Array args; + /*! \brief The keys of keyword arguments */ + Array kwargs_keys; + /*! + * \brief The values of keyword arguments. + * + * The i-th element is the value of the i-th key in `kwargs_keys`. + * It must have the same length as `kwargs_keys`. + */ + Array kwargs_values; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("callee", &callee); + v->Visit("args", &args); + v->Visit("kwargs_keys", &kwargs_keys); + v->Visit("kwargs_values", &kwargs_values); + } + + static constexpr const char* _type_key = "script.printer.CallDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(CallDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of CallDocNode. + * + * \sa CallDocNode + */ +class CallDoc : public ExprDoc { + public: + /*! + * \brief Constructor of CallDoc + * \param callee The callee of this function call. + * \param args The positional arguments. + * \param kwargs_keys Keys of keyword arguments. + * \param kwargs_values Values of keyword arguments, must have the same length as `kwargs_keys. + */ + CallDoc(ExprDoc callee, Array args, Array kwargs_keys, + Array kwargs_values); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CallDoc, ExprDoc, CallDocNode); +}; + +/*! + * \brief Doc that represents operation. + * + * It can be unary, binary and other special operators (for example, + * the if-then-else expression). + * + * \sa OperationDoc + */ +class OperationDocNode : public ExprDocNode { + public: + enum class Kind : int32_t { + // Unary operators + kUnaryStart, + kUSub, // -x + kInvert, // ~x + kUnaryEnd, + + // Binary operators + kBinaryStart, + kAdd, // + + kSub, // - + kMult, // * + kDiv, // / + kFloorDiv, // // in Python + kMod, // % in Python + kPow, // ** in Python + kLShift, // << + kRShift, // >> + kBitAnd, // & + kBitOr, // | + kBitXor, // ^ + kLt, // < + kLtE, // <= + kEq, // == + kNotEq, // != + kGt, // > + kGtE, // >= + kBinaryEnd, + + // Special + kSpecialStart, + kAssert, + }; + + /*! \brief The kind of operation (operator) */ + Kind kind; + /*! \brief Operands of this expression */ + Array operands; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("kind", &kind); + v->Visit("operands", &operands); + } + + static constexpr const char* _type_key = "script.printer.OperationDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(OperationDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of OperationDocNode. + * + * \sa OperationDocNode + */ +class OperationDoc : public ExprDoc { + public: + /*! + * \brief Constructor of OperationDoc + * \param kind The kind of operation. + * \param operands Operands of this expression. + */ + explicit OperationDoc(OperationDocNode::Kind kind, Array operands); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(OperationDoc, ExprDoc, OperationDocNode); +}; + +/*! + * \brief Doc that represents anonymous function. + * + * LambdaDoc can only have positional arguments without type annotation, + * and a single expression as body. + * + * \sa LambdaDoc + */ +class LambdaDocNode : public ExprDocNode { + public: + /*! \brief The arguments of this anonymous function */ + Array args; + /*! \brief The body of this anonymous function */ + ExprDoc body{nullptr}; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("args", &args); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "script.printer.LambdaDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(LambdaDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of LambdaDocNode. + * + * \sa LambdaDocNode + */ +class LambdaDoc : public ExprDoc { + public: + /*! + * \brief Constructor of LambdaDoc + * \param args Arguments of this function. + * \param body Body expression of this function. + */ + explicit LambdaDoc(Array args, ExprDoc body); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LambdaDoc, ExprDoc, LambdaDocNode); +}; + +/*! + * \brief Doc that represents tuple literal. + * + * \sa TupleDoc + */ +class TupleDocNode : public ExprDocNode { + public: + /*! \brief Elements of tuple */ + Array elements; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("elements", &elements); + } + + static constexpr const char* _type_key = "script.printer.TupleDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of TupleDocNode. + * + * \sa TupleDocNode + */ +class TupleDoc : public ExprDoc { + public: + /*! + * \brief Create an empty TupleDoc + */ + TupleDoc() : TupleDoc(runtime::make_object()) {} + /*! + * \brief Constructor of TupleDoc + * \param elements Elements of tuple. + */ + explicit TupleDoc(Array elements); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleDoc, ExprDoc, TupleDocNode); +}; + +/*! + * \brief Doc that represents list literal. + * + * \sa AttrAccessDoc + */ +class ListDocNode : public ExprDocNode { + public: + /*! \brief Elements of list */ + Array elements; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("elements", &elements); + } + + static constexpr const char* _type_key = "script.printer.ListDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(ListDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of ListDocNode. + * + * \sa ListDocNode + */ +class ListDoc : public ExprDoc { + public: + /*! + * \brief Create an empty ListDoc + */ + ListDoc() : ListDoc(runtime::make_object()) {} + /*! + * \brief Constructor of ListDoc + * \param elements Elements of list. + */ + explicit ListDoc(Array elements); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ListDoc, ExprDoc, ListDocNode); +}; + +/*! + * \brief Doc that represents dictionary literal. + * + * \sa AttrAccessDoc + */ +class DictDocNode : public ExprDocNode { + public: + /*! \brief keys of dictionary */ + Array keys; + /*! + * \brief Values of dictionary + * + * The i-th element is the value of the i-th element of `keys`. + * It must have the same length as `keys`. + */ + Array values; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("keys", &keys); + v->Visit("values", &values); + } + + static constexpr const char* _type_key = "script.printer.DictDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(DictDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of DictDocNode. + * + * \sa DictDocNode + */ +class DictDoc : public ExprDoc { + public: + /*! + * \brief Create an empty dictionary + */ + DictDoc() : DictDoc(runtime::make_object()) {} + /*! + * \brief Constructor of DictDoc + * \param keys Keys of dictionary. + * \param values Values of dictionary, must have same length as `keys`. + */ + explicit DictDoc(Array keys, Array values); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DictDoc, ExprDoc, DictDocNode); +}; + +/*! + * \brief Doc that represents slice in Index expression. + * + * This doc can only appear in IndexDoc::indices. + * + * \sa AttrAccessDoc + */ +class SliceDocNode : public DocNode { + public: + /*! \brief The start of slice */ + Optional start; + /*! \brief The exclusive end of slice */ + Optional stop; + + void VisitAttrs(AttrVisitor* v) { + DocNode::VisitAttrs(v); + v->Visit("start", &start); + v->Visit("stop", &stop); + } + + static constexpr const char* _type_key = "script.printer.SliceDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(SliceDocNode, DocNode); +}; + +/*! + * \brief Reference type of SliceDocNode. + * + * \sa SliceDocNode + */ +class SliceDoc : public Doc { + public: + /*! + * \brief Constructor of SliceDoc + * \param start The start of slice. + * \param start The exclusive end of slice. + */ + explicit SliceDoc(Optional start, Optional stop); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SliceDoc, Doc, SliceDocNode); +}; + } // namespace printer } // namespace script } // namespace tvm diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index f6179d7351b2..0bddb2608851 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -16,8 +16,13 @@ # under the License. """Doc types for TVMScript Unified Printer""" +from typing import List, Dict, Tuple, Optional, Union +from enum import IntEnum, auto, unique + import tvm._ffi +import tvm.ir.container from tvm.runtime import Object +from tvm.tir import FloatImm, IntImm from . import _ffi_api @@ -29,12 +34,63 @@ class Doc(Object): class ExprDoc(Object): """Base class of all expression Docs""" + def attr_access(self, attr: str) -> "AttrAccessDoc": + """ + Create a doc that represents attribute access on self. + + Parameters + ---------- + attr : str + The attribute name to access + + Returns + ------- + doc : AttrAccessDoc + """ + return _ffi_api.ExprDocAttr(self, attr) # type: ignore + + def index_access(self, indices: List[Union["ExprDoc", "SliceDoc"]]) -> "IndexDoc": + """ + Create a doc that represents index access on self. + + Parameters + ---------- + indices : List[Union["ExprDoc", "SliceDoc"]] + The indices to access + + Returns + ------- + doc : IndexDoc + """ + return _ffi_api.ExprDocIndex(self, indices) # type: ignore + + def call_with(self, *args: Tuple["ExprDoc"], **kwargs: Dict[str, "ExprDoc"]) -> "CallDoc": + """ + Create a doc that represents function call, with self as callee. + + Parameters + ---------- + *args : ExprDoc + The positional arguments of the function call. + **kwargs + The keyword arguments of the function call. + + Returns + ------- + doc : CallDoc + """ + kwargs_keys = list(kwargs.keys()) + kwargs_values = list(kwargs.values()) + return _ffi_api.ExprDocCall(self, args, kwargs_keys, kwargs_values) # type: ignore + @tvm._ffi.register_object("script.printer.LiteralDoc") class LiteralDoc(ExprDoc): """Doc that represents literal value""" - def __init__(self, value): + value: Union[str, IntImm, FloatImm, None] + + def __init__(self, value: Union[str, float, bool, int, None]): if value is None: self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone) # type: ignore elif isinstance(value, str): @@ -47,3 +103,165 @@ def __init__(self, value): self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt, value) # type: ignore else: raise TypeError(f"Unsupported type {type(value)} for LiteralDoc") + + +@tvm._ffi.register_object("script.printer.IdDoc") +class IdDoc(ExprDoc): + """Doc that represents identifier""" + + name: str + + def __init__(self, name: str): + self.__init_handle_by_constructor__(_ffi_api.IdDoc, name) # type: ignore + + +@tvm._ffi.register_object("script.printer.AttrAccessDoc") +class AttrAccessDoc(ExprDoc): + """Doc that represents attribute access on an expression""" + + value: ExprDoc + attr: str + + def __init__(self, value: ExprDoc, attr: str): + self.__init_handle_by_constructor__(_ffi_api.AttrAccessDoc, value, attr) # type: ignore + + +@tvm._ffi.register_object("script.printer.IndexDoc") +class IndexDoc(ExprDoc): + """Doc that represents index access on an expression""" + + value: ExprDoc + indices: tvm.ir.container.Array # actual type: List[Union[ExprDoc, "SliceDoc"]] + + def __init__(self, value: ExprDoc, indices: List[Union[ExprDoc, "SliceDoc"]]): + self.__init_handle_by_constructor__(_ffi_api.IndexDoc, value, indices) # type: ignore + + +@tvm._ffi.register_object("script.printer.CallDoc") +class CallDoc(ExprDoc): + """Doc that represents function call""" + + callee: ExprDoc + args: tvm.ir.container.Array # actual type: List[ExprDoc] + kwargs_keys: tvm.ir.container.Array # actual type: List[str] + kwargs_values: tvm.ir.container.Array # actual type: List[ExprDoc] + + def __init__(self, callee: ExprDoc, *args: Tuple[ExprDoc], **kwargs: Dict[str, ExprDoc]): + kwargs_keys = list(kwargs.keys()) + kwargs_values = list(kwargs.values()) + self.__init_handle_by_constructor__( + _ffi_api.CallDoc, callee, args, kwargs_keys, kwargs_values # type: ignore + ) + + +@unique +class OperationKind(IntEnum): + """ + This enum represents the kind of operation (operator) in OpeartionDoc + + It's mirrored from OperationDocNode::Kind at include/tvm/script/printer/doc.h + """ + + _UnaryStart = 0 + USub = auto() + Invert = auto() + UnaryEnd = auto() + + _BinaryStart = auto() + Add = auto() + Sub = auto() + Mult = auto() + Div = auto() + FloorDiv = auto() + Mod = auto() + Pow = auto() + LShift = auto() + RShift = auto() + BitAnd = auto() + BitOr = auto() + BitXor = auto() + Lt = auto() + LtE = auto() + Eq = auto() + NotEq = auto() + Gt = auto() + GtE = auto() + _BinaryEnd = auto() + + _SpecialStart = auto() + Assert = auto() + + +@tvm._ffi.register_object("script.printer.OperationDoc") +class OperationDoc(ExprDoc): + """ + Doc that represents operation + + It can be unary, binary and other special operators (for example, the + if-then-else expression). + """ + + kind: OperationKind + operands: tvm.ir.container.Array # actual type: List[ExprDoc] + + def __init__(self, kind: OperationKind, operands: List[ExprDoc]): + self.__init_handle_by_constructor__(_ffi_api.OperationDoc, kind, operands) # type: ignore + + +@tvm._ffi.register_object("script.printer.LambdaDoc") +class LambdaDoc(ExprDoc): + """Doc that represents lambda function""" + + args: tvm.ir.container.Array # actual type: List[IdDoc] + body: ExprDoc + + def __init__(self, args: List[IdDoc], body: ExprDoc): + self.__init_handle_by_constructor__(_ffi_api.LambdaDoc, args, body) # type: ignore + + +@tvm._ffi.register_object("script.printer.TupleDoc") +class TupleDoc(ExprDoc): + """Doc that represents tuple literal""" + + elements: tvm.ir.container.Array # actual type: List[ExprDoc] + + def __init__(self, elements: List[ExprDoc]): + self.__init_handle_by_constructor__(_ffi_api.TupleDoc, elements) # type: ignore + + +@tvm._ffi.register_object("script.printer.ListDoc") +class ListDoc(ExprDoc): + """Doc that represents list literal""" + + elements: tvm.ir.container.Array # actual type: List[ExprDoc] + + def __init__(self, elements: List[ExprDoc]): + self.__init_handle_by_constructor__(_ffi_api.ListDoc, elements) # type: ignore + + +@tvm._ffi.register_object("script.printer.DictDoc") +class DictDoc(ExprDoc): + """Doc that represents dict literal""" + + keys: tvm.ir.container.Array # actual type: List[ExprDoc] + values: tvm.ir.container.Array # actual type: List[ExprDoc] + + def __init__(self, content: Dict[ExprDoc, ExprDoc]): + keys = list(content.keys()) + values = list(content.values()) + self.__init_handle_by_constructor__(_ffi_api.DictDoc, keys, values) # type: ignore + + +@tvm._ffi.register_object("script.printer.SliceDoc") +class SliceDoc(ExprDoc): + """ + Doc that represents slice in Index expression + + This doc can only appear in `IndexDoc.indices`. + """ + + start: Optional[ExprDoc] + stop: Optional[ExprDoc] + + def __init__(self, start: Optional[ExprDoc], stop: Optional[ExprDoc] = None): + self.__init_handle_by_constructor__(_ffi_api.SliceDoc, start, stop) # type: ignore diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index e54adbd36b4c..3422f9d2fdc4 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -23,14 +23,106 @@ namespace tvm { namespace script { namespace printer { +ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef(this), attr); } + +ExprDoc ExprDocNode::Index(Array indices) const { + return IndexDoc(GetRef(this), indices); +} + +ExprDoc ExprDocNode::Call(Array args) const { + return CallDoc(GetRef(this), args, {}, {}); +} + +ExprDoc ExprDocNode::Call(Array args, Array kwargs_keys, + Array kwargs_values) const { + return CallDoc(GetRef(this), args, kwargs_keys, kwargs_values); +} + LiteralDoc::LiteralDoc(ObjectRef value) { ObjectPtr n = make_object(); n->value = value; this->data_ = std::move(n); } +IdDoc::IdDoc(String name) { + ObjectPtr n = make_object(); + n->name = name; + this->data_ = std::move(n); +} + +AttrAccessDoc::AttrAccessDoc(ExprDoc value, String attr) { + ObjectPtr n = make_object(); + n->value = value; + n->attr = attr; + this->data_ = std::move(n); +} + +IndexDoc::IndexDoc(ExprDoc value, Array indices) { + ObjectPtr n = make_object(); + n->value = value; + n->indices = indices; + this->data_ = std::move(n); +} + +CallDoc::CallDoc(ExprDoc callee, Array args, Array kwargs_keys, + Array kwargs_values) { + ObjectPtr n = make_object(); + n->callee = callee; + n->args = args; + n->kwargs_keys = kwargs_keys; + n->kwargs_values = kwargs_values; + this->data_ = std::move(n); +} + +OperationDoc::OperationDoc(OperationDocNode::Kind kind, Array operands) { + ObjectPtr n = make_object(); + n->kind = kind; + n->operands = operands; + this->data_ = std::move(n); +} + +LambdaDoc::LambdaDoc(Array args, ExprDoc body) { + ObjectPtr n = make_object(); + n->args = args; + n->body = body; + this->data_ = std::move(n); +} + +TupleDoc::TupleDoc(Array elements) { + ObjectPtr n = make_object(); + n->elements = elements; + this->data_ = std::move(n); +} + +ListDoc::ListDoc(Array elements) { + ObjectPtr n = make_object(); + n->elements = elements; + this->data_ = std::move(n); +} + +DictDoc::DictDoc(Array keys, Array values) { + ObjectPtr n = make_object(); + n->keys = keys; + n->values = values; + this->data_ = std::move(n); +} + +SliceDoc::SliceDoc(Optional start, Optional stop) { + ObjectPtr n = make_object(); + n->start = start; + n->stop = stop; + this->data_ = std::move(n); +} + TVM_REGISTER_NODE_TYPE(DocNode); + TVM_REGISTER_NODE_TYPE(ExprDocNode); +TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr").set_body_method(&ExprDocNode::Attr); +TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex").set_body_method(&ExprDocNode::Index); +TVM_REGISTER_GLOBAL("script.printer.ExprDocCall") + .set_body_method, Array, Array>( + &ExprDocNode::Call); + TVM_REGISTER_NODE_TYPE(LiteralDocNode); TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None); TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int); @@ -38,6 +130,59 @@ TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDo TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float); TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str); +TVM_REGISTER_NODE_TYPE(IdDocNode); +TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); }); + +TVM_REGISTER_NODE_TYPE(AttrAccessDocNode); +TVM_REGISTER_GLOBAL("script.printer.AttrAccessDoc").set_body_typed([](ExprDoc value, String attr) { + return AttrAccessDoc(value, attr); +}); + +TVM_REGISTER_NODE_TYPE(IndexDocNode); +TVM_REGISTER_GLOBAL("script.printer.IndexDoc") + .set_body_typed([](ExprDoc value, Array indices) { return IndexDoc(value, indices); }); + +TVM_REGISTER_NODE_TYPE(CallDocNode); +TVM_REGISTER_GLOBAL("script.printer.CallDoc") + .set_body_typed([](ExprDoc callee, // + Array args, // + Array kwargs_keys, // + Array kwargs_values) { + return CallDoc(callee, args, kwargs_keys, kwargs_values); + }); + +TVM_REGISTER_NODE_TYPE(OperationDocNode); +TVM_REGISTER_GLOBAL("script.printer.OperationDoc") + .set_body_typed([](int32_t kind, Array operands) { + return OperationDoc(OperationDocNode::Kind(kind), operands); + }); + +TVM_REGISTER_NODE_TYPE(LambdaDocNode); +TVM_REGISTER_GLOBAL("script.printer.LambdaDoc").set_body_typed([](Array args, ExprDoc body) { + return LambdaDoc(args, body); +}); + +TVM_REGISTER_NODE_TYPE(TupleDocNode); +TVM_REGISTER_GLOBAL("script.printer.TupleDoc").set_body_typed([](Array elements) { + return TupleDoc(elements); +}); + +TVM_REGISTER_NODE_TYPE(ListDocNode); +TVM_REGISTER_GLOBAL("script.printer.ListDoc").set_body_typed([](Array elements) { + return ListDoc(elements); +}); + +TVM_REGISTER_NODE_TYPE(DictDocNode); +TVM_REGISTER_GLOBAL("script.printer.DictDoc") + .set_body_typed([](Array keys, Array values) { + return DictDoc(keys, values); + }); + +TVM_REGISTER_NODE_TYPE(SliceDocNode); +TVM_REGISTER_GLOBAL("script.printer.SliceDoc") + .set_body_typed([](Optional start, Optional stop) { + return SliceDoc(start, stop); + }); } // namespace printer } // namespace script } // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py index 6330d33bf25a..164889f8568b 100644 --- a/tests/python/unittest/test_tvmscript_printer_doc.py +++ b/tests/python/unittest/test_tvmscript_printer_doc.py @@ -16,8 +16,20 @@ # under the License. import pytest -from tvm.tir import IntImm -from tvm.script.printer.doc import LiteralDoc +from tvm.script.printer.doc import ( + LiteralDoc, + IdDoc, + AttrAccessDoc, + IndexDoc, + CallDoc, + OperationKind, + OperationDoc, + LambdaDoc, + TupleDoc, + ListDoc, + DictDoc, + SliceDoc, +) @pytest.mark.parametrize( @@ -26,8 +38,210 @@ ) def test_literal_doc_construction(value): doc = LiteralDoc(value) + if isinstance(value, float): # FloatImm cannot be compared with Python's float directly assert float(doc.value) == pytest.approx(value) else: assert doc.value == value + + +def test_id_doc(): + doc = IdDoc("name") + + assert doc.name == "name" + + +def test_attr_access_doc(): + target = IdDoc("x") + + doc = AttrAccessDoc(target, "attribute") + + assert doc.value == target + assert doc.attr == "attribute" + + +@pytest.mark.parametrize( + "indices", + [ + [], + [LiteralDoc(1)], + [LiteralDoc(2), IdDoc("x")], + [SliceDoc(LiteralDoc(1), LiteralDoc(2))], + [SliceDoc(LiteralDoc(1)), IdDoc("y")], + ], +) +def test_index_doc(indices): + target = IdDoc("x") + + doc = IndexDoc(target, indices) + + assert doc.value == target + assert list(doc.indices) == indices + + +@pytest.mark.parametrize( + "args, kwargs", + [ + ([], {}), + ([LiteralDoc("arg")], {}), + ([LiteralDoc("arg"), IdDoc("x")], {}), + ([], {"x": LiteralDoc("x")}), + ([], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}), + ([LiteralDoc("arg")], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}), + ([LiteralDoc("arg"), IdDoc("x")], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}), + ], +) +def test_call_doc(args, kwargs): + target = IdDoc("x") + + doc = CallDoc(target, *args, **kwargs) + + assert doc.callee == target + assert list(doc.args) == args + assert dict(zip(doc.kwargs_keys, doc.kwargs_values)) == kwargs + + +@pytest.mark.parametrize( + "operands", + [ + [], + [LiteralDoc(1)], + [LiteralDoc(2), IdDoc("x")], + [LiteralDoc(2), IdDoc("x"), LiteralDoc("y")], + ], +) +def test_operation_doc(operands): + # Here we just test the contructor and attr visitor of OperationDoc + # so the choice of OperationKind doesn't matter + operator = OperationKind.Add + + doc = OperationDoc(OperationKind.Add, operands) + + assert doc.kind == operator + assert list(doc.operands) == operands + + +@pytest.mark.parametrize( + "args", + [ + [], + [IdDoc("x")], + [IdDoc("x"), IdDoc("y")], + ], +) +def test_lambda_doc(args): + body = LiteralDoc(1) + + doc = LambdaDoc(args, body) + + assert doc.body == body + assert list(doc.args) == args + + +@pytest.mark.parametrize( + "elements", + [ + [], + [IdDoc("x")], + [IdDoc("x"), IdDoc("y")], + ], +) +def test_tuple_doc(elements): + doc = TupleDoc(elements) + + assert list(doc.elements) == elements + + +@pytest.mark.parametrize( + "elements", + [ + [], + [IdDoc("x")], + [IdDoc("x"), IdDoc("y")], + ], +) +def test_list_doc(elements): + doc = ListDoc(elements) + + assert list(doc.elements) == elements + + +@pytest.mark.parametrize( + "content", + [ + {}, + {LiteralDoc("k"): IdDoc("v")}, + {LiteralDoc("k"): IdDoc("v"), LiteralDoc("k2"): IdDoc("v2")}, + ], +) +def test_dict_doc(content): + doc = DictDoc(content) + + assert dict(zip(doc.keys, doc.values)) == content + + +@pytest.mark.parametrize( + "start,stop", + [ + (LiteralDoc(1), LiteralDoc(2)), + (LiteralDoc(1), None), + (None, LiteralDoc(2)), + ], +) +def test_slice_doc(start, stop): + doc = SliceDoc(start, stop) + + assert doc.start == start + assert doc.stop == stop + + +def test_expr_doc_attr_access(): + target = IdDoc("x") + attr = "test" + + doc = target.attr_access(attr) + + assert doc.value == target + assert doc.attr == attr + + +@pytest.mark.parametrize( + "indices", + [ + [], + [LiteralDoc(1)], + [LiteralDoc(2), IdDoc("x")], + [SliceDoc(LiteralDoc(1), LiteralDoc(2))], + [SliceDoc(LiteralDoc(1)), IdDoc("y")], + ], +) +def test_expr_doc_index_access(indices): + target = IdDoc("x") + + doc = target.index_access(indices) + + assert doc.value == target + assert list(doc.indices) == indices + + +@pytest.mark.parametrize( + "args, kwargs", + [ + ([], {}), + ([LiteralDoc("arg")], {}), + ([LiteralDoc("arg"), IdDoc("x")], {}), + ([], {"x": LiteralDoc("x")}), + ([], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}), + ([LiteralDoc("arg")], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}), + ([LiteralDoc("arg"), IdDoc("x")], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}), + ], +) +def test_expr_doc_call_with(args, kwargs): + target = IdDoc("x") + + doc = target.call_with(*args, **kwargs) + + assert doc.callee == target + assert list(doc.args) == args + assert dict(zip(doc.kwargs_keys, doc.kwargs_values)) == kwargs From 07a51024a38f40e6c31b8e32c7744879ee74d615 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Mon, 11 Jul 2022 10:54:22 -0400 Subject: [PATCH 02/23] Fix lint --- python/tvm/script/printer/doc.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 0bddb2608851..905bb24db1d3 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -150,7 +150,7 @@ def __init__(self, callee: ExprDoc, *args: Tuple[ExprDoc], **kwargs: Dict[str, E kwargs_keys = list(kwargs.keys()) kwargs_values = list(kwargs.values()) self.__init_handle_by_constructor__( - _ffi_api.CallDoc, callee, args, kwargs_keys, kwargs_values # type: ignore + _ffi_api.CallDoc, callee, args, kwargs_keys, kwargs_values # type: ignore ) @@ -162,6 +162,9 @@ class OperationKind(IntEnum): It's mirrored from OperationDocNode::Kind at include/tvm/script/printer/doc.h """ + # The name convention follows https://docs.python.org/3/library/ast.html + # pylint: disable=invalid-name + _UnaryStart = 0 USub = auto() Invert = auto() @@ -191,6 +194,8 @@ class OperationKind(IntEnum): _SpecialStart = auto() Assert = auto() + # pylint: enable=invalid-name + @tvm._ffi.register_object("script.printer.OperationDoc") class OperationDoc(ExprDoc): From 4dee76c9e0191dc9e5405152af10a817797105aa Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Mon, 11 Jul 2022 15:57:12 -0400 Subject: [PATCH 03/23] Add ExprDoc support in PythonDocPrinter --- include/tvm/script/printer/doc.h | 1 + python/tvm/script/printer/doc.py | 5 +- src/script/printer/base_doc_printer.cc | 20 + src/script/printer/base_doc_printer.h | 50 ++ src/script/printer/python_doc_printer.cc | 214 +++++++++ .../unittest/test_tvmscript_printer_doc.py | 1 + ...st_tvmscript_printer_python_doc_printer.py | 429 +++++++++++++++++- 7 files changed, 715 insertions(+), 5 deletions(-) diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 1f80fc905db0..d04102ab1bb8 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -398,6 +398,7 @@ class OperationDocNode : public ExprDocNode { // Special kSpecialStart, kAssert, + kSpecialEnd }; /*! \brief The kind of operation (operator) */ diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 905bb24db1d3..24ce04782a89 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -168,7 +168,7 @@ class OperationKind(IntEnum): _UnaryStart = 0 USub = auto() Invert = auto() - UnaryEnd = auto() + _UnaryEnd = auto() _BinaryStart = auto() Add = auto() @@ -193,6 +193,7 @@ class OperationKind(IntEnum): _SpecialStart = auto() Assert = auto() + _SpecialEnd = auto() # pylint: enable=invalid-name @@ -268,5 +269,5 @@ class SliceDoc(ExprDoc): start: Optional[ExprDoc] stop: Optional[ExprDoc] - def __init__(self, start: Optional[ExprDoc], stop: Optional[ExprDoc] = None): + def __init__(self, start: Optional[ExprDoc] = None, stop: Optional[ExprDoc] = None): self.__init_handle_by_constructor__(_ffi_api.SliceDoc, start, stop) # type: ignore diff --git a/src/script/printer/base_doc_printer.cc b/src/script/printer/base_doc_printer.cc index f6874ba1a2ee..42d3f2d8f3ac 100644 --- a/src/script/printer/base_doc_printer.cc +++ b/src/script/printer/base_doc_printer.cc @@ -38,6 +38,26 @@ String DocPrinter::GetString() const { void DocPrinter::PrintDoc(const Doc& doc) { if (const auto* doc_node = doc.as()) { PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); } else { LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey(); throw; diff --git a/src/script/printer/base_doc_printer.h b/src/script/printer/base_doc_printer.h index 128fcef2ea32..d5bfdcd94c6b 100644 --- a/src/script/printer/base_doc_printer.h +++ b/src/script/printer/base_doc_printer.h @@ -83,6 +83,56 @@ class DocPrinter { */ virtual void PrintTypedDoc(const LiteralDoc& doc) = 0; + /*! + * \brief Virtual method to print a IdDoc + */ + virtual void PrintTypedDoc(const IdDoc& doc) = 0; + + /*! + * \brief Virtual method to print a AttrAccessDoc + */ + virtual void PrintTypedDoc(const AttrAccessDoc& doc) = 0; + + /*! + * \brief Virtual method to print a IndexDoc + */ + virtual void PrintTypedDoc(const IndexDoc& doc) = 0; + + /*! + * \brief Virtual method to print a OperationDoc + */ + virtual void PrintTypedDoc(const OperationDoc& doc) = 0; + + /*! + * \brief Virtual method to print a CallDoc + */ + virtual void PrintTypedDoc(const CallDoc& doc) = 0; + + /*! + * \brief Virtual method to print a LambdaDoc + */ + virtual void PrintTypedDoc(const LambdaDoc& doc) = 0; + + /*! + * \brief Virtual method to print a ListDoc + */ + virtual void PrintTypedDoc(const ListDoc& doc) = 0; + + /*! + * \brief Virtual method to print a TupleDoc + */ + virtual void PrintTypedDoc(const TupleDoc& doc) = 0; + + /*! + * \brief Virtual method to print a DictDoc + */ + virtual void PrintTypedDoc(const DictDoc& doc) = 0; + + /*! + * \brief Virtual method to print a SliceDoc + */ + virtual void PrintTypedDoc(const SliceDoc& doc) = 0; + /*! * \brief Increase the indent level of any content to be * printed after this call diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index cd816e4f7010..3725a40ce6bb 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -19,8 +19,11 @@ #include +#include + #include "../../support/str_escape.h" #include "./base_doc_printer.h" +#include "tvm/runtime/logging.h" namespace tvm { namespace script { @@ -34,6 +37,33 @@ class PythonDocPrinter : public DocPrinter { using DocPrinter::PrintDoc; void PrintTypedDoc(const LiteralDoc& doc) final; + void PrintTypedDoc(const IdDoc& doc) final; + void PrintTypedDoc(const AttrAccessDoc& doc) final; + void PrintTypedDoc(const IndexDoc& doc) final; + void PrintTypedDoc(const OperationDoc& doc) final; + void PrintTypedDoc(const CallDoc& doc) final; + void PrintTypedDoc(const LambdaDoc& doc) final; + void PrintTypedDoc(const ListDoc& doc) final; + void PrintTypedDoc(const DictDoc& doc) final; + void PrintTypedDoc(const TupleDoc& doc) final; + void PrintTypedDoc(const SliceDoc& doc) final; + + private: + template + void PrintJoinedElements(const std::string& left, const Array& elements, + const std::string& separator, const std::string& right) { + output_ << left; + bool is_first = true; + for (auto& element : elements) { + if (is_first) { + is_first = false; + } else { + output_ << separator; + } + PrintDoc(element); + } + output_ << right; + } }; void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { @@ -57,6 +87,190 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { } } +/* + * This function checks whether an input string is a valid + * identifier. Invalid identifier name can make the result + * still parsable but into a different IR tree. So we want + * to fail as soon as possible. + */ +bool IsValidPythonIdentifier(const std::string& id) { + // This regex is just an approximation of the Python identifier + // rule. This doesn't exclude the reserved keywords. But it should + // be good enough for roundtrippable TVMScript printing and parsing. + const static std::regex id_pattern(R"(^[^\d\W]\w*$)"); + return std::regex_match(id, id_pattern); +} + +void PythonDocPrinter::PrintTypedDoc(const IdDoc& doc) { + CHECK(IsValidPythonIdentifier(doc->name)) + << "ValueError: " << doc->name << " is not a valid identifier."; + output_ << doc->name; +} + +void PythonDocPrinter::PrintTypedDoc(const AttrAccessDoc& doc) { + CHECK(IsValidPythonIdentifier(doc->attr)) + << "ValueError: " << doc->attr << " is not a valid attribute."; + PrintDoc(doc->value); + output_ << "." << doc->attr; +} + +void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) { + PrintDoc(doc->value); + if (doc->indices.size() == 0) { + output_ << "[()]"; + } else { + PrintJoinedElements("[", doc->indices, ", ", "]"); + } +} + +constexpr int OP_STR_TABLE_SIZE = static_cast(OperationDocNode::Kind::kSpecialEnd) + 1; +static const std::array OP_STR_TABLE = []() { + using OpKind = OperationDocNode::Kind; + std::array table; + auto set_op = [&table](auto op, const char* str) { table[static_cast(op)] = str; }; + + set_op(OpKind::kUSub, "-"); + set_op(OpKind::kInvert, "~"); + set_op(OpKind::kAdd, "+"); + set_op(OpKind::kSub, "-"); + set_op(OpKind::kMult, "*"); + set_op(OpKind::kDiv, "/"); + set_op(OpKind::kFloorDiv, "//"); + set_op(OpKind::kMod, "%"); + set_op(OpKind::kPow, "**"); + set_op(OpKind::kLShift, "<<"); + set_op(OpKind::kRShift, ">>"); + set_op(OpKind::kBitAnd, "&"); + set_op(OpKind::kBitOr, "|"); + set_op(OpKind::kBitXor, "^"); + set_op(OpKind::kLt, "<"); + set_op(OpKind::kLtE, "<="); + set_op(OpKind::kEq, "=="); + set_op(OpKind::kNotEq, "!="); + set_op(OpKind::kGt, ">"); + set_op(OpKind::kGtE, ">="); + + return table; +}(); + +const char* OperatorToString(OperationDocNode::Kind operation_kind) { + auto op_index = static_cast(operation_kind); + ICHECK_LT(op_index, OP_STR_TABLE_SIZE); + const char* str = OP_STR_TABLE[static_cast(operation_kind)]; + if (str == nullptr) { + LOG(FATAL) << "OperationDocNode::Kind " << static_cast(operation_kind) + << " cannot be converted to operator token in Python directly."; + throw; + } + return str; +} + +void PythonDocPrinter::PrintTypedDoc(const OperationDoc& doc) { + using OpKind = OperationDocNode::Kind; + if (doc->kind < OpKind::kUnaryEnd) { + // Unary Operators + ICHECK_EQ(doc->operands.size(), 1); + output_ << OperatorToString(doc->kind); + PrintDoc(doc->operands[0]); + } else if (doc->kind < OpKind::kBinaryEnd) { + // Binary Operator + ICHECK_EQ(doc->operands.size(), 2); + PrintDoc(doc->operands[0]); + output_ << " " << OperatorToString(doc->kind) << " "; + PrintDoc(doc->operands[1]); + } else if (doc->kind == OpKind::kAssert) { + // Special Operator: Assert + output_ << "assert "; + PrintDoc(doc->operands[0]); + if (doc->operands.size() > 1) { + output_ << ", "; + PrintDoc(doc->operands[1]); + } + } else { + LOG(FATAL) << "Unknown OperationDocNode::Kind " << static_cast(doc->kind); + throw; + } +} + +void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) { + PrintDoc(doc->callee); + + output_ << "("; + + // Print positional args + bool is_first = true; + for (const ExprDoc& arg : doc->args) { + if (is_first) { + is_first = false; + } else { + output_ << ", "; + } + PrintDoc(arg); + } + + // Print keyword args + for (size_t i = 0; i < doc->kwargs_keys.size(); i++) { + if (is_first) { + is_first = false; + } else { + output_ << ", "; + } + const String& keyword = doc->kwargs_keys[i]; + CHECK(IsValidPythonIdentifier(keyword)) + << "ValueError: " << keyword << " is not a valid name for keyword parameter."; + output_ << keyword; + output_ << "="; + PrintDoc(doc->kwargs_values[i]); + } + + output_ << ")"; +} + +void PythonDocPrinter::PrintTypedDoc(const LambdaDoc& doc) { + output_ << "lambda "; + PrintJoinedElements("", doc->args, ", ", ": "); + PrintDoc(doc->body); +} + +void PythonDocPrinter::PrintTypedDoc(const ListDoc& doc) { + PrintJoinedElements("[", doc->elements, ", ", "]"); +} + +void PythonDocPrinter::PrintTypedDoc(const TupleDoc& doc) { + if (doc->elements.size() == 1) { + output_ << "("; + PrintDoc(doc->elements[0]); + output_ << ",)"; + } else { + PrintJoinedElements("(", doc->elements, ", ", ")"); + } +} + +void PythonDocPrinter::PrintTypedDoc(const DictDoc& doc) { + output_ << "{"; + size_t idx = 0; + for (const ExprDoc& key : doc->keys) { + if (idx > 0) { + output_ << ", "; + } + PrintDoc(key); + output_ << ": "; + PrintDoc(doc->values[idx]); + idx++; + } + output_ << "}"; +} + +void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) { + if (doc->start != nullptr) { + PrintDoc(doc->start.value()); + } + output_ << ":"; + if (doc->stop != nullptr) { + PrintDoc(doc->stop.value()); + } +} + String DocToPythonScript(Doc doc, int indent_spaces) { PythonDocPrinter printer(indent_spaces); printer.Append(doc); diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py index 164889f8568b..fcb630bd3ec3 100644 --- a/tests/python/unittest/test_tvmscript_printer_doc.py +++ b/tests/python/unittest/test_tvmscript_printer_doc.py @@ -187,6 +187,7 @@ def test_dict_doc(content): (LiteralDoc(1), LiteralDoc(2)), (LiteralDoc(1), None), (None, LiteralDoc(2)), + (None, None), ], ) def test_slice_doc(start, stop): diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index 55b5e88c88c8..fdbe96fc3c4c 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -16,8 +16,19 @@ # under the License. import pytest +from tvm.script.printer.doc import ( + CallDoc, + DictDoc, + IdDoc, + LambdaDoc, + ListDoc, + LiteralDoc, + OperationDoc, + OperationKind, + SliceDoc, + TupleDoc, +) from tvm.script.printer.doc_printer import to_python_script -from tvm.script.printer.doc import LiteralDoc def format_script(s: str) -> str: @@ -28,7 +39,7 @@ def format_script(s: str) -> str: non_empty_lines = [line for line in s.splitlines() if line and not line.isspace()] line_indents = [len(line) - len(line.lstrip(" ")) for line in non_empty_lines] spaces_to_remove = min(line_indents) - return "\n".join(line[spaces_to_remove:] for line in s.splitlines()) + return "\n".join(line[spaces_to_remove:] for line in s.splitlines()) + "\n" @pytest.mark.parametrize( @@ -50,4 +61,416 @@ def format_script(s: str) -> str: ], ) def test_print_literal_doc(doc, expected): - assert to_python_script(doc).rstrip("\n") == format_script(expected) + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "name", + [ + "test", + "_test", + "TestCase", + "test_case", + "test123", + ], +) +def test_print_id_doc(name): + doc = IdDoc(name) + assert to_python_script(doc) == format_script(name) + + +INVALID_IDENTIFIERS = [ + "", + "123", + "@test", + "test@", + "test case", + "test, case", + "test[0]", + "test.case", +] + + +@pytest.mark.parametrize( + "name", + INVALID_IDENTIFIERS, +) +def test_print_invalid_id_doc(name): + doc = IdDoc(name) + with pytest.raises(ValueError) as e: + to_python_script(doc) + assert "IsValidPythonIdentifier" in str(e.value) + + +@pytest.mark.parametrize( + "attr", + [ + "attr", + "_attr", + "Attr", + "attr_1", + ], +) +def test_print_attr_doc(attr): + doc = IdDoc("x").attr_access(attr) + assert to_python_script(doc) == format_script(f"x.{attr}") + + +@pytest.mark.parametrize( + "attr", + [ + "", + "123", + "@attr", + "attr@", + "attr with space", + "attr, with dot", + "attr[0]", + "attr.dot", + ], +) +def test_print_invalid_attr_doc(attr): + doc = IdDoc("x").attr_access(attr) + with pytest.raises(ValueError) as e: + to_python_script(doc) + assert "IsValidPythonIdentifier" in str(e.value) + + +@pytest.mark.parametrize( + "indices,expected", + [ + ( + [], + "[()]", + ), + ( + [LiteralDoc(1)], + "[1]", + ), + ( + [LiteralDoc(2), IdDoc("x")], + "[2, x]", + ), + ( + [SliceDoc(LiteralDoc(1), LiteralDoc(2))], + "[1:2]", + ), + ( + [SliceDoc(LiteralDoc(1)), IdDoc("y")], + "[1:, y]", + ), + ( + [SliceDoc(), IdDoc("y")], + "[:, y]", + ), + ( + [IdDoc("x"), IdDoc("y"), IdDoc("z")], + "[x, y, z]", + ), + ], +) +def test_print_index_doc(indices, expected): + doc = IdDoc("x").index_access(indices) + assert to_python_script(doc) == format_script(f"x{expected}") + + +UNARY_OP_TOKENS = { + OperationKind.USub: "-", + OperationKind.Invert: "~", +} + + +@pytest.mark.parametrize( + "op_kind, expected_token", + list(UNARY_OP_TOKENS.items()), + ids=UNARY_OP_TOKENS.keys(), +) +def test_print_unary_operation_doc(op_kind, expected_token): + doc = OperationDoc(op_kind, [IdDoc("x")]) + assert to_python_script(doc) == format_script(f"{expected_token}x") + + +BINARY_OP_TOKENS = { + OperationKind.Add: "+", + OperationKind.Sub: "-", + OperationKind.Mult: "*", + OperationKind.Div: "/", + OperationKind.FloorDiv: "//", + OperationKind.Mod: "%", + OperationKind.Pow: "**", + OperationKind.LShift: "<<", + OperationKind.RShift: ">>", + OperationKind.BitAnd: "&", + OperationKind.BitOr: "|", + OperationKind.BitXor: "^", + OperationKind.Lt: "<", + OperationKind.LtE: "<=", + OperationKind.Eq: "==", + OperationKind.NotEq: "!=", + OperationKind.Gt: ">", + OperationKind.GtE: ">=", +} + + +@pytest.mark.parametrize( + "op_kind, expected_token", + list(BINARY_OP_TOKENS.items()), + ids=BINARY_OP_TOKENS.keys(), +) +def test_print_binary_operation_doc(op_kind, expected_token): + doc = OperationDoc(op_kind, [IdDoc("x"), IdDoc("y")]) + assert to_python_script(doc) == format_script(f"x {expected_token} y") + + +SPECIAL_OP_CASES = [ + ( + OperationKind.Assert, + [LiteralDoc(True), LiteralDoc("assert_message")], + 'assert True, "assert_message"', + ), + ( + OperationKind.Assert, + [LiteralDoc(True)], + "assert True", + ), +] + + +@pytest.mark.parametrize( + "op_kind, operands, expected", SPECIAL_OP_CASES, ids=[kind for (kind, *_) in SPECIAL_OP_CASES] +) +def test_print_special_operation_doc(op_kind, operands, expected): + doc = OperationDoc(op_kind, operands) + assert to_python_script(doc) == format_script(expected) + + +def test_operation_doc_test_exhaustive(): + special_op_covered = {k for k, *_ in SPECIAL_OP_CASES} + for op_kind in OperationKind: + if OperationKind._UnaryStart < op_kind < OperationKind._UnaryEnd: + assert op_kind in UNARY_OP_TOKENS, ( + f"{op_kind.name} not covered in test_print_unary_operation_doc. " + f"Please add the expected token to UNARY_OP_TOKENS" + ) + elif OperationKind._BinaryStart < op_kind < OperationKind._BinaryEnd: + assert op_kind in BINARY_OP_TOKENS, ( + f"{op_kind.name} not covered in test_print_binary_operation_doc. " + f"Please add the expected token to BINARY_OP_TOKENS" + ) + elif not op_kind.name.startswith("_"): + # Special Op + assert op_kind in special_op_covered, ( + f"{op_kind.name} not covered in test_print_special_operation_doc. " + f"Please add the expected token to SPECIAL_OP_CASES" + ) + + +@pytest.mark.parametrize( + "args, kwargs, expected", + [ + ( + (), + {}, + "()", + ), + ( + (), + {"key0": IdDoc("u")}, + "(key0=u)", + ), + ( + (), + {"key0": IdDoc("u"), "key1": IdDoc("v")}, + "(key0=u, key1=v)", + ), + ( + (IdDoc("x"),), + {}, + "(x)", + ), + ( + (IdDoc("x"),), + {"key0": IdDoc("u")}, + "(x, key0=u)", + ), + ( + (IdDoc("x"),), + {"key0": IdDoc("u"), "key1": IdDoc("v")}, + "(x, key0=u, key1=v)", + ), + ( + (IdDoc("x"), (IdDoc("y"))), + {}, + "(x, y)", + ), + ( + (IdDoc("x"), (IdDoc("y"))), + {"key0": IdDoc("u")}, + "(x, y, key0=u)", + ), + ( + (IdDoc("x"), (IdDoc("y"))), + {"key0": IdDoc("u"), "key1": IdDoc("v")}, + "(x, y, key0=u, key1=v)", + ), + ], +) +def test_print_call_doc(args, kwargs, expected): + doc = CallDoc(IdDoc("f"), *args, **kwargs) + assert to_python_script(doc) == format_script(f"f{expected}") + + +@pytest.mark.parametrize( + "args, kwargs", + [ + *[((), {invalid_name: IdDoc("u")}) for invalid_name in INVALID_IDENTIFIERS], + *[ + ((), {"valid_key": IdDoc("v"), invalid_name: IdDoc("u")}) + for invalid_name in INVALID_IDENTIFIERS + ], + *[((IdDoc("x"),), {invalid_name: IdDoc("u")}) for invalid_name in INVALID_IDENTIFIERS], + *[ + ((IdDoc("x"),), {"valid_key": IdDoc("v"), invalid_name: IdDoc("u")}) + for invalid_name in INVALID_IDENTIFIERS + ], + ], +) +def test_print_call_doc_invalid_kwarg_key(args, kwargs): + doc = CallDoc(IdDoc("f"), *args, **kwargs) + with pytest.raises(ValueError) as e: + to_python_script(doc) + assert "IsValidPythonIdentifier" in str(e.value) + + +@pytest.mark.parametrize( + "args, expected", + [ + ( + (), + "lambda : 0", + ), + ( + (IdDoc("x"),), + "lambda x: 0", + ), + ( + (IdDoc("x"), IdDoc("y")), + "lambda x, y: 0", + ), + ( + (IdDoc("x"), IdDoc("y"), IdDoc("z")), + "lambda x, y, z: 0", + ), + ], +) +def test_print_lambda_doc(args, expected): + doc = LambdaDoc(args, body=LiteralDoc(0)) + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "elements, expected", + [ + ( + (), + "[]", + ), + ( + [IdDoc("x")], + "[x]", + ), + ( + [IdDoc("x"), IdDoc("y")], + "[x, y]", + ), + ( + [IdDoc("x"), IdDoc("y"), IdDoc("z")], + "[x, y, z]", + ), + ], +) +def test_print_list_doc(elements, expected): + doc = ListDoc(elements) + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "elements, expected", + [ + ( + (), + "()", + ), + ( + [IdDoc("x")], + "(x,)", + ), + ( + [IdDoc("x"), IdDoc("y")], + "(x, y)", + ), + ( + [IdDoc("x"), IdDoc("y"), IdDoc("z")], + "(x, y, z)", + ), + ], +) +def test_print_tuple_doc(elements, expected): + doc = TupleDoc(elements) + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "content, expected", + [ + ( + {}, + "{}", + ), + ( + {LiteralDoc("key_x"): IdDoc("x")}, + '{"key_x": x}', + ), + ( + {LiteralDoc("key_x"): IdDoc("x"), LiteralDoc("key_y"): IdDoc("y")}, + '{"key_x": x, "key_y": y}', + ), + ( + { + LiteralDoc("key_x"): IdDoc("x"), + LiteralDoc("key_y"): IdDoc("y"), + LiteralDoc("key_z"): IdDoc("z"), + }, + '{"key_x": x, "key_y": y, "key_z": z}', + ), + ], +) +def test_print_dict_doc(content, expected): + doc = DictDoc(content) + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "slice_doc, expected", + [ + ( + SliceDoc(), + ":", + ), + ( + SliceDoc(LiteralDoc(1)), + "1:", + ), + ( + SliceDoc(None, LiteralDoc(2)), + ":2", + ), + ( + SliceDoc(LiteralDoc(1), LiteralDoc(2)), + "1:2", + ), + ], +) +def test_print_slice_doc(slice_doc, expected): + doc = IdDoc("x").index_access([slice_doc]) + assert to_python_script(doc) == format_script(f"x[{expected}]") From 1580f48282956ca43d509c1fac32e437f83b0c96 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Mon, 11 Jul 2022 16:02:21 -0400 Subject: [PATCH 04/23] Fix typo --- .../unittest/test_tvmscript_printer_python_doc_printer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index fdbe96fc3c4c..6cdd3576a14c 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -137,7 +137,7 @@ def test_print_invalid_attr_doc(attr): @pytest.mark.parametrize( - "indices,expected", + "indices, expected", [ ( [], @@ -261,7 +261,7 @@ def test_operation_doc_test_exhaustive(): # Special Op assert op_kind in special_op_covered, ( f"{op_kind.name} not covered in test_print_special_operation_doc. " - f"Please add the expected token to SPECIAL_OP_CASES" + f"Please add the test cases for it to SPECIAL_OP_CASES" ) From fc5a7ef9e600a6a654a3719aa0538c9c764bb1d8 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Mon, 11 Jul 2022 16:10:04 -0400 Subject: [PATCH 05/23] Simplify PrintJoinedDocs --- src/script/printer/python_doc_printer.cc | 27 +++++++++++++----------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index 3725a40ce6bb..715cb5351aa6 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -50,19 +50,16 @@ class PythonDocPrinter : public DocPrinter { private: template - void PrintJoinedElements(const std::string& left, const Array& elements, - const std::string& separator, const std::string& right) { - output_ << left; + void PrintJoinedDocs(const Array& docs, const std::string& separator) { bool is_first = true; - for (auto& element : elements) { + for (auto& doc : docs) { if (is_first) { is_first = false; } else { output_ << separator; } - PrintDoc(element); + PrintDoc(doc); } - output_ << right; } }; @@ -119,7 +116,9 @@ void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) { if (doc->indices.size() == 0) { output_ << "[()]"; } else { - PrintJoinedElements("[", doc->indices, ", ", "]"); + output_ << "["; + PrintJoinedDocs(doc->indices, ", "); + output_ << "]"; } } @@ -228,22 +227,26 @@ void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) { void PythonDocPrinter::PrintTypedDoc(const LambdaDoc& doc) { output_ << "lambda "; - PrintJoinedElements("", doc->args, ", ", ": "); + PrintJoinedDocs(doc->args, ", "); + output_ << ": "; PrintDoc(doc->body); } void PythonDocPrinter::PrintTypedDoc(const ListDoc& doc) { - PrintJoinedElements("[", doc->elements, ", ", "]"); + output_ << "["; + PrintJoinedDocs(doc->elements, ", "); + output_ << "]"; } void PythonDocPrinter::PrintTypedDoc(const TupleDoc& doc) { + output_ << "("; if (doc->elements.size() == 1) { - output_ << "("; PrintDoc(doc->elements[0]); - output_ << ",)"; + output_ << ","; } else { - PrintJoinedElements("(", doc->elements, ", ", ")"); + PrintJoinedDocs(doc->elements, ", "); } + output_ << ")"; } void PythonDocPrinter::PrintTypedDoc(const DictDoc& doc) { From 150f9e2687fef1e37da6aaffed060c0ba2c66964 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 12 Jul 2022 01:05:51 -0400 Subject: [PATCH 06/23] Fix lint --- include/tvm/script/printer/doc.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index d04102ab1bb8..4cedbdc35bd4 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -630,7 +630,7 @@ class SliceDoc : public Doc { /*! * \brief Constructor of SliceDoc * \param start The start of slice. - * \param start The exclusive end of slice. + * \param stop The exclusive end of slice. */ explicit SliceDoc(Optional start, Optional stop); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SliceDoc, Doc, SliceDocNode); From 1b78fdeabd1d5373b25b0b198ffbc5ab4122cae9 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 12 Jul 2022 10:53:03 -0400 Subject: [PATCH 07/23] Fix lint --- src/script/printer/python_doc_printer.cc | 58 ++++++++++++------------ 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index 715cb5351aa6..c78f40158aab 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -94,7 +94,7 @@ bool IsValidPythonIdentifier(const std::string& id) { // This regex is just an approximation of the Python identifier // rule. This doesn't exclude the reserved keywords. But it should // be good enough for roundtrippable TVMScript printing and parsing. - const static std::regex id_pattern(R"(^[^\d\W]\w*$)"); + static const std::regex id_pattern(R"(^[^\d\W]\w*$)"); return std::regex_match(id, id_pattern); } @@ -122,37 +122,37 @@ void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) { } } -constexpr int OP_STR_TABLE_SIZE = static_cast(OperationDocNode::Kind::kSpecialEnd) + 1; -static const std::array OP_STR_TABLE = []() { - using OpKind = OperationDocNode::Kind; - std::array table; - auto set_op = [&table](auto op, const char* str) { table[static_cast(op)] = str; }; +const char* OperatorToString(OperationDocNode::Kind operation_kind) { + constexpr int OP_STR_TABLE_SIZE = static_cast(OperationDocNode::Kind::kSpecialEnd) + 1; + static const std::array OP_STR_TABLE = []() { + using OpKind = OperationDocNode::Kind; + std::array table; + auto set_op = [&table](auto op, const char* str) { table[static_cast(op)] = str; }; - set_op(OpKind::kUSub, "-"); - set_op(OpKind::kInvert, "~"); - set_op(OpKind::kAdd, "+"); - set_op(OpKind::kSub, "-"); - set_op(OpKind::kMult, "*"); - set_op(OpKind::kDiv, "/"); - set_op(OpKind::kFloorDiv, "//"); - set_op(OpKind::kMod, "%"); - set_op(OpKind::kPow, "**"); - set_op(OpKind::kLShift, "<<"); - set_op(OpKind::kRShift, ">>"); - set_op(OpKind::kBitAnd, "&"); - set_op(OpKind::kBitOr, "|"); - set_op(OpKind::kBitXor, "^"); - set_op(OpKind::kLt, "<"); - set_op(OpKind::kLtE, "<="); - set_op(OpKind::kEq, "=="); - set_op(OpKind::kNotEq, "!="); - set_op(OpKind::kGt, ">"); - set_op(OpKind::kGtE, ">="); + set_op(OpKind::kUSub, "-"); + set_op(OpKind::kInvert, "~"); + set_op(OpKind::kAdd, "+"); + set_op(OpKind::kSub, "-"); + set_op(OpKind::kMult, "*"); + set_op(OpKind::kDiv, "/"); + set_op(OpKind::kFloorDiv, "//"); + set_op(OpKind::kMod, "%"); + set_op(OpKind::kPow, "**"); + set_op(OpKind::kLShift, "<<"); + set_op(OpKind::kRShift, ">>"); + set_op(OpKind::kBitAnd, "&"); + set_op(OpKind::kBitOr, "|"); + set_op(OpKind::kBitXor, "^"); + set_op(OpKind::kLt, "<"); + set_op(OpKind::kLtE, "<="); + set_op(OpKind::kEq, "=="); + set_op(OpKind::kNotEq, "!="); + set_op(OpKind::kGt, ">"); + set_op(OpKind::kGtE, ">="); - return table; -}(); + return table; + }(); -const char* OperatorToString(OperationDocNode::Kind operation_kind) { auto op_index = static_cast(operation_kind); ICHECK_LT(op_index, OP_STR_TABLE_SIZE); const char* str = OP_STR_TABLE[static_cast(operation_kind)]; From 9371b5457f7b00363a398570779daca3a3401b70 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 12 Jul 2022 13:21:42 -0400 Subject: [PATCH 08/23] include array --- src/script/printer/python_doc_printer.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index c78f40158aab..8759000710b6 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -19,6 +19,7 @@ #include +#include #include #include "../../support/str_escape.h" From 070c60a7a139286a6c8272ee3fe8cfd957290add Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 12 Jul 2022 16:33:40 -0400 Subject: [PATCH 09/23] Try to fix MSVC compile error --- src/script/printer/python_doc_printer.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index 8759000710b6..7db180109201 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -125,9 +125,11 @@ void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) { const char* OperatorToString(OperationDocNode::Kind operation_kind) { constexpr int OP_STR_TABLE_SIZE = static_cast(OperationDocNode::Kind::kSpecialEnd) + 1; - static const std::array OP_STR_TABLE = []() { + using OpStrTable = std::array; + // Add explicit return type to satisfy MSVC + static const OpStrTable OP_STR_TABLE = [&]() -> OpStrTable { using OpKind = OperationDocNode::Kind; - std::array table; + OpStrTable table; auto set_op = [&table](auto op, const char* str) { table[static_cast(op)] = str; }; set_op(OpKind::kUSub, "-"); From 95f782d365b028bbecfc0a088c4347b0b05873df Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 12 Jul 2022 17:16:46 -0400 Subject: [PATCH 10/23] Simplify the implementation of OperatorToString --- src/script/printer/python_doc_printer.cc | 65 +++++++++++++----------- 1 file changed, 34 insertions(+), 31 deletions(-) diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index 7db180109201..cd5ed814360a 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -123,43 +123,46 @@ void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) { } } -const char* OperatorToString(OperationDocNode::Kind operation_kind) { - constexpr int OP_STR_TABLE_SIZE = static_cast(OperationDocNode::Kind::kSpecialEnd) + 1; - using OpStrTable = std::array; - // Add explicit return type to satisfy MSVC - static const OpStrTable OP_STR_TABLE = [&]() -> OpStrTable { +const std::string OperatorToString(OperationDocNode::Kind operation_kind) { + static const std::vector OP_STR_TABLE = []() { using OpKind = OperationDocNode::Kind; - OpStrTable table; - auto set_op = [&table](auto op, const char* str) { table[static_cast(op)] = str; }; - - set_op(OpKind::kUSub, "-"); - set_op(OpKind::kInvert, "~"); - set_op(OpKind::kAdd, "+"); - set_op(OpKind::kSub, "-"); - set_op(OpKind::kMult, "*"); - set_op(OpKind::kDiv, "/"); - set_op(OpKind::kFloorDiv, "//"); - set_op(OpKind::kMod, "%"); - set_op(OpKind::kPow, "**"); - set_op(OpKind::kLShift, "<<"); - set_op(OpKind::kRShift, ">>"); - set_op(OpKind::kBitAnd, "&"); - set_op(OpKind::kBitOr, "|"); - set_op(OpKind::kBitXor, "^"); - set_op(OpKind::kLt, "<"); - set_op(OpKind::kLtE, "<="); - set_op(OpKind::kEq, "=="); - set_op(OpKind::kNotEq, "!="); - set_op(OpKind::kGt, ">"); - set_op(OpKind::kGtE, ">="); + std::map raw_table = { + {OpKind::kUSub, "-"}, // + {OpKind::kInvert, "~"}, // + {OpKind::kAdd, "+"}, // + {OpKind::kSub, "-"}, // + {OpKind::kMult, "*"}, // + {OpKind::kDiv, "/"}, // + {OpKind::kFloorDiv, "//"}, // + {OpKind::kMod, "%"}, // + {OpKind::kPow, "**"}, // + {OpKind::kLShift, "<<"}, // + {OpKind::kRShift, ">>"}, // + {OpKind::kBitAnd, "&"}, // + {OpKind::kBitOr, "|"}, // + {OpKind::kBitXor, "^"}, // + {OpKind::kLt, "<"}, // + {OpKind::kLtE, "<="}, // + {OpKind::kEq, "=="}, // + {OpKind::kNotEq, "!="}, // + {OpKind::kGt, ">"}, // + {OpKind::kGtE, ">="}, // + }; + + std::vector table; + table.resize(static_cast(OperationDocNode::Kind::kSpecialEnd) + 1); + + for (const auto& kv : raw_table) { + table[static_cast(kv.first)] = kv.second; + } return table; }(); auto op_index = static_cast(operation_kind); - ICHECK_LT(op_index, OP_STR_TABLE_SIZE); - const char* str = OP_STR_TABLE[static_cast(operation_kind)]; - if (str == nullptr) { + ICHECK_LT(op_index, OP_STR_TABLE.size()); + const std::string str = OP_STR_TABLE[op_index]; + if (str.empty()) { LOG(FATAL) << "OperationDocNode::Kind " << static_cast(operation_kind) << " cannot be converted to operator token in Python directly."; throw; From 7c81bf3b21d65b1448db19910febab7033180f2f Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 12 Jul 2022 17:18:22 -0400 Subject: [PATCH 11/23] Remove unused header --- src/script/printer/python_doc_printer.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index cd5ed814360a..11bc8a8f9889 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -19,7 +19,6 @@ #include -#include #include #include "../../support/str_escape.h" From 216c98ea7d7f6bb2536db29e869c48f9a431d057 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Wed, 13 Jul 2022 10:27:36 -0400 Subject: [PATCH 12/23] Remove the usage of regex to unblock PR --- src/script/printer/python_doc_printer.cc | 26 +------- ...st_tvmscript_printer_python_doc_printer.py | 65 ------------------- 2 files changed, 1 insertion(+), 90 deletions(-) diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index 11bc8a8f9889..57859771bdbe 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -19,8 +19,6 @@ #include -#include - #include "../../support/str_escape.h" #include "./base_doc_printer.h" #include "tvm/runtime/logging.h" @@ -84,29 +82,9 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { } } -/* - * This function checks whether an input string is a valid - * identifier. Invalid identifier name can make the result - * still parsable but into a different IR tree. So we want - * to fail as soon as possible. - */ -bool IsValidPythonIdentifier(const std::string& id) { - // This regex is just an approximation of the Python identifier - // rule. This doesn't exclude the reserved keywords. But it should - // be good enough for roundtrippable TVMScript printing and parsing. - static const std::regex id_pattern(R"(^[^\d\W]\w*$)"); - return std::regex_match(id, id_pattern); -} - -void PythonDocPrinter::PrintTypedDoc(const IdDoc& doc) { - CHECK(IsValidPythonIdentifier(doc->name)) - << "ValueError: " << doc->name << " is not a valid identifier."; - output_ << doc->name; -} +void PythonDocPrinter::PrintTypedDoc(const IdDoc& doc) { output_ << doc->name; } void PythonDocPrinter::PrintTypedDoc(const AttrAccessDoc& doc) { - CHECK(IsValidPythonIdentifier(doc->attr)) - << "ValueError: " << doc->attr << " is not a valid attribute."; PrintDoc(doc->value); output_ << "." << doc->attr; } @@ -220,8 +198,6 @@ void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) { output_ << ", "; } const String& keyword = doc->kwargs_keys[i]; - CHECK(IsValidPythonIdentifier(keyword)) - << "ValueError: " << keyword << " is not a valid name for keyword parameter."; output_ << keyword; output_ << "="; PrintDoc(doc->kwargs_values[i]); diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index 6cdd3576a14c..fb30b7564991 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -79,29 +79,6 @@ def test_print_id_doc(name): assert to_python_script(doc) == format_script(name) -INVALID_IDENTIFIERS = [ - "", - "123", - "@test", - "test@", - "test case", - "test, case", - "test[0]", - "test.case", -] - - -@pytest.mark.parametrize( - "name", - INVALID_IDENTIFIERS, -) -def test_print_invalid_id_doc(name): - doc = IdDoc(name) - with pytest.raises(ValueError) as e: - to_python_script(doc) - assert "IsValidPythonIdentifier" in str(e.value) - - @pytest.mark.parametrize( "attr", [ @@ -116,26 +93,6 @@ def test_print_attr_doc(attr): assert to_python_script(doc) == format_script(f"x.{attr}") -@pytest.mark.parametrize( - "attr", - [ - "", - "123", - "@attr", - "attr@", - "attr with space", - "attr, with dot", - "attr[0]", - "attr.dot", - ], -) -def test_print_invalid_attr_doc(attr): - doc = IdDoc("x").attr_access(attr) - with pytest.raises(ValueError) as e: - to_python_script(doc) - assert "IsValidPythonIdentifier" in str(e.value) - - @pytest.mark.parametrize( "indices, expected", [ @@ -320,28 +277,6 @@ def test_print_call_doc(args, kwargs, expected): assert to_python_script(doc) == format_script(f"f{expected}") -@pytest.mark.parametrize( - "args, kwargs", - [ - *[((), {invalid_name: IdDoc("u")}) for invalid_name in INVALID_IDENTIFIERS], - *[ - ((), {"valid_key": IdDoc("v"), invalid_name: IdDoc("u")}) - for invalid_name in INVALID_IDENTIFIERS - ], - *[((IdDoc("x"),), {invalid_name: IdDoc("u")}) for invalid_name in INVALID_IDENTIFIERS], - *[ - ((IdDoc("x"),), {"valid_key": IdDoc("v"), invalid_name: IdDoc("u")}) - for invalid_name in INVALID_IDENTIFIERS - ], - ], -) -def test_print_call_doc_invalid_kwarg_key(args, kwargs): - doc = CallDoc(IdDoc("f"), *args, **kwargs) - with pytest.raises(ValueError) as e: - to_python_script(doc) - assert "IsValidPythonIdentifier" in str(e.value) - - @pytest.mark.parametrize( "args, expected", [ From fb7a5e05592816614c2981b3850126aa2fdd5dab Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Thu, 14 Jul 2022 11:35:32 -0400 Subject: [PATCH 13/23] Change include style --- src/script/printer/python_doc_printer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index 57859771bdbe..c64c67d3837a 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -17,11 +17,11 @@ * under the License. */ +#include #include #include "../../support/str_escape.h" #include "./base_doc_printer.h" -#include "tvm/runtime/logging.h" namespace tvm { namespace script { From 5e233ffeb7ac62c9ac612052ea944b31aa51b227 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 15 Jul 2022 10:14:31 -0400 Subject: [PATCH 14/23] Remove unnecessary template parameter --- include/tvm/script/printer/doc.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 4cedbdc35bd4..4e52d4a3776d 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -96,8 +96,8 @@ class ExprDocNode : public DocNode { * \param kwargs_keys Keys of keywords arguments of the function call. * \param kwargs_values Values of keywords arguments of the function call. */ - ExprDoc Call(Array args, // - Array kwargs_keys, // + ExprDoc Call(Array args, // + Array kwargs_keys, // Array kwargs_values) const; void VisitAttrs(AttrVisitor* v) { DocNode::VisitAttrs(v); } From c0a429e88bf831f60cb256ae3872bf91503d4d99 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 15 Jul 2022 10:21:46 -0400 Subject: [PATCH 15/23] Remove assert op and add if-then-else as op --- include/tvm/script/printer/doc.h | 2 +- python/tvm/script/printer/doc.py | 2 +- src/script/printer/python_doc_printer.cc | 14 +++++++------- .../test_tvmscript_printer_python_doc_printer.py | 12 ++++++------ 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 4e52d4a3776d..4541c1f761e0 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -397,7 +397,7 @@ class OperationDocNode : public ExprDocNode { // Special kSpecialStart, - kAssert, + kIfThenElse, // if else kSpecialEnd }; diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 24ce04782a89..6f582ebdb589 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -192,7 +192,7 @@ class OperationKind(IntEnum): _BinaryEnd = auto() _SpecialStart = auto() - Assert = auto() + IfThenElse = auto() _SpecialEnd = auto() # pylint: enable=invalid-name diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index c64c67d3837a..ccdfe3fed057 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -160,14 +160,14 @@ void PythonDocPrinter::PrintTypedDoc(const OperationDoc& doc) { PrintDoc(doc->operands[0]); output_ << " " << OperatorToString(doc->kind) << " "; PrintDoc(doc->operands[1]); - } else if (doc->kind == OpKind::kAssert) { - // Special Operator: Assert - output_ << "assert "; + } else if (doc->kind == OpKind::kIfThenElse) { + CHECK_EQ(doc->operands.size(), 3) + << "ValueError: IfThenElse requires 3 operands, but got " << doc->operands.size(); + PrintDoc(doc->operands[1]); + output_ << " if "; PrintDoc(doc->operands[0]); - if (doc->operands.size() > 1) { - output_ << ", "; - PrintDoc(doc->operands[1]); - } + output_ << " else "; + PrintDoc(doc->operands[2]); } else { LOG(FATAL) << "Unknown OperationDocNode::Kind " << static_cast(doc->kind); throw; diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index fb30b7564991..9dfae35eb040 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -181,14 +181,14 @@ def test_print_binary_operation_doc(op_kind, expected_token): SPECIAL_OP_CASES = [ ( - OperationKind.Assert, - [LiteralDoc(True), LiteralDoc("assert_message")], - 'assert True, "assert_message"', + OperationKind.IfThenElse, + [LiteralDoc(True), LiteralDoc("true"), LiteralDoc("false")], + '"true" if True else "false"', ), ( - OperationKind.Assert, - [LiteralDoc(True)], - "assert True", + OperationKind.IfThenElse, + [IdDoc("x"), LiteralDoc(None), LiteralDoc(1)], + 'None if x else 1', ), ] From 99ffc1847dd15cee509c7b1b1db66ef2ef639e81 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 15 Jul 2022 12:03:49 -0400 Subject: [PATCH 16/23] Fix lint --- .../unittest/test_tvmscript_printer_python_doc_printer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index 9dfae35eb040..702fc151381d 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -188,7 +188,7 @@ def test_print_binary_operation_doc(op_kind, expected_token): ( OperationKind.IfThenElse, [IdDoc("x"), LiteralDoc(None), LiteralDoc(1)], - 'None if x else 1', + "None if x else 1", ), ] From a1ca5f46f115fe099b88a962d7bc7c828aae415d Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Sat, 16 Jul 2022 21:56:48 -0400 Subject: [PATCH 17/23] Be explicit on the enum value --- include/tvm/script/printer/doc.h | 54 ++++++++++++++--------------- python/tvm/script/printer/doc.py | 58 ++++++++++++++++---------------- 2 files changed, 56 insertions(+), 56 deletions(-) diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 4541c1f761e0..0fe108f6cb78 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -368,37 +368,37 @@ class OperationDocNode : public ExprDocNode { public: enum class Kind : int32_t { // Unary operators - kUnaryStart, - kUSub, // -x - kInvert, // ~x - kUnaryEnd, + kUnaryStart = 0, + kUSub = 1, // -x + kInvert = 2, // ~x + kUnaryEnd = 3, // Binary operators - kBinaryStart, - kAdd, // + - kSub, // - - kMult, // * - kDiv, // / - kFloorDiv, // // in Python - kMod, // % in Python - kPow, // ** in Python - kLShift, // << - kRShift, // >> - kBitAnd, // & - kBitOr, // | - kBitXor, // ^ - kLt, // < - kLtE, // <= - kEq, // == - kNotEq, // != - kGt, // > - kGtE, // >= - kBinaryEnd, + kBinaryStart = 4, + kAdd = 5, // + + kSub = 6, // - + kMult = 7, // * + kDiv = 8, // / + kFloorDiv = 9, // // in Python + kMod = 10, // % in Python + kPow = 11, // ** in Python + kLShift = 12, // << + kRShift = 13, // >> + kBitAnd = 14, // & + kBitOr = 15, // | + kBitXor = 16, // ^ + kLt = 17, // < + kLtE = 18, // <= + kEq = 19, // == + kNotEq = 20, // != + kGt = 21, // > + kGtE = 22, // >= + kBinaryEnd = 23, // Special - kSpecialStart, - kIfThenElse, // if else - kSpecialEnd + kSpecialStart = 24, + kIfThenElse = 25, // if else + kSpecialEnd = 26 }; /*! \brief The kind of operation (operator) */ diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 6f582ebdb589..31cb89799430 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -17,7 +17,7 @@ """Doc types for TVMScript Unified Printer""" from typing import List, Dict, Tuple, Optional, Union -from enum import IntEnum, auto, unique +from enum import IntEnum, unique import tvm._ffi import tvm.ir.container @@ -166,34 +166,34 @@ class OperationKind(IntEnum): # pylint: disable=invalid-name _UnaryStart = 0 - USub = auto() - Invert = auto() - _UnaryEnd = auto() - - _BinaryStart = auto() - Add = auto() - Sub = auto() - Mult = auto() - Div = auto() - FloorDiv = auto() - Mod = auto() - Pow = auto() - LShift = auto() - RShift = auto() - BitAnd = auto() - BitOr = auto() - BitXor = auto() - Lt = auto() - LtE = auto() - Eq = auto() - NotEq = auto() - Gt = auto() - GtE = auto() - _BinaryEnd = auto() - - _SpecialStart = auto() - IfThenElse = auto() - _SpecialEnd = auto() + USub = 1 + Invert = 2 + _UnaryEnd = 3 + + _BinaryStart = 4 + Add = 5 + Sub = 6 + Mult = 7 + Div = 8 + FloorDiv = 9 + Mod = 10 + Pow = 11 + LShift = 12 + RShift = 13 + BitAnd = 14 + BitOr = 15 + BitXor = 16 + Lt = 17 + LtE = 18 + Eq = 19 + NotEq = 20 + Gt = 21 + GtE = 22 + _BinaryEnd = 23 + + _SpecialStart = 24 + IfThenElse = 25 + _SpecialEnd = 26 # pylint: enable=invalid-name From 12a9b3b22480c50af010c86bb706828cb2869c97 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 19 Jul 2022 00:41:46 -0400 Subject: [PATCH 18/23] Add attr step to SliceDoc --- include/tvm/script/printer/doc.h | 5 ++++- python/tvm/script/printer/doc.py | 5 +++-- src/script/printer/doc.cc | 7 ++++--- src/script/printer/python_doc_printer.cc | 4 ++++ .../unittest/test_tvmscript_printer_doc.py | 14 ++++---------- .../test_tvmscript_printer_python_doc_printer.py | 16 ++++++++++++++++ 6 files changed, 35 insertions(+), 16 deletions(-) diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 0fe108f6cb78..a3ffc737611a 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -609,11 +609,14 @@ class SliceDocNode : public DocNode { Optional start; /*! \brief The exclusive end of slice */ Optional stop; + /*! \brief The step of slice */ + Optional step; void VisitAttrs(AttrVisitor* v) { DocNode::VisitAttrs(v); v->Visit("start", &start); v->Visit("stop", &stop); + v->Visit("step", &step); } static constexpr const char* _type_key = "script.printer.SliceDoc"; @@ -632,7 +635,7 @@ class SliceDoc : public Doc { * \param start The start of slice. * \param stop The exclusive end of slice. */ - explicit SliceDoc(Optional start, Optional stop); + explicit SliceDoc(Optional start, Optional stop, Optional step); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SliceDoc, Doc, SliceDocNode); }; diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 31cb89799430..2c81e72de348 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -268,6 +268,7 @@ class SliceDoc(ExprDoc): start: Optional[ExprDoc] stop: Optional[ExprDoc] + step: Optional[ExprDoc] - def __init__(self, start: Optional[ExprDoc] = None, stop: Optional[ExprDoc] = None): - self.__init_handle_by_constructor__(_ffi_api.SliceDoc, start, stop) # type: ignore + def __init__(self, start: Optional[ExprDoc] = None, stop: Optional[ExprDoc] = None, step: Optional[ExprDoc] = None): + self.__init_handle_by_constructor__(_ffi_api.SliceDoc, start, stop, step) # type: ignore diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index 3422f9d2fdc4..a529fbf7e63c 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -107,10 +107,11 @@ DictDoc::DictDoc(Array keys, Array values) { this->data_ = std::move(n); } -SliceDoc::SliceDoc(Optional start, Optional stop) { +SliceDoc::SliceDoc(Optional start, Optional stop, Optional step) { ObjectPtr n = make_object(); n->start = start; n->stop = stop; + n->step = step; this->data_ = std::move(n); } @@ -180,8 +181,8 @@ TVM_REGISTER_GLOBAL("script.printer.DictDoc") TVM_REGISTER_NODE_TYPE(SliceDocNode); TVM_REGISTER_GLOBAL("script.printer.SliceDoc") - .set_body_typed([](Optional start, Optional stop) { - return SliceDoc(start, stop); + .set_body_typed([](Optional start, Optional stop, Optional step) { + return SliceDoc(start, stop, step); }); } // namespace printer } // namespace script diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index ccdfe3fed057..c5ee1175b118 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -253,6 +253,10 @@ void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) { if (doc->stop != nullptr) { PrintDoc(doc->stop.value()); } + if (doc->step != nullptr) { + output_ << ":"; + PrintDoc(doc->step.value()); + } } String DocToPythonScript(Doc doc, int indent_spaces) { diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py index fcb630bd3ec3..7942eb8a9a9c 100644 --- a/tests/python/unittest/test_tvmscript_printer_doc.py +++ b/tests/python/unittest/test_tvmscript_printer_doc.py @@ -181,16 +181,10 @@ def test_dict_doc(content): assert dict(zip(doc.keys, doc.values)) == content -@pytest.mark.parametrize( - "start,stop", - [ - (LiteralDoc(1), LiteralDoc(2)), - (LiteralDoc(1), None), - (None, LiteralDoc(2)), - (None, None), - ], -) -def test_slice_doc(start, stop): +@pytest.mark.parametrize("start", [LiteralDoc(1), None]) +@pytest.mark.parametrize("stop", [LiteralDoc(2), None]) +@pytest.mark.parametrize("step", [LiteralDoc(3), None]) +def test_slice_doc(start, stop, step): doc = SliceDoc(start, stop) assert doc.start == start diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index 702fc151381d..c0e2c75ac00c 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -404,6 +404,22 @@ def test_print_dict_doc(content, expected): SliceDoc(LiteralDoc(1), LiteralDoc(2)), "1:2", ), + ( + SliceDoc(None, None, LiteralDoc(3)), + "::3", + ), + ( + SliceDoc(LiteralDoc(1), None, LiteralDoc(3)), + "1::3", + ), + ( + SliceDoc(None, LiteralDoc(2), LiteralDoc(3)), + ":2:3", + ), + ( + SliceDoc(LiteralDoc(1), LiteralDoc(2), LiteralDoc(3)), + "1:2:3", + ), ], ) def test_print_slice_doc(slice_doc, expected): From 19bfb25270c8c2c324c6de4ddd0b46c7af6781bb Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 19 Jul 2022 00:51:55 -0400 Subject: [PATCH 19/23] Use accurate typing --- python/tvm/script/printer/doc.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 2c81e72de348..efa35cb8a329 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -16,7 +16,7 @@ # under the License. """Doc types for TVMScript Unified Printer""" -from typing import List, Dict, Tuple, Optional, Union +from typing import List, Dict, Tuple, Optional, Union, Sequence from enum import IntEnum, unique import tvm._ffi @@ -131,7 +131,7 @@ class IndexDoc(ExprDoc): """Doc that represents index access on an expression""" value: ExprDoc - indices: tvm.ir.container.Array # actual type: List[Union[ExprDoc, "SliceDoc"]] + indices: Sequence[Union[ExprDoc, "SliceDoc"]] def __init__(self, value: ExprDoc, indices: List[Union[ExprDoc, "SliceDoc"]]): self.__init_handle_by_constructor__(_ffi_api.IndexDoc, value, indices) # type: ignore @@ -142,9 +142,9 @@ class CallDoc(ExprDoc): """Doc that represents function call""" callee: ExprDoc - args: tvm.ir.container.Array # actual type: List[ExprDoc] - kwargs_keys: tvm.ir.container.Array # actual type: List[str] - kwargs_values: tvm.ir.container.Array # actual type: List[ExprDoc] + args: Sequence[ExprDoc] + kwargs_keys: Sequence[str] + kwargs_values: Sequence[ExprDoc] def __init__(self, callee: ExprDoc, *args: Tuple[ExprDoc], **kwargs: Dict[str, ExprDoc]): kwargs_keys = list(kwargs.keys()) @@ -208,7 +208,7 @@ class OperationDoc(ExprDoc): """ kind: OperationKind - operands: tvm.ir.container.Array # actual type: List[ExprDoc] + operands: Sequence[ExprDoc] def __init__(self, kind: OperationKind, operands: List[ExprDoc]): self.__init_handle_by_constructor__(_ffi_api.OperationDoc, kind, operands) # type: ignore @@ -218,7 +218,7 @@ def __init__(self, kind: OperationKind, operands: List[ExprDoc]): class LambdaDoc(ExprDoc): """Doc that represents lambda function""" - args: tvm.ir.container.Array # actual type: List[IdDoc] + args: Sequence[IdDoc] body: ExprDoc def __init__(self, args: List[IdDoc], body: ExprDoc): @@ -229,7 +229,7 @@ def __init__(self, args: List[IdDoc], body: ExprDoc): class TupleDoc(ExprDoc): """Doc that represents tuple literal""" - elements: tvm.ir.container.Array # actual type: List[ExprDoc] + elements: Sequence[ExprDoc] def __init__(self, elements: List[ExprDoc]): self.__init_handle_by_constructor__(_ffi_api.TupleDoc, elements) # type: ignore @@ -239,7 +239,7 @@ def __init__(self, elements: List[ExprDoc]): class ListDoc(ExprDoc): """Doc that represents list literal""" - elements: tvm.ir.container.Array # actual type: List[ExprDoc] + elements: Sequence[ExprDoc] def __init__(self, elements: List[ExprDoc]): self.__init_handle_by_constructor__(_ffi_api.ListDoc, elements) # type: ignore @@ -249,8 +249,8 @@ def __init__(self, elements: List[ExprDoc]): class DictDoc(ExprDoc): """Doc that represents dict literal""" - keys: tvm.ir.container.Array # actual type: List[ExprDoc] - values: tvm.ir.container.Array # actual type: List[ExprDoc] + keys: Sequence[ExprDoc] + values: Sequence[ExprDoc] def __init__(self, content: Dict[ExprDoc, ExprDoc]): keys = list(content.keys()) @@ -270,5 +270,10 @@ class SliceDoc(ExprDoc): stop: Optional[ExprDoc] step: Optional[ExprDoc] - def __init__(self, start: Optional[ExprDoc] = None, stop: Optional[ExprDoc] = None, step: Optional[ExprDoc] = None): + def __init__( + self, + start: Optional[ExprDoc] = None, + stop: Optional[ExprDoc] = None, + step: Optional[ExprDoc] = None, + ): self.__init_handle_by_constructor__(_ffi_api.SliceDoc, start, stop, step) # type: ignore From 7b53eb24a46727ba84f89d39192b2b15c64df324 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 19 Jul 2022 00:57:48 -0400 Subject: [PATCH 20/23] Change variable names and add more validation --- src/script/printer/python_doc_printer.cc | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index c5ee1175b118..b01c85bb9ad7 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -101,7 +101,7 @@ void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) { } const std::string OperatorToString(OperationDocNode::Kind operation_kind) { - static const std::vector OP_STR_TABLE = []() { + static const std::vector op_kind2str = []() { using OpKind = OperationDocNode::Kind; std::map raw_table = { {OpKind::kUSub, "-"}, // @@ -137,13 +137,10 @@ const std::string OperatorToString(OperationDocNode::Kind operation_kind) { }(); auto op_index = static_cast(operation_kind); - ICHECK_LT(op_index, OP_STR_TABLE.size()); - const std::string str = OP_STR_TABLE[op_index]; - if (str.empty()) { - LOG(FATAL) << "OperationDocNode::Kind " << static_cast(operation_kind) - << " cannot be converted to operator token in Python directly."; - throw; - } + ICHECK_LT(op_index, op_kind2str.size()); + const std::string str = op_kind2str[op_index]; + ICHECK(!str.empty()) << "OperationDocNode::Kind " << static_cast(operation_kind) + << " cannot be converted to operator token in Python directly."; return str; } @@ -161,7 +158,7 @@ void PythonDocPrinter::PrintTypedDoc(const OperationDoc& doc) { output_ << " " << OperatorToString(doc->kind) << " "; PrintDoc(doc->operands[1]); } else if (doc->kind == OpKind::kIfThenElse) { - CHECK_EQ(doc->operands.size(), 3) + ICHECK_EQ(doc->operands.size(), 3) << "ValueError: IfThenElse requires 3 operands, but got " << doc->operands.size(); PrintDoc(doc->operands[1]); output_ << " if "; @@ -191,6 +188,8 @@ void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) { } // Print keyword args + ICHECK_EQ(doc->kwargs_keys.size(), doc->kwargs_values.size()) + << "CallDoc should have equal number of elements in kwargs_keys and kwargs_values."; for (size_t i = 0; i < doc->kwargs_keys.size(); i++) { if (is_first) { is_first = false; @@ -231,6 +230,8 @@ void PythonDocPrinter::PrintTypedDoc(const TupleDoc& doc) { } void PythonDocPrinter::PrintTypedDoc(const DictDoc& doc) { + ICHECK_EQ(doc->keys.size(), doc->values.size()) + << "DictDoc should have equal number of elements in keys and values."; output_ << "{"; size_t idx = 0; for (const ExprDoc& key : doc->keys) { From c3e3d3bf0763ab366f1369fda8259a3751dcbe4f Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 19 Jul 2022 14:33:38 -0400 Subject: [PATCH 21/23] Rename methods to be consistent between Python and C++ --- include/tvm/script/printer/doc.h | 8 +-- python/tvm/script/printer/doc.py | 50 ++++++++++++------- src/script/printer/doc.cc | 9 ++-- src/script/printer/python_doc_printer.cc | 2 +- .../unittest/test_tvmscript_printer_doc.py | 28 ++++++----- ...st_tvmscript_printer_python_doc_printer.py | 20 ++++---- 6 files changed, 69 insertions(+), 48 deletions(-) diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index a3ffc737611a..eaa16fab170c 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -82,7 +82,7 @@ class ExprDocNode : public DocNode { * \brief Create a doc representing index access on the current ExprDoc * \param indices The indices to access. */ - ExprDoc Index(Array indices) const; + ExprDoc operator[](Array indices) const; /*! * \brief Create a doc representing calling the current ExprDoc @@ -232,12 +232,12 @@ class AttrAccessDocNode : public ExprDocNode { /*! \brief The target expression to be accessed */ ExprDoc value{nullptr}; /*! \brief The attribute to be accessed */ - String attr; + String name; void VisitAttrs(AttrVisitor* v) { ExprDocNode::VisitAttrs(v); v->Visit("value", &value); - v->Visit("attr", &attr); + v->Visit("name", &name); } static constexpr const char* _type_key = "script.printer.AttrAccessDoc"; @@ -256,7 +256,7 @@ class AttrAccessDoc : public ExprDoc { * \param value The target expression of attribute access. * \param attr The name of attribute to access. */ - explicit AttrAccessDoc(ExprDoc value, String attr); + explicit AttrAccessDoc(ExprDoc value, String name); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AttrAccessDoc, ExprDoc, AttrAccessDocNode); }; diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index efa35cb8a329..42eecd8d6015 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -34,7 +34,7 @@ class Doc(Object): class ExprDoc(Object): """Base class of all expression Docs""" - def attr_access(self, attr: str) -> "AttrAccessDoc": + def attr(self, attr: str) -> "AttrAccessDoc": """ Create a doc that represents attribute access on self. @@ -49,22 +49,7 @@ def attr_access(self, attr: str) -> "AttrAccessDoc": """ return _ffi_api.ExprDocAttr(self, attr) # type: ignore - def index_access(self, indices: List[Union["ExprDoc", "SliceDoc"]]) -> "IndexDoc": - """ - Create a doc that represents index access on self. - - Parameters - ---------- - indices : List[Union["ExprDoc", "SliceDoc"]] - The indices to access - - Returns - ------- - doc : IndexDoc - """ - return _ffi_api.ExprDocIndex(self, indices) # type: ignore - - def call_with(self, *args: Tuple["ExprDoc"], **kwargs: Dict[str, "ExprDoc"]) -> "CallDoc": + def call(self, *args: Tuple["ExprDoc"], **kwargs: Dict[str, "ExprDoc"]) -> "CallDoc": """ Create a doc that represents function call, with self as callee. @@ -83,6 +68,37 @@ def call_with(self, *args: Tuple["ExprDoc"], **kwargs: Dict[str, "ExprDoc"]) -> kwargs_values = list(kwargs.values()) return _ffi_api.ExprDocCall(self, args, kwargs_keys, kwargs_values) # type: ignore + _IndexType = Union["ExprDoc", "SliceDoc"] + + def __getitem__(self, indices: Union[Tuple[_IndexType], _IndexType]) -> "IndexDoc": + """ + Create a doc that represents index access on self. + + Parameters + ---------- + indices : Union[Tuple[Union["ExprDoc", "SliceDoc"]], Union["ExprDoc", "SliceDoc"]] + The indices to access + + Returns + ------- + doc : IndexDoc + """ + if not isinstance(indices, tuple): + indices = (indices,) + return _ffi_api.ExprDocIndex(self, indices) # type: ignore + + def __iter__(self): + """ + This is implemented to prevent confusing error message when trying to use ExprDoc + as iterable. According to PEP-234, An object can be iterated over if it + implements __iter__() or __getitem__(). If an object has only __getitem__ + but not __iter__, interpreter will iterate the object by calling + __getitem__ with 0, 1, 2, ..., until an IndexError is raised. + + https://peps.python.org/pep-0234/#python-api-specification + """ + raise RuntimeError(f"{self.__class__} cannot be used as iterable.") + @tvm._ffi.register_object("script.printer.LiteralDoc") class LiteralDoc(ExprDoc): diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index a529fbf7e63c..ed81f9d2dd26 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -25,7 +25,7 @@ namespace printer { ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef(this), attr); } -ExprDoc ExprDocNode::Index(Array indices) const { +ExprDoc ExprDocNode::operator[](Array indices) const { return IndexDoc(GetRef(this), indices); } @@ -50,10 +50,10 @@ IdDoc::IdDoc(String name) { this->data_ = std::move(n); } -AttrAccessDoc::AttrAccessDoc(ExprDoc value, String attr) { +AttrAccessDoc::AttrAccessDoc(ExprDoc value, String name) { ObjectPtr n = make_object(); n->value = value; - n->attr = attr; + n->name = name; this->data_ = std::move(n); } @@ -119,7 +119,8 @@ TVM_REGISTER_NODE_TYPE(DocNode); TVM_REGISTER_NODE_TYPE(ExprDocNode); TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr").set_body_method(&ExprDocNode::Attr); -TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex").set_body_method(&ExprDocNode::Index); +TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex") + .set_body_method(&ExprDocNode::operator[]); TVM_REGISTER_GLOBAL("script.printer.ExprDocCall") .set_body_method, Array, Array>( &ExprDocNode::Call); diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index b01c85bb9ad7..5c7b048f8144 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -86,7 +86,7 @@ void PythonDocPrinter::PrintTypedDoc(const IdDoc& doc) { output_ << doc->name; } void PythonDocPrinter::PrintTypedDoc(const AttrAccessDoc& doc) { PrintDoc(doc->value); - output_ << "." << doc->attr; + output_ << "." << doc->name; } void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) { diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py index 7942eb8a9a9c..4ff6a0f547d7 100644 --- a/tests/python/unittest/test_tvmscript_printer_doc.py +++ b/tests/python/unittest/test_tvmscript_printer_doc.py @@ -58,7 +58,7 @@ def test_attr_access_doc(): doc = AttrAccessDoc(target, "attribute") assert doc.value == target - assert doc.attr == "attribute" + assert doc.name == "attribute" @pytest.mark.parametrize( @@ -195,29 +195,33 @@ def test_expr_doc_attr_access(): target = IdDoc("x") attr = "test" - doc = target.attr_access(attr) + doc = target.attr(attr) assert doc.value == target - assert doc.attr == attr + assert doc.name == attr @pytest.mark.parametrize( "indices", [ - [], - [LiteralDoc(1)], - [LiteralDoc(2), IdDoc("x")], - [SliceDoc(LiteralDoc(1), LiteralDoc(2))], - [SliceDoc(LiteralDoc(1)), IdDoc("y")], + (), + LiteralDoc(1), + SliceDoc(LiteralDoc(1), LiteralDoc(2)), + (LiteralDoc(1),), + (LiteralDoc(2), IdDoc("x")), + (SliceDoc(LiteralDoc(1), LiteralDoc(2)),), + (SliceDoc(LiteralDoc(1)), IdDoc("y")), ], ) -def test_expr_doc_index_access(indices): +def test_expr_doc_get_item(indices): target = IdDoc("x") - doc = target.index_access(indices) + doc = target[indices] assert doc.value == target - assert list(doc.indices) == indices + if not isinstance(indices, tuple): + indices = (indices,) + assert tuple(doc.indices) == indices @pytest.mark.parametrize( @@ -235,7 +239,7 @@ def test_expr_doc_index_access(indices): def test_expr_doc_call_with(args, kwargs): target = IdDoc("x") - doc = target.call_with(*args, **kwargs) + doc = target.call(*args, **kwargs) assert doc.callee == target assert list(doc.args) == args diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index c0e2c75ac00c..b65eaa6b98a1 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -89,7 +89,7 @@ def test_print_id_doc(name): ], ) def test_print_attr_doc(attr): - doc = IdDoc("x").attr_access(attr) + doc = IdDoc("x").attr(attr) assert to_python_script(doc) == format_script(f"x.{attr}") @@ -97,37 +97,37 @@ def test_print_attr_doc(attr): "indices, expected", [ ( - [], + (), "[()]", ), ( - [LiteralDoc(1)], + (LiteralDoc(1),), "[1]", ), ( - [LiteralDoc(2), IdDoc("x")], + (LiteralDoc(2), IdDoc("x")), "[2, x]", ), ( - [SliceDoc(LiteralDoc(1), LiteralDoc(2))], + (SliceDoc(LiteralDoc(1), LiteralDoc(2)),), "[1:2]", ), ( - [SliceDoc(LiteralDoc(1)), IdDoc("y")], + (SliceDoc(LiteralDoc(1)), IdDoc("y")), "[1:, y]", ), ( - [SliceDoc(), IdDoc("y")], + (SliceDoc(), IdDoc("y")), "[:, y]", ), ( - [IdDoc("x"), IdDoc("y"), IdDoc("z")], + (IdDoc("x"), IdDoc("y"), IdDoc("z")), "[x, y, z]", ), ], ) def test_print_index_doc(indices, expected): - doc = IdDoc("x").index_access(indices) + doc = IdDoc("x")[indices] assert to_python_script(doc) == format_script(f"x{expected}") @@ -423,5 +423,5 @@ def test_print_dict_doc(content, expected): ], ) def test_print_slice_doc(slice_doc, expected): - doc = IdDoc("x").index_access([slice_doc]) + doc = IdDoc("x")[slice_doc] assert to_python_script(doc) == format_script(f"x[{expected}]") From 6fb9a1dff00c4d29edecfe21202e1be856e8564f Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 19 Jul 2022 14:34:49 -0400 Subject: [PATCH 22/23] Add missing doc --- include/tvm/script/printer/doc.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index eaa16fab170c..cc9925b45456 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -634,6 +634,7 @@ class SliceDoc : public Doc { * \brief Constructor of SliceDoc * \param start The start of slice. * \param stop The exclusive end of slice. + * \param step The step of slice. */ explicit SliceDoc(Optional start, Optional stop, Optional step); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SliceDoc, Doc, SliceDocNode); From f53aaccb1e29dbcc34a1d3e6bf5aac0c4eafa2fb Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 19 Jul 2022 15:11:48 -0400 Subject: [PATCH 23/23] Fix lint --- include/tvm/script/printer/doc.h | 2 +- python/tvm/script/printer/doc.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index cc9925b45456..f3f980e53f5e 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -254,7 +254,7 @@ class AttrAccessDoc : public ExprDoc { /*! * \brief Constructor of AttrAccessDoc * \param value The target expression of attribute access. - * \param attr The name of attribute to access. + * \param name The name of attribute to access. */ explicit AttrAccessDoc(ExprDoc value, String name); TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AttrAccessDoc, ExprDoc, AttrAccessDocNode); diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 42eecd8d6015..acdb63dcf250 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -34,20 +34,20 @@ class Doc(Object): class ExprDoc(Object): """Base class of all expression Docs""" - def attr(self, attr: str) -> "AttrAccessDoc": + def attr(self, name: str) -> "AttrAccessDoc": """ Create a doc that represents attribute access on self. Parameters ---------- - attr : str + name : str The attribute name to access Returns ------- doc : AttrAccessDoc """ - return _ffi_api.ExprDocAttr(self, attr) # type: ignore + return _ffi_api.ExprDocAttr(self, name) # type: ignore def call(self, *args: Tuple["ExprDoc"], **kwargs: Dict[str, "ExprDoc"]) -> "CallDoc": """ @@ -136,10 +136,10 @@ class AttrAccessDoc(ExprDoc): """Doc that represents attribute access on an expression""" value: ExprDoc - attr: str + name: str - def __init__(self, value: ExprDoc, attr: str): - self.__init_handle_by_constructor__(_ffi_api.AttrAccessDoc, value, attr) # type: ignore + def __init__(self, value: ExprDoc, name: str): + self.__init_handle_by_constructor__(_ffi_api.AttrAccessDoc, value, name) # type: ignore @tvm._ffi.register_object("script.printer.IndexDoc")