diff --git a/src/target/target.cc b/src/target/target.cc index 9ccd755540ca..6124cedc7256 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -30,8 +31,13 @@ #include #include -#include +#include +#include #include +#include +#include +#include +#include #include "../runtime/object_internal.h" @@ -62,6 +68,17 @@ class TargetInternal { private: static std::unordered_map QueryDevice(int device_id, const TargetNode* target); + static bool IsQuoted(const std::string& str); + static std::string Quote(const std::string& str); + static std::string JoinString(const std::vector& array, char separator); + static std::vector SplitString(const std::string& str, char separator); + static std::string Interpret(const std::string& str); + static std::string Uninterpret(const std::string& str); + static std::string StringifyAtomicType(const ObjectRef& obj); + static std::string StringifyArray(const ArrayNode& array); + + static constexpr char quote = '\''; + static constexpr char escape = '\\'; }; /********** Helper functions **********/ @@ -135,48 +152,50 @@ static std::string RemovePrefixDashes(const std::string& s) { return s.substr(n_dashes); } -static int FindFirstSubstr(const std::string& str, const std::string& substr) { - size_t pos = str.find_first_of(substr); - return pos == std::string::npos ? -1 : pos; -} - -static Optional JoinString(const std::vector& array, char separator) { - char escape = '\\'; - char quote = '\''; - - if (array.empty()) { - return NullOpt; +bool TargetInternal::IsQuoted(const std::string& str) { + std::string::size_type start = 0, end = str.size(); + if (end < 2 || str[start] != quote || str[end - 1] != quote) { + return false; } - - std::ostringstream os; - - for (size_t i = 0; i < array.size(); ++i) { - if (i > 0) { - os << separator; + bool escaping = false; + for (auto i = start + 1, e = end - 1; i < e; ++i) { + if (escaping) { + escaping = false; + } else if (str[i] == escape) { + escaping = true; + } else if (str[i] == quote) { + return false; } + } + // If the reduced string ends with \, then the terminating quote is escaped. + return !escaping; +} - std::string str = array[i]; +std::string TargetInternal::Quote(const std::string& str) { + std::string result(1, quote); + result.append(str); + result.push_back(quote); + return result; +} - if ((str.find(separator) == std::string::npos) && (str.find(quote) == std::string::npos)) { - os << str; - } else { - os << quote; - for (char c : str) { - if (c == quote) { - os << escape; - } - os << c; - } - os << quote; +std::string TargetInternal::JoinString(const std::vector& array, char separator) { + std::string result; + ICHECK(separator != quote && separator != escape) + << "string join separator cannot be " << quote << " or " << escape; + + bool is_first = true; + for (const auto& s : array) { + if (!is_first) { + result.push_back(separator); } + result.append(s); + is_first = false; } - return String(os.str()); -} -static std::vector SplitString(const std::string& str, char separator) { - char escape = '\\'; - char quote = '\''; + return result; +} +std::vector TargetInternal::SplitString(const std::string& str, char separator) { std::vector output; const char* start = str.data(); @@ -199,10 +218,12 @@ static std::vector SplitString(const std::string& str, char separat if ((*pos == separator) && !pos_quoted) { finish_word(); pos++; - } else if ((*pos == escape) && (pos + 1 < end) && (pos[1] == quote)) { - current_word << quote; + } else if (*pos == escape && pos + 1 < end) { + current_word << escape; + current_word << pos[1]; pos += 2; } else if (*pos == quote) { + current_word << quote; pos_quoted = !pos_quoted; pos++; } else { @@ -218,12 +239,91 @@ static std::vector SplitString(const std::string& str, char separat return output; } +std::string TargetInternal::Interpret(const std::string& str) { + // String interpretation deals with quotes (') and escapes(\). + // - An escape character must be followed by another character forming an + // "escape sequence". (Trailing escape is not allowed.) An escape prevents + // interpretation of the character that follows. This happens regardless of + // whether the escape sequence appears within quoted substring or not. + // - A quote character, when interpreted, marks the beginning or the end of a + // quoted substring. (A quoted substring cannot contain unescaped quotes.) + // - Any other character, when interpreted, represents itself. + // + // Interpretation happens in two steps: + // 1. If the entire string is quoted, the quotes are removed first, and the + // resulting string is treated as unquoted. + // 2. Each character or escape sequence is interpreted, and the result is copied + // to the result. When not inside a quoted substring, the interpretation of an + // escape sequence is the escaped character, otherwise it is the entire escape + // sequence. + // + // Examples: + // blah -> blah Nothing happened + // 'blah' -> blah Enclosing quotes removed + // 'bl'ah -> 'bl'ah Non-enclosing quotes remain + // '\'blah\'' -> 'blah' Enclosing quotes removed, escaped quotes + // interpreted. + // '\'\\\'blah\\\'\'' -> '\'blah\'' Same as above. + // + // Note that + // '\'\\\'blah\\\'\'' -> '\'blah\'' -> 'blah' + + std::string result; + if (str.empty()) { + return result; + } + + // Check if the entire string is enclosed in quotes ''. If so, strip the quotes + // and treat the string as unquoted (so that escapes are interpreted). Doing that + // will allow '\'foo\'' to become 'foo', instead of \'foo\'. + std::string::size_type start = 0, end = str.size(); + if (IsQuoted(str)) { + start++; + end--; + } + + bool inside_quote = false; + bool escaping = false; + + for (auto i = start, e = end; i < e; ++i) { + std::string::value_type c = str[i]; + if (escaping) { + escaping = false; + } else if (c == escape) { + escaping = true; + if (!inside_quote) { + continue; + } + } else if (c == quote) { + inside_quote = !inside_quote; + } + result.push_back(c); + } + + return result; +} + +std::string TargetInternal::Uninterpret(const std::string& str) { + // Do the opposite to `Interpret`, so that Interpret(Uninterpret(str)) == str. + std::string result; + + for (std::string::size_type i = 0, e = str.size(); i < e; ++i) { + std::string::value_type c = str[i]; + if (c == escape || c == quote) { + result.push_back(escape); + } + result.push_back(c); + } + + return result; +} + static int ParseKVPair(const std::string& s, const std::string& s_next, std::string* key, std::string* value) { - int pos; + std::string::size_type pos; std::string& result_k = *key; std::string& result_v = *value; - if ((pos = FindFirstSubstr(s, "=")) != -1) { + if ((pos = s.find_first_of('=')) != std::string::npos) { // case 1. --key=value result_k = s.substr(0, pos); result_v = s.substr(pos + 1); @@ -267,13 +367,14 @@ const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKi ObjectRef TargetInternal::ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info) { + std::string interp_str = Interpret(str); if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing integer - std::istringstream is(str); + std::istringstream is(interp_str); int v; if (!(is >> v)) { - std::string lower(str.size(), '\x0'); - std::transform(str.begin(), str.end(), lower.begin(), + std::string lower(interp_str.size(), '\x0'); + std::transform(interp_str.begin(), interp_str.end(), lower.begin(), [](unsigned char c) { return std::tolower(c); }); // Bool is a subclass of IntImm, so allow textual boolean values. if (lower == "true") { @@ -281,23 +382,27 @@ ObjectRef TargetInternal::ParseType(const std::string& str, } else if (lower == "false") { v = 0; } else { - throw Error(": Cannot parse into type \"Integer\" from string: " + str); + throw Error(": Cannot parse into type \"Integer\" from string: " + interp_str); } } return Integer(v); } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - // Parsing string, strip leading/trailing spaces - auto start = str.find_first_not_of(' '); - auto end = str.find_last_not_of(' '); - return String(str.substr(start, (end - start + 1))); + // Parsing string, strip leading/trailing spaces, and enclosing quotes if any + auto start = interp_str.find_first_not_of(' '); + auto end = interp_str.find_last_not_of(' '); + if (start == std::string::npos || end == std::string::npos) { + // The whole string is made of spaces. + return String(); + } + return String(interp_str.substr(start, (end - start + 1))); } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing target - return Target(TargetInternal::FromString(str)); + return Target(TargetInternal::FromString(interp_str)); } else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) { // Parsing array std::vector result; - for (const std::string& substr : SplitString(str, ',')) { + for (const std::string& substr : SplitString(interp_str, ',')) { try { ObjectRef parsed = TargetInternal::ParseType(substr, *info.key); result.push_back(parsed); @@ -308,7 +413,8 @@ ObjectRef TargetInternal::ParseType(const std::string& str, } return Array(result); } - throw Error(": Unsupported type \"" + info.type_key + "\" for parsing from string: " + str); + throw Error(": Unsupported type \"" + info.type_key + + "\" for parsing from string: " + interp_str); } ObjectRef TargetInternal::ParseType(const ObjectRef& obj, @@ -385,14 +491,35 @@ ObjectRef TargetInternal::ParseType(const ObjectRef& obj, /********** Stringifying **********/ -static inline Optional StringifyAtomicType(const ObjectRef& obj) { +std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) { if (const auto* p = obj.as()) { - return String(std::to_string(p->value)); + return std::to_string(p->value); } if (const auto* p = obj.as()) { - return GetRef(p); + auto s = static_cast(GetRef(p)); + auto u = Uninterpret(s); + if (u.find_first_of(' ') != std::string::npos && !IsQuoted(u)) { + u = Quote(u); + } + return u; } - return NullOpt; + LOG(FATAL) << "Cannot stringify this object"; + return ""; // unreachable +} + +std::string TargetInternal::StringifyArray(const ArrayNode& array) { + std::vector elements; + + for (const ObjectRef& item : array) { + std::string s = StringifyAtomicType(item); + std::string u = Uninterpret(s); + if (u.find_first_of(',') != std::string::npos && !IsQuoted(u)) { + u = Quote(u); + } + elements.push_back(u); + } + + return JoinString(elements, ','); } Optional TargetInternal::StringifyAttrsToRaw(const Map& attrs) { @@ -402,30 +529,21 @@ Optional TargetInternal::StringifyAttrsToRaw(const Map result; + std::vector result; + for (const auto& key : keys) { const ObjectRef& obj = attrs[key]; - Optional value = NullOpt; + std::string value; if (const auto* array = obj.as()) { - std::vector items; - for (const ObjectRef& item : *array) { - Optional str = StringifyAtomicType(item); - if (str.defined()) { - items.push_back(str.value()); - } else { - items.clear(); - break; - } - } - value = JoinString(items, ','); + value = String(StringifyArray(*array)); } else { value = StringifyAtomicType(obj); } - if (value.defined()) { - result.push_back("-" + key + "=" + value.value()); + if (!value.empty()) { + result.push_back("-" + key + "=" + value); } } - return JoinString(result, ' '); + return String(JoinString(result, ' ')); } const std::string& TargetNode::str() const { diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index cb5eaa18b576..4b4de2b5f44a 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -203,6 +203,92 @@ TEST(TargetCreation, ClashingTargetProcessing) { EXPECT_THROW(Target test("TestClashingPreprocessor -mcpu=woof -mattr=cake"), InternalError); } +TVM_REGISTER_TARGET_KIND("TestStringKind", kDLCPU) + .add_attr_option("single") + .add_attr_option>("array") + .add_attr_option>>("nested-array") + .add_attr_option>>>("nested2-array"); + +TEST(TargetCreation, ProcessStrings) { + Target test_target1("TestStringKind -single='\\'string with single quote'"); + ASSERT_TRUE(test_target1->GetAttr("single")); + String string1 = test_target1->GetAttr("single").value(); + ASSERT_EQ(string1, "'string with single quote"); + + Target test_target2("TestStringKind -single='\\\'\\\\\\'blah\\\\\\'\\\''"); + ASSERT_TRUE(test_target2->GetAttr("single")); + String string2 = test_target2->GetAttr("single").value(); + ASSERT_EQ(string2, "'\\\'blah\\\''"); + + Target test_target3("TestStringKind -array=-danny,-sammy=1,-kirby='string with space'"); + ASSERT_TRUE(test_target3->GetAttr>("array")); + Array array3 = test_target3->GetAttr>("array").value(); + ASSERT_EQ(array3[0], "-danny"); + ASSERT_EQ(array3[1], "-sammy=1"); + ASSERT_EQ(array3[2], "-kirby='string with space'"); + + Target test_target4("TestStringKind -array='fred, foo, bar',baz"); + ASSERT_TRUE(test_target4->GetAttr>("array")); + Array array4 = test_target4->GetAttr>("array").value(); + ASSERT_EQ(array4[0], "fred, foo, bar"); + ASSERT_EQ(array4[1], "baz"); + + Target test_target5("TestStringKind -array='fr\\'ed','f\\'oo',' bar,baz '"); + ASSERT_TRUE(test_target5->GetAttr>("array")); + Array array5 = test_target5->GetAttr>("array").value(); + ASSERT_EQ(array5[0], "fr'ed"); + ASSERT_EQ(array5[1], "f'oo"); + ASSERT_EQ(array5[2], "bar,baz"); + + Target test_target6("TestStringKind -nested-array='foo0,foo1,foo2','bar0,bar1,bar2','baz0,baz1'"); + ASSERT_TRUE(test_target6->GetAttr>>("nested-array")); + Array> array6 = test_target6->GetAttr>>("nested-array").value(); + ASSERT_EQ(array6[0][0], "foo0"); + ASSERT_EQ(array6[0][1], "foo1"); + ASSERT_EQ(array6[0][2], "foo2"); + ASSERT_EQ(array6[1][0], "bar0"); + ASSERT_EQ(array6[1][1], "bar1"); + ASSERT_EQ(array6[1][2], "bar2"); + ASSERT_EQ(array6[2][0], "baz0"); + ASSERT_EQ(array6[2][1], "baz1"); + + Target test_target7( + "TestStringKind -nested2-array=" + "'\\'foo0,foo1\\',\\'bar0,bar1\\',\\'baz0,baz1\\''," + "'\\'zing0,zing1\\',\\'fred\\''"); + + ASSERT_TRUE(test_target7->GetAttr>>>("nested2-array")); + Array>> array7 = + test_target7->GetAttr>>>("nested2-array").value(); + // { + // {foo0, foo1}, + // {bar0, bar1}, + // {baz0, baz1}, + // }, + // { + // {zing0, zing1}, + // {fred}, + // } + ASSERT_EQ(array7.size(), 2); + ASSERT_EQ(array7[0].size(), 3); + ASSERT_EQ(array7[0][0].size(), 2); + ASSERT_EQ(array7[0][1].size(), 2); + ASSERT_EQ(array7[0][2].size(), 2); + ASSERT_EQ(array7[1].size(), 2); + ASSERT_EQ(array7[1][0].size(), 2); + ASSERT_EQ(array7[1][1].size(), 1); + + ASSERT_EQ(array7[0][0][0], "foo0"); + ASSERT_EQ(array7[0][0][1], "foo1"); + ASSERT_EQ(array7[0][1][0], "bar0"); + ASSERT_EQ(array7[0][1][1], "bar1"); + ASSERT_EQ(array7[0][2][0], "baz0"); + ASSERT_EQ(array7[0][2][1], "baz1"); + ASSERT_EQ(array7[1][0][0], "zing0"); + ASSERT_EQ(array7[1][0][1], "zing1"); + ASSERT_EQ(array7[1][1][0], "fred"); +} + TVM_REGISTER_TARGET_KIND("test_external_codegen_0", kDLCUDA) .set_attr(tvm::attr::kIsExternalCodegen, Bool(true));