From 22755f79818326b1b7be88b7d8874421a821f062 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 16 Sep 2022 23:26:01 -0400 Subject: [PATCH 1/3] Update TVMScript infra for future TIR printer changes Co-authored-by: Greg Bonik --- include/tvm/script/printer/doc.h | 64 +++++++++++++ .../script/printer/traced_object_functor.h | 37 +------- include/tvm/script/printer/var_table.h | 11 +++ src/script/printer/doc.cc | 30 ++++-- src/script/printer/ir_docsifier.cc | 2 +- src/script/printer/utils.h | 93 +++++++++++++++++++ src/script/printer/var_table.cc | 3 +- .../cpp/tvmscript_printer_irdocsifier_test.cc | 13 ++- ...ript_printer_traced_object_functor_test.cc | 37 ++++---- 9 files changed, 228 insertions(+), 62 deletions(-) create mode 100644 src/script/printer/utils.h diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 72f343354b1b..1ee7fd6a7fd4 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -22,6 +22,7 @@ #include #include #include +#include namespace tvm { namespace script { @@ -87,6 +88,15 @@ class ExprDocNode : public DocNode { */ ExprDoc Attr(String attr) const; + /*! + * \brief Create a doc representing attribute access on the current ExprDoc + * \param attr The attribute to access. + * + * The ObjectPath of attr will be pushed to the source_path of the returned + * doc. + */ + ExprDoc Attr(TracedObject attr) const; + /*! * \brief Create a doc representing index access on the current ExprDoc * \param indices The indices to access. @@ -242,6 +252,7 @@ class LiteralDocNode : public ExprDocNode { class LiteralDoc : public ExprDoc { protected: explicit LiteralDoc(ObjectRef value); + LiteralDoc(ObjectRef value, ObjectPath object_path); public: /*! @@ -249,30 +260,83 @@ class LiteralDoc : public ExprDoc { */ static LiteralDoc None() { return LiteralDoc(ObjectRef(nullptr)); } + /*! + * \brief Create a LiteralDoc to represent None/null/empty value. + * \param object_path The source path of the returned Doc. + */ + static LiteralDoc None(ObjectPath object_path) { + return LiteralDoc(ObjectRef(nullptr), object_path); + } + /*! * \brief Create a LiteralDoc to represent integer. * \param v The integer value. */ static LiteralDoc Int(int v) { return LiteralDoc(IntImm(DataType::Int(64), v)); } + /*! + * \brief Create a LiteralDoc to represent integer. + * \param v The integer value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Int(const TracedObject& v) { return LiteralDoc(v.Get(), v.GetPath()); } + + /*! + * \brief Create a LiteralDoc to represent integer. + * \param v The integer value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Int(const TracedBasicValue& v) { + return LiteralDoc(IntImm(DataType::Int(64), v.Get()), v.GetPath()); + } /*! * \brief Create a LiteralDoc to represent boolean. * \param v The boolean value. */ static LiteralDoc Boolean(bool v) { return LiteralDoc(IntImm(DataType::Bool(), v)); } + /*! + * \brief Create a LiteralDoc to represent boolean. + * \param v The boolean value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Boolean(const TracedBasicValue& v) { + return LiteralDoc(IntImm(DataType::Bool(), v.Get()), v.GetPath()); + } + /*! * \brief Create a LiteralDoc to represent float. * \param v The float value. */ static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); } + /*! + * \brief Create a LiteralDoc to represent float. + * \param v The float value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Float(const TracedObject& v) { + return LiteralDoc(v.Get(), v.GetPath()); + } + /*! * \brief Create a LiteralDoc to represent string. * \param v The string value. */ static LiteralDoc Str(const String& v) { return LiteralDoc(v); } + /*! + * \brief Create a LiteralDoc to represent string. + * \param v The string value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Str(const TracedObject& v) { return LiteralDoc(v.Get(), v.GetPath()); } + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode); }; diff --git a/include/tvm/script/printer/traced_object_functor.h b/include/tvm/script/printer/traced_object_functor.h index 6caaf8a6e0d5..8f72d139a5a5 100644 --- a/include/tvm/script/printer/traced_object_functor.h +++ b/include/tvm/script/printer/traced_object_functor.h @@ -34,35 +34,6 @@ namespace tvm { namespace script { namespace printer { -namespace { - -namespace detail { -/*! - * \brief Helper template class to extract the type of first argument of a function - * \tparam FType The function type. - */ -template -struct FirstArgTypeGetter; - -template -struct FirstArgTypeGetter { - using T = ArgOne; -}; - -/*! - * \brief Template alias for the type of first argument of a function - * \tparam FType The function type. - * - * The name of public functions are in snake case to be consistent with - * tvm/node/functor.h - */ -template -using FirstArgType = typename detail::FirstArgTypeGetter< - typename tvm::runtime::detail::function_signature::FType>::T; -} // namespace detail - -} // namespace - /* * This type alias and the following free functions are created to reduce the binary bloat * from template and also hide implementation details from this header @@ -156,8 +127,7 @@ class TracedObjectFunctor { * * The diaptch function should have signature `R(TracedObject, Args...)`. */ - template ::ObjectRefType, + template ::value>> TSelf& set_dispatch(String token, TCallable f) { return set_dispatch( @@ -177,9 +147,10 @@ class TracedObjectFunctor { * * Default dispatch function has an empty string as dispatch token. */ - template + template ::value>> TSelf& set_dispatch(TCallable&& f) { - return set_dispatch(kDefaultDispatchToken, std::forward(f)); + return set_dispatch(kDefaultDispatchToken, std::forward(f)); } /*! diff --git a/include/tvm/script/printer/var_table.h b/include/tvm/script/printer/var_table.h index 9300a976c569..2cd9335213a3 100644 --- a/include/tvm/script/printer/var_table.h +++ b/include/tvm/script/printer/var_table.h @@ -103,6 +103,17 @@ class VarTableNode : public Object { */ Optional GetVarDoc(const ObjectRef& obj, const ObjectPath& object_path) const; + /*! + * \brief Get the doc for variable. + * \param obj The traced variable object. + * + * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt. + */ + template + Optional GetVarDoc(const TracedObject obj) const { + return GetVarDoc(obj.Get(), obj.GetPath()); + } + /*! * \brief Check if a variable exists in the table. * \param obj The variable object. diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index d6f5ff35ab53..f3b431bd62db 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -27,6 +27,12 @@ namespace printer { ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef(this), attr); } +ExprDoc ExprDocNode::Attr(TracedObject attr) const { + auto doc = AttrAccessDoc(GetRef(this), attr.Get()); + doc->source_paths.push_back(attr.GetPath()); + return doc; +} + ExprDoc ExprDocNode::operator[](Array indices) const { return IndexDoc(GetRef(this), indices); } @@ -54,6 +60,13 @@ LiteralDoc::LiteralDoc(ObjectRef value) { this->data_ = std::move(n); } +LiteralDoc::LiteralDoc(ObjectRef value, ObjectPath object_path) { + ObjectPtr n = make_object(); + n->value = value; + n->source_paths.push_back(object_path); + this->data_ = std::move(n); +} + IdDoc::IdDoc(String name) { ObjectPtr n = make_object(); n->name = name; @@ -225,7 +238,8 @@ TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths") }); TVM_REGISTER_NODE_TYPE(ExprDocNode); -TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr").set_body_method(&ExprDocNode::Attr); +TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr") + .set_body_method(&ExprDocNode::Attr); TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex") .set_body_method(&ExprDocNode::operator[]); TVM_REGISTER_GLOBAL("script.printer.ExprDocCall") @@ -242,11 +256,15 @@ TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array(LiteralDoc::None); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt") + .set_body_typed(LiteralDoc::Int); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean") + .set_body_typed(LiteralDoc::Boolean); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat") + .set_body_typed(LiteralDoc::Float); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr") + .set_body_typed(LiteralDoc::Str); TVM_REGISTER_NODE_TYPE(IdDocNode); TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); }); diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index b72ed48db63b..7f032ec50269 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -61,7 +61,7 @@ RootNodeContainer::RootNodeContainer(ObjectRef root_node) { // }); // \endcode TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch([](TracedObject obj, IRDocsifier p) -> Doc { + .set_dispatch([](TracedObject obj, IRDocsifier p) -> Doc { String top_dispatch_token = p->dispatch_tokens.back(); ICHECK_NE(top_dispatch_token, ""); ICHECK(false) << "Printing IR " << top_dispatch_token << " is not implemented."; diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h new file mode 100644 index 000000000000..abe7ce5e9a88 --- /dev/null +++ b/src/script/printer/utils.h @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_UTILS_H_ +#define TVM_SCRIPT_PRINTER_UTILS_H_ + +#include +#include + +#include + +namespace tvm { +namespace script { +namespace printer { + +template +Array AsDocArray(const TracedArray& refs, const IRDocsifier& ir_docsifier) { + Array result; + for (auto ref : refs) { + result.push_back(ir_docsifier->AsExprDoc(ref)); + } + return result; +} + +template +Array AsDocArray(std::initializer_list&& refs, const IRDocsifier& ir_docsifier) { + Array result; + for (auto& ref : refs) { + result.push_back(ir_docsifier->AsExprDoc(ref)); + } + return result; +} + +template +Array AsExprDocArray(const TracedArray& refs, const IRDocsifier& ir_docsifier) { + return AsDocArray(refs, ir_docsifier); +} + +template +Array AsExprDocArray(std::initializer_list&& refs, + const IRDocsifier& ir_docsifier) { + return AsDocArray(std::move(refs), ir_docsifier); +} + +inline DictDoc AsDictDoc(const TracedMap& dict, + const IRDocsifier& ir_docsifier) { + Array keys; + Array values; + + for (auto p : dict) { + keys.push_back(LiteralDoc::Str(p.first)); + values.push_back(ir_docsifier->AsExprDoc(p.second)); + } + + auto doc = DictDoc(keys, values); + doc->source_paths.push_back(dict.GetPath()); + return doc; +} + +template +inline ListDoc AsListDoc(const TracedArray& arr, const IRDocsifier& ir_docsifier) { + auto ret = ListDoc(AsExprDocArray(arr, ir_docsifier)); + ret->source_paths.push_back(arr.GetPath()); + return ret; +} + +template +inline TupleDoc AsTupleDoc(const TracedArray& arr, const IRDocsifier& ir_docsifier) { + auto ret = TupleDoc(AsExprDocArray(arr, ir_docsifier)); + ret->source_paths.push_back(arr.GetPath()); + return ret; +} + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_UTILS_H_ diff --git a/src/script/printer/var_table.cc b/src/script/printer/var_table.cc index 49ba93f9bcfe..62d8b2f66cc2 100644 --- a/src/script/printer/var_table.cc +++ b/src/script/printer/var_table.cc @@ -99,7 +99,8 @@ TVM_REGISTER_GLOBAL("script.printer.VarTableDefineByDoc") obj, [f = std::move(factory)]() { return f(); }, frame); }); TVM_REGISTER_GLOBAL("script.printer.VarTableGetVarDoc") - .set_body_method(&VarTableNode::GetVarDoc); + .set_body_method, const ObjectRef&, + const ObjectPath&>(&VarTableNode::GetVarDoc); TVM_REGISTER_GLOBAL("script.printer.VarTableIsVarDefined") .set_body_method(&VarTableNode::IsVarDefined); diff --git a/tests/cpp/tvmscript_printer_irdocsifier_test.cc b/tests/cpp/tvmscript_printer_irdocsifier_test.cc index fcdb5ed04e41..8c68399df222 100644 --- a/tests/cpp/tvmscript_printer_irdocsifier_test.cc +++ b/tests/cpp/tvmscript_printer_irdocsifier_test.cc @@ -45,14 +45,19 @@ class TestObject : public ObjectRef { TVM_REGISTER_NODE_TYPE(TestObjectNode); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch([](TracedObject obj, IRDocsifier p) { return IdDoc("x"); }); + .set_dispatch([](TracedObject obj, IRDocsifier p) { + return IdDoc("x"); + }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("tir", [](TracedObject obj, IRDocsifier p) { return IdDoc("tir"); }); + .set_dispatch("tir", [](TracedObject obj, IRDocsifier p) { + return IdDoc("tir"); + }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("relax", - [](TracedObject obj, IRDocsifier p) { return IdDoc("relax"); }); + .set_dispatch("relax", [](TracedObject obj, IRDocsifier p) { + return IdDoc("relax"); + }); TEST(PrinterIRDocsifierTest, AsDoc) { IRDocsifier p(Map{}); diff --git a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc index 374eb609b6cb..d662ce132405 100644 --- a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc +++ b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc @@ -33,7 +33,7 @@ class FooObjectNode : public Object { public: void VisitAttrs(AttrVisitor* v) {} - static constexpr const char* _type_key = "test.FooObject"; + static constexpr const char* _type_key = "test.TracedObjectFunctor.FooObject"; TVM_DECLARE_FINAL_OBJECT_INFO(FooObjectNode, Object); }; @@ -49,7 +49,7 @@ class BarObjectNode : public Object { public: void VisitAttrs(AttrVisitor* v) {} - static constexpr const char* _type_key = "test.BarObject"; + static constexpr const char* _type_key = "test.TracedObjectFunctor.BarObject"; TVM_DECLARE_FINAL_OBJECT_INFO(BarObjectNode, Object); }; @@ -69,8 +69,8 @@ TEST(TracedObjectFunctorTest, NormalRegistration) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); - functor.set_dispatch([](TracedObject o) -> String { return "Bar"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Bar"; }); ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); ICHECK_EQ(functor("", MakeTraced(BarObject(), path)), "Bar"); @@ -80,8 +80,8 @@ TEST(TracedObjectFunctorTest, RegistrationWithFunction) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o) -> String { return "FooLambda"; }); - functor.set_dispatch("tir", ComputeFoo); + functor.set_dispatch([](TracedObject o) -> String { return "FooLambda"; }); + functor.set_dispatch("tir", ComputeFoo); ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "FooLambda"); ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo"); @@ -91,9 +91,11 @@ TEST(TracedObjectFunctorTest, RegistrationWithDispatchToken) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); - functor.set_dispatch("tir", [](TracedObject o) -> String { return "Foo tir"; }); - functor.set_dispatch("relax", [](TracedObject o) -> String { return "Foo relax"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch("tir", + [](TracedObject o) -> String { return "Foo tir"; }); + functor.set_dispatch("relax", + [](TracedObject o) -> String { return "Foo relax"; }); ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir"); @@ -119,8 +121,8 @@ TEST(TracedObjectFunctorTest, ExtraArg) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o, int x) { return x; }); - functor.set_dispatch([](TracedObject o, int x) { return x + 1; }); + functor.set_dispatch([](TracedObject o, int x) { return x; }); + functor.set_dispatch([](TracedObject o, int x) { return x + 1; }); ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2); ICHECK_EQ(functor("", MakeTraced(BarObject(), path), 2), 3); @@ -131,8 +133,9 @@ TEST(TracedObjectFunctorTest, RemoveDispatchFunction) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); - functor.set_dispatch("tir", [](TracedObject o) -> String { return "Foo tir"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch("tir", + [](TracedObject o) -> String { return "Foo tir"; }); ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir"); @@ -158,11 +161,11 @@ TEST(TracedObjectFunctorTest, DuplicateRegistration_WithoutToken) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o, int x) { return x; }); + functor.set_dispatch([](TracedObject o, int x) { return x; }); bool failed = false; try { - functor.set_dispatch([](TracedObject o, int x) { return x; }); + functor.set_dispatch([](TracedObject o, int x) { return x; }); } catch (...) { failed = true; } @@ -173,11 +176,11 @@ TEST(TracedObjectFunctorTest, DuplicateRegistration_WithToken) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); + functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); bool failed = false; try { - functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); + functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); } catch (...) { failed = true; } From 7551fe007eca2d5fe22ac0a3655b9b4ede27a402 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Wed, 17 Aug 2022 17:13:11 -0400 Subject: [PATCH 2/3] Add TIR var and type printing Co-authored-by: Greg Bonik --- src/script/printer/tir/tir.cc | 77 ++++++++++++++++ src/script/printer/tir/tir.h | 89 ++++++++++++++++++ src/script/printer/tir/type.cc | 69 ++++++++++++++ src/script/printer/tir/var.cc | 77 ++++++++++++++++ .../unittest/test_tvmscript_printer_tir.py | 92 +++++++++++++++++++ 5 files changed, 404 insertions(+) create mode 100644 src/script/printer/tir/tir.cc create mode 100644 src/script/printer/tir/tir.h create mode 100644 src/script/printer/tir/type.cc create mode 100644 src/script/printer/tir/var.cc create mode 100644 tests/python/unittest/test_tvmscript_printer_tir.py diff --git a/src/script/printer/tir/tir.cc b/src/script/printer/tir/tir.cc new file mode 100644 index 000000000000..38bd94a72bb5 --- /dev/null +++ b/src/script/printer/tir/tir.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./tir.h" + +#include +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +TIRTopLevelFrame::TIRTopLevelFrame() : TIRFrame(make_object()) {} + +TIRGeneralFrame::TIRGeneralFrame() : TIRFrame(make_object()) {} + +ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p) { + auto type_annotation = var.GetAttr(&tir::VarNode::type_annotation); + if (type_annotation.Get().defined()) { + return p->AsExprDoc(type_annotation); + } else { + auto dtype = var.GetAttr(&tir::VarNode::dtype); + Type raw_type = GetTypeFromRuntimeDataType(dtype.Get()); + return p->AsExprDoc(MakeTraced(raw_type, dtype.GetPath())); + } +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject obj, IRDocsifier p) { + const ObjectRef& root_node = obj.Get()->root_node; + + TIRTopLevelFrame top_level_frame; + auto frame_ctx = p->WithFrame(top_level_frame); + + // Because we are printing a single element, concise scoping should be allowed by default + top_level_frame->allow_concise_scoping = true; + + Doc root_doc = p->AsDoc(MakeTraced(root_node)); + + Array doc_to_print = top_level_frame->free_var_definitions; + + if (const auto* stmt_doc_node = root_doc.as()) { + doc_to_print.push_back(GetRef(stmt_doc_node)); + } else if (const auto* expr_doc_node = root_doc.as()) { + doc_to_print.push_back(ExprStmtDoc(GetRef(expr_doc_node))); + } else if (const auto* stmt_block_node = root_doc.as()) { + doc_to_print = runtime::Concat(doc_to_print, stmt_block_node->stmts); + } else if (const auto* slice_doc_node = root_doc.as()) { + doc_to_print.push_back(ExprStmtDoc(IdDoc("_")[{GetRef(slice_doc_node)}])); + } else { + ICHECK(false) << "Cannot print " << root_doc->GetTypeKey() << " as top level doc for TIR."; + } + + return StmtBlockDoc(doc_to_print); + }); +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/tir.h b/src/script/printer/tir/tir.h new file mode 100644 index 000000000000..bb5973ee4f3b --- /dev/null +++ b/src/script/printer/tir/tir.h @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_TIR_TIR_H_ +#define TVM_SCRIPT_PRINTER_TIR_TIR_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +class TIRFrameNode : public FrameNode { + public: + mutable bool allow_concise_scoping{false}; + + void VisitAttrs(AttrVisitor* v) { + FrameNode::VisitAttrs(v); + v->Visit("allow_concise_scoping", &allow_concise_scoping); + } + + static constexpr const char* _type_key = "script.printer.TIRFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, FrameNode); +}; + +class TIRFrame : public Frame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, Frame, TIRFrameNode); +}; + +class TIRTopLevelFrameNode : public TIRFrameNode { + public: + Array free_var_definitions; + + void VisitAttrs(AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("free_var_definitions", &free_var_definitions); + } + + static constexpr const char* _type_key = "script.printer.TIRTopLevelFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(TIRTopLevelFrameNode, FrameNode); +}; + +class TIRTopLevelFrame : public TIRFrame { + public: + TIRTopLevelFrame(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRTopLevelFrame, TIRFrame, + TIRTopLevelFrameNode); +}; + +class TIRGeneralFrameNode : public TIRFrameNode { + public: + static constexpr const char* _type_key = "script.printer.TIRGeneralFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(TIRGeneralFrameNode, FrameNode); +}; + +class TIRGeneralFrame : public TIRFrame { + public: + TIRGeneralFrame(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRGeneralFrame, TIRFrame, TIRGeneralFrameNode); +}; + +inline IdDoc TIR(const IRDocsifier& p) { return IdDoc(p->ir_prefix.Get("tir").value_or("T")); } + +ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_TIR_TIR_H_ diff --git a/src/script/printer/tir/type.cc b/src/script/printer/tir/type.cc new file mode 100644 index 000000000000..09aa96be7847 --- /dev/null +++ b/src/script/printer/tir/type.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../utils.h" +#include "./tir.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject ty, IRDocsifier p) -> Doc { + TracedBasicValue dtype = ty.GetAttr(&PrimTypeNode::dtype); + String ty_str = runtime::DLDataType2String(dtype.Get()); + return TIR(p)->Attr(MakeTraced(ty_str, ty.GetPath())); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject ty, IRDocsifier p) -> Doc { + TracedObject element_type = ty.GetAttr(&PointerTypeNode::element_type); + TracedObject storage_scope = ty.GetAttr(&PointerTypeNode::storage_scope); + + ExprDoc element_type_doc = p->AsDoc(element_type); + if (storage_scope.Get().empty()) { + return TIR(p)->Attr("Ptr")->Call({element_type_doc}); + } else { + return TIR(p)->Attr("Ptr")->Call({element_type_doc, LiteralDoc::Str(storage_scope)}); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject ty, IRDocsifier p) -> Doc { + auto fields = ty.GetAttr(&TupleTypeNode::fields); + + if (fields.empty()) { + return LiteralDoc::None(fields.GetPath()); + } + return TIR(p)->Attr("Tuple")->Call(AsExprDocArray(fields, p)); + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/var.cc b/src/script/printer/tir/var.cc new file mode 100644 index 000000000000..e6c200e1fe8e --- /dev/null +++ b/src/script/printer/tir/var.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "./tir.h" + +namespace tvm { +namespace script { +namespace printer { + +TracedObject GetVarNameHint(const TracedObject& var) { + TracedObject name_hint = var.GetAttr(&tir::VarNode::name_hint); + if (name_hint.Get().empty()) { + return MakeTraced(String("v"), var.GetPath()); + } else { + return name_hint; + } +} + +IdDoc CreateFreeVariableDefinition(TracedObject var, IRDocsifier p) { + TracedObject name_hint = GetVarNameHint(var); + // TODO(yelite): When implementing the PrimFunc printing, the logic here + // needs to change, putting variable def into PrimFuncFrame if it exists. + TIRTopLevelFrame top_level_frame = p->GetFrame().value(); + IdDoc doc = p->vars->Define(var.Get(), name_hint, top_level_frame); + StmtDoc def_doc = AssignDoc(doc, NullOpt, GetTypeAnnotationDocForVar(var, p)); + top_level_frame->free_var_definitions.push_back(def_doc); + return doc; +} + +ExprDoc PrintVariable(TracedObject var, IRDocsifier p) { + Optional doc = p->vars->GetVarDoc(var); + if (doc.defined()) { + return doc.value(); + } else { + return CreateFreeVariableDefinition(var, p); + } +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch(PrintVariable); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject var, IRDocsifier p) { + return PrintVariable(MakeTraced(var.Get(), var.GetPath()), p); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject v, IRDocsifier p) -> Doc { + LOG(FATAL) << "Cannot print IterVar directly. Please use the helper functions in tir.h for " + "specific usage of IterVar."; + throw; + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py new file mode 100644 index 000000000000..936a2d74b48e --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from tvm.ir import PointerType, PrimType, TupleType +from tvm.script.printer import script +from tvm.tir import SizeVar, Var + + +def format_script(s: str) -> str: + """ + Remove leading and trailing blank lines, and make the minimum idention 0 + """ + s = s.strip("\n") + + non_empty_lines = [line for line in s.splitlines() if line and not line.isspace()] + if not non_empty_lines: + # no actual content + return "\n" + + line_indents = [len(line) - len(line.lstrip(" ")) for line in non_empty_lines] + spaces_to_remove = min(line_indents) + + cleaned_lines = "\n".join(line[spaces_to_remove:] for line in s.splitlines()) + if not cleaned_lines.endswith("\n"): + cleaned_lines += "\n" + return cleaned_lines + + +@pytest.mark.parametrize( + "ty, expected", + [ + ( + PrimType("int8"), + """ + T.int8 + """, + ), + ( + PrimType("float32"), + """ + T.float32 + """, + ), + ( + PointerType(PrimType("int32")), + """ + T.Ptr(T.int32) + """, + ), + ( + PointerType(PrimType("int32"), "global"), + """ + T.Ptr(T.int32, "global") + """, + ), + ( + TupleType([]), + """ + None + """, + ), + ], +) +def test_type(ty, expected): + assert format_script(expected) == script(ty, "tir", {"tir": "T"}) + + +@pytest.mark.parametrize("var_type", [Var, SizeVar]) +def test_var(var_type): + var = var_type("x", "int8") + + assert script(var, "tir", {"tir": "T"}) == format_script( + """ + x: T.int8 + x + """ + ) From 882cfc5c6f4c04c33734775a2650b872db4ebbc2 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 19 Aug 2022 12:57:52 -0400 Subject: [PATCH 3/3] Add TIR buffer printing Co-authored-by: Junru Shao Co-authored-by: Greg Bonik Fix lint --- include/tvm/script/printer/visit_traced.h | 37 +++ src/script/printer/tir/buffer.cc | 298 ++++++++++++++++++ src/script/printer/tir/buffer.h | 105 ++++++ src/script/printer/tir/expr.cc | 35 ++ src/script/printer/tir/tir.cc | 19 ++ src/script/printer/tir/tir.h | 7 + src/script/printer/utils.h | 11 + src/script/printer/visit_traced.cc | 91 ++++++ .../unittest/test_tvmscript_printer_tir.py | 89 +++++- 9 files changed, 684 insertions(+), 8 deletions(-) create mode 100644 include/tvm/script/printer/visit_traced.h create mode 100644 src/script/printer/tir/buffer.cc create mode 100644 src/script/printer/tir/buffer.h create mode 100644 src/script/printer/tir/expr.cc create mode 100644 src/script/printer/visit_traced.cc diff --git a/include/tvm/script/printer/visit_traced.h b/include/tvm/script/printer/visit_traced.h new file mode 100644 index 000000000000..62f5d3a5d641 --- /dev/null +++ b/include/tvm/script/printer/visit_traced.h @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_SCRIPT_PRINTER_VISIT_TRACED_H_ +#define TVM_SCRIPT_PRINTER_VISIT_TRACED_H_ + +#include + +namespace tvm { +namespace script { +namespace printer { + +void PostOrderVisitTraced(const TracedObject& object, + const std::function& node_predicate, + const std::function&)>& callback); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_VISIT_TRACED_H_ diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc new file mode 100644 index 000000000000..c85169865cb4 --- /dev/null +++ b/src/script/printer/tir/buffer.cc @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./buffer.h" + +#include +#include +#include + +#include + +#include "../utils.h" +#include "./tir.h" +#include "tvm/runtime/data_type.h" + +namespace tvm { +namespace script { +namespace printer { + +ExprDoc BufferPrintInfo::AsCall( + const ExprDoc& prefix, std::function&)> converter) const { + return AsCall(prefix, {}, converter); +} + +ExprDoc BufferPrintInfo::AsCall( + const ExprDoc& prefix, const Array& extra_args, + std::function&)> converter) const { + Array args(extra_args); + Array kwargs_keys; + Array kwargs_values; + { + Array results; + results.reserve(shape.size()); + for (TracedObject e : shape) { + results.push_back(converter(e)); + } + kwargs_keys.push_back("shape"); + kwargs_values.push_back(TupleDoc(results)); + } + if (dtype.defined()) { + args.push_back(dtype.value()); + } + if (data.defined()) { + kwargs_keys.push_back("data"); + kwargs_values.push_back(converter(data.value())); + } + if (strides.defined()) { + Array results; + results.reserve(strides.value().size()); + for (TracedObject stride : strides.value()) { + results.push_back(converter(stride)); + } + kwargs_keys.push_back("strides"); + kwargs_values.push_back(TupleDoc(results)); + } + if (elem_offset.defined()) { + kwargs_keys.push_back("elem_offset"); + kwargs_values.push_back(converter(elem_offset.value())); + } + if (scope.defined()) { + kwargs_keys.push_back("scope"); + kwargs_values.push_back(scope.value()); + } + if (align.defined()) { + kwargs_keys.push_back("align"); + kwargs_values.push_back(align.value()); + } + if (offset_factor.defined()) { + kwargs_keys.push_back("offset_factor"); + kwargs_values.push_back(offset_factor.value()); + } + if (buffer_type.defined()) { + kwargs_keys.push_back("buffer_type"); + kwargs_values.push_back(buffer_type.value()); + } + return prefix->Call(args, kwargs_keys, kwargs_values); +} + +static Optional GetBufferScope(const TracedObject& buffer) { + auto data = buffer.GetAttr(&tir::BufferNode::data); + auto type = data.GetAttr(&tir::VarNode::type_annotation).Downcast(); + auto scope = type.GetAttr(&PointerTypeNode::storage_scope); + if (scope.Get().empty() || scope.Get() == "global") { + return NullOpt; + } else { + return LiteralDoc::Str(scope); + } +} + +static Optional GetBufferDtype(const TracedObject& buffer) { + auto dtype = buffer.GetAttr(&tir::BufferNode::dtype); + if (dtype.Get() == DataType::Float(32)) { + return NullOpt; + } else { + return DType2Literal(dtype); + } +} + +static bool HasDefaultDataPtr(const tir::Buffer& buffer) { + const auto* ptr_type = buffer->data->type_annotation.as(); + ICHECK(ptr_type) << "Buffer variable is not of pointer type"; + if (const auto* element_type = ptr_type->element_type.as()) { + DataType default_dtype = buffer->dtype; + if (buffer->dtype.is_bool()) { + default_dtype = DataType::Int(8); + } + return element_type->dtype == default_dtype; + } else { + return false; + } +} + +static TracedOptional GetBufferData(const TracedObject& buffer, + const BufferAssociatedVariables& associated_vars) { + auto data = buffer.GetAttr(&tir::BufferNode::data); + if (associated_vars.IsAssociatedWith(data.Get(), buffer.Get()) && + HasDefaultDataPtr(buffer.Get())) { + return TracedOptional(NullOpt, data.GetPath()); + } else { + return data; + } +} + +static TracedOptional> GetBufferStrides(const TracedObject& buffer) { + auto strides = buffer.GetAttr(&tir::BufferNode::strides); + if (!strides.empty()) { + return TracedOptional>(strides); + } else { + return TracedOptional>(NullOpt, strides.GetPath()); + } +} + +static TracedOptional GetBufferElemOffset( + const TracedObject& buffer, const BufferAssociatedVariables& associated_vars) { + auto elem_offset = buffer.GetAttr(&tir::BufferNode::elem_offset); + if (elem_offset.defined()) { + // Don't print the offset if it is an associated variable + if (elem_offset.IsInstance() && + associated_vars.IsAssociatedWith(elem_offset.Get(), buffer.Get())) { + return TracedOptional(NullOpt, elem_offset.GetPath()); + } + + // Don't print the offset if it is zero + if (auto i = elem_offset.TryDowncast()) { + if (i.value().Get()->value == 0 && i.value().Get()->dtype == DataType::Int(32)) { + return TracedOptional(NullOpt, elem_offset.GetPath()); + } + } + } + return elem_offset; +} + +static Optional GetBufferAlignment(const TracedObject& buffer) { + auto data_alignment = buffer.GetAttr(&tir::BufferNode::data_alignment); + if (data_alignment.Get() != runtime::kAllocAlignment) { + return LiteralDoc::Int(data_alignment); + } else { + return NullOpt; + } +} + +static Optional GetBufferOffsetFactor(const TracedObject& buffer) { + auto offset_factor = buffer.GetAttr(&tir::BufferNode::offset_factor); + if (offset_factor.Get() != 1) { + return LiteralDoc::Int(offset_factor); + } else { + return NullOpt; + } +} + +static Optional GetBufferType(const TracedObject& buffer) { + auto buffer_type = buffer.GetAttr(&tir::BufferNode::buffer_type); + if (buffer_type.Get() != tir::BufferType::kDefault) { + return LiteralDoc::Str(MakeTraced(String("auto"), buffer_type.GetPath())); + } else { + return NullOpt; + } +} + +std::vector GetBufferPrintInfo( + const std::vector>& buffers, // + std::function f_var_defined, + std::unordered_map* var_explicit_def, + BufferAssociatedVariables* associated_vars) { + using namespace tvm::tir; + auto check_explicit_def = [&](const TracedObject& e) -> void { + PostOrderVisitExprTraced(e, [&](const TracedObject& n) -> void { + if (const auto* v = n.Get().as()) { + if (!f_var_defined(v) && !associated_vars->IsAssociated(v)) { + var_explicit_def->insert({v, n.GetPath()}); + } + } + }); + }; + for (const TracedObject& traced_buffer : buffers) { + const Buffer& buffer = traced_buffer.Get(); + if (!f_var_defined(buffer->data.get()) && HasDefaultDataPtr(buffer)) { + associated_vars->AssociateIfNotAlready(buffer->data.get(), buffer); + } + if (const auto* elem_offset = buffer->elem_offset.as()) { + if (!f_var_defined(elem_offset)) { + associated_vars->AssociateIfNotAlready(elem_offset, buffer); + } + } + } + for (TracedObject buffer : buffers) { + auto shape = buffer.GetAttr(&tir::BufferNode::shape); + std::for_each(shape.begin(), shape.end(), check_explicit_def); + + auto strides = buffer.GetAttr(&tir::BufferNode::strides); + std::for_each(strides.begin(), strides.end(), check_explicit_def); + + check_explicit_def(buffer.GetAttr(&tir::BufferNode::data)); + check_explicit_def(buffer.GetAttr(&tir::BufferNode::elem_offset)); + } + std::vector results; + for (TracedObject buffer : buffers) { + results.push_back( + BufferPrintInfo{/* .buffer = */ buffer, + /* .shape = */ buffer.GetAttr(&tir::BufferNode::shape), + /* .dtype = */ GetBufferDtype(buffer), + /* .data = */ GetBufferData(buffer, *associated_vars), + /* .strides = */ GetBufferStrides(buffer), + /* .elem_offset = */ GetBufferElemOffset(buffer, *associated_vars), + /* .scope = */ GetBufferScope(buffer), + /* .align = */ GetBufferAlignment(buffer), + /* .offset_factor = */ GetBufferOffsetFactor(buffer), + /* .buffer_type = */ GetBufferType(buffer)}); + } + return results; +} + +TracedObject GetBufferNameHint(const TracedObject& buf) { + TracedObject name_hint = buf.GetAttr(&tir::BufferNode::name); + if (name_hint.Get().empty()) { + return MakeTraced(String("buf"), buf.GetPath()); + } else { + return name_hint; + } +} + +IdDoc DefineFreeBuffer(const TracedObject& buf, const Frame& frame, + const IRDocsifier& p, std::function add_definiton) { + TracedObject name_hint = GetBufferNameHint(buf); + IdDoc buf_doc = p->vars->Define(buf.Get(), name_hint, frame); + auto f_var_defined = [&p](const tir::VarNode* var) -> bool { + return p->vars->IsVarDefined(GetRef(var)); + }; + std::unordered_map var_explicit_def; + BufferAssociatedVariables associated_vars; + + BufferPrintInfo buffer_print_info = + GetBufferPrintInfo({buf}, f_var_defined, &var_explicit_def, &associated_vars)[0]; + + associated_vars.Define(p->vars.get(), frame); + + ExprDoc buf_definition = buffer_print_info.AsCall( + TIR(p)->Attr("Buffer"), + [&p](const TracedObject& expr) -> ExprDoc { return p->AsDoc(expr); }); + add_definiton(AssignDoc(buf_doc, NullOpt, buf_definition)); + + return buf_doc; +} + +ExprDoc PrintBuffer(TracedObject buf, IRDocsifier p) { + Optional doc = p->vars->GetVarDoc(buf); + if (doc.defined()) { + return doc.value(); + } else { + // TODO(yelite): When implementing the PrimFunc printing, the logic here + // needs to change, putting variable def into PrimFuncFrame if it exists. + TIRTopLevelFrame top_level_frame = p->GetFrame().value(); + return DefineFreeBuffer(buf, top_level_frame, p, [top_level_frame](StmtDoc definition) { + top_level_frame->free_var_definitions.push_back(definition); + }); + } +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch(PrintBuffer); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/buffer.h b/src/script/printer/tir/buffer.h new file mode 100644 index 000000000000..10c2039b4e57 --- /dev/null +++ b/src/script/printer/tir/buffer.h @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_TIR_BUFFER_H_ +#define TVM_SCRIPT_PRINTER_TIR_BUFFER_H_ + +#include +#include +#include + +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +class BufferAssociatedVariables { + public: + void Disassociate(const tir::VarNode* var) { var2buffer_.erase(var); } + + void AssociateIfNotAlready(const tir::VarNode* var, const tir::Buffer& buffer) { + var2buffer_.insert({var, buffer}); + } + + bool IsAssociated(const tir::VarNode* var) const { return var2buffer_.count(var) != 0; } + + bool IsAssociatedWith(const PrimExpr& e, const tir::Buffer& buffer) const { + if (const auto* v = e.as()) { + auto it = var2buffer_.find(v); + return it != var2buffer_.end() && it->second == buffer; + } + return false; + } + + void Define(VarTableNode* vars, const Frame& frame) const { + for (const auto& kv : var2buffer_) { + const tir::VarNode* var = kv.first; + const tir::Buffer& buffer = kv.second; + + ExprDoc buffer_name = vars->GetVarDoc(MakeTraced(buffer)).value(); + buffer_name->source_paths.clear(); + + if (buffer->data.get() == var) { + vars->DefineByDoc( + buffer->data, [buffer_name]() { return buffer_name->Attr("data"); }, frame); + } else if (buffer->elem_offset.get() == var) { + vars->DefineByDoc( + buffer->elem_offset, [buffer_name]() { return buffer_name->Attr("elem_offset"); }, + frame); + } else { + ICHECK(false) << "Unexpected association. Buffer: " << buffer + << "; Var: " << GetRef(var); + } + } + } + + private: + std::unordered_map var2buffer_; +}; + +struct BufferPrintInfo { + TracedObject buffer; + TracedArray shape; + Optional dtype; + TracedOptional data; + TracedOptional> strides; + TracedOptional elem_offset; + Optional scope; + Optional align; + Optional offset_factor; + Optional buffer_type; + + ExprDoc AsCall(const ExprDoc& prefix, + std::function&)> converter) const; + ExprDoc AsCall(const ExprDoc& prefix, const Array& extra_args, + std::function&)> converter) const; +}; + +std::vector GetBufferPrintInfo( + const std::vector>& buffers, // + std::function f_var_defined, + std::unordered_map* var_explicit_def, + BufferAssociatedVariables* associated_vars); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_TIR_BUFFER_H_ diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc new file mode 100644 index 000000000000..fad85f269668 --- /dev/null +++ b/src/script/printer/tir/expr.cc @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject i, IRDocsifier p) { return LiteralDoc::Int(i); }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/tir.cc b/src/script/printer/tir/tir.cc index 38bd94a72bb5..48975791bd15 100644 --- a/src/script/printer/tir/tir.cc +++ b/src/script/printer/tir/tir.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include namespace tvm { @@ -33,6 +34,24 @@ TIRTopLevelFrame::TIRTopLevelFrame() : TIRFrame(make_object()) {} +void PostOrderVisitExprTraced(const TracedObject& expr, + const std::function&)>& callback) { + PostOrderVisitTraced( + expr, [](const ObjectRef& object) { return object->IsInstance(); }, + [&](const TracedObject& object) { callback(object.Downcast()); }); +} + +void PostOrderVisitStmtExprTraced( + const TracedObject& stmt, + const std::function&)>& callback) { + PostOrderVisitTraced( + stmt, + [](const ObjectRef& object) { + return object->IsInstance() || object->IsInstance(); + }, + [&](const TracedObject& object) { callback(object); }); +} + ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p) { auto type_annotation = var.GetAttr(&tir::VarNode::type_annotation); if (type_annotation.Get().defined()) { diff --git a/src/script/printer/tir/tir.h b/src/script/printer/tir/tir.h index bb5973ee4f3b..eb9b3f544d24 100644 --- a/src/script/printer/tir/tir.h +++ b/src/script/printer/tir/tir.h @@ -82,6 +82,13 @@ inline IdDoc TIR(const IRDocsifier& p) { return IdDoc(p->ir_prefix.Get("tir").va ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p); +void PostOrderVisitExprTraced(const TracedObject& expr, + const std::function&)>& callback); + +void PostOrderVisitStmtExprTraced( + const TracedObject& expr, + const std::function&)>& callback); + } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index abe7ce5e9a88..eef93b9348d9 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -86,6 +86,17 @@ inline TupleDoc AsTupleDoc(const TracedArray& arr, const IRDocsifier& ir_docs return ret; } +inline LiteralDoc DType2Literal(const DLDataType& dtype) { + using runtime::DLDataType2String; + return LiteralDoc::Str(DLDataType2String(dtype)); +} + +inline LiteralDoc DType2Literal(const TracedBasicValue& dtype) { + auto doc = DType2Literal(dtype.Get()); + doc->source_paths.push_back(dtype.GetPath()); + return doc; +} + } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/visit_traced.cc b/src/script/printer/visit_traced.cc new file mode 100644 index 000000000000..87ab787c2074 --- /dev/null +++ b/src/script/printer/visit_traced.cc @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +void PostOrderVisitTracedImpl(const ObjectRef& object, const ObjectPath& path, + const std::function& node_predicate, + const std::function&)>& callback); + +struct ObjAttrVisitor : public AttrVisitor { + ObjAttrVisitor(const ObjectPath& path, const std::function node_predicate, + const std::function&)>& callback) + : path(path), node_predicate(node_predicate), callback(callback) {} + + const ObjectPath& path; + const std::function node_predicate; + const std::function&)>& callback; + + void Visit(const char* key, double* value) final {} + void Visit(const char* key, int64_t* value) final {} + void Visit(const char* key, uint64_t* value) final {} + void Visit(const char* key, int* value) final {} + void Visit(const char* key, bool* value) final {} + void Visit(const char* key, void** value) final {} + void Visit(const char* key, DataType* value) final {} + void Visit(const char* key, std::string* value) final {} + void Visit(const char* key, runtime::NDArray* value) final {} + void Visit(const char* key, ObjectRef* value) final { + PostOrderVisitTracedImpl(*value, path->Attr(key), node_predicate, callback); + } +}; + +void PostOrderVisitTracedImpl(const ObjectRef& object, const ObjectPath& path, + const std::function& node_predicate, + const std::function&)>& callback) { + if (!object.defined()) { + return; + } + + if (object->IsInstance()) { + const ArrayNode* node = static_cast(object.get()); + for (size_t i = 0; i < node->size(); ++i) { + PostOrderVisitTracedImpl(node->at(i), path->ArrayIndex(i), node_predicate, callback); + } + } else if (object->IsInstance()) { + const MapNode* node = static_cast(object.get()); + for (auto kv : *node) { + PostOrderVisitTracedImpl(kv.second, path->MapValue(kv.first), node_predicate, callback); + } + } else { + if (!node_predicate(object)) { + return; + } + + ObjAttrVisitor visitor(path, node_predicate, callback); + ReflectionVTable::Global()->VisitAttrs(const_cast(object.get()), &visitor); + + callback(MakeTraced(object, path)); + } +} + +void PostOrderVisitTraced(const TracedObject& object, + const std::function& node_predicate, + const std::function&)>& callback) { + PostOrderVisitTracedImpl(object.Get(), object.GetPath(), node_predicate, callback); +} + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 936a2d74b48e..2d597f512a87 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -18,7 +18,7 @@ from tvm.ir import PointerType, PrimType, TupleType from tvm.script.printer import script -from tvm.tir import SizeVar, Var +from tvm.tir import SizeVar, Var, decl_buffer def format_script(s: str) -> str: @@ -41,52 +41,125 @@ def format_script(s: str) -> str: return cleaned_lines +def as_tir_script(node): + return script(node, "tir", {"tir": "T"}) + + @pytest.mark.parametrize( "ty, expected", [ - ( + pytest.param( PrimType("int8"), """ T.int8 """, + id="int", ), - ( + pytest.param( PrimType("float32"), """ T.float32 """, + id="float", ), - ( + pytest.param( PointerType(PrimType("int32")), """ T.Ptr(T.int32) """, + id="pointer", ), - ( + pytest.param( PointerType(PrimType("int32"), "global"), """ T.Ptr(T.int32, "global") """, + id="with_scope", ), - ( + pytest.param( TupleType([]), """ None """, + id="none", ), ], ) def test_type(ty, expected): - assert format_script(expected) == script(ty, "tir", {"tir": "T"}) + assert format_script(expected) == as_tir_script(ty) @pytest.mark.parametrize("var_type", [Var, SizeVar]) def test_var(var_type): var = var_type("x", "int8") - assert script(var, "tir", {"tir": "T"}) == format_script( + assert as_tir_script(var) == format_script( """ x: T.int8 x """ ) + + +@pytest.mark.parametrize( + "buffer, expected", + [ + pytest.param( + decl_buffer((5, 10), name="b"), + """ + b: T.Buffer(shape=(5, 10)) + b + """, + id="simple", + ), + pytest.param( + decl_buffer((5), name=""), + """ + buf: T.Buffer(shape=(5,)) + buf + """, + id="no_name", + ), + pytest.param( + decl_buffer((SizeVar("m", "int"), SizeVar("n", "int")), dtype="int8"), + """ + m: T.int32 + n: T.int32 + buffer: T.Buffer("int8", shape=(m, n)) + buffer + """, + id="symbolic_shape", + ), + pytest.param( + decl_buffer( + (4, 10), + dtype="int8", + data=Var("p", PointerType(PrimType("int16"), "local")), + strides=[2, 5], + elem_offset=2, + data_alignment=16, + offset_factor=2, + scope="local", + ), + """ + p: T.Ptr(T.int16, "local") + buffer: T.Buffer("int8", shape=(4, 10), data=p, strides=(2, 5), elem_offset=2, scope="local", align=16, offset_factor=2) + buffer + """, + id="all_param", + ), + pytest.param( + decl_buffer( + (4, 10), + dtype="bool", + ), + """ + buffer: T.Buffer("bool", shape=(4, 10)) + buffer + """, + id="bool_different_ptr_type", + ), + ], +) +def test_buffer(buffer, expected): + assert as_tir_script(buffer) == format_script(expected)