diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index f2d985279f12..38334de357d8 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -21,7 +21,7 @@ #include #include -#include +#include namespace tvm { @@ -38,8 +38,17 @@ std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional 0 && // + (std::isalpha(name[0]) || name[0] == '_') && // + std::all_of(name.begin() + 1, name.end(), + [](char c) { return std::isalnum(c) || c == '_'; }); } PrinterConfig::PrinterConfig(Map config_dict) { diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc index 9524c90b577c..7fb67d9376f5 100644 --- a/src/relax/ir/dataflow_matcher.cc +++ b/src/relax/ir/dataflow_matcher.cc @@ -36,7 +36,7 @@ #include #include #include -#include +#include #include #include #include @@ -557,8 +557,9 @@ bool DFPatternMatcher::VisitDFPattern_(const DataflowVarPatternNode* op, const E bool DFPatternMatcher::VisitDFPattern_(const GlobalVarPatternNode* op, const Expr& expr) { // GlobalVarPattern is not inherited from Var, so we need to handle it separately. if (const auto* var_node = expr.as()) { - std::regex pat{std::string(op->name_hint())}; - return "" == op->name_hint() || std::regex_search(std::string(var_node->name_hint), pat); + std::string pat = std::string(op->name_hint()); + std::string var_name = std::string(var_node->name_hint); + return pat.empty() || var_name.find(pat) != std::string::npos; } return false; } diff --git a/src/runtime/contrib/cublas/cublas_json_runtime.cc b/src/runtime/contrib/cublas/cublas_json_runtime.cc index c6916d4f86fa..23e35d2f7188 100644 --- a/src/runtime/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/contrib/cublas/cublas_json_runtime.cc @@ -26,7 +26,6 @@ #include #include -#include #include #include diff --git a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc index 58e4e59afcb9..7d701396d0ca 100644 --- a/src/runtime/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/contrib/cudnn/cudnn_json_runtime.cc @@ -26,7 +26,6 @@ #include #include -#include #include #include @@ -54,7 +53,7 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { stream = static_cast((*func)().operator void*()); auto attr_in_name = [](const std::string& op_name, const std::string& attr_name) { - return std::regex_search(op_name, std::regex(attr_name)); + return op_name.find(attr_name) != std::string::npos; }; auto vstr2vint = [](const JSONGraphNode& node, const std::string& attrStr) { diff --git a/tests/python/relax/test_dataflow_pattern.py b/tests/python/relax/test_dataflow_pattern.py index 685a382ad7a5..edd3bd16104f 100644 --- a/tests/python/relax/test_dataflow_pattern.py +++ b/tests/python/relax/test_dataflow_pattern.py @@ -97,7 +97,9 @@ def test_dataflow_var_pattern(): def test_global_var_pattern(): assert is_gv("x").match(rx.GlobalVar("x")) - assert is_gv("x.*").match(rx.GlobalVar("x_2")) + # TODO: disabled as regex is not supported due to + # symbol conflict with PyTorch + # assert is_gv("x.*").match(rx.GlobalVar("x_2")) assert is_gv().match(rx.GlobalVar("x")) assert not is_gv("x").match(rx.GlobalVar("y")) assert not is_gv("x").match(rx.Var("x"))