Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,8 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
std::unordered_map<Var, Doc, ObjectPtrHash, ObjectPtrEqual> memo_var_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<Buffer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_buf_;
/*! \brief Map from Buffer to Doc */
std::unordered_map<DataProducer, Doc, ObjectPtrHash, ObjectPtrEqual> memo_producer_;
/*! \brief name allocation map */
std::unordered_map<std::string, int> name_alloc_map_;

Expand Down Expand Up @@ -321,7 +323,9 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc VisitStmt_(const AssertStmtNode* op) override;
Doc VisitStmt_(const StoreNode* op) override;
Doc VisitStmt_(const BufferStoreNode* op) override;
Doc VisitStmt_(const ProducerStoreNode* op) override;
Doc VisitStmt_(const BufferRealizeNode* op) override;
Doc VisitStmt_(const ProducerRealizeNode* op) override;
Doc VisitStmt_(const AllocateNode* op) override;
Doc VisitStmt_(const IfThenElseNode* op) override;
Doc VisitStmt_(const SeqStmtNode* op) override;
Expand All @@ -342,7 +346,9 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc PrintIterVar(const IterVarNode* op);
Doc PrintRange(const RangeNode* op);
Doc PrintBuffer(const BufferNode* op);
Doc PrintProducer(const DataProducerNode* op);
Doc BufferNode2Doc(const BufferNode* op, Doc doc);
Doc DataProducerNode2Doc(const DataProducerNode* op, Doc doc);
Doc PrintString(const StringObj* op) { return Doc::StrLiteral(op->data); }
Doc PrintBufferRegion(const BufferRegionNode* op);

Expand All @@ -361,6 +367,7 @@ class TIRTextPrinter : public StmtFunctor<Doc(const Stmt&)>,
Doc GetUniqueName(std::string prefix);
Doc AllocVar(const Var& var);
Doc AllocBuf(const Buffer& buffer);
Doc AllocProducer(const DataProducer& buffer);
/*!
* \brief special method to render vectors of docs with a separator
* \param vec vector of docs
Expand Down
47 changes: 47 additions & 0 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ Doc TIRTextPrinter::Print(const ObjectRef& node) {
return PrintRange(node.as<RangeNode>());
} else if (node->IsInstance<BufferNode>()) {
return PrintBuffer(node.as<BufferNode>());
} else if (node->IsInstance<DataProducerNode>()) {
return PrintProducer(node.as<DataProducerNode>());
} else if (node->IsInstance<StringObj>()) {
return PrintString(node.as<StringObj>());
} else if (node->IsInstance<BufferRegionNode>()) {
Expand Down Expand Up @@ -199,6 +201,19 @@ Doc TIRTextPrinter::PrintBuffer(const BufferNode* op) {
}
}

Doc TIRTextPrinter::PrintProducer(const DataProducerNode* op) {
const DataProducer& prod = GetRef<DataProducer>(op);

if (meta_->InMeta(prod)) {
return meta_->GetMetaNode(prod);
} else if (memo_producer_.count(prod)) {
return memo_producer_[prod];
} else {
memo_producer_[prod] = AllocProducer(prod);
return DataProducerNode2Doc(op, memo_producer_[prod]);
}
}

Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) {
doc << Doc::Text(": Buffer(") << Print(buf->data) << ", " << PrintDType(buf->dtype) << ", "
<< Print(buf->shape) << ", " << Print(buf->strides);
Expand All @@ -220,6 +235,11 @@ Doc TIRTextPrinter::BufferNode2Doc(const BufferNode* buf, Doc doc) {
return doc << ")";
}

Doc TIRTextPrinter::DataProducerNode2Doc(const DataProducerNode* prod, Doc doc) {
return doc << Doc::Text(": DataProducer(") << Print(prod->GetNameHint()) << ", "
<< PrintDType(prod->GetDataType()) << ", " << Print(prod->GetShape()) << ")";
}

Doc TIRTextPrinter::PrintBufferRegion(const BufferRegionNode* op) {
Doc doc;
doc << Print(op->buffer) << "[";
Expand Down Expand Up @@ -439,13 +459,26 @@ Doc TIRTextPrinter::VisitStmt_(const BufferStoreNode* op) {
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const ProducerStoreNode* op) {
Doc doc;
doc << Print(op->producer) << Print(op->indices) << " = " << Print(op->value);
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const BufferRealizeNode* op) {
Doc doc;
doc << "realize(" << Print(op->buffer) << ", " << Print(op->bounds) << ", "
<< Print(op->condition) << PrintBody(op->body) << ")";
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const ProducerRealizeNode* op) {
Doc doc;
doc << "producer_realize(" << Print(op->producer) << ", " << Print(op->bounds) << ", "
<< Print(op->condition) << ", " << PrintBody(op->body) << ")";
return doc;
}

Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) {
Doc doc;
auto scope = GetPtrStorageScope(op->buffer_var);
Expand Down Expand Up @@ -709,6 +742,20 @@ Doc TIRTextPrinter::AllocBuf(const Buffer& buffer) {
return val;
}

Doc TIRTextPrinter::AllocProducer(const DataProducer& producer) {
const auto& it = memo_producer_.find(producer);
if (it != memo_producer_.end()) {
return it->second;
}
std::string name = producer->GetNameHint();
if (name.length() == 0 || !std::isalpha(name[0])) {
name = "tensor_" + name;
}
Doc val = GetUniqueName(name);
memo_producer_[producer] = val;
return val;
}

Doc TIRTextPrinter::PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
Doc seq;
if (vec.size() != 0) {
Expand Down