diff --git a/src/target/llvm/llvm_instance.cc b/src/target/llvm/llvm_instance.cc index 772e71b28724..19ff480452e8 100644 --- a/src/target/llvm/llvm_instance.cc +++ b/src/target/llvm/llvm_instance.cc @@ -39,6 +39,7 @@ #include #endif #include +#include #include #include #include @@ -56,9 +57,14 @@ #include #include +#include +#include +#include +#include #include #include #include +#include namespace tvm { namespace codegen { @@ -136,10 +142,27 @@ std::unique_ptr LLVMInstance::ParseBuffer(const llvm::MemoryBuffer return module; } -// LLVMTarget +// LLVMTargetInfo + +std::ostream& operator<<(std::ostream& os, const LLVMTargetInfo::Option& opt) { + os << '-' << opt.name; + switch (opt.type) { + case LLVMTargetInfo::Option::OptType::Bool: + return os << ":bool=" << (opt.value.b ? "true" : "false"); + case LLVMTargetInfo::Option::OptType::Int: + return os << ":int=" << opt.value.i; + case LLVMTargetInfo::Option::OptType::UInt: + return os << ":uint=" << opt.value.u; + case LLVMTargetInfo::Option::OptType::String: + return os << ":string=" << opt.value.s; + default: + os << ":?(" << static_cast(opt.type) << ")"; + break; + } + return os; +} -LLVMTarget::LLVMTarget(LLVMInstance& instance, const Target& target) - : instance_(instance), ctx_(instance.GetContext()) { +LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& instance, const Target& target) { triple_ = target->GetAttr("mtriple").value_or("default"); if (triple_.empty() || triple_ == "default") { @@ -153,6 +176,26 @@ LLVMTarget::LLVMTarget(LLVMInstance& instance, const Target& target) } } + if (const Optional>& v = target->GetAttr>("cl-opt")) { + llvm::StringMap& options = llvm::cl::getRegisteredOptions(); + bool parse_error = false; + for (const String& s : v.value()) { + Option opt = ParseOptionString(s); + if (opt.type == Option::OptType::Invalid) { + parse_error = true; + continue; + } + if (options.count(opt.name)) { + llvm_options_.push_back(opt); + } else { + // Flag an error, but don't abort. LLVM flags may change, and this would + // give the code a chance to run even if the option no longer applies. + LOG(ERROR) << "\"" << opt.name << "\" is not an LLVM option, option ignored"; + } + } + ICHECK(!parse_error) << "there were errors parsing command-line options"; + } + llvm::FloatABI::ABIType float_abi = llvm::FloatABI::Default; if (const Optional& v = target->GetAttr("mfloat-abi")) { String value = v.value(); @@ -238,17 +281,12 @@ LLVMTarget::LLVMTarget(LLVMInstance& instance, const Target& target) } } -LLVMTarget::LLVMTarget(LLVMInstance& scope, const std::string& target_str) - : LLVMTarget(scope, Target(target_str)) {} +LLVMTargetInfo::LLVMTargetInfo(LLVMInstance& scope, const std::string& target_str) + : LLVMTargetInfo(scope, Target(target_str)) {} -LLVMTarget::~LLVMTarget() = default; - -llvm::LLVMContext* LLVMTarget::GetContext() const { - ICHECK(!ctx_.expired()) << "LLVM scope has been deleted"; - return ctx_.lock().get(); -} +LLVMTargetInfo::~LLVMTargetInfo() = default; -llvm::TargetMachine* LLVMTarget::GetOrCreateTargetMachine(bool allow_missing) { +llvm::TargetMachine* LLVMTargetInfo::GetOrCreateTargetMachine(bool allow_missing) { if (target_machine_) return target_machine_.get(); std::string error; @@ -264,11 +302,11 @@ llvm::TargetMachine* LLVMTarget::GetOrCreateTargetMachine(bool allow_missing) { return target_machine_.get(); } -std::string LLVMTarget::GetTargetFeatureString() const { // +std::string LLVMTargetInfo::GetTargetFeatureString() const { // return Join(",", attrs_); } -std::string LLVMTarget::str() const { +std::string LLVMTargetInfo::str() const { std::ostringstream os; os << "llvm"; if (!triple_.empty()) { @@ -340,9 +378,324 @@ std::string LLVMTarget::str() const { } } + if (size_t num = llvm_options_.size(); num > 0) { + os << " -cl-opt="; + std::vector opts; + for (const Option& opt : llvm_options_) { + std::stringstream os; + os << opt; + opts.emplace_back(os.str()); + } + auto* quote = num > 1 ? "'" : ""; + os << quote << Join(",", opts) << quote; + } + return os.str(); } +LLVMTargetInfo::Option LLVMTargetInfo::ParseOptionString(const std::string& str) { + Option opt; + opt.type = Option::OptType::Invalid; + + // Option string: "-"+ ":" "=" + // + // Note: "-"+ means 1 or more dashes, but only "-" are "--" valid. + + // The first step is to do "lexing" of the option string, i.e. to break + // it up into parts (like "tokens") according to the syntax above. These + // parts will be non-overlapping substrings of the option string, and + // concatenated together, they will be equal to the option string. + // The literal elements are parts on their own. + // + // Note that the option string may be malformed, so any of the literal + // elements in the syntax may be missing. + + std::vector parts; + + auto find_first_of = [](const std::string& str, const std::string& chars, auto start = 0) { + auto pos = str.find_first_of(chars, start); + return pos != std::string::npos ? pos : str.size(); + }; + auto find_first_not_of = [](const std::string& str, const std::string& chars, auto start = 0) { + auto pos = str.find_first_not_of(chars, start); + return pos != std::string::npos ? pos : str.size(); + }; + + // "-"+ + std::string::size_type pos_start = 0, pos_end = str.size(); + std::string::size_type pos_at = find_first_not_of(str, "-", pos_start); + if (pos_at > 0) { + parts.push_back(str.substr(pos_start, pos_at)); + } + // , always present, may be empty string + pos_start = pos_at; + pos_at = find_first_of(str, ":=", pos_start); + parts.push_back(str.substr(pos_start, pos_at - pos_start)); + + // ":" or "=", if any + pos_start = pos_at; + char c = pos_start < pos_end ? str[pos_start] : 0; + if (c != 0) { + parts.emplace_back(1, c); + pos_start++; + } + // If the character found in the previous step wasn't '=', look for '='. + if (c == ':') { + // + pos_at = find_first_of(str, "=", pos_start); + if (pos_at > pos_start) { // if non-empty + parts.push_back(str.substr(pos_start, pos_at - pos_start)); + } + + // "=" + if (pos_at < pos_end) { + parts.emplace_back(1, str[pos_at]); + pos_start = pos_at + 1; + } + } + if (pos_start < pos_end) { + // + parts.push_back(str.substr(pos_start)); + } + + // After breaking up the option string, examine and validate the individual + // parts. + + int part_this = 0, part_end = parts.size(); + + const std::string error_header = "while parsing option \"" + str + "\": "; + + // Check for "-" or "--". + if (part_this < part_end) { + auto& p = parts[part_this++]; + if ((p.size() != 1 && p.size() != 2) || p.find_first_not_of('-') != std::string::npos) { + LOG(ERROR) << error_header << "option must start with \"-\" or \"--\""; + return opt; + } + } + + // Validate option name. + if (part_this < part_end) { + auto& p = parts[part_this++]; + if (p.empty()) { + LOG(ERROR) << error_header << "option name must not be empty"; + return opt; + } + opt.name = std::move(p); + } + + // Check type, if present. + Option::OptType type = Option::OptType::Invalid; + if (part_this < part_end) { + auto& p0 = parts[part_this]; + if (p0 == ":") { + part_this++; // Only advance if we saw ":". + if (part_this < part_end) { + auto& p1 = parts[part_this]; + ICHECK(!p1.empty()) << "tokenizing error"; // This shouldn't happen. + if (p1 != "=") { + part_this++; + if (p1 == "bool") { + type = Option::OptType::Bool; + } else if (p1 == "int") { + type = Option::OptType::Int; + } else if (p1 == "uint") { + type = Option::OptType::UInt; + } else if (p1 == "string") { + type = Option::OptType::String; + } + } + } + // If there was ":", there must be a type. + if (type == Option::OptType::Invalid) { + LOG(ERROR) << error_header << "invalid type"; + return opt; + } + } + } + + // Check value, if present. + std::optional value; + if (part_this < part_end) { + auto& p0 = parts[part_this]; + if (p0 == "=") { + part_this++; + if (part_this < part_end) { + value = std::move(parts[part_this]); + } else { + value = ""; + } + } else { + // If there are still any parts left to be processed, there must be "=". + LOG(ERROR) << error_header << "expecting \"=\""; + return opt; + } + } + + // NOLINTNEXTLINE(runtime/int) + auto to_integer = [](const std::string& s) -> std::optional { + // std::stoll takes "long long" + long long number; // NOLINT(runtime/int) + size_t pos; + try { + number = std::stoll(s, &pos); + } catch (...) { + return std::nullopt; + } + if (pos == s.size()) { + return number; + } else { + return std::nullopt; + } + }; + + auto to_boolean = [&to_integer](const std::string& s) -> std::optional { + // Return 0 or 1, if string corresponds to a valid boolean value, + // otherwise return 2. + auto ti = to_integer(s); + if (ti.has_value() && (ti.value() == 0 || ti.value() == 1)) { + return static_cast(ti.value()); + } + + std::string lower; + std::transform(s.begin(), s.end(), std::back_inserter(lower), + [](unsigned char c) { return std::tolower(c); }); + if (lower == "true") { + return true; + } else if (lower == "false") { + return false; + } + return std::nullopt; + }; + + if (value.has_value()) { + if (type == Option::OptType::Int || type == Option::OptType::UInt) { + auto v = to_integer(value.value()); + if (!v.has_value()) { + LOG(ERROR) << error_header << "invalid integer value \"" << value.value() << "\""; + return opt; + } + if (type == Option::OptType::Int) { + opt.value.i = static_cast(v.value()); + if (opt.value.i != v.value()) { + LOG(WARNING) << error_header << "value exceeds int range, assuming " << opt.value.i; + } + } else { + // NOLINTNEXTLINE(runtime/int) + opt.value.u = static_cast(static_cast(v.value())); + if (opt.value.u != static_cast(v.value())) { // NOLINT(runtime/int) + LOG(WARNING) << error_header << "value exceeds int range, assuming " << opt.value.u; + } + } + } else if (type == Option::OptType::String) { + opt.value.s = std::move(value.value()); + } else { + // "type" is either Bool (given explicitly) or Invalid (type not present in string) + auto v = to_boolean(value.value()); + if (!v.has_value()) { + LOG(ERROR) << error_header << "invalid boolean value \"" << value.value() << "\""; + return opt; + } + opt.value.b = v.value(); + type = Option::OptType::Bool; + } + } else { + // Value was not present in string. Assume "true" if "type" is Bool or Invalid + if (type == Option::OptType::Bool || type == Option::OptType::Invalid) { + opt.value.b = true; + type = Option::OptType::Bool; + } else { + LOG(ERROR) << error_header << "must have a value"; + return opt; + } + } + + ICHECK(type != Option::OptType::Invalid); + opt.type = type; + return opt; +} + +bool LLVMTargetInfo::MatchesGlobalState() const { + for (const Option& opt : GetCommandLineOptions()) { + Option current_opt = opt; + GetOptionValue(¤t_opt); + ICHECK(current_opt.type != Option::OptType::Invalid); + switch (current_opt.type) { + case Option::OptType::Bool: + if (current_opt.value.b != opt.value.b) return false; + continue; + case Option::OptType::Int: + if (current_opt.value.i != opt.value.i) return false; + continue; + case Option::OptType::UInt: + if (current_opt.value.u != opt.value.u) return false; + continue; + case Option::OptType::String: + if (current_opt.value.s != opt.value.s) return false; + continue; + default:; // NOLINT(whitespace/semicolon) + } + } + return true; +} + +void LLVMTargetInfo::GetOptionValue(LLVMTargetInfo::Option* opt) const { + llvm::StringMap& options = llvm::cl::getRegisteredOptions(); + llvm::cl::Option* base_op = options[opt->name]; + + if (opt->type == Option::OptType::Bool) { + auto* bool_op = static_cast*>(base_op); + opt->value.b = bool_op->getValue(); + } else if (opt->type == Option::OptType::Int) { + auto* int_op = static_cast*>(base_op); + opt->value.i = int_op->getValue(); + } else if (opt->type == Option::OptType::UInt) { + auto* uint_op = static_cast*>(base_op); + opt->value.u = uint_op->getValue(); + } else if (opt->type == Option::OptType::String) { + auto* str_op = static_cast*>(base_op); + opt->value.s = str_op->getValue(); + } else { + opt->type = Option::OptType::Invalid; + } +} + +// LLVMTarget + +bool LLVMTarget::modified_llvm_state_ = false; + +LLVMTarget::LLVMTarget(LLVMInstance& instance, const LLVMTargetInfo& target_info) + : LLVMTargetInfo(target_info), instance_(instance), ctx_(instance.GetContext()) { + // Populate the list of saved options with the current values. + for (const Option& opt : GetCommandLineOptions()) { + GetOptionValue(&saved_llvm_options_.emplace_back(opt)); + } + + if (modified_llvm_state_) { + ICHECK(!ApplyLLVMOptions(true)); + } else { + modified_llvm_state_ = ApplyLLVMOptions(true); + } +} + +LLVMTarget::LLVMTarget(LLVMInstance& instance, const Target& target) + : LLVMTarget(instance, LLVMTargetInfo(instance, target)) {} + +LLVMTarget::LLVMTarget(LLVMInstance& scope, const std::string& target_str) + : LLVMTarget(scope, Target(target_str)) {} + +LLVMTarget::~LLVMTarget() { + // Revert all applied LLVM options. + if (ApplyLLVMOptions(false)) { + modified_llvm_state_ = false; + } +} + +llvm::LLVMContext* LLVMTarget::GetContext() const { + ICHECK(!ctx_.expired()) << "LLVM scope has been deleted"; + return ctx_.lock().get(); +} + std::string LLVMTarget::GetTargetMetadata(const llvm::Module& module) { if (llvm::Metadata* tvm_target = module.getModuleFlag("tvm_target")) { auto* mdstr = llvm::cast(tvm_target); @@ -359,6 +712,55 @@ void LLVMTarget::SetTargetMetadata(llvm::Module* module) const { llvm::MDString::get(*GetContext(), str())); } +bool LLVMTarget::ApplyLLVMOptions(bool apply_otherwise_revert, bool dry_run) { + llvm::StringMap& options = llvm::cl::getRegisteredOptions(); + bool changed = false; + +#define HANDLE_OPTION_VALUE(option, new_val, saved_val) \ + do { \ + auto current = (option)->getValue(); \ + auto replacement = apply_otherwise_revert ? (new_val) : (saved_val); \ + if (current != replacement) { \ + changed = true; \ + if (!dry_run) { \ + (option)->setValue(replacement); \ + } \ + } \ + } while (false); + + const auto& new_options = GetCommandLineOptions(); + for (size_t i = 0, e = saved_llvm_options_.size(); i != e; ++i) { + const Option& new_opt = new_options[i]; + const Option& saved_opt = saved_llvm_options_[i]; + + llvm::cl::Option* base_op = options[new_opt.name]; + + if (new_opt.type == Option::OptType::Bool) { + auto* bool_op = static_cast*>(base_op); + HANDLE_OPTION_VALUE(bool_op, new_opt.value.b, saved_opt.value.b); + } else if (new_opt.type == Option::OptType::Int) { + auto* int_op = static_cast*>(base_op); + HANDLE_OPTION_VALUE(int_op, new_opt.value.i, saved_opt.value.i); + } else if (new_opt.type == Option::OptType::UInt) { + auto* uint_op = static_cast*>(base_op); + HANDLE_OPTION_VALUE(uint_op, new_opt.value.u, saved_opt.value.u); + } else if (new_opt.type == Option::OptType::String) { + auto* str_op = static_cast*>(base_op); + HANDLE_OPTION_VALUE(str_op, new_opt.value.s, saved_opt.value.s); + } else { + LOG(FATAL) << "unexpected type in option " << new_opt; + } + + if (dry_run && changed) { + return true; + } + } + +#undef HANDLE_OPTION_VALUE + + return changed; +} + } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/llvm_instance.h b/src/target/llvm/llvm_instance.h index afb6e58deb1f..217db63aad7a 100644 --- a/src/target/llvm/llvm_instance.h +++ b/src/target/llvm/llvm_instance.h @@ -38,6 +38,7 @@ #include #include +#include #include #include #include @@ -57,8 +58,9 @@ class LLVMTarget; /*! * \class LLVMInstance - * \brief LLVMInstance is a class that (conceptually) starts and stops LLVM. All - * uses of LLVM should take place within a lifetime of an object of this class. + * \brief LLVMInstance is a class that (conceptually) starts and stops LLVM. + * All uses of LLVM should take place within a lifetime of an object + * of this class. * * E.g. * ```{.cpp} @@ -128,60 +130,48 @@ class LLVMInstance { }; /*! - * \class LLVMTarget - * \brief Information used by LLVM for code generation for particular target + * \class LLVMTargetInfo + * \brief Summary of information for this TVM target relevant to LLVM code + * generation. * * This class contains all information that LLVM needs for code generation for - * a particular target. Since Target in TVM will soon contain command line - * flags for LLVM, objects of this class will handle saving and restoring - * global LLVM state that may be affected by these flags. This way, code - * generation for each LLVM-based target in TVM will start with the same LLVM - * global state. + * a particular target. The purpose of this class is only to provide information + * in an easily-accessible form (for example for querying the target properties). * * Note that objects of this class must be created within the lifetime of an * LLVMInstance object. */ -class LLVMTarget { +class LLVMTargetInfo { public: /*! - * \brief Constructs LLVMTarget from `Target` + * \brief Constructs LLVMTargetInfo from `Target` * \param scope LLVMInstance object * \param target TVM Target object for target "llvm" */ - LLVMTarget(LLVMInstance& scope, const Target& target); // NOLINT(runtime/references) + LLVMTargetInfo(LLVMInstance& scope, const Target& target); // NOLINT(runtime/references) /*! - * \brief Constructs LLVMTarget from target string + * \brief Constructs LLVMTargetInfo from target string * \param scope LLVMInstance object * \param target TVM target string for target "llvm" */ - LLVMTarget(LLVMInstance& scope, const std::string& target_str); // NOLINT(runtime/references) + // NOLINTNEXTLINE(runtime/references) + LLVMTargetInfo(LLVMInstance& scope, const std::string& target_str); /*! - * \brief Destroys LLVMTarget object + * \brief Destroys LLVMTargetInfo object */ - ~LLVMTarget(); + ~LLVMTargetInfo(); /*! - * \brief Returns string representation (as TVM target) of the LLVMTarget + * \brief Returns string representation (as TVM target) of the LLVMTargetInfo * \return Target string * - * Note: If the LLVMTarget object was created from a string `s`, the string + * Note: If the LLVMTargetInfo object was created from a string `s`, the string * returned here may not be exactly equal to `s`. For example, if the CPU * was "default", the returned string will have CPU set to the detected host * CPU. */ std::string str() const; - /*! - * \brief Get the LLVMInstance object from which the LLVMTarget object was - * created - * \return The enclosing LLVMInstance object - */ - const LLVMInstance& GetInstance() const { return instance_; } - /*! - * \brief Get the current LLVM context - * \return the current LLVM context - */ - llvm::LLVMContext* GetContext() const; /*! * \brief Return LLVM's `TargetMachine`, or nullptr * \param allow_missing do not abort if the target machine cannot be created, @@ -228,6 +218,125 @@ class LLVMTarget { */ llvm::CodeGenOpt::Level GetOptLevel() const { return opt_level_; } + /*! + * \class Option + * \brief Internal representation of command-line option + */ + struct Option { + enum class OptType { + Invalid = 0, //!< placeholder, indicates parsing error + Bool, //!< enum value corresponding to type string "bool" + Int, //!< enum value corresponding to type string "int" + UInt, //!< enum value corresponding to type string "uint" + String, //!< enum value corresponding to type string "string" + }; + std::string name; //!< option name + OptType type; //!< type of the option value + struct { + union { + bool b; //!< bool option value + int i; //!< int option value + unsigned u = 0; //!< unsigned option value + }; + std::string s; //!< string option value + } value; //!< option value specified in the option string + }; + + /*! + * \brief Get LLVM command line options + * \return the list of LLVM command line options specified for this target + */ + const std::vector