From 20a67e47caefc3caf1d4fef9612c5caf6a4fe611 Mon Sep 17 00:00:00 2001 From: Nikolaj Bjorner Date: Sat, 22 May 2021 13:12:49 -0700 Subject: [PATCH] remove symbol -> zstring -> symbol round-trips Signed-off-by: Nikolaj Bjorner --- src/ast/ast.cpp | 7 +++++++ src/ast/ast.h | 11 ++++++++++- src/ast/rewriter/seq_axioms.cpp | 2 +- src/ast/rewriter/seq_rewriter.cpp | 8 ++++---- src/ast/seq_decl_plugin.cpp | 16 ++++------------ src/ast/seq_decl_plugin.h | 13 +++---------- src/model/seq_factory.h | 14 ++++++++------ src/parsers/smt2/smt2parser.cpp | 2 +- src/smt/theory_str.cpp | 3 +-- src/smt/theory_str.h | 11 ++++++----- src/smt/theory_str_mc.cpp | 20 ++++++++------------ src/util/zstring.cpp | 14 +++++++++----- src/util/zstring.h | 9 ++++++--- 13 files changed, 68 insertions(+), 62 deletions(-) diff --git a/src/ast/ast.cpp b/src/ast/ast.cpp index 96744bc9bc2..82f0d0c86f1 100644 --- a/src/ast/ast.cpp +++ b/src/ast/ast.cpp @@ -41,6 +41,9 @@ parameter::~parameter() { if (m_kind == PARAM_RATIONAL) { dealloc(m_rational); } + if (m_kind == PARAM_ZSTRING) { + dealloc(m_zstring); + } } parameter::parameter(parameter const& other) { @@ -64,6 +67,7 @@ parameter& parameter::operator=(parameter const& other) { case PARAM_RATIONAL: m_rational = alloc(rational, other.get_rational()); break; case PARAM_DOUBLE: m_dval = other.m_dval; break; case PARAM_EXTERNAL: m_ext_id = other.m_ext_id; break; + case PARAM_ZSTRING: m_zstring = alloc(zstring, other.get_zstring()); break; default: UNREACHABLE(); break; @@ -99,6 +103,7 @@ bool parameter::operator==(parameter const & p) const { case PARAM_RATIONAL: return get_rational() == p.get_rational(); case PARAM_DOUBLE: return m_dval == p.m_dval; case PARAM_EXTERNAL: return m_ext_id == p.m_ext_id; + case PARAM_ZSTRING: return get_zstring() == p.get_zstring(); default: UNREACHABLE(); return false; } } @@ -111,6 +116,7 @@ unsigned parameter::hash() const { case PARAM_SYMBOL: b = get_symbol().hash(); break; case PARAM_RATIONAL: b = get_rational().hash(); break; case PARAM_DOUBLE: b = static_cast(m_dval); break; + case PARAM_ZSTRING: b = get_zstring().hash(); break; case PARAM_EXTERNAL: b = m_ext_id; break; } return (b << 2) | m_kind; @@ -124,6 +130,7 @@ std::ostream& parameter::display(std::ostream& out) const { case PARAM_AST: return out << "#" << get_ast()->get_id(); case PARAM_DOUBLE: return out << m_dval; case PARAM_EXTERNAL: return out << "@" << m_ext_id; + case PARAM_ZSTRING: return out << get_zstring(); default: UNREACHABLE(); return out << "[invalid parameter]"; diff --git a/src/ast/ast.h b/src/ast/ast.h index 70e079605f7..cdf53d4b280 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -22,6 +22,7 @@ Revision History: #include "util/vector.h" #include "util/hashtable.h" #include "util/buffer.h" +#include "util/zstring.h" #include "util/symbol.h" #include "util/rational.h" #include "util/hash.h" @@ -100,6 +101,7 @@ class parameter { PARAM_INT, PARAM_AST, PARAM_SYMBOL, + PARAM_ZSTRING, PARAM_RATIONAL, PARAM_DOUBLE, // PARAM_EXTERNAL is used for handling decl_plugin specific parameters. @@ -119,6 +121,7 @@ class parameter { ast* m_ast; // for PARAM_AST symbol m_symbol; // for PARAM_SYMBOL rational* m_rational; // for PARAM_RATIONAL + zstring* m_zstring; // for PARAM_ZSTRING double m_dval; // for PARAM_DOUBLE (remark: this is not used in float_decl_plugin) unsigned m_ext_id; // for PARAM_EXTERNAL }; @@ -131,7 +134,9 @@ class parameter { explicit parameter(ast * p): m_kind(PARAM_AST), m_ast(p) {} explicit parameter(symbol const & s): m_kind(PARAM_SYMBOL), m_symbol(s) {} explicit parameter(rational const & r): m_kind(PARAM_RATIONAL), m_rational(alloc(rational, r)) {} - explicit parameter(rational && r) : m_kind(PARAM_RATIONAL), m_rational(alloc(rational, std::move(r))) {} + explicit parameter(rational && r) : m_kind(PARAM_RATIONAL), m_rational(alloc(rational, std::move(r))) {} + explicit parameter(zstring const& s): m_kind(PARAM_ZSTRING), m_zstring(alloc(zstring, s)) {} + explicit parameter(zstring && s): m_kind(PARAM_ZSTRING), m_zstring(alloc(zstring, std::move(s))) {} explicit parameter(double d):m_kind(PARAM_DOUBLE), m_dval(d) {} explicit parameter(const char *s):m_kind(PARAM_SYMBOL), m_symbol(symbol(s)) {} explicit parameter(const std::string &s):m_kind(PARAM_SYMBOL), m_symbol(symbol(s)) {} @@ -146,6 +151,7 @@ class parameter { case PARAM_RATIONAL: m_rational = nullptr; std::swap(m_rational, other.m_rational); break; case PARAM_DOUBLE: m_dval = other.m_dval; break; case PARAM_EXTERNAL: m_ext_id = other.m_ext_id; break; + case PARAM_ZSTRING: m_zstring = other.m_zstring; break; default: UNREACHABLE(); break; @@ -163,6 +169,7 @@ class parameter { bool is_rational() const { return m_kind == PARAM_RATIONAL; } bool is_double() const { return m_kind == PARAM_DOUBLE; } bool is_external() const { return m_kind == PARAM_EXTERNAL; } + bool is_zstring() const { return m_kind == PARAM_ZSTRING; } bool is_int(int & i) const { return is_int() && (i = get_int(), true); } bool is_ast(ast * & a) const { return is_ast() && (a = get_ast(), true); } @@ -170,6 +177,7 @@ class parameter { bool is_rational(rational & r) const { return is_rational() && (r = get_rational(), true); } bool is_double(double & d) const { return is_double() && (d = get_double(), true); } bool is_external(unsigned & id) const { return is_external() && (id = get_ext_id(), true); } + bool is_zstring(zstring& s) const { return is_zstring() && (s = get_zstring(), true); } /** \brief This method is invoked when the parameter is @@ -187,6 +195,7 @@ class parameter { ast * get_ast() const { SASSERT(is_ast()); return m_ast; } symbol get_symbol() const { SASSERT(is_symbol()); return m_symbol; } rational const & get_rational() const { SASSERT(is_rational()); return *m_rational; } + zstring const& get_zstring() const { SASSERT(is_zstring()); return *m_zstring; } double get_double() const { SASSERT(is_double()); return m_dval; } unsigned get_ext_id() const { SASSERT(is_external()); return m_ext_id; } diff --git a/src/ast/rewriter/seq_axioms.cpp b/src/ast/rewriter/seq_axioms.cpp index 93f1351765a..44389eb6ca4 100644 --- a/src/ast/rewriter/seq_axioms.cpp +++ b/src/ast/rewriter/seq_axioms.cpp @@ -682,7 +682,7 @@ namespace seq { // itos(n) does not start with "0" when n > 0 // n = 0 or at(itos(n),0) != "0" // alternative: n >= 0 => itos(stoi(itos(n))) = itos(n) - expr_ref zs(seq.str.mk_string(symbol("0")), m); + expr_ref zs(seq.str.mk_string("0"), m); m_rewrite(zs); expr_ref eq0 = mk_eq(n, zero); expr_ref at0 = mk_eq(seq.str.mk_at(e, zero), zs); diff --git a/src/ast/rewriter/seq_rewriter.cpp b/src/ast/rewriter/seq_rewriter.cpp index 8c0b729600f..0ca46f082dc 100644 --- a/src/ast/rewriter/seq_rewriter.cpp +++ b/src/ast/rewriter/seq_rewriter.cpp @@ -2159,7 +2159,7 @@ br_status seq_rewriter::mk_str_from_code(expr* a, expr_ref& result) { rational r; if (m_autil.is_numeral(a, r)) { if (r.is_neg() || r > u().max_char()) { - result = str().mk_string(symbol("")); + result = str().mk_string(zstring()); } else { unsigned num = r.get_unsigned(); @@ -2207,10 +2207,10 @@ br_status seq_rewriter::mk_str_itos(expr* a, expr_ref& result) { rational r; if (m_autil.is_numeral(a, r)) { if (r.is_int() && !r.is_neg()) { - result = str().mk_string(symbol(r.to_string())); + result = str().mk_string(zstring(r)); } else { - result = str().mk_string(symbol("")); + result = str().mk_string(zstring()); } return BR_DONE; } @@ -2225,7 +2225,7 @@ br_status seq_rewriter::mk_str_itos(expr* a, expr_ref& result) { eqs.push_back(m().mk_eq(b, str().mk_string(s))); } result = m().mk_or(eqs); - result = m().mk_ite(result, b, str().mk_string(symbol(""))); + result = m().mk_ite(result, b, str().mk_string(zstring())); return BR_REWRITE2; } return BR_FAILED; diff --git a/src/ast/seq_decl_plugin.cpp b/src/ast/seq_decl_plugin.cpp index 048e6ffed59..8536583e638 100644 --- a/src/ast/seq_decl_plugin.cpp +++ b/src/ast/seq_decl_plugin.cpp @@ -375,7 +375,7 @@ func_decl * seq_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters, case OP_SEQ_EMPTY: match(*m_sigs[k], arity, domain, range, rng); if (rng == m_string) { - parameter param(symbol("")); + parameter param(zstring("")); return mk_func_decl(OP_STRING_CONST, 1, ¶m, 0, nullptr, m_string); } else { @@ -474,7 +474,7 @@ func_decl * seq_decl_plugin::mk_func_decl(decl_kind k, unsigned num_parameters, m.raise_exception("Incorrect arguments used for re.^. Expected one non-negative integer parameter"); case OP_STRING_CONST: - if (!(num_parameters == 1 && arity == 0 && parameters[0].is_symbol())) { + if (!(num_parameters == 1 && arity == 0 && parameters[0].is_zstring())) { m.raise_exception("invalid string declaration"); } return m.mk_const_decl(m_stringc_sym, m_string, @@ -630,16 +630,8 @@ void seq_decl_plugin::get_sort_names(svector & sort_names, symbol sort_names.push_back(builtin_name("StringSequence", _STRING_SORT)); } -app* seq_decl_plugin::mk_string(symbol const& s) { - parameter param(s); - func_decl* f = m_manager->mk_const_decl(m_stringc_sym, m_string, - func_decl_info(m_family_id, OP_STRING_CONST, 1, ¶m)); - return m_manager->mk_const(f); -} - app* seq_decl_plugin::mk_string(zstring const& s) { - symbol sym(s.encode()); - parameter param(sym); + parameter param(s); func_decl* f = m_manager->mk_const_decl(m_stringc_sym, m_string, func_decl_info(m_family_id, OP_STRING_CONST, 1, ¶m)); return m_manager->mk_const(f); @@ -792,7 +784,7 @@ app* seq_util::mk_lt(expr* ch1, expr* ch2) const { bool seq_util::str::is_string(func_decl const* f, zstring& s) const { if (is_string(f)) { - s = zstring(f->get_parameter(0).get_symbol().bare_str()); + s = f->get_parameter(0).get_zstring(); return true; } else { diff --git a/src/ast/seq_decl_plugin.h b/src/ast/seq_decl_plugin.h index 16b8e76db16..9adb4df3213 100644 --- a/src/ast/seq_decl_plugin.h +++ b/src/ast/seq_decl_plugin.h @@ -191,7 +191,6 @@ class seq_decl_plugin : public decl_plugin { unsigned max_char() const { return get_char_plugin().max_char(); } unsigned num_bits() const { return get_char_plugin().num_bits(); } - app* mk_string(symbol const& s); app* mk_string(zstring const& s); app* mk_char(unsigned ch); @@ -262,9 +261,6 @@ class seq_util { ast_manager& m; family_id m_fid; - app* mk_string(char const* s) { return mk_string(symbol(s)); } - app* mk_string(std::string const& s) { return mk_string(symbol(s.c_str())); } - public: str(seq_util& u): u(u), m(u.m), m_fid(u.m_fid) {} @@ -273,7 +269,6 @@ class seq_util { sort* mk_string_sort() const { return m.mk_sort(m_fid, _STRING_SORT, 0, nullptr); } app* mk_empty(sort* s) const { return m.mk_const(m.mk_func_decl(m_fid, OP_SEQ_EMPTY, 0, nullptr, 0, (expr*const*)nullptr, s)); } app* mk_string(zstring const& s) const; - app* mk_string(symbol const& s) const { return u.seq.mk_string(s); } app* mk_char(unsigned ch) const; app* mk_concat(expr* a, expr* b) const { expr* es[2] = { a, b }; return m.mk_app(m_fid, OP_SEQ_CONCAT, 2, es); } app* mk_concat(expr* a, expr* b, expr* c) const { return mk_concat(a, mk_concat(b, c)); } @@ -313,14 +308,12 @@ class seq_util { bool is_skolem(func_decl const* f) const { return is_decl_of(f, m_fid, _OP_SEQ_SKOLEM); } bool is_string(expr const * n) const { return is_app_of(n, m_fid, OP_STRING_CONST); } - bool is_string(expr const* n, symbol& s) const { - return is_string(n) && (s = to_app(n)->get_decl()->get_parameter(0).get_symbol(), true); - } bool is_string(func_decl const* f) const { return is_decl_of(f, m_fid, OP_STRING_CONST); } bool is_string(expr const* n, zstring& s) const; bool is_string(func_decl const* f, zstring& s) const; - bool is_empty(expr const* n) const { symbol s; - return is_app_of(n, m_fid, OP_SEQ_EMPTY) || (is_string(n, s) && !s.is_numerical() && *s.bare_str() == 0); + bool is_empty(expr const* n) const { + zstring s; + return is_app_of(n, m_fid, OP_SEQ_EMPTY) || (is_string(n, s) && s.empty()); } bool is_concat(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_CONCAT); } bool is_length(expr const* n) const { return is_app_of(n, m_fid, OP_SEQ_LENGTH); } diff --git a/src/model/seq_factory.h b/src/model/seq_factory.h index d40a7f7bdc4..f5046b384ed 100644 --- a/src/model/seq_factory.h +++ b/src/model/seq_factory.h @@ -67,8 +67,8 @@ class seq_factory : public value_factory { } bool get_some_values(sort* s, expr_ref& v1, expr_ref& v2) override { if (u.is_string(s)) { - v1 = u.str.mk_string(symbol("a")); - v2 = u.str.mk_string(symbol("b")); + v1 = u.str.mk_string("a"); + v2 = u.str.mk_string("b"); return true; } sort* ch; @@ -94,10 +94,11 @@ class seq_factory : public value_factory { while (true) { std::ostringstream strm; strm << m_unique_delim << std::hex << m_next++ << std::dec << m_unique_delim; - symbol sym(strm.str()); + std::string s(strm.str()); + symbol sym(s); if (m_strings.contains(sym)) continue; m_strings.insert(sym); - return u.str.mk_string(sym); + return u.str.mk_string(s); } } sort* seq = nullptr, *ch = nullptr; @@ -131,8 +132,9 @@ class seq_factory : public value_factory { return nullptr; } void register_value(expr* n) override { - symbol sym; - if (u.str.is_string(n, sym)) { + zstring s; + if (u.str.is_string(n, s)) { + symbol sym(s.encode()); m_strings.insert(sym); if (sym.str().find(m_unique_delim) != std::string::npos) add_new_delim(); diff --git a/src/parsers/smt2/smt2parser.cpp b/src/parsers/smt2/smt2parser.cpp index 596cb8af90e..344c61283f9 100644 --- a/src/parsers/smt2/smt2parser.cpp +++ b/src/parsers/smt2/smt2parser.cpp @@ -1188,7 +1188,7 @@ namespace smt2 { void parse_string_const() { SASSERT(curr() == scanner::STRING_TOKEN); - zstring zs(m_scanner.get_string(), true); + zstring zs(m_scanner.get_string()); expr_stack().push_back(sutil().str.mk_string(zs)); TRACE("smt2parser", tout << "new string: " << mk_pp(expr_stack().back(), m()) << "\n";); next(); diff --git a/src/smt/theory_str.cpp b/src/smt/theory_str.cpp index 6ef086b6098..e5c1469a6fc 100644 --- a/src/smt/theory_str.cpp +++ b/src/smt/theory_str.cpp @@ -190,8 +190,7 @@ namespace smt { } expr * theory_str::mk_string(const char * str) { - symbol sym(str); - return u.str.mk_string(sym); + return u.str.mk_string(str); } void theory_str::collect_statistics(::statistics & st) const { diff --git a/src/smt/theory_str.h b/src/smt/theory_str.h index 14356f8a17f..d96a4e4aff1 100644 --- a/src/smt/theory_str.h +++ b/src/smt/theory_str.h @@ -53,11 +53,11 @@ class str_value_factory : public value_factory { u(m), delim("!"), m_next(0) {} ~str_value_factory() override {} expr * get_some_value(sort * s) override { - return u.str.mk_string(symbol("some value")); + return u.str.mk_string("some value"); } bool get_some_values(sort * s, expr_ref & v1, expr_ref & v2) override { - v1 = u.str.mk_string(symbol("value 1")); - v2 = u.str.mk_string(symbol("value 2")); + v1 = u.str.mk_string("value 1"); + v2 = u.str.mk_string("value 2"); return true; } expr * get_fresh_value(sort * s) override { @@ -65,10 +65,11 @@ class str_value_factory : public value_factory { while (true) { std::ostringstream strm; strm << delim << std::hex << (m_next++) << std::dec << delim; - symbol sym(strm.str()); + std::string s(strm.str()); + symbol sym(s); if (m_strings.contains(sym)) continue; m_strings.insert(sym); - return u.str.mk_string(sym); + return u.str.mk_string(s); } } sort* seq = nullptr; diff --git a/src/smt/theory_str_mc.cpp b/src/smt/theory_str_mc.cpp index 8979a7a379b..49efa107738 100644 --- a/src/smt/theory_str_mc.cpp +++ b/src/smt/theory_str_mc.cpp @@ -1337,9 +1337,8 @@ namespace smt { rw(arg_subst); TRACE("str_fl", tout << "ival = " << ival << ", string arg evaluates to " << mk_pp(arg_subst, m) << std::endl;); - symbol arg_str; - if (u.str.is_string(arg_subst, arg_str)) { - zstring arg_zstr(arg_str.bare_str()); + zstring arg_zstr; + if (u.str.is_string(arg_subst, arg_zstr)) { rational arg_value; if (string_integer_conversion_valid(arg_zstr, arg_value)) { if (ival != arg_value) { @@ -1365,9 +1364,8 @@ namespace smt { (*replacer)(arg, arg_subst); rw(arg_subst); TRACE("str_fl", tout << "ival = " << ival << ", string arg evaluates to " << mk_pp(arg_subst, m) << std::endl;); - symbol arg_str; - if (u.str.is_string(arg_subst, arg_str)) { - zstring arg_zstr(arg_str.bare_str()); + zstring arg_zstr; + if (u.str.is_string(arg_subst, arg_zstr)) { if (ival >= rational::zero() && ival <= rational(u.max_char())) { // check that arg_subst has length 1 and that the codepoints are the same if (arg_zstr.length() != 1 || rational(arg_zstr[0]) != ival) { @@ -1396,9 +1394,8 @@ namespace smt { rw(e_subst); TRACE("str_fl", tout << "ival = " << ival << ", string arg evaluates to " << mk_pp(e_subst, m) << std::endl;); - symbol e_str; - if (u.str.is_string(e_subst, e_str)) { - zstring e_zstr(e_str.bare_str()); + zstring e_zstr; + if (u.str.is_string(e_subst, e_zstr)) { // if arg is negative, e must be empty // if arg is non-negative, e must be valid AND cannot contain leading zeroes @@ -1436,9 +1433,8 @@ namespace smt { (*replacer)(e, e_subst); rw(e_subst); TRACE("str_fl", tout << "ival = " << ival << ", string arg evaluates to " << mk_pp(e_subst, m) << std::endl;); - symbol e_str; - if (u.str.is_string(e_subst, e_str)) { - zstring e_zstr(e_str.bare_str()); + zstring e_zstr; + if (u.str.is_string(e_subst, e_zstr)) { // if arg is out of range, e must be empty // if arg is in range, e must be valid if (ival <= rational::zero() || ival >= rational(u.max_char())) { diff --git a/src/util/zstring.cpp b/src/util/zstring.cpp index 5a57ec862c9..d5f12353317 100644 --- a/src/util/zstring.cpp +++ b/src/util/zstring.cpp @@ -33,7 +33,7 @@ static bool is_hex_digit(char ch, unsigned& d) { return false; } -bool zstring::is_escape_char(bool from_input, char const *& s, unsigned& result) { +bool zstring::is_escape_char(char const *& s, unsigned& result) { unsigned d; if (*s == '\\' && s[1] == 'u' && s[2] == '{' && s[3] != '}') { result = 0; @@ -55,8 +55,6 @@ bool zstring::is_escape_char(bool from_input, char const *& s, unsigned& result) } return false; } - if (!from_input) - return false; unsigned d1, d2, d3, d4; if (*s == '\\' && s[1] == 'u' && is_hex_digit(s[2], d1) && @@ -75,10 +73,10 @@ bool zstring::is_escape_char(bool from_input, char const *& s, unsigned& result) return false; } -zstring::zstring(char const* s, bool from_input) { +zstring::zstring(char const* s) { while (*s) { unsigned ch = 0; - if (is_escape_char(from_input, s, ch)) { + if (is_escape_char(s, ch)) { m_buffer.push_back(ch); } else { @@ -89,6 +87,7 @@ zstring::zstring(char const* s, bool from_input) { SASSERT(well_formed()); } + bool zstring::uses_unicode() const { return gparams::get_value("unicode") != "false"; } @@ -236,12 +235,17 @@ zstring zstring::extract(unsigned offset, unsigned len) const { return result; } +unsigned zstring::hash() const { + return unsigned_ptr_hash(m_buffer.data(), m_buffer.size(), 23); +} + zstring zstring::operator+(zstring const& other) const { zstring result(*this); result.m_buffer.append(other.m_buffer); return result; } + bool zstring::operator==(const zstring& other) const { // two strings are equal iff they have the same length and characters if (length() != other.length()) { diff --git a/src/util/zstring.h b/src/util/zstring.h index f69a74ba5c8..7531390c86b 100644 --- a/src/util/zstring.h +++ b/src/util/zstring.h @@ -19,21 +19,23 @@ Module Name: #include #include "util/vector.h" #include "util/buffer.h" +#include "util/rational.h" class zstring { private: buffer m_buffer; bool well_formed() const; bool uses_unicode() const; - bool is_escape_char(bool from_input, char const *& s, unsigned& result); + bool is_escape_char(char const *& s, unsigned& result); public: static unsigned unicode_max_char() { return 196607; } static unsigned unicode_num_bits() { return 18; } static unsigned ascii_max_char() { return 255; } static unsigned ascii_num_bits() { return 8; } zstring() {} - zstring(char const* s, bool from_input); - zstring(const std::string &str) : zstring(str.c_str(), false) {} + zstring(char const* s); + zstring(const std::string &str) : zstring(str.c_str()) {} + zstring(rational const& r): zstring(r.to_string()) {} zstring(unsigned sz, unsigned const* s) { m_buffer.append(sz, s); SASSERT(well_formed()); } zstring(unsigned ch); zstring replace(zstring const& src, zstring const& dst) const; @@ -51,6 +53,7 @@ class zstring { zstring operator+(zstring const& other) const; bool operator==(const zstring& other) const; bool operator!=(const zstring& other) const; + unsigned hash() const; friend std::ostream& operator<<(std::ostream &os, const zstring &str); friend bool operator<(const zstring& lhs, const zstring& rhs);