From 22755f79818326b1b7be88b7d8874421a821f062 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 16 Sep 2022 23:26:01 -0400 Subject: [PATCH 1/4] Update TVMScript infra for future TIR printer changes Co-authored-by: Greg Bonik --- include/tvm/script/printer/doc.h | 64 +++++++++++++ .../script/printer/traced_object_functor.h | 37 +------- include/tvm/script/printer/var_table.h | 11 +++ src/script/printer/doc.cc | 30 ++++-- src/script/printer/ir_docsifier.cc | 2 +- src/script/printer/utils.h | 93 +++++++++++++++++++ src/script/printer/var_table.cc | 3 +- .../cpp/tvmscript_printer_irdocsifier_test.cc | 13 ++- ...ript_printer_traced_object_functor_test.cc | 37 ++++---- 9 files changed, 228 insertions(+), 62 deletions(-) create mode 100644 src/script/printer/utils.h diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 72f343354b1b..1ee7fd6a7fd4 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -22,6 +22,7 @@ #include #include #include +#include namespace tvm { namespace script { @@ -87,6 +88,15 @@ class ExprDocNode : public DocNode { */ ExprDoc Attr(String attr) const; + /*! + * \brief Create a doc representing attribute access on the current ExprDoc + * \param attr The attribute to access. + * + * The ObjectPath of attr will be pushed to the source_path of the returned + * doc. + */ + ExprDoc Attr(TracedObject attr) const; + /*! * \brief Create a doc representing index access on the current ExprDoc * \param indices The indices to access. @@ -242,6 +252,7 @@ class LiteralDocNode : public ExprDocNode { class LiteralDoc : public ExprDoc { protected: explicit LiteralDoc(ObjectRef value); + LiteralDoc(ObjectRef value, ObjectPath object_path); public: /*! @@ -249,30 +260,83 @@ class LiteralDoc : public ExprDoc { */ static LiteralDoc None() { return LiteralDoc(ObjectRef(nullptr)); } + /*! + * \brief Create a LiteralDoc to represent None/null/empty value. + * \param object_path The source path of the returned Doc. + */ + static LiteralDoc None(ObjectPath object_path) { + return LiteralDoc(ObjectRef(nullptr), object_path); + } + /*! * \brief Create a LiteralDoc to represent integer. * \param v The integer value. */ static LiteralDoc Int(int v) { return LiteralDoc(IntImm(DataType::Int(64), v)); } + /*! + * \brief Create a LiteralDoc to represent integer. + * \param v The integer value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Int(const TracedObject& v) { return LiteralDoc(v.Get(), v.GetPath()); } + + /*! + * \brief Create a LiteralDoc to represent integer. + * \param v The integer value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Int(const TracedBasicValue& v) { + return LiteralDoc(IntImm(DataType::Int(64), v.Get()), v.GetPath()); + } /*! * \brief Create a LiteralDoc to represent boolean. * \param v The boolean value. */ static LiteralDoc Boolean(bool v) { return LiteralDoc(IntImm(DataType::Bool(), v)); } + /*! + * \brief Create a LiteralDoc to represent boolean. + * \param v The boolean value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Boolean(const TracedBasicValue& v) { + return LiteralDoc(IntImm(DataType::Bool(), v.Get()), v.GetPath()); + } + /*! * \brief Create a LiteralDoc to represent float. * \param v The float value. */ static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); } + /*! + * \brief Create a LiteralDoc to represent float. + * \param v The float value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Float(const TracedObject& v) { + return LiteralDoc(v.Get(), v.GetPath()); + } + /*! * \brief Create a LiteralDoc to represent string. * \param v The string value. */ static LiteralDoc Str(const String& v) { return LiteralDoc(v); } + /*! + * \brief Create a LiteralDoc to represent string. + * \param v The string value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Str(const TracedObject& v) { return LiteralDoc(v.Get(), v.GetPath()); } + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode); }; diff --git a/include/tvm/script/printer/traced_object_functor.h b/include/tvm/script/printer/traced_object_functor.h index 6caaf8a6e0d5..8f72d139a5a5 100644 --- a/include/tvm/script/printer/traced_object_functor.h +++ b/include/tvm/script/printer/traced_object_functor.h @@ -34,35 +34,6 @@ namespace tvm { namespace script { namespace printer { -namespace { - -namespace detail { -/*! - * \brief Helper template class to extract the type of first argument of a function - * \tparam FType The function type. - */ -template -struct FirstArgTypeGetter; - -template -struct FirstArgTypeGetter { - using T = ArgOne; -}; - -/*! - * \brief Template alias for the type of first argument of a function - * \tparam FType The function type. - * - * The name of public functions are in snake case to be consistent with - * tvm/node/functor.h - */ -template -using FirstArgType = typename detail::FirstArgTypeGetter< - typename tvm::runtime::detail::function_signature::FType>::T; -} // namespace detail - -} // namespace - /* * This type alias and the following free functions are created to reduce the binary bloat * from template and also hide implementation details from this header @@ -156,8 +127,7 @@ class TracedObjectFunctor { * * The diaptch function should have signature `R(TracedObject, Args...)`. */ - template ::ObjectRefType, + template ::value>> TSelf& set_dispatch(String token, TCallable f) { return set_dispatch( @@ -177,9 +147,10 @@ class TracedObjectFunctor { * * Default dispatch function has an empty string as dispatch token. */ - template + template ::value>> TSelf& set_dispatch(TCallable&& f) { - return set_dispatch(kDefaultDispatchToken, std::forward(f)); + return set_dispatch(kDefaultDispatchToken, std::forward(f)); } /*! diff --git a/include/tvm/script/printer/var_table.h b/include/tvm/script/printer/var_table.h index 9300a976c569..2cd9335213a3 100644 --- a/include/tvm/script/printer/var_table.h +++ b/include/tvm/script/printer/var_table.h @@ -103,6 +103,17 @@ class VarTableNode : public Object { */ Optional GetVarDoc(const ObjectRef& obj, const ObjectPath& object_path) const; + /*! + * \brief Get the doc for variable. + * \param obj The traced variable object. + * + * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt. + */ + template + Optional GetVarDoc(const TracedObject obj) const { + return GetVarDoc(obj.Get(), obj.GetPath()); + } + /*! * \brief Check if a variable exists in the table. * \param obj The variable object. diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index d6f5ff35ab53..f3b431bd62db 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -27,6 +27,12 @@ namespace printer { ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef(this), attr); } +ExprDoc ExprDocNode::Attr(TracedObject attr) const { + auto doc = AttrAccessDoc(GetRef(this), attr.Get()); + doc->source_paths.push_back(attr.GetPath()); + return doc; +} + ExprDoc ExprDocNode::operator[](Array indices) const { return IndexDoc(GetRef(this), indices); } @@ -54,6 +60,13 @@ LiteralDoc::LiteralDoc(ObjectRef value) { this->data_ = std::move(n); } +LiteralDoc::LiteralDoc(ObjectRef value, ObjectPath object_path) { + ObjectPtr n = make_object(); + n->value = value; + n->source_paths.push_back(object_path); + this->data_ = std::move(n); +} + IdDoc::IdDoc(String name) { ObjectPtr n = make_object(); n->name = name; @@ -225,7 +238,8 @@ TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths") }); TVM_REGISTER_NODE_TYPE(ExprDocNode); -TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr").set_body_method(&ExprDocNode::Attr); +TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr") + .set_body_method(&ExprDocNode::Attr); TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex") .set_body_method(&ExprDocNode::operator[]); TVM_REGISTER_GLOBAL("script.printer.ExprDocCall") @@ -242,11 +256,15 @@ TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array(LiteralDoc::None); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt") + .set_body_typed(LiteralDoc::Int); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean") + .set_body_typed(LiteralDoc::Boolean); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat") + .set_body_typed(LiteralDoc::Float); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr") + .set_body_typed(LiteralDoc::Str); TVM_REGISTER_NODE_TYPE(IdDocNode); TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); }); diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index b72ed48db63b..7f032ec50269 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -61,7 +61,7 @@ RootNodeContainer::RootNodeContainer(ObjectRef root_node) { // }); // \endcode TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch([](TracedObject obj, IRDocsifier p) -> Doc { + .set_dispatch([](TracedObject obj, IRDocsifier p) -> Doc { String top_dispatch_token = p->dispatch_tokens.back(); ICHECK_NE(top_dispatch_token, ""); ICHECK(false) << "Printing IR " << top_dispatch_token << " is not implemented."; diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h new file mode 100644 index 000000000000..abe7ce5e9a88 --- /dev/null +++ b/src/script/printer/utils.h @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_UTILS_H_ +#define TVM_SCRIPT_PRINTER_UTILS_H_ + +#include +#include + +#include + +namespace tvm { +namespace script { +namespace printer { + +template +Array AsDocArray(const TracedArray& refs, const IRDocsifier& ir_docsifier) { + Array result; + for (auto ref : refs) { + result.push_back(ir_docsifier->AsExprDoc(ref)); + } + return result; +} + +template +Array AsDocArray(std::initializer_list&& refs, const IRDocsifier& ir_docsifier) { + Array result; + for (auto& ref : refs) { + result.push_back(ir_docsifier->AsExprDoc(ref)); + } + return result; +} + +template +Array AsExprDocArray(const TracedArray& refs, const IRDocsifier& ir_docsifier) { + return AsDocArray(refs, ir_docsifier); +} + +template +Array AsExprDocArray(std::initializer_list&& refs, + const IRDocsifier& ir_docsifier) { + return AsDocArray(std::move(refs), ir_docsifier); +} + +inline DictDoc AsDictDoc(const TracedMap& dict, + const IRDocsifier& ir_docsifier) { + Array keys; + Array values; + + for (auto p : dict) { + keys.push_back(LiteralDoc::Str(p.first)); + values.push_back(ir_docsifier->AsExprDoc(p.second)); + } + + auto doc = DictDoc(keys, values); + doc->source_paths.push_back(dict.GetPath()); + return doc; +} + +template +inline ListDoc AsListDoc(const TracedArray& arr, const IRDocsifier& ir_docsifier) { + auto ret = ListDoc(AsExprDocArray(arr, ir_docsifier)); + ret->source_paths.push_back(arr.GetPath()); + return ret; +} + +template +inline TupleDoc AsTupleDoc(const TracedArray& arr, const IRDocsifier& ir_docsifier) { + auto ret = TupleDoc(AsExprDocArray(arr, ir_docsifier)); + ret->source_paths.push_back(arr.GetPath()); + return ret; +} + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_UTILS_H_ diff --git a/src/script/printer/var_table.cc b/src/script/printer/var_table.cc index 49ba93f9bcfe..62d8b2f66cc2 100644 --- a/src/script/printer/var_table.cc +++ b/src/script/printer/var_table.cc @@ -99,7 +99,8 @@ TVM_REGISTER_GLOBAL("script.printer.VarTableDefineByDoc") obj, [f = std::move(factory)]() { return f(); }, frame); }); TVM_REGISTER_GLOBAL("script.printer.VarTableGetVarDoc") - .set_body_method(&VarTableNode::GetVarDoc); + .set_body_method, const ObjectRef&, + const ObjectPath&>(&VarTableNode::GetVarDoc); TVM_REGISTER_GLOBAL("script.printer.VarTableIsVarDefined") .set_body_method(&VarTableNode::IsVarDefined); diff --git a/tests/cpp/tvmscript_printer_irdocsifier_test.cc b/tests/cpp/tvmscript_printer_irdocsifier_test.cc index fcdb5ed04e41..8c68399df222 100644 --- a/tests/cpp/tvmscript_printer_irdocsifier_test.cc +++ b/tests/cpp/tvmscript_printer_irdocsifier_test.cc @@ -45,14 +45,19 @@ class TestObject : public ObjectRef { TVM_REGISTER_NODE_TYPE(TestObjectNode); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch([](TracedObject obj, IRDocsifier p) { return IdDoc("x"); }); + .set_dispatch([](TracedObject obj, IRDocsifier p) { + return IdDoc("x"); + }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("tir", [](TracedObject obj, IRDocsifier p) { return IdDoc("tir"); }); + .set_dispatch("tir", [](TracedObject obj, IRDocsifier p) { + return IdDoc("tir"); + }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("relax", - [](TracedObject obj, IRDocsifier p) { return IdDoc("relax"); }); + .set_dispatch("relax", [](TracedObject obj, IRDocsifier p) { + return IdDoc("relax"); + }); TEST(PrinterIRDocsifierTest, AsDoc) { IRDocsifier p(Map{}); diff --git a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc index 374eb609b6cb..d662ce132405 100644 --- a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc +++ b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc @@ -33,7 +33,7 @@ class FooObjectNode : public Object { public: void VisitAttrs(AttrVisitor* v) {} - static constexpr const char* _type_key = "test.FooObject"; + static constexpr const char* _type_key = "test.TracedObjectFunctor.FooObject"; TVM_DECLARE_FINAL_OBJECT_INFO(FooObjectNode, Object); }; @@ -49,7 +49,7 @@ class BarObjectNode : public Object { public: void VisitAttrs(AttrVisitor* v) {} - static constexpr const char* _type_key = "test.BarObject"; + static constexpr const char* _type_key = "test.TracedObjectFunctor.BarObject"; TVM_DECLARE_FINAL_OBJECT_INFO(BarObjectNode, Object); }; @@ -69,8 +69,8 @@ TEST(TracedObjectFunctorTest, NormalRegistration) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); - functor.set_dispatch([](TracedObject o) -> String { return "Bar"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Bar"; }); ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); ICHECK_EQ(functor("", MakeTraced(BarObject(), path)), "Bar"); @@ -80,8 +80,8 @@ TEST(TracedObjectFunctorTest, RegistrationWithFunction) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o) -> String { return "FooLambda"; }); - functor.set_dispatch("tir", ComputeFoo); + functor.set_dispatch([](TracedObject o) -> String { return "FooLambda"; }); + functor.set_dispatch("tir", ComputeFoo); ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "FooLambda"); ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo"); @@ -91,9 +91,11 @@ TEST(TracedObjectFunctorTest, RegistrationWithDispatchToken) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); - functor.set_dispatch("tir", [](TracedObject o) -> String { return "Foo tir"; }); - functor.set_dispatch("relax", [](TracedObject o) -> String { return "Foo relax"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch("tir", + [](TracedObject o) -> String { return "Foo tir"; }); + functor.set_dispatch("relax", + [](TracedObject o) -> String { return "Foo relax"; }); ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir"); @@ -119,8 +121,8 @@ TEST(TracedObjectFunctorTest, ExtraArg) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o, int x) { return x; }); - functor.set_dispatch([](TracedObject o, int x) { return x + 1; }); + functor.set_dispatch([](TracedObject o, int x) { return x; }); + functor.set_dispatch([](TracedObject o, int x) { return x + 1; }); ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2); ICHECK_EQ(functor("", MakeTraced(BarObject(), path), 2), 3); @@ -131,8 +133,9 @@ TEST(TracedObjectFunctorTest, RemoveDispatchFunction) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); - functor.set_dispatch("tir", [](TracedObject o) -> String { return "Foo tir"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch("tir", + [](TracedObject o) -> String { return "Foo tir"; }); ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir"); @@ -158,11 +161,11 @@ TEST(TracedObjectFunctorTest, DuplicateRegistration_WithoutToken) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o, int x) { return x; }); + functor.set_dispatch([](TracedObject o, int x) { return x; }); bool failed = false; try { - functor.set_dispatch([](TracedObject o, int x) { return x; }); + functor.set_dispatch([](TracedObject o, int x) { return x; }); } catch (...) { failed = true; } @@ -173,11 +176,11 @@ TEST(TracedObjectFunctorTest, DuplicateRegistration_WithToken) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); + functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); bool failed = false; try { - functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); + functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); } catch (...) { failed = true; } From 7551fe007eca2d5fe22ac0a3655b9b4ede27a402 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Wed, 17 Aug 2022 17:13:11 -0400 Subject: [PATCH 2/4] Add TIR var and type printing Co-authored-by: Greg Bonik --- src/script/printer/tir/tir.cc | 77 ++++++++++++++++ src/script/printer/tir/tir.h | 89 ++++++++++++++++++ src/script/printer/tir/type.cc | 69 ++++++++++++++ src/script/printer/tir/var.cc | 77 ++++++++++++++++ .../unittest/test_tvmscript_printer_tir.py | 92 +++++++++++++++++++ 5 files changed, 404 insertions(+) create mode 100644 src/script/printer/tir/tir.cc create mode 100644 src/script/printer/tir/tir.h create mode 100644 src/script/printer/tir/type.cc create mode 100644 src/script/printer/tir/var.cc create mode 100644 tests/python/unittest/test_tvmscript_printer_tir.py diff --git a/src/script/printer/tir/tir.cc b/src/script/printer/tir/tir.cc new file mode 100644 index 000000000000..38bd94a72bb5 --- /dev/null +++ b/src/script/printer/tir/tir.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./tir.h" + +#include +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +TIRTopLevelFrame::TIRTopLevelFrame() : TIRFrame(make_object()) {} + +TIRGeneralFrame::TIRGeneralFrame() : TIRFrame(make_object()) {} + +ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p) { + auto type_annotation = var.GetAttr(&tir::VarNode::type_annotation); + if (type_annotation.Get().defined()) { + return p->AsExprDoc(type_annotation); + } else { + auto dtype = var.GetAttr(&tir::VarNode::dtype); + Type raw_type = GetTypeFromRuntimeDataType(dtype.Get()); + return p->AsExprDoc(MakeTraced(raw_type, dtype.GetPath())); + } +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject obj, IRDocsifier p) { + const ObjectRef& root_node = obj.Get()->root_node; + + TIRTopLevelFrame top_level_frame; + auto frame_ctx = p->WithFrame(top_level_frame); + + // Because we are printing a single element, concise scoping should be allowed by default + top_level_frame->allow_concise_scoping = true; + + Doc root_doc = p->AsDoc(MakeTraced(root_node)); + + Array doc_to_print = top_level_frame->free_var_definitions; + + if (const auto* stmt_doc_node = root_doc.as()) { + doc_to_print.push_back(GetRef(stmt_doc_node)); + } else if (const auto* expr_doc_node = root_doc.as()) { + doc_to_print.push_back(ExprStmtDoc(GetRef(expr_doc_node))); + } else if (const auto* stmt_block_node = root_doc.as()) { + doc_to_print = runtime::Concat(doc_to_print, stmt_block_node->stmts); + } else if (const auto* slice_doc_node = root_doc.as()) { + doc_to_print.push_back(ExprStmtDoc(IdDoc("_")[{GetRef(slice_doc_node)}])); + } else { + ICHECK(false) << "Cannot print " << root_doc->GetTypeKey() << " as top level doc for TIR."; + } + + return StmtBlockDoc(doc_to_print); + }); +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/tir.h b/src/script/printer/tir/tir.h new file mode 100644 index 000000000000..bb5973ee4f3b --- /dev/null +++ b/src/script/printer/tir/tir.h @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_TIR_TIR_H_ +#define TVM_SCRIPT_PRINTER_TIR_TIR_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +class TIRFrameNode : public FrameNode { + public: + mutable bool allow_concise_scoping{false}; + + void VisitAttrs(AttrVisitor* v) { + FrameNode::VisitAttrs(v); + v->Visit("allow_concise_scoping", &allow_concise_scoping); + } + + static constexpr const char* _type_key = "script.printer.TIRFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, FrameNode); +}; + +class TIRFrame : public Frame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, Frame, TIRFrameNode); +}; + +class TIRTopLevelFrameNode : public TIRFrameNode { + public: + Array free_var_definitions; + + void VisitAttrs(AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("free_var_definitions", &free_var_definitions); + } + + static constexpr const char* _type_key = "script.printer.TIRTopLevelFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(TIRTopLevelFrameNode, FrameNode); +}; + +class TIRTopLevelFrame : public TIRFrame { + public: + TIRTopLevelFrame(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRTopLevelFrame, TIRFrame, + TIRTopLevelFrameNode); +}; + +class TIRGeneralFrameNode : public TIRFrameNode { + public: + static constexpr const char* _type_key = "script.printer.TIRGeneralFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(TIRGeneralFrameNode, FrameNode); +}; + +class TIRGeneralFrame : public TIRFrame { + public: + TIRGeneralFrame(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRGeneralFrame, TIRFrame, TIRGeneralFrameNode); +}; + +inline IdDoc TIR(const IRDocsifier& p) { return IdDoc(p->ir_prefix.Get("tir").value_or("T")); } + +ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_TIR_TIR_H_ diff --git a/src/script/printer/tir/type.cc b/src/script/printer/tir/type.cc new file mode 100644 index 000000000000..09aa96be7847 --- /dev/null +++ b/src/script/printer/tir/type.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../utils.h" +#include "./tir.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject ty, IRDocsifier p) -> Doc { + TracedBasicValue dtype = ty.GetAttr(&PrimTypeNode::dtype); + String ty_str = runtime::DLDataType2String(dtype.Get()); + return TIR(p)->Attr(MakeTraced(ty_str, ty.GetPath())); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject ty, IRDocsifier p) -> Doc { + TracedObject element_type = ty.GetAttr(&PointerTypeNode::element_type); + TracedObject storage_scope = ty.GetAttr(&PointerTypeNode::storage_scope); + + ExprDoc element_type_doc = p->AsDoc(element_type); + if (storage_scope.Get().empty()) { + return TIR(p)->Attr("Ptr")->Call({element_type_doc}); + } else { + return TIR(p)->Attr("Ptr")->Call({element_type_doc, LiteralDoc::Str(storage_scope)}); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject ty, IRDocsifier p) -> Doc { + auto fields = ty.GetAttr(&TupleTypeNode::fields); + + if (fields.empty()) { + return LiteralDoc::None(fields.GetPath()); + } + return TIR(p)->Attr("Tuple")->Call(AsExprDocArray(fields, p)); + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/var.cc b/src/script/printer/tir/var.cc new file mode 100644 index 000000000000..e6c200e1fe8e --- /dev/null +++ b/src/script/printer/tir/var.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "./tir.h" + +namespace tvm { +namespace script { +namespace printer { + +TracedObject GetVarNameHint(const TracedObject& var) { + TracedObject name_hint = var.GetAttr(&tir::VarNode::name_hint); + if (name_hint.Get().empty()) { + return MakeTraced(String("v"), var.GetPath()); + } else { + return name_hint; + } +} + +IdDoc CreateFreeVariableDefinition(TracedObject var, IRDocsifier p) { + TracedObject name_hint = GetVarNameHint(var); + // TODO(yelite): When implementing the PrimFunc printing, the logic here + // needs to change, putting variable def into PrimFuncFrame if it exists. + TIRTopLevelFrame top_level_frame = p->GetFrame().value(); + IdDoc doc = p->vars->Define(var.Get(), name_hint, top_level_frame); + StmtDoc def_doc = AssignDoc(doc, NullOpt, GetTypeAnnotationDocForVar(var, p)); + top_level_frame->free_var_definitions.push_back(def_doc); + return doc; +} + +ExprDoc PrintVariable(TracedObject var, IRDocsifier p) { + Optional doc = p->vars->GetVarDoc(var); + if (doc.defined()) { + return doc.value(); + } else { + return CreateFreeVariableDefinition(var, p); + } +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch(PrintVariable); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject var, IRDocsifier p) { + return PrintVariable(MakeTraced(var.Get(), var.GetPath()), p); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject v, IRDocsifier p) -> Doc { + LOG(FATAL) << "Cannot print IterVar directly. Please use the helper functions in tir.h for " + "specific usage of IterVar."; + throw; + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py new file mode 100644 index 000000000000..936a2d74b48e --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from tvm.ir import PointerType, PrimType, TupleType +from tvm.script.printer import script +from tvm.tir import SizeVar, Var + + +def format_script(s: str) -> str: + """ + Remove leading and trailing blank lines, and make the minimum idention 0 + """ + s = s.strip("\n") + + non_empty_lines = [line for line in s.splitlines() if line and not line.isspace()] + if not non_empty_lines: + # no actual content + return "\n" + + line_indents = [len(line) - len(line.lstrip(" ")) for line in non_empty_lines] + spaces_to_remove = min(line_indents) + + cleaned_lines = "\n".join(line[spaces_to_remove:] for line in s.splitlines()) + if not cleaned_lines.endswith("\n"): + cleaned_lines += "\n" + return cleaned_lines + + +@pytest.mark.parametrize( + "ty, expected", + [ + ( + PrimType("int8"), + """ + T.int8 + """, + ), + ( + PrimType("float32"), + """ + T.float32 + """, + ), + ( + PointerType(PrimType("int32")), + """ + T.Ptr(T.int32) + """, + ), + ( + PointerType(PrimType("int32"), "global"), + """ + T.Ptr(T.int32, "global") + """, + ), + ( + TupleType([]), + """ + None + """, + ), + ], +) +def test_type(ty, expected): + assert format_script(expected) == script(ty, "tir", {"tir": "T"}) + + +@pytest.mark.parametrize("var_type", [Var, SizeVar]) +def test_var(var_type): + var = var_type("x", "int8") + + assert script(var, "tir", {"tir": "T"}) == format_script( + """ + x: T.int8 + x + """ + ) From 882cfc5c6f4c04c33734775a2650b872db4ebbc2 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 19 Aug 2022 12:57:52 -0400 Subject: [PATCH 3/4] Add TIR buffer printing Co-authored-by: Junru Shao Co-authored-by: Greg Bonik Fix lint --- include/tvm/script/printer/visit_traced.h | 37 +++ src/script/printer/tir/buffer.cc | 298 ++++++++++++++++++ src/script/printer/tir/buffer.h | 105 ++++++ src/script/printer/tir/expr.cc | 35 ++ src/script/printer/tir/tir.cc | 19 ++ src/script/printer/tir/tir.h | 7 + src/script/printer/utils.h | 11 + src/script/printer/visit_traced.cc | 91 ++++++ .../unittest/test_tvmscript_printer_tir.py | 89 +++++- 9 files changed, 684 insertions(+), 8 deletions(-) create mode 100644 include/tvm/script/printer/visit_traced.h create mode 100644 src/script/printer/tir/buffer.cc create mode 100644 src/script/printer/tir/buffer.h create mode 100644 src/script/printer/tir/expr.cc create mode 100644 src/script/printer/visit_traced.cc diff --git a/include/tvm/script/printer/visit_traced.h b/include/tvm/script/printer/visit_traced.h new file mode 100644 index 000000000000..62f5d3a5d641 --- /dev/null +++ b/include/tvm/script/printer/visit_traced.h @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_SCRIPT_PRINTER_VISIT_TRACED_H_ +#define TVM_SCRIPT_PRINTER_VISIT_TRACED_H_ + +#include + +namespace tvm { +namespace script { +namespace printer { + +void PostOrderVisitTraced(const TracedObject& object, + const std::function& node_predicate, + const std::function&)>& callback); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_VISIT_TRACED_H_ diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc new file mode 100644 index 000000000000..c85169865cb4 --- /dev/null +++ b/src/script/printer/tir/buffer.cc @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./buffer.h" + +#include +#include +#include + +#include + +#include "../utils.h" +#include "./tir.h" +#include "tvm/runtime/data_type.h" + +namespace tvm { +namespace script { +namespace printer { + +ExprDoc BufferPrintInfo::AsCall( + const ExprDoc& prefix, std::function&)> converter) const { + return AsCall(prefix, {}, converter); +} + +ExprDoc BufferPrintInfo::AsCall( + const ExprDoc& prefix, const Array& extra_args, + std::function&)> converter) const { + Array args(extra_args); + Array kwargs_keys; + Array kwargs_values; + { + Array results; + results.reserve(shape.size()); + for (TracedObject e : shape) { + results.push_back(converter(e)); + } + kwargs_keys.push_back("shape"); + kwargs_values.push_back(TupleDoc(results)); + } + if (dtype.defined()) { + args.push_back(dtype.value()); + } + if (data.defined()) { + kwargs_keys.push_back("data"); + kwargs_values.push_back(converter(data.value())); + } + if (strides.defined()) { + Array results; + results.reserve(strides.value().size()); + for (TracedObject stride : strides.value()) { + results.push_back(converter(stride)); + } + kwargs_keys.push_back("strides"); + kwargs_values.push_back(TupleDoc(results)); + } + if (elem_offset.defined()) { + kwargs_keys.push_back("elem_offset"); + kwargs_values.push_back(converter(elem_offset.value())); + } + if (scope.defined()) { + kwargs_keys.push_back("scope"); + kwargs_values.push_back(scope.value()); + } + if (align.defined()) { + kwargs_keys.push_back("align"); + kwargs_values.push_back(align.value()); + } + if (offset_factor.defined()) { + kwargs_keys.push_back("offset_factor"); + kwargs_values.push_back(offset_factor.value()); + } + if (buffer_type.defined()) { + kwargs_keys.push_back("buffer_type"); + kwargs_values.push_back(buffer_type.value()); + } + return prefix->Call(args, kwargs_keys, kwargs_values); +} + +static Optional GetBufferScope(const TracedObject& buffer) { + auto data = buffer.GetAttr(&tir::BufferNode::data); + auto type = data.GetAttr(&tir::VarNode::type_annotation).Downcast(); + auto scope = type.GetAttr(&PointerTypeNode::storage_scope); + if (scope.Get().empty() || scope.Get() == "global") { + return NullOpt; + } else { + return LiteralDoc::Str(scope); + } +} + +static Optional GetBufferDtype(const TracedObject& buffer) { + auto dtype = buffer.GetAttr(&tir::BufferNode::dtype); + if (dtype.Get() == DataType::Float(32)) { + return NullOpt; + } else { + return DType2Literal(dtype); + } +} + +static bool HasDefaultDataPtr(const tir::Buffer& buffer) { + const auto* ptr_type = buffer->data->type_annotation.as(); + ICHECK(ptr_type) << "Buffer variable is not of pointer type"; + if (const auto* element_type = ptr_type->element_type.as()) { + DataType default_dtype = buffer->dtype; + if (buffer->dtype.is_bool()) { + default_dtype = DataType::Int(8); + } + return element_type->dtype == default_dtype; + } else { + return false; + } +} + +static TracedOptional GetBufferData(const TracedObject& buffer, + const BufferAssociatedVariables& associated_vars) { + auto data = buffer.GetAttr(&tir::BufferNode::data); + if (associated_vars.IsAssociatedWith(data.Get(), buffer.Get()) && + HasDefaultDataPtr(buffer.Get())) { + return TracedOptional(NullOpt, data.GetPath()); + } else { + return data; + } +} + +static TracedOptional> GetBufferStrides(const TracedObject& buffer) { + auto strides = buffer.GetAttr(&tir::BufferNode::strides); + if (!strides.empty()) { + return TracedOptional>(strides); + } else { + return TracedOptional>(NullOpt, strides.GetPath()); + } +} + +static TracedOptional GetBufferElemOffset( + const TracedObject& buffer, const BufferAssociatedVariables& associated_vars) { + auto elem_offset = buffer.GetAttr(&tir::BufferNode::elem_offset); + if (elem_offset.defined()) { + // Don't print the offset if it is an associated variable + if (elem_offset.IsInstance() && + associated_vars.IsAssociatedWith(elem_offset.Get(), buffer.Get())) { + return TracedOptional(NullOpt, elem_offset.GetPath()); + } + + // Don't print the offset if it is zero + if (auto i = elem_offset.TryDowncast()) { + if (i.value().Get()->value == 0 && i.value().Get()->dtype == DataType::Int(32)) { + return TracedOptional(NullOpt, elem_offset.GetPath()); + } + } + } + return elem_offset; +} + +static Optional GetBufferAlignment(const TracedObject& buffer) { + auto data_alignment = buffer.GetAttr(&tir::BufferNode::data_alignment); + if (data_alignment.Get() != runtime::kAllocAlignment) { + return LiteralDoc::Int(data_alignment); + } else { + return NullOpt; + } +} + +static Optional GetBufferOffsetFactor(const TracedObject& buffer) { + auto offset_factor = buffer.GetAttr(&tir::BufferNode::offset_factor); + if (offset_factor.Get() != 1) { + return LiteralDoc::Int(offset_factor); + } else { + return NullOpt; + } +} + +static Optional GetBufferType(const TracedObject& buffer) { + auto buffer_type = buffer.GetAttr(&tir::BufferNode::buffer_type); + if (buffer_type.Get() != tir::BufferType::kDefault) { + return LiteralDoc::Str(MakeTraced(String("auto"), buffer_type.GetPath())); + } else { + return NullOpt; + } +} + +std::vector GetBufferPrintInfo( + const std::vector>& buffers, // + std::function f_var_defined, + std::unordered_map* var_explicit_def, + BufferAssociatedVariables* associated_vars) { + using namespace tvm::tir; + auto check_explicit_def = [&](const TracedObject& e) -> void { + PostOrderVisitExprTraced(e, [&](const TracedObject& n) -> void { + if (const auto* v = n.Get().as()) { + if (!f_var_defined(v) && !associated_vars->IsAssociated(v)) { + var_explicit_def->insert({v, n.GetPath()}); + } + } + }); + }; + for (const TracedObject& traced_buffer : buffers) { + const Buffer& buffer = traced_buffer.Get(); + if (!f_var_defined(buffer->data.get()) && HasDefaultDataPtr(buffer)) { + associated_vars->AssociateIfNotAlready(buffer->data.get(), buffer); + } + if (const auto* elem_offset = buffer->elem_offset.as()) { + if (!f_var_defined(elem_offset)) { + associated_vars->AssociateIfNotAlready(elem_offset, buffer); + } + } + } + for (TracedObject buffer : buffers) { + auto shape = buffer.GetAttr(&tir::BufferNode::shape); + std::for_each(shape.begin(), shape.end(), check_explicit_def); + + auto strides = buffer.GetAttr(&tir::BufferNode::strides); + std::for_each(strides.begin(), strides.end(), check_explicit_def); + + check_explicit_def(buffer.GetAttr(&tir::BufferNode::data)); + check_explicit_def(buffer.GetAttr(&tir::BufferNode::elem_offset)); + } + std::vector results; + for (TracedObject buffer : buffers) { + results.push_back( + BufferPrintInfo{/* .buffer = */ buffer, + /* .shape = */ buffer.GetAttr(&tir::BufferNode::shape), + /* .dtype = */ GetBufferDtype(buffer), + /* .data = */ GetBufferData(buffer, *associated_vars), + /* .strides = */ GetBufferStrides(buffer), + /* .elem_offset = */ GetBufferElemOffset(buffer, *associated_vars), + /* .scope = */ GetBufferScope(buffer), + /* .align = */ GetBufferAlignment(buffer), + /* .offset_factor = */ GetBufferOffsetFactor(buffer), + /* .buffer_type = */ GetBufferType(buffer)}); + } + return results; +} + +TracedObject GetBufferNameHint(const TracedObject& buf) { + TracedObject name_hint = buf.GetAttr(&tir::BufferNode::name); + if (name_hint.Get().empty()) { + return MakeTraced(String("buf"), buf.GetPath()); + } else { + return name_hint; + } +} + +IdDoc DefineFreeBuffer(const TracedObject& buf, const Frame& frame, + const IRDocsifier& p, std::function add_definiton) { + TracedObject name_hint = GetBufferNameHint(buf); + IdDoc buf_doc = p->vars->Define(buf.Get(), name_hint, frame); + auto f_var_defined = [&p](const tir::VarNode* var) -> bool { + return p->vars->IsVarDefined(GetRef(var)); + }; + std::unordered_map var_explicit_def; + BufferAssociatedVariables associated_vars; + + BufferPrintInfo buffer_print_info = + GetBufferPrintInfo({buf}, f_var_defined, &var_explicit_def, &associated_vars)[0]; + + associated_vars.Define(p->vars.get(), frame); + + ExprDoc buf_definition = buffer_print_info.AsCall( + TIR(p)->Attr("Buffer"), + [&p](const TracedObject& expr) -> ExprDoc { return p->AsDoc(expr); }); + add_definiton(AssignDoc(buf_doc, NullOpt, buf_definition)); + + return buf_doc; +} + +ExprDoc PrintBuffer(TracedObject buf, IRDocsifier p) { + Optional doc = p->vars->GetVarDoc(buf); + if (doc.defined()) { + return doc.value(); + } else { + // TODO(yelite): When implementing the PrimFunc printing, the logic here + // needs to change, putting variable def into PrimFuncFrame if it exists. + TIRTopLevelFrame top_level_frame = p->GetFrame().value(); + return DefineFreeBuffer(buf, top_level_frame, p, [top_level_frame](StmtDoc definition) { + top_level_frame->free_var_definitions.push_back(definition); + }); + } +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch(PrintBuffer); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/buffer.h b/src/script/printer/tir/buffer.h new file mode 100644 index 000000000000..10c2039b4e57 --- /dev/null +++ b/src/script/printer/tir/buffer.h @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_TIR_BUFFER_H_ +#define TVM_SCRIPT_PRINTER_TIR_BUFFER_H_ + +#include +#include +#include + +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +class BufferAssociatedVariables { + public: + void Disassociate(const tir::VarNode* var) { var2buffer_.erase(var); } + + void AssociateIfNotAlready(const tir::VarNode* var, const tir::Buffer& buffer) { + var2buffer_.insert({var, buffer}); + } + + bool IsAssociated(const tir::VarNode* var) const { return var2buffer_.count(var) != 0; } + + bool IsAssociatedWith(const PrimExpr& e, const tir::Buffer& buffer) const { + if (const auto* v = e.as()) { + auto it = var2buffer_.find(v); + return it != var2buffer_.end() && it->second == buffer; + } + return false; + } + + void Define(VarTableNode* vars, const Frame& frame) const { + for (const auto& kv : var2buffer_) { + const tir::VarNode* var = kv.first; + const tir::Buffer& buffer = kv.second; + + ExprDoc buffer_name = vars->GetVarDoc(MakeTraced(buffer)).value(); + buffer_name->source_paths.clear(); + + if (buffer->data.get() == var) { + vars->DefineByDoc( + buffer->data, [buffer_name]() { return buffer_name->Attr("data"); }, frame); + } else if (buffer->elem_offset.get() == var) { + vars->DefineByDoc( + buffer->elem_offset, [buffer_name]() { return buffer_name->Attr("elem_offset"); }, + frame); + } else { + ICHECK(false) << "Unexpected association. Buffer: " << buffer + << "; Var: " << GetRef(var); + } + } + } + + private: + std::unordered_map var2buffer_; +}; + +struct BufferPrintInfo { + TracedObject buffer; + TracedArray shape; + Optional dtype; + TracedOptional data; + TracedOptional> strides; + TracedOptional elem_offset; + Optional scope; + Optional align; + Optional offset_factor; + Optional buffer_type; + + ExprDoc AsCall(const ExprDoc& prefix, + std::function&)> converter) const; + ExprDoc AsCall(const ExprDoc& prefix, const Array& extra_args, + std::function&)> converter) const; +}; + +std::vector GetBufferPrintInfo( + const std::vector>& buffers, // + std::function f_var_defined, + std::unordered_map* var_explicit_def, + BufferAssociatedVariables* associated_vars); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_TIR_BUFFER_H_ diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc new file mode 100644 index 000000000000..fad85f269668 --- /dev/null +++ b/src/script/printer/tir/expr.cc @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject i, IRDocsifier p) { return LiteralDoc::Int(i); }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/tir.cc b/src/script/printer/tir/tir.cc index 38bd94a72bb5..48975791bd15 100644 --- a/src/script/printer/tir/tir.cc +++ b/src/script/printer/tir/tir.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include namespace tvm { @@ -33,6 +34,24 @@ TIRTopLevelFrame::TIRTopLevelFrame() : TIRFrame(make_object()) {} +void PostOrderVisitExprTraced(const TracedObject& expr, + const std::function&)>& callback) { + PostOrderVisitTraced( + expr, [](const ObjectRef& object) { return object->IsInstance(); }, + [&](const TracedObject& object) { callback(object.Downcast()); }); +} + +void PostOrderVisitStmtExprTraced( + const TracedObject& stmt, + const std::function&)>& callback) { + PostOrderVisitTraced( + stmt, + [](const ObjectRef& object) { + return object->IsInstance() || object->IsInstance(); + }, + [&](const TracedObject& object) { callback(object); }); +} + ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p) { auto type_annotation = var.GetAttr(&tir::VarNode::type_annotation); if (type_annotation.Get().defined()) { diff --git a/src/script/printer/tir/tir.h b/src/script/printer/tir/tir.h index bb5973ee4f3b..eb9b3f544d24 100644 --- a/src/script/printer/tir/tir.h +++ b/src/script/printer/tir/tir.h @@ -82,6 +82,13 @@ inline IdDoc TIR(const IRDocsifier& p) { return IdDoc(p->ir_prefix.Get("tir").va ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p); +void PostOrderVisitExprTraced(const TracedObject& expr, + const std::function&)>& callback); + +void PostOrderVisitStmtExprTraced( + const TracedObject& expr, + const std::function&)>& callback); + } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index abe7ce5e9a88..eef93b9348d9 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -86,6 +86,17 @@ inline TupleDoc AsTupleDoc(const TracedArray& arr, const IRDocsifier& ir_docs return ret; } +inline LiteralDoc DType2Literal(const DLDataType& dtype) { + using runtime::DLDataType2String; + return LiteralDoc::Str(DLDataType2String(dtype)); +} + +inline LiteralDoc DType2Literal(const TracedBasicValue& dtype) { + auto doc = DType2Literal(dtype.Get()); + doc->source_paths.push_back(dtype.GetPath()); + return doc; +} + } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/visit_traced.cc b/src/script/printer/visit_traced.cc new file mode 100644 index 000000000000..87ab787c2074 --- /dev/null +++ b/src/script/printer/visit_traced.cc @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +void PostOrderVisitTracedImpl(const ObjectRef& object, const ObjectPath& path, + const std::function& node_predicate, + const std::function&)>& callback); + +struct ObjAttrVisitor : public AttrVisitor { + ObjAttrVisitor(const ObjectPath& path, const std::function node_predicate, + const std::function&)>& callback) + : path(path), node_predicate(node_predicate), callback(callback) {} + + const ObjectPath& path; + const std::function node_predicate; + const std::function&)>& callback; + + void Visit(const char* key, double* value) final {} + void Visit(const char* key, int64_t* value) final {} + void Visit(const char* key, uint64_t* value) final {} + void Visit(const char* key, int* value) final {} + void Visit(const char* key, bool* value) final {} + void Visit(const char* key, void** value) final {} + void Visit(const char* key, DataType* value) final {} + void Visit(const char* key, std::string* value) final {} + void Visit(const char* key, runtime::NDArray* value) final {} + void Visit(const char* key, ObjectRef* value) final { + PostOrderVisitTracedImpl(*value, path->Attr(key), node_predicate, callback); + } +}; + +void PostOrderVisitTracedImpl(const ObjectRef& object, const ObjectPath& path, + const std::function& node_predicate, + const std::function&)>& callback) { + if (!object.defined()) { + return; + } + + if (object->IsInstance()) { + const ArrayNode* node = static_cast(object.get()); + for (size_t i = 0; i < node->size(); ++i) { + PostOrderVisitTracedImpl(node->at(i), path->ArrayIndex(i), node_predicate, callback); + } + } else if (object->IsInstance()) { + const MapNode* node = static_cast(object.get()); + for (auto kv : *node) { + PostOrderVisitTracedImpl(kv.second, path->MapValue(kv.first), node_predicate, callback); + } + } else { + if (!node_predicate(object)) { + return; + } + + ObjAttrVisitor visitor(path, node_predicate, callback); + ReflectionVTable::Global()->VisitAttrs(const_cast(object.get()), &visitor); + + callback(MakeTraced(object, path)); + } +} + +void PostOrderVisitTraced(const TracedObject& object, + const std::function& node_predicate, + const std::function&)>& callback) { + PostOrderVisitTracedImpl(object.Get(), object.GetPath(), node_predicate, callback); +} + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 936a2d74b48e..2d597f512a87 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -18,7 +18,7 @@ from tvm.ir import PointerType, PrimType, TupleType from tvm.script.printer import script -from tvm.tir import SizeVar, Var +from tvm.tir import SizeVar, Var, decl_buffer def format_script(s: str) -> str: @@ -41,52 +41,125 @@ def format_script(s: str) -> str: return cleaned_lines +def as_tir_script(node): + return script(node, "tir", {"tir": "T"}) + + @pytest.mark.parametrize( "ty, expected", [ - ( + pytest.param( PrimType("int8"), """ T.int8 """, + id="int", ), - ( + pytest.param( PrimType("float32"), """ T.float32 """, + id="float", ), - ( + pytest.param( PointerType(PrimType("int32")), """ T.Ptr(T.int32) """, + id="pointer", ), - ( + pytest.param( PointerType(PrimType("int32"), "global"), """ T.Ptr(T.int32, "global") """, + id="with_scope", ), - ( + pytest.param( TupleType([]), """ None """, + id="none", ), ], ) def test_type(ty, expected): - assert format_script(expected) == script(ty, "tir", {"tir": "T"}) + assert format_script(expected) == as_tir_script(ty) @pytest.mark.parametrize("var_type", [Var, SizeVar]) def test_var(var_type): var = var_type("x", "int8") - assert script(var, "tir", {"tir": "T"}) == format_script( + assert as_tir_script(var) == format_script( """ x: T.int8 x """ ) + + +@pytest.mark.parametrize( + "buffer, expected", + [ + pytest.param( + decl_buffer((5, 10), name="b"), + """ + b: T.Buffer(shape=(5, 10)) + b + """, + id="simple", + ), + pytest.param( + decl_buffer((5), name=""), + """ + buf: T.Buffer(shape=(5,)) + buf + """, + id="no_name", + ), + pytest.param( + decl_buffer((SizeVar("m", "int"), SizeVar("n", "int")), dtype="int8"), + """ + m: T.int32 + n: T.int32 + buffer: T.Buffer("int8", shape=(m, n)) + buffer + """, + id="symbolic_shape", + ), + pytest.param( + decl_buffer( + (4, 10), + dtype="int8", + data=Var("p", PointerType(PrimType("int16"), "local")), + strides=[2, 5], + elem_offset=2, + data_alignment=16, + offset_factor=2, + scope="local", + ), + """ + p: T.Ptr(T.int16, "local") + buffer: T.Buffer("int8", shape=(4, 10), data=p, strides=(2, 5), elem_offset=2, scope="local", align=16, offset_factor=2) + buffer + """, + id="all_param", + ), + pytest.param( + decl_buffer( + (4, 10), + dtype="bool", + ), + """ + buffer: T.Buffer("bool", shape=(4, 10)) + buffer + """, + id="bool_different_ptr_type", + ), + ], +) +def test_buffer(buffer, expected): + assert as_tir_script(buffer) == format_script(expected) From 373ff0f2f07ba45b4a1e53577b611682a318d7e9 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 23 Aug 2022 00:54:07 -0400 Subject: [PATCH 4/4] Add TIR expr printing Co-authored-by: Greg Bonik --- src/script/printer/tir/expr.cc | 249 ++++++++- src/script/printer/tir/tir.cc | 57 +++ src/script/printer/tir/tir.h | 7 + .../unittest/test_tvmscript_printer_tir.py | 481 +++++++++++++++++- 4 files changed, 791 insertions(+), 3 deletions(-) diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index fad85f269668..f3eb46ca5fe7 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -17,19 +17,266 @@ * under the License. */ +#include "tvm/ir/expr.h" + +#include +#include #include #include #include +#include #include "../utils.h" +#include "./tir.h" namespace tvm { namespace script { namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch([](TracedObject i, IRDocsifier p) { return LiteralDoc::Int(i); }); + .set_dispatch([](TracedObject s, IRDocsifier p) { + auto value = s.GetAttr(&tir::StringImmNode::value); + return LiteralDoc::Str(value); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject i, IRDocsifier p) -> ExprDoc { + const IntImm& node = i.Get(); + if (node->dtype == DataType::Int(32)) { + return LiteralDoc::Int(i); + } else if (node->dtype.is_bool()) { + return LiteralDoc::Boolean(MakeTraced(i.Get()->value != 0, i.GetPath())); + } else { + String type_name = runtime::DLDataType2String(node->dtype); + return TIR(p) + ->Attr(MakeTraced(type_name, i.GetAttr(&PrimExprNode::dtype).GetPath())) + ->Call({LiteralDoc::Int(i)}); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject f, IRDocsifier p) { + String type_name = runtime::DLDataType2String(f.Get()->dtype); + return TIR(p) + ->Attr(MakeTraced(type_name, f.GetAttr(&PrimExprNode::dtype).GetPath())) + ->Call({LiteralDoc::Float(f)}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject cast, IRDocsifier p) { + auto value = cast.GetAttr(&tir::CastNode::value); + auto dtype = cast.GetAttr(&tir::CastNode::dtype); + return TIR(p)->Attr("cast")->Call({p->AsExprDoc(value), DType2Literal(dtype)}); + }); + +using OpKind = OperationDocNode::Kind; + +template +ExprDoc PrintBinOp(TracedObject expr, IRDocsifier p) { + using NodeType = typename BinOpType::ContainerType; + auto a = expr.GetAttr(&NodeType::a); + auto b = expr.GetAttr(&NodeType::b); + return OperationDoc(op_kind, {p->AsExprDoc(a), p->AsExprDoc(b)}); +} + +#define TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(Op, DocOpKind) \ + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch(PrintBinOp); + +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::Add, OpKind::kAdd); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::Sub, OpKind::kSub); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::Mul, OpKind::kMult); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::Div, OpKind::kDiv); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::FloorDiv, OpKind::kFloorDiv); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::FloorMod, OpKind::kMod); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::LT, OpKind::kLt); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::LE, OpKind::kLtE); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::GT, OpKind::kGt); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::GE, OpKind::kGtE); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::EQ, OpKind::kEq); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::NE, OpKind::kNotEq); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::And, OpKind::kAnd); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::Or, OpKind::kOr); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject e, IRDocsifier p) { + return OperationDoc(OpKind::kNot, {p->AsExprDoc(e.GetAttr(&tir::NotNode::a))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject expr, IRDocsifier p) { + auto condition = expr.GetAttr(&tir::SelectNode::condition); + auto true_value = expr.GetAttr(&tir::SelectNode::true_value); + auto false_value = expr.GetAttr(&tir::SelectNode::false_value); + return TIR(p)->Attr("Select")->Call( + {p->AsExprDoc(condition), p->AsExprDoc(true_value), p->AsExprDoc(false_value)}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject expr, IRDocsifier p) { + auto buffer = expr.GetAttr(&tir::BufferLoadNode::buffer); + auto indices = expr.GetAttr(&tir::BufferLoadNode::indices); + + ExprDoc base = p->AsExprDoc(buffer); + return base[AsDocArray(indices, p)]; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject e, IRDocsifier p) -> Doc { + LOG(FATAL) + << "Cannot print a tir.ProducerLoad as it is not valid in TIR Primfuncs. You need to " + "lower this function first."; + throw; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject e, IRDocsifier p) -> Doc { + LOG(FATAL) << "Cannot print a tir.Load"; + throw; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject expr, IRDocsifier p) { + auto base = expr.GetAttr(&tir::RampNode::base); + auto stride = expr.GetAttr(&tir::RampNode::stride); + auto lanes = expr.GetAttr(&tir::RampNode::lanes); + return TIR(p)->Attr("ramp")->Call( + {p->AsExprDoc(base), p->AsExprDoc(stride), LiteralDoc::Int(lanes)}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject expr, IRDocsifier p) { + auto value = expr.GetAttr(&tir::BroadcastNode::value); + auto lanes = expr.GetAttr(&tir::BroadcastNode::lanes); + return TIR(p)->Attr("broadcast")->Call({p->AsExprDoc(value), LiteralDoc::Int(lanes)}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject expr, IRDocsifier p) { + TIRGeneralFrame frame; + WithCtx with_frame = p->WithFrame(frame); + + auto var = expr.GetAttr(&tir::LetNode::var); + auto value = expr.GetAttr(&tir::LetNode::value); + auto body = expr.GetAttr(&tir::LetNode::body); + + auto value_doc = p->AsExprDoc(value); + IdDoc var_doc = DefineTIRVar(var, frame, p); + return TIR(p)->Attr("let")->Call({var_doc, value_doc, p->AsExprDoc(body)}); + }); + +ExprDoc PrintCall(TracedObject call, IRDocsifier p) { + auto op_or_global_var = call.GetAttr(&tir::CallNode::op); + + if (op_or_global_var.IsInstance()) { + // TODO(yelite): Call PrintOpCall once it's finished + TracedObject op_name = op_or_global_var.Downcast().GetAttr(&OpNode::name); + Array arg_docs{LiteralDoc::Str(op_name)}; + TracedArray args = call.GetAttr(&tir::CallNode::args); + arg_docs = Concat(arg_docs, AsExprDocArray(args, p)); + return TIR(p)->Attr("call")->Call(arg_docs); + } else { + auto op_gvar = op_or_global_var.Downcast(); + auto name_hint = op_gvar.GetAttr(&GlobalVarNode::name_hint); + auto args = call.GetAttr(&tir::CallNode::args); + + IdDoc name_doc(name_hint.Get()); + name_doc->source_paths.push_back(name_hint.GetPath()); + + return name_doc->Call(AsExprDocArray(args, p)); + } +} +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch(PrintCall); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject expr, IRDocsifier p) { + auto vectors = expr.GetAttr(&tir::ShuffleNode::vectors); + auto indices = expr.GetAttr(&tir::ShuffleNode::indices); + return TIR(p)->Attr("shuffle")->Call({AsListDoc(vectors, p), AsListDoc(indices, p)}); + }); + +ExprDoc PrintCommReducer(TracedObject expr, IRDocsifier p) { + TIRGeneralFrame frame; + WithCtx with_frame = p->WithFrame(frame); + + auto lhs = expr.GetAttr(&tir::CommReducerNode::lhs); + auto rhs = expr.GetAttr(&tir::CommReducerNode::rhs); + + Array reducer_args; + for (TracedObject v_lhs : lhs) { + IdDoc var_doc = DefineTIRVar(v_lhs, frame, p); + reducer_args.push_back(var_doc); + } + for (TracedObject v_rhs : rhs) { + IdDoc var_doc = DefineTIRVar(v_rhs, frame, p); + reducer_args.push_back(var_doc); + } + + auto result = expr.GetAttr(&tir::CommReducerNode::result); + + ExprDoc reducer_body = rhs.size() == 1 ? p->AsExprDoc(result[0]) : AsTupleDoc(result, p); + + LambdaDoc reducer{reducer_args, reducer_body}; + + auto identity_element = expr.GetAttr(&tir::CommReducerNode::identity_element); + ListDoc identity_elements_doc = AsListDoc(identity_element, p); + + return TIR(p)->Attr("comm_reducer")->Call({reducer, identity_elements_doc}); +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch(PrintCommReducer); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject expr, IRDocsifier p) { + auto combiner = expr.GetAttr(&tir::ReduceNode::combiner); + auto source = expr.GetAttr(&tir::ReduceNode::source); + auto axis = expr.GetAttr(&tir::ReduceNode::axis); + auto value_index = expr.GetAttr(&tir::ReduceNode::value_index); + + Array axis_docs; + for (const auto& iter_var : axis) { + axis_docs.push_back(IterVarStandaloneDef(iter_var, p)); + } + ListDoc axis_list_doc = ListDoc(axis_docs); + axis_list_doc->source_paths.push_back(axis.GetPath()); + + return TIR(p)->Attr("reduce")->Call({p->AsExprDoc(combiner), AsListDoc(source, p), + axis_list_doc, LiteralDoc::Int(value_index)}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject expr, IRDocsifier p) { + auto min = expr.GetAttr(&RangeNode::min); + auto extent = expr.GetAttr(&RangeNode::extent); + auto max = MakeTraced(min.Get() + extent.Get(), extent.GetPath()); + return SliceDoc(p->AsExprDoc(min), p->AsExprDoc(max), NullOpt); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject e, IRDocsifier p) -> Doc { + LOG(FATAL) << "Cannot print tir::Any"; + throw; + }); + +ExprDoc PrintBufferRegion(TracedObject buffer_region, IRDocsifier p) { + auto region = buffer_region.GetAttr(&tir::BufferRegionNode::region); + + Array indices; + + for (TracedObject range : region) { + auto extent = range.GetAttr(&RangeNode::extent); + if (tir::is_one(extent.Get())) { + auto index = p->AsExprDoc(range.GetAttr(&RangeNode::min)); + index->source_paths.push_back(extent.GetPath()); + indices.push_back(std::move(index)); + } else { + indices.push_back(p->AsDoc(range)); + } + } + auto buffer = buffer_region.GetAttr(&tir::BufferRegionNode::buffer); + return p->AsExprDoc(buffer)[indices]; +} +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch(PrintBufferRegion); } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/tir/tir.cc b/src/script/printer/tir/tir.cc index 48975791bd15..cbe88d5bc2c4 100644 --- a/src/script/printer/tir/tir.cc +++ b/src/script/printer/tir/tir.cc @@ -63,6 +63,63 @@ ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDo } } +static String GetIterTypePyStr(tir::IterVarType iter_type) { + switch (iter_type) { + case tir::kDataPar: + return "DataPar"; + case tir::kThreadIndex: + return "ThreadIndex"; + case tir::kCommReduce: + return "CommReduce"; + case tir::kOrdered: + return "Ordered"; + case tir::kOpaque: + return "DimInfo"; + case tir::kUnrolled: + return "Unrolled"; + case tir::kVectorized: + return "Vectorized"; + case tir::kParallelized: + return "Parallelized"; + case tir::kTensorized: + return "Tensorized"; + default: + LOG(FATAL) << "Unknown iter type: " << iter_type; + throw; + } +} + +ExprDoc IterVarStandaloneDef(const TracedObject iter_var, const IRDocsifier& p) { + Array args; + + args.push_back(p->AsExprDoc(iter_var.GetAttr(&tir::IterVarNode::var))); + + if (iter_var.Get()->dom.defined()) { + auto dom = iter_var.GetAttr(&tir::IterVarNode::dom); + auto min = dom.GetAttr(&RangeNode::min); + auto extent = dom.GetAttr(&RangeNode::extent); + if (tir::is_zero(min.Get())) { + auto extent_doc = p->AsExprDoc(extent); + extent_doc->source_paths.push_back(min.GetPath()); + args.push_back(extent_doc); + } else { + auto max = MakeTraced(min.Get() + extent.Get(), extent.GetPath()); + args.push_back(TupleDoc({p->AsExprDoc(min), p->AsExprDoc(max)})); + } + } else { + args.push_back(LiteralDoc::None()); + } + + auto iter_type = iter_var.GetAttr(&tir::IterVarNode::iter_type); + args.push_back( + LiteralDoc::Str(MakeTraced(GetIterTypePyStr(iter_type.Get()), iter_type.GetPath()))); + args.push_back(LiteralDoc::Str(iter_var.GetAttr(&tir::IterVarNode::thread_tag))); + + ExprDoc result = TIR(p)->Attr("iter_var")->Call(args); + result->source_paths.push_back(iter_var.GetPath()); + return result; +} + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("tir", [](TracedObject obj, IRDocsifier p) { const ObjectRef& root_node = obj.Get()->root_node; diff --git a/src/script/printer/tir/tir.h b/src/script/printer/tir/tir.h index eb9b3f544d24..6020fe335e54 100644 --- a/src/script/printer/tir/tir.h +++ b/src/script/printer/tir/tir.h @@ -80,6 +80,10 @@ class TIRGeneralFrame : public TIRFrame { inline IdDoc TIR(const IRDocsifier& p) { return IdDoc(p->ir_prefix.Get("tir").value_or("T")); } +inline IdDoc DefineTIRVar(const TracedObject& var, const Frame& frame, IRDocsifier p) { + return p->vars->Define(var.Get(), var.GetAttr(&tir::VarNode::name_hint), frame); +} + ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p); void PostOrderVisitExprTraced(const TracedObject& expr, @@ -89,6 +93,9 @@ void PostOrderVisitStmtExprTraced( const TracedObject& expr, const std::function&)>& callback); +// Print IterVar as T.iter_var(...) +ExprDoc IterVarStandaloneDef(const TracedObject iter_var, const IRDocsifier& p); + } // namespace printer } // namespace script } // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 2d597f512a87..b03ce24fa808 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -16,9 +16,43 @@ # under the License. import pytest -from tvm.ir import PointerType, PrimType, TupleType +from tvm.ir import GlobalVar, PointerType, PrimType, Range, TupleType from tvm.script.printer import script -from tvm.tir import SizeVar, Var, decl_buffer +from tvm.tir import ( + EQ, + GE, + GT, + LE, + LT, + NE, + Add, + And, + Broadcast, + BufferLoad, + BufferRegion, + Call, + Cast, + CommReducer, + Div, + FloatImm, + FloorDiv, + FloorMod, + IntImm, + IterVar, + Let, + Mul, + Not, + Or, + Ramp, + Reduce, + Select, + Shuffle, + SizeVar, + StringImm, + Sub, + Var, + decl_buffer, +) def format_script(s: str) -> str: @@ -163,3 +197,446 @@ def test_var(var_type): ) def test_buffer(buffer, expected): assert as_tir_script(buffer) == format_script(expected) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + StringImm("test"), + """ + "test" + """, + id="string", + ), + pytest.param( + StringImm(""), + """ + "" + """, + id="empty", + ), + pytest.param( + StringImm("test1\ntest2\n"), + r""" + "test1\ntest2\n" + """, + id="multiline", + ), + ], +) +def test_string_imm(node, expected): + assert as_tir_script(node) == format_script(expected) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + IntImm("int32", 1), + """ + 1 + """, + id="default-dtype", + ), + pytest.param( + IntImm("int8", 0), + """ + T.int8(0) + """, + id="int8", + ), + pytest.param( + IntImm("int64", -1), + """ + T.int64(-1) + """, + id="int64", + ), + pytest.param( + IntImm("bool", 1), + """ + True + """, + id="boolean-true", + ), + pytest.param( + IntImm("bool", 0), + """ + False + """, + id="boolean-true", + ), + ], +) +def test_int_imm(node, expected): + assert as_tir_script(node) == format_script(expected) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + FloatImm("float32", 1.5), + """ + T.float32(1.5) + """, + id="f32", + ), + pytest.param( + FloatImm("float16", 1.5), + """ + T.float16(1.5) + """, + id="f16", + ), + ], +) +def test_float_imm(node, expected): + assert as_tir_script(node) == format_script(expected) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + Cast("float32", Var("x", dtype="int32")), + """ + x: T.int32 + T.cast(x, "float32") + """, + id="with-var", + ), + pytest.param( + Cast("int8", IntImm("int16", 1)), + """ + T.cast(T.int16(1), "int8") + """, + id="with-var", + ), + ], +) +def test_cast(node, expected): + assert as_tir_script(node) == format_script(expected) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + Add(1, 0), + """ + 1 + 0 + """, + id="add", + ), + pytest.param( + Sub(1, 0), + """ + 1 - 0 + """, + id="sub", + ), + pytest.param( + Mul(-2.5, 1.5), + """ + T.float32(-2.5) * T.float32(1.5) + """, + id="mul", + ), + pytest.param( + Div(5, 2), + """ + 5 / 2 + """, + id="div", + ), + pytest.param( + FloorDiv(3, 2), + """ + 3 // 2 + """, + id="floor-div", + ), + pytest.param( + FloorMod(IntImm("int8", 5), IntImm("int8", 2)), + """ + T.int8(5) % T.int8(2) + """, + id="floor-mod", + ), + pytest.param( + LT(0, 1), + """ + 0 < 1 + """, + id="lt", + ), + pytest.param( + LE(1.0, 5.0), + """ + T.float32(1) <= T.float32(5) + """, + id="le", + ), + pytest.param( + GT(1, 0), + """ + 1 > 0 + """, + id="gt", + ), + pytest.param( + GE(Var("n", "int32"), 0), + """ + n: T.int32 + n >= 0 + """, + id="ge", + ), + pytest.param( + EQ(Var("n", "int32"), Var("m", "int32")), + """ + n: T.int32 + m: T.int32 + n == m + """, + id="eq", + ), + pytest.param( + NE(1, 0), + """ + 1 != 0 + """, + id="ne", + ), + pytest.param( + And(IntImm("bool", 1), IntImm("bool", 0)), + """ + True and False + """, + id="and", + ), + pytest.param( + Or(IntImm("bool", 1), IntImm("bool", 0)), + """ + True or False + """, + id="or", + ), + ], +) +def test_binary_op(node, expected): + assert as_tir_script(node) == format_script(expected) + + +def test_not(): + assert as_tir_script(Not(IntImm("bool", 1))) == format_script( + """ + not True + """ + ) + + +def test_select(): + node = Select(IntImm("bool", 1), 0, 1) + assert as_tir_script(node) == format_script( + """ + T.Select(True, 0, 1) + """ + ) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + BufferLoad(decl_buffer((5, 10), name="b"), [0, 1]), + """ + b: T.Buffer(shape=(5, 10)) + b[0, 1] + """, + id="normal", + ), + pytest.param( + BufferLoad( + decl_buffer((5,), name="b"), + [ + 0, + ], + ), + """ + b: T.Buffer(shape=(5,)) + b[0] + """, + id="1d", + ), + pytest.param( + BufferLoad(decl_buffer((), name="b"), []), + """ + b: T.Buffer(shape=()) + b[()] + """, + id="0d", + ), + ], +) +def test_buffer_load(node, expected): + assert as_tir_script(node) == format_script(expected) + + +def test_ramp(): + node = Ramp(0, 1, 8) + expected = """ + T.ramp(0, 1, 8) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_broadcast(): + node = Broadcast(0, 4) + expected = """ + T.broadcast(0, 4) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_let(): + x = Var("x", "int32") + y = Var("y", "int32") + node = Let(x, 1, x + y) + # Not var definition for x because x isn't free variable in this expression + expected = """ + y: T.int32 + T.let(x, 1, x + y) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_call_tir_op(): + node = Call("float32", "tir.tvm_stack_make_shape", [0, 1]) + expected = """ + T.call("tir.tvm_stack_make_shape", 0, 1) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_call_global_var(): + f_var = GlobalVar("test_f") + node = Call("float32", f_var, [0, 1]) + expected = """ + test_f(0, 1) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_shuffle(): + x = Var("x", "int32") + y = Var("y", "int32") + node = Shuffle([x, 1, 10], [0, 1, y]) + expected = """ + x: T.int32 + y: T.int32 + T.shuffle([x, 1, 10], [0, 1, y]) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_comm_reducer_single_value(): + x = Var("x", "int32") + y = Var("y", "int32") + node = CommReducer([x], [y], [x + y], [0]) + expected = """ + T.comm_reducer(lambda x, y: x + y, [0]) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_comm_reducer_multi_value(): + x0 = Var("x0", "int32") + x1 = Var("x1", "int32") + y0 = Var("y0", "int32") + y1 = Var("y1", "int32") + node = CommReducer([x0, x1], [y0, y1], [x0 + y0, x1 * y1], [0, 1]) + expected = """ + T.comm_reducer(lambda x0, x1, y0, y1: (x0 + y0, x1 * y1), [0, 1]) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_reduce(): + x = Var("x", "int32") + y = Var("y", "int32") + m = Var("m", "int32") + comm_reducer = CommReducer([x], [y], [x + y], [0]) + node = Reduce(comm_reducer, [m], [IterVar(None, Var("i", "int32"), 2)], True, 0) + expected = """ + i: T.int32 + m: T.int32 + T.reduce(T.comm_reducer(lambda x, y: x + y, [0]), [m], [T.iter_var(i, None, "CommReduce", "")], 0) + """ + assert as_tir_script(node) == format_script(expected) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + Range(0, 1), + """ + _[0:1] + """, + id="normal", + ), + pytest.param( + Range(10), + """ + _[0:10] + """, + id="one-arg", + ), + ], +) +def test_range(node, expected): + assert as_tir_script(node) == format_script(expected) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + BufferRegion(decl_buffer((5, 10), name="b"), [Range(0, 4), Range(0, 9)]), + """ + b: T.Buffer(shape=(5, 10)) + b[0:4, 0:9] + """, + id="normal", + ), + pytest.param( + BufferRegion(decl_buffer((5, 10), name="b"), [Range(0, 1), Range(5, 9)]), + """ + b: T.Buffer(shape=(5, 10)) + b[0, 5:9] + """, + id="scalar-range", + ), + pytest.param( + BufferRegion(decl_buffer((5,), name="b"), [Range(0, 3)]), + """ + b: T.Buffer(shape=(5,)) + b[0:3] + """, + id="1d", + ), + pytest.param( + BufferRegion(decl_buffer((), name="b"), []), + """ + b: T.Buffer(shape=()) + b[()] + """, + id="0d", + ), + ], +) +def test_buffer_region(node, expected): + assert as_tir_script(node) == format_script(expected)