From 6b8e03058eaf1cb7c618ce95fb62e56324004de9 Mon Sep 17 00:00:00 2001 From: driazati Date: Thu, 29 Sep 2022 16:54:36 -0700 Subject: [PATCH 1/7] TIR debug info --- .gitignore | 3 + include/tvm/tir/transform.h | 7 + src/driver/driver_api.cc | 8 + src/ir/transform.cc | 1 + src/printer/text_printer.h | 40 ++--- src/printer/tir_text_printer.cc | 53 +++---- src/printer/tir_text_printer_debug.cc | 66 ++++++++ src/printer/tir_text_printer_debug.h | 72 +++++++++ src/target/llvm/codegen_cpu.cc | 78 +++++----- src/target/llvm/codegen_cpu.h | 1 + src/target/llvm/codegen_llvm.cc | 75 +++++++-- src/target/llvm/codegen_llvm.h | 24 ++- src/tir/transforms/install_debug_spans.cc | 153 +++++++++++++++++++ src/tir/transforms/install_debug_spans.h | 80 ++++++++++ src/tir/transforms/install_debug_spans_ops.h | 78 ++++++++++ tests/python/tir/test_debug_info.py | 71 +++++++++ 16 files changed, 718 insertions(+), 92 deletions(-) create mode 100644 src/printer/tir_text_printer_debug.cc create mode 100644 src/printer/tir_text_printer_debug.h create mode 100644 src/tir/transforms/install_debug_spans.cc create mode 100644 src/tir/transforms/install_debug_spans.h create mode 100644 src/tir/transforms/install_debug_spans_ops.h create mode 100644 tests/python/tir/test_debug_info.py diff --git a/.gitignore b/.gitignore index 03c0a0bc6af9..8920bc741770 100644 --- a/.gitignore +++ b/.gitignore @@ -271,3 +271,6 @@ gallery/how_to/work_with_microtvm/micro_tvmc.py # Used in CI to communicate between Python and Jenkins .docker-image-names/ + +# Printed TIR code on disk +*.tir diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 48372565469b..829594d61b98 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -496,6 +496,13 @@ TVM_DLL Pass LowerAsyncDMA(); */ TVM_DLL Pass CommonSubexprElimTIR(bool enable_cse_tir = true, bool identify_equiv_terms = false); +/*! + * \brief Add TIR-printer output as debug information to all ops in the module + * \return The pass. + */ + +TVM_DLL Pass InstallDebugSpans(); + /*! * \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and * "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g., diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 92769d1cef45..288ac7b92a2c 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -45,6 +45,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_bound_checkers", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_assert", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_vectorize", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); @@ -603,6 +604,9 @@ TVM_REGISTER_GLOBAL("driver.mixed_mod_passes") }); transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_host) { + transform::PassContext pass_ctx = transform::PassContext::Current(); + bool enable_debug = pass_ctx->GetConfig("tir.enable_debug", Bool(false)).value(); + Array host_pass_list; runtime::TypedPackedFunc fcond = [](const tir::PrimFunc& f) { @@ -621,6 +625,10 @@ transform::Sequential HostModulePassManager(IRModule mixed_mod, Target target_ho host_pass_list.push_back(tir::transform::LowerDeviceStorageAccessInfo()); host_pass_list.push_back(tir::transform::CombineContextCall()); + if (enable_debug) { + host_pass_list.push_back(tir::transform::InstallDebugSpans()); + } + return transform::Sequential(host_pass_list); } diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 77ea942a0bb9..e0f08d28fb18 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -440,6 +440,7 @@ Pass GetPass(const String& pass_name) { // ordering problem needs to be handled in the future. IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const { for (const Pass& pass : passes) { + VLOG(0) << "Running pass " << pass->Info()->name; ICHECK(pass.defined()) << "Found undefined pass for optimization."; const PassInfo& pass_info = pass->Info(); if (!pass_ctx.PassEnabled(pass_info)) { diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 2dc0997f82ec..afc76112879e 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -280,6 +280,9 @@ class TIRTextPrinter : public StmtFunctor, explicit TIRTextPrinter(bool show_meta, TextMetaDataContext* meta) : show_meta_(show_meta), meta_(meta), meta_collector_(meta) {} + /*! \brief Output a newline */ + virtual Doc NewLine(); + /*! \brief Print the node */ Doc Print(const ObjectRef& node); @@ -290,24 +293,7 @@ class TIRTextPrinter : public StmtFunctor, */ bool GetVarName(::tvm::tir::Var v, std::string* s); - private: - /*! \brief whether show meta data */ - bool show_meta_; - /*! \brief meta data context */ - TextMetaDataContext* meta_; - /*! \brief meta collector */ - MetaCollector meta_collector_; - /*! \brief Map from Var to Doc */ - std::unordered_map memo_var_; - /*! \brief Map from Buffer to Doc */ - std::unordered_map memo_buf_; - /*! \brief Map from Buffer to Doc */ - std::unordered_map memo_producer_; - /*! \brief name allocation map */ - std::unordered_map name_alloc_map_; - - friend class tvm::TextPrinter; - + protected: Doc VisitExpr_(const IntImmNode* op) override; Doc VisitExpr_(const FloatImmNode* op) override; Doc VisitExpr_(const StringImmNode* op) override; @@ -363,6 +349,24 @@ class TIRTextPrinter : public StmtFunctor, Doc VisitStmt_(const BlockRealizeNode* op) override; Doc VisitStmtDefault_(const Object* op) override; + private: + /*! \brief whether show meta data */ + bool show_meta_; + /*! \brief meta data context */ + TextMetaDataContext* meta_; + /*! \brief meta collector */ + MetaCollector meta_collector_; + /*! \brief Map from Var to Doc */ + std::unordered_map memo_var_; + /*! \brief Map from Buffer to Doc */ + std::unordered_map memo_buf_; + /*! \brief Map from Buffer to Doc */ + std::unordered_map memo_producer_; + /*! \brief name allocation map */ + std::unordered_map name_alloc_map_; + + friend class tvm::TextPrinter; + Doc VisitType_(const PrimTypeNode* node) override; Doc VisitType_(const PointerTypeNode* node) override; Doc VisitType_(const TupleTypeNode* node) override; diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index fc3f49d76fae..4d74cc6d5a48 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -124,7 +124,7 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) { for (const auto& it : op->attrs->dict) { attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); } - attr_doc << Doc::NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}"; + attr_doc << NewLine() << "attr = {" << PrintSep(attr_docs, Doc::Text(", ")) << "}"; doc << Doc::Indent(2, attr_doc); } @@ -136,8 +136,8 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) { const Buffer buf = op->buffer_map[v]; buffer_docs.push_back(BufferNode2Doc(buf.get(), Print(buf))); } - buffer_doc << Doc::NewLine() << "buffers = {"; - buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << Doc::NewLine())); + buffer_doc << NewLine() << "buffers = {"; + buffer_doc << PrintSep(buffer_docs, Doc::Indent(11, Doc::Text(",") << NewLine())); doc << Doc::Indent(2, buffer_doc) << "}"; } @@ -149,26 +149,28 @@ Doc TIRTextPrinter::PrintPrimFunc(const PrimFunc& prim_func) { buffer_map_doc.push_back(Print(v) << ": " << Print(buf)); } doc << Doc::Indent( - 2, Doc::NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}"); + 2, NewLine() << "buffer_map = {" << PrintSep(buffer_map_doc, Doc::Text(", ")) << "}"); } doc << PrintBody(op->body); return doc; } +Doc TIRTextPrinter::NewLine() { return Doc::NewLine(); } + Doc TIRTextPrinter::PrintIRModule(const IRModule& module) { const auto* op = module.operator->(); Doc doc; Doc body; - body << Doc::NewLine(); + body << NewLine(); std::vector functions; for (auto it = op->functions.begin(); it != op->functions.end(); ++it) { if ((*it).second.as()) { functions.push_back(Print((*it).second)); } } - body << TIRTextPrinter::PrintSep(functions, Doc::NewLine() << Doc::NewLine()); + body << TIRTextPrinter::PrintSep(functions, NewLine() << NewLine()); doc << Doc::Indent(0, body); return doc; } @@ -451,7 +453,7 @@ Doc TIRTextPrinter::VisitExpr_(const ReduceNode* op) { Doc TIRTextPrinter::VisitStmt_(const LetStmtNode* op) { Doc doc; - doc << "let " << Print(op->var) << " = " << Print(op->value) << Doc::NewLine() << Print(op->body); + doc << "let " << Print(op->var) << " = " << Print(op->value) << NewLine() << Print(op->body); return doc; } @@ -463,14 +465,14 @@ Doc TIRTextPrinter::VisitStmt_(const AttrStmtNode* op) { if (op->body->IsInstance()) { doc << PrintBody(op->body); } else { - doc << ";" << Doc::NewLine() << Print(op->body); + doc << ";" << NewLine() << Print(op->body); } return doc; } Doc TIRTextPrinter::VisitStmt_(const AssertStmtNode* op) { Doc doc; - doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" << Doc::NewLine() + doc << "assert(" << Print(op->condition) << ", " << Print(op->message) << ")" << NewLine() << Print(op->body); return doc; } @@ -529,7 +531,7 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateNode* op) { if (op->body->IsInstance()) { doc << PrintBody(op->body); } else { - doc << ";" << Doc::NewLine() << Print(op->body); + doc << ";" << NewLine() << Print(op->body); } return doc; } @@ -542,7 +544,7 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) { if (op->body->IsInstance()) { doc << PrintBody(op->body); } else { - doc << ";" << Doc::NewLine() << Print(op->body); + doc << ";" << NewLine() << Print(op->body); } return doc; } @@ -550,11 +552,11 @@ Doc TIRTextPrinter::VisitStmt_(const AllocateConstNode* op) { Doc TIRTextPrinter::VisitStmt_(const DeclBufferNode* op) { Doc doc; doc << AllocBuf(op->buffer) << " = decl_buffer(" << Print(op->buffer->data) << ", " - << PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) << ")" << Doc::NewLine(); + << PrintDType(op->buffer->dtype) << ", " << Print(op->buffer->shape) << ")" << NewLine(); if (op->body->IsInstance()) { doc << PrintBody(op->body); } else { - doc << ";" << Doc::NewLine() << Print(op->body); + doc << ";" << NewLine() << Print(op->body); } return doc; } @@ -572,9 +574,9 @@ Doc TIRTextPrinter::VisitStmt_(const SeqStmtNode* op) { std::vector stmts; Doc seq_doc, doc; for (Stmt stmt : op->seq) { - seq_doc << Doc::NewLine() << Print(stmt); + seq_doc << NewLine() << Print(stmt); } - doc << " {" << Doc::Indent(2, seq_doc) << Doc::NewLine() << "}"; + doc << " {" << Doc::Indent(2, seq_doc) << NewLine() << "}"; return doc; } @@ -657,37 +659,36 @@ Doc TIRTextPrinter::VisitStmt_(const BlockRealizeNode* op) { Doc block_attr_doc; // print predicate, binding, read/write tensor region, annotations if (!is_one(op->predicate)) { - block_attr_doc << Doc::NewLine() << "where(" << Print(op->predicate) << ")"; + block_attr_doc << NewLine() << "where(" << Print(op->predicate) << ")"; } for (size_t i = 0; i < block_op->iter_vars.size(); ++i) - block_attr_doc << Doc::NewLine() << "bind(" << Print(block_op->iter_vars[i]->var) << ", " + block_attr_doc << NewLine() << "bind(" << Print(block_op->iter_vars[i]->var) << ", " << Print(op->iter_values[i]) << ")"; - block_attr_doc << Doc::NewLine() << "tir.reads(" << Print(block_op->reads) << ")"; - block_attr_doc << Doc::NewLine() << "tir.writes(" << Print(block_op->writes) << ")"; + block_attr_doc << NewLine() << "tir.reads(" << Print(block_op->reads) << ")"; + block_attr_doc << NewLine() << "tir.writes(" << Print(block_op->writes) << ")"; if (!block_op->annotations.empty()) { std::vector attr_docs; for (const auto& it : block_op->annotations) { attr_docs.push_back(Doc::StrLiteral(it.first) << ": " << Print(it.second)); } - block_attr_doc << Doc::NewLine() << "tir.attrs({" << PrintSep(attr_docs, Doc::Text(", ")) - << "})"; + block_attr_doc << NewLine() << "tir.attrs({" << PrintSep(attr_docs, Doc::Text(", ")) << "})"; } // print body Doc body; - body << Doc::NewLine(); + body << NewLine(); for (const auto& alloc_buf : block_op->alloc_buffers) { body << AllocBuf(alloc_buf) << " = alloc_buffer(" << PrintDType(alloc_buf->dtype) - << Print(alloc_buf->shape) << ")" << Doc::NewLine(); + << Print(alloc_buf->shape) << ")" << NewLine(); } for (const auto& match_buf : block_op->match_buffers) { body << AllocBuf(match_buf->buffer) << " = match_buffer(" << Print(match_buf->source) << ")" - << Doc::NewLine(); + << NewLine(); } if (block_op->init.defined()) { Doc init_block; init_block << "with init()"; init_block << PrintBody(block_op->init.value()); - body << init_block << Doc::NewLine(); + body << init_block << NewLine(); } body << Print(block_op->body); doc << Doc::Indent(2, block_attr_doc << body); @@ -826,7 +827,7 @@ Doc TIRTextPrinter::PrintSep(const std::vector& vec, const Doc& sep) { Doc TIRTextPrinter::PrintBody(const Stmt& body, bool indent) { Doc doc; if (body->IsInstance()) return Print(body); - doc << " {" << Doc::Indent(2, Doc::NewLine() << Print(body)) << Doc::NewLine() << "}"; + doc << " {" << Doc::Indent(2, NewLine() << Print(body)) << NewLine() << "}"; return doc; } diff --git a/src/printer/tir_text_printer_debug.cc b/src/printer/tir_text_printer_debug.cc new file mode 100644 index 000000000000..4afd700d446a --- /dev/null +++ b/src/printer/tir_text_printer_debug.cc @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir_text_printer.cc + * \brief Printer to print out the IR text format + * that can be parsed by a parser. + */ + +#include "tir_text_printer_debug.h" + +#include + +#include "text_printer.h" + +namespace tvm { +namespace tir { + +std::string span_text(const Span& span) { + if (!span.defined()) { + return "missing"; + } + std::string source("file"); + return source + ":" + std::to_string(span->line) + ":" + std::to_string(span->column); +} + +Doc TIRTextPrinterDebug::NewLine() { + current_line_ += 1; + + return TIRTextPrinter::NewLine(); +} + +#define X(TypeName) \ + Doc TIRTextPrinterDebug::VisitExpr_(const TypeName##Node* op) { \ + exprs_by_line_.push_back(std::make_tuple(op, current_line_)); \ + return TIRTextPrinter::VisitExpr_(op); \ + } +TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_EXPRS +#undef X + +#define X(TypeName) \ + Doc TIRTextPrinterDebug::VisitStmt_(const TypeName##Node* op) { \ + stmts_by_line_.push_back(std::make_tuple(op, current_line_)); \ + return TIRTextPrinter::VisitStmt_(op); \ + } +TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS +#undef X + +} // namespace tir +} // namespace tvm diff --git a/src/printer/tir_text_printer_debug.h b/src/printer/tir_text_printer_debug.h new file mode 100644 index 000000000000..6150fcc2514e --- /dev/null +++ b/src/printer/tir_text_printer_debug.h @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file text_printer.h + * \brief Printer to print out the unified IR text format + * that can be parsed by a parser. + */ + +#ifndef TVM_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_ +#define TVM_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_ + +#include +#include + +#include "../tir/transforms/install_debug_spans_ops.h" +#include "text_printer.h" + +namespace tvm { +namespace tir { + +class TIRTextPrinterDebug : public TIRTextPrinter { + public: + TIRTextPrinterDebug() : TIRTextPrinter(false, &meta_), current_line_(1) {} + + std::vector> GetExprsByLine() const { + return exprs_by_line_; + } + + std::vector> GetStmtsByLine() const { return stmts_by_line_; } + + private: + Doc NewLine() override; + +#define X(TypeName) Doc VisitExpr_(const TypeName##Node* op) override; + TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_EXPRS +#undef X + +#define X(TypeName) Doc VisitStmt_(const TypeName##Node* op) override; + TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS +#undef X + + TextMetaDataContext meta_; + + // Line that the printer is currently printing + size_t current_line_; + + // Record of all stmts and exprs and their corresponding line + std::vector> stmts_by_line_; + std::vector> exprs_by_line_; +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_PRINTER_TIR_TEXT_PRINTER_DEBUG_H_ diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index facb49660078..610c450aac91 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -183,57 +183,56 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, b InitGlobalContext(dynamic_lookup); } -void CodeGenCPU::AddFunction(const PrimFunc& f) { - CodeGenLLVM::AddFunction(f); - if (f_tvm_register_system_symbol_ != nullptr) { - auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; - export_system_symbols_.emplace_back( - std::make_pair(global_symbol.value().operator std::string(), function_)); - } - AddDebugInformation(f, function_); -} - -// Following Glow |DebugInfo::generateFunctionDebugInfo|, https://git.io/fjadv -void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { +llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(const PrimFunc& f) { #if TVM_LLVM_VERSION >= 50 - ICHECK(!f_llvm->getSubprogram()); llvm::SmallVector paramTys; - // Functions in TIR can only return void or an int. - ICHECK(f_llvm->getReturnType() == t_void_ || f_llvm->getReturnType() == t_int_) - << "Unexpected return type"; - auto ret_type_tir = f_llvm->getReturnType() == t_int_ ? DataType::Int(32) : DataType::Void(); - llvm::DIType* returnTy = - GetDebugType(GetTypeFromRuntimeDataType(ret_type_tir), f_llvm->getReturnType()); - paramTys.push_back(returnTy); - for (size_t i = 0; i < f_llvm->arg_size(); ++i) { - paramTys.push_back( - GetDebugType(GetType(f_tir->params[i]), f_llvm->getFunctionType()->getParamType(i))); - } auto* DIFunctionTy = dbg_info_->di_builder_->createSubroutineType( dbg_info_->di_builder_->getOrCreateTypeArray(paramTys)); - bool local_to_unit = llvm::GlobalValue::isLocalLinkage(f_llvm->getLinkage()); + // TODO(driazati): add the right argument info to the function + bool local_to_unit = false; #if TVM_LLVM_VERSION >= 80 - auto SPFlags = - llvm::DISubprogram::toSPFlags(local_to_unit, /*IsDefinition=*/true, /*IsOptimized=*/true); + auto SPFlags = llvm::DISubprogram::toSPFlags(local_to_unit, /*IsDefinition=*/true, + /*IsOptimized=*/true); auto* DIFunction = dbg_info_->di_builder_->createFunction( - /*Scope=*/dbg_info_->file_, /*Name=*/f_llvm->getName(), /*LinkageName=*/"", + /*Scope=*/dbg_info_->file_, /*Name=*/"main.tir", /*LinkageName=*/"", /*File=*/dbg_info_->file_, /*LineNo=*/0, /*Ty=*/DIFunctionTy, /*ScopeLine=*/0, /*Flags=*/llvm::DINode::FlagZero, /*SPFlags=*/SPFlags); #else auto* DIFunction = dbg_info_->di_builder_->createFunction( - /*Scope=*/dbg_info_->file_, /*Name=*/f_llvm->getName(), /*LinkageName=*/"", + /*Scope=*/dbg_info_->file_, /*Name=*/"main.tir", /*LinkageName=*/"", /*File=*/dbg_info_->file_, /*LineNo=*/0, /*Ty=*/DIFunctionTy, /*isLocalToUnit=*/local_to_unit, /*isDefinition=*/true, /*ScopeLine=*/0, /*Flags=*/llvm::DINode::FlagPrototyped, /*isOptimized=*/true); #endif + return DIFunction; +#endif +} - ICHECK(DIFunction); - f_llvm->setSubprogram(DIFunction); - ICHECK_EQ(f_llvm->getSubprogram(), DIFunction); +void CodeGenCPU::AddFunction(const PrimFunc& f) { +#if TVM_LLVM_VERSION >= 50 + di_subprogram_ = CreateDebugFunction(f); +#endif + + EmitDebugLocation(f->span); + CodeGenLLVM::AddFunction(f); + if (f_tvm_register_system_symbol_ != nullptr) { + auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()) + << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; + export_system_symbols_.emplace_back( + std::make_pair(global_symbol.value().operator std::string(), function_)); + } + AddDebugInformation(f, function_); +} + +// Following Glow |DebugInfo::generateFunctionDebugInfo|, https://git.io/fjadv +void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { +#if TVM_LLVM_VERSION >= 50 + ICHECK(di_subprogram_); + f_llvm->setSubprogram(di_subprogram_); + ICHECK_EQ(f_llvm->getSubprogram(), di_subprogram_); IRBuilder builder(&f_llvm->getEntryBlock()); if (!f_llvm->getEntryBlock().empty()) { @@ -246,11 +245,11 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { auto* paramAlloca = builder.CreateAlloca(f_llvm->getFunctionType()->getParamType(i)); std::string paramName = "arg" + std::to_string(i + 1); auto param = dbg_info_->di_builder_->createParameterVariable( - DIFunction, paramName, i + 1, dbg_info_->file_, 0, + di_subprogram_, paramName, i + 1, dbg_info_->file_, 0, GetDebugType(GetType(f_tir->params[i]), f_llvm->getFunctionType()->getParamType(i)), /*alwaysPreserve=*/true); auto* store = builder.CreateStore(f_llvm->arg_begin() + i, paramAlloca); - auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, DIFunction); + auto* di_loc = llvm::DILocation::get(*ctx, 0, 0, di_subprogram_); dbg_info_->di_builder_->insertDeclare(paramAlloca, param, dbg_info_->di_builder_->createExpression(), llvm::DebugLoc(di_loc), store); @@ -260,6 +259,7 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { if (!scope) { return; } + for (auto& BB : *f_llvm) { for (auto& I : BB) { if (I.getDebugLoc()) { @@ -541,6 +541,7 @@ llvm::BasicBlock* CodeGenCPU::CheckCallSuccess(llvm::Value* retcode) { } void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { + EmitDebugLocation(op); /*! \brief maintain states that should be guarded when step into compute scope */ struct ComputeScopeStates { explicit ComputeScopeStates(CodeGenCPU* parent) : parent_(parent) {} @@ -950,6 +951,7 @@ llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op, bool use_string_lo } llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { + EmitDebugLocation(op); ICHECK_EQ(op->args.size(), 6U); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, op->args[4].as()->value, true); @@ -1385,6 +1387,7 @@ void CodeGenCPU::AddStartupFunction() { } llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { + EmitDebugLocation(op); if (op->op.same_as(builtin::tvm_call_packed_lowered())) { return CreateCallPacked(op, true /* use_string_lookup */); } else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) { @@ -1447,6 +1450,7 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { } void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { + EmitDebugLocation(op); llvm::Value* cond = MakeValue(op->condition); std::ostringstream os; os << "Assert fail: " << op->condition; @@ -1475,6 +1479,7 @@ void CodeGenCPU::VisitStmt_(const AssertStmtNode* op) { } void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { + EmitDebugLocation(op); if (op->attr_key == tir::attr::coproc_uop_scope) { const StringImmNode* value = op->value.as(); ICHECK(value != nullptr); @@ -1517,6 +1522,7 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { } void CodeGenCPU::VisitStmt_(const ForNode* op) { + EmitDebugLocation(op); ICHECK(is_zero(op->min)); if (op->kind == ForKind::kSerial || op->kind == ForKind::kUnrolled) { CodeGenLLVM::VisitStmt_(op); diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index e0716ac8be2d..1dc914bc2c00 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -164,6 +164,7 @@ class CodeGenCPU : public CodeGenLLVM { // if not directly finalize function and pass on return code. // return the end block after the check llvm::BasicBlock* CheckCallSuccess(llvm::Value* retcode); + llvm::DISubprogram* CreateDebugFunction(const PrimFunc& f); // Context for injection lookup llvm::GlobalVariable* gv_mod_ctx_{nullptr}; llvm::GlobalVariable* gv_tvm_func_call_{nullptr}; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 526bcf0fb26e..430dab9b5264 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -298,6 +298,7 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { } #endif + EmitDebugLocation(f->span); if (ret_void) { builder_->CreateRetVoid(); } else { @@ -556,6 +557,7 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const PrimExpr& expr) const { // void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer_var, PrimExpr index, DataType access_dtype) { + EmitDebugLocation(index->span); if (alias_var_set_.count(buffer_var) != 0) { // Mark all possibly aliased pointer as same type. llvm::MDNode* meta = md_tbaa_alias_set_; @@ -663,12 +665,13 @@ std::unique_ptr CodeGenLLVM::CreateDebugInfo(llvm::Modul debug_info->di_builder_ = llvm::make_unique(*module); #endif // TODO(tulloch): pass this information through relay::Span classes to the IRModule instance? - debug_info->file_ = debug_info->di_builder_->createFile("model.tvm", "/tmp/"); + debug_info->file_ = debug_info->di_builder_->createFile("main.tir", "."); + const int runtime_version = 0; + const bool is_optimized = false; + const char* compiler_flags = ""; debug_info->compilation_unit_ = debug_info->di_builder_->createCompileUnit( - llvm::dwarf::DW_LANG_C, debug_info->file_, "TVM", 0, "", 0, "", - llvm::DICompileUnit::DebugEmissionKind::FullDebug, - /* SplitDebugInlining */ true, - /* DebugInfoForProfiling */ true); + /*Lang=*/llvm::dwarf::DW_LANG_C, /*File=*/debug_info->file_, /*Producer=*/"TVM", is_optimized, + compiler_flags, runtime_version); return debug_info; } @@ -722,6 +725,7 @@ llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) { } llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { + ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing"; llvm::Value* mask = llvm::UndefValue::get(DTypeToLLVMType(DataType::Int(32, target_lanes))); int num_elems = GetVectorNumElements(vec); if (num_elems == target_lanes) return vec; @@ -733,6 +737,7 @@ llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { } llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { + ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing"; // To allow creating vectors from scalars, convert any scalars in "vecs" to single-lane // LLVM vector types. for (size_t i = 0, e = vecs.size(); i != e; ++i) { @@ -789,6 +794,7 @@ llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Value* stride, const Var& loop_var, const Stmt& body) { + EmitDebugLocation(body->span); llvm::BasicBlock* pre_block = builder_->GetInsertBlock(); std::string loop_var_name = loop_var->name_hint; llvm::LLVMContext* ctx = llvm_target_->GetContext(); @@ -802,8 +808,8 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va loop_value->addIncoming(begin, pre_block); ICHECK(!var_map_.count(loop_var.get())); var_map_[loop_var.get()] = loop_value; - builder_->CreateCondBr(CreateLT(loop_var.dtype(), loop_value, end), for_body, for_end, - md_very_likely_branch_); + auto lt = CreateLT(loop_var.dtype(), loop_value, end); + builder_->CreateCondBr(lt, for_body, for_end, md_very_likely_branch_); builder_->SetInsertPoint(for_body); this->VisitStmt(body); var_map_.erase(loop_var.get()); @@ -815,6 +821,7 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va // cast operatpr llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) { + ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing"; llvm::Type* target = DTypeToLLVMType(to); if (value->getType() == target) return value; if (to.is_handle()) { @@ -850,6 +857,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va llvm::Constant* CodeGenLLVM::GetGlobalConstant(llvm::Constant* const_data, const std::string& name, llvm::GlobalValue::LinkageTypes linkage_type) { + ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing"; llvm::Type* ty = const_data->getType(); llvm::GlobalVariable* global = new llvm::GlobalVariable(*module_, ty, true, linkage_type, const_data, name); @@ -865,6 +873,7 @@ llvm::Constant* CodeGenLLVM::GetGlobalConstant(llvm::Constant* const_data, const } llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { + ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing"; auto it = str_map_.find(str); if (it != str_map_.end()) return it->second; auto llvm_str = llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), str); @@ -877,6 +886,7 @@ CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, llvm::ArrayRef indices, DataType value_dtype) { + ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing"; ICHECK_EQ(indices.size(), 1) << "CodeGenLLVM requires all buffers to be flat 1-d buffers."; llvm::Value* index = indices[0]; @@ -916,6 +926,7 @@ llvm::Value* CodeGenLLVM::GetVarValue(const VarNode* v) const { void CodeGenLLVM::CreatePrintf(const std::string& format, llvm::ArrayRef format_args) { + EmitDebugLocation(); llvm::Function* func_printf = module_->getFunction("printf"); if (func_printf == nullptr) { llvm::FunctionType* ftype = llvm::FunctionType::get(t_int32_, true); @@ -946,6 +957,7 @@ void CodeGenLLVM::CreatePrintf(const std::string& format, } llvm::Value* CodeGenLLVM::CreateLookupReturnAddress(unsigned int level) { + EmitDebugLocation(); llvm::Value* level_val = llvm::ConstantInt::get(t_int32_, level); llvm::Function* builtin = llvm::Intrinsic::getDeclaration(module_.get(), llvm::Intrinsic::returnaddress); @@ -1234,6 +1246,7 @@ void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) { } llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { + EmitDebugLocation(op); if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) { ICHECK_GE(op->args.size(), 2U); llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); @@ -1270,6 +1283,7 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::bitwise_not())) { return builder_->CreateNot(MakeValue(op->args[0])); } else if (op->op.same_as(builtin::bitwise_xor())) { + EmitDebugLocation(op); return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1])); } else if (op->op.same_as(builtin::shift_left())) { return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1])); @@ -1398,20 +1412,29 @@ void CodeGenLLVM::Scalarize(const PrimExpr& e, std::functionvalue.dtype(), op->dtype, MakeValue(op->value)); } llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) { + EmitDebugLocation(op); return llvm::ConstantInt::getSigned(DTypeToLLVMType(op->dtype), op->value); } llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { + EmitDebugLocation(op); return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value); } -llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { return GetConstString(op->value); } +llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { + EmitDebugLocation(op); + return GetConstString(op->value); +} #define DEFINE_CODEGEN_BINARY_OP(Op) \ llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \ @@ -1433,6 +1456,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { return GetConstS } \ } \ llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \ + EmitDebugLocation(op); \ return Create##Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \ } @@ -1452,6 +1476,7 @@ DEFINE_CODEGEN_BINARY_OP(Mul); } \ } \ llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \ + EmitDebugLocation(op); \ return Create##Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \ } @@ -1461,6 +1486,7 @@ DEFINE_CODEGEN_CMP_OP(GT); DEFINE_CODEGEN_CMP_OP(GE); llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) { + EmitDebugLocation(op); llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); if (op->dtype.is_int()) { @@ -1474,6 +1500,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) { + EmitDebugLocation(op); llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); if (op->dtype.is_int()) { @@ -1487,18 +1514,21 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) { + EmitDebugLocation(op); llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b); } llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) { + EmitDebugLocation(op); llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b); } llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) { + EmitDebugLocation(op); llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); if (op->a.dtype().is_int() || op->a.dtype().is_uint()) { @@ -1509,6 +1539,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) { + EmitDebugLocation(op); llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); if (op->a.dtype().is_int() || op->a.dtype().is_uint()) { @@ -1519,23 +1550,28 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const AndNode* op) { + EmitDebugLocation(op); return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b)); } llvm::Value* CodeGenLLVM::VisitExpr_(const OrNode* op) { + EmitDebugLocation(op); return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b)); } llvm::Value* CodeGenLLVM::VisitExpr_(const NotNode* op) { + EmitDebugLocation(op); return builder_->CreateNot(MakeValue(op->a)); } llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) { + EmitDebugLocation(op); return builder_->CreateSelect(MakeValue(op->condition), MakeValue(op->true_value), MakeValue(op->false_value)); } llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) { + EmitDebugLocation(op); auto it = let_binding_.find(op->var); if (it != let_binding_.end()) { ICHECK(deep_equal_(it->second->value, op->value)) @@ -1652,6 +1688,7 @@ void CodeGenLLVM::BufferAccessHelper( } llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { + EmitDebugLocation(op); DataType value_dtype = op->dtype; std::vector loads; @@ -1689,6 +1726,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { + EmitDebugLocation(op); if (auto* ptr_op = op->op.as()) { auto call_op = GetRef(ptr_op); if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { @@ -1714,6 +1752,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) { + EmitDebugLocation(op); llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype)); for (int i = 0; i < op->lanes; ++i) { vec = builder_->CreateInsertElement( @@ -1723,6 +1762,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { + EmitDebugLocation(op); std::vector vecs(op->vectors.size()); int total_lanes = 0; for (int i = 0, e = op->vectors.size(); i < e; ++i) { @@ -1747,6 +1787,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) { + EmitDebugLocation(op); return CreateBroadcast(MakeValue(op->value), op->lanes); } @@ -1755,6 +1796,7 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { } void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { + EmitDebugLocation(op); DataType value_dtype = op->value.dtype(); Var buffer_var = op->buffer->data; @@ -1781,6 +1823,7 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) { } void CodeGenLLVM::VisitStmt_(const ForNode* op) { + EmitDebugLocation(op); ICHECK(is_zero(op->min)); analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); if (op->kind == ForKind::kUnrolled) { @@ -1794,6 +1837,7 @@ void CodeGenLLVM::VisitStmt_(const ForNode* op) { } void CodeGenLLVM::VisitStmt_(const WhileNode* op) { + EmitDebugLocation(op); llvm::LLVMContext* ctx = llvm_target_->GetContext(); auto* while_cond = llvm::BasicBlock::Create(*ctx, "while_cond", function_); auto* while_body = llvm::BasicBlock::Create(*ctx, "while_body", function_); @@ -1808,6 +1852,7 @@ void CodeGenLLVM::VisitStmt_(const WhileNode* op) { } void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { + EmitDebugLocation(op); llvm::Value* cond = MakeValue(op->condition); llvm::LLVMContext* ctx = llvm_target_->GetContext(); auto* then_block = llvm::BasicBlock::Create(*ctx, "if_then", function_); @@ -1831,6 +1876,7 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { } void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) { + EmitDebugLocation(op); auto data = op->data.value(); auto array = NDArrayToLLVMArray(llvm_target_->GetContext(), data); std::string symbol_name = op->buffer_var->name_hint; @@ -1842,6 +1888,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateConstNode* op) { } void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { + EmitDebugLocation(op); ICHECK_EQ(op->extents.size(), 1) << "LLVM codegen only supports flat 1-d buffer allocation, but allocation of " << op->buffer_var->name_hint << " is " << op->extents << "-d"; @@ -1892,6 +1939,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { } void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { + EmitDebugLocation(op); if (op->attr_key == tir::attr::thread_extent) { IterVar iv = Downcast(op->node); if (iv->thread_tag.length() != 0) { @@ -1917,11 +1965,14 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { } void CodeGenLLVM::VisitStmt_(const AssertStmtNode* op) { + EmitDebugLocation(op); + // auto a_cu = With cctx(analyzer_.get(), op->condition); this->VisitStmt(op->body); } void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { + EmitDebugLocation(op); const VarNode* v = op->var.get(); ICHECK(!var_map_.count(v)); if (v->dtype.is_handle()) { @@ -1941,12 +1992,16 @@ void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { } void CodeGenLLVM::VisitStmt_(const SeqStmtNode* op) { + EmitDebugLocation(op); for (Stmt stmt : op->seq) { this->VisitStmt(stmt); } } -void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } +void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { + EmitDebugLocation(op); + MakeValue(op->value); +} } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 1ae9d14dc4ad..3536f27b0107 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -42,6 +42,7 @@ #include #include #include +#include #if TVM_LLVM_VERSION >= 140 #include #else @@ -70,6 +71,7 @@ #include "../../runtime/thread_storage_scope.h" #include "../../tir/transforms/ir_utils.h" #include "codegen_params.h" +#include "llvm_instance.h" namespace llvm { class Argument; @@ -92,8 +94,6 @@ class MDBuilder; namespace tvm { namespace codegen { -class LLVMTarget; - using namespace tir; /*! @@ -523,6 +523,8 @@ class CodeGenLLVM : public ExprFunctor, ExprDeepEqual deep_equal_; // binding of let variables. Enables duplicate var defs that map to same value std::unordered_map let_binding_; + // debug info for function being compiled + llvm::DISubprogram* di_subprogram_; // Cache potential common path ops to slightly improve lookup time. // global symbol table. OpAttrMap op_attr_global_symbol_ = Op::GetAttrMap("TGlobalSymbol"); @@ -533,6 +535,24 @@ class CodeGenLLVM : public ExprFunctor, const Op& builtin_lookup_param_ = builtin::lookup_param(); const Op& builtin_tvm_call_cpacked_lowered_ = builtin::tvm_call_cpacked_lowered(); + void EmitDebugLocation(const Span& span) { + ICHECK(di_subprogram_ != nullptr) << "DISubprogram not initialized"; + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + if (!span.defined()) { + auto loc = llvm::DebugLoc(llvm::DILocation::get(*ctx, 212, 212, di_subprogram_)); + builder_->SetCurrentDebugLocation(loc); + } else { + auto loc = + llvm::DebugLoc(llvm::DILocation::get(*ctx, span->line, span->column, di_subprogram_)); + builder_->SetCurrentDebugLocation(loc); + } + } + + void EmitDebugLocation() { builder_->SetCurrentDebugLocation(nullptr); } + + void EmitDebugLocation(const StmtNode* op) { EmitDebugLocation(op->span); } + void EmitDebugLocation(const PrimExprNode* op) { EmitDebugLocation(op->span); } + /*! \brief Helper struct for debug infos. */ struct DebugInfo { ~DebugInfo(); // Because of the std::unique_ptr. diff --git a/src/tir/transforms/install_debug_spans.cc b/src/tir/transforms/install_debug_spans.cc new file mode 100644 index 000000000000..fbb8c5fe1e88 --- /dev/null +++ b/src/tir/transforms/install_debug_spans.cc @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file install_debug_spans.cc + * \brief Prints TIR code in memory and replaces all spans in the module with + the location to which the ops would be printed + */ + +#include "install_debug_spans.h" + +#include + +#include +#include + +#include "../../printer/tir_text_printer_debug.h" + +namespace tvm { +namespace tir { + +Stmt DebugInfoInstaller::InstallInfo(const Stmt& stmt) { + DebugInfoInstaller installer(stmt, "main.tir"); + auto result = installer.VisitStmt(stmt); + + // TODO(driazati): remove debugging code + tvm::tir::TIRTextPrinterDebug printer; + // Fill in the stmts and exprs' line info + auto printed_with_spans = printer.Print(result).str(); + std::ofstream out("filled-main.tir"); + out << printed_with_spans; + out.close(); + + return result; +} + +DebugInfoInstaller::DebugInfoInstaller(const Stmt& stmt, const std::string& filename) { + // Determine the line that each stmt/expr will be printed on + tvm::tir::TIRTextPrinterDebug printer; + + // Fill in the stmts and exprs' line info + auto result = printer.Print(stmt).str(); + + // Create map of the stmt/expr -> its line number in the output to later + // create new spans for each stmt/expr + const auto& stmts = printer.GetStmtsByLine(); + VLOG(0) << "Debug printer found " << stmts.size() << " stmts after printing"; + for (const auto& line : stmts) { + stmt_lines_[std::get<0>(line)] = std::get<1>(line); + } + + const auto& exprs = printer.GetExprsByLine(); + VLOG(0) << "Debug printer found " << exprs.size() << " exprs after printing"; + for (const auto& line : exprs) { + expr_lines_[std::get<0>(line)] = std::get<1>(line); + } + + // Output the printed TIR to the specified file + filename_ = std::move(filename); + std::ofstream out(filename_); + out << result; + out.close(); +} + +PrimExpr DebugInfoInstaller::VisitExpr(const PrimExpr& expr) { + PrimExpr result = expr; + result = StmtExprMutator::VisitExpr(result); + return result; +} + +Stmt DebugInfoInstaller::VisitStmt(const Stmt& stmt) { + Stmt result = stmt; + result = StmtExprMutator::VisitStmt(result); + return result; +} + +Span DebugInfoInstaller::MaybeSpan(const StmtNode* op) { + auto entry = stmt_lines_.find(op); + if (entry == stmt_lines_.end()) { + return Span(); + } else { + size_t column = 0; + size_t line = entry->second; + return Span(SourceName::Get(filename_), line, line, column, column); + } +} + +Span DebugInfoInstaller::MaybeSpan(const PrimExprNode* op) { + auto entry = expr_lines_.find(op); + if (entry == expr_lines_.end()) { + return Span(); + } else { + size_t column = 0; + size_t line = entry->second; + return Span(SourceName::Get(filename_), line, line, column, column); + } +} + +#define X(TypeName) \ + PrimExpr DebugInfoInstaller::VisitExpr_(const TypeName##Node* op) { \ + auto new_expr = StmtExprMutator::VisitExpr_(op); \ + auto new_type = Downcast(new_expr); \ + auto new_node = new_type.CopyOnWrite(); \ + new_node->span = MaybeSpan(op); \ + return new_type; \ + } +TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_EXPRS +#undef X + +#define X(TypeName) \ + Stmt DebugInfoInstaller::VisitStmt_(const TypeName##Node* op) { \ + Stmt new_stmt = StmtExprMutator::VisitStmt_(op); \ + auto new_type = Downcast(new_stmt); \ + auto new_node = new_type.CopyOnWrite(); \ + new_node->span = MaybeSpan(op); \ + return new_type; \ + } +TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS +#undef X + +namespace transform { + +Pass InstallDebugSpans() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = DebugInfoInstaller::InstallInfo(std::move(f->body)); + + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InstallDebugSpans", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InstallDebugSpans").set_body_typed(InstallDebugSpans); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/install_debug_spans.h b/src/tir/transforms/install_debug_spans.h new file mode 100644 index 000000000000..c7946e404400 --- /dev/null +++ b/src/tir/transforms/install_debug_spans.h @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file install_debug_spans.h + * \brief Interface of the InstallDebugSpans pass + */ + +#ifndef TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_H_ +#define TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_H_ + +#include +#include +#include +#include + +#include +#include + +#include "install_debug_spans_ops.h" + +namespace tvm { +namespace tir { + +class DebugInfoInstaller : public StmtExprMutator { + public: + static Stmt InstallInfo(const Stmt& stmt); + + PrimExpr VisitExpr(const PrimExpr& expr) override; + Stmt VisitStmt(const Stmt& stmt) override; + + protected: + DebugInfoInstaller(const Stmt& stmt, const std::string& filename); + +#define X(TypeName) PrimExpr VisitExpr_(const TypeName##Node* op) override; + TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_EXPRS +#undef X + +#define X(TypeName) Stmt VisitStmt_(const TypeName##Node* op) override; + TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS +#undef X + + private: + std::unordered_map stmt_lines_; + std::unordered_map expr_lines_; + std::string filename_; + + template + Stmt add_span(const ObjectName* op) { + Stmt new_stmt = StmtExprMutator::VisitStmt_(op); + auto new_type = Downcast(new_stmt); + auto new_node = new_type.CopyOnWrite(); + new_node->span = MaybeSpan(op); + return new_type; + } + + Span MaybeSpan(const StmtNode* op); + Span MaybeSpan(const PrimExprNode* op); +}; + +} // namespace tir +} // namespace tvm + +#endif // TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_H_ diff --git a/src/tir/transforms/install_debug_spans_ops.h b/src/tir/transforms/install_debug_spans_ops.h new file mode 100644 index 000000000000..245c0d164e85 --- /dev/null +++ b/src/tir/transforms/install_debug_spans_ops.h @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file install_debug_spans_ops.h + * \brief List of stmts and exprs supported by the debug info pass + */ + +#ifndef TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_OPS_H_ +#define TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_OPS_H_ + +#define TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_EXPRS \ + X(Call) \ + X(Add) \ + X(Sub) \ + X(Mul) \ + X(Div) \ + X(Mod) \ + X(FloorDiv) \ + X(FloorMod) \ + X(Min) \ + X(Max) \ + X(EQ) \ + X(NE) \ + X(LT) \ + X(LE) \ + X(GT) \ + X(GE) \ + X(And) \ + X(Or) \ + X(Reduce) \ + X(Cast) \ + X(Not) \ + X(Select) \ + X(Ramp) \ + X(Broadcast) \ + X(Shuffle) \ + X(IntImm) \ + X(FloatImm) \ + X(StringImm) + +#define TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS \ + X(AttrStmt) \ + X(IfThenElse) \ + X(LetStmt) \ + X(For) \ + X(While) \ + X(Allocate) \ + X(AllocateConst) \ + X(DeclBuffer) \ + X(Store) \ + X(BufferStore) \ + X(BufferRealize) \ + X(AssertStmt) \ + X(ProducerStore) \ + X(ProducerRealize) \ + X(Prefetch) \ + X(SeqStmt) \ + X(Evaluate) \ + X(BlockRealize) + +#endif // TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_OPS_H_ diff --git a/tests/python/tir/test_debug_info.py b/tests/python/tir/test_debug_info.py new file mode 100644 index 000000000000..ede777132c00 --- /dev/null +++ b/tests/python/tir/test_debug_info.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test line-level debug info for TIR""" +import tvm +import tvm.testing +from tvm import tir +from tvm import relay +from tvm.script import tir as T + +from typing import List, Dict +import re + + +def find_di_locations(source: str) -> Dict[int, int]: + """ + Parse out DILocation references in printed LLVM IR + """ + result = {} + + for line in source.splitlines(): + m = re.match(r"!(\d+) = !DILocation\(line: (\d+).*", line) + if m: + debug_id, line = m.groups() + result[debug_id] = line + + return result + + +def test_llvm_ir_debug_info(): + @tvm.script.ir_module + class MyModule: + @T.prim_func + def main(a: T.handle, b: T.handle): + # We exchange data between function by handles, which are similar to pointer. + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # Create buffer from handles. + A = T.match_buffer(a, (8,), dtype="float32") + B = T.match_buffer(b, (8,), dtype="float32") + for i in range(8): + # A block is an abstraction for computation. + with T.block("B"): + # Define a spatial block iterator and bind it to value i. + vi = T.axis.spatial(8, i) + assert 1 == 0, "Some numbers" + B[vi] = A[vi] + 1.0 + + with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": True}): + runtime_module = tvm.build(MyModule, target="llvm") + + source = runtime_module.get_source() + + locations = find_di_locations(source) + assert len(locations) == 33 + + +if __name__ == "__main__": + tvm.testing.main() From 3ba16a866c7bfc2e8b7cf217985a712bbdeaf5c4 Mon Sep 17 00:00:00 2001 From: driazati Date: Thu, 20 Oct 2022 23:30:28 -0700 Subject: [PATCH 2/7] Fix location emission --- .gitignore | 3 ++ src/printer/tir_text_printer_debug.cc | 47 ++++++++++++++++-- src/printer/tir_text_printer_debug.h | 6 ++- src/target/llvm/codegen_cpu.cc | 17 +++++-- src/target/llvm/codegen_cpu.h | 1 + src/target/llvm/codegen_llvm.cc | 58 ++++++++--------------- src/target/llvm/codegen_llvm.h | 22 ++------- src/tir/transforms/install_debug_spans.cc | 27 +++++------ src/tir/transforms/install_debug_spans.h | 2 +- tests/python/tir/test_debug_info.py | 2 +- 10 files changed, 103 insertions(+), 82 deletions(-) diff --git a/.gitignore b/.gitignore index 8920bc741770..851552d95976 100644 --- a/.gitignore +++ b/.gitignore @@ -274,3 +274,6 @@ gallery/how_to/work_with_microtvm/micro_tvmc.py # Printed TIR code on disk *.tir + +# GDB history file +.gdb_history diff --git a/src/printer/tir_text_printer_debug.cc b/src/printer/tir_text_printer_debug.cc index 4afd700d446a..646985a64df9 100644 --- a/src/printer/tir_text_printer_debug.cc +++ b/src/printer/tir_text_printer_debug.cc @@ -25,6 +25,7 @@ #include "tir_text_printer_debug.h" +#include #include #include "text_printer.h" @@ -32,18 +33,56 @@ namespace tvm { namespace tir { -std::string span_text(const Span& span) { +std::optional span_text(const Span& span) { if (!span.defined()) { - return "missing"; + return std::nullopt; + } + + std::string source("main.tir"); + if (span->source_name.defined() && span->source_name->name.get()) { + source = span->source_name->name; } - std::string source("file"); return source + ":" + std::to_string(span->line) + ":" + std::to_string(span->column); } +template +void add_all_relevant_lines(const std::vector>& data, + size_t current_line, Doc* output) { + ICHECK(output) << "output must be a valid Doc"; + for (const auto& item : data) { + if (std::get<1>(item) != current_line - 1) { + // Item is not relevant for this line, skip it + continue; + } + + // Print out the item's span info if present + auto text = span_text(std::get<0>(item)->span); + if (text.has_value()) { + *output << *text; + } else { + *output << "missing"; + } + *output << ", "; + } +} + Doc TIRTextPrinterDebug::NewLine() { current_line_ += 1; - return TIRTextPrinter::NewLine(); + if (!show_spans_) { + return TIRTextPrinter::NewLine(); + } + + Doc output; + + output << " ["; + + add_all_relevant_lines(exprs_by_line_, current_line_, &output); + add_all_relevant_lines(stmts_by_line_, current_line_, &output); + + output << "]" << TIRTextPrinter::NewLine(); + + return output; } #define X(TypeName) \ diff --git a/src/printer/tir_text_printer_debug.h b/src/printer/tir_text_printer_debug.h index 6150fcc2514e..b6c77ce989ae 100644 --- a/src/printer/tir_text_printer_debug.h +++ b/src/printer/tir_text_printer_debug.h @@ -37,7 +37,8 @@ namespace tir { class TIRTextPrinterDebug : public TIRTextPrinter { public: - TIRTextPrinterDebug() : TIRTextPrinter(false, &meta_), current_line_(1) {} + explicit TIRTextPrinterDebug(bool show_spans) + : TIRTextPrinter(false, &meta_), current_line_(1), show_spans_(show_spans) {} std::vector> GetExprsByLine() const { return exprs_by_line_; @@ -61,6 +62,9 @@ class TIRTextPrinterDebug : public TIRTextPrinter { // Line that the printer is currently printing size_t current_line_; + // Whether to include spans relevant to each line before a newline or not + bool show_spans_; + // Record of all stmts and exprs and their corresponding line std::vector> stmts_by_line_; std::vector> exprs_by_line_; diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 610c450aac91..292ceadfe843 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -186,11 +186,16 @@ void CodeGenCPU::Init(const std::string& module_name, LLVMTarget* llvm_target, b llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(const PrimFunc& f) { #if TVM_LLVM_VERSION >= 50 llvm::SmallVector paramTys; + + paramTys.push_back(GetDebugType(f->ret_type)); + for (const auto& param : f->params) { + paramTys.push_back(GetDebugType(GetType(param))); + } + auto* DIFunctionTy = dbg_info_->di_builder_->createSubroutineType( dbg_info_->di_builder_->getOrCreateTypeArray(paramTys)); - // TODO(driazati): add the right argument info to the function - bool local_to_unit = false; + bool local_to_unit = llvm::GlobalVariable::isLocalLinkage(llvm::GlobalValue::InternalLinkage); #if TVM_LLVM_VERSION >= 80 auto SPFlags = llvm::DISubprogram::toSPFlags(local_to_unit, /*IsDefinition=*/true, @@ -207,6 +212,8 @@ llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(const PrimFunc& f) { /*Flags=*/llvm::DINode::FlagPrototyped, /*isOptimized=*/true); #endif return DIFunction; +#else + return nullptr; #endif } @@ -214,7 +221,6 @@ void CodeGenCPU::AddFunction(const PrimFunc& f) { #if TVM_LLVM_VERSION >= 50 di_subprogram_ = CreateDebugFunction(f); #endif - EmitDebugLocation(f->span); CodeGenLLVM::AddFunction(f); if (f_tvm_register_system_symbol_ != nullptr) { @@ -272,6 +278,9 @@ void CodeGenCPU::AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm) { #endif } +llvm::DIType* CodeGenCPU::GetDebugType(const Type& ty_tir) { + return GetDebugType(ty_tir, GetLLVMType(ty_tir)); +} llvm::DIType* CodeGenCPU::GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm) { if (ty_llvm == t_void_) { return nullptr; @@ -951,7 +960,6 @@ llvm::Value* CodeGenCPU::CreateCallPacked(const CallNode* op, bool use_string_lo } llvm::Value* CodeGenCPU::CreateCallTracePacked(const CallNode* op) { - EmitDebugLocation(op); ICHECK_EQ(op->args.size(), 6U); PackedCall pc = MakeCallPackedLowered(op->args, op->dtype, op->args[3].as()->value, op->args[4].as()->value, true); @@ -1387,7 +1395,6 @@ void CodeGenCPU::AddStartupFunction() { } llvm::Value* CodeGenCPU::CreateIntrinsic(const CallNode* op) { - EmitDebugLocation(op); if (op->op.same_as(builtin::tvm_call_packed_lowered())) { return CreateCallPacked(op, true /* use_string_lookup */); } else if (op->op.same_as(builtin::tvm_call_trace_packed_lowered())) { diff --git a/src/target/llvm/codegen_cpu.h b/src/target/llvm/codegen_cpu.h index 1dc914bc2c00..afbd49e14348 100644 --- a/src/target/llvm/codegen_cpu.h +++ b/src/target/llvm/codegen_cpu.h @@ -195,6 +195,7 @@ class CodeGenCPU : public CodeGenLLVM { // Get the DWARF type corresponding to the LLVM type |ty|. The current API in practice only // generates |int32|, and |int8*|. + llvm::DIType* GetDebugType(const Type& ty_tir); llvm::DIType* GetDebugType(const Type& ty_tir, llvm::Type* ty_llvm); // Adds the DWARF debug information for |function| to |dbg_info_|. void AddDebugInformation(PrimFunc f_tir, llvm::Function* f_llvm); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 430dab9b5264..2182ecfa51ce 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -725,7 +725,6 @@ llvm::Value* CodeGenLLVM::CreateVecFlip(llvm::Value* vec) { } llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { - ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing"; llvm::Value* mask = llvm::UndefValue::get(DTypeToLLVMType(DataType::Int(32, target_lanes))); int num_elems = GetVectorNumElements(vec); if (num_elems == target_lanes) return vec; @@ -737,7 +736,6 @@ llvm::Value* CodeGenLLVM::CreateVecPad(llvm::Value* vec, int target_lanes) { } llvm::Value* CodeGenLLVM::CreateVecConcat(std::vector vecs) { - ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing"; // To allow creating vectors from scalars, convert any scalars in "vecs" to single-lane // LLVM vector types. for (size_t i = 0, e = vecs.size(); i != e; ++i) { @@ -821,7 +819,6 @@ void CodeGenLLVM::CreateSerialFor(llvm::Value* begin, llvm::Value* end, llvm::Va // cast operatpr llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* value) { - ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing"; llvm::Type* target = DTypeToLLVMType(to); if (value->getType() == target) return value; if (to.is_handle()) { @@ -857,7 +854,6 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, DataType to, llvm::Value* va llvm::Constant* CodeGenLLVM::GetGlobalConstant(llvm::Constant* const_data, const std::string& name, llvm::GlobalValue::LinkageTypes linkage_type) { - ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing"; llvm::Type* ty = const_data->getType(); llvm::GlobalVariable* global = new llvm::GlobalVariable(*module_, ty, true, linkage_type, const_data, name); @@ -873,7 +869,6 @@ llvm::Constant* CodeGenLLVM::GetGlobalConstant(llvm::Constant* const_data, const } llvm::Constant* CodeGenLLVM::GetConstString(const std::string& str) { - ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing"; auto it = str_map_.find(str); if (it != str_map_.end()) return it->second; auto llvm_str = llvm::ConstantDataArray::getString(*llvm_target_->GetContext(), str); @@ -886,7 +881,6 @@ CodeGenLLVM::TypedPointer CodeGenLLVM::CreateBufferPtr(llvm::Value* buffer_ptr, DataType buffer_element_dtype, llvm::ArrayRef indices, DataType value_dtype) { - ICHECK(builder_->getCurrentDebugLocation() != llvm::DebugLoc()) << "Debug information missing"; ICHECK_EQ(indices.size(), 1) << "CodeGenLLVM requires all buffers to be flat 1-d buffers."; llvm::Value* index = indices[0]; @@ -1246,7 +1240,6 @@ void CodeGenLLVM::EmitFloat16ConversionBuiltins(bool use_float16_abi) { } llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { - EmitDebugLocation(op); if (op->op.same_as(builtin_call_llvm_intrin_) || op->op.same_as(builtin_call_llvm_pure_intrin_)) { ICHECK_GE(op->args.size(), 2U); llvm::Intrinsic::ID id = static_cast(Downcast(op->args[0])->value); @@ -1283,7 +1276,6 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) { } else if (op->op.same_as(builtin::bitwise_not())) { return builder_->CreateNot(MakeValue(op->args[0])); } else if (op->op.same_as(builtin::bitwise_xor())) { - EmitDebugLocation(op); return builder_->CreateXor(MakeValue(op->args[0]), MakeValue(op->args[1])); } else if (op->op.same_as(builtin::shift_left())) { return builder_->CreateShl(MakeValue(op->args[0]), MakeValue(op->args[1])); @@ -1412,29 +1404,20 @@ void CodeGenLLVM::Scalarize(const PrimExpr& e, std::functionvalue.dtype(), op->dtype, MakeValue(op->value)); } llvm::Value* CodeGenLLVM::VisitExpr_(const IntImmNode* op) { - EmitDebugLocation(op); return llvm::ConstantInt::getSigned(DTypeToLLVMType(op->dtype), op->value); } llvm::Value* CodeGenLLVM::VisitExpr_(const FloatImmNode* op) { - EmitDebugLocation(op); return llvm::ConstantFP::get(DTypeToLLVMType(op->dtype), op->value); } -llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { - EmitDebugLocation(op); - return GetConstString(op->value); -} +llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { return GetConstString(op->value); } #define DEFINE_CODEGEN_BINARY_OP(Op) \ llvm::Value* CodeGenLLVM::Create##Op(DataType t, llvm::Value* a, llvm::Value* b) { \ @@ -1456,7 +1439,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const StringImmNode* op) { } \ } \ llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \ - EmitDebugLocation(op); \ return Create##Op(op->dtype, MakeValue(op->a), MakeValue(op->b)); \ } @@ -1476,7 +1458,6 @@ DEFINE_CODEGEN_BINARY_OP(Mul); } \ } \ llvm::Value* CodeGenLLVM::VisitExpr_(const Op##Node* op) { \ - EmitDebugLocation(op); \ return Create##Op(op->a.dtype(), MakeValue(op->a), MakeValue(op->b)); \ } @@ -1486,7 +1467,6 @@ DEFINE_CODEGEN_CMP_OP(GT); DEFINE_CODEGEN_CMP_OP(GE); llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) { - EmitDebugLocation(op); llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); if (op->dtype.is_int()) { @@ -1500,7 +1480,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const DivNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) { - EmitDebugLocation(op); llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); if (op->dtype.is_int()) { @@ -1514,21 +1493,18 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ModNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const MinNode* op) { - EmitDebugLocation(op); llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); return builder_->CreateSelect(CreateLT(op->a.dtype(), a, b), a, b); } llvm::Value* CodeGenLLVM::VisitExpr_(const MaxNode* op) { - EmitDebugLocation(op); llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); return builder_->CreateSelect(CreateGT(op->a.dtype(), a, b), a, b); } llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) { - EmitDebugLocation(op); llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); if (op->a.dtype().is_int() || op->a.dtype().is_uint()) { @@ -1539,7 +1515,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const EQNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) { - EmitDebugLocation(op); llvm::Value* a = MakeValue(op->a); llvm::Value* b = MakeValue(op->b); if (op->a.dtype().is_int() || op->a.dtype().is_uint()) { @@ -1550,28 +1525,23 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const NENode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const AndNode* op) { - EmitDebugLocation(op); return builder_->CreateAnd(MakeValue(op->a), MakeValue(op->b)); } llvm::Value* CodeGenLLVM::VisitExpr_(const OrNode* op) { - EmitDebugLocation(op); return builder_->CreateOr(MakeValue(op->a), MakeValue(op->b)); } llvm::Value* CodeGenLLVM::VisitExpr_(const NotNode* op) { - EmitDebugLocation(op); return builder_->CreateNot(MakeValue(op->a)); } llvm::Value* CodeGenLLVM::VisitExpr_(const SelectNode* op) { - EmitDebugLocation(op); return builder_->CreateSelect(MakeValue(op->condition), MakeValue(op->true_value), MakeValue(op->false_value)); } llvm::Value* CodeGenLLVM::VisitExpr_(const LetNode* op) { - EmitDebugLocation(op); auto it = let_binding_.find(op->var); if (it != let_binding_.end()) { ICHECK(deep_equal_(it->second->value, op->value)) @@ -1688,7 +1658,6 @@ void CodeGenLLVM::BufferAccessHelper( } llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { - EmitDebugLocation(op); DataType value_dtype = op->dtype; std::vector loads; @@ -1726,7 +1695,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { - EmitDebugLocation(op); if (auto* ptr_op = op->op.as()) { auto call_op = GetRef(ptr_op); if (op->op.same_as(builtin_call_extern_) || op->op.same_as(builtin_call_pure_extern_)) { @@ -1752,7 +1720,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) { - EmitDebugLocation(op); llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype)); for (int i = 0; i < op->lanes; ++i) { vec = builder_->CreateInsertElement( @@ -1762,7 +1729,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { - EmitDebugLocation(op); std::vector vecs(op->vectors.size()); int total_lanes = 0; for (int i = 0, e = op->vectors.size(); i < e; ++i) { @@ -1787,7 +1753,6 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode* op) { } llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) { - EmitDebugLocation(op); return CreateBroadcast(MakeValue(op->value), op->lanes); } @@ -2003,6 +1968,25 @@ void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } +void CodeGenLLVM::EmitDebugLocation(const Span& span) { +#if TVM_LLVM_VERSION >= 50 + if (di_subprogram_ == nullptr) { + // debug info is not always generated outside of CPU codegen + return; + } + if (!span.defined()) { + VLOG(0) << "Cannot emit debug location for undefined span"; + return; + } + llvm::LLVMContext* ctx = llvm_target_->GetContext(); + auto loc = llvm::DebugLoc(llvm::DILocation::get(*ctx, span->line, span->column, di_subprogram_)); + builder_->SetCurrentDebugLocation(loc); +#endif +} + +void CodeGenLLVM::EmitDebugLocation() { builder_->SetCurrentDebugLocation(nullptr); } +void CodeGenLLVM::EmitDebugLocation(const StmtNode* op) { EmitDebugLocation(op->span); } + } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index 3536f27b0107..632cfaafc51a 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -37,12 +37,12 @@ #else #include #endif +#include #include #include #include #include #include -#include #if TVM_LLVM_VERSION >= 140 #include #else @@ -535,23 +535,9 @@ class CodeGenLLVM : public ExprFunctor, const Op& builtin_lookup_param_ = builtin::lookup_param(); const Op& builtin_tvm_call_cpacked_lowered_ = builtin::tvm_call_cpacked_lowered(); - void EmitDebugLocation(const Span& span) { - ICHECK(di_subprogram_ != nullptr) << "DISubprogram not initialized"; - llvm::LLVMContext* ctx = llvm_target_->GetContext(); - if (!span.defined()) { - auto loc = llvm::DebugLoc(llvm::DILocation::get(*ctx, 212, 212, di_subprogram_)); - builder_->SetCurrentDebugLocation(loc); - } else { - auto loc = - llvm::DebugLoc(llvm::DILocation::get(*ctx, span->line, span->column, di_subprogram_)); - builder_->SetCurrentDebugLocation(loc); - } - } - - void EmitDebugLocation() { builder_->SetCurrentDebugLocation(nullptr); } - - void EmitDebugLocation(const StmtNode* op) { EmitDebugLocation(op->span); } - void EmitDebugLocation(const PrimExprNode* op) { EmitDebugLocation(op->span); } + void EmitDebugLocation(); + void EmitDebugLocation(const Span& span); + void EmitDebugLocation(const StmtNode* op); /*! \brief Helper struct for debug infos. */ struct DebugInfo { diff --git a/src/tir/transforms/install_debug_spans.cc b/src/tir/transforms/install_debug_spans.cc index fbb8c5fe1e88..4daa1aafe8cc 100644 --- a/src/tir/transforms/install_debug_spans.cc +++ b/src/tir/transforms/install_debug_spans.cc @@ -35,24 +35,14 @@ namespace tvm { namespace tir { -Stmt DebugInfoInstaller::InstallInfo(const Stmt& stmt) { - DebugInfoInstaller installer(stmt, "main.tir"); - auto result = installer.VisitStmt(stmt); - - // TODO(driazati): remove debugging code - tvm::tir::TIRTextPrinterDebug printer; - // Fill in the stmts and exprs' line info - auto printed_with_spans = printer.Print(result).str(); - std::ofstream out("filled-main.tir"); - out << printed_with_spans; - out.close(); - - return result; +Stmt DebugInfoInstaller::InstallInfo(const std::string& name, const Stmt& stmt) { + DebugInfoInstaller installer(stmt, name + ".tir"); + return installer.VisitStmt(stmt); } DebugInfoInstaller::DebugInfoInstaller(const Stmt& stmt, const std::string& filename) { // Determine the line that each stmt/expr will be printed on - tvm::tir::TIRTextPrinterDebug printer; + tvm::tir::TIRTextPrinterDebug printer(false); // Fill in the stmts and exprs' line info auto result = printer.Print(stmt).str(); @@ -72,6 +62,7 @@ DebugInfoInstaller::DebugInfoInstaller(const Stmt& stmt, const std::string& file } // Output the printed TIR to the specified file + VLOG(0) << "Outputting TIR to " << filename; filename_ = std::move(filename); std::ofstream out(filename_); out << result; @@ -138,8 +129,14 @@ namespace transform { Pass InstallDebugSpans() { auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + ICHECK(m->functions.size() == 1) + << "Debug info can only be added to IRModules with a single function"; + // There is known to be only 1 function in the module at this point + auto entry = m->functions.begin(); + auto name = std::get<0>(*entry)->name_hint; auto* n = f.CopyOnWrite(); - n->body = DebugInfoInstaller::InstallInfo(std::move(f->body)); + + n->body = DebugInfoInstaller::InstallInfo(std::move(name), std::move(f->body)); return f; }; diff --git a/src/tir/transforms/install_debug_spans.h b/src/tir/transforms/install_debug_spans.h index c7946e404400..c5fc4922dc42 100644 --- a/src/tir/transforms/install_debug_spans.h +++ b/src/tir/transforms/install_debug_spans.h @@ -40,7 +40,7 @@ namespace tir { class DebugInfoInstaller : public StmtExprMutator { public: - static Stmt InstallInfo(const Stmt& stmt); + static Stmt InstallInfo(const std::string& name, const Stmt& stmt); PrimExpr VisitExpr(const PrimExpr& expr) override; Stmt VisitStmt(const Stmt& stmt) override; diff --git a/tests/python/tir/test_debug_info.py b/tests/python/tir/test_debug_info.py index ede777132c00..03524e8776d6 100644 --- a/tests/python/tir/test_debug_info.py +++ b/tests/python/tir/test_debug_info.py @@ -64,7 +64,7 @@ def main(a: T.handle, b: T.handle): source = runtime_module.get_source() locations = find_di_locations(source) - assert len(locations) == 33 + assert len(locations) == 34 if __name__ == "__main__": From 5c0be6213ee7dd965bca6503bf0dfee1fe4ba707 Mon Sep 17 00:00:00 2001 From: driazati Date: Sun, 20 Nov 2022 21:19:22 -0800 Subject: [PATCH 3/7] Comments 1/N (docs, cleanups) --- src/target/llvm/codegen_cpu.cc | 1 + src/tir/transforms/install_debug_spans.h | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 292ceadfe843..21d2c6ebe0a5 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -197,6 +197,7 @@ llvm::DISubprogram* CodeGenCPU::CreateDebugFunction(const PrimFunc& f) { bool local_to_unit = llvm::GlobalVariable::isLocalLinkage(llvm::GlobalValue::InternalLinkage); + // TODO(driazati): determine the IRModule name instead of hardcoding 'main.tir' #if TVM_LLVM_VERSION >= 80 auto SPFlags = llvm::DISubprogram::toSPFlags(local_to_unit, /*IsDefinition=*/true, /*IsOptimized=*/true); diff --git a/src/tir/transforms/install_debug_spans.h b/src/tir/transforms/install_debug_spans.h index c5fc4922dc42..ab6910103b9d 100644 --- a/src/tir/transforms/install_debug_spans.h +++ b/src/tir/transforms/install_debug_spans.h @@ -38,6 +38,14 @@ namespace tvm { namespace tir { +/*! + * \brief This Pass prints out the provided 'stmt' through the TIR debug printer + while recording the statements and expressions printed on each line. Running + this pass uses the per-line information to change the Spans attached to each + statement and expression to the source location in the printed TIR. This pass + also writes to a file called '.tir' so the line information used is + saved to disk. + */ class DebugInfoInstaller : public StmtExprMutator { public: static Stmt InstallInfo(const std::string& name, const Stmt& stmt); From 849fe60679eef1f285b7559cfed02f7b2c49a4f2 Mon Sep 17 00:00:00 2001 From: driazati Date: Sun, 20 Nov 2022 22:09:11 -0800 Subject: [PATCH 4/7] Remove leaky macro usage --- src/printer/tir_text_printer_debug.cc | 24 ++---- src/printer/tir_text_printer_debug.h | 10 +-- src/tir/transforms/install_debug_spans.h | 55 +++++++++++++- src/tir/transforms/install_debug_spans_ops.h | 78 -------------------- 4 files changed, 64 insertions(+), 103 deletions(-) delete mode 100644 src/tir/transforms/install_debug_spans_ops.h diff --git a/src/printer/tir_text_printer_debug.cc b/src/printer/tir_text_printer_debug.cc index 646985a64df9..6c29558f722c 100644 --- a/src/printer/tir_text_printer_debug.cc +++ b/src/printer/tir_text_printer_debug.cc @@ -28,8 +28,6 @@ #include #include -#include "text_printer.h" - namespace tvm { namespace tir { @@ -85,21 +83,15 @@ Doc TIRTextPrinterDebug::NewLine() { return output; } -#define X(TypeName) \ - Doc TIRTextPrinterDebug::VisitExpr_(const TypeName##Node* op) { \ - exprs_by_line_.push_back(std::make_tuple(op, current_line_)); \ - return TIRTextPrinter::VisitExpr_(op); \ - } -TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_EXPRS -#undef X +Doc TIRTextPrinterDebug::VisitStmt(const tvm::tir::Stmt& n) { + stmts_by_line_.push_back(std::make_tuple(n.get(), current_line_)); + return TIRTextPrinter::VisitStmt(n); +} -#define X(TypeName) \ - Doc TIRTextPrinterDebug::VisitStmt_(const TypeName##Node* op) { \ - stmts_by_line_.push_back(std::make_tuple(op, current_line_)); \ - return TIRTextPrinter::VisitStmt_(op); \ - } -TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS -#undef X +Doc TIRTextPrinterDebug::VisitExpr(const PrimExpr& e) { + exprs_by_line_.push_back(std::make_tuple(e.get(), current_line_)); + return TIRTextPrinter::VisitExpr(e); +} } // namespace tir } // namespace tvm diff --git a/src/printer/tir_text_printer_debug.h b/src/printer/tir_text_printer_debug.h index b6c77ce989ae..d0046034cfbf 100644 --- a/src/printer/tir_text_printer_debug.h +++ b/src/printer/tir_text_printer_debug.h @@ -29,7 +29,6 @@ #include #include -#include "../tir/transforms/install_debug_spans_ops.h" #include "text_printer.h" namespace tvm { @@ -49,13 +48,8 @@ class TIRTextPrinterDebug : public TIRTextPrinter { private: Doc NewLine() override; -#define X(TypeName) Doc VisitExpr_(const TypeName##Node* op) override; - TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_EXPRS -#undef X - -#define X(TypeName) Doc VisitStmt_(const TypeName##Node* op) override; - TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS -#undef X + Doc VisitStmt(const tvm::tir::Stmt& n) override; + Doc VisitExpr(const PrimExpr& e) override; TextMetaDataContext meta_; diff --git a/src/tir/transforms/install_debug_spans.h b/src/tir/transforms/install_debug_spans.h index ab6910103b9d..836c95fb4834 100644 --- a/src/tir/transforms/install_debug_spans.h +++ b/src/tir/transforms/install_debug_spans.h @@ -33,7 +33,60 @@ #include #include -#include "install_debug_spans_ops.h" +#ifndef TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_OPS_H_ +#define TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_OPS_H_ + +#define TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_EXPRS \ + X(Call) \ + X(Add) \ + X(Sub) \ + X(Mul) \ + X(Div) \ + X(Mod) \ + X(FloorDiv) \ + X(FloorMod) \ + X(Min) \ + X(Max) \ + X(EQ) \ + X(NE) \ + X(LT) \ + X(LE) \ + X(GT) \ + X(GE) \ + X(And) \ + X(Or) \ + X(Reduce) \ + X(Cast) \ + X(Not) \ + X(Select) \ + X(Ramp) \ + X(Broadcast) \ + X(Shuffle) \ + X(IntImm) \ + X(FloatImm) \ + X(StringImm) + +#define TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS \ + X(AttrStmt) \ + X(IfThenElse) \ + X(LetStmt) \ + X(For) \ + X(While) \ + X(Allocate) \ + X(AllocateConst) \ + X(DeclBuffer) \ + X(Store) \ + X(BufferStore) \ + X(BufferRealize) \ + X(AssertStmt) \ + X(ProducerStore) \ + X(ProducerRealize) \ + X(Prefetch) \ + X(SeqStmt) \ + X(Evaluate) \ + X(BlockRealize) + +#endif // TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_OPS_H_ namespace tvm { namespace tir { diff --git a/src/tir/transforms/install_debug_spans_ops.h b/src/tir/transforms/install_debug_spans_ops.h deleted file mode 100644 index 245c0d164e85..000000000000 --- a/src/tir/transforms/install_debug_spans_ops.h +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file install_debug_spans_ops.h - * \brief List of stmts and exprs supported by the debug info pass - */ - -#ifndef TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_OPS_H_ -#define TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_OPS_H_ - -#define TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_EXPRS \ - X(Call) \ - X(Add) \ - X(Sub) \ - X(Mul) \ - X(Div) \ - X(Mod) \ - X(FloorDiv) \ - X(FloorMod) \ - X(Min) \ - X(Max) \ - X(EQ) \ - X(NE) \ - X(LT) \ - X(LE) \ - X(GT) \ - X(GE) \ - X(And) \ - X(Or) \ - X(Reduce) \ - X(Cast) \ - X(Not) \ - X(Select) \ - X(Ramp) \ - X(Broadcast) \ - X(Shuffle) \ - X(IntImm) \ - X(FloatImm) \ - X(StringImm) - -#define TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS \ - X(AttrStmt) \ - X(IfThenElse) \ - X(LetStmt) \ - X(For) \ - X(While) \ - X(Allocate) \ - X(AllocateConst) \ - X(DeclBuffer) \ - X(Store) \ - X(BufferStore) \ - X(BufferRealize) \ - X(AssertStmt) \ - X(ProducerStore) \ - X(ProducerRealize) \ - X(Prefetch) \ - X(SeqStmt) \ - X(Evaluate) \ - X(BlockRealize) - -#endif // TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_OPS_H_ From b89c185c69757254bb109cc54d24b998f372f623 Mon Sep 17 00:00:00 2001 From: driazati Date: Mon, 21 Nov 2022 15:26:23 -0800 Subject: [PATCH 5/7] Add unit test --- python/tvm/tir/transform/transform.py | 12 +++++++++++ tests/python/tir/test_debug_info.py | 29 ++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 9b0e5748bcc0..bc3ec5b2ad74 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -1028,3 +1028,15 @@ def InstrumentProfileIntrinsics(): The result pass """ return _ffi_api.InstrumentProfileIntrinsics() # type: ignore + + +def InstallDebugSpans(): + """Add line information from the TIR printer as spans on each statement and + expression. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.InstallDebugSpans() # type: ignore diff --git a/tests/python/tir/test_debug_info.py b/tests/python/tir/test_debug_info.py index 03524e8776d6..c09df012c8c5 100644 --- a/tests/python/tir/test_debug_info.py +++ b/tests/python/tir/test_debug_info.py @@ -40,7 +40,7 @@ def find_di_locations(source: str) -> Dict[int, int]: return result -def test_llvm_ir_debug_info(): +def _module(): @tvm.script.ir_module class MyModule: @T.prim_func @@ -58,6 +58,33 @@ def main(a: T.handle, b: T.handle): assert 1 == 0, "Some numbers" B[vi] = A[vi] + 1.0 + return MyModule + + +def test_tir_debug_info(): + """ + Test that Spans are correctly replaced with debug spans that reference + the printed TIR + """ + + def find_span(m): + func = next(m.functions.values()) + return func.body.block.body.span + + module_before = _module() + span_before = find_span(module_before) + assert span_before is None + + module_after = tir.transform.InstallDebugSpans()(module_before) + span_after = find_span(module_after) + + # Check that the module name has been added and a line number is present + assert span_after.source_name.name == "main.tir" + assert span_after.line == 4 + + +def test_llvm_ir_debug_info(): + MyModule = _module() with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": True}): runtime_module = tvm.build(MyModule, target="llvm") From 91220a7b8483a731eaec4361e14d357e76c8b0a4 Mon Sep 17 00:00:00 2001 From: driazati Date: Tue, 29 Nov 2022 11:42:21 -0800 Subject: [PATCH 6/7] Remove dead code --- src/tir/transforms/install_debug_spans.h | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/tir/transforms/install_debug_spans.h b/src/tir/transforms/install_debug_spans.h index 836c95fb4834..c71891aba5a6 100644 --- a/src/tir/transforms/install_debug_spans.h +++ b/src/tir/transforms/install_debug_spans.h @@ -122,15 +122,6 @@ class DebugInfoInstaller : public StmtExprMutator { std::unordered_map expr_lines_; std::string filename_; - template - Stmt add_span(const ObjectName* op) { - Stmt new_stmt = StmtExprMutator::VisitStmt_(op); - auto new_type = Downcast(new_stmt); - auto new_node = new_type.CopyOnWrite(); - new_node->span = MaybeSpan(op); - return new_type; - } - Span MaybeSpan(const StmtNode* op); Span MaybeSpan(const PrimExprNode* op); }; From b80864d8864d100d53ceb813590983bad892b07b Mon Sep 17 00:00:00 2001 From: driazati Date: Mon, 5 Dec 2022 14:00:02 -0800 Subject: [PATCH 7/7] Add accuracy test --- tests/python/tir/test_debug_info.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/python/tir/test_debug_info.py b/tests/python/tir/test_debug_info.py index c09df012c8c5..8ecabbd51a97 100644 --- a/tests/python/tir/test_debug_info.py +++ b/tests/python/tir/test_debug_info.py @@ -84,6 +84,9 @@ def find_span(m): def test_llvm_ir_debug_info(): + """ + Check that the right amount of debug locations are present + """ MyModule = _module() with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": True}): runtime_module = tvm.build(MyModule, target="llvm") @@ -94,5 +97,28 @@ def test_llvm_ir_debug_info(): assert len(locations) == 34 +def test_llvm_ir_debug_accuracy(): + """ + Check that the debug location on an assert is correct + """ + MyModule = _module() + with tvm.transform.PassContext(opt_level=3, config={"tir.enable_debug": True}): + runtime_module = tvm.build(MyModule, target="llvm") + source = runtime_module.get_source() + locations = find_di_locations(source) + + # Find the 'assert' from MyModule + debug_dir_match = re.search( + r"tail call void %0\(i8\* getelementptr inbounds .* !dbg !(\d+)\n", source + ) + + # Extract out the debug directive line + directive_idx = debug_dir_match.groups()[0] + + # Check that it matches the expected line number (in main.tir) + debug_line_no = int(locations[directive_idx]) + assert debug_line_no == 42 + + if __name__ == "__main__": tvm.testing.main()