From 759d945a2c77bc712be122a0148bcbc05fb80768 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 8 Oct 2021 09:57:46 -0500 Subject: [PATCH] [TIR] Added PrettyPrint of ProducerStore/ProducerRealize nodes --- src/printer/text_printer.h | 7 +++++ src/printer/tir_text_printer.cc | 47 +++++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 3514f3228e27..a2178167b2e3 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -276,6 +276,8 @@ class TIRTextPrinter : public StmtFunctor, std::unordered_map memo_var_; /*! \brief Map from Buffer to Doc */ std::unordered_map memo_buf_; + /*! \brief Map from Buffer to Doc */ + std::unordered_map memo_producer_; /*! \brief name allocation map */ std::unordered_map name_alloc_map_; @@ -321,7 +323,9 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitStmt_(const AssertStmtNode* op) override; Doc VisitStmt_(const StoreNode* op) override; Doc VisitStmt_(const BufferStoreNode* op) override; + Doc VisitStmt_(const ProducerStoreNode* op) override; Doc VisitStmt_(const BufferRealizeNode* op) override; + Doc VisitStmt_(const ProducerRealizeNode* op) override; Doc VisitStmt_(const AllocateNode* op) override; Doc VisitStmt_(const IfThenElseNode* op) override; Doc VisitStmt_(const SeqStmtNode* op) override; @@ -342,7 +346,9 @@ class TIRTextPrinter : public StmtFunctor, Doc PrintIterVar(const IterVarNode* op); Doc PrintRange(const RangeNode* op); Doc PrintBuffer(const BufferNode* op); + Doc PrintProducer(const DataProducerNode* op); Doc BufferNode2Doc(const BufferNode* op, Doc doc); + Doc DataProducerNode2Doc(const DataProducerNode* op, Doc doc); Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); } Doc PrintBufferRegion(const BufferRegionNode* op); @@ -361,6 +367,7 @@ class TIRTextPrinter : public StmtFunctor, Doc GetUniqueName(std::string prefix); Doc AllocVar(const Var& var); Doc AllocBuf(const Buffer& buffer); + Doc AllocProducer(const DataProducer& buffer); /*! * \brief special method to render vectors of docs with a separator * \param vec vector of docs diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index fa132f079793..302c4491cebe 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -65,6 +65,8 @@ Doc TIRTextPrinter::Print(const ObjectRef& node) { return PrintRange(node.as()); } else if (node->IsInstance()) { return PrintBuffer(node.as()); + } else if (node->IsInstance()) { + return PrintProducer(node.as()); } else if (node->IsInstance()) { return PrintString(node.as()); } else if (node->IsInstance()) { @@ -199,6 +201,19 @@ Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) { } } +Doc TIRTextPrinter::PrintProducer(const DataProducerNode* op) { + const DataProducer& prod = GetRef(op); + + if (meta_->InMeta(prod)) { + return meta_->GetMetaNode(prod); + } else if (memo_producer_.count(prod)) { + return memo_producer_[prod]; + } else { + memo_producer_[prod] = AllocProducer(prod); + return DataProducerNode2Doc(op, memo_producer_[prod]); + } +} + Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) { doc << Doc::Text(": Buffer(") << Print(buf->data) << ", " << PrintDType(buf->dtype) << ", " << Print(buf->shape) << ", " << Print(buf->strides); @@ -220,6 +235,11 @@ Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) { return doc << ")"; } +Doc TIRTextPrinter::DataProducerNode2Doc(const DataProducerNode* prod, Doc doc) { + return doc << Doc::Text(": DataProducer(") << Print(prod->GetNameHint()) << ", " + << PrintDType(prod->GetDataType()) << ", " << Print(prod->GetShape()) << ")"; +} + Doc TIRTextPrinter::PrintBufferRegion(const BufferRegionNode* op) { Doc doc; doc << Print(op->buffer) << "["; @@ -439,6 +459,12 @@ Doc TIRTextPrinter::VisitStmt_(const BufferStoreNode* op) { return doc; } +Doc TIRTextPrinter::VisitStmt_(const ProducerStoreNode* op) { + Doc doc; + doc << Print(op->producer) << Print(op->indices) << " = " << Print(op->value); + return doc; +} + Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { Doc doc; doc << "realize(" << Print(op->buffer) << ", " << Print(op->bounds) << ", " @@ -446,6 +472,13 @@ Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) { return doc; } +Doc TIRTextPrinter::VisitStmt_(const ProducerRealizeNode* op) { + Doc doc; + doc << "producer_realize(" << Print(op->producer) << ", " << Print(op->bounds) << ", " + << Print(op->condition) << ", " << PrintBody(op->body) << ")"; + return doc; +} + Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { Doc doc; auto scope = GetPtrStorageScope(op->buffer_var); @@ -709,6 +742,20 @@ Doc TIRTextPrinter::AllocBuf(const Buffer& buffer) { return val; } +Doc TIRTextPrinter::AllocProducer(const DataProducer& producer) { + const auto& it = memo_producer_.find(producer); + if (it != memo_producer_.end()) { + return it->second; + } + std::string name = producer->GetNameHint(); + if (name.length() == 0 || !std::isalpha(name[0])) { + name = "tensor_" + name; + } + Doc val = GetUniqueName(name); + memo_producer_[producer] = val; + return val; +} + Doc TIRTextPrinter::PrintSep(const std::vector& vec, const Doc& sep) { Doc seq; if (vec.size() != 0) {