From 22755f79818326b1b7be88b7d8874421a821f062 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 16 Sep 2022 23:26:01 -0400 Subject: [PATCH 1/8] Update TVMScript infra for future TIR printer changes Co-authored-by: Greg Bonik --- include/tvm/script/printer/doc.h | 64 +++++++++++++ .../script/printer/traced_object_functor.h | 37 +------- include/tvm/script/printer/var_table.h | 11 +++ src/script/printer/doc.cc | 30 ++++-- src/script/printer/ir_docsifier.cc | 2 +- src/script/printer/utils.h | 93 +++++++++++++++++++ src/script/printer/var_table.cc | 3 +- .../cpp/tvmscript_printer_irdocsifier_test.cc | 13 ++- ...ript_printer_traced_object_functor_test.cc | 37 ++++---- 9 files changed, 228 insertions(+), 62 deletions(-) create mode 100644 src/script/printer/utils.h diff --git a/include/tvm/script/printer/doc.h b/include/tvm/script/printer/doc.h index 72f343354b1b..1ee7fd6a7fd4 100644 --- a/include/tvm/script/printer/doc.h +++ b/include/tvm/script/printer/doc.h @@ -22,6 +22,7 @@ #include #include #include +#include namespace tvm { namespace script { @@ -87,6 +88,15 @@ class ExprDocNode : public DocNode { */ ExprDoc Attr(String attr) const; + /*! + * \brief Create a doc representing attribute access on the current ExprDoc + * \param attr The attribute to access. + * + * The ObjectPath of attr will be pushed to the source_path of the returned + * doc. + */ + ExprDoc Attr(TracedObject attr) const; + /*! * \brief Create a doc representing index access on the current ExprDoc * \param indices The indices to access. @@ -242,6 +252,7 @@ class LiteralDocNode : public ExprDocNode { class LiteralDoc : public ExprDoc { protected: explicit LiteralDoc(ObjectRef value); + LiteralDoc(ObjectRef value, ObjectPath object_path); public: /*! @@ -249,30 +260,83 @@ class LiteralDoc : public ExprDoc { */ static LiteralDoc None() { return LiteralDoc(ObjectRef(nullptr)); } + /*! + * \brief Create a LiteralDoc to represent None/null/empty value. + * \param object_path The source path of the returned Doc. + */ + static LiteralDoc None(ObjectPath object_path) { + return LiteralDoc(ObjectRef(nullptr), object_path); + } + /*! * \brief Create a LiteralDoc to represent integer. * \param v The integer value. */ static LiteralDoc Int(int v) { return LiteralDoc(IntImm(DataType::Int(64), v)); } + /*! + * \brief Create a LiteralDoc to represent integer. + * \param v The integer value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Int(const TracedObject& v) { return LiteralDoc(v.Get(), v.GetPath()); } + + /*! + * \brief Create a LiteralDoc to represent integer. + * \param v The integer value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Int(const TracedBasicValue& v) { + return LiteralDoc(IntImm(DataType::Int(64), v.Get()), v.GetPath()); + } /*! * \brief Create a LiteralDoc to represent boolean. * \param v The boolean value. */ static LiteralDoc Boolean(bool v) { return LiteralDoc(IntImm(DataType::Bool(), v)); } + /*! + * \brief Create a LiteralDoc to represent boolean. + * \param v The boolean value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Boolean(const TracedBasicValue& v) { + return LiteralDoc(IntImm(DataType::Bool(), v.Get()), v.GetPath()); + } + /*! * \brief Create a LiteralDoc to represent float. * \param v The float value. */ static LiteralDoc Float(double v) { return LiteralDoc(FloatImm(DataType::Float(64), v)); } + /*! + * \brief Create a LiteralDoc to represent float. + * \param v The float value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Float(const TracedObject& v) { + return LiteralDoc(v.Get(), v.GetPath()); + } + /*! * \brief Create a LiteralDoc to represent string. * \param v The string value. */ static LiteralDoc Str(const String& v) { return LiteralDoc(v); } + /*! + * \brief Create a LiteralDoc to represent string. + * \param v The string value. + * + * The ObjectPath of v will be pushed to the source_path of the returned doc. + */ + static LiteralDoc Str(const TracedObject& v) { return LiteralDoc(v.Get(), v.GetPath()); } + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LiteralDoc, ExprDoc, LiteralDocNode); }; diff --git a/include/tvm/script/printer/traced_object_functor.h b/include/tvm/script/printer/traced_object_functor.h index 6caaf8a6e0d5..8f72d139a5a5 100644 --- a/include/tvm/script/printer/traced_object_functor.h +++ b/include/tvm/script/printer/traced_object_functor.h @@ -34,35 +34,6 @@ namespace tvm { namespace script { namespace printer { -namespace { - -namespace detail { -/*! - * \brief Helper template class to extract the type of first argument of a function - * \tparam FType The function type. - */ -template -struct FirstArgTypeGetter; - -template -struct FirstArgTypeGetter { - using T = ArgOne; -}; - -/*! - * \brief Template alias for the type of first argument of a function - * \tparam FType The function type. - * - * The name of public functions are in snake case to be consistent with - * tvm/node/functor.h - */ -template -using FirstArgType = typename detail::FirstArgTypeGetter< - typename tvm::runtime::detail::function_signature::FType>::T; -} // namespace detail - -} // namespace - /* * This type alias and the following free functions are created to reduce the binary bloat * from template and also hide implementation details from this header @@ -156,8 +127,7 @@ class TracedObjectFunctor { * * The diaptch function should have signature `R(TracedObject, Args...)`. */ - template ::ObjectRefType, + template ::value>> TSelf& set_dispatch(String token, TCallable f) { return set_dispatch( @@ -177,9 +147,10 @@ class TracedObjectFunctor { * * Default dispatch function has an empty string as dispatch token. */ - template + template ::value>> TSelf& set_dispatch(TCallable&& f) { - return set_dispatch(kDefaultDispatchToken, std::forward(f)); + return set_dispatch(kDefaultDispatchToken, std::forward(f)); } /*! diff --git a/include/tvm/script/printer/var_table.h b/include/tvm/script/printer/var_table.h index 9300a976c569..2cd9335213a3 100644 --- a/include/tvm/script/printer/var_table.h +++ b/include/tvm/script/printer/var_table.h @@ -103,6 +103,17 @@ class VarTableNode : public Object { */ Optional GetVarDoc(const ObjectRef& obj, const ObjectPath& object_path) const; + /*! + * \brief Get the doc for variable. + * \param obj The traced variable object. + * + * \return The doc for variable, if it exists in the table. Otherwise it returns NullOpt. + */ + template + Optional GetVarDoc(const TracedObject obj) const { + return GetVarDoc(obj.Get(), obj.GetPath()); + } + /*! * \brief Check if a variable exists in the table. * \param obj The variable object. diff --git a/src/script/printer/doc.cc b/src/script/printer/doc.cc index d6f5ff35ab53..f3b431bd62db 100644 --- a/src/script/printer/doc.cc +++ b/src/script/printer/doc.cc @@ -27,6 +27,12 @@ namespace printer { ExprDoc ExprDocNode::Attr(String attr) const { return AttrAccessDoc(GetRef(this), attr); } +ExprDoc ExprDocNode::Attr(TracedObject attr) const { + auto doc = AttrAccessDoc(GetRef(this), attr.Get()); + doc->source_paths.push_back(attr.GetPath()); + return doc; +} + ExprDoc ExprDocNode::operator[](Array indices) const { return IndexDoc(GetRef(this), indices); } @@ -54,6 +60,13 @@ LiteralDoc::LiteralDoc(ObjectRef value) { this->data_ = std::move(n); } +LiteralDoc::LiteralDoc(ObjectRef value, ObjectPath object_path) { + ObjectPtr n = make_object(); + n->value = value; + n->source_paths.push_back(object_path); + this->data_ = std::move(n); +} + IdDoc::IdDoc(String name) { ObjectPtr n = make_object(); n->name = name; @@ -225,7 +238,8 @@ TVM_REGISTER_GLOBAL("script.printer.DocSetSourcePaths") }); TVM_REGISTER_NODE_TYPE(ExprDocNode); -TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr").set_body_method(&ExprDocNode::Attr); +TVM_REGISTER_GLOBAL("script.printer.ExprDocAttr") + .set_body_method(&ExprDocNode::Attr); TVM_REGISTER_GLOBAL("script.printer.ExprDocIndex") .set_body_method(&ExprDocNode::operator[]); TVM_REGISTER_GLOBAL("script.printer.ExprDocCall") @@ -242,11 +256,15 @@ TVM_REGISTER_GLOBAL("script.printer.StmtBlockDoc").set_body_typed([](Array(LiteralDoc::None); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocInt") + .set_body_typed(LiteralDoc::Int); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocBoolean") + .set_body_typed(LiteralDoc::Boolean); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocFloat") + .set_body_typed(LiteralDoc::Float); +TVM_REGISTER_GLOBAL("script.printer.LiteralDocStr") + .set_body_typed(LiteralDoc::Str); TVM_REGISTER_NODE_TYPE(IdDocNode); TVM_REGISTER_GLOBAL("script.printer.IdDoc").set_body_typed([](String name) { return IdDoc(name); }); diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index b72ed48db63b..7f032ec50269 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -61,7 +61,7 @@ RootNodeContainer::RootNodeContainer(ObjectRef root_node) { // }); // \endcode TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch([](TracedObject obj, IRDocsifier p) -> Doc { + .set_dispatch([](TracedObject obj, IRDocsifier p) -> Doc { String top_dispatch_token = p->dispatch_tokens.back(); ICHECK_NE(top_dispatch_token, ""); ICHECK(false) << "Printing IR " << top_dispatch_token << " is not implemented."; diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h new file mode 100644 index 000000000000..abe7ce5e9a88 --- /dev/null +++ b/src/script/printer/utils.h @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_UTILS_H_ +#define TVM_SCRIPT_PRINTER_UTILS_H_ + +#include +#include + +#include + +namespace tvm { +namespace script { +namespace printer { + +template +Array AsDocArray(const TracedArray& refs, const IRDocsifier& ir_docsifier) { + Array result; + for (auto ref : refs) { + result.push_back(ir_docsifier->AsExprDoc(ref)); + } + return result; +} + +template +Array AsDocArray(std::initializer_list&& refs, const IRDocsifier& ir_docsifier) { + Array result; + for (auto& ref : refs) { + result.push_back(ir_docsifier->AsExprDoc(ref)); + } + return result; +} + +template +Array AsExprDocArray(const TracedArray& refs, const IRDocsifier& ir_docsifier) { + return AsDocArray(refs, ir_docsifier); +} + +template +Array AsExprDocArray(std::initializer_list&& refs, + const IRDocsifier& ir_docsifier) { + return AsDocArray(std::move(refs), ir_docsifier); +} + +inline DictDoc AsDictDoc(const TracedMap& dict, + const IRDocsifier& ir_docsifier) { + Array keys; + Array values; + + for (auto p : dict) { + keys.push_back(LiteralDoc::Str(p.first)); + values.push_back(ir_docsifier->AsExprDoc(p.second)); + } + + auto doc = DictDoc(keys, values); + doc->source_paths.push_back(dict.GetPath()); + return doc; +} + +template +inline ListDoc AsListDoc(const TracedArray& arr, const IRDocsifier& ir_docsifier) { + auto ret = ListDoc(AsExprDocArray(arr, ir_docsifier)); + ret->source_paths.push_back(arr.GetPath()); + return ret; +} + +template +inline TupleDoc AsTupleDoc(const TracedArray& arr, const IRDocsifier& ir_docsifier) { + auto ret = TupleDoc(AsExprDocArray(arr, ir_docsifier)); + ret->source_paths.push_back(arr.GetPath()); + return ret; +} + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_UTILS_H_ diff --git a/src/script/printer/var_table.cc b/src/script/printer/var_table.cc index 49ba93f9bcfe..62d8b2f66cc2 100644 --- a/src/script/printer/var_table.cc +++ b/src/script/printer/var_table.cc @@ -99,7 +99,8 @@ TVM_REGISTER_GLOBAL("script.printer.VarTableDefineByDoc") obj, [f = std::move(factory)]() { return f(); }, frame); }); TVM_REGISTER_GLOBAL("script.printer.VarTableGetVarDoc") - .set_body_method(&VarTableNode::GetVarDoc); + .set_body_method, const ObjectRef&, + const ObjectPath&>(&VarTableNode::GetVarDoc); TVM_REGISTER_GLOBAL("script.printer.VarTableIsVarDefined") .set_body_method(&VarTableNode::IsVarDefined); diff --git a/tests/cpp/tvmscript_printer_irdocsifier_test.cc b/tests/cpp/tvmscript_printer_irdocsifier_test.cc index fcdb5ed04e41..8c68399df222 100644 --- a/tests/cpp/tvmscript_printer_irdocsifier_test.cc +++ b/tests/cpp/tvmscript_printer_irdocsifier_test.cc @@ -45,14 +45,19 @@ class TestObject : public ObjectRef { TVM_REGISTER_NODE_TYPE(TestObjectNode); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch([](TracedObject obj, IRDocsifier p) { return IdDoc("x"); }); + .set_dispatch([](TracedObject obj, IRDocsifier p) { + return IdDoc("x"); + }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("tir", [](TracedObject obj, IRDocsifier p) { return IdDoc("tir"); }); + .set_dispatch("tir", [](TracedObject obj, IRDocsifier p) { + return IdDoc("tir"); + }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch("relax", - [](TracedObject obj, IRDocsifier p) { return IdDoc("relax"); }); + .set_dispatch("relax", [](TracedObject obj, IRDocsifier p) { + return IdDoc("relax"); + }); TEST(PrinterIRDocsifierTest, AsDoc) { IRDocsifier p(Map{}); diff --git a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc index 374eb609b6cb..d662ce132405 100644 --- a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc +++ b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc @@ -33,7 +33,7 @@ class FooObjectNode : public Object { public: void VisitAttrs(AttrVisitor* v) {} - static constexpr const char* _type_key = "test.FooObject"; + static constexpr const char* _type_key = "test.TracedObjectFunctor.FooObject"; TVM_DECLARE_FINAL_OBJECT_INFO(FooObjectNode, Object); }; @@ -49,7 +49,7 @@ class BarObjectNode : public Object { public: void VisitAttrs(AttrVisitor* v) {} - static constexpr const char* _type_key = "test.BarObject"; + static constexpr const char* _type_key = "test.TracedObjectFunctor.BarObject"; TVM_DECLARE_FINAL_OBJECT_INFO(BarObjectNode, Object); }; @@ -69,8 +69,8 @@ TEST(TracedObjectFunctorTest, NormalRegistration) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); - functor.set_dispatch([](TracedObject o) -> String { return "Bar"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Bar"; }); ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); ICHECK_EQ(functor("", MakeTraced(BarObject(), path)), "Bar"); @@ -80,8 +80,8 @@ TEST(TracedObjectFunctorTest, RegistrationWithFunction) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o) -> String { return "FooLambda"; }); - functor.set_dispatch("tir", ComputeFoo); + functor.set_dispatch([](TracedObject o) -> String { return "FooLambda"; }); + functor.set_dispatch("tir", ComputeFoo); ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "FooLambda"); ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo"); @@ -91,9 +91,11 @@ TEST(TracedObjectFunctorTest, RegistrationWithDispatchToken) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); - functor.set_dispatch("tir", [](TracedObject o) -> String { return "Foo tir"; }); - functor.set_dispatch("relax", [](TracedObject o) -> String { return "Foo relax"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch("tir", + [](TracedObject o) -> String { return "Foo tir"; }); + functor.set_dispatch("relax", + [](TracedObject o) -> String { return "Foo relax"; }); ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir"); @@ -119,8 +121,8 @@ TEST(TracedObjectFunctorTest, ExtraArg) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o, int x) { return x; }); - functor.set_dispatch([](TracedObject o, int x) { return x + 1; }); + functor.set_dispatch([](TracedObject o, int x) { return x; }); + functor.set_dispatch([](TracedObject o, int x) { return x + 1; }); ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2); ICHECK_EQ(functor("", MakeTraced(BarObject(), path), 2), 3); @@ -131,8 +133,9 @@ TEST(TracedObjectFunctorTest, RemoveDispatchFunction) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); - functor.set_dispatch("tir", [](TracedObject o) -> String { return "Foo tir"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch("tir", + [](TracedObject o) -> String { return "Foo tir"; }); ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir"); @@ -158,11 +161,11 @@ TEST(TracedObjectFunctorTest, DuplicateRegistration_WithoutToken) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch([](TracedObject o, int x) { return x; }); + functor.set_dispatch([](TracedObject o, int x) { return x; }); bool failed = false; try { - functor.set_dispatch([](TracedObject o, int x) { return x; }); + functor.set_dispatch([](TracedObject o, int x) { return x; }); } catch (...) { failed = true; } @@ -173,11 +176,11 @@ TEST(TracedObjectFunctorTest, DuplicateRegistration_WithToken) { TracedObjectFunctor functor; ObjectPath path = ObjectPath::Root(); - functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); + functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); bool failed = false; try { - functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); + functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); } catch (...) { failed = true; } From 7551fe007eca2d5fe22ac0a3655b9b4ede27a402 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Wed, 17 Aug 2022 17:13:11 -0400 Subject: [PATCH 2/8] Add TIR var and type printing Co-authored-by: Greg Bonik --- src/script/printer/tir/tir.cc | 77 ++++++++++++++++ src/script/printer/tir/tir.h | 89 ++++++++++++++++++ src/script/printer/tir/type.cc | 69 ++++++++++++++ src/script/printer/tir/var.cc | 77 ++++++++++++++++ .../unittest/test_tvmscript_printer_tir.py | 92 +++++++++++++++++++ 5 files changed, 404 insertions(+) create mode 100644 src/script/printer/tir/tir.cc create mode 100644 src/script/printer/tir/tir.h create mode 100644 src/script/printer/tir/type.cc create mode 100644 src/script/printer/tir/var.cc create mode 100644 tests/python/unittest/test_tvmscript_printer_tir.py diff --git a/src/script/printer/tir/tir.cc b/src/script/printer/tir/tir.cc new file mode 100644 index 000000000000..38bd94a72bb5 --- /dev/null +++ b/src/script/printer/tir/tir.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "./tir.h" + +#include +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +TIRTopLevelFrame::TIRTopLevelFrame() : TIRFrame(make_object()) {} + +TIRGeneralFrame::TIRGeneralFrame() : TIRFrame(make_object()) {} + +ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p) { + auto type_annotation = var.GetAttr(&tir::VarNode::type_annotation); + if (type_annotation.Get().defined()) { + return p->AsExprDoc(type_annotation); + } else { + auto dtype = var.GetAttr(&tir::VarNode::dtype); + Type raw_type = GetTypeFromRuntimeDataType(dtype.Get()); + return p->AsExprDoc(MakeTraced(raw_type, dtype.GetPath())); + } +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject obj, IRDocsifier p) { + const ObjectRef& root_node = obj.Get()->root_node; + + TIRTopLevelFrame top_level_frame; + auto frame_ctx = p->WithFrame(top_level_frame); + + // Because we are printing a single element, concise scoping should be allowed by default + top_level_frame->allow_concise_scoping = true; + + Doc root_doc = p->AsDoc(MakeTraced(root_node)); + + Array doc_to_print = top_level_frame->free_var_definitions; + + if (const auto* stmt_doc_node = root_doc.as()) { + doc_to_print.push_back(GetRef(stmt_doc_node)); + } else if (const auto* expr_doc_node = root_doc.as()) { + doc_to_print.push_back(ExprStmtDoc(GetRef(expr_doc_node))); + } else if (const auto* stmt_block_node = root_doc.as()) { + doc_to_print = runtime::Concat(doc_to_print, stmt_block_node->stmts); + } else if (const auto* slice_doc_node = root_doc.as()) { + doc_to_print.push_back(ExprStmtDoc(IdDoc("_")[{GetRef(slice_doc_node)}])); + } else { + ICHECK(false) << "Cannot print " << root_doc->GetTypeKey() << " as top level doc for TIR."; + } + + return StmtBlockDoc(doc_to_print); + }); +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/tir.h b/src/script/printer/tir/tir.h new file mode 100644 index 000000000000..bb5973ee4f3b --- /dev/null +++ b/src/script/printer/tir/tir.h @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_TIR_TIR_H_ +#define TVM_SCRIPT_PRINTER_TIR_TIR_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +class TIRFrameNode : public FrameNode { + public: + mutable bool allow_concise_scoping{false}; + + void VisitAttrs(AttrVisitor* v) { + FrameNode::VisitAttrs(v); + v->Visit("allow_concise_scoping", &allow_concise_scoping); + } + + static constexpr const char* _type_key = "script.printer.TIRFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(TIRFrameNode, FrameNode); +}; + +class TIRFrame : public Frame { + public: + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRFrame, Frame, TIRFrameNode); +}; + +class TIRTopLevelFrameNode : public TIRFrameNode { + public: + Array free_var_definitions; + + void VisitAttrs(AttrVisitor* v) { + TIRFrameNode::VisitAttrs(v); + v->Visit("free_var_definitions", &free_var_definitions); + } + + static constexpr const char* _type_key = "script.printer.TIRTopLevelFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(TIRTopLevelFrameNode, FrameNode); +}; + +class TIRTopLevelFrame : public TIRFrame { + public: + TIRTopLevelFrame(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRTopLevelFrame, TIRFrame, + TIRTopLevelFrameNode); +}; + +class TIRGeneralFrameNode : public TIRFrameNode { + public: + static constexpr const char* _type_key = "script.printer.TIRGeneralFrame"; + TVM_DECLARE_BASE_OBJECT_INFO(TIRGeneralFrameNode, FrameNode); +}; + +class TIRGeneralFrame : public TIRFrame { + public: + TIRGeneralFrame(); + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(TIRGeneralFrame, TIRFrame, TIRGeneralFrameNode); +}; + +inline IdDoc TIR(const IRDocsifier& p) { return IdDoc(p->ir_prefix.Get("tir").value_or("T")); } + +ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_TIR_TIR_H_ diff --git a/src/script/printer/tir/type.cc b/src/script/printer/tir/type.cc new file mode 100644 index 000000000000..09aa96be7847 --- /dev/null +++ b/src/script/printer/tir/type.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../utils.h" +#include "./tir.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject ty, IRDocsifier p) -> Doc { + TracedBasicValue dtype = ty.GetAttr(&PrimTypeNode::dtype); + String ty_str = runtime::DLDataType2String(dtype.Get()); + return TIR(p)->Attr(MakeTraced(ty_str, ty.GetPath())); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject ty, IRDocsifier p) -> Doc { + TracedObject element_type = ty.GetAttr(&PointerTypeNode::element_type); + TracedObject storage_scope = ty.GetAttr(&PointerTypeNode::storage_scope); + + ExprDoc element_type_doc = p->AsDoc(element_type); + if (storage_scope.Get().empty()) { + return TIR(p)->Attr("Ptr")->Call({element_type_doc}); + } else { + return TIR(p)->Attr("Ptr")->Call({element_type_doc, LiteralDoc::Str(storage_scope)}); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch("tir", [](TracedObject ty, IRDocsifier p) -> Doc { + auto fields = ty.GetAttr(&TupleTypeNode::fields); + + if (fields.empty()) { + return LiteralDoc::None(fields.GetPath()); + } + return TIR(p)->Attr("Tuple")->Call(AsExprDocArray(fields, p)); + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/var.cc b/src/script/printer/tir/var.cc new file mode 100644 index 000000000000..e6c200e1fe8e --- /dev/null +++ b/src/script/printer/tir/var.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include + +#include "./tir.h" + +namespace tvm { +namespace script { +namespace printer { + +TracedObject GetVarNameHint(const TracedObject& var) { + TracedObject name_hint = var.GetAttr(&tir::VarNode::name_hint); + if (name_hint.Get().empty()) { + return MakeTraced(String("v"), var.GetPath()); + } else { + return name_hint; + } +} + +IdDoc CreateFreeVariableDefinition(TracedObject var, IRDocsifier p) { + TracedObject name_hint = GetVarNameHint(var); + // TODO(yelite): When implementing the PrimFunc printing, the logic here + // needs to change, putting variable def into PrimFuncFrame if it exists. + TIRTopLevelFrame top_level_frame = p->GetFrame().value(); + IdDoc doc = p->vars->Define(var.Get(), name_hint, top_level_frame); + StmtDoc def_doc = AssignDoc(doc, NullOpt, GetTypeAnnotationDocForVar(var, p)); + top_level_frame->free_var_definitions.push_back(def_doc); + return doc; +} + +ExprDoc PrintVariable(TracedObject var, IRDocsifier p) { + Optional doc = p->vars->GetVarDoc(var); + if (doc.defined()) { + return doc.value(); + } else { + return CreateFreeVariableDefinition(var, p); + } +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch(PrintVariable); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject var, IRDocsifier p) { + return PrintVariable(MakeTraced(var.Get(), var.GetPath()), p); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject v, IRDocsifier p) -> Doc { + LOG(FATAL) << "Cannot print IterVar directly. Please use the helper functions in tir.h for " + "specific usage of IterVar."; + throw; + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py new file mode 100644 index 000000000000..936a2d74b48e --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +from tvm.ir import PointerType, PrimType, TupleType +from tvm.script.printer import script +from tvm.tir import SizeVar, Var + + +def format_script(s: str) -> str: + """ + Remove leading and trailing blank lines, and make the minimum idention 0 + """ + s = s.strip("\n") + + non_empty_lines = [line for line in s.splitlines() if line and not line.isspace()] + if not non_empty_lines: + # no actual content + return "\n" + + line_indents = [len(line) - len(line.lstrip(" ")) for line in non_empty_lines] + spaces_to_remove = min(line_indents) + + cleaned_lines = "\n".join(line[spaces_to_remove:] for line in s.splitlines()) + if not cleaned_lines.endswith("\n"): + cleaned_lines += "\n" + return cleaned_lines + + +@pytest.mark.parametrize( + "ty, expected", + [ + ( + PrimType("int8"), + """ + T.int8 + """, + ), + ( + PrimType("float32"), + """ + T.float32 + """, + ), + ( + PointerType(PrimType("int32")), + """ + T.Ptr(T.int32) + """, + ), + ( + PointerType(PrimType("int32"), "global"), + """ + T.Ptr(T.int32, "global") + """, + ), + ( + TupleType([]), + """ + None + """, + ), + ], +) +def test_type(ty, expected): + assert format_script(expected) == script(ty, "tir", {"tir": "T"}) + + +@pytest.mark.parametrize("var_type", [Var, SizeVar]) +def test_var(var_type): + var = var_type("x", "int8") + + assert script(var, "tir", {"tir": "T"}) == format_script( + """ + x: T.int8 + x + """ + ) From 882cfc5c6f4c04c33734775a2650b872db4ebbc2 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 19 Aug 2022 12:57:52 -0400 Subject: [PATCH 3/8] Add TIR buffer printing Co-authored-by: Junru Shao Co-authored-by: Greg Bonik Fix lint --- include/tvm/script/printer/visit_traced.h | 37 +++ src/script/printer/tir/buffer.cc | 298 ++++++++++++++++++ src/script/printer/tir/buffer.h | 105 ++++++ src/script/printer/tir/expr.cc | 35 ++ src/script/printer/tir/tir.cc | 19 ++ src/script/printer/tir/tir.h | 7 + src/script/printer/utils.h | 11 + src/script/printer/visit_traced.cc | 91 ++++++ .../unittest/test_tvmscript_printer_tir.py | 89 +++++- 9 files changed, 684 insertions(+), 8 deletions(-) create mode 100644 include/tvm/script/printer/visit_traced.h create mode 100644 src/script/printer/tir/buffer.cc create mode 100644 src/script/printer/tir/buffer.h create mode 100644 src/script/printer/tir/expr.cc create mode 100644 src/script/printer/visit_traced.cc diff --git a/include/tvm/script/printer/visit_traced.h b/include/tvm/script/printer/visit_traced.h new file mode 100644 index 000000000000..62f5d3a5d641 --- /dev/null +++ b/include/tvm/script/printer/visit_traced.h @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_SCRIPT_PRINTER_VISIT_TRACED_H_ +#define TVM_SCRIPT_PRINTER_VISIT_TRACED_H_ + +#include + +namespace tvm { +namespace script { +namespace printer { + +void PostOrderVisitTraced(const TracedObject& object, + const std::function& node_predicate, + const std::function&)>& callback); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_VISIT_TRACED_H_ diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc new file mode 100644 index 000000000000..c85169865cb4 --- /dev/null +++ b/src/script/printer/tir/buffer.cc @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "./buffer.h" + +#include +#include +#include + +#include + +#include "../utils.h" +#include "./tir.h" +#include "tvm/runtime/data_type.h" + +namespace tvm { +namespace script { +namespace printer { + +ExprDoc BufferPrintInfo::AsCall( + const ExprDoc& prefix, std::function&)> converter) const { + return AsCall(prefix, {}, converter); +} + +ExprDoc BufferPrintInfo::AsCall( + const ExprDoc& prefix, const Array& extra_args, + std::function&)> converter) const { + Array args(extra_args); + Array kwargs_keys; + Array kwargs_values; + { + Array results; + results.reserve(shape.size()); + for (TracedObject e : shape) { + results.push_back(converter(e)); + } + kwargs_keys.push_back("shape"); + kwargs_values.push_back(TupleDoc(results)); + } + if (dtype.defined()) { + args.push_back(dtype.value()); + } + if (data.defined()) { + kwargs_keys.push_back("data"); + kwargs_values.push_back(converter(data.value())); + } + if (strides.defined()) { + Array results; + results.reserve(strides.value().size()); + for (TracedObject stride : strides.value()) { + results.push_back(converter(stride)); + } + kwargs_keys.push_back("strides"); + kwargs_values.push_back(TupleDoc(results)); + } + if (elem_offset.defined()) { + kwargs_keys.push_back("elem_offset"); + kwargs_values.push_back(converter(elem_offset.value())); + } + if (scope.defined()) { + kwargs_keys.push_back("scope"); + kwargs_values.push_back(scope.value()); + } + if (align.defined()) { + kwargs_keys.push_back("align"); + kwargs_values.push_back(align.value()); + } + if (offset_factor.defined()) { + kwargs_keys.push_back("offset_factor"); + kwargs_values.push_back(offset_factor.value()); + } + if (buffer_type.defined()) { + kwargs_keys.push_back("buffer_type"); + kwargs_values.push_back(buffer_type.value()); + } + return prefix->Call(args, kwargs_keys, kwargs_values); +} + +static Optional GetBufferScope(const TracedObject& buffer) { + auto data = buffer.GetAttr(&tir::BufferNode::data); + auto type = data.GetAttr(&tir::VarNode::type_annotation).Downcast(); + auto scope = type.GetAttr(&PointerTypeNode::storage_scope); + if (scope.Get().empty() || scope.Get() == "global") { + return NullOpt; + } else { + return LiteralDoc::Str(scope); + } +} + +static Optional GetBufferDtype(const TracedObject& buffer) { + auto dtype = buffer.GetAttr(&tir::BufferNode::dtype); + if (dtype.Get() == DataType::Float(32)) { + return NullOpt; + } else { + return DType2Literal(dtype); + } +} + +static bool HasDefaultDataPtr(const tir::Buffer& buffer) { + const auto* ptr_type = buffer->data->type_annotation.as(); + ICHECK(ptr_type) << "Buffer variable is not of pointer type"; + if (const auto* element_type = ptr_type->element_type.as()) { + DataType default_dtype = buffer->dtype; + if (buffer->dtype.is_bool()) { + default_dtype = DataType::Int(8); + } + return element_type->dtype == default_dtype; + } else { + return false; + } +} + +static TracedOptional GetBufferData(const TracedObject& buffer, + const BufferAssociatedVariables& associated_vars) { + auto data = buffer.GetAttr(&tir::BufferNode::data); + if (associated_vars.IsAssociatedWith(data.Get(), buffer.Get()) && + HasDefaultDataPtr(buffer.Get())) { + return TracedOptional(NullOpt, data.GetPath()); + } else { + return data; + } +} + +static TracedOptional> GetBufferStrides(const TracedObject& buffer) { + auto strides = buffer.GetAttr(&tir::BufferNode::strides); + if (!strides.empty()) { + return TracedOptional>(strides); + } else { + return TracedOptional>(NullOpt, strides.GetPath()); + } +} + +static TracedOptional GetBufferElemOffset( + const TracedObject& buffer, const BufferAssociatedVariables& associated_vars) { + auto elem_offset = buffer.GetAttr(&tir::BufferNode::elem_offset); + if (elem_offset.defined()) { + // Don't print the offset if it is an associated variable + if (elem_offset.IsInstance() && + associated_vars.IsAssociatedWith(elem_offset.Get(), buffer.Get())) { + return TracedOptional(NullOpt, elem_offset.GetPath()); + } + + // Don't print the offset if it is zero + if (auto i = elem_offset.TryDowncast()) { + if (i.value().Get()->value == 0 && i.value().Get()->dtype == DataType::Int(32)) { + return TracedOptional(NullOpt, elem_offset.GetPath()); + } + } + } + return elem_offset; +} + +static Optional GetBufferAlignment(const TracedObject& buffer) { + auto data_alignment = buffer.GetAttr(&tir::BufferNode::data_alignment); + if (data_alignment.Get() != runtime::kAllocAlignment) { + return LiteralDoc::Int(data_alignment); + } else { + return NullOpt; + } +} + +static Optional GetBufferOffsetFactor(const TracedObject& buffer) { + auto offset_factor = buffer.GetAttr(&tir::BufferNode::offset_factor); + if (offset_factor.Get() != 1) { + return LiteralDoc::Int(offset_factor); + } else { + return NullOpt; + } +} + +static Optional GetBufferType(const TracedObject& buffer) { + auto buffer_type = buffer.GetAttr(&tir::BufferNode::buffer_type); + if (buffer_type.Get() != tir::BufferType::kDefault) { + return LiteralDoc::Str(MakeTraced(String("auto"), buffer_type.GetPath())); + } else { + return NullOpt; + } +} + +std::vector GetBufferPrintInfo( + const std::vector>& buffers, // + std::function f_var_defined, + std::unordered_map* var_explicit_def, + BufferAssociatedVariables* associated_vars) { + using namespace tvm::tir; + auto check_explicit_def = [&](const TracedObject& e) -> void { + PostOrderVisitExprTraced(e, [&](const TracedObject& n) -> void { + if (const auto* v = n.Get().as()) { + if (!f_var_defined(v) && !associated_vars->IsAssociated(v)) { + var_explicit_def->insert({v, n.GetPath()}); + } + } + }); + }; + for (const TracedObject& traced_buffer : buffers) { + const Buffer& buffer = traced_buffer.Get(); + if (!f_var_defined(buffer->data.get()) && HasDefaultDataPtr(buffer)) { + associated_vars->AssociateIfNotAlready(buffer->data.get(), buffer); + } + if (const auto* elem_offset = buffer->elem_offset.as()) { + if (!f_var_defined(elem_offset)) { + associated_vars->AssociateIfNotAlready(elem_offset, buffer); + } + } + } + for (TracedObject buffer : buffers) { + auto shape = buffer.GetAttr(&tir::BufferNode::shape); + std::for_each(shape.begin(), shape.end(), check_explicit_def); + + auto strides = buffer.GetAttr(&tir::BufferNode::strides); + std::for_each(strides.begin(), strides.end(), check_explicit_def); + + check_explicit_def(buffer.GetAttr(&tir::BufferNode::data)); + check_explicit_def(buffer.GetAttr(&tir::BufferNode::elem_offset)); + } + std::vector results; + for (TracedObject buffer : buffers) { + results.push_back( + BufferPrintInfo{/* .buffer = */ buffer, + /* .shape = */ buffer.GetAttr(&tir::BufferNode::shape), + /* .dtype = */ GetBufferDtype(buffer), + /* .data = */ GetBufferData(buffer, *associated_vars), + /* .strides = */ GetBufferStrides(buffer), + /* .elem_offset = */ GetBufferElemOffset(buffer, *associated_vars), + /* .scope = */ GetBufferScope(buffer), + /* .align = */ GetBufferAlignment(buffer), + /* .offset_factor = */ GetBufferOffsetFactor(buffer), + /* .buffer_type = */ GetBufferType(buffer)}); + } + return results; +} + +TracedObject GetBufferNameHint(const TracedObject& buf) { + TracedObject name_hint = buf.GetAttr(&tir::BufferNode::name); + if (name_hint.Get().empty()) { + return MakeTraced(String("buf"), buf.GetPath()); + } else { + return name_hint; + } +} + +IdDoc DefineFreeBuffer(const TracedObject& buf, const Frame& frame, + const IRDocsifier& p, std::function add_definiton) { + TracedObject name_hint = GetBufferNameHint(buf); + IdDoc buf_doc = p->vars->Define(buf.Get(), name_hint, frame); + auto f_var_defined = [&p](const tir::VarNode* var) -> bool { + return p->vars->IsVarDefined(GetRef(var)); + }; + std::unordered_map var_explicit_def; + BufferAssociatedVariables associated_vars; + + BufferPrintInfo buffer_print_info = + GetBufferPrintInfo({buf}, f_var_defined, &var_explicit_def, &associated_vars)[0]; + + associated_vars.Define(p->vars.get(), frame); + + ExprDoc buf_definition = buffer_print_info.AsCall( + TIR(p)->Attr("Buffer"), + [&p](const TracedObject& expr) -> ExprDoc { return p->AsDoc(expr); }); + add_definiton(AssignDoc(buf_doc, NullOpt, buf_definition)); + + return buf_doc; +} + +ExprDoc PrintBuffer(TracedObject buf, IRDocsifier p) { + Optional doc = p->vars->GetVarDoc(buf); + if (doc.defined()) { + return doc.value(); + } else { + // TODO(yelite): When implementing the PrimFunc printing, the logic here + // needs to change, putting variable def into PrimFuncFrame if it exists. + TIRTopLevelFrame top_level_frame = p->GetFrame().value(); + return DefineFreeBuffer(buf, top_level_frame, p, [top_level_frame](StmtDoc definition) { + top_level_frame->free_var_definitions.push_back(definition); + }); + } +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch(PrintBuffer); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/buffer.h b/src/script/printer/tir/buffer.h new file mode 100644 index 000000000000..10c2039b4e57 --- /dev/null +++ b/src/script/printer/tir/buffer.h @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_TIR_BUFFER_H_ +#define TVM_SCRIPT_PRINTER_TIR_BUFFER_H_ + +#include +#include +#include + +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +class BufferAssociatedVariables { + public: + void Disassociate(const tir::VarNode* var) { var2buffer_.erase(var); } + + void AssociateIfNotAlready(const tir::VarNode* var, const tir::Buffer& buffer) { + var2buffer_.insert({var, buffer}); + } + + bool IsAssociated(const tir::VarNode* var) const { return var2buffer_.count(var) != 0; } + + bool IsAssociatedWith(const PrimExpr& e, const tir::Buffer& buffer) const { + if (const auto* v = e.as()) { + auto it = var2buffer_.find(v); + return it != var2buffer_.end() && it->second == buffer; + } + return false; + } + + void Define(VarTableNode* vars, const Frame& frame) const { + for (const auto& kv : var2buffer_) { + const tir::VarNode* var = kv.first; + const tir::Buffer& buffer = kv.second; + + ExprDoc buffer_name = vars->GetVarDoc(MakeTraced(buffer)).value(); + buffer_name->source_paths.clear(); + + if (buffer->data.get() == var) { + vars->DefineByDoc( + buffer->data, [buffer_name]() { return buffer_name->Attr("data"); }, frame); + } else if (buffer->elem_offset.get() == var) { + vars->DefineByDoc( + buffer->elem_offset, [buffer_name]() { return buffer_name->Attr("elem_offset"); }, + frame); + } else { + ICHECK(false) << "Unexpected association. Buffer: " << buffer + << "; Var: " << GetRef(var); + } + } + } + + private: + std::unordered_map var2buffer_; +}; + +struct BufferPrintInfo { + TracedObject buffer; + TracedArray shape; + Optional dtype; + TracedOptional data; + TracedOptional> strides; + TracedOptional elem_offset; + Optional scope; + Optional align; + Optional offset_factor; + Optional buffer_type; + + ExprDoc AsCall(const ExprDoc& prefix, + std::function&)> converter) const; + ExprDoc AsCall(const ExprDoc& prefix, const Array& extra_args, + std::function&)> converter) const; +}; + +std::vector GetBufferPrintInfo( + const std::vector>& buffers, // + std::function f_var_defined, + std::unordered_map* var_explicit_def, + BufferAssociatedVariables* associated_vars); + +} // namespace printer +} // namespace script +} // namespace tvm + +#endif // TVM_SCRIPT_PRINTER_TIR_BUFFER_H_ diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc new file mode 100644 index 000000000000..fad85f269668 --- /dev/null +++ b/src/script/printer/tir/expr.cc @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include "../utils.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject i, IRDocsifier p) { return LiteralDoc::Int(i); }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/tir.cc b/src/script/printer/tir/tir.cc index 38bd94a72bb5..48975791bd15 100644 --- a/src/script/printer/tir/tir.cc +++ b/src/script/printer/tir/tir.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include namespace tvm { @@ -33,6 +34,24 @@ TIRTopLevelFrame::TIRTopLevelFrame() : TIRFrame(make_object()) {} +void PostOrderVisitExprTraced(const TracedObject& expr, + const std::function&)>& callback) { + PostOrderVisitTraced( + expr, [](const ObjectRef& object) { return object->IsInstance(); }, + [&](const TracedObject& object) { callback(object.Downcast()); }); +} + +void PostOrderVisitStmtExprTraced( + const TracedObject& stmt, + const std::function&)>& callback) { + PostOrderVisitTraced( + stmt, + [](const ObjectRef& object) { + return object->IsInstance() || object->IsInstance(); + }, + [&](const TracedObject& object) { callback(object); }); +} + ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p) { auto type_annotation = var.GetAttr(&tir::VarNode::type_annotation); if (type_annotation.Get().defined()) { diff --git a/src/script/printer/tir/tir.h b/src/script/printer/tir/tir.h index bb5973ee4f3b..eb9b3f544d24 100644 --- a/src/script/printer/tir/tir.h +++ b/src/script/printer/tir/tir.h @@ -82,6 +82,13 @@ inline IdDoc TIR(const IRDocsifier& p) { return IdDoc(p->ir_prefix.Get("tir").va ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p); +void PostOrderVisitExprTraced(const TracedObject& expr, + const std::function&)>& callback); + +void PostOrderVisitStmtExprTraced( + const TracedObject& expr, + const std::function&)>& callback); + } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index abe7ce5e9a88..eef93b9348d9 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -86,6 +86,17 @@ inline TupleDoc AsTupleDoc(const TracedArray& arr, const IRDocsifier& ir_docs return ret; } +inline LiteralDoc DType2Literal(const DLDataType& dtype) { + using runtime::DLDataType2String; + return LiteralDoc::Str(DLDataType2String(dtype)); +} + +inline LiteralDoc DType2Literal(const TracedBasicValue& dtype) { + auto doc = DType2Literal(dtype.Get()); + doc->source_paths.push_back(dtype.GetPath()); + return doc; +} + } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/visit_traced.cc b/src/script/printer/visit_traced.cc new file mode 100644 index 000000000000..87ab787c2074 --- /dev/null +++ b/src/script/printer/visit_traced.cc @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +void PostOrderVisitTracedImpl(const ObjectRef& object, const ObjectPath& path, + const std::function& node_predicate, + const std::function&)>& callback); + +struct ObjAttrVisitor : public AttrVisitor { + ObjAttrVisitor(const ObjectPath& path, const std::function node_predicate, + const std::function&)>& callback) + : path(path), node_predicate(node_predicate), callback(callback) {} + + const ObjectPath& path; + const std::function node_predicate; + const std::function&)>& callback; + + void Visit(const char* key, double* value) final {} + void Visit(const char* key, int64_t* value) final {} + void Visit(const char* key, uint64_t* value) final {} + void Visit(const char* key, int* value) final {} + void Visit(const char* key, bool* value) final {} + void Visit(const char* key, void** value) final {} + void Visit(const char* key, DataType* value) final {} + void Visit(const char* key, std::string* value) final {} + void Visit(const char* key, runtime::NDArray* value) final {} + void Visit(const char* key, ObjectRef* value) final { + PostOrderVisitTracedImpl(*value, path->Attr(key), node_predicate, callback); + } +}; + +void PostOrderVisitTracedImpl(const ObjectRef& object, const ObjectPath& path, + const std::function& node_predicate, + const std::function&)>& callback) { + if (!object.defined()) { + return; + } + + if (object->IsInstance()) { + const ArrayNode* node = static_cast(object.get()); + for (size_t i = 0; i < node->size(); ++i) { + PostOrderVisitTracedImpl(node->at(i), path->ArrayIndex(i), node_predicate, callback); + } + } else if (object->IsInstance()) { + const MapNode* node = static_cast(object.get()); + for (auto kv : *node) { + PostOrderVisitTracedImpl(kv.second, path->MapValue(kv.first), node_predicate, callback); + } + } else { + if (!node_predicate(object)) { + return; + } + + ObjAttrVisitor visitor(path, node_predicate, callback); + ReflectionVTable::Global()->VisitAttrs(const_cast(object.get()), &visitor); + + callback(MakeTraced(object, path)); + } +} + +void PostOrderVisitTraced(const TracedObject& object, + const std::function& node_predicate, + const std::function&)>& callback) { + PostOrderVisitTracedImpl(object.Get(), object.GetPath(), node_predicate, callback); +} + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 936a2d74b48e..2d597f512a87 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -18,7 +18,7 @@ from tvm.ir import PointerType, PrimType, TupleType from tvm.script.printer import script -from tvm.tir import SizeVar, Var +from tvm.tir import SizeVar, Var, decl_buffer def format_script(s: str) -> str: @@ -41,52 +41,125 @@ def format_script(s: str) -> str: return cleaned_lines +def as_tir_script(node): + return script(node, "tir", {"tir": "T"}) + + @pytest.mark.parametrize( "ty, expected", [ - ( + pytest.param( PrimType("int8"), """ T.int8 """, + id="int", ), - ( + pytest.param( PrimType("float32"), """ T.float32 """, + id="float", ), - ( + pytest.param( PointerType(PrimType("int32")), """ T.Ptr(T.int32) """, + id="pointer", ), - ( + pytest.param( PointerType(PrimType("int32"), "global"), """ T.Ptr(T.int32, "global") """, + id="with_scope", ), - ( + pytest.param( TupleType([]), """ None """, + id="none", ), ], ) def test_type(ty, expected): - assert format_script(expected) == script(ty, "tir", {"tir": "T"}) + assert format_script(expected) == as_tir_script(ty) @pytest.mark.parametrize("var_type", [Var, SizeVar]) def test_var(var_type): var = var_type("x", "int8") - assert script(var, "tir", {"tir": "T"}) == format_script( + assert as_tir_script(var) == format_script( """ x: T.int8 x """ ) + + +@pytest.mark.parametrize( + "buffer, expected", + [ + pytest.param( + decl_buffer((5, 10), name="b"), + """ + b: T.Buffer(shape=(5, 10)) + b + """, + id="simple", + ), + pytest.param( + decl_buffer((5), name=""), + """ + buf: T.Buffer(shape=(5,)) + buf + """, + id="no_name", + ), + pytest.param( + decl_buffer((SizeVar("m", "int"), SizeVar("n", "int")), dtype="int8"), + """ + m: T.int32 + n: T.int32 + buffer: T.Buffer("int8", shape=(m, n)) + buffer + """, + id="symbolic_shape", + ), + pytest.param( + decl_buffer( + (4, 10), + dtype="int8", + data=Var("p", PointerType(PrimType("int16"), "local")), + strides=[2, 5], + elem_offset=2, + data_alignment=16, + offset_factor=2, + scope="local", + ), + """ + p: T.Ptr(T.int16, "local") + buffer: T.Buffer("int8", shape=(4, 10), data=p, strides=(2, 5), elem_offset=2, scope="local", align=16, offset_factor=2) + buffer + """, + id="all_param", + ), + pytest.param( + decl_buffer( + (4, 10), + dtype="bool", + ), + """ + buffer: T.Buffer("bool", shape=(4, 10)) + buffer + """, + id="bool_different_ptr_type", + ), + ], +) +def test_buffer(buffer, expected): + assert as_tir_script(buffer) == format_script(expected) From 373ff0f2f07ba45b4a1e53577b611682a318d7e9 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Tue, 23 Aug 2022 00:54:07 -0400 Subject: [PATCH 4/8] Add TIR expr printing Co-authored-by: Greg Bonik --- src/script/printer/tir/expr.cc | 249 ++++++++- src/script/printer/tir/tir.cc | 57 +++ src/script/printer/tir/tir.h | 7 + .../unittest/test_tvmscript_printer_tir.py | 481 +++++++++++++++++- 4 files changed, 791 insertions(+), 3 deletions(-) diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index fad85f269668..f3eb46ca5fe7 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -17,19 +17,266 @@ * under the License. */ +#include "tvm/ir/expr.h" + +#include +#include #include #include #include +#include #include "../utils.h" +#include "./tir.h" namespace tvm { namespace script { namespace printer { TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch([](TracedObject i, IRDocsifier p) { return LiteralDoc::Int(i); }); + .set_dispatch([](TracedObject s, IRDocsifier p) { + auto value = s.GetAttr(&tir::StringImmNode::value); + return LiteralDoc::Str(value); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject i, IRDocsifier p) -> ExprDoc { + const IntImm& node = i.Get(); + if (node->dtype == DataType::Int(32)) { + return LiteralDoc::Int(i); + } else if (node->dtype.is_bool()) { + return LiteralDoc::Boolean(MakeTraced(i.Get()->value != 0, i.GetPath())); + } else { + String type_name = runtime::DLDataType2String(node->dtype); + return TIR(p) + ->Attr(MakeTraced(type_name, i.GetAttr(&PrimExprNode::dtype).GetPath())) + ->Call({LiteralDoc::Int(i)}); + } + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject f, IRDocsifier p) { + String type_name = runtime::DLDataType2String(f.Get()->dtype); + return TIR(p) + ->Attr(MakeTraced(type_name, f.GetAttr(&PrimExprNode::dtype).GetPath())) + ->Call({LiteralDoc::Float(f)}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject cast, IRDocsifier p) { + auto value = cast.GetAttr(&tir::CastNode::value); + auto dtype = cast.GetAttr(&tir::CastNode::dtype); + return TIR(p)->Attr("cast")->Call({p->AsExprDoc(value), DType2Literal(dtype)}); + }); + +using OpKind = OperationDocNode::Kind; + +template +ExprDoc PrintBinOp(TracedObject expr, IRDocsifier p) { + using NodeType = typename BinOpType::ContainerType; + auto a = expr.GetAttr(&NodeType::a); + auto b = expr.GetAttr(&NodeType::b); + return OperationDoc(op_kind, {p->AsExprDoc(a), p->AsExprDoc(b)}); +} + +#define TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(Op, DocOpKind) \ + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch(PrintBinOp); + +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::Add, OpKind::kAdd); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::Sub, OpKind::kSub); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::Mul, OpKind::kMult); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::Div, OpKind::kDiv); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::FloorDiv, OpKind::kFloorDiv); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::FloorMod, OpKind::kMod); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::LT, OpKind::kLt); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::LE, OpKind::kLtE); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::GT, OpKind::kGt); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::GE, OpKind::kGtE); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::EQ, OpKind::kEq); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::NE, OpKind::kNotEq); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::And, OpKind::kAnd); +TVM_SCRIPT_PRINTER_SET_TIR_BINARY_OP(tir::Or, OpKind::kOr); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject e, IRDocsifier p) { + return OperationDoc(OpKind::kNot, {p->AsExprDoc(e.GetAttr(&tir::NotNode::a))}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject expr, IRDocsifier p) { + auto condition = expr.GetAttr(&tir::SelectNode::condition); + auto true_value = expr.GetAttr(&tir::SelectNode::true_value); + auto false_value = expr.GetAttr(&tir::SelectNode::false_value); + return TIR(p)->Attr("Select")->Call( + {p->AsExprDoc(condition), p->AsExprDoc(true_value), p->AsExprDoc(false_value)}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject expr, IRDocsifier p) { + auto buffer = expr.GetAttr(&tir::BufferLoadNode::buffer); + auto indices = expr.GetAttr(&tir::BufferLoadNode::indices); + + ExprDoc base = p->AsExprDoc(buffer); + return base[AsDocArray(indices, p)]; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject e, IRDocsifier p) -> Doc { + LOG(FATAL) + << "Cannot print a tir.ProducerLoad as it is not valid in TIR Primfuncs. You need to " + "lower this function first."; + throw; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject e, IRDocsifier p) -> Doc { + LOG(FATAL) << "Cannot print a tir.Load"; + throw; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject expr, IRDocsifier p) { + auto base = expr.GetAttr(&tir::RampNode::base); + auto stride = expr.GetAttr(&tir::RampNode::stride); + auto lanes = expr.GetAttr(&tir::RampNode::lanes); + return TIR(p)->Attr("ramp")->Call( + {p->AsExprDoc(base), p->AsExprDoc(stride), LiteralDoc::Int(lanes)}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject expr, IRDocsifier p) { + auto value = expr.GetAttr(&tir::BroadcastNode::value); + auto lanes = expr.GetAttr(&tir::BroadcastNode::lanes); + return TIR(p)->Attr("broadcast")->Call({p->AsExprDoc(value), LiteralDoc::Int(lanes)}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject expr, IRDocsifier p) { + TIRGeneralFrame frame; + WithCtx with_frame = p->WithFrame(frame); + + auto var = expr.GetAttr(&tir::LetNode::var); + auto value = expr.GetAttr(&tir::LetNode::value); + auto body = expr.GetAttr(&tir::LetNode::body); + + auto value_doc = p->AsExprDoc(value); + IdDoc var_doc = DefineTIRVar(var, frame, p); + return TIR(p)->Attr("let")->Call({var_doc, value_doc, p->AsExprDoc(body)}); + }); + +ExprDoc PrintCall(TracedObject call, IRDocsifier p) { + auto op_or_global_var = call.GetAttr(&tir::CallNode::op); + + if (op_or_global_var.IsInstance()) { + // TODO(yelite): Call PrintOpCall once it's finished + TracedObject op_name = op_or_global_var.Downcast().GetAttr(&OpNode::name); + Array arg_docs{LiteralDoc::Str(op_name)}; + TracedArray args = call.GetAttr(&tir::CallNode::args); + arg_docs = Concat(arg_docs, AsExprDocArray(args, p)); + return TIR(p)->Attr("call")->Call(arg_docs); + } else { + auto op_gvar = op_or_global_var.Downcast(); + auto name_hint = op_gvar.GetAttr(&GlobalVarNode::name_hint); + auto args = call.GetAttr(&tir::CallNode::args); + + IdDoc name_doc(name_hint.Get()); + name_doc->source_paths.push_back(name_hint.GetPath()); + + return name_doc->Call(AsExprDocArray(args, p)); + } +} +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch(PrintCall); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject expr, IRDocsifier p) { + auto vectors = expr.GetAttr(&tir::ShuffleNode::vectors); + auto indices = expr.GetAttr(&tir::ShuffleNode::indices); + return TIR(p)->Attr("shuffle")->Call({AsListDoc(vectors, p), AsListDoc(indices, p)}); + }); + +ExprDoc PrintCommReducer(TracedObject expr, IRDocsifier p) { + TIRGeneralFrame frame; + WithCtx with_frame = p->WithFrame(frame); + + auto lhs = expr.GetAttr(&tir::CommReducerNode::lhs); + auto rhs = expr.GetAttr(&tir::CommReducerNode::rhs); + + Array reducer_args; + for (TracedObject v_lhs : lhs) { + IdDoc var_doc = DefineTIRVar(v_lhs, frame, p); + reducer_args.push_back(var_doc); + } + for (TracedObject v_rhs : rhs) { + IdDoc var_doc = DefineTIRVar(v_rhs, frame, p); + reducer_args.push_back(var_doc); + } + + auto result = expr.GetAttr(&tir::CommReducerNode::result); + + ExprDoc reducer_body = rhs.size() == 1 ? p->AsExprDoc(result[0]) : AsTupleDoc(result, p); + + LambdaDoc reducer{reducer_args, reducer_body}; + + auto identity_element = expr.GetAttr(&tir::CommReducerNode::identity_element); + ListDoc identity_elements_doc = AsListDoc(identity_element, p); + + return TIR(p)->Attr("comm_reducer")->Call({reducer, identity_elements_doc}); +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch(PrintCommReducer); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject expr, IRDocsifier p) { + auto combiner = expr.GetAttr(&tir::ReduceNode::combiner); + auto source = expr.GetAttr(&tir::ReduceNode::source); + auto axis = expr.GetAttr(&tir::ReduceNode::axis); + auto value_index = expr.GetAttr(&tir::ReduceNode::value_index); + + Array axis_docs; + for (const auto& iter_var : axis) { + axis_docs.push_back(IterVarStandaloneDef(iter_var, p)); + } + ListDoc axis_list_doc = ListDoc(axis_docs); + axis_list_doc->source_paths.push_back(axis.GetPath()); + + return TIR(p)->Attr("reduce")->Call({p->AsExprDoc(combiner), AsListDoc(source, p), + axis_list_doc, LiteralDoc::Int(value_index)}); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject expr, IRDocsifier p) { + auto min = expr.GetAttr(&RangeNode::min); + auto extent = expr.GetAttr(&RangeNode::extent); + auto max = MakeTraced(min.Get() + extent.Get(), extent.GetPath()); + return SliceDoc(p->AsExprDoc(min), p->AsExprDoc(max), NullOpt); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject e, IRDocsifier p) -> Doc { + LOG(FATAL) << "Cannot print tir::Any"; + throw; + }); + +ExprDoc PrintBufferRegion(TracedObject buffer_region, IRDocsifier p) { + auto region = buffer_region.GetAttr(&tir::BufferRegionNode::region); + + Array indices; + + for (TracedObject range : region) { + auto extent = range.GetAttr(&RangeNode::extent); + if (tir::is_one(extent.Get())) { + auto index = p->AsExprDoc(range.GetAttr(&RangeNode::min)); + index->source_paths.push_back(extent.GetPath()); + indices.push_back(std::move(index)); + } else { + indices.push_back(p->AsDoc(range)); + } + } + auto buffer = buffer_region.GetAttr(&tir::BufferRegionNode::buffer); + return p->AsExprDoc(buffer)[indices]; +} +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch(PrintBufferRegion); } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/tir/tir.cc b/src/script/printer/tir/tir.cc index 48975791bd15..cbe88d5bc2c4 100644 --- a/src/script/printer/tir/tir.cc +++ b/src/script/printer/tir/tir.cc @@ -63,6 +63,63 @@ ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDo } } +static String GetIterTypePyStr(tir::IterVarType iter_type) { + switch (iter_type) { + case tir::kDataPar: + return "DataPar"; + case tir::kThreadIndex: + return "ThreadIndex"; + case tir::kCommReduce: + return "CommReduce"; + case tir::kOrdered: + return "Ordered"; + case tir::kOpaque: + return "DimInfo"; + case tir::kUnrolled: + return "Unrolled"; + case tir::kVectorized: + return "Vectorized"; + case tir::kParallelized: + return "Parallelized"; + case tir::kTensorized: + return "Tensorized"; + default: + LOG(FATAL) << "Unknown iter type: " << iter_type; + throw; + } +} + +ExprDoc IterVarStandaloneDef(const TracedObject iter_var, const IRDocsifier& p) { + Array args; + + args.push_back(p->AsExprDoc(iter_var.GetAttr(&tir::IterVarNode::var))); + + if (iter_var.Get()->dom.defined()) { + auto dom = iter_var.GetAttr(&tir::IterVarNode::dom); + auto min = dom.GetAttr(&RangeNode::min); + auto extent = dom.GetAttr(&RangeNode::extent); + if (tir::is_zero(min.Get())) { + auto extent_doc = p->AsExprDoc(extent); + extent_doc->source_paths.push_back(min.GetPath()); + args.push_back(extent_doc); + } else { + auto max = MakeTraced(min.Get() + extent.Get(), extent.GetPath()); + args.push_back(TupleDoc({p->AsExprDoc(min), p->AsExprDoc(max)})); + } + } else { + args.push_back(LiteralDoc::None()); + } + + auto iter_type = iter_var.GetAttr(&tir::IterVarNode::iter_type); + args.push_back( + LiteralDoc::Str(MakeTraced(GetIterTypePyStr(iter_type.Get()), iter_type.GetPath()))); + args.push_back(LiteralDoc::Str(iter_var.GetAttr(&tir::IterVarNode::thread_tag))); + + ExprDoc result = TIR(p)->Attr("iter_var")->Call(args); + result->source_paths.push_back(iter_var.GetPath()); + return result; +} + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch("tir", [](TracedObject obj, IRDocsifier p) { const ObjectRef& root_node = obj.Get()->root_node; diff --git a/src/script/printer/tir/tir.h b/src/script/printer/tir/tir.h index eb9b3f544d24..6020fe335e54 100644 --- a/src/script/printer/tir/tir.h +++ b/src/script/printer/tir/tir.h @@ -80,6 +80,10 @@ class TIRGeneralFrame : public TIRFrame { inline IdDoc TIR(const IRDocsifier& p) { return IdDoc(p->ir_prefix.Get("tir").value_or("T")); } +inline IdDoc DefineTIRVar(const TracedObject& var, const Frame& frame, IRDocsifier p) { + return p->vars->Define(var.Get(), var.GetAttr(&tir::VarNode::name_hint), frame); +} + ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p); void PostOrderVisitExprTraced(const TracedObject& expr, @@ -89,6 +93,9 @@ void PostOrderVisitStmtExprTraced( const TracedObject& expr, const std::function&)>& callback); +// Print IterVar as T.iter_var(...) +ExprDoc IterVarStandaloneDef(const TracedObject iter_var, const IRDocsifier& p); + } // namespace printer } // namespace script } // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 2d597f512a87..b03ce24fa808 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -16,9 +16,43 @@ # under the License. import pytest -from tvm.ir import PointerType, PrimType, TupleType +from tvm.ir import GlobalVar, PointerType, PrimType, Range, TupleType from tvm.script.printer import script -from tvm.tir import SizeVar, Var, decl_buffer +from tvm.tir import ( + EQ, + GE, + GT, + LE, + LT, + NE, + Add, + And, + Broadcast, + BufferLoad, + BufferRegion, + Call, + Cast, + CommReducer, + Div, + FloatImm, + FloorDiv, + FloorMod, + IntImm, + IterVar, + Let, + Mul, + Not, + Or, + Ramp, + Reduce, + Select, + Shuffle, + SizeVar, + StringImm, + Sub, + Var, + decl_buffer, +) def format_script(s: str) -> str: @@ -163,3 +197,446 @@ def test_var(var_type): ) def test_buffer(buffer, expected): assert as_tir_script(buffer) == format_script(expected) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + StringImm("test"), + """ + "test" + """, + id="string", + ), + pytest.param( + StringImm(""), + """ + "" + """, + id="empty", + ), + pytest.param( + StringImm("test1\ntest2\n"), + r""" + "test1\ntest2\n" + """, + id="multiline", + ), + ], +) +def test_string_imm(node, expected): + assert as_tir_script(node) == format_script(expected) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + IntImm("int32", 1), + """ + 1 + """, + id="default-dtype", + ), + pytest.param( + IntImm("int8", 0), + """ + T.int8(0) + """, + id="int8", + ), + pytest.param( + IntImm("int64", -1), + """ + T.int64(-1) + """, + id="int64", + ), + pytest.param( + IntImm("bool", 1), + """ + True + """, + id="boolean-true", + ), + pytest.param( + IntImm("bool", 0), + """ + False + """, + id="boolean-true", + ), + ], +) +def test_int_imm(node, expected): + assert as_tir_script(node) == format_script(expected) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + FloatImm("float32", 1.5), + """ + T.float32(1.5) + """, + id="f32", + ), + pytest.param( + FloatImm("float16", 1.5), + """ + T.float16(1.5) + """, + id="f16", + ), + ], +) +def test_float_imm(node, expected): + assert as_tir_script(node) == format_script(expected) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + Cast("float32", Var("x", dtype="int32")), + """ + x: T.int32 + T.cast(x, "float32") + """, + id="with-var", + ), + pytest.param( + Cast("int8", IntImm("int16", 1)), + """ + T.cast(T.int16(1), "int8") + """, + id="with-var", + ), + ], +) +def test_cast(node, expected): + assert as_tir_script(node) == format_script(expected) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + Add(1, 0), + """ + 1 + 0 + """, + id="add", + ), + pytest.param( + Sub(1, 0), + """ + 1 - 0 + """, + id="sub", + ), + pytest.param( + Mul(-2.5, 1.5), + """ + T.float32(-2.5) * T.float32(1.5) + """, + id="mul", + ), + pytest.param( + Div(5, 2), + """ + 5 / 2 + """, + id="div", + ), + pytest.param( + FloorDiv(3, 2), + """ + 3 // 2 + """, + id="floor-div", + ), + pytest.param( + FloorMod(IntImm("int8", 5), IntImm("int8", 2)), + """ + T.int8(5) % T.int8(2) + """, + id="floor-mod", + ), + pytest.param( + LT(0, 1), + """ + 0 < 1 + """, + id="lt", + ), + pytest.param( + LE(1.0, 5.0), + """ + T.float32(1) <= T.float32(5) + """, + id="le", + ), + pytest.param( + GT(1, 0), + """ + 1 > 0 + """, + id="gt", + ), + pytest.param( + GE(Var("n", "int32"), 0), + """ + n: T.int32 + n >= 0 + """, + id="ge", + ), + pytest.param( + EQ(Var("n", "int32"), Var("m", "int32")), + """ + n: T.int32 + m: T.int32 + n == m + """, + id="eq", + ), + pytest.param( + NE(1, 0), + """ + 1 != 0 + """, + id="ne", + ), + pytest.param( + And(IntImm("bool", 1), IntImm("bool", 0)), + """ + True and False + """, + id="and", + ), + pytest.param( + Or(IntImm("bool", 1), IntImm("bool", 0)), + """ + True or False + """, + id="or", + ), + ], +) +def test_binary_op(node, expected): + assert as_tir_script(node) == format_script(expected) + + +def test_not(): + assert as_tir_script(Not(IntImm("bool", 1))) == format_script( + """ + not True + """ + ) + + +def test_select(): + node = Select(IntImm("bool", 1), 0, 1) + assert as_tir_script(node) == format_script( + """ + T.Select(True, 0, 1) + """ + ) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + BufferLoad(decl_buffer((5, 10), name="b"), [0, 1]), + """ + b: T.Buffer(shape=(5, 10)) + b[0, 1] + """, + id="normal", + ), + pytest.param( + BufferLoad( + decl_buffer((5,), name="b"), + [ + 0, + ], + ), + """ + b: T.Buffer(shape=(5,)) + b[0] + """, + id="1d", + ), + pytest.param( + BufferLoad(decl_buffer((), name="b"), []), + """ + b: T.Buffer(shape=()) + b[()] + """, + id="0d", + ), + ], +) +def test_buffer_load(node, expected): + assert as_tir_script(node) == format_script(expected) + + +def test_ramp(): + node = Ramp(0, 1, 8) + expected = """ + T.ramp(0, 1, 8) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_broadcast(): + node = Broadcast(0, 4) + expected = """ + T.broadcast(0, 4) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_let(): + x = Var("x", "int32") + y = Var("y", "int32") + node = Let(x, 1, x + y) + # Not var definition for x because x isn't free variable in this expression + expected = """ + y: T.int32 + T.let(x, 1, x + y) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_call_tir_op(): + node = Call("float32", "tir.tvm_stack_make_shape", [0, 1]) + expected = """ + T.call("tir.tvm_stack_make_shape", 0, 1) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_call_global_var(): + f_var = GlobalVar("test_f") + node = Call("float32", f_var, [0, 1]) + expected = """ + test_f(0, 1) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_shuffle(): + x = Var("x", "int32") + y = Var("y", "int32") + node = Shuffle([x, 1, 10], [0, 1, y]) + expected = """ + x: T.int32 + y: T.int32 + T.shuffle([x, 1, 10], [0, 1, y]) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_comm_reducer_single_value(): + x = Var("x", "int32") + y = Var("y", "int32") + node = CommReducer([x], [y], [x + y], [0]) + expected = """ + T.comm_reducer(lambda x, y: x + y, [0]) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_comm_reducer_multi_value(): + x0 = Var("x0", "int32") + x1 = Var("x1", "int32") + y0 = Var("y0", "int32") + y1 = Var("y1", "int32") + node = CommReducer([x0, x1], [y0, y1], [x0 + y0, x1 * y1], [0, 1]) + expected = """ + T.comm_reducer(lambda x0, x1, y0, y1: (x0 + y0, x1 * y1), [0, 1]) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_reduce(): + x = Var("x", "int32") + y = Var("y", "int32") + m = Var("m", "int32") + comm_reducer = CommReducer([x], [y], [x + y], [0]) + node = Reduce(comm_reducer, [m], [IterVar(None, Var("i", "int32"), 2)], True, 0) + expected = """ + i: T.int32 + m: T.int32 + T.reduce(T.comm_reducer(lambda x, y: x + y, [0]), [m], [T.iter_var(i, None, "CommReduce", "")], 0) + """ + assert as_tir_script(node) == format_script(expected) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + Range(0, 1), + """ + _[0:1] + """, + id="normal", + ), + pytest.param( + Range(10), + """ + _[0:10] + """, + id="one-arg", + ), + ], +) +def test_range(node, expected): + assert as_tir_script(node) == format_script(expected) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + BufferRegion(decl_buffer((5, 10), name="b"), [Range(0, 4), Range(0, 9)]), + """ + b: T.Buffer(shape=(5, 10)) + b[0:4, 0:9] + """, + id="normal", + ), + pytest.param( + BufferRegion(decl_buffer((5, 10), name="b"), [Range(0, 1), Range(5, 9)]), + """ + b: T.Buffer(shape=(5, 10)) + b[0, 5:9] + """, + id="scalar-range", + ), + pytest.param( + BufferRegion(decl_buffer((5,), name="b"), [Range(0, 3)]), + """ + b: T.Buffer(shape=(5,)) + b[0:3] + """, + id="1d", + ), + pytest.param( + BufferRegion(decl_buffer((), name="b"), []), + """ + b: T.Buffer(shape=()) + b[()] + """, + id="0d", + ), + ], +) +def test_buffer_region(node, expected): + assert as_tir_script(node) == format_script(expected) From 34e20967224b7829869b2ef6b4b3d9e808aa0fc9 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Wed, 24 Aug 2022 14:09:53 -0400 Subject: [PATCH 5/8] Add op.cc Co-authored-by: Greg Bonik --- src/script/printer/tir/expr.cc | 7 +- src/script/printer/tir/op.cc | 88 +++++++++++++++++++ src/script/printer/tir/tir.h | 2 + .../unittest/test_tvmscript_printer_tir.py | 4 +- 4 files changed, 93 insertions(+), 8 deletions(-) create mode 100644 src/script/printer/tir/op.cc diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc index f3eb46ca5fe7..5cb9cd37aeb6 100644 --- a/src/script/printer/tir/expr.cc +++ b/src/script/printer/tir/expr.cc @@ -168,12 +168,7 @@ ExprDoc PrintCall(TracedObject call, IRDocsifier p) { auto op_or_global_var = call.GetAttr(&tir::CallNode::op); if (op_or_global_var.IsInstance()) { - // TODO(yelite): Call PrintOpCall once it's finished - TracedObject op_name = op_or_global_var.Downcast().GetAttr(&OpNode::name); - Array arg_docs{LiteralDoc::Str(op_name)}; - TracedArray args = call.GetAttr(&tir::CallNode::args); - arg_docs = Concat(arg_docs, AsExprDocArray(args, p)); - return TIR(p)->Attr("call")->Call(arg_docs); + return PrintOpCall(call, p); } else { auto op_gvar = op_or_global_var.Downcast(); auto name_hint = op_gvar.GetAttr(&GlobalVarNode::name_hint); diff --git a/src/script/printer/tir/op.cc b/src/script/printer/tir/op.cc new file mode 100644 index 000000000000..ff50553a15a5 --- /dev/null +++ b/src/script/printer/tir/op.cc @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +#include "../utils.h" +#include "./tir.h" + +namespace tvm { +namespace script { +namespace printer { + +constexpr const char kFTVMScriptOpSugarKey[] = "FTVMScriptOpSugar"; + +ExprDoc PrintOpCall(TracedObject call, IRDocsifier p) { + static auto op_sugar_map = Op::GetAttrMap(kFTVMScriptOpSugarKey); + auto op = call.GetAttr(&tir::CallNode::op).Downcast(); + auto args = call.GetAttr(&tir::CallNode::args); + + if (op_sugar_map.count(op.Get())) { + auto name_str = MakeTraced(op_sugar_map[op.Get()], op.GetPath()); + return TIR(p)->Attr(name_str)->Call(AsExprDocArray(args, p), {}, {}); + } else { + auto op_name = op.GetAttr(&OpNode::name); + Array arg_docs{LiteralDoc::Str(op_name)}; + arg_docs = Concat(arg_docs, AsExprDocArray(args, p)); + return TIR(p)->Attr("call")->Call(arg_docs); + } +} + +#define TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT(name) \ + TVM_REGISTER_OP("tir." name).set_attr(kFTVMScriptOpSugarKey, String(name)) + +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("trunc"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("exp"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("exp2"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("exp10"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("erf"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("tanh"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("sigmoid"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("sqrt"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("rsqrt"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("log"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("log2"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("log1p"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("log10"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("tan"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("cos"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("cosh"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("sin"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("sinh"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("asin"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("acos"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("atan"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("acosh"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("asinh"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("atanh"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("clz"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("atan2"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("nextafter"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("hypot"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("copysign"); +TVM_SCRIPT_TIR_OP_SUGAR_DEFAULT("ldexp"); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/src/script/printer/tir/tir.h b/src/script/printer/tir/tir.h index 6020fe335e54..f518b50298a3 100644 --- a/src/script/printer/tir/tir.h +++ b/src/script/printer/tir/tir.h @@ -84,6 +84,8 @@ inline IdDoc DefineTIRVar(const TracedObject& var, const Frame& frame, return p->vars->Define(var.Get(), var.GetAttr(&tir::VarNode::name_hint), frame); } +ExprDoc PrintOpCall(TracedObject call, IRDocsifier p); + ExprDoc GetTypeAnnotationDocForVar(const TracedObject& var, const IRDocsifier& p); void PostOrderVisitExprTraced(const TracedObject& expr, diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index b03ce24fa808..f45a12706439 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -514,9 +514,9 @@ def test_let(): def test_call_tir_op(): - node = Call("float32", "tir.tvm_stack_make_shape", [0, 1]) + node = Call("float64", "tir.exp", [0.0]) expected = """ - T.call("tir.tvm_stack_make_shape", 0, 1) + T.exp(T.float32(0)) """ assert as_tir_script(node) == format_script(expected) From c591926f5055673d4a7bee9d9776d2e96d952131 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Wed, 24 Aug 2022 15:53:29 -0400 Subject: [PATCH 6/8] Add TIR simple stmt printing Co-authored-by: Greg Bonik --- src/script/printer/tir/stmt.cc | 78 +++++++++++++++++++ .../unittest/test_tvmscript_printer_tir.py | 45 +++++++++++ 2 files changed, 123 insertions(+) create mode 100644 src/script/printer/tir/stmt.cc diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc new file mode 100644 index 000000000000..3314d9c05c8c --- /dev/null +++ b/src/script/printer/tir/stmt.cc @@ -0,0 +1,78 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "../utils.h" +#include "./tir.h" + +namespace tvm { +namespace script { +namespace printer { + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject stmt, IRDocsifier p) { + Array indices = AsExprDocArray(stmt.GetAttr(&tir::BufferStoreNode::indices), p); + Array index_docs(indices.begin(), indices.end()); + return AssignDoc(p->AsExprDoc(stmt.GetAttr(&tir::BufferStoreNode::buffer))[index_docs], + p->AsExprDoc(stmt.GetAttr(&tir::BufferStoreNode::value)), NullOpt); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject stmt, IRDocsifier p) { + return ExprStmtDoc(p->AsExprDoc(stmt.GetAttr(&tir::EvaluateNode::value))); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject stmt, IRDocsifier p) -> Doc { + LOG(FATAL) << "tir::Store cannot be printed. Store is replaced by BufferStore."; + throw; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject stmt, + IRDocsifier p) -> Doc { + LOG(FATAL) + << "tir::BufferRealize cannot be printed. All the BufferRealize should be nested inside " + "with AttrStmt."; + throw; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject stmt, + IRDocsifier p) -> Doc { + LOG(FATAL) << "tir::ProducerStore cannot be printed"; + throw; + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject stmt, + IRDocsifier p) -> Doc { + LOG(FATAL) << "tir::ProducerRealize cannot be printed"; + throw; + }); + +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index f45a12706439..188ae98c87bc 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -30,10 +30,12 @@ Broadcast, BufferLoad, BufferRegion, + BufferStore, Call, Cast, CommReducer, Div, + Evaluate, FloatImm, FloorDiv, FloorMod, @@ -640,3 +642,46 @@ def test_range(node, expected): ) def test_buffer_region(node, expected): assert as_tir_script(node) == format_script(expected) + + +@pytest.mark.parametrize( + "node, expected", + [ + pytest.param( + BufferStore(decl_buffer((5, 10), name="buf"), 1.0, [0, 1]), + """ + buf: T.Buffer(shape=(5, 10)) + buf[0, 1] = T.float32(1) + """, + id="2d", + ), + pytest.param( + BufferStore(decl_buffer((5,), name="buf"), 1.0, [0]), + """ + buf: T.Buffer(shape=(5,)) + buf[0] = T.float32(1) + """, + id="1d", + ), + pytest.param( + BufferStore(decl_buffer((), name="buf"), 1.0, []), + """ + buf: T.Buffer(shape=()) + buf[()] = T.float32(1) + """, + id="0d", + ), + ], +) +def test_buffer_store(node, expected): + assert as_tir_script(node) == format_script(expected) + + +def test_evaluate(): + var = Var("a", "int32") + node = Evaluate(3 + var) + expected = """ + a: T.int32 + 3 + a + """ + assert as_tir_script(node) == format_script(expected) From d26a634da6866f2bd1d1ca68c3717670f44d2ec9 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Thu, 25 Aug 2022 13:22:45 -0400 Subject: [PATCH 7/8] Add TIR assert stmt printing Co-authored-by: Greg Bonik --- src/script/printer/tir/stmt.cc | 163 ++++++++++++++++++ src/script/printer/tir/tir.h | 2 +- .../unittest/test_tvmscript_printer_tir.py | 40 +++++ 3 files changed, 204 insertions(+), 1 deletion(-) diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 3314d9c05c8c..83c3bfd34b05 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -31,6 +31,169 @@ namespace tvm { namespace script { namespace printer { +/* + * \brief Helper to print stmt in the concise scoping form. + * + * For example, the allocate statment in TIR can be written as + * \code + * ... + * with T.allocate([16], "float32", "global") as buf: + * buf[0] = 0.0 # inside the allocate + * T.evaluate(T.call_extern(...)) # outside the allocate + * \endcode + * This representation is ambiguilty-free, but it adds one extra indent to + * the code, which reduces readability if multiple statements are nested together. + * + * If the allocate statement is the last statement in its parent, it can be + * written in the concise scoping form, avoiding adding extra level of indent. + * \code + * ... + * buf = T.allocate([16], "float32", "global") + * buf[0] = 0.0 + * ... + * \endcode + * + * This builder class helps print stmt in the concise scoping form. The attributes + * of this builder map to the output as, + * \code + * # Normal form + * with as : + * + * + * # Concise form + * = + * + * + * # Concise form if the `concise_stmt_override` is defined + * + * + * + * \endcode + * + */ +class ConciseScopedStmtBuilder { + public: + Optional target{NullOpt}; + ExprDoc parent_expr{nullptr}; + Array body; + Optional concise_stmt_override{NullOpt}; + + ConciseScopedStmtBuilder() {} + + using TSelf = ConciseScopedStmtBuilder; + + TSelf& WithBody(Array body) { + this->body = body; + return *this; + } + + TSelf& WithConciseFormStmt(StmtDoc stmt) { + this->concise_stmt_override = stmt; + return *this; + } + + TSelf& WithTarget(ExprDoc target) { + this->target = target; + return *this; + } + + TSelf& WithParentExpr(ExprDoc expr) { + this->parent_expr = expr; + return *this; + } + + StmtBlockDoc ToDoc(const IRDocsifier& p) { return ToDoc(p->GetFrame().value()); } + + StmtBlockDoc ToDoc(const TIRFrame& frame) { + ICHECK(parent_expr.defined()); + if (frame->allow_concise_scoping) { + StmtDoc first_doc = ExprStmtDoc(parent_expr); + if (concise_stmt_override) { + first_doc = concise_stmt_override.value(); + } else if (target.defined()) { + first_doc = AssignDoc(target.value(), parent_expr, NullOpt); + } + + return StmtBlockDoc(runtime::Concat({first_doc}, body)); + } else { + return StmtBlockDoc({ScopeDoc(target, parent_expr, body)}); + } + } +}; + +std::vector> FlattenSeqStmt(const TracedObject& stmt) { + std::vector> result; + + if (stmt.IsInstance()) { + auto seq = stmt.Downcast().GetAttr(&tir::SeqStmtNode::seq); + for (const TracedObject& child : seq) { + std::vector> flattened_child = FlattenSeqStmt(child); + result.insert(result.end(), flattened_child.begin(), flattened_child.end()); + } + } else { + result.push_back(stmt); + } + + return result; +} + +Array FlattenStmtDoc(const Doc& doc) { + if (const auto* stmt_block = doc.as()) { + return stmt_block->stmts; + } else if (const auto* stmt_doc = doc.as()) { + return {GetRef(stmt_doc)}; + } else { + LOG(FATAL) << "Expect to get StmtBlockDoc or StmtDoc, got " << doc->GetTypeKey(); + throw; + } +} + +Array AsStmtDocArray(const TracedObject& obj, IRDocsifier p) { + Array result; + std::vector> flattened_stmts = FlattenSeqStmt(obj); + + const auto* frame_node = p->frames.back().as(); + ICHECK_NOTNULL(frame_node); + + size_t stmt_count = flattened_stmts.size(); + + const bool original_concise_scoping_status = frame_node->allow_concise_scoping; + frame_node->allow_concise_scoping = false; + for (size_t i = 0; i < stmt_count; i++) { + if (i == stmt_count - 1) { + frame_node->allow_concise_scoping = true; + } + result = runtime::Concat(result, FlattenStmtDoc(p->AsDoc(flattened_stmts[i]))); + } + frame_node->allow_concise_scoping = original_concise_scoping_status; + + return result; +} + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject stmt, IRDocsifier p) -> Doc { + if (!p->frames.back()->IsInstance()) { + // Throw error + LOG(FATAL) << "tir::SeqStmt can only be printed when it's the top level statement. " + "Use AsStmtDocArray to print the body of statement"; + throw; + } + return StmtBlockDoc(AsStmtDocArray(stmt, p)); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject stmt, IRDocsifier p) { + ExprDoc condition_expr = p->AsExprDoc(stmt.GetAttr(&tir::AssertStmtNode::condition)); + ExprDoc message_expr = p->AsExprDoc(stmt.GetAttr(&tir::AssertStmtNode::message)); + Array body = AsStmtDocArray(stmt.GetAttr(&tir::AssertStmtNode::body), p); + + return ConciseScopedStmtBuilder() + .WithParentExpr(TIR(p)->Attr("Assert")->Call({condition_expr, message_expr})) + .WithConciseFormStmt(AssertDoc(condition_expr, message_expr)) + .WithBody(body) + .ToDoc(p); + }); + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch([](TracedObject stmt, IRDocsifier p) { Array indices = AsExprDocArray(stmt.GetAttr(&tir::BufferStoreNode::indices), p); diff --git a/src/script/printer/tir/tir.h b/src/script/printer/tir/tir.h index f518b50298a3..36ad53a3e7fc 100644 --- a/src/script/printer/tir/tir.h +++ b/src/script/printer/tir/tir.h @@ -56,7 +56,7 @@ class TIRTopLevelFrameNode : public TIRFrameNode { } static constexpr const char* _type_key = "script.printer.TIRTopLevelFrame"; - TVM_DECLARE_BASE_OBJECT_INFO(TIRTopLevelFrameNode, FrameNode); + TVM_DECLARE_BASE_OBJECT_INFO(TIRTopLevelFrameNode, TIRFrameNode); }; class TIRTopLevelFrame : public TIRFrame { diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index 188ae98c87bc..b9330ace356b 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -27,6 +27,7 @@ NE, Add, And, + AssertStmt, Broadcast, BufferLoad, BufferRegion, @@ -48,6 +49,7 @@ Ramp, Reduce, Select, + SeqStmt, Shuffle, SizeVar, StringImm, @@ -685,3 +687,41 @@ def test_evaluate(): 3 + a """ assert as_tir_script(node) == format_script(expected) + + +def test_assert_normal_form(): + body = SeqStmt([Evaluate(2), Evaluate(3)]) + node = SeqStmt([AssertStmt(1, StringImm("test"), body), Evaluate(4)]) + expected = """ + with T.Assert(1, "test"): + 2 + 3 + 4 + """ + assert as_tir_script(node) == format_script(expected) + + +def test_assert_concise_form(): + body = SeqStmt([Evaluate(2), Evaluate(3)]) + node = AssertStmt(1, StringImm("test"), body) + expected = """ + assert 1, "test" + 2 + 3 + """ + assert as_tir_script(node) == format_script(expected) + + +def test_assert_body_concise_form(): + body_assert1 = AssertStmt(2, StringImm("test"), Evaluate(3)) + body_assert2 = AssertStmt(4, StringImm("test"), Evaluate(5)) + body = SeqStmt([body_assert1, body_assert2]) + node = AssertStmt(1, StringImm("test"), body) + expected = """ + assert 1, "test" + with T.Assert(2, "test"): + 3 + assert 4, "test" + 5 + """ + assert as_tir_script(node) == format_script(expected) From 041b9a4baa7bd86c263266883f3a6cf0bd568c8e Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Sun, 28 Aug 2022 10:46:02 -0400 Subject: [PATCH 8/8] Add TIR statement printing Co-authored-by: Greg Bonik --- src/script/printer/tir/buffer.cc | 41 +++--- src/script/printer/tir/buffer.h | 5 + src/script/printer/tir/stmt.cc | 106 +++++++++++++++ src/script/printer/tir/tir.h | 2 +- src/script/printer/utils.h | 2 +- .../unittest/test_tvmscript_printer_tir.py | 123 ++++++++++++++++++ 6 files changed, 262 insertions(+), 17 deletions(-) diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index c85169865cb4..55a943cf842e 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -254,27 +254,34 @@ TracedObject GetBufferNameHint(const TracedObject& buf) { } } -IdDoc DefineFreeBuffer(const TracedObject& buf, const Frame& frame, - const IRDocsifier& p, std::function add_definiton) { - TracedObject name_hint = GetBufferNameHint(buf); - IdDoc buf_doc = p->vars->Define(buf.Get(), name_hint, frame); +std::vector DefineBuffers(const std::vector>& buffers, + const Frame& frame, const IRDocsifier& p, + const ExprDoc& definition_prefix, + std::function add_definiton) { + std::vector result; + auto f_var_defined = [&p](const tir::VarNode* var) -> bool { return p->vars->IsVarDefined(GetRef(var)); }; std::unordered_map var_explicit_def; BufferAssociatedVariables associated_vars; - BufferPrintInfo buffer_print_info = - GetBufferPrintInfo({buf}, f_var_defined, &var_explicit_def, &associated_vars)[0]; + std::vector buffers_print_info = + GetBufferPrintInfo(buffers, f_var_defined, &var_explicit_def, &associated_vars); + for (const BufferPrintInfo& buffer_print_info : buffers_print_info) { + TracedObject buffer = buffer_print_info.buffer; + TracedObject name_hint = GetBufferNameHint(buffer); + IdDoc buf_doc = p->vars->Define(buffer.Get(), name_hint, frame); + result.push_back(buf_doc); + ExprDoc buf_definition = buffer_print_info.AsCall( + definition_prefix, + [&p](const TracedObject& expr) -> ExprDoc { return p->AsDoc(expr); }); + add_definiton(buf_doc, buf_definition); + } associated_vars.Define(p->vars.get(), frame); - ExprDoc buf_definition = buffer_print_info.AsCall( - TIR(p)->Attr("Buffer"), - [&p](const TracedObject& expr) -> ExprDoc { return p->AsDoc(expr); }); - add_definiton(AssignDoc(buf_doc, NullOpt, buf_definition)); - - return buf_doc; + return result; } ExprDoc PrintBuffer(TracedObject buf, IRDocsifier p) { @@ -285,9 +292,13 @@ ExprDoc PrintBuffer(TracedObject buf, IRDocsifier p) { // TODO(yelite): When implementing the PrimFunc printing, the logic here // needs to change, putting variable def into PrimFuncFrame if it exists. TIRTopLevelFrame top_level_frame = p->GetFrame().value(); - return DefineFreeBuffer(buf, top_level_frame, p, [top_level_frame](StmtDoc definition) { - top_level_frame->free_var_definitions.push_back(definition); - }); + auto add_free_buffer_definition = [top_level_frame](IdDoc buf_indentifier, + ExprDoc buf_definition) { + top_level_frame->free_var_definitions.push_back( + AssignDoc(buf_indentifier, NullOpt, buf_definition)); + }; + return DefineBuffers({buf}, top_level_frame, p, TIR(p)->Attr("Buffer"), + add_free_buffer_definition)[0]; } } diff --git a/src/script/printer/tir/buffer.h b/src/script/printer/tir/buffer.h index 10c2039b4e57..3af66dd00f4e 100644 --- a/src/script/printer/tir/buffer.h +++ b/src/script/printer/tir/buffer.h @@ -98,6 +98,11 @@ std::vector GetBufferPrintInfo( std::unordered_map* var_explicit_def, BufferAssociatedVariables* associated_vars); +std::vector DefineBuffers(const std::vector>& buffers, + const Frame& frame, const IRDocsifier& p, + const ExprDoc& definition_prefix, + std::function add_definiton); + } // namespace printer } // namespace script } // namespace tvm diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 83c3bfd34b05..15f8ce75f5a6 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -25,6 +25,7 @@ #include #include "../utils.h" +#include "./buffer.h" #include "./tir.h" namespace tvm { @@ -170,6 +171,33 @@ Array AsStmtDocArray(const TracedObject& obj, IRDocsifier p) return result; } +static TracedOptional GetUsedBuffer(const TracedObject& stmt_or_expr) { + if (auto load = stmt_or_expr.TryDowncast()) { + return load.value().GetAttr(&tir::BufferLoadNode::buffer); + } else if (auto store = stmt_or_expr.TryDowncast()) { + return store.value().GetAttr(&tir::BufferStoreNode::buffer); + } else { + return TracedOptional(NullOpt, ObjectPath::Root()); + } +} + +std::vector> FindBufferVarUsage(tir::Var buffer_var, + TracedObject body) { + std::vector> ret; + PostOrderVisitStmtExprTraced( + body, [&ret, buffer_var](const TracedObject& stmt_or_expr) { + if (auto buffer_opt = GetUsedBuffer(stmt_or_expr)) { + auto buffer = buffer_opt.value(); + if (buffer.Get()->data.same_as(buffer_var) && + std::find_if(ret.begin(), ret.end(), + [&](const auto& b) { return b.Get() == buffer.Get(); }) == ret.end()) { + ret.push_back(buffer); + } + } + }); + return ret; +} + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch([](TracedObject stmt, IRDocsifier p) -> Doc { if (!p->frames.back()->IsInstance()) { @@ -202,6 +230,84 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) p->AsExprDoc(stmt.GetAttr(&tir::BufferStoreNode::value)), NullOpt); }); +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject stmt, IRDocsifier p) { + ExprDoc predicate = p->AsExprDoc(stmt.GetAttr(&tir::IfThenElseNode::condition)); + Array then_branch = AsStmtDocArray(stmt.GetAttr(&tir::IfThenElseNode::then_case), p); + Array else_branch = AsStmtDocArray(stmt.GetAttr(&tir::IfThenElseNode::else_case), p); + return IfDoc(predicate, then_branch, else_branch); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject stmt, IRDocsifier p) { + return WhileDoc(p->AsExprDoc(stmt.GetAttr(&tir::WhileNode::condition)), + AsStmtDocArray(stmt.GetAttr(&tir::WhileNode::body), p)); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject stmt, IRDocsifier p) { + auto buffer = stmt.GetAttr(&tir::PrefetchNode::buffer); + auto bounds = stmt.GetAttr(&tir::PrefetchNode::bounds); + return ExprStmtDoc( + TIR(p)->Attr("prefetch")->Call({p->AsExprDoc(buffer)[AsDocArray(bounds, p)]})); + }); + +TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) + .set_dispatch([](TracedObject stmt, IRDocsifier p) { + TIRFrame previous_frame = p->GetFrame().value(); + TIRGeneralFrame let_frame; + WithCtx ctx = p->WithFrame(let_frame); + + auto var = stmt.GetAttr(&tir::LetStmtNode::var); + bool is_var_defined_previously = p->vars->IsVarDefined(var.Get()); + ExprDoc var_doc{nullptr}; + if (is_var_defined_previously) { + var_doc = p->vars->GetVarDoc(var).value(); + } else { + var_doc = DefineTIRVar(var, let_frame, p); + } + + auto value_doc = p->AsExprDoc(stmt.GetAttr(&tir::LetStmtNode::value)); + auto dtype = var.GetAttr(&tir::VarNode::dtype); + auto type_annotation_doc = GetTypeAnnotationDocForVar(var, p); + + TracedObject body_stmt = stmt.GetAttr(&tir::LetStmtNode::body); + Array body_doc; + + // Print definition of buffers that aliases the variable of this Let stmt. + std::vector> aliasing_buffers = + FindBufferVarUsage(var.Get(), body_stmt); + std::vector> buffers_to_define; + for (const TracedObject& buffer : aliasing_buffers) { + if (!p->vars->IsVarDefined(buffer.Get())) { + buffers_to_define.push_back(buffer); + } + } + DefineBuffers(buffers_to_define, let_frame, p, TIR(p)->Attr("decl_buffer"), + [&body_doc](IdDoc buf_identifier, ExprDoc buf_definition) { + body_doc.push_back(AssignDoc(buf_identifier, buf_definition, NullOpt)); + }); + + body_doc = runtime::Concat(body_doc, AsStmtDocArray(body_stmt, p)); + + if (previous_frame->allow_concise_scoping) { + // dtype won't be linked to a doc object if it does concise scoping + // here we manually link it to type annotation + type_annotation_doc->source_paths.push_back(dtype.GetPath()); + AssignDoc var_assignment = AssignDoc(var_doc, value_doc, type_annotation_doc); + return StmtBlockDoc(runtime::Concat({var_assignment}, body_doc)); + } else { + Array result; + if (!is_var_defined_previously) { + result.push_back(AssignDoc(var_doc, TIR(p)->Attr("var")->Call({DType2Literal(dtype)}), + type_annotation_doc)); + } + ExprDoc let_call = TIR(p)->Attr("let")->Call({var_doc, value_doc}); + result.push_back(ScopeDoc(NullOpt, let_call, body_doc)); + return StmtBlockDoc(result); + } + }); + TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch([](TracedObject stmt, IRDocsifier p) { return ExprStmtDoc(p->AsExprDoc(stmt.GetAttr(&tir::EvaluateNode::value))); diff --git a/src/script/printer/tir/tir.h b/src/script/printer/tir/tir.h index 36ad53a3e7fc..31e885d901bf 100644 --- a/src/script/printer/tir/tir.h +++ b/src/script/printer/tir/tir.h @@ -69,7 +69,7 @@ class TIRTopLevelFrame : public TIRFrame { class TIRGeneralFrameNode : public TIRFrameNode { public: static constexpr const char* _type_key = "script.printer.TIRGeneralFrame"; - TVM_DECLARE_BASE_OBJECT_INFO(TIRGeneralFrameNode, FrameNode); + TVM_DECLARE_BASE_OBJECT_INFO(TIRGeneralFrameNode, TIRFrameNode); }; class TIRGeneralFrame : public TIRFrame { diff --git a/src/script/printer/utils.h b/src/script/printer/utils.h index eef93b9348d9..40ac7ca17865 100644 --- a/src/script/printer/utils.h +++ b/src/script/printer/utils.h @@ -32,7 +32,7 @@ template Array AsDocArray(const TracedArray& refs, const IRDocsifier& ir_docsifier) { Array result; for (auto ref : refs) { - result.push_back(ir_docsifier->AsExprDoc(ref)); + result.push_back(ir_docsifier->AsDoc(ref)); } return result; } diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index b9330ace356b..b6beafc549dc 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -40,12 +40,15 @@ FloatImm, FloorDiv, FloorMod, + IfThenElse, IntImm, IterVar, Let, + LetStmt, Mul, Not, Or, + Prefetch, Ramp, Reduce, Select, @@ -55,6 +58,7 @@ StringImm, Sub, Var, + While, decl_buffer, ) @@ -186,6 +190,18 @@ def test_var(var_type): """, id="all_param", ), + pytest.param( + decl_buffer( + (4, 10), + dtype="int32", + data=Var("p", PointerType(PrimType("int32"))), + ), + """ + buffer: T.Buffer("int32", shape=(4, 10)) + buffer + """, + id="default_data_ptr", + ), pytest.param( decl_buffer( (4, 10), @@ -195,6 +211,7 @@ def test_var(var_type): buffer: T.Buffer("bool", shape=(4, 10)) buffer """, + # Boolean buffer has different ptr type (Int8) and buffer type (UInt1) id="bool_different_ptr_type", ), ], @@ -203,6 +220,27 @@ def test_buffer(buffer, expected): assert as_tir_script(buffer) == format_script(expected) +def test_buffer_free_buffer_aliasing(): + ptr_var = Var("p", PointerType(PrimType("int16"))) + buffer_a = decl_buffer((4, 10), name="buffer_a", dtype="int8", data=ptr_var) + buffer_b = decl_buffer((8, 5), name="buffer_b", dtype="int4", data=ptr_var) + node = SeqStmt( + [ + Evaluate(BufferLoad(buffer_a, [0, 0])), + Evaluate(BufferLoad(buffer_b, [0, 1])), + ] + ) + # only prints one p + expected = """ + p: T.Ptr(T.int16) + buffer_a: T.Buffer("int8", shape=(4, 10), data=p) + buffer_b: T.Buffer("int4", shape=(8, 5), data=p) + buffer_a[0, 0] + buffer_b[0, 1] + """ + assert as_tir_script(node) == format_script(expected) + + @pytest.mark.parametrize( "node, expected", [ @@ -725,3 +763,88 @@ def test_assert_body_concise_form(): 5 """ assert as_tir_script(node) == format_script(expected) + + +def test_if_then_else(): + node = IfThenElse(1, Evaluate(2), Evaluate(3)) + expected = """ + if 1: + 2 + else: + 3 + """ + assert as_tir_script(node) == format_script(expected) + + +def test_if_then_else_free_var(): + var = Var("a", "int32") + node = IfThenElse(1, Evaluate(var), Evaluate(var + 1)) + expected = """ + a: T.int32 + if 1: + a + else: + a + 1 + """ + assert as_tir_script(node) == format_script(expected) + + +def test_while(): + var = Var("a", "int32") + node = While(var, Evaluate(var * 2)) + expected = """ + a: T.int32 + while a: + a * 2 + """ + assert as_tir_script(node) == format_script(expected) + + +def test_prefetch(): + buf = decl_buffer((5, 10), name="buf") + ranges = [Range(1, 2), Range(5, 6)] + node = Prefetch(buf, ranges) + expected = """ + buf: T.Buffer(shape=(5, 10)) + T.prefetch(buf[1:2, 5:6]) + """ + assert as_tir_script(node) == format_script(expected) + + +def test_let_stmt(): + x = Var("x", "int32") + node = LetStmt(x, 1, Evaluate(x + 1)) + expected = """ + x: T.int32 = 1 + x + 1 + """ + assert as_tir_script(node) == format_script(expected) + + +def test_let_stmt_concise_scoping(): + x = Var("x", "int32") + node = SeqStmt([LetStmt(x, 1, Evaluate(x + 1)), LetStmt(x, 2, Evaluate(x + 2))]) + expected = """ + x: T.int32 = T.var("int32") + with T.let(x, 1): + x + 1 + x: T.int32 = 2 + x + 2 + """ + assert as_tir_script(node) == format_script(expected) + + +def test_let_stmt_alias_buffer(): + buf0 = decl_buffer((10, 10), name="buf0") + ptr = Var("ptr", PointerType(PrimType("float32"))) + buf1 = decl_buffer((5, 10), name="buf1", data=ptr) + buf2 = decl_buffer((2, 5), name="buf2", data=ptr) + node = LetStmt(ptr, buf0.data, SeqStmt([BufferStore(buf1, BufferLoad(buf2, [2, 4]), [1, 2])])) + expected = """ + buf0: T.Ptr(T.float32) + ptr: T.Ptr(T.float32) = buf0 + buf2 = T.decl_buffer(shape=(2, 5), data=ptr) + buf1 = T.decl_buffer(shape=(5, 10), data=ptr) + buf1[1, 2] = buf2[2, 4] + """ + assert as_tir_script(node) == format_script(expected)