Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[Target] Improve string interpretation in Target creation (apache#12152)
Browse files Browse the repository at this point in the history
- SplitString now preserves escape sequences, but still observes
  quote characters.
- Added function Interpret that transforms given string according
  to interpretation rules:
  - outermost quotes are removed (if present),
  - escape sequences inside quotes are preserved verbatim,
  - unquoted escape sequences produce the escaped character (the
    escape character (\) is removed.
- Interpretation happens every time a value of any type is to be
  parsed from a string, e.g. Array<String> will first be parsed
  as an array, then substrings of the input will be parsed as
  individual elements of that array. In this case, some parts of
  the initial input will be parsed (and interpreted) twice.
- Implement corresponding stringification functionality.

This new scheme enabled encoding nested arrays of string with any
degree of nesting. For example
  "-field='\\'foo0\\',\\'bar0,bar1\\'','\\'zing0,zing1\\',\\'fred\\''"
would correspond to the target kind attribute
  Array<Array<Array<String>>>("field"))
and have the value
 { { {foo0}, {bar0, bar1} }, { {zing0, zing1}, {fred} } }
  • Loading branch information
Krzysztof Parzyszek authored and xinetzone committed Nov 25, 2022
1 parent 5318c7e commit 1d81f78
Show file tree
Hide file tree
Showing 2 changed files with 274 additions and 70 deletions.
258 changes: 188 additions & 70 deletions src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
*/
#include <dmlc/thread_local.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/registry.h>
#include <tvm/target/tag.h>
#include <tvm/target/target.h>
Expand All @@ -30,8 +31,13 @@

#include <algorithm>
#include <cctype>
#include <cstring>
#include <ios>
#include <sstream>
#include <stack>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "../runtime/object_internal.h"

Expand Down Expand Up @@ -62,6 +68,17 @@ class TargetInternal {

private:
static std::unordered_map<String, ObjectRef> QueryDevice(int device_id, const TargetNode* target);
static bool IsQuoted(const std::string& str);
static std::string Quote(const std::string& str);
static std::string JoinString(const std::vector<std::string>& array, char separator);
static std::vector<std::string> SplitString(const std::string& str, char separator);
static std::string Interpret(const std::string& str);
static std::string Uninterpret(const std::string& str);
static std::string StringifyAtomicType(const ObjectRef& obj);
static std::string StringifyArray(const ArrayNode& array);

static constexpr char quote = '\'';
static constexpr char escape = '\\';
};

/********** Helper functions **********/
Expand Down Expand Up @@ -135,48 +152,50 @@ static std::string RemovePrefixDashes(const std::string& s) {
return s.substr(n_dashes);
}

static int FindFirstSubstr(const std::string& str, const std::string& substr) {
size_t pos = str.find_first_of(substr);
return pos == std::string::npos ? -1 : pos;
}

static Optional<String> JoinString(const std::vector<String>& array, char separator) {
char escape = '\\';
char quote = '\'';

if (array.empty()) {
return NullOpt;
bool TargetInternal::IsQuoted(const std::string& str) {
std::string::size_type start = 0, end = str.size();
if (end < 2 || str[start] != quote || str[end - 1] != quote) {
return false;
}

std::ostringstream os;

for (size_t i = 0; i < array.size(); ++i) {
if (i > 0) {
os << separator;
bool escaping = false;
for (auto i = start + 1, e = end - 1; i < e; ++i) {
if (escaping) {
escaping = false;
} else if (str[i] == escape) {
escaping = true;
} else if (str[i] == quote) {
return false;
}
}
// If the reduced string ends with \, then the terminating quote is escaped.
return !escaping;
}

std::string str = array[i];
std::string TargetInternal::Quote(const std::string& str) {
std::string result(1, quote);
result.append(str);
result.push_back(quote);
return result;
}

if ((str.find(separator) == std::string::npos) && (str.find(quote) == std::string::npos)) {
os << str;
} else {
os << quote;
for (char c : str) {
if (c == quote) {
os << escape;
}
os << c;
}
os << quote;
std::string TargetInternal::JoinString(const std::vector<std::string>& array, char separator) {
std::string result;
ICHECK(separator != quote && separator != escape)
<< "string join separator cannot be " << quote << " or " << escape;

bool is_first = true;
for (const auto& s : array) {
if (!is_first) {
result.push_back(separator);
}
result.append(s);
is_first = false;
}
return String(os.str());
}

static std::vector<std::string> SplitString(const std::string& str, char separator) {
char escape = '\\';
char quote = '\'';
return result;
}

std::vector<std::string> TargetInternal::SplitString(const std::string& str, char separator) {
std::vector<std::string> output;

const char* start = str.data();
Expand All @@ -199,10 +218,12 @@ static std::vector<std::string> SplitString(const std::string& str, char separat
if ((*pos == separator) && !pos_quoted) {
finish_word();
pos++;
} else if ((*pos == escape) && (pos + 1 < end) && (pos[1] == quote)) {
current_word << quote;
} else if (*pos == escape && pos + 1 < end) {
current_word << escape;
current_word << pos[1];
pos += 2;
} else if (*pos == quote) {
current_word << quote;
pos_quoted = !pos_quoted;
pos++;
} else {
Expand All @@ -218,12 +239,91 @@ static std::vector<std::string> SplitString(const std::string& str, char separat
return output;
}

std::string TargetInternal::Interpret(const std::string& str) {
// String interpretation deals with quotes (') and escapes(\).
// - An escape character must be followed by another character forming an
// "escape sequence". (Trailing escape is not allowed.) An escape prevents
// interpretation of the character that follows. This happens regardless of
// whether the escape sequence appears within quoted substring or not.
// - A quote character, when interpreted, marks the beginning or the end of a
// quoted substring. (A quoted substring cannot contain unescaped quotes.)
// - Any other character, when interpreted, represents itself.
//
// Interpretation happens in two steps:
// 1. If the entire string is quoted, the quotes are removed first, and the
// resulting string is treated as unquoted.
// 2. Each character or escape sequence is interpreted, and the result is copied
// to the result. When not inside a quoted substring, the interpretation of an
// escape sequence is the escaped character, otherwise it is the entire escape
// sequence.
//
// Examples:
// blah -> blah Nothing happened
// 'blah' -> blah Enclosing quotes removed
// 'bl'ah -> 'bl'ah Non-enclosing quotes remain
// '\'blah\'' -> 'blah' Enclosing quotes removed, escaped quotes
// interpreted.
// '\'\\\'blah\\\'\'' -> '\'blah\'' Same as above.
//
// Note that
// '\'\\\'blah\\\'\'' -> '\'blah\'' -> 'blah'

std::string result;
if (str.empty()) {
return result;
}

// Check if the entire string is enclosed in quotes ''. If so, strip the quotes
// and treat the string as unquoted (so that escapes are interpreted). Doing that
// will allow '\'foo\'' to become 'foo', instead of \'foo\'.
std::string::size_type start = 0, end = str.size();
if (IsQuoted(str)) {
start++;
end--;
}

bool inside_quote = false;
bool escaping = false;

for (auto i = start, e = end; i < e; ++i) {
std::string::value_type c = str[i];
if (escaping) {
escaping = false;
} else if (c == escape) {
escaping = true;
if (!inside_quote) {
continue;
}
} else if (c == quote) {
inside_quote = !inside_quote;
}
result.push_back(c);
}

return result;
}

std::string TargetInternal::Uninterpret(const std::string& str) {
// Do the opposite to `Interpret`, so that Interpret(Uninterpret(str)) == str.
std::string result;

for (std::string::size_type i = 0, e = str.size(); i < e; ++i) {
std::string::value_type c = str[i];
if (c == escape || c == quote) {
result.push_back(escape);
}
result.push_back(c);
}

return result;
}

static int ParseKVPair(const std::string& s, const std::string& s_next, std::string* key,
std::string* value) {
int pos;
std::string::size_type pos;
std::string& result_k = *key;
std::string& result_v = *value;
if ((pos = FindFirstSubstr(s, "=")) != -1) {
if ((pos = s.find_first_of('=')) != std::string::npos) {
// case 1. --key=value
result_k = s.substr(0, pos);
result_v = s.substr(pos + 1);
Expand Down Expand Up @@ -267,37 +367,42 @@ const TargetKindNode::ValueTypeInfo& TargetInternal::FindTypeInfo(const TargetKi

ObjectRef TargetInternal::ParseType(const std::string& str,
const TargetKindNode::ValueTypeInfo& info) {
std::string interp_str = Interpret(str);
if (info.type_index == Integer::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing integer
std::istringstream is(str);
std::istringstream is(interp_str);
int v;
if (!(is >> v)) {
std::string lower(str.size(), '\x0');
std::transform(str.begin(), str.end(), lower.begin(),
std::string lower(interp_str.size(), '\x0');
std::transform(interp_str.begin(), interp_str.end(), lower.begin(),
[](unsigned char c) { return std::tolower(c); });
// Bool is a subclass of IntImm, so allow textual boolean values.
if (lower == "true") {
v = 1;
} else if (lower == "false") {
v = 0;
} else {
throw Error(": Cannot parse into type \"Integer\" from string: " + str);
throw Error(": Cannot parse into type \"Integer\" from string: " + interp_str);
}
}
return Integer(v);
} else if (info.type_index == String::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing string, strip leading/trailing spaces
auto start = str.find_first_not_of(' ');
auto end = str.find_last_not_of(' ');
return String(str.substr(start, (end - start + 1)));
// Parsing string, strip leading/trailing spaces, and enclosing quotes if any
auto start = interp_str.find_first_not_of(' ');
auto end = interp_str.find_last_not_of(' ');
if (start == std::string::npos || end == std::string::npos) {
// The whole string is made of spaces.
return String();
}
return String(interp_str.substr(start, (end - start + 1)));

} else if (info.type_index == Target::ContainerType::_GetOrAllocRuntimeTypeIndex()) {
// Parsing target
return Target(TargetInternal::FromString(str));
return Target(TargetInternal::FromString(interp_str));
} else if (info.type_index == ArrayNode::_GetOrAllocRuntimeTypeIndex()) {
// Parsing array
std::vector<ObjectRef> result;
for (const std::string& substr : SplitString(str, ',')) {
for (const std::string& substr : SplitString(interp_str, ',')) {
try {
ObjectRef parsed = TargetInternal::ParseType(substr, *info.key);
result.push_back(parsed);
Expand All @@ -308,7 +413,8 @@ ObjectRef TargetInternal::ParseType(const std::string& str,
}
return Array<ObjectRef>(result);
}
throw Error(": Unsupported type \"" + info.type_key + "\" for parsing from string: " + str);
throw Error(": Unsupported type \"" + info.type_key +
"\" for parsing from string: " + interp_str);
}

ObjectRef TargetInternal::ParseType(const ObjectRef& obj,
Expand Down Expand Up @@ -385,14 +491,35 @@ ObjectRef TargetInternal::ParseType(const ObjectRef& obj,

/********** Stringifying **********/

static inline Optional<String> StringifyAtomicType(const ObjectRef& obj) {
std::string TargetInternal::StringifyAtomicType(const ObjectRef& obj) {
if (const auto* p = obj.as<IntImmNode>()) {
return String(std::to_string(p->value));
return std::to_string(p->value);
}
if (const auto* p = obj.as<StringObj>()) {
return GetRef<String>(p);
auto s = static_cast<std::string>(GetRef<String>(p));
auto u = Uninterpret(s);
if (u.find_first_of(' ') != std::string::npos && !IsQuoted(u)) {
u = Quote(u);
}
return u;
}
return NullOpt;
LOG(FATAL) << "Cannot stringify this object";
return ""; // unreachable
}

std::string TargetInternal::StringifyArray(const ArrayNode& array) {
std::vector<std::string> elements;

for (const ObjectRef& item : array) {
std::string s = StringifyAtomicType(item);
std::string u = Uninterpret(s);
if (u.find_first_of(',') != std::string::npos && !IsQuoted(u)) {
u = Quote(u);
}
elements.push_back(u);
}

return JoinString(elements, ',');
}

Optional<String> TargetInternal::StringifyAttrsToRaw(const Map<String, ObjectRef>& attrs) {
Expand All @@ -402,30 +529,21 @@ Optional<String> TargetInternal::StringifyAttrsToRaw(const Map<String, ObjectRef
keys.push_back(kv.first);
}
std::sort(keys.begin(), keys.end());
std::vector<String> result;
std::vector<std::string> result;

for (const auto& key : keys) {
const ObjectRef& obj = attrs[key];
Optional<String> value = NullOpt;
std::string value;
if (const auto* array = obj.as<ArrayNode>()) {
std::vector<String> items;
for (const ObjectRef& item : *array) {
Optional<String> str = StringifyAtomicType(item);
if (str.defined()) {
items.push_back(str.value());
} else {
items.clear();
break;
}
}
value = JoinString(items, ',');
value = String(StringifyArray(*array));
} else {
value = StringifyAtomicType(obj);
}
if (value.defined()) {
result.push_back("-" + key + "=" + value.value());
if (!value.empty()) {
result.push_back("-" + key + "=" + value);
}
}
return JoinString(result, ' ');
return String(JoinString(result, ' '));
}

const std::string& TargetNode::str() const {
Expand Down
Loading

0 comments on commit 1d81f78

Please sign in to comment.