From cfbc41a075dee32edf4583a44f35a888ae819f0b Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Thu, 30 Jun 2022 01:45:21 -0400 Subject: [PATCH 01/27] Add literal doc definition --- CMakeLists.txt | 1 + include/tvm/script/printer/doc.h | 155 ++++++++++++++++++++++++++ python/tvm/script/printer/_ffi_api.py | 20 ++++ python/tvm/script/printer/doc.py | 57 ++++++++++ src/script/printer/doc.cc | 44 ++++++++ 5 files changed, 277 insertions(+) create mode 100644 include/tvm/script/printer/doc.h create mode 100644 python/tvm/script/printer/_ffi_api.py create mode 100644 python/tvm/script/printer/doc.py create mode 100644 src/script/printer/doc.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 306a8be30858..46de8f5d07fa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -281,6 +281,7 @@ tvm_file_glob(GLOB_RECURSE COMPILER_SRCS src/parser/*.cc src/printer/*.cc src/support/*.cc + src/script/*.cc ) tvm_file_glob(GLOB CODEGEN_SRCS diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h new file mode 100644 index 000000000000..e82efaa68eda --- /dev/null +++ b/include/tvm/script/printer/doc.h @@ -0,0 +1,155 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_DOC_H_ +#define TVM_SCRIPT_PRINTER_DOC_H_ + +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +/*! + * \brief The base class of all Doc. + * + * Doc is an intermediate representation between IR from TVM + * and the TVMScript code. + * During printing, IR graph is first translated into Doc tree, + * then the Doc tree is translated to the target language in + * text format. + * + * \sa Doc + */ +class DocNode : public Object { + public: + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const char* _type_key = "script.printer.Doc"; + TVM_DECLARE_BASE_OBJECT_INFO(DocNode, Object); + + public: + virtual ~DocNode() = default; +}; + +/*! + * \brief Reference type of DocNode. + * + * \sa DocNode + */ +class Doc : public ObjectRef { + protected: + Doc() = default; + + public: + virtual ~Doc() = default; + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Doc, ObjectRef, DocNode); +}; + +/*! + * \brief The base class of expression doc. + * + * \sa ExprDoc + */ +class ExprDocNode : public DocNode { + public: + void VisitAttrs(AttrVisitor* v) { DocNode::VisitAttrs(v); } + + static constexpr const char* _type_key = "script.printer.ExprDoc"; + TVM_DECLARE_BASE_OBJECT_INFO(ExprDocNode, DocNode); +}; + +/*! + * \brief Reference type of ExprDocNode. + * + * \sa ExprDocNode + */ +class ExprDoc : public Doc { + protected: + ExprDoc() = default; + + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ExprDoc, Doc, ExprDocNode); +}; + +/*! + * \brief Doc that represents literal value. + * + * \sa LiteralDoc + */ +class LiteralDocNode : public ExprDocNode { + public: + /*! + * \brief the internal representation of the literal value. + * + * The actual type is union of IntImm, FloatImm and String, or a + * null ObjectRef. + */ + ObjectRef value; + + void VisitAttrs(AttrVisitor* v) { + ExprDocNode::VisitAttrs(v); + v->Visit("value", &value); + } + + static constexpr const char* _type_key = "script.printer.LiteralDoc"; + TVM_DECLARE_FINAL_OBJECT_INFO(LiteralDocNode, ExprDocNode); +}; + +/*! + * \brief Reference type of LiteralDocNode. + * + * \sa LiteralDocNode + */ +class LiteralDoc : public ExprDoc { + protected: + explicit LiteralDoc(ObjectRef value); + + public: + /*! + * \brief Create a LiteralDoc to represent None/null/empty value. + */ + static LiteralDoc None() { return LiteralDoc(ObjectRef(nullptr)); } + + /*! + * \brief Create a LiteralDoc to represent integer. + * \param v The integer value. + */ + static LiteralDoc Int(const IntImm& v) { return LiteralDoc(v); } + + /*! + * \brief Create a LiteralDoc to represent float. + * \param v The float value. + */ + static LiteralDoc Float(const FloatImm& v) { return LiteralDoc(v); } + + /*! + * \brief Create a LiteralDoc to represent string. + * \param v The string value. + */ + static LiteralDoc Str(const String& v) { return LiteralDoc(v); } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode); +}; + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif diff --git a/python/tvm/script/printer/_ffi_api.py b/python/tvm/script/printer/_ffi_api.py new file mode 100644 index 000000000000..baa639fe2d67 --- /dev/null +++ b/python/tvm/script/printer/_ffi_api.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""FFI APIs for tvm.script.printer""" +import tvm._ffi + +tvm._ffi._init_api("script.printer", __name__) diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py new file mode 100644 index 000000000000..c5be57d9fdd1 --- /dev/null +++ b/python/tvm/script/printer/doc.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Doc types for TVMScript Unified Printer""" + +import tvm._ffi +from tvm.runtime import Object +from tvm.tir import FloatImm, IntImm + +from . import _ffi_api + + +class Doc(Object): + """Base class of all Docs""" + + +class ExprDoc(Object): + """Base class of all expression Docs""" + + +@tvm._ffi.register_object("script.printer.LiteralDoc") +class LiteralDoc(ExprDoc): + """Doc that represents literal value""" + + def __init__(self, value): + if isinstance(value, str): + self.__init_handle_by_constructor__( + _ffi_api.LiteralDoc.Str, value + ) + elif isinstance(value, (float, FloatImm)): + self.__init_handle_by_constructor__( + _ffi_api.LiteralDoc.Float, value + ) + elif isinstance(value, (int, IntImm)): + self.__init_handle_by_constructor__( + _ffi_api.LiteralDoc.Int, value + ) + elif value is None: + self.__init_handle_by_constructor__( + _ffi_api.LiteralDoc.None_ + ) + else: + raise TypeError(f"Unsupported type {type(value)} for LiteralDoc") + diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc new file mode 100644 index 000000000000..0fb7ebee6f76 --- /dev/null +++ b/src/script/printer/doc.cc @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include + +#include + +namespace tvm { +namespace script { +namespace printer { + +TVM_REGISTER_NODE_TYPE(DocNode); +TVM_REGISTER_NODE_TYPE(ExprDocNode); + +LiteralDoc::LiteralDoc(ObjectRef value) { + ObjectPtr n = make_object(); + n->value = value; + this->data_ = std::move(n); +} +TVM_REGISTER_NODE_TYPE(LiteralDocNode); +// Underscore is added to avoid syntax error in Python FFI binding +TVM_REGISTER_GLOBAL("script.printer.LiteralDoc.None_").set_body_typed(LiteralDoc::None); +TVM_REGISTER_GLOBAL("script.printer.LiteralDoc.Int").set_body_typed(LiteralDoc::Int); +TVM_REGISTER_GLOBAL("script.printer.LiteralDoc.Float").set_body_typed(LiteralDoc::Float); +TVM_REGISTER_GLOBAL("script.printer.LiteralDoc.Str").set_body_typed(LiteralDoc::Str); + +} // namespace printer +} // namespace script +} // namespace tvm From 03017beba2148a61ccaecfe597b390ae7e6db7a9 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Thu, 30 Jun 2022 11:38:34 -0400 Subject: [PATCH 02/27] Add test for literal doc construction --- include/tvm/script/printer/doc.h | 3 ++ python/tvm/script/printer/doc.py | 34 +++++++++-------- src/script/printer/doc.cc | 16 +++++--- .../test_tvmscript_unified_printer_doc.py | 37 +++++++++++++++++++ 4 files changed, 68 insertions(+), 22 deletions(-) create mode 100644 tests/python/unittest/test_tvmscript_unified_printer_doc.py diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index e82efaa68eda..5d74b10b9e10 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -21,6 +21,7 @@ #include #include +#include "tvm/runtime/data_type.h" namespace tvm { namespace script { @@ -132,12 +133,14 @@ class LiteralDoc : public ExprDoc { * \param v The integer value. */ static LiteralDoc Int(const IntImm& v) { return LiteralDoc(v); } + static LiteralDoc Int(int v) { return LiteralDoc(IntImm(DataType::Int(64), v)); } /*! * \brief Create a LiteralDoc to represent float. * \param v The float value. */ static LiteralDoc Float(const FloatImm& v) { return LiteralDoc(v); } + static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); } /*! * \brief Create a LiteralDoc to represent string. diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index c5be57d9fdd1..2bcc7e279802 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -31,27 +31,29 @@ class ExprDoc(Object): """Base class of all expression Docs""" +_literal_constructors = [ + (str, _ffi_api.LiteralDocStr), + (float, _ffi_api.LiteralDocFloat), + (int, _ffi_api.LiteralDocInt), + (IntImm, _ffi_api.LiteralDocIntImm), + (FloatImm, _ffi_api.LiteralDocFloatImm) +] + + @tvm._ffi.register_object("script.printer.LiteralDoc") class LiteralDoc(ExprDoc): """Doc that represents literal value""" def __init__(self, value): - if isinstance(value, str): - self.__init_handle_by_constructor__( - _ffi_api.LiteralDoc.Str, value - ) - elif isinstance(value, (float, FloatImm)): - self.__init_handle_by_constructor__( - _ffi_api.LiteralDoc.Float, value - ) - elif isinstance(value, (int, IntImm)): - self.__init_handle_by_constructor__( - _ffi_api.LiteralDoc.Int, value - ) - elif value is None: - self.__init_handle_by_constructor__( - _ffi_api.LiteralDoc.None_ - ) + if value is None: + self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone) + return + + for (cls, constructor) in _literal_constructors: + if isinstance(value, cls): + break else: raise TypeError(f"Unsupported type {type(value)} for LiteralDoc") + self.__init_handle_by_constructor__(constructor, value) + diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index 0fb7ebee6f76..51fdb2e37023 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -17,7 +17,6 @@ * under the License. */ #include - #include namespace tvm { @@ -33,11 +32,16 @@ LiteralDoc::LiteralDoc(ObjectRef value) { this->data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(LiteralDocNode); -// Underscore is added to avoid syntax error in Python FFI binding -TVM_REGISTER_GLOBAL("script.printer.LiteralDoc.None_").set_body_typed(LiteralDoc::None); -TVM_REGISTER_GLOBAL("script.printer.LiteralDoc.Int").set_body_typed(LiteralDoc::Int); -TVM_REGISTER_GLOBAL("script.printer.LiteralDoc.Float").set_body_typed(LiteralDoc::Float); -TVM_REGISTER_GLOBAL("script.printer.LiteralDoc.Str").set_body_typed(LiteralDoc::Str); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt") + .set_body_typed(static_cast(LiteralDoc::Int)); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocIntImm") + .set_body_typed(static_cast(LiteralDoc::Int)); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat") + .set_body_typed(static_cast(LiteralDoc::Float)); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloatImm") + .set_body_typed(static_cast(LiteralDoc::Float)); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str); } // namespace printer } // namespace script diff --git a/tests/python/unittest/test_tvmscript_unified_printer_doc.py b/tests/python/unittest/test_tvmscript_unified_printer_doc.py new file mode 100644 index 000000000000..a3b3b5b653bd --- /dev/null +++ b/tests/python/unittest/test_tvmscript_unified_printer_doc.py @@ -0,0 +1,37 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from tvm.tir import FloatImm, IntImm +from tvm.script.printer.doc import LiteralDoc + + +@pytest.mark.parametrize("value", [ + None, + "test", + 1, + 1.5, + FloatImm("float32", 3.2), + IntImm("int8", 5) +]) +def test_literal_doc_construction(value): + doc = LiteralDoc(value) + if isinstance(value, float): + # FloatImm isn't unpacked to Python's float automatically + assert float(doc.value) == pytest.approx(value) + else: + assert doc.value == value From da5ee17d0196e07012f694e9a1aaa078acf49469 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Thu, 30 Jun 2022 13:37:45 -0400 Subject: [PATCH 03/27] Add doc printer --- include/tvm/script/printer/doc.h | 2 - include/tvm/script/printer/doc_printer.h | 70 ++++++++++++ python/tvm/script/printer/__init__.py | 19 ++++ python/tvm/script/printer/doc.py | 4 - src/script/printer/doc.cc | 10 +- src/script/printer/doc_printer.cc | 52 +++++++++ src/script/printer/python_doc_printer.cc | 101 ++++++++++++++++++ .../test_tvmscript_python_doc_printer.py | 74 +++++++++++++ .../test_tvmscript_unified_printer_doc.py | 3 - 9 files changed, 318 insertions(+), 17 deletions(-) create mode 100644 include/tvm/script/printer/doc_printer.h create mode 100644 python/tvm/script/printer/__init__.py create mode 100644 src/script/printer/doc_printer.cc create mode 100644 src/script/printer/python_doc_printer.cc create mode 100644 tests/python/unittest/test_tvmscript_python_doc_printer.py diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 5d74b10b9e10..4cb69785586a 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -132,14 +132,12 @@ class LiteralDoc : public ExprDoc { * \brief Create a LiteralDoc to represent integer. * \param v The integer value. */ - static LiteralDoc Int(const IntImm& v) { return LiteralDoc(v); } static LiteralDoc Int(int v) { return LiteralDoc(IntImm(DataType::Int(64), v)); } /*! * \brief Create a LiteralDoc to represent float. * \param v The float value. */ - static LiteralDoc Float(const FloatImm& v) { return LiteralDoc(v); } static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); } /*! diff --git a/include/tvm/script/printer/doc_printer.h b/include/tvm/script/printer/doc_printer.h new file mode 100644 index 000000000000..ee84dbc12bda --- /dev/null +++ b/include/tvm/script/printer/doc_printer.h @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_DOC_PRINTER_H_ +#define TVM_SCRIPT_PRINTER_DOC_PRINTER_H_ + +#include + +namespace tvm { +namespace script { +namespace printer { + +struct DocPrinterOptions { + int indent_spaces = 4; +}; + +class DocPrinter { + public: + explicit DocPrinter(const DocPrinterOptions& options); + virtual ~DocPrinter() = default; + + void Append(const Doc& doc); + String GetString() const; + + protected: + void PrintDoc(const Doc& doc); + + virtual void PrintTypedDoc(const LiteralDoc& doc) = 0; + + using OutputStream = std::ostringstream; + + void IncreaseIndent() { indent_ += options_.indent_spaces; } + + void DecreaseIndent() { indent_ -= options_.indent_spaces; } + + OutputStream& NewLine() { + output_ << "\n"; + output_ << std::string(indent_, ' '); + return output_; + } + + OutputStream output_; + + private: + DocPrinterOptions options_; + int indent_ = 0; +}; + +std::unique_ptr GetPythonDocPrinter(const DocPrinterOptions& options); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif diff --git a/python/tvm/script/printer/__init__.py b/python/tvm/script/printer/__init__.py new file mode 100644 index 000000000000..12bcecca6d5a --- /dev/null +++ b/python/tvm/script/printer/__init__.py @@ -0,0 +1,19 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from . import _ffi_api + diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 2bcc7e279802..a8e0b3565182 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -18,7 +18,6 @@ import tvm._ffi from tvm.runtime import Object -from tvm.tir import FloatImm, IntImm from . import _ffi_api @@ -35,8 +34,6 @@ class ExprDoc(Object): (str, _ffi_api.LiteralDocStr), (float, _ffi_api.LiteralDocFloat), (int, _ffi_api.LiteralDocInt), - (IntImm, _ffi_api.LiteralDocIntImm), - (FloatImm, _ffi_api.LiteralDocFloatImm) ] @@ -56,4 +53,3 @@ def __init__(self, value): raise TypeError(f"Unsupported type {type(value)} for LiteralDoc") self.__init_handle_by_constructor__(constructor, value) - diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index 51fdb2e37023..333d8a3ab7ed 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -33,14 +33,8 @@ LiteralDoc::LiteralDoc(ObjectRef value) { } TVM_REGISTER_NODE_TYPE(LiteralDocNode); TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt") - .set_body_typed(static_cast(LiteralDoc::Int)); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocIntImm") - .set_body_typed(static_cast(LiteralDoc::Int)); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat") - .set_body_typed(static_cast(LiteralDoc::Float)); -TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloatImm") - .set_body_typed(static_cast(LiteralDoc::Float)); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float); TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str); } // namespace printer diff --git a/src/script/printer/doc_printer.cc b/src/script/printer/doc_printer.cc new file mode 100644 index 000000000000..6b48c1c3ead0 --- /dev/null +++ b/src/script/printer/doc_printer.cc @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +namespace tvm { +namespace script { +namespace printer { + +DocPrinter::DocPrinter(const DocPrinterOptions& options) : options_(options) {} + +void DocPrinter::Append(const Doc& doc) { + PrintDoc(doc); +} + +String DocPrinter::GetString() const { + std::string text = output_.str(); + if (!text.empty() && text.back() != '\n') { + text.push_back('\n'); + } + return text; +} + +void DocPrinter::PrintDoc(const Doc& doc) { + if (const auto* doc_node = doc.as()) { + PrintTypedDoc(GetRef(doc_node)); + } else { + LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey(); + throw; + } +} + + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc new file mode 100644 index 000000000000..73db9a7c4411 --- /dev/null +++ b/src/script/printer/python_doc_printer.cc @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +namespace { + +void PrintLiteralString(const String& string, std::ostringstream& out) { + // TODO: Escape and smart quote (choose ' or " automatically) + out << "\"" << string << "\""; +} + +void PrintLiteralPrimExpr(const PrimExpr& expr, std::ostringstream& out) { + const DataType& dtype = expr->dtype; + + if (dtype == DataType::Int(64)) { + out << Downcast(expr)->value; + } else if (dtype == DataType::Float(64)) { + // TODO: make the float printing roundtrippable + std::ostringstream number_value; + number_value.precision(17); + number_value << Downcast(expr)->value; + out << number_value.str(); + } else if (dtype == DataType::Bool()) { + out << (Downcast(expr)->value ? "True" : "False"); + } else { + LOG(FATAL) << "Cannot print value with dtype " << dtype << " as literal expression"; + } +} + +} // namespace + +class PythonDocPrinter : public DocPrinter { + public: + PythonDocPrinter(const DocPrinterOptions& options) : DocPrinter(options) {} + + protected: + using DocPrinter::PrintDoc; + + void PrintTypedDoc(const LiteralDoc& doc) final; + + private: + template + std::enable_if_t::value, void> PrintObject(const ObjType& obj) { + PrintDoc(obj); + } + + template + std::enable_if_t::value, void> PrintObject(const ObjType& obj) { + output_ << obj; + } +}; + +void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { + const ObjectRef& value = doc->value; + if (!value.defined()) { + output_ << "None"; + } else if (const auto* expr_node = value.as()) { + PrintLiteralPrimExpr(GetRef(expr_node), output_); + } else if (const auto* string_obj = value.as()) { + PrintLiteralString(GetRef(string_obj), output_); + } else { + LOG(FATAL) << "Unsupported literal value type " << value->GetTypeKey(); + } +} + +std::unique_ptr GetPythonDocPrinter(const DocPrinterOptions& options) { + return std::make_unique(options); +} + +TVM_REGISTER_GLOBAL("script.printer.PrintDocAsPython") + .set_body_typed([](Doc doc, int indent_spaces = 4) { + PythonDocPrinter printer({.indent_spaces = indent_spaces}); + printer.Append(doc); + return printer.GetString(); + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_python_doc_printer.py b/tests/python/unittest/test_tvmscript_python_doc_printer.py new file mode 100644 index 000000000000..276b23c3bc03 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_python_doc_printer.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from tvm.script.printer import _ffi_api +from tvm.script.printer.doc import LiteralDoc + + +def format_script(s: str) -> str: + """ + Remove leading and trailing blank lines, and make the minimum idention 0 + """ + s = s.strip("\n") + non_empty_lines = [line for line in s.splitlines() if line and not line.isspace()] + line_indents = [len(line) - len(line.lstrip(" ")) for line in non_empty_lines] + spaces_to_remove = min(line_indents) + return "\n".join(line[spaces_to_remove:] for line in s.splitlines()) + + +def print_doc_as_python(doc, indent_spaces=4): + return format_script(_ffi_api.PrintDocAsPython(doc, indent_spaces)) + + +@pytest.mark.parametrize("doc,expected", [ + ( + LiteralDoc(None), + "None" + ), + ( + LiteralDoc("test"), + '"test"' + ), + ( + LiteralDoc(""), + '""' + ), + # TODO: make the string printing add character escaping + pytest.param(LiteralDoc("\""), r'"\""', marks=pytest.mark.xfail), + ( + LiteralDoc(0), + "0" + ), + ( + LiteralDoc(-1), + "-1" + ), + ( + LiteralDoc(3.25), + "3.25" + ), + ( + LiteralDoc(-0.5), + "-0.5" + ), + # TODO: make the float number printing preserve percision and roundtrippable + pytest.param(LiteralDoc(0.0), "0.0", marks=pytest.mark.xfail), + pytest.param(LiteralDoc(3.14), "3.14", marks=pytest.mark.xfail) +]) +def test_print_literal_doc(doc, expected): + assert print_doc_as_python(doc) == format_script(expected) diff --git a/tests/python/unittest/test_tvmscript_unified_printer_doc.py b/tests/python/unittest/test_tvmscript_unified_printer_doc.py index a3b3b5b653bd..003386f2cb5c 100644 --- a/tests/python/unittest/test_tvmscript_unified_printer_doc.py +++ b/tests/python/unittest/test_tvmscript_unified_printer_doc.py @@ -16,7 +16,6 @@ # under the License. import pytest -from tvm.tir import FloatImm, IntImm from tvm.script.printer.doc import LiteralDoc @@ -25,8 +24,6 @@ "test", 1, 1.5, - FloatImm("float32", 3.2), - IntImm("int8", 5) ]) def test_literal_doc_construction(value): doc = LiteralDoc(value) From 3525a7084546e9f0717445cfa28e2274bbe8bab7 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Thu, 30 Jun 2022 14:18:21 -0400 Subject: [PATCH 04/27] Add documentation --- include/tvm/script/printer/doc_printer.h | 75 ++++++++++++++++++++++++ python/tvm/script/printer/__init__.py | 9 ++- src/script/printer/python_doc_printer.cc | 25 ++++---- 3 files changed, 97 insertions(+), 12 deletions(-) diff --git a/include/tvm/script/printer/doc_printer.h b/include/tvm/script/printer/doc_printer.h index ee84dbc12bda..e90749846b86 100644 --- a/include/tvm/script/printer/doc_printer.h +++ b/include/tvm/script/printer/doc_printer.h @@ -25,42 +25,117 @@ namespace tvm { namespace script { namespace printer { +/*! + * \brief Configurable options for DocPrinter + * + * \sa DocPrinter + */ struct DocPrinterOptions { int indent_spaces = 4; }; +/*! + * \brief DocPrinter is responsible for printing Doc tree into text format + * \details This is the base class for translating Doc into string. + * Each target language needs to have its subclass of DocPrinter + * to define the actual logic of printing Doc. + * + * \sa Doc + */ class DocPrinter { public: + /*! + * \brief The constructor of DocPrinter + * + * \param options the option for printer + */ explicit DocPrinter(const DocPrinterOptions& options); virtual ~DocPrinter() = default; + /*! + * \brief Append a doc into the final content + * + * \param doc the Doc to be printed + * + * \sa GetString + */ void Append(const Doc& doc); + + /*! + * \brief Get the printed string of all Doc appended + * + * The content of each Doc in the returned string will + * appear in the same order as they are appended. + * + * \sa Append + */ String GetString() const; protected: + /*! + * \brief Get the printed string + * + * It will dispatch to the PrintTypedDoc method based on + * the actual type of Doc. + * + * \sa PrintTypedDoc + */ void PrintDoc(const Doc& doc); + /*! + * \brief Virtual method to print a LiteralDoc + */ virtual void PrintTypedDoc(const LiteralDoc& doc) = 0; using OutputStream = std::ostringstream; + /*! + * \brief Increase the indent level of any content to be + * printed after this call + */ void IncreaseIndent() { indent_ += options_.indent_spaces; } + /*! + * \brief Decrease the indent level of any content to be + * printed after this call + */ void DecreaseIndent() { indent_ -= options_.indent_spaces; } + /*! + * \brief Add a new line into the output stream + * + * \sa output_ + */ OutputStream& NewLine() { output_ << "\n"; output_ << std::string(indent_, ' '); return output_; } + /*! + * \brief The output stream of printer + * + * All printed content will be stored in this stream and returned + * when GetString is called. + * + * \sa GetString + */ OutputStream output_; private: + /*! \brief the printer options */ DocPrinterOptions options_; + + /*! \brief the current level of indent */ int indent_ = 0; }; +/*! + * \brief Get a doc printer to print Doc into Python code + * + * \param options the option for printer + * \return A pointer to the printer + */ std::unique_ptr GetPythonDocPrinter(const DocPrinterOptions& options); } // namespace printer diff --git a/python/tvm/script/printer/__init__.py b/python/tvm/script/printer/__init__.py index 12bcecca6d5a..84ab7b0ba836 100644 --- a/python/tvm/script/printer/__init__.py +++ b/python/tvm/script/printer/__init__.py @@ -14,6 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +""" +TVMScript Unified Printer -from . import _ffi_api +This package provides a set of APIs to print supported TVM IR into TVMScript +in a roundtrippable way. + +https://github.com/apache/tvm-rfcs/blob/main/rfcs/0074-tvmscript-unified-printer.md +""" +from . import _ffi_api diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index 73db9a7c4411..3bdffa6667ba 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -26,11 +26,25 @@ namespace printer { namespace { +/*! + * \brief Print a Python literal string + * + * \param string the string to be printed + * \param out the output stream + */ void PrintLiteralString(const String& string, std::ostringstream& out) { // TODO: Escape and smart quote (choose ' or " automatically) out << "\"" << string << "\""; } +/*! + * \brief Print a tvm::ir::PrimExpr as Python literal + * + * This only supports IntImm and FloatImm with size of 64 bits + * + * \param expr the PrimExpr to be printed + * \param out the output stream + */ void PrintLiteralPrimExpr(const PrimExpr& expr, std::ostringstream& out) { const DataType& dtype = expr->dtype; @@ -59,17 +73,6 @@ class PythonDocPrinter : public DocPrinter { using DocPrinter::PrintDoc; void PrintTypedDoc(const LiteralDoc& doc) final; - - private: - template - std::enable_if_t::value, void> PrintObject(const ObjType& obj) { - PrintDoc(obj); - } - - template - std::enable_if_t::value, void> PrintObject(const ObjType& obj) { - output_ << obj; - } }; void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { From ae2a7c15fc116838899462b137ce68b5e1aa03fd Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Thu, 30 Jun 2022 14:19:58 -0400 Subject: [PATCH 05/27] Rename test file --- ...cript_unified_printer_doc.py => test_tvmscript_printer_doc.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/python/unittest/{test_tvmscript_unified_printer_doc.py => test_tvmscript_printer_doc.py} (100%) diff --git a/tests/python/unittest/test_tvmscript_unified_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py similarity index 100% rename from tests/python/unittest/test_tvmscript_unified_printer_doc.py rename to tests/python/unittest/test_tvmscript_printer_doc.py From 0ca5dc54c095e5d71013dcd0193d2b3db68b03fc Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Thu, 30 Jun 2022 14:30:23 -0400 Subject: [PATCH 06/27] Format code --- include/tvm/script/printer/doc.h | 5 +- include/tvm/script/printer/doc_printer.h | 15 +++--- src/script/printer/doc_printer.cc | 5 +- src/script/printer/python_doc_printer.cc | 4 +- .../test_tvmscript_python_doc_printer.py | 52 ++++++------------- 5 files changed, 32 insertions(+), 49 deletions(-) diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 4cb69785586a..7e860f04f992 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -21,6 +21,7 @@ #include #include + #include "tvm/runtime/data_type.h" namespace tvm { @@ -140,7 +141,7 @@ class LiteralDoc : public ExprDoc { */ static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); } - /*! + /*! * \brief Create a LiteralDoc to represent string. * \param v The string value. */ @@ -153,4 +154,4 @@ class LiteralDoc : public ExprDoc { } // namespace script } // namespace tvm -#endif +#endif // TVM_SCRIPT_PRINTER_DOC_H_ diff --git a/include/tvm/script/printer/doc_printer.h b/include/tvm/script/printer/doc_printer.h index e90749846b86..0a61df5b6707 100644 --- a/include/tvm/script/printer/doc_printer.h +++ b/include/tvm/script/printer/doc_printer.h @@ -21,6 +21,9 @@ #include +#include +#include + namespace tvm { namespace script { namespace printer { @@ -64,7 +67,7 @@ class DocPrinter { /*! * \brief Get the printed string of all Doc appended * - * The content of each Doc in the returned string will + * The content of each Doc in the returned string will * appear in the same order as they are appended. * * \sa Append @@ -75,7 +78,7 @@ class DocPrinter { /*! * \brief Get the printed string * - * It will dispatch to the PrintTypedDoc method based on + * It will dispatch to the PrintTypedDoc method based on * the actual type of Doc. * * \sa PrintTypedDoc @@ -90,13 +93,13 @@ class DocPrinter { using OutputStream = std::ostringstream; /*! - * \brief Increase the indent level of any content to be + * \brief Increase the indent level of any content to be * printed after this call */ void IncreaseIndent() { indent_ += options_.indent_spaces; } /*! - * \brief Decrease the indent level of any content to be + * \brief Decrease the indent level of any content to be * printed after this call */ void DecreaseIndent() { indent_ -= options_.indent_spaces; } @@ -115,7 +118,7 @@ class DocPrinter { /*! * \brief The output stream of printer * - * All printed content will be stored in this stream and returned + * All printed content will be stored in this stream and returned * when GetString is called. * * \sa GetString @@ -142,4 +145,4 @@ std::unique_ptr GetPythonDocPrinter(const DocPrinterOptions& options } // namespace script } // namespace tvm -#endif +#endif // TVM_SCRIPT_PRINTER_DOC_PRINTER_H_ diff --git a/src/script/printer/doc_printer.cc b/src/script/printer/doc_printer.cc index 6b48c1c3ead0..cbd760f55a22 100644 --- a/src/script/printer/doc_printer.cc +++ b/src/script/printer/doc_printer.cc @@ -25,9 +25,7 @@ namespace printer { DocPrinter::DocPrinter(const DocPrinterOptions& options) : options_(options) {} -void DocPrinter::Append(const Doc& doc) { - PrintDoc(doc); -} +void DocPrinter::Append(const Doc& doc) { PrintDoc(doc); } String DocPrinter::GetString() const { std::string text = output_.str(); @@ -46,7 +44,6 @@ void DocPrinter::PrintDoc(const Doc& doc) { } } - } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index 3bdffa6667ba..a97fc1bed3ae 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -33,7 +33,7 @@ namespace { * \param out the output stream */ void PrintLiteralString(const String& string, std::ostringstream& out) { - // TODO: Escape and smart quote (choose ' or " automatically) + // TODO(yelite): Escape and smart quote (choose ' or " automatically) out << "\"" << string << "\""; } @@ -51,7 +51,7 @@ void PrintLiteralPrimExpr(const PrimExpr& expr, std::ostringstream& out) { if (dtype == DataType::Int(64)) { out << Downcast(expr)->value; } else if (dtype == DataType::Float(64)) { - // TODO: make the float printing roundtrippable + // TODO(yelite): make the float printing roundtrippable std::ostringstream number_value; number_value.precision(17); number_value << Downcast(expr)->value; diff --git a/tests/python/unittest/test_tvmscript_python_doc_printer.py b/tests/python/unittest/test_tvmscript_python_doc_printer.py index 276b23c3bc03..9732ff8c97c5 100644 --- a/tests/python/unittest/test_tvmscript_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_python_doc_printer.py @@ -35,40 +35,22 @@ def print_doc_as_python(doc, indent_spaces=4): return format_script(_ffi_api.PrintDocAsPython(doc, indent_spaces)) -@pytest.mark.parametrize("doc,expected", [ - ( - LiteralDoc(None), - "None" - ), - ( - LiteralDoc("test"), - '"test"' - ), - ( - LiteralDoc(""), - '""' - ), - # TODO: make the string printing add character escaping - pytest.param(LiteralDoc("\""), r'"\""', marks=pytest.mark.xfail), - ( - LiteralDoc(0), - "0" - ), - ( - LiteralDoc(-1), - "-1" - ), - ( - LiteralDoc(3.25), - "3.25" - ), - ( - LiteralDoc(-0.5), - "-0.5" - ), - # TODO: make the float number printing preserve percision and roundtrippable - pytest.param(LiteralDoc(0.0), "0.0", marks=pytest.mark.xfail), - pytest.param(LiteralDoc(3.14), "3.14", marks=pytest.mark.xfail) -]) +@pytest.mark.parametrize( + "doc,expected", + [ + (LiteralDoc(None), "None"), + (LiteralDoc("test"), '"test"'), + (LiteralDoc(""), '""'), + # TODO: make the string printing add character escaping + pytest.param(LiteralDoc('"'), r'"\""', marks=pytest.mark.xfail), + (LiteralDoc(0), "0"), + (LiteralDoc(-1), "-1"), + (LiteralDoc(3.25), "3.25"), + (LiteralDoc(-0.5), "-0.5"), + # TODO: make the float number printing preserve percision and roundtrippable + pytest.param(LiteralDoc(0.0), "0.0", marks=pytest.mark.xfail), + pytest.param(LiteralDoc(3.14), "3.14", marks=pytest.mark.xfail), + ], +) def test_print_literal_doc(doc, expected): assert print_doc_as_python(doc) == format_script(expected) From 3baec2122ea7287cdce9716bd5d7f8a306ff07c6 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Thu, 30 Jun 2022 14:30:30 -0400 Subject: [PATCH 07/27] Add more test cases for literal doc --- .../unittest/test_tvmscript_printer_doc.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py index 003386f2cb5c..d98d036fb33b 100644 --- a/tests/python/unittest/test_tvmscript_printer_doc.py +++ b/tests/python/unittest/test_tvmscript_printer_doc.py @@ -19,12 +19,19 @@ from tvm.script.printer.doc import LiteralDoc -@pytest.mark.parametrize("value", [ - None, - "test", - 1, - 1.5, -]) +@pytest.mark.parametrize( + "value", + [ + None, + "test", + 0, + 1, + -2, + 0.0, + 1.5, + -1.3 + ], +) def test_literal_doc_construction(value): doc = LiteralDoc(value) if isinstance(value, float): From 5e8522d2ec7badb1fc997b423bb914758551f935 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Thu, 30 Jun 2022 16:13:46 -0400 Subject: [PATCH 08/27] Format Python code --- tests/python/unittest/test_tvmscript_printer_doc.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py index d98d036fb33b..3fd8f25473a3 100644 --- a/tests/python/unittest/test_tvmscript_printer_doc.py +++ b/tests/python/unittest/test_tvmscript_printer_doc.py @@ -21,16 +21,7 @@ @pytest.mark.parametrize( "value", - [ - None, - "test", - 0, - 1, - -2, - 0.0, - 1.5, - -1.3 - ], + [None, "test", 0, 1, -2, 0.0, 1.5, -1.3], ) def test_literal_doc_construction(value): doc = LiteralDoc(value) From 11d6a514bc87483afb973e0d622dd426557c0370 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Thu, 30 Jun 2022 17:34:44 -0400 Subject: [PATCH 09/27] Remove type alias --- include/tvm/script/printer/doc_printer.h | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/include/tvm/script/printer/doc_printer.h b/include/tvm/script/printer/doc_printer.h index 0a61df5b6707..32a366d36782 100644 --- a/include/tvm/script/printer/doc_printer.h +++ b/include/tvm/script/printer/doc_printer.h @@ -21,8 +21,9 @@ #include -#include #include +#include +#include namespace tvm { namespace script { @@ -90,8 +91,6 @@ class DocPrinter { */ virtual void PrintTypedDoc(const LiteralDoc& doc) = 0; - using OutputStream = std::ostringstream; - /*! * \brief Increase the indent level of any content to be * printed after this call @@ -109,7 +108,7 @@ class DocPrinter { * * \sa output_ */ - OutputStream& NewLine() { + std::ostream& NewLine() { output_ << "\n"; output_ << std::string(indent_, ' '); return output_; @@ -123,7 +122,7 @@ class DocPrinter { * * \sa GetString */ - OutputStream output_; + std::ostringstream output_; private: /*! \brief the printer options */ From 93a044b30c31d69bb022e3ea81ddcc2c224d2e9b Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 1 Jul 2022 16:57:21 -0400 Subject: [PATCH 10/27] Fix name convention --- python/tvm/script/printer/doc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index a8e0b3565182..f3ab2573556b 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -30,7 +30,7 @@ class ExprDoc(Object): """Base class of all expression Docs""" -_literal_constructors = [ +LITERAL_CONSTRUCTORS = [ (str, _ffi_api.LiteralDocStr), (float, _ffi_api.LiteralDocFloat), (int, _ffi_api.LiteralDocInt), @@ -46,7 +46,7 @@ def __init__(self, value): self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone) return - for (cls, constructor) in _literal_constructors: + for (cls, constructor) in LITERAL_CONSTRUCTORS: if isinstance(value, cls): break else: From 9d297e10260bef4700cd159ad3df46c93fccf575 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 1 Jul 2022 16:59:07 -0400 Subject: [PATCH 11/27] Move doc printer to private headers --- src/script/printer/doc_printer.cc | 2 +- {include/tvm => src}/script/printer/doc_printer.h | 0 src/script/printer/python_doc_printer.cc | 3 ++- 3 files changed, 3 insertions(+), 2 deletions(-) rename {include/tvm => src}/script/printer/doc_printer.h (100%) diff --git a/src/script/printer/doc_printer.cc b/src/script/printer/doc_printer.cc index cbd760f55a22..888f3e3ed1e9 100644 --- a/src/script/printer/doc_printer.cc +++ b/src/script/printer/doc_printer.cc @@ -17,7 +17,7 @@ * under the License. */ -#include +#include "./doc_printer.h" namespace tvm { namespace script { diff --git a/include/tvm/script/printer/doc_printer.h b/src/script/printer/doc_printer.h similarity index 100% rename from include/tvm/script/printer/doc_printer.h rename to src/script/printer/doc_printer.h diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index a97fc1bed3ae..52c9731e8092 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -18,7 +18,8 @@ */ #include -#include + +#include "./doc_printer.h" namespace tvm { namespace script { From e194332b96b9fcd0dab9f9f6d5c972e43d6b27fd Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 5 Jul 2022 11:21:18 -0400 Subject: [PATCH 12/27] Fix doc and include --- include/tvm/script/printer/doc.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 7e860f04f992..c7a74bfcf4e2 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -21,8 +21,7 @@ #include #include - -#include "tvm/runtime/data_type.h" +#include namespace tvm { namespace script { @@ -100,8 +99,11 @@ class LiteralDocNode : public ExprDocNode { /*! * \brief the internal representation of the literal value. * - * The actual type is union of IntImm, FloatImm and String, or a - * null ObjectRef. + * Possible actual types: + * - IntImm (integer or boolean) + * - FloatImm + * - String + * - null */ ObjectRef value; From d6173107fa0050fe465a1dfe8d951c344cd2a27b Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 5 Jul 2022 11:21:44 -0400 Subject: [PATCH 13/27] Remove indirections from LiteralDoc Python constructor --- python/tvm/script/printer/doc.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index f3ab2573556b..d59f13d838f7 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -30,26 +30,18 @@ class ExprDoc(Object): """Base class of all expression Docs""" -LITERAL_CONSTRUCTORS = [ - (str, _ffi_api.LiteralDocStr), - (float, _ffi_api.LiteralDocFloat), - (int, _ffi_api.LiteralDocInt), -] - - @tvm._ffi.register_object("script.printer.LiteralDoc") class LiteralDoc(ExprDoc): """Doc that represents literal value""" def __init__(self, value): - if value is None: - self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone) - return - - for (cls, constructor) in LITERAL_CONSTRUCTORS: - if isinstance(value, cls): - break + if isinstance(value, str): + self.__init_handle_by_constructor__(_ffi_api.LiteralDocStr(), value) # type: ignore + elif isinstance(value, float): + self.__init_handle_by_constructor__(_ffi_api.LiteralDocFloat(), value) # type: ignore + elif isinstance(value, int): + self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt(), value) # type: ignore + elif value is None: + self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone()) # type: ignore else: raise TypeError(f"Unsupported type {type(value)} for LiteralDoc") - - self.__init_handle_by_constructor__(constructor, value) From 2701bb0572baa3d83e45d19508f65315c5e22064 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 5 Jul 2022 11:21:52 -0400 Subject: [PATCH 14/27] Add printer package to mypy check --- tests/scripts/task_mypy.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/scripts/task_mypy.sh b/tests/scripts/task_mypy.sh index 1ef7db589432..f165adfe1bc4 100755 --- a/tests/scripts/task_mypy.sh +++ b/tests/scripts/task_mypy.sh @@ -32,6 +32,9 @@ mypy --check-untyped-defs python/tvm/tir/analysis/ echo "Checking MyPy Type defs in the transform package." mypy --check-untyped-defs python/tvm/tir/transform/ +echo "Checking MyPy Type defs in the tvmscript printer package." +mypy --check-untyped-defs python/tvm/script/printer + echo "Checking MyPy Type defs in the TIR package with unittest" MYPYPATH=$TVM_PATH/python mypy --check-untyped-defs tests/python/unittest/test_tvmscript_type.py From 64d4814b65120e5177bf76bf12c1af24c8a4f8de Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 5 Jul 2022 11:29:47 -0400 Subject: [PATCH 15/27] Move around registration --- src/script/printer/doc.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index 333d8a3ab7ed..88b21445567c 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -23,14 +23,14 @@ namespace tvm { namespace script { namespace printer { -TVM_REGISTER_NODE_TYPE(DocNode); -TVM_REGISTER_NODE_TYPE(ExprDocNode); - LiteralDoc::LiteralDoc(ObjectRef value) { ObjectPtr n = make_object(); n->value = value; this->data_ = std::move(n); } + +TVM_REGISTER_NODE_TYPE(DocNode); +TVM_REGISTER_NODE_TYPE(ExprDocNode); TVM_REGISTER_NODE_TYPE(LiteralDocNode); TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None); TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int); From b6e548b61648c57982a1d3af1361b263660c8b87 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 5 Jul 2022 11:33:39 -0400 Subject: [PATCH 16/27] Fix typo in doc.py --- python/tvm/script/printer/doc.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index d59f13d838f7..8339e62120f8 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -36,12 +36,12 @@ class LiteralDoc(ExprDoc): def __init__(self, value): if isinstance(value, str): - self.__init_handle_by_constructor__(_ffi_api.LiteralDocStr(), value) # type: ignore + self.__init_handle_by_constructor__(_ffi_api.LiteralDocStr, value) # type: ignore elif isinstance(value, float): - self.__init_handle_by_constructor__(_ffi_api.LiteralDocFloat(), value) # type: ignore + self.__init_handle_by_constructor__(_ffi_api.LiteralDocFloat, value) # type: ignore elif isinstance(value, int): - self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt(), value) # type: ignore + self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt, value) # type: ignore elif value is None: - self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone()) # type: ignore + self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone) # type: ignore else: raise TypeError(f"Unsupported type {type(value)} for LiteralDoc") From 52787c5c4dc29b2f91174a4686e1ebe39241b4d6 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 5 Jul 2022 11:52:01 -0400 Subject: [PATCH 17/27] Add string escape --- src/script/printer/python_doc_printer.cc | 4 ++-- tests/python/unittest/test_tvmscript_python_doc_printer.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index 52c9731e8092..5ac4ef52c2be 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -19,6 +19,7 @@ #include +#include "../../support/str_escape.h" #include "./doc_printer.h" namespace tvm { @@ -34,8 +35,7 @@ namespace { * \param out the output stream */ void PrintLiteralString(const String& string, std::ostringstream& out) { - // TODO(yelite): Escape and smart quote (choose ' or " automatically) - out << "\"" << string << "\""; + out << "\"" << support::StrEscape(string->data, string->size) << "\""; } /*! diff --git a/tests/python/unittest/test_tvmscript_python_doc_printer.py b/tests/python/unittest/test_tvmscript_python_doc_printer.py index 9732ff8c97c5..ae7d61f33d68 100644 --- a/tests/python/unittest/test_tvmscript_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_python_doc_printer.py @@ -41,8 +41,10 @@ def print_doc_as_python(doc, indent_spaces=4): (LiteralDoc(None), "None"), (LiteralDoc("test"), '"test"'), (LiteralDoc(""), '""'), - # TODO: make the string printing add character escaping - pytest.param(LiteralDoc('"'), r'"\""', marks=pytest.mark.xfail), + (LiteralDoc('""'), r'"\"\""'), + (LiteralDoc('\n\t\\test\r'), r'"\n\t\\test\r"'), + # TODO: make the roundatrippable problem caused by utf8 + pytest.param(LiteralDoc('\x88'), r'"\x88"', marks=pytest.mark.xfail), (LiteralDoc(0), "0"), (LiteralDoc(-1), "-1"), (LiteralDoc(3.25), "3.25"), From 5bf12034fd5bad67d2a73f8ab85093e66d73eee3 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 5 Jul 2022 12:25:53 -0400 Subject: [PATCH 18/27] Add boolean support and reorganize code --- include/tvm/script/printer/doc.h | 6 ++ python/tvm/script/printer/doc.py | 2 + src/script/printer/doc.cc | 1 + src/script/printer/python_doc_printer.cc | 56 ++++--------------- .../unittest/test_tvmscript_printer_doc.py | 5 +- .../test_tvmscript_python_doc_printer.py | 2 + 6 files changed, 26 insertions(+), 46 deletions(-) diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index c7a74bfcf4e2..67c27bd45a1d 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -137,6 +137,12 @@ class LiteralDoc : public ExprDoc { */ static LiteralDoc Int(int v) { return LiteralDoc(IntImm(DataType::Int(64), v)); } + /*! + * \brief Create a LiteralDoc to represent boolean. + * \param v The boolean value. + */ + static LiteralDoc Boolean(bool v) { return LiteralDoc(IntImm(DataType::Bool(), v)); } + /*! * \brief Create a LiteralDoc to represent float. * \param v The float value. diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 8339e62120f8..96a5efcbebb3 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -39,6 +39,8 @@ def __init__(self, value): self.__init_handle_by_constructor__(_ffi_api.LiteralDocStr, value) # type: ignore elif isinstance(value, float): self.__init_handle_by_constructor__(_ffi_api.LiteralDocFloat, value) # type: ignore + elif isinstance(value, bool): + self.__init_handle_by_constructor__(_ffi_api.LiteralDocBoolean, value) # type: ignore elif isinstance(value, int): self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt, value) # type: ignore elif value is None: diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index 88b21445567c..e54adbd36b4c 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -34,6 +34,7 @@ TVM_REGISTER_NODE_TYPE(ExprDocNode); TVM_REGISTER_NODE_TYPE(LiteralDocNode); TVM_REGISTER_GLOBAL("script.printer.LiteralDocNone").set_body_typed(LiteralDoc::None); TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt").set_body_typed(LiteralDoc::Int); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean").set_body_typed(LiteralDoc::Boolean); TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat").set_body_typed(LiteralDoc::Float); TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr").set_body_typed(LiteralDoc::Str); diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index 5ac4ef52c2be..87b391aecab3 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -26,46 +26,6 @@ namespace tvm { namespace script { namespace printer { -namespace { - -/*! - * \brief Print a Python literal string - * - * \param string the string to be printed - * \param out the output stream - */ -void PrintLiteralString(const String& string, std::ostringstream& out) { - out << "\"" << support::StrEscape(string->data, string->size) << "\""; -} - -/*! - * \brief Print a tvm::ir::PrimExpr as Python literal - * - * This only supports IntImm and FloatImm with size of 64 bits - * - * \param expr the PrimExpr to be printed - * \param out the output stream - */ -void PrintLiteralPrimExpr(const PrimExpr& expr, std::ostringstream& out) { - const DataType& dtype = expr->dtype; - - if (dtype == DataType::Int(64)) { - out << Downcast(expr)->value; - } else if (dtype == DataType::Float(64)) { - // TODO(yelite): make the float printing roundtrippable - std::ostringstream number_value; - number_value.precision(17); - number_value << Downcast(expr)->value; - out << number_value.str(); - } else if (dtype == DataType::Bool()) { - out << (Downcast(expr)->value ? "True" : "False"); - } else { - LOG(FATAL) << "Cannot print value with dtype " << dtype << " as literal expression"; - } -} - -} // namespace - class PythonDocPrinter : public DocPrinter { public: PythonDocPrinter(const DocPrinterOptions& options) : DocPrinter(options) {} @@ -80,12 +40,20 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { const ObjectRef& value = doc->value; if (!value.defined()) { output_ << "None"; - } else if (const auto* expr_node = value.as()) { - PrintLiteralPrimExpr(GetRef(expr_node), output_); + } else if (const auto* int_imm = value.as()) { + if (int_imm->dtype.is_bool()) { + output_ << (int_imm->value ? "True" : "False"); + } else { + output_ << int_imm->value; + } + } else if (const auto* float_imm = value.as()) { + // TODO(yelite): Make float number printing roundtrippable + output_.precision(17); + output_ << float_imm->value; } else if (const auto* string_obj = value.as()) { - PrintLiteralString(GetRef(string_obj), output_); + output_ << "\"" << support::StrEscape(string_obj->data, string_obj->size) << "\""; } else { - LOG(FATAL) << "Unsupported literal value type " << value->GetTypeKey(); + LOG(FATAL) << "TypeError: Unsupported literal value type: " << value->GetTypeKey(); } } diff --git a/tests/python/unittest/test_tvmscript_printer_doc.py b/tests/python/unittest/test_tvmscript_printer_doc.py index 3fd8f25473a3..6330d33bf25a 100644 --- a/tests/python/unittest/test_tvmscript_printer_doc.py +++ b/tests/python/unittest/test_tvmscript_printer_doc.py @@ -16,17 +16,18 @@ # under the License. import pytest +from tvm.tir import IntImm from tvm.script.printer.doc import LiteralDoc @pytest.mark.parametrize( "value", - [None, "test", 0, 1, -2, 0.0, 1.5, -1.3], + [None, "test", 0, 1, -2, 0.0, 1.5, -1.3, True, False], ) def test_literal_doc_construction(value): doc = LiteralDoc(value) if isinstance(value, float): - # FloatImm isn't unpacked to Python's float automatically + # FloatImm cannot be compared with Python's float directly assert float(doc.value) == pytest.approx(value) else: assert doc.value == value diff --git a/tests/python/unittest/test_tvmscript_python_doc_printer.py b/tests/python/unittest/test_tvmscript_python_doc_printer.py index ae7d61f33d68..8b9ecffad0e9 100644 --- a/tests/python/unittest/test_tvmscript_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_python_doc_printer.py @@ -39,6 +39,8 @@ def print_doc_as_python(doc, indent_spaces=4): "doc,expected", [ (LiteralDoc(None), "None"), + (LiteralDoc(True), "True"), + (LiteralDoc(False), "False"), (LiteralDoc("test"), '"test"'), (LiteralDoc(""), '""'), (LiteralDoc('""'), r'"\"\""'), From 677adb1426eee73265449c7ce88245ebb0e73e73 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 5 Jul 2022 12:26:26 -0400 Subject: [PATCH 19/27] Rename test file for better file name consistency --- ...oc_printer.py => test_tvmscript_printer_python_doc_printer.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/python/unittest/{test_tvmscript_python_doc_printer.py => test_tvmscript_printer_python_doc_printer.py} (100%) diff --git a/tests/python/unittest/test_tvmscript_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py similarity index 100% rename from tests/python/unittest/test_tvmscript_python_doc_printer.py rename to tests/python/unittest/test_tvmscript_printer_python_doc_printer.py From 93e24f9d1ba429a9a1c617b158ade461b3848f60 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 5 Jul 2022 14:03:11 -0400 Subject: [PATCH 20/27] Fix lint problem --- src/script/printer/python_doc_printer.cc | 2 +- .../unittest/test_tvmscript_printer_python_doc_printer.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index 87b391aecab3..c9baf60572b7 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -28,7 +28,7 @@ namespace printer { class PythonDocPrinter : public DocPrinter { public: - PythonDocPrinter(const DocPrinterOptions& options) : DocPrinter(options) {} + explicit PythonDocPrinter(const DocPrinterOptions& options) : DocPrinter(options) {} protected: using DocPrinter::PrintDoc; diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index 8b9ecffad0e9..ba5ccccdabdb 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -44,9 +44,9 @@ def print_doc_as_python(doc, indent_spaces=4): (LiteralDoc("test"), '"test"'), (LiteralDoc(""), '""'), (LiteralDoc('""'), r'"\"\""'), - (LiteralDoc('\n\t\\test\r'), r'"\n\t\\test\r"'), + (LiteralDoc("\n\t\\test\r"), r'"\n\t\\test\r"'), # TODO: make the roundatrippable problem caused by utf8 - pytest.param(LiteralDoc('\x88'), r'"\x88"', marks=pytest.mark.xfail), + pytest.param(LiteralDoc("\x88"), r'"\x88"', marks=pytest.mark.xfail), (LiteralDoc(0), "0"), (LiteralDoc(-1), "-1"), (LiteralDoc(3.25), "3.25"), From 9daff554081f2342f270552fb487d9fa1a2ed5ce Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 5 Jul 2022 14:39:41 -0400 Subject: [PATCH 21/27] Expose the entry function of printing Doc to Python script to public headers --- include/tvm/script/printer/doc_printer.h | 56 +++++++++++++++++++ python/tvm/script/printer/doc_printer.py | 39 +++++++++++++ .../{doc_printer.cc => base_doc_printer.cc} | 4 +- .../{doc_printer.h => base_doc_printer.h} | 30 +++------- src/script/printer/python_doc_printer.cc | 24 ++++---- ...st_tvmscript_printer_python_doc_printer.py | 8 +-- 6 files changed, 120 insertions(+), 41 deletions(-) create mode 100644 include/tvm/script/printer/doc_printer.h create mode 100644 python/tvm/script/printer/doc_printer.py rename src/script/printer/{doc_printer.cc => base_doc_printer.cc} (92%) rename src/script/printer/{doc_printer.h => base_doc_printer.h} (83%) diff --git a/include/tvm/script/printer/doc_printer.h b/include/tvm/script/printer/doc_printer.h new file mode 100644 index 000000000000..cbe7ee214e2a --- /dev/null +++ b/include/tvm/script/printer/doc_printer.h @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_DOC_PRINTER_H_ +#define TVM_SCRIPT_PRINTER_DOC_PRINTER_H_ + +#include + +namespace tvm { +namespace script { +namespace printer { + +/*! + * \brief Configurable options for converting Doc into text format + */ +struct DocPrintingOptions { + int indent_spaces = 4; +}; + +/*! + * \brief Convert Doc into Python script. + * + * \param options the option for printer + */ +String DocToPythonScript(Doc doc, DocPrintingOptions options); + +/*! + * \brief Convert Doc into Python script. + * + * This function unpacks the DocPrinterOptions into function arguments + * to be FFI friendly. + * + * \param indent_spaces the number of spaces used for indention + */ +String DocToPythonScript(Doc doc, int indent_spaces = 4); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_DOC_PRINTER_H_ diff --git a/python/tvm/script/printer/doc_printer.py b/python/tvm/script/printer/doc_printer.py new file mode 100644 index 000000000000..7f4765d86269 --- /dev/null +++ b/python/tvm/script/printer/doc_printer.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Functions to print doc into text format""" + +from . import _ffi_api +from .doc import Doc + + +def to_python_script(doc: Doc, indent_spaces: int = 4) -> str: + """ + Convert Doc into Python script. + + Parameters + ---------- + doc : Doc + The doc to convert into Python script + indent_spaces : int + The number of indent spaces to use in the output + + Returns + ------- + script : str + The text representation of Doc in Python syntax + """ + return _ffi_api.DocToScriptPython(doc, indent_spaces) # type: ignore diff --git a/src/script/printer/doc_printer.cc b/src/script/printer/base_doc_printer.cc similarity index 92% rename from src/script/printer/doc_printer.cc rename to src/script/printer/base_doc_printer.cc index 888f3e3ed1e9..ebf0c9576080 100644 --- a/src/script/printer/doc_printer.cc +++ b/src/script/printer/base_doc_printer.cc @@ -17,13 +17,13 @@ * under the License. */ -#include "./doc_printer.h" +#include "./base_doc_printer.h" namespace tvm { namespace script { namespace printer { -DocPrinter::DocPrinter(const DocPrinterOptions& options) : options_(options) {} +DocPrinter::DocPrinter(const DocPrintingOptions& options) : options_(options) {} void DocPrinter::Append(const Doc& doc) { PrintDoc(doc); } diff --git a/src/script/printer/doc_printer.h b/src/script/printer/base_doc_printer.h similarity index 83% rename from src/script/printer/doc_printer.h rename to src/script/printer/base_doc_printer.h index 32a366d36782..5b09da29231b 100644 --- a/src/script/printer/doc_printer.h +++ b/src/script/printer/base_doc_printer.h @@ -16,10 +16,11 @@ * specific language governing permissions and limitations * under the License. */ -#ifndef TVM_SCRIPT_PRINTER_DOC_PRINTER_H_ -#define TVM_SCRIPT_PRINTER_DOC_PRINTER_H_ +#ifndef TVM_SCRIPT_PRINTER_BASE_DOC_PRINTER_H_ +#define TVM_SCRIPT_PRINTER_BASE_DOC_PRINTER_H_ #include +#include #include #include @@ -29,15 +30,6 @@ namespace tvm { namespace script { namespace printer { -/*! - * \brief Configurable options for DocPrinter - * - * \sa DocPrinter - */ -struct DocPrinterOptions { - int indent_spaces = 4; -}; - /*! * \brief DocPrinter is responsible for printing Doc tree into text format * \details This is the base class for translating Doc into string. @@ -53,7 +45,7 @@ class DocPrinter { * * \param options the option for printer */ - explicit DocPrinter(const DocPrinterOptions& options); + explicit DocPrinter(const DocPrintingOptions& options); virtual ~DocPrinter() = default; /*! @@ -125,23 +117,15 @@ class DocPrinter { std::ostringstream output_; private: - /*! \brief the printer options */ - DocPrinterOptions options_; + /*! \brief the printing options */ + DocPrintingOptions options_; /*! \brief the current level of indent */ int indent_ = 0; }; -/*! - * \brief Get a doc printer to print Doc into Python code - * - * \param options the option for printer - * \return A pointer to the printer - */ -std::unique_ptr GetPythonDocPrinter(const DocPrinterOptions& options); - } // namespace printer } // namespace script } // namespace tvm -#endif // TVM_SCRIPT_PRINTER_DOC_PRINTER_H_ +#endif // TVM_SCRIPT_PRINTER_BASE_DOC_PRINTER_H_ diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index c9baf60572b7..2b3595651894 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -20,7 +20,7 @@ #include #include "../../support/str_escape.h" -#include "./doc_printer.h" +#include "./base_doc_printer.h" namespace tvm { namespace script { @@ -28,7 +28,7 @@ namespace printer { class PythonDocPrinter : public DocPrinter { public: - explicit PythonDocPrinter(const DocPrinterOptions& options) : DocPrinter(options) {} + explicit PythonDocPrinter(const DocPrintingOptions& options) : DocPrinter(options) {} protected: using DocPrinter::PrintDoc; @@ -57,16 +57,20 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { } } -std::unique_ptr GetPythonDocPrinter(const DocPrinterOptions& options) { - return std::make_unique(options); +String DocToPythonScript(Doc doc, DocPrintingOptions options) { + PythonDocPrinter printer(options); + printer.Append(doc); + return printer.GetString(); } -TVM_REGISTER_GLOBAL("script.printer.PrintDocAsPython") - .set_body_typed([](Doc doc, int indent_spaces = 4) { - PythonDocPrinter printer({.indent_spaces = indent_spaces}); - printer.Append(doc); - return printer.GetString(); - }); +String DocToPythonScript(Doc doc, int indent_spaces) { + DocPrintingOptions options; + options.indent_spaces = indent_spaces; + return DocToPythonScript(doc, options); +} + +TVM_REGISTER_GLOBAL("script.printer.DocToPythonScript") + .set_body_typed(DocToPythonScript); } // namespace printer } // namespace script diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index ba5ccccdabdb..97949c666943 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -16,7 +16,7 @@ # under the License. import pytest -from tvm.script.printer import _ffi_api +from tvm.script.printer.doc_printer import to_python_script from tvm.script.printer.doc import LiteralDoc @@ -31,10 +31,6 @@ def format_script(s: str) -> str: return "\n".join(line[spaces_to_remove:] for line in s.splitlines()) -def print_doc_as_python(doc, indent_spaces=4): - return format_script(_ffi_api.PrintDocAsPython(doc, indent_spaces)) - - @pytest.mark.parametrize( "doc,expected", [ @@ -57,4 +53,4 @@ def print_doc_as_python(doc, indent_spaces=4): ], ) def test_print_literal_doc(doc, expected): - assert print_doc_as_python(doc) == format_script(expected) + assert to_python_script(doc) == format_script(expected) From 5af3f05fac1f893e8c7f2ede8066bd52bdd08525 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 5 Jul 2022 14:51:27 -0400 Subject: [PATCH 22/27] Add missing doc --- include/tvm/script/printer/doc_printer.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/tvm/script/printer/doc_printer.h b/include/tvm/script/printer/doc_printer.h index cbe7ee214e2a..d56c8f4da63a 100644 --- a/include/tvm/script/printer/doc_printer.h +++ b/include/tvm/script/printer/doc_printer.h @@ -35,6 +35,7 @@ struct DocPrintingOptions { /*! * \brief Convert Doc into Python script. * + * \param doc the doc to be converted * \param options the option for printer */ String DocToPythonScript(Doc doc, DocPrintingOptions options); @@ -45,6 +46,7 @@ String DocToPythonScript(Doc doc, DocPrintingOptions options); * This function unpacks the DocPrinterOptions into function arguments * to be FFI friendly. * + * \param doc the doc to be converted * \param indent_spaces the number of spaces used for indention */ String DocToPythonScript(Doc doc, int indent_spaces = 4); From 37c4f7d6391683a0193f187ea84298a37c4bcb7f Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 5 Jul 2022 16:57:42 -0400 Subject: [PATCH 23/27] Fix failed printer tests --- python/tvm/script/printer/doc_printer.py | 2 +- .../unittest/test_tvmscript_printer_python_doc_printer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/printer/doc_printer.py b/python/tvm/script/printer/doc_printer.py index 7f4765d86269..404632b44c07 100644 --- a/python/tvm/script/printer/doc_printer.py +++ b/python/tvm/script/printer/doc_printer.py @@ -36,4 +36,4 @@ def to_python_script(doc: Doc, indent_spaces: int = 4) -> str: script : str The text representation of Doc in Python syntax """ - return _ffi_api.DocToScriptPython(doc, indent_spaces) # type: ignore + return _ffi_api.DocToPythonScript(doc, indent_spaces) # type: ignore diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index 97949c666943..55e1a0c35cfc 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -53,4 +53,4 @@ def format_script(s: str) -> str: ], ) def test_print_literal_doc(doc, expected): - assert to_python_script(doc) == format_script(expected) + assert to_python_script(doc).rstrip("\n") == format_script(expected) From 79cb3dbf9d1fcbac0ae34b8d213d1ec06b13e78f Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Wed, 6 Jul 2022 10:06:21 -0400 Subject: [PATCH 24/27] Remove unnecessary xfail test cases --- .../unittest/test_tvmscript_printer_python_doc_printer.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index 55e1a0c35cfc..59cf6eae92a9 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -47,9 +47,6 @@ def format_script(s: str) -> str: (LiteralDoc(-1), "-1"), (LiteralDoc(3.25), "3.25"), (LiteralDoc(-0.5), "-0.5"), - # TODO: make the float number printing preserve percision and roundtrippable - pytest.param(LiteralDoc(0.0), "0.0", marks=pytest.mark.xfail), - pytest.param(LiteralDoc(3.14), "3.14", marks=pytest.mark.xfail), ], ) def test_print_literal_doc(doc, expected): From 0c80e6ee551c26789786e277751da4f40bc7a3c6 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Wed, 6 Jul 2022 10:06:39 -0400 Subject: [PATCH 25/27] Check if value is None first --- python/tvm/script/printer/doc.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/script/printer/doc.py b/python/tvm/script/printer/doc.py index 96a5efcbebb3..f6179d7351b2 100644 --- a/python/tvm/script/printer/doc.py +++ b/python/tvm/script/printer/doc.py @@ -35,7 +35,9 @@ class LiteralDoc(ExprDoc): """Doc that represents literal value""" def __init__(self, value): - if isinstance(value, str): + if value is None: + self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone) # type: ignore + elif isinstance(value, str): self.__init_handle_by_constructor__(_ffi_api.LiteralDocStr, value) # type: ignore elif isinstance(value, float): self.__init_handle_by_constructor__(_ffi_api.LiteralDocFloat, value) # type: ignore @@ -43,7 +45,5 @@ def __init__(self, value): self.__init_handle_by_constructor__(_ffi_api.LiteralDocBoolean, value) # type: ignore elif isinstance(value, int): self.__init_handle_by_constructor__(_ffi_api.LiteralDocInt, value) # type: ignore - elif value is None: - self.__init_handle_by_constructor__(_ffi_api.LiteralDocNone) # type: ignore else: raise TypeError(f"Unsupported type {type(value)} for LiteralDoc") From f95f5fc0dfac4bd9c34925b8a405448e5e733219 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Wed, 6 Jul 2022 10:20:31 -0400 Subject: [PATCH 26/27] Fix typos --- .../unittest/test_tvmscript_printer_python_doc_printer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py index 59cf6eae92a9..55b5e88c88c8 100644 --- a/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py +++ b/tests/python/unittest/test_tvmscript_printer_python_doc_printer.py @@ -41,7 +41,7 @@ def format_script(s: str) -> str: (LiteralDoc(""), '""'), (LiteralDoc('""'), r'"\"\""'), (LiteralDoc("\n\t\\test\r"), r'"\n\t\\test\r"'), - # TODO: make the roundatrippable problem caused by utf8 + # TODO: fix the roundatrippable problem caused by utf8 pytest.param(LiteralDoc("\x88"), r'"\x88"', marks=pytest.mark.xfail), (LiteralDoc(0), "0"), (LiteralDoc(-1), "-1"), From 2f5f5620f6596302e2401b393df433d65bc89667 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Wed, 6 Jul 2022 10:28:17 -0400 Subject: [PATCH 27/27] Remove DocPrintingOptions --- include/tvm/script/printer/doc_printer.h | 15 --------------- src/script/printer/base_doc_printer.cc | 2 +- src/script/printer/base_doc_printer.h | 10 +++++----- src/script/printer/python_doc_printer.cc | 15 ++++----------- 4 files changed, 10 insertions(+), 32 deletions(-) diff --git a/include/tvm/script/printer/doc_printer.h b/include/tvm/script/printer/doc_printer.h index d56c8f4da63a..6bf502fab910 100644 --- a/include/tvm/script/printer/doc_printer.h +++ b/include/tvm/script/printer/doc_printer.h @@ -25,21 +25,6 @@ namespace tvm { namespace script { namespace printer { -/*! - * \brief Configurable options for converting Doc into text format - */ -struct DocPrintingOptions { - int indent_spaces = 4; -}; - -/*! - * \brief Convert Doc into Python script. - * - * \param doc the doc to be converted - * \param options the option for printer - */ -String DocToPythonScript(Doc doc, DocPrintingOptions options); - /*! * \brief Convert Doc into Python script. * diff --git a/src/script/printer/base_doc_printer.cc b/src/script/printer/base_doc_printer.cc index ebf0c9576080..f6874ba1a2ee 100644 --- a/src/script/printer/base_doc_printer.cc +++ b/src/script/printer/base_doc_printer.cc @@ -23,7 +23,7 @@ namespace tvm { namespace script { namespace printer { -DocPrinter::DocPrinter(const DocPrintingOptions& options) : options_(options) {} +DocPrinter::DocPrinter(int indent_spaces) : indent_spaces_(indent_spaces) {} void DocPrinter::Append(const Doc& doc) { PrintDoc(doc); } diff --git a/src/script/printer/base_doc_printer.h b/src/script/printer/base_doc_printer.h index 5b09da29231b..128fcef2ea32 100644 --- a/src/script/printer/base_doc_printer.h +++ b/src/script/printer/base_doc_printer.h @@ -45,7 +45,7 @@ class DocPrinter { * * \param options the option for printer */ - explicit DocPrinter(const DocPrintingOptions& options); + explicit DocPrinter(int indent_spaces = 4); virtual ~DocPrinter() = default; /*! @@ -87,13 +87,13 @@ class DocPrinter { * \brief Increase the indent level of any content to be * printed after this call */ - void IncreaseIndent() { indent_ += options_.indent_spaces; } + void IncreaseIndent() { indent_ += indent_spaces_; } /*! * \brief Decrease the indent level of any content to be * printed after this call */ - void DecreaseIndent() { indent_ -= options_.indent_spaces; } + void DecreaseIndent() { indent_ -= indent_spaces_; } /*! * \brief Add a new line into the output stream @@ -117,8 +117,8 @@ class DocPrinter { std::ostringstream output_; private: - /*! \brief the printing options */ - DocPrintingOptions options_; + /*! \brief the number of spaces for one level of indentation */ + int indent_spaces_ = 4; /*! \brief the current level of indent */ int indent_ = 0; diff --git a/src/script/printer/python_doc_printer.cc b/src/script/printer/python_doc_printer.cc index 2b3595651894..cd816e4f7010 100644 --- a/src/script/printer/python_doc_printer.cc +++ b/src/script/printer/python_doc_printer.cc @@ -28,7 +28,7 @@ namespace printer { class PythonDocPrinter : public DocPrinter { public: - explicit PythonDocPrinter(const DocPrintingOptions& options) : DocPrinter(options) {} + explicit PythonDocPrinter(int indent_spaces = 4) : DocPrinter(indent_spaces) {} protected: using DocPrinter::PrintDoc; @@ -57,20 +57,13 @@ void PythonDocPrinter::PrintTypedDoc(const LiteralDoc& doc) { } } -String DocToPythonScript(Doc doc, DocPrintingOptions options) { - PythonDocPrinter printer(options); +String DocToPythonScript(Doc doc, int indent_spaces) { + PythonDocPrinter printer(indent_spaces); printer.Append(doc); return printer.GetString(); } -String DocToPythonScript(Doc doc, int indent_spaces) { - DocPrintingOptions options; - options.indent_spaces = indent_spaces; - return DocToPythonScript(doc, options); -} - -TVM_REGISTER_GLOBAL("script.printer.DocToPythonScript") - .set_body_typed(DocToPythonScript); +TVM_REGISTER_GLOBAL("script.printer.DocToPythonScript").set_body_typed(DocToPythonScript); } // namespace printer } // namespace script