diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 67c27bd45a1d..f3f980e53f5e 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -63,6 +63,8 @@ class Doc : public ObjectRef { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Doc, ObjectRef, DocNode); }; +class ExprDoc; + /*! * \brief The base class of expression doc. * @@ -70,6 +72,34 @@ class Doc : public ObjectRef { */ class ExprDocNode : public DocNode { public: + /*! + * \brief Create a doc representing attribute access on the current ExprDoc + * \param attr The attribute to access. + */ + ExprDoc Attr(String attr) const; + + /*! + * \brief Create a doc representing index access on the current ExprDoc + * \param indices The indices to access. + */ + ExprDoc operator[](Array indices) const; + + /*! + * \brief Create a doc representing calling the current ExprDoc + * \param args The positional arguments of the function call. + */ + ExprDoc Call(Array args) const; + + /*! + * \brief Create a doc representing attribute access on the current ExprDoc + * \param args The positional arguments of the function call. + * \param kwargs_keys Keys of keywords arguments of the function call. + * \param kwargs_values Values of keywords arguments of the function call. + */ + ExprDoc Call(Array args, // + Array kwargs_keys, // + Array kwargs_values) const; + void VisitAttrs(AttrVisitor* v) { DocNode::VisitAttrs(v); } static constexpr const char* _type_key = "script.printer.ExprDoc"; @@ -158,6 +188,458 @@ class LiteralDoc : public ExprDoc { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode); }; +/*! + * \brief Doc that represents identifier. + * + * \sa IdDoc + */ +class IdDocNode : public ExprDocNode { + public: + /*! \brief The name of the identifier */ + String name; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("name", &name); + } + + static constexpr const char* _type_key = "script.printer.IdDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(IdDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of IdDocNode. + * + * \sa IdDocNode + */ +class IdDoc : public ExprDoc { + public: + /*! + * \brief Constructor of IdDoc. + * \param name The name of identifier. + */ + explicit IdDoc(String name); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IdDoc, ExprDoc, IdDocNode); +}; + +/*! + * \brief Doc that represents attribute access on another expression. + * + * \sa AttrAccessDoc + */ +class AttrAccessDocNode : public ExprDocNode { + public: + /*! \brief The target expression to be accessed */ + ExprDoc value{nullptr}; + /*! \brief The attribute to be accessed */ + String name; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("value", &value); + v->Visit("name", &name); + } + + static constexpr const char* _type_key = "script.printer.AttrAccessDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(AttrAccessDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of AttrAccessDocNode. + * + * \sa AttrAccessDocNode + */ +class AttrAccessDoc : public ExprDoc { + public: + /*! + * \brief Constructor of AttrAccessDoc + * \param value The target expression of attribute access. + * \param name The name of attribute to access. + */ + explicit AttrAccessDoc(ExprDoc value, String name); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(AttrAccessDoc, ExprDoc, AttrAccessDocNode); +}; + +/*! + * \brief Doc that represents index access on another expression. + * + * \sa IndexDoc + */ +class IndexDocNode : public ExprDocNode { + public: + /*! \brief The container value to be accessed */ + ExprDoc value{nullptr}; + /*! + * \brief The indices to access + * + * Possible actual types: + * - ExprDoc (single point access like a[1, 2]) + * - SliceDoc (slice access like a[1:5, 2]) + */ + Array indices; // Each element is union of: Slice / ExprDoc + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("value", &value); + v->Visit("indices", &indices); + } + + static constexpr const char* _type_key = "script.printer.IndexDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(IndexDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of IndexDocNode. + * + * \sa IndexDocNode + */ +class IndexDoc : public ExprDoc { + public: + /*! + * \brief Constructor of IndexDoc + * \param value The target expression of index access. + * \param indices The indices to access. + */ + explicit IndexDoc(ExprDoc value, Array indices); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(IndexDoc, ExprDoc, IndexDocNode); +}; + +/*! + * \brief Doc that represents function call. + * + * \sa CallDoc + */ +class CallDocNode : public ExprDocNode { + public: + /*! \brief The callee of this function call */ + ExprDoc callee{nullptr}; + /*! \brief The positional arguments */ + Array args; + /*! \brief The keys of keyword arguments */ + Array kwargs_keys; + /*! + * \brief The values of keyword arguments. + * + * The i-th element is the value of the i-th key in `kwargs_keys`. + * It must have the same length as `kwargs_keys`. + */ + Array kwargs_values; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("callee", &callee); + v->Visit("args", &args); + v->Visit("kwargs_keys", &kwargs_keys); + v->Visit("kwargs_values", &kwargs_values); + } + + static constexpr const char* _type_key = "script.printer.CallDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(CallDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of CallDocNode. + * + * \sa CallDocNode + */ +class CallDoc : public ExprDoc { + public: + /*! + * \brief Constructor of CallDoc + * \param callee The callee of this function call. + * \param args The positional arguments. + * \param kwargs_keys Keys of keyword arguments. + * \param kwargs_values Values of keyword arguments, must have the same length as `kwargs_keys. + */ + CallDoc(ExprDoc callee, Array args, Array kwargs_keys, + Array kwargs_values); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(CallDoc, ExprDoc, CallDocNode); +}; + +/*! + * \brief Doc that represents operation. + * + * It can be unary, binary and other special operators (for example, + * the if-then-else expression). + * + * \sa OperationDoc + */ +class OperationDocNode : public ExprDocNode { + public: + enum class Kind : int32_t { + // Unary operators + kUnaryStart = 0, + kUSub = 1, // -x + kInvert = 2, // ~x + kUnaryEnd = 3, + + // Binary operators + kBinaryStart = 4, + kAdd = 5, // + + kSub = 6, // - + kMult = 7, // * + kDiv = 8, // / + kFloorDiv = 9, // // in Python + kMod = 10, // % in Python + kPow = 11, // ** in Python + kLShift = 12, // << + kRShift = 13, // >> + kBitAnd = 14, // & + kBitOr = 15, // | + kBitXor = 16, // ^ + kLt = 17, // < + kLtE = 18, // <= + kEq = 19, // == + kNotEq = 20, // != + kGt = 21, // > + kGtE = 22, // >= + kBinaryEnd = 23, + + // Special + kSpecialStart = 24, + kIfThenElse = 25, // if else + kSpecialEnd = 26 + }; + + /*! \brief The kind of operation (operator) */ + Kind kind; + /*! \brief Operands of this expression */ + Array operands; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("kind", &kind); + v->Visit("operands", &operands); + } + + static constexpr const char* _type_key = "script.printer.OperationDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(OperationDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of OperationDocNode. + * + * \sa OperationDocNode + */ +class OperationDoc : public ExprDoc { + public: + /*! + * \brief Constructor of OperationDoc + * \param kind The kind of operation. + * \param operands Operands of this expression. + */ + explicit OperationDoc(OperationDocNode::Kind kind, Array operands); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(OperationDoc, ExprDoc, OperationDocNode); +}; + +/*! + * \brief Doc that represents anonymous function. + * + * LambdaDoc can only have positional arguments without type annotation, + * and a single expression as body. + * + * \sa LambdaDoc + */ +class LambdaDocNode : public ExprDocNode { + public: + /*! \brief The arguments of this anonymous function */ + Array args; + /*! \brief The body of this anonymous function */ + ExprDoc body{nullptr}; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("args", &args); + v->Visit("body", &body); + } + + static constexpr const char* _type_key = "script.printer.LambdaDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(LambdaDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of LambdaDocNode. + * + * \sa LambdaDocNode + */ +class LambdaDoc : public ExprDoc { + public: + /*! + * \brief Constructor of LambdaDoc + * \param args Arguments of this function. + * \param body Body expression of this function. + */ + explicit LambdaDoc(Array args, ExprDoc body); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LambdaDoc, ExprDoc, LambdaDocNode); +}; + +/*! + * \brief Doc that represents tuple literal. + * + * \sa TupleDoc + */ +class TupleDocNode : public ExprDocNode { + public: + /*! \brief Elements of tuple */ + Array elements; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("elements", &elements); + } + + static constexpr const char* _type_key = "script.printer.TupleDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(TupleDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of TupleDocNode. + * + * \sa TupleDocNode + */ +class TupleDoc : public ExprDoc { + public: + /*! + * \brief Create an empty TupleDoc + */ + TupleDoc() : TupleDoc(runtime::make_object()) {} + /*! + * \brief Constructor of TupleDoc + * \param elements Elements of tuple. + */ + explicit TupleDoc(Array elements); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleDoc, ExprDoc, TupleDocNode); +}; + +/*! + * \brief Doc that represents list literal. + * + * \sa AttrAccessDoc + */ +class ListDocNode : public ExprDocNode { + public: + /*! \brief Elements of list */ + Array elements; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("elements", &elements); + } + + static constexpr const char* _type_key = "script.printer.ListDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(ListDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of ListDocNode. + * + * \sa ListDocNode + */ +class ListDoc : public ExprDoc { + public: + /*! + * \brief Create an empty ListDoc + */ + ListDoc() : ListDoc(runtime::make_object()) {} + /*! + * \brief Constructor of ListDoc + * \param elements Elements of list. + */ + explicit ListDoc(Array elements); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ListDoc, ExprDoc, ListDocNode); +}; + +/*! + * \brief Doc that represents dictionary literal. + * + * \sa AttrAccessDoc + */ +class DictDocNode : public ExprDocNode { + public: + /*! \brief keys of dictionary */ + Array keys; + /*! + * \brief Values of dictionary + * + * The i-th element is the value of the i-th element of `keys`. + * It must have the same length as `keys`. + */ + Array values; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("keys", &keys); + v->Visit("values", &values); + } + + static constexpr const char* _type_key = "script.printer.DictDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(DictDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of DictDocNode. + * + * \sa DictDocNode + */ +class DictDoc : public ExprDoc { + public: + /*! + * \brief Create an empty dictionary + */ + DictDoc() : DictDoc(runtime::make_object()) {} + /*! + * \brief Constructor of DictDoc + * \param keys Keys of dictionary. + * \param values Values of dictionary, must have same length as `keys`. + */ + explicit DictDoc(Array keys, Array values); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(DictDoc, ExprDoc, DictDocNode); +}; + +/*! + * \brief Doc that represents slice in Index expression. + * + * This doc can only appear in IndexDoc::indices. + * + * \sa AttrAccessDoc + */ +class SliceDocNode : public DocNode { + public: + /*! \brief The start of slice */ + Optional start; + /*! \brief The exclusive end of slice */ + Optional stop; + /*! \brief The step of slice */ + Optional step; + + void VisitAttrs(AttrVisitor* v) { + DocNode::VisitAttrs(v); + v->Visit("start", &start); + v->Visit("stop", &stop); + v->Visit("step", &step); + } + + static constexpr const char* _type_key = "script.printer.SliceDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(SliceDocNode, DocNode); +}; + +/*! + * \brief Reference type of SliceDocNode. + * + * \sa SliceDocNode + */ +class SliceDoc : public Doc { + public: + /*! + * \brief Constructor of SliceDoc + * \param start The start of slice. + * \param stop The exclusive end of slice. + * \param step The step of slice. + */ + explicit SliceDoc(Optional start, Optional stop, Optional step); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(SliceDoc, Doc, SliceDocNode); +}; + } // namespace printer } // namespace script } // namespace tvm diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index f6179d7351b2..acdb63dcf250 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -16,8 +16,13 @@ # under the License. """Doc types for TVMScript Unified Printer""" +from typing import List, Dict, Tuple, Optional, Union, Sequence +from enum import IntEnum, unique + import tvm._ffi +import tvm.ir.container from tvm.runtime import Object +from tvm.tir import FloatImm, IntImm from . import _ffi_api @@ -29,12 +34,79 @@ class Doc(Object): class ExprDoc(Object): """Base class of all expression Docs""" + def attr(self, name: str) -> "AttrAccessDoc": + """ + Create a doc that represents attribute access on self. + + Parameters + ---------- + name : str + The attribute name to access + + Returns + ------- + doc : AttrAccessDoc + """ + return _ffi_api.ExprDocAttr(self, name) # type: ignore + + def call(self, *args: Tuple["ExprDoc"], **kwargs: Dict[str, "ExprDoc"]) -> "CallDoc": + """ + Create a doc that represents function call, with self as callee. + + Parameters + ---------- + *args : ExprDoc + The positional arguments of the function call. + **kwargs + The keyword arguments of the function call. + + Returns + ------- + doc : CallDoc + """ + kwargs_keys = list(kwargs.keys()) + kwargs_values = list(kwargs.values()) + return _ffi_api.ExprDocCall(self, args, kwargs_keys, kwargs_values) # type: ignore + + _IndexType = Union["ExprDoc", "SliceDoc"] + + def __getitem__(self, indices: Union[Tuple[_IndexType], _IndexType]) -> "IndexDoc": + """ + Create a doc that represents index access on self. + + Parameters + ---------- + indices : Union[Tuple[Union["ExprDoc", "SliceDoc"]], Union["ExprDoc", "SliceDoc"]] + The indices to access + + Returns + ------- + doc : IndexDoc + """ + if not isinstance(indices, tuple): + indices = (indices,) + return _ffi_api.ExprDocIndex(self, indices) # type: ignore + + def __iter__(self): + """ + This is implemented to prevent confusing error message when trying to use ExprDoc + as iterable. According to PEP-234, An object can be iterated over if it + implements __iter__() or __getitem__(). If an object has only __getitem__ + but not __iter__, interpreter will iterate the object by calling + __getitem__ with 0, 1, 2, ..., until an IndexError is raised. + + https://peps.python.org/pep-0234/#python-api-specification + """ + raise RuntimeError(f"{self.__class__} cannot be used as iterable.") + @tvm._ffi.register_object("script.printer.LiteralDoc") class LiteralDoc(ExprDoc): """Doc that represents literal value""" - def __init__(self, value): + value: Union[str, IntImm, FloatImm, None] + + def __init__(self, value: Union[str, float, bool, int, None]): if value is None: self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone) # type: ignore elif isinstance(value, str): @@ -47,3 +119,177 @@ def __init__(self, value): self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt, value) # type: ignore else: raise TypeError(f"Unsupported type {type(value)} for LiteralDoc") + + +@tvm._ffi.register_object("script.printer.IdDoc") +class IdDoc(ExprDoc): + """Doc that represents identifier""" + + name: str + + def __init__(self, name: str): + self.__init_handle_by_constructor__(_ffi_api.IdDoc, name) # type: ignore + + +@tvm._ffi.register_object("script.printer.AttrAccessDoc") +class AttrAccessDoc(ExprDoc): + """Doc that represents attribute access on an expression""" + + value: ExprDoc + name: str + + def __init__(self, value: ExprDoc, name: str): + self.__init_handle_by_constructor__(_ffi_api.AttrAccessDoc, value, name) # type: ignore + + +@tvm._ffi.register_object("script.printer.IndexDoc") +class IndexDoc(ExprDoc): + """Doc that represents index access on an expression""" + + value: ExprDoc + indices: Sequence[Union[ExprDoc, "SliceDoc"]] + + def __init__(self, value: ExprDoc, indices: List[Union[ExprDoc, "SliceDoc"]]): + self.__init_handle_by_constructor__(_ffi_api.IndexDoc, value, indices) # type: ignore + + +@tvm._ffi.register_object("script.printer.CallDoc") +class CallDoc(ExprDoc): + """Doc that represents function call""" + + callee: ExprDoc + args: Sequence[ExprDoc] + kwargs_keys: Sequence[str] + kwargs_values: Sequence[ExprDoc] + + def __init__(self, callee: ExprDoc, *args: Tuple[ExprDoc], **kwargs: Dict[str, ExprDoc]): + kwargs_keys = list(kwargs.keys()) + kwargs_values = list(kwargs.values()) + self.__init_handle_by_constructor__( + _ffi_api.CallDoc, callee, args, kwargs_keys, kwargs_values # type: ignore + ) + + +@unique +class OperationKind(IntEnum): + """ + This enum represents the kind of operation (operator) in OpeartionDoc + + It's mirrored from OperationDocNode::Kind at include/tvm/script/printer/doc.h + """ + + # The name convention follows https://docs.python.org/3/library/ast.html + # pylint: disable=invalid-name + + _UnaryStart = 0 + USub = 1 + Invert = 2 + _UnaryEnd = 3 + + _BinaryStart = 4 + Add = 5 + Sub = 6 + Mult = 7 + Div = 8 + FloorDiv = 9 + Mod = 10 + Pow = 11 + LShift = 12 + RShift = 13 + BitAnd = 14 + BitOr = 15 + BitXor = 16 + Lt = 17 + LtE = 18 + Eq = 19 + NotEq = 20 + Gt = 21 + GtE = 22 + _BinaryEnd = 23 + + _SpecialStart = 24 + IfThenElse = 25 + _SpecialEnd = 26 + + # pylint: enable=invalid-name + + +@tvm._ffi.register_object("script.printer.OperationDoc") +class OperationDoc(ExprDoc): + """ + Doc that represents operation + + It can be unary, binary and other special operators (for example, the + if-then-else expression). + """ + + kind: OperationKind + operands: Sequence[ExprDoc] + + def __init__(self, kind: OperationKind, operands: List[ExprDoc]): + self.__init_handle_by_constructor__(_ffi_api.OperationDoc, kind, operands) # type: ignore + + +@tvm._ffi.register_object("script.printer.LambdaDoc") +class LambdaDoc(ExprDoc): + """Doc that represents lambda function""" + + args: Sequence[IdDoc] + body: ExprDoc + + def __init__(self, args: List[IdDoc], body: ExprDoc): + self.__init_handle_by_constructor__(_ffi_api.LambdaDoc, args, body) # type: ignore + + +@tvm._ffi.register_object("script.printer.TupleDoc") +class TupleDoc(ExprDoc): + """Doc that represents tuple literal""" + + elements: Sequence[ExprDoc] + + def __init__(self, elements: List[ExprDoc]): + self.__init_handle_by_constructor__(_ffi_api.TupleDoc, elements) # type: ignore + + +@tvm._ffi.register_object("script.printer.ListDoc") +class ListDoc(ExprDoc): + """Doc that represents list literal""" + + elements: Sequence[ExprDoc] + + def __init__(self, elements: List[ExprDoc]): + self.__init_handle_by_constructor__(_ffi_api.ListDoc, elements) # type: ignore + + +@tvm._ffi.register_object("script.printer.DictDoc") +class DictDoc(ExprDoc): + """Doc that represents dict literal""" + + keys: Sequence[ExprDoc] + values: Sequence[ExprDoc] + + def __init__(self, content: Dict[ExprDoc, ExprDoc]): + keys = list(content.keys()) + values = list(content.values()) + self.__init_handle_by_constructor__(_ffi_api.DictDoc, keys, values) # type: ignore + + +@tvm._ffi.register_object("script.printer.SliceDoc") +class SliceDoc(ExprDoc): + """ + Doc that represents slice in Index expression + + This doc can only appear in `IndexDoc.indices`. + """ + + start: Optional[ExprDoc] + stop: Optional[ExprDoc] + step: Optional[ExprDoc] + + def __init__( + self, + start: Optional[ExprDoc] = None, + stop: Optional[ExprDoc] = None, + step: Optional[ExprDoc] = None, + ): + self.__init_handle_by_constructor__(_ffi_api.SliceDoc, start, stop, step) # type: ignore diff --git a/src/script/printer/base_doc_printer.cc b/src/script/printer/base_doc_printer.cc index f6874ba1a2ee..42d3f2d8f3ac 100644 --- a/src/script/printer/base_doc_printer.cc +++ b/src/script/printer/base_doc_printer.cc @@ -38,6 +38,26 @@ String DocPrinter::GetString() const { void DocPrinter::PrintDoc(const Doc& doc) { if (const auto* doc_node = doc.as()) { PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); } else { LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey(); throw; diff --git a/src/script/printer/base_doc_printer.h b/src/script/printer/base_doc_printer.h index 128fcef2ea32..d5bfdcd94c6b 100644 --- a/src/script/printer/base_doc_printer.h +++ b/src/script/printer/base_doc_printer.h @@ -83,6 +83,56 @@ class DocPrinter { */ virtual void PrintTypedDoc(const LiteralDoc& doc) = 0; + /*! + * \brief Virtual method to print a IdDoc + */ + virtual void PrintTypedDoc(const IdDoc& doc) = 0; + + /*! + * \brief Virtual method to print a AttrAccessDoc + */ + virtual void PrintTypedDoc(const AttrAccessDoc& doc) = 0; + + /*! + * \brief Virtual method to print a IndexDoc + */ + virtual void PrintTypedDoc(const IndexDoc& doc) = 0; + + /*! + * \brief Virtual method to print a OperationDoc + */ + virtual void PrintTypedDoc(const OperationDoc& doc) = 0; + + /*! + * \brief Virtual method to print a CallDoc + */ + virtual void PrintTypedDoc(const CallDoc& doc) = 0; + + /*! + * \brief Virtual method to print a LambdaDoc + */ + virtual void PrintTypedDoc(const LambdaDoc& doc) = 0; + + /*! + * \brief Virtual method to print a ListDoc + */ + virtual void PrintTypedDoc(const ListDoc& doc) = 0; + + /*! + * \brief Virtual method to print a TupleDoc + */ + virtual void PrintTypedDoc(const TupleDoc& doc) = 0; + + /*! + * \brief Virtual method to print a DictDoc + */ + virtual void PrintTypedDoc(const DictDoc& doc) = 0; + + /*! + * \brief Virtual method to print a SliceDoc + */ + virtual void PrintTypedDoc(const SliceDoc& doc) = 0; + /*! * \brief Increase the indent level of any content to be * printed after this call diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index e54adbd36b4c..ed81f9d2dd26 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -23,14 +23,108 @@ namespace tvm { namespace script { namespace printer { +ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef(this), attr); } + +ExprDoc ExprDocNode::operator[](Array indices) const { + return IndexDoc(GetRef(this), indices); +} + +ExprDoc ExprDocNode::Call(Array args) const { + return CallDoc(GetRef(this), args, {}, {}); +} + +ExprDoc ExprDocNode::Call(Array args, Array kwargs_keys, + Array kwargs_values) const { + return CallDoc(GetRef(this), args, kwargs_keys, kwargs_values); +} + LiteralDoc::LiteralDoc(ObjectRef value) { ObjectPtr n = make_object(); n->value = value; this->data_ = std::move(n); } +IdDoc::IdDoc(String name) { + ObjectPtr n = make_object(); + n->name = name; + this->data_ = std::move(n); +} + +AttrAccessDoc::AttrAccessDoc(ExprDoc value, String name) { + ObjectPtr n = make_object(); + n->value = value; + n->name = name; + this->data_ = std::move(n); +} + +IndexDoc::IndexDoc(ExprDoc value, Array indices) { + ObjectPtr n = make_object(); + n->value = value; + n->indices = indices; + this->data_ = std::move(n); +} + +CallDoc::CallDoc(ExprDoc callee, Array args, Array kwargs_keys, + Array kwargs_values) { + ObjectPtr n = make_object(); + n->callee = callee; + n->args = args; + n->kwargs_keys = kwargs_keys; + n->kwargs_values = kwargs_values; + this->data_ = std::move(n); +} + +OperationDoc::OperationDoc(OperationDocNode::Kind kind, Array operands) { + ObjectPtr n = make_object(); + n->kind = kind; + n->operands = operands; + this->data_ = std::move(n); +} + +LambdaDoc::LambdaDoc(Array args, ExprDoc body) { + ObjectPtr n = make_object(); + n->args = args; + n->body = body; + this->data_ = std::move(n); +} + +TupleDoc::TupleDoc(Array elements) { + ObjectPtr n = make_object(); + n->elements = elements; + this->data_ = std::move(n); +} + +ListDoc::ListDoc(Array elements) { + ObjectPtr n = make_object(); + n->elements = elements; + this->data_ = std::move(n); +} + +DictDoc::DictDoc(Array keys, Array values) { + ObjectPtr n = make_object(); + n->keys = keys; + n->values = values; + this->data_ = std::move(n); +} + +SliceDoc::SliceDoc(Optional start, Optional stop, Optional step) { + ObjectPtr n = make_object(); + n->start = start; + n->stop = stop; + n->step = step; + this->data_ = std::move(n); +} + TVM_REGISTER_NODE_TYPE(DocNode); + TVM_REGISTER_NODE_TYPE(ExprDocNode); +TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr").set_body_method(&ExprDocNode::Attr); +TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex") + .set_body_method(&ExprDocNode::operator[]); +TVM_REGISTER_GLOBAL("script.printer.ExprDocCall") + .set_body_method, Array, Array>( + &ExprDocNode::Call); + TVM_REGISTER_NODE_TYPE(LiteralDocNode); TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None); TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int); @@ -38,6 +132,59 @@ TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDo TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float); TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str); +TVM_REGISTER_NODE_TYPE(IdDocNode); +TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); }); + +TVM_REGISTER_NODE_TYPE(AttrAccessDocNode); +TVM_REGISTER_GLOBAL("script.printer.AttrAccessDoc").set_body_typed([](ExprDoc value, String attr) { + return AttrAccessDoc(value, attr); +}); + +TVM_REGISTER_NODE_TYPE(IndexDocNode); +TVM_REGISTER_GLOBAL("script.printer.IndexDoc") + .set_body_typed([](ExprDoc value, Array indices) { return IndexDoc(value, indices); }); + +TVM_REGISTER_NODE_TYPE(CallDocNode); +TVM_REGISTER_GLOBAL("script.printer.CallDoc") + .set_body_typed([](ExprDoc callee, // + Array args, // + Array kwargs_keys, // + Array kwargs_values) { + return CallDoc(callee, args, kwargs_keys, kwargs_values); + }); + +TVM_REGISTER_NODE_TYPE(OperationDocNode); +TVM_REGISTER_GLOBAL("script.printer.OperationDoc") + .set_body_typed([](int32_t kind, Array operands) { + return OperationDoc(OperationDocNode::Kind(kind), operands); + }); + +TVM_REGISTER_NODE_TYPE(LambdaDocNode); +TVM_REGISTER_GLOBAL("script.printer.LambdaDoc").set_body_typed([](Array args, ExprDoc body) { + return LambdaDoc(args, body); +}); + +TVM_REGISTER_NODE_TYPE(TupleDocNode); +TVM_REGISTER_GLOBAL("script.printer.TupleDoc").set_body_typed([](Array elements) { + return TupleDoc(elements); +}); + +TVM_REGISTER_NODE_TYPE(ListDocNode); +TVM_REGISTER_GLOBAL("script.printer.ListDoc").set_body_typed([](Array elements) { + return ListDoc(elements); +}); + +TVM_REGISTER_NODE_TYPE(DictDocNode); +TVM_REGISTER_GLOBAL("script.printer.DictDoc") + .set_body_typed([](Array keys, Array values) { + return DictDoc(keys, values); + }); + +TVM_REGISTER_NODE_TYPE(SliceDocNode); +TVM_REGISTER_GLOBAL("script.printer.SliceDoc") + .set_body_typed([](Optional start, Optional stop, Optional step) { + return SliceDoc(start, stop, step); + }); } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index cd816e4f7010..5c7b048f8144 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include "../../support/str_escape.h" @@ -34,6 +35,30 @@ class PythonDocPrinter : public DocPrinter { using DocPrinter::PrintDoc; void PrintTypedDoc(const LiteralDoc& doc) final; + void PrintTypedDoc(const IdDoc& doc) final; + void PrintTypedDoc(const AttrAccessDoc& doc) final; + void PrintTypedDoc(const IndexDoc& doc) final; + void PrintTypedDoc(const OperationDoc& doc) final; + void PrintTypedDoc(const CallDoc& doc) final; + void PrintTypedDoc(const LambdaDoc& doc) final; + void PrintTypedDoc(const ListDoc& doc) final; + void PrintTypedDoc(const DictDoc& doc) final; + void PrintTypedDoc(const TupleDoc& doc) final; + void PrintTypedDoc(const SliceDoc& doc) final; + + private: + template + void PrintJoinedDocs(const Array& docs, const std::string& separator) { + bool is_first = true; + for (auto& doc : docs) { + if (is_first) { + is_first = false; + } else { + output_ << separator; + } + PrintDoc(doc); + } + } }; void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { @@ -57,6 +82,184 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { } } +void PythonDocPrinter::PrintTypedDoc(const IdDoc& doc) { output_ << doc->name; } + +void PythonDocPrinter::PrintTypedDoc(const AttrAccessDoc& doc) { + PrintDoc(doc->value); + output_ << "." << doc->name; +} + +void PythonDocPrinter::PrintTypedDoc(const IndexDoc& doc) { + PrintDoc(doc->value); + if (doc->indices.size() == 0) { + output_ << "[()]"; + } else { + output_ << "["; + PrintJoinedDocs(doc->indices, ", "); + output_ << "]"; + } +} + +const std::string OperatorToString(OperationDocNode::Kind operation_kind) { + static const std::vector op_kind2str = []() { + using OpKind = OperationDocNode::Kind; + std::map raw_table = { + {OpKind::kUSub, "-"}, // + {OpKind::kInvert, "~"}, // + {OpKind::kAdd, "+"}, // + {OpKind::kSub, "-"}, // + {OpKind::kMult, "*"}, // + {OpKind::kDiv, "/"}, // + {OpKind::kFloorDiv, "//"}, // + {OpKind::kMod, "%"}, // + {OpKind::kPow, "**"}, // + {OpKind::kLShift, "<<"}, // + {OpKind::kRShift, ">>"}, // + {OpKind::kBitAnd, "&"}, // + {OpKind::kBitOr, "|"}, // + {OpKind::kBitXor, "^"}, // + {OpKind::kLt, "<"}, // + {OpKind::kLtE, "<="}, // + {OpKind::kEq, "=="}, // + {OpKind::kNotEq, "!="}, // + {OpKind::kGt, ">"}, // + {OpKind::kGtE, ">="}, // + }; + + std::vector table; + table.resize(static_cast(OperationDocNode::Kind::kSpecialEnd) + 1); + + for (const auto& kv : raw_table) { + table[static_cast(kv.first)] = kv.second; + } + + return table; + }(); + + auto op_index = static_cast(operation_kind); + ICHECK_LT(op_index, op_kind2str.size()); + const std::string str = op_kind2str[op_index]; + ICHECK(!str.empty()) << "OperationDocNode::Kind " << static_cast(operation_kind) + << " cannot be converted to operator token in Python directly."; + return str; +} + +void PythonDocPrinter::PrintTypedDoc(const OperationDoc& doc) { + using OpKind = OperationDocNode::Kind; + if (doc->kind < OpKind::kUnaryEnd) { + // Unary Operators + ICHECK_EQ(doc->operands.size(), 1); + output_ << OperatorToString(doc->kind); + PrintDoc(doc->operands[0]); + } else if (doc->kind < OpKind::kBinaryEnd) { + // Binary Operator + ICHECK_EQ(doc->operands.size(), 2); + PrintDoc(doc->operands[0]); + output_ << " " << OperatorToString(doc->kind) << " "; + PrintDoc(doc->operands[1]); + } else if (doc->kind == OpKind::kIfThenElse) { + ICHECK_EQ(doc->operands.size(), 3) + << "ValueError: IfThenElse requires 3 operands, but got " << doc->operands.size(); + PrintDoc(doc->operands[1]); + output_ << " if "; + PrintDoc(doc->operands[0]); + output_ << " else "; + PrintDoc(doc->operands[2]); + } else { + LOG(FATAL) << "Unknown OperationDocNode::Kind " << static_cast(doc->kind); + throw; + } +} + +void PythonDocPrinter::PrintTypedDoc(const CallDoc& doc) { + PrintDoc(doc->callee); + + output_ << "("; + + // Print positional args + bool is_first = true; + for (const ExprDoc& arg : doc->args) { + if (is_first) { + is_first = false; + } else { + output_ << ", "; + } + PrintDoc(arg); + } + + // Print keyword args + ICHECK_EQ(doc->kwargs_keys.size(), doc->kwargs_values.size()) + << "CallDoc should have equal number of elements in kwargs_keys and kwargs_values."; + for (size_t i = 0; i < doc->kwargs_keys.size(); i++) { + if (is_first) { + is_first = false; + } else { + output_ << ", "; + } + const String& keyword = doc->kwargs_keys[i]; + output_ << keyword; + output_ << "="; + PrintDoc(doc->kwargs_values[i]); + } + + output_ << ")"; +} + +void PythonDocPrinter::PrintTypedDoc(const LambdaDoc& doc) { + output_ << "lambda "; + PrintJoinedDocs(doc->args, ", "); + output_ << ": "; + PrintDoc(doc->body); +} + +void PythonDocPrinter::PrintTypedDoc(const ListDoc& doc) { + output_ << "["; + PrintJoinedDocs(doc->elements, ", "); + output_ << "]"; +} + +void PythonDocPrinter::PrintTypedDoc(const TupleDoc& doc) { + output_ << "("; + if (doc->elements.size() == 1) { + PrintDoc(doc->elements[0]); + output_ << ","; + } else { + PrintJoinedDocs(doc->elements, ", "); + } + output_ << ")"; +} + +void PythonDocPrinter::PrintTypedDoc(const DictDoc& doc) { + ICHECK_EQ(doc->keys.size(), doc->values.size()) + << "DictDoc should have equal number of elements in keys and values."; + output_ << "{"; + size_t idx = 0; + for (const ExprDoc& key : doc->keys) { + if (idx > 0) { + output_ << ", "; + } + PrintDoc(key); + output_ << ": "; + PrintDoc(doc->values[idx]); + idx++; + } + output_ << "}"; +} + +void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) { + if (doc->start != nullptr) { + PrintDoc(doc->start.value()); + } + output_ << ":"; + if (doc->stop != nullptr) { + PrintDoc(doc->stop.value()); + } + if (doc->step != nullptr) { + output_ << ":"; + PrintDoc(doc->step.value()); + } +} + String DocToPythonScript(Doc doc, int indent_spaces) { PythonDocPrinter printer(indent_spaces); printer.Append(doc); diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py index 6330d33bf25a..4ff6a0f547d7 100644 --- a/tests/python/unittest/test_tvmscript_printer_doc.py +++ b/tests/python/unittest/test_tvmscript_printer_doc.py @@ -16,8 +16,20 @@ # under the License. import pytest -from tvm.tir import IntImm -from tvm.script.printer.doc import LiteralDoc +from tvm.script.printer.doc import ( + LiteralDoc, + IdDoc, + AttrAccessDoc, + IndexDoc, + CallDoc, + OperationKind, + OperationDoc, + LambdaDoc, + TupleDoc, + ListDoc, + DictDoc, + SliceDoc, +) @pytest.mark.parametrize( @@ -26,8 +38,209 @@ ) def test_literal_doc_construction(value): doc = LiteralDoc(value) + if isinstance(value, float): # FloatImm cannot be compared with Python's float directly assert float(doc.value) == pytest.approx(value) else: assert doc.value == value + + +def test_id_doc(): + doc = IdDoc("name") + + assert doc.name == "name" + + +def test_attr_access_doc(): + target = IdDoc("x") + + doc = AttrAccessDoc(target, "attribute") + + assert doc.value == target + assert doc.name == "attribute" + + +@pytest.mark.parametrize( + "indices", + [ + [], + [LiteralDoc(1)], + [LiteralDoc(2), IdDoc("x")], + [SliceDoc(LiteralDoc(1), LiteralDoc(2))], + [SliceDoc(LiteralDoc(1)), IdDoc("y")], + ], +) +def test_index_doc(indices): + target = IdDoc("x") + + doc = IndexDoc(target, indices) + + assert doc.value == target + assert list(doc.indices) == indices + + +@pytest.mark.parametrize( + "args, kwargs", + [ + ([], {}), + ([LiteralDoc("arg")], {}), + ([LiteralDoc("arg"), IdDoc("x")], {}), + ([], {"x": LiteralDoc("x")}), + ([], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}), + ([LiteralDoc("arg")], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}), + ([LiteralDoc("arg"), IdDoc("x")], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}), + ], +) +def test_call_doc(args, kwargs): + target = IdDoc("x") + + doc = CallDoc(target, *args, **kwargs) + + assert doc.callee == target + assert list(doc.args) == args + assert dict(zip(doc.kwargs_keys, doc.kwargs_values)) == kwargs + + +@pytest.mark.parametrize( + "operands", + [ + [], + [LiteralDoc(1)], + [LiteralDoc(2), IdDoc("x")], + [LiteralDoc(2), IdDoc("x"), LiteralDoc("y")], + ], +) +def test_operation_doc(operands): + # Here we just test the contructor and attr visitor of OperationDoc + # so the choice of OperationKind doesn't matter + operator = OperationKind.Add + + doc = OperationDoc(OperationKind.Add, operands) + + assert doc.kind == operator + assert list(doc.operands) == operands + + +@pytest.mark.parametrize( + "args", + [ + [], + [IdDoc("x")], + [IdDoc("x"), IdDoc("y")], + ], +) +def test_lambda_doc(args): + body = LiteralDoc(1) + + doc = LambdaDoc(args, body) + + assert doc.body == body + assert list(doc.args) == args + + +@pytest.mark.parametrize( + "elements", + [ + [], + [IdDoc("x")], + [IdDoc("x"), IdDoc("y")], + ], +) +def test_tuple_doc(elements): + doc = TupleDoc(elements) + + assert list(doc.elements) == elements + + +@pytest.mark.parametrize( + "elements", + [ + [], + [IdDoc("x")], + [IdDoc("x"), IdDoc("y")], + ], +) +def test_list_doc(elements): + doc = ListDoc(elements) + + assert list(doc.elements) == elements + + +@pytest.mark.parametrize( + "content", + [ + {}, + {LiteralDoc("k"): IdDoc("v")}, + {LiteralDoc("k"): IdDoc("v"), LiteralDoc("k2"): IdDoc("v2")}, + ], +) +def test_dict_doc(content): + doc = DictDoc(content) + + assert dict(zip(doc.keys, doc.values)) == content + + +@pytest.mark.parametrize("start", [LiteralDoc(1), None]) +@pytest.mark.parametrize("stop", [LiteralDoc(2), None]) +@pytest.mark.parametrize("step", [LiteralDoc(3), None]) +def test_slice_doc(start, stop, step): + doc = SliceDoc(start, stop) + + assert doc.start == start + assert doc.stop == stop + + +def test_expr_doc_attr_access(): + target = IdDoc("x") + attr = "test" + + doc = target.attr(attr) + + assert doc.value == target + assert doc.name == attr + + +@pytest.mark.parametrize( + "indices", + [ + (), + LiteralDoc(1), + SliceDoc(LiteralDoc(1), LiteralDoc(2)), + (LiteralDoc(1),), + (LiteralDoc(2), IdDoc("x")), + (SliceDoc(LiteralDoc(1), LiteralDoc(2)),), + (SliceDoc(LiteralDoc(1)), IdDoc("y")), + ], +) +def test_expr_doc_get_item(indices): + target = IdDoc("x") + + doc = target[indices] + + assert doc.value == target + if not isinstance(indices, tuple): + indices = (indices,) + assert tuple(doc.indices) == indices + + +@pytest.mark.parametrize( + "args, kwargs", + [ + ([], {}), + ([LiteralDoc("arg")], {}), + ([LiteralDoc("arg"), IdDoc("x")], {}), + ([], {"x": LiteralDoc("x")}), + ([], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}), + ([LiteralDoc("arg")], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}), + ([LiteralDoc("arg"), IdDoc("x")], {"x": LiteralDoc("x"), "y": LiteralDoc("y")}), + ], +) +def test_expr_doc_call_with(args, kwargs): + target = IdDoc("x") + + doc = target.call(*args, **kwargs) + + assert doc.callee == target + assert list(doc.args) == args + assert dict(zip(doc.kwargs_keys, doc.kwargs_values)) == kwargs diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index 55b5e88c88c8..b65eaa6b98a1 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -16,8 +16,19 @@ # under the License. import pytest +from tvm.script.printer.doc import ( + CallDoc, + DictDoc, + IdDoc, + LambdaDoc, + ListDoc, + LiteralDoc, + OperationDoc, + OperationKind, + SliceDoc, + TupleDoc, +) from tvm.script.printer.doc_printer import to_python_script -from tvm.script.printer.doc import LiteralDoc def format_script(s: str) -> str: @@ -28,7 +39,7 @@ def format_script(s: str) -> str: non_empty_lines = [line for line in s.splitlines() if line and not line.isspace()] line_indents = [len(line) - len(line.lstrip(" ")) for line in non_empty_lines] spaces_to_remove = min(line_indents) - return "\n".join(line[spaces_to_remove:] for line in s.splitlines()) + return "\n".join(line[spaces_to_remove:] for line in s.splitlines()) + "\n" @pytest.mark.parametrize( @@ -50,4 +61,367 @@ def format_script(s: str) -> str: ], ) def test_print_literal_doc(doc, expected): - assert to_python_script(doc).rstrip("\n") == format_script(expected) + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "name", + [ + "test", + "_test", + "TestCase", + "test_case", + "test123", + ], +) +def test_print_id_doc(name): + doc = IdDoc(name) + assert to_python_script(doc) == format_script(name) + + +@pytest.mark.parametrize( + "attr", + [ + "attr", + "_attr", + "Attr", + "attr_1", + ], +) +def test_print_attr_doc(attr): + doc = IdDoc("x").attr(attr) + assert to_python_script(doc) == format_script(f"x.{attr}") + + +@pytest.mark.parametrize( + "indices, expected", + [ + ( + (), + "[()]", + ), + ( + (LiteralDoc(1),), + "[1]", + ), + ( + (LiteralDoc(2), IdDoc("x")), + "[2, x]", + ), + ( + (SliceDoc(LiteralDoc(1), LiteralDoc(2)),), + "[1:2]", + ), + ( + (SliceDoc(LiteralDoc(1)), IdDoc("y")), + "[1:, y]", + ), + ( + (SliceDoc(), IdDoc("y")), + "[:, y]", + ), + ( + (IdDoc("x"), IdDoc("y"), IdDoc("z")), + "[x, y, z]", + ), + ], +) +def test_print_index_doc(indices, expected): + doc = IdDoc("x")[indices] + assert to_python_script(doc) == format_script(f"x{expected}") + + +UNARY_OP_TOKENS = { + OperationKind.USub: "-", + OperationKind.Invert: "~", +} + + +@pytest.mark.parametrize( + "op_kind, expected_token", + list(UNARY_OP_TOKENS.items()), + ids=UNARY_OP_TOKENS.keys(), +) +def test_print_unary_operation_doc(op_kind, expected_token): + doc = OperationDoc(op_kind, [IdDoc("x")]) + assert to_python_script(doc) == format_script(f"{expected_token}x") + + +BINARY_OP_TOKENS = { + OperationKind.Add: "+", + OperationKind.Sub: "-", + OperationKind.Mult: "*", + OperationKind.Div: "/", + OperationKind.FloorDiv: "//", + OperationKind.Mod: "%", + OperationKind.Pow: "**", + OperationKind.LShift: "<<", + OperationKind.RShift: ">>", + OperationKind.BitAnd: "&", + OperationKind.BitOr: "|", + OperationKind.BitXor: "^", + OperationKind.Lt: "<", + OperationKind.LtE: "<=", + OperationKind.Eq: "==", + OperationKind.NotEq: "!=", + OperationKind.Gt: ">", + OperationKind.GtE: ">=", +} + + +@pytest.mark.parametrize( + "op_kind, expected_token", + list(BINARY_OP_TOKENS.items()), + ids=BINARY_OP_TOKENS.keys(), +) +def test_print_binary_operation_doc(op_kind, expected_token): + doc = OperationDoc(op_kind, [IdDoc("x"), IdDoc("y")]) + assert to_python_script(doc) == format_script(f"x {expected_token} y") + + +SPECIAL_OP_CASES = [ + ( + OperationKind.IfThenElse, + [LiteralDoc(True), LiteralDoc("true"), LiteralDoc("false")], + '"true" if True else "false"', + ), + ( + OperationKind.IfThenElse, + [IdDoc("x"), LiteralDoc(None), LiteralDoc(1)], + "None if x else 1", + ), +] + + +@pytest.mark.parametrize( + "op_kind, operands, expected", SPECIAL_OP_CASES, ids=[kind for (kind, *_) in SPECIAL_OP_CASES] +) +def test_print_special_operation_doc(op_kind, operands, expected): + doc = OperationDoc(op_kind, operands) + assert to_python_script(doc) == format_script(expected) + + +def test_operation_doc_test_exhaustive(): + special_op_covered = {k for k, *_ in SPECIAL_OP_CASES} + for op_kind in OperationKind: + if OperationKind._UnaryStart < op_kind < OperationKind._UnaryEnd: + assert op_kind in UNARY_OP_TOKENS, ( + f"{op_kind.name} not covered in test_print_unary_operation_doc. " + f"Please add the expected token to UNARY_OP_TOKENS" + ) + elif OperationKind._BinaryStart < op_kind < OperationKind._BinaryEnd: + assert op_kind in BINARY_OP_TOKENS, ( + f"{op_kind.name} not covered in test_print_binary_operation_doc. " + f"Please add the expected token to BINARY_OP_TOKENS" + ) + elif not op_kind.name.startswith("_"): + # Special Op + assert op_kind in special_op_covered, ( + f"{op_kind.name} not covered in test_print_special_operation_doc. " + f"Please add the test cases for it to SPECIAL_OP_CASES" + ) + + +@pytest.mark.parametrize( + "args, kwargs, expected", + [ + ( + (), + {}, + "()", + ), + ( + (), + {"key0": IdDoc("u")}, + "(key0=u)", + ), + ( + (), + {"key0": IdDoc("u"), "key1": IdDoc("v")}, + "(key0=u, key1=v)", + ), + ( + (IdDoc("x"),), + {}, + "(x)", + ), + ( + (IdDoc("x"),), + {"key0": IdDoc("u")}, + "(x, key0=u)", + ), + ( + (IdDoc("x"),), + {"key0": IdDoc("u"), "key1": IdDoc("v")}, + "(x, key0=u, key1=v)", + ), + ( + (IdDoc("x"), (IdDoc("y"))), + {}, + "(x, y)", + ), + ( + (IdDoc("x"), (IdDoc("y"))), + {"key0": IdDoc("u")}, + "(x, y, key0=u)", + ), + ( + (IdDoc("x"), (IdDoc("y"))), + {"key0": IdDoc("u"), "key1": IdDoc("v")}, + "(x, y, key0=u, key1=v)", + ), + ], +) +def test_print_call_doc(args, kwargs, expected): + doc = CallDoc(IdDoc("f"), *args, **kwargs) + assert to_python_script(doc) == format_script(f"f{expected}") + + +@pytest.mark.parametrize( + "args, expected", + [ + ( + (), + "lambda : 0", + ), + ( + (IdDoc("x"),), + "lambda x: 0", + ), + ( + (IdDoc("x"), IdDoc("y")), + "lambda x, y: 0", + ), + ( + (IdDoc("x"), IdDoc("y"), IdDoc("z")), + "lambda x, y, z: 0", + ), + ], +) +def test_print_lambda_doc(args, expected): + doc = LambdaDoc(args, body=LiteralDoc(0)) + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "elements, expected", + [ + ( + (), + "[]", + ), + ( + [IdDoc("x")], + "[x]", + ), + ( + [IdDoc("x"), IdDoc("y")], + "[x, y]", + ), + ( + [IdDoc("x"), IdDoc("y"), IdDoc("z")], + "[x, y, z]", + ), + ], +) +def test_print_list_doc(elements, expected): + doc = ListDoc(elements) + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "elements, expected", + [ + ( + (), + "()", + ), + ( + [IdDoc("x")], + "(x,)", + ), + ( + [IdDoc("x"), IdDoc("y")], + "(x, y)", + ), + ( + [IdDoc("x"), IdDoc("y"), IdDoc("z")], + "(x, y, z)", + ), + ], +) +def test_print_tuple_doc(elements, expected): + doc = TupleDoc(elements) + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "content, expected", + [ + ( + {}, + "{}", + ), + ( + {LiteralDoc("key_x"): IdDoc("x")}, + '{"key_x": x}', + ), + ( + {LiteralDoc("key_x"): IdDoc("x"), LiteralDoc("key_y"): IdDoc("y")}, + '{"key_x": x, "key_y": y}', + ), + ( + { + LiteralDoc("key_x"): IdDoc("x"), + LiteralDoc("key_y"): IdDoc("y"), + LiteralDoc("key_z"): IdDoc("z"), + }, + '{"key_x": x, "key_y": y, "key_z": z}', + ), + ], +) +def test_print_dict_doc(content, expected): + doc = DictDoc(content) + assert to_python_script(doc) == format_script(expected) + + +@pytest.mark.parametrize( + "slice_doc, expected", + [ + ( + SliceDoc(), + ":", + ), + ( + SliceDoc(LiteralDoc(1)), + "1:", + ), + ( + SliceDoc(None, LiteralDoc(2)), + ":2", + ), + ( + SliceDoc(LiteralDoc(1), LiteralDoc(2)), + "1:2", + ), + ( + SliceDoc(None, None, LiteralDoc(3)), + "::3", + ), + ( + SliceDoc(LiteralDoc(1), None, LiteralDoc(3)), + "1::3", + ), + ( + SliceDoc(None, LiteralDoc(2), LiteralDoc(3)), + ":2:3", + ), + ( + SliceDoc(LiteralDoc(1), LiteralDoc(2), LiteralDoc(3)), + "1:2:3", + ), + ], +) +def test_print_slice_doc(slice_doc, expected): + doc = IdDoc("x")[slice_doc] + assert to_python_script(doc) == format_script(f"x[{expected}]")