From 92c29e808b833b327f3912655f4f4b8ae988d56d Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Thu, 21 Jul 2022 07:38:33 -0700 Subject: [PATCH] [Target] Improve string interpretation in Target creation - SplitString now preserves escape sequences, but still observes quote characters. - Added function Interpret that transforms given string according to interpretation rules: - outermost quotes are removed (if present), - escape sequences inside quotes are preserved verbatim, - unquoted escape sequences produce the escaped character (the escape character (\) is removed. - Interpretation happens every time a value of any type is to be parsed from a string, e.g. Array will first be parsed as an array, then substrings of the input will be parsed as individual elements of that array. In this case, some parts of the initial input will be parsed (and interpreted) twice. - Implement corresponding stringification functionality. This new scheme enabled encoding nested arrays of string with any degree of nesting. For example "-field='\\'foo0\\',\\'bar0,bar1\\'','\\'zing0,zing1\\',\\'fred\\''" would correspond to the target kind attribute Array>>("field")) and have the value { { {foo0}, {bar0, bar1} }, { {zing0, zing1}, {fred} } } --- src/target/target.cc | 258 ++++++++++++++++++++++++++++----------- tests/cpp/target_test.cc | 86 +++++++++++++ 2 files changed, 274 insertions(+), 70 deletions(-) diff --git a/src/target/target.cc b/src/target/target.cc index 9ccd755540ca..6124cedc7256 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -30,8 +31,13 @@ #include #include -#include +#include +#include #include +#include +#include +#include +#include #include "../runtime/object_internal.h" @@ -62,6 +68,17 @@ class TargetInternal { private: static std::unordered_map QueryDevice(int device_id, const TargetNode* target); + static bool IsQuoted(const std::string& str); + static std::string Quote(const std::string& str); + static std::string JoinString(const std::vector& array, char separator); + static std::vector SplitString(const std::string& str, char separator); + static std::string Interpret(const std::string& str); + static std::string Uninterpret(const std::string& str); + static std::string StringifyAtomicType(const ObjectRef& obj); + static std::string StringifyArray(const ArrayNode& array); + + static constexpr char quote = '\''; + static constexpr char escape = '\\'; }; /********** Helper functions **********/ @@ -135,48 +152,50 @@ static std::string RemovePrefixDashes(const std::string& s) { return s.substr(n_dashes); } -static int FindFirstSubstr(const std::string& str, const std::string& substr) { - size_t pos = str.find_first_of(substr); - return pos == std::string::npos ? -1 : pos; -} - -static Optional JoinString(const std::vector& array, char separator) { - char escape = '\\'; - char quote = '\''; - - if (array.empty()) { - return NullOpt; +bool TargetInternal::IsQuoted(const std::string& str) { + std::string::size_type start = 0, end = str.size(); + if (end < 2 || str[start] != quote || str[end - 1] != quote) { + return false; } - - std::ostringstream os; - - for (size_t i = 0; i < array.size(); ++i) { - if (i > 0) { - os << separator; + bool escaping = false; + for (auto i = start + 1, e = end - 1; i < e; ++i) { + if (escaping) { + escaping = false; + } else if (str[i] == escape) { + escaping = true; + } else if (str[i] == quote) { + return false; } + } + // If the reduced string ends with \, then the terminating quote is escaped. + return !escaping; +} - std::string str = array[i]; +std::string TargetInternal::Quote(const std::string& str) { + std::string result(1, quote); + result.append(str); + result.push_back(quote); + return result; +} - if ((str.find(separator) == std::string::npos) && (str.find(quote) == std::string::npos)) { - os << str; - } else { - os << quote; - for (char c : str) { - if (c == quote) { - os << escape; - } - os << c; - } - os << quote; +std::string TargetInternal::JoinString(const std::vector& array, char separator) { + std::string result; + ICHECK(separator != quote && separator != escape) + << "string join separator cannot be " << quote << " or " << escape; + + bool is_first = true; + for (const auto& s : array) { + if (!is_first) { + result.push_back(separator); } + result.append(s); + is_first = false; } - return String(os.str()); -} -static std::vector SplitString(const std::string& str, char separator) { - char escape = '\\'; - char quote = '\''; + return result; +} +std::vector TargetInternal::SplitString(const std::string& str, char separator) { std::vector output; const char* start = str.data(); @@ -199,10 +218,12 @@ static std::vector SplitString(const std::string& str, char separat if ((*pos == separator) && !pos_quoted) { finish_word(); pos++; - } else if ((*pos == escape) && (pos + 1 < end) && (pos[1] == quote)) { - current_word << quote; + } else if (*pos == escape && pos + 1 < end) { + current_word << escape; + current_word << pos[1]; pos += 2; } else if (*pos == quote) { + current_word << quote; pos_quoted = !pos_quoted; pos++; } else { @@ -218,12 +239,91 @@ static std::vector SplitString(const std::string& str, char separat return output; } +std::string TargetInternal::Interpret(const std::string& str) { + // String interpretation deals with quotes (') and escapes(\). + // - An escape character must be followed by another character forming an + // "escape sequence". (Trailing escape is not allowed.) An escape prevents + // interpretation of the character that follows. This happens regardless of + // whether the escape sequence appears within quoted substring or not. + // - A quote character, when interpreted, marks the beginning or the end of a + // quoted substring. (A quoted substring cannot contain unescaped quotes.) + // - Any other character, when interpreted, represents itself. + // + // Interpretation happens in two steps: + // 1. If the entire string is quoted, the quotes are removed first, and the + // resulting string is treated as unquoted. + // 2. Each character or escape sequence is interpreted, and the result is copied + // to the result. When not inside a quoted substring, the interpretation of an + // escape sequence is the escaped character, otherwise it is the entire escape + // sequence. + // + // Examples: + // blah -> blah Nothing happened + // 'blah' -> blah Enclosing quotes removed + // 'bl'ah -> 'bl'ah Non-enclosing quotes remain + // '\'blah\'' -> 'blah' Enclosing quotes removed, escaped quotes + // interpreted. + // '\'\\\'blah\\\'\'' -> '\'blah\'' Same as above. + // + // Note that + // '\'\\\'blah\\\'\'' -> '\'blah\'' -> 'blah' + + std::string result; + if (str.empty()) { + return result; + } + + // Check if the entire string is enclosed in quotes ''. If so, strip the quotes + // and treat the string as unquoted (so that escapes are interpreted). Doing that + // will allow '\'foo\'' to become 'foo', instead of \'foo\'. + std::string::size_type start = 0, end = str.size(); + if (IsQuoted(str)) { + start++; + end--; + } + + bool inside_quote = false; + bool escaping = false; + + for (auto i = start, e = end; i < e; ++i) { + std::string::value_type c = str[i]; + if (escaping) { + escaping = false; + } else if (c == escape) { + escaping = true; + if (!inside_quote) { + continue; + } + } else if (c == quote) { + inside_quote = !inside_quote; + } + result.push_back(c); + } + + return result; +} + +std::string TargetInternal::Uninterpret(const std::string& str) { + // Do the opposite to `Interpret`, so that Interpret(Uninterpret(str)) == str. + std::string result; + + for (std::string::size_type i = 0, e = str.size(); i < e; ++i) { + std::string::value_type c = str[i]; + if (c == escape || c == quote) { + result.push_back(escape); + } + result.push_back(c); + } + + return result; +} + static int ParseKVPair(const std::string& s, const std::string& s_next, std::string* key, std::string* value) { - int pos; + std::string::size_type pos; std::string& result_k = *key; std::string& result_v = *value; - if ((pos = FindFirstSubstr(s, "=")) != -1) { + if ((pos = s.find_first_of('=')) != std::string::npos) { // case 1. --key=value result_k = s.substr(0, pos); result_v = s.substr(pos + 1); @@ -267,13 +367,14 @@ const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKi ObjectRef TargetInternal::ParseType(const std::string& str, const TargetKindNode::ValueTypeInfo& info) { + std::string interp_str = Interpret(str); if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing integer - std::istringstream is(str); + std::istringstream is(interp_str); int v; if (!(is >> v)) { - std::string lower(str.size(), '\x0'); - std::transform(str.begin(), str.end(), lower.begin(), + std::string lower(interp_str.size(), '\x0'); + std::transform(interp_str.begin(), interp_str.end(), lower.begin(), [](unsigned char c) { return std::tolower(c); }); // Bool is a subclass of IntImm, so allow textual boolean values. if (lower == "true") { @@ -281,23 +382,27 @@ ObjectRef TargetInternal::ParseType(const std::string& str, } else if (lower == "false") { v = 0; } else { - throw Error(": Cannot parse into type \"Integer\" from string: " + str); + throw Error(": Cannot parse into type \"Integer\" from string: " + interp_str); } } return Integer(v); } else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) { - // Parsing string, strip leading/trailing spaces - auto start = str.find_first_not_of(' '); - auto end = str.find_last_not_of(' '); - return String(str.substr(start, (end - start + 1))); + // Parsing string, strip leading/trailing spaces, and enclosing quotes if any + auto start = interp_str.find_first_not_of(' '); + auto end = interp_str.find_last_not_of(' '); + if (start == std::string::npos || end == std::string::npos) { + // The whole string is made of spaces. + return String(); + } + return String(interp_str.substr(start, (end - start + 1))); } else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) { // Parsing target - return Target(TargetInternal::FromString(str)); + return Target(TargetInternal::FromString(interp_str)); } else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) { // Parsing array std::vector result; - for (const std::string& substr : SplitString(str, ',')) { + for (const std::string& substr : SplitString(interp_str, ',')) { try { ObjectRef parsed = TargetInternal::ParseType(substr, *info.key); result.push_back(parsed); @@ -308,7 +413,8 @@ ObjectRef TargetInternal::ParseType(const std::string& str, } return Array(result); } - throw Error(": Unsupported type \"" + info.type_key + "\" for parsing from string: " + str); + throw Error(": Unsupported type \"" + info.type_key + + "\" for parsing from string: " + interp_str); } ObjectRef TargetInternal::ParseType(const ObjectRef& obj, @@ -385,14 +491,35 @@ ObjectRef TargetInternal::ParseType(const ObjectRef& obj, /********** Stringifying **********/ -static inline Optional StringifyAtomicType(const ObjectRef& obj) { +std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) { if (const auto* p = obj.as()) { - return String(std::to_string(p->value)); + return std::to_string(p->value); } if (const auto* p = obj.as()) { - return GetRef(p); + auto s = static_cast(GetRef(p)); + auto u = Uninterpret(s); + if (u.find_first_of(' ') != std::string::npos && !IsQuoted(u)) { + u = Quote(u); + } + return u; } - return NullOpt; + LOG(FATAL) << "Cannot stringify this object"; + return ""; // unreachable +} + +std::string TargetInternal::StringifyArray(const ArrayNode& array) { + std::vector elements; + + for (const ObjectRef& item : array) { + std::string s = StringifyAtomicType(item); + std::string u = Uninterpret(s); + if (u.find_first_of(',') != std::string::npos && !IsQuoted(u)) { + u = Quote(u); + } + elements.push_back(u); + } + + return JoinString(elements, ','); } Optional TargetInternal::StringifyAttrsToRaw(const Map& attrs) { @@ -402,30 +529,21 @@ Optional TargetInternal::StringifyAttrsToRaw(const Map result; + std::vector result; + for (const auto& key : keys) { const ObjectRef& obj = attrs[key]; - Optional value = NullOpt; + std::string value; if (const auto* array = obj.as()) { - std::vector items; - for (const ObjectRef& item : *array) { - Optional str = StringifyAtomicType(item); - if (str.defined()) { - items.push_back(str.value()); - } else { - items.clear(); - break; - } - } - value = JoinString(items, ','); + value = String(StringifyArray(*array)); } else { value = StringifyAtomicType(obj); } - if (value.defined()) { - result.push_back("-" + key + "=" + value.value()); + if (!value.empty()) { + result.push_back("-" + key + "=" + value); } } - return JoinString(result, ' '); + return String(JoinString(result, ' ')); } const std::string& TargetNode::str() const { diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index cb5eaa18b576..4b4de2b5f44a 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -203,6 +203,92 @@ TEST(TargetCreation, ClashingTargetProcessing) { EXPECT_THROW(Target test("TestClashingPreprocessor -mcpu=woof -mattr=cake"), InternalError); } +TVM_REGISTER_TARGET_KIND("TestStringKind", kDLCPU) + .add_attr_option("single") + .add_attr_option>("array") + .add_attr_option>>("nested-array") + .add_attr_option>>>("nested2-array"); + +TEST(TargetCreation, ProcessStrings) { + Target test_target1("TestStringKind -single='\\'string with single quote'"); + ASSERT_TRUE(test_target1->GetAttr("single")); + String string1 = test_target1->GetAttr("single").value(); + ASSERT_EQ(string1, "'string with single quote"); + + Target test_target2("TestStringKind -single='\\\'\\\\\\'blah\\\\\\'\\\''"); + ASSERT_TRUE(test_target2->GetAttr("single")); + String string2 = test_target2->GetAttr("single").value(); + ASSERT_EQ(string2, "'\\\'blah\\\''"); + + Target test_target3("TestStringKind -array=-danny,-sammy=1,-kirby='string with space'"); + ASSERT_TRUE(test_target3->GetAttr>("array")); + Array array3 = test_target3->GetAttr>("array").value(); + ASSERT_EQ(array3[0], "-danny"); + ASSERT_EQ(array3[1], "-sammy=1"); + ASSERT_EQ(array3[2], "-kirby='string with space'"); + + Target test_target4("TestStringKind -array='fred, foo, bar',baz"); + ASSERT_TRUE(test_target4->GetAttr>("array")); + Array array4 = test_target4->GetAttr>("array").value(); + ASSERT_EQ(array4[0], "fred, foo, bar"); + ASSERT_EQ(array4[1], "baz"); + + Target test_target5("TestStringKind -array='fr\\'ed','f\\'oo',' bar,baz '"); + ASSERT_TRUE(test_target5->GetAttr>("array")); + Array array5 = test_target5->GetAttr>("array").value(); + ASSERT_EQ(array5[0], "fr'ed"); + ASSERT_EQ(array5[1], "f'oo"); + ASSERT_EQ(array5[2], "bar,baz"); + + Target test_target6("TestStringKind -nested-array='foo0,foo1,foo2','bar0,bar1,bar2','baz0,baz1'"); + ASSERT_TRUE(test_target6->GetAttr>>("nested-array")); + Array> array6 = test_target6->GetAttr>>("nested-array").value(); + ASSERT_EQ(array6[0][0], "foo0"); + ASSERT_EQ(array6[0][1], "foo1"); + ASSERT_EQ(array6[0][2], "foo2"); + ASSERT_EQ(array6[1][0], "bar0"); + ASSERT_EQ(array6[1][1], "bar1"); + ASSERT_EQ(array6[1][2], "bar2"); + ASSERT_EQ(array6[2][0], "baz0"); + ASSERT_EQ(array6[2][1], "baz1"); + + Target test_target7( + "TestStringKind -nested2-array=" + "'\\'foo0,foo1\\',\\'bar0,bar1\\',\\'baz0,baz1\\''," + "'\\'zing0,zing1\\',\\'fred\\''"); + + ASSERT_TRUE(test_target7->GetAttr>>>("nested2-array")); + Array>> array7 = + test_target7->GetAttr>>>("nested2-array").value(); + // { + // {foo0, foo1}, + // {bar0, bar1}, + // {baz0, baz1}, + // }, + // { + // {zing0, zing1}, + // {fred}, + // } + ASSERT_EQ(array7.size(), 2); + ASSERT_EQ(array7[0].size(), 3); + ASSERT_EQ(array7[0][0].size(), 2); + ASSERT_EQ(array7[0][1].size(), 2); + ASSERT_EQ(array7[0][2].size(), 2); + ASSERT_EQ(array7[1].size(), 2); + ASSERT_EQ(array7[1][0].size(), 2); + ASSERT_EQ(array7[1][1].size(), 1); + + ASSERT_EQ(array7[0][0][0], "foo0"); + ASSERT_EQ(array7[0][0][1], "foo1"); + ASSERT_EQ(array7[0][1][0], "bar0"); + ASSERT_EQ(array7[0][1][1], "bar1"); + ASSERT_EQ(array7[0][2][0], "baz0"); + ASSERT_EQ(array7[0][2][1], "baz1"); + ASSERT_EQ(array7[1][0][0], "zing0"); + ASSERT_EQ(array7[1][0][1], "zing1"); + ASSERT_EQ(array7[1][1][0], "fred"); +} + TVM_REGISTER_TARGET_KIND("test_external_codegen_0", kDLCUDA) .set_attr(tvm::attr::kIsExternalCodegen, Bool(true));