diff --git a/src/autoschedulers/adams2019/ASLog.cpp b/src/autoschedulers/adams2019/ASLog.cpp index 5b4754288545..9cbea935fa94 100644 --- a/src/autoschedulers/adams2019/ASLog.cpp +++ b/src/autoschedulers/adams2019/ASLog.cpp @@ -1,5 +1,8 @@ #include "ASLog.h" +#include +#include + namespace Halide { namespace Internal { @@ -50,5 +53,20 @@ int aslog::aslog_level() { return cached_aslog_level; } +std::string conform_name(const std::string &name, const std::string &prefix) { + auto invalid_contents = [](const char &c) { + return std::ispunct(c) || std::isspace(c); + }; + + auto invalid_prefix = [](const char &c) { + return (c != '_') && !(std::isalpha(c)); + }; + + std::string result(name); + std::replace_if(result.begin(), result.end(), invalid_contents, '_'); + if (invalid_prefix(result.front())) { result = std::string(prefix) + result; } + return result; +} + } // namespace Internal } // namespace Halide diff --git a/src/autoschedulers/adams2019/ASLog.h b/src/autoschedulers/adams2019/ASLog.h index 9ba9844ce342..afee33d4ff15 100644 --- a/src/autoschedulers/adams2019/ASLog.h +++ b/src/autoschedulers/adams2019/ASLog.h @@ -31,6 +31,9 @@ class aslog { static int aslog_level(); }; +// Conform the given name into a valid C++ identifier (eg for dumping a Func/Var inside a schedule to a header) +std::string conform_name(const std::string &name, const std::string &prefix = "_"); + } // namespace Internal } // namespace Halide diff --git a/src/autoschedulers/adams2019/LoopNest.cpp b/src/autoschedulers/adams2019/LoopNest.cpp index 9a87608f305e..b4db9b72f977 100644 --- a/src/autoschedulers/adams2019/LoopNest.cpp +++ b/src/autoschedulers/adams2019/LoopNest.cpp @@ -1709,7 +1709,7 @@ void LoopNest::apply(LoopLevel here, internal_assert(v.innermost_pure_dim && v.exists) << v.var.name() << "\n"; // Is the result of a split state.schedule_source - << "\n .vectorize(" << v.var.name() << ")"; + << "\n .vectorize(" << conform_name(v.var.name()) << ")"; s.vectorize(v.var); } } else { @@ -1760,9 +1760,9 @@ void LoopNest::apply(LoopLevel here, parent.exists = false; parent.extent = 1; } else { - VarOrRVar inner(Var(parent.var.name() + "i")); + VarOrRVar inner(Var(conform_name(parent.var.name() + "i"))); if (parent.var.is_rvar) { - inner = RVar(parent.var.name() + "i"); + inner = RVar(conform_name(parent.var.name() + "i", "r")); } auto tail_strategy = pure_var_tail_strategy; @@ -1779,8 +1779,8 @@ void LoopNest::apply(LoopLevel here, s.split(parent.var, parent.var, inner, (int)factor, tail_strategy); state.schedule_source << "\n .split(" - << parent.var.name() << ", " - << parent.var.name() << ", " + << conform_name(parent.var.name()) << ", " + << conform_name(parent.var.name()) << ", " << inner.name() << ", " << factor << ", " << "TailStrategy::" << tail_strategy << ")"; @@ -1819,7 +1819,7 @@ void LoopNest::apply(LoopLevel here, for (size_t i = 0; i < symbolic_loop.size(); i++) { if (state.vars[i].pure && state.vars[i].exists && state.vars[i].extent > 1) { s.unroll(state.vars[i].var); - state.schedule_source << "\n .unroll(" << state.vars[i].var.name() << ")"; + state.schedule_source << "\n .unroll(" << conform_name(state.vars[i].var.name()) << ")"; } } } @@ -1858,7 +1858,7 @@ void LoopNest::apply(LoopLevel here, if (here.is_root()) { loop_level = "_root()"; } else { - loop_level = "_at(" + here.func() + ", " + here.var().name() + ")"; + loop_level = "_at(" + here.func() + ", " + conform_name(here.var().name()) + ")"; } for (const auto &c : children) { if (c->node != node) {