Skip to content

Commit

Permalink
Stop using surrogate escape (#302)
Browse files Browse the repository at this point in the history
It was workaround for `bytes %`.
Since we dropped Python 3.4 support, we can use just `bytes %` now.
  • Loading branch information
methane authored Dec 6, 2018
1 parent 628bb1b commit 5e8eeac
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 125 deletions.
35 changes: 14 additions & 21 deletions MySQLdb/_mysql.c
Original file line number Diff line number Diff line change
Expand Up @@ -915,14 +915,15 @@ _mysql.string_literal(obj) cannot handle character sets.";
static PyObject *
_mysql_string_literal(
_mysql_ConnectionObject *self,
PyObject *args)
PyObject *o)
{
PyObject *str, *s, *o, *d;
PyObject *str, *s;
char *in, *out;
int len, size;

if (self && PyModule_Check((PyObject*)self))
self = NULL;
if (!PyArg_ParseTuple(args, "O|O:string_literal", &o, &d)) return NULL;

if (PyBytes_Check(o)) {
s = o;
Py_INCREF(s);
Expand Down Expand Up @@ -965,33 +966,25 @@ static PyObject *_mysql_NULL;

static PyObject *
_escape_item(
PyObject *self,
PyObject *item,
PyObject *d)
{
PyObject *quoted=NULL, *itemtype, *itemconv;
if (!(itemtype = PyObject_Type(item)))
goto error;
if (!(itemtype = PyObject_Type(item))) {
return NULL;
}
itemconv = PyObject_GetItem(d, itemtype);
Py_DECREF(itemtype);
if (!itemconv) {
PyErr_Clear();
itemconv = PyObject_GetItem(d,
#ifdef IS_PY3K
(PyObject *) &PyUnicode_Type);
#else
(PyObject *) &PyString_Type);
#endif
}
if (!itemconv) {
PyErr_SetString(PyExc_TypeError,
"no default type converter defined");
goto error;
return _mysql_string_literal((_mysql_ConnectionObject*)self, item);
}
Py_INCREF(d);
quoted = PyObject_CallFunction(itemconv, "OO", item, d);
Py_DECREF(d);
Py_DECREF(itemconv);
error:

return quoted;
}

Expand All @@ -1013,14 +1006,14 @@ _mysql_escape(
"argument 2 must be a mapping");
return NULL;
}
return _escape_item(o, d);
return _escape_item(self, o, d);
} else {
if (!self) {
PyErr_SetString(PyExc_TypeError,
"argument 2 must be a mapping");
return NULL;
}
return _escape_item(o,
return _escape_item(self, o,
((_mysql_ConnectionObject *) self)->converter);
}
}
Expand Down Expand Up @@ -2264,7 +2257,7 @@ static PyMethodDef _mysql_ConnectionObject_methods[] = {
{
"string_literal",
(PyCFunction)_mysql_string_literal,
METH_VARARGS,
METH_O,
_mysql_string_literal__doc__},
{
"thread_id",
Expand Down Expand Up @@ -2587,7 +2580,7 @@ _mysql_methods[] = {
{
"string_literal",
(PyCFunction)_mysql_string_literal,
METH_VARARGS,
METH_O,
_mysql_string_literal__doc__
},
{
Expand Down
48 changes: 10 additions & 38 deletions MySQLdb/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,6 @@
)


if not PY2:
if sys.version_info[:2] < (3, 6):
# See http://bugs.python.org/issue24870
_surrogateescape_table = [chr(i) if i < 0x80 else chr(i + 0xdc00) for i in range(256)]

def _fast_surrogateescape(s):
return s.decode('latin1').translate(_surrogateescape_table)
else:
def _fast_surrogateescape(s):
return s.decode('ascii', 'surrogateescape')


re_numeric_part = re.compile(r"^(\d+)")

def numeric_part(s):
Expand Down Expand Up @@ -183,21 +171,8 @@ class object, used to create cursors (keyword only)
self.encoding = 'ascii' # overridden in set_character_set()
db = proxy(self)

# Note: string_literal() is called for bytes object on Python 3 (via bytes_literal)
def string_literal(obj, dummy=None):
return db.string_literal(obj)

if PY2:
# unicode_literal is called for only unicode object.
def unicode_literal(u, dummy=None):
return db.string_literal(u.encode(db.encoding))
else:
# unicode_literal() is called for arbitrary object.
def unicode_literal(u, dummy=None):
return db.string_literal(str(u).encode(db.encoding))

def bytes_literal(obj, dummy=None):
return b'_binary' + db.string_literal(obj)
def unicode_literal(u, dummy=None):
return db.string_literal(u.encode(db.encoding))

def string_decoder(s):
return s.decode(db.encoding)
Expand All @@ -214,7 +189,6 @@ def string_decoder(s):
FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.LONG_BLOB, FIELD_TYPE.BLOB):
self.converter[t].append((None, string_decoder))

self.encoders[bytes] = string_literal
self.encoders[unicode] = unicode_literal
self._transactional = self.server_capabilities & CLIENT.TRANSACTIONS
if self._transactional:
Expand Down Expand Up @@ -250,7 +224,7 @@ def _bytes_literal(self, bs):
return x

def _tuple_literal(self, t):
return "(%s)" % (','.join(map(self.literal, t)))
return b"(%s)" % (b','.join(map(self.literal, t)))

def literal(self, o):
"""If o is a single object, returns an SQL literal as a string.
Expand All @@ -260,29 +234,27 @@ def literal(self, o):
Non-standard. For internal use; do not use this in your
applications.
"""
if isinstance(o, bytearray):
if isinstance(o, unicode):
s = self.string_literal(o.encode(self.encoding))
elif isinstance(o, bytearray):
s = self._bytes_literal(o)
elif not PY2 and isinstance(o, bytes):
s = self._bytes_literal(o)
elif isinstance(o, (tuple, list)):
s = self._tuple_literal(o)
else:
s = self.escape(o, self.encoders)
# Python 3(~3.4) doesn't support % operation for bytes object.
# We should decode it before using %.
# Decoding with ascii and surrogateescape allows convert arbitrary
# bytes to unicode and back again.
# See http://python.org/dev/peps/pep-0383/
if not PY2 and isinstance(s, (bytes, bytearray)):
return _fast_surrogateescape(s)
if isinstance(s, unicode):
s = s.encode(self.encoding)
assert isinstance(s, bytes)
return s

def begin(self):
"""Explicitly begin a connection.
This method is not used when autocommit=False (default).
"""
self.query("BEGIN")
self.query(b"BEGIN")

if not hasattr(_mysql.connection, 'warning_count'):

Expand Down
4 changes: 2 additions & 2 deletions MySQLdb/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def Str2Set(s):

def Set2Str(s, d):
# Only support ascii string. Not tested.
return string_literal(','.join(s), d)
return string_literal(','.join(s))

def Thing2Str(s, d):
"""Convert something into a string via str()."""
Expand All @@ -80,7 +80,7 @@ def Thing2Literal(o, d):
MySQL-3.23 or newer, string_literal() is a method of the
_mysql.MYSQL object, and this function will be overridden with
that method when the connection is created."""
return string_literal(o, d)
return string_literal(o)

def Decimal2Literal(o, d):
return format(o, 'f')
Expand Down
90 changes: 28 additions & 62 deletions MySQLdb/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@
NotSupportedError, ProgrammingError)


PY2 = sys.version_info[0] == 2
if PY2:
text_type = unicode
else:
text_type = str


#: Regular expression for :meth:`Cursor.executemany`.
#: executemany only supports simple bulk insert.
#: You can use it to load large dataset.
Expand Down Expand Up @@ -95,31 +88,28 @@ def __exit__(self, *exc_info):
del exc_info
self.close()

def _ensure_bytes(self, x, encoding=None):
if isinstance(x, text_type):
x = x.encode(encoding)
elif isinstance(x, (tuple, list)):
x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x)
return x

def _escape_args(self, args, conn):
ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding)
encoding = conn.encoding
literal = conn.literal

def ensure_bytes(x):
if isinstance(x, unicode):
return x.encode(encoding)
elif isinstance(x, tuple):
return tuple(map(ensure_bytes, x))
elif isinstance(x, list):
return list(map(ensure_bytes, x))
return x

if isinstance(args, (tuple, list)):
if PY2:
args = tuple(map(ensure_bytes, args))
return tuple(conn.literal(arg) for arg in args)
return tuple(literal(ensure_bytes(arg)) for arg in args)
elif isinstance(args, dict):
if PY2:
args = dict((ensure_bytes(key), ensure_bytes(val)) for
(key, val) in args.items())
return dict((key, conn.literal(val)) for (key, val) in args.items())
return {ensure_bytes(key): literal(ensure_bytes(val))
for (key, val) in args.items()}
else:
# If it's not a dictionary let's try escaping it anyways.
# Worst case it will throw a Value error
if PY2:
args = ensure_bytes(args)
return conn.literal(args)
return literal(ensure_bytes(args))

def _check_executed(self):
if not self._executed:
Expand Down Expand Up @@ -186,31 +176,20 @@ def execute(self, query, args=None):
pass
db = self._get_db()

# NOTE:
# Python 2: query should be bytes when executing %.
# All unicode in args should be encoded to bytes on Python 2.
# Python 3: query should be str (unicode) when executing %.
# All bytes in args should be decoded with ascii and surrogateescape on Python 3.
# db.literal(obj) always returns str.

if PY2 and isinstance(query, unicode):
if isinstance(query, unicode):
query = query.encode(db.encoding)

if args is not None:
if isinstance(args, dict):
args = dict((key, db.literal(item)) for key, item in args.items())
else:
args = tuple(map(db.literal, args))
if not PY2 and isinstance(query, (bytes, bytearray)):
query = query.decode(db.encoding)
try:
query = query % args
except TypeError as m:
raise ProgrammingError(str(m))

if isinstance(query, unicode):
query = query.encode(db.encoding, 'surrogateescape')

assert isinstance(query, (bytes, bytearray))
res = self._query(query)
return res

Expand Down Expand Up @@ -247,29 +226,19 @@ def executemany(self, query, args):
def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
conn = self._get_db()
escape = self._escape_args
if isinstance(prefix, text_type):
if isinstance(prefix, unicode):
prefix = prefix.encode(encoding)
if PY2 and isinstance(values, text_type):
if isinstance(values, unicode):
values = values.encode(encoding)
if isinstance(postfix, text_type):
if isinstance(postfix, unicode):
postfix = postfix.encode(encoding)
sql = bytearray(prefix)
args = iter(args)
v = values % escape(next(args), conn)
if isinstance(v, text_type):
if PY2:
v = v.encode(encoding)
else:
v = v.encode(encoding, 'surrogateescape')
sql += v
rows = 0
for arg in args:
v = values % escape(arg, conn)
if isinstance(v, text_type):
if PY2:
v = v.encode(encoding)
else:
v = v.encode(encoding, 'surrogateescape')
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
rows += self.execute(sql + postfix)
sql = bytearray(prefix)
Expand Down Expand Up @@ -308,22 +277,19 @@ def callproc(self, procname, args=()):
to advance through all result sets; otherwise you may get
disconnected.
"""

db = self._get_db()
if isinstance(procname, unicode):
procname = procname.encode(db.encoding)
if args:
fmt = '@_{0}_%d=%s'.format(procname)
q = 'SET %s' % ','.join(fmt % (index, db.literal(arg))
for index, arg in enumerate(args))
if isinstance(q, unicode):
q = q.encode(db.encoding, 'surrogateescape')
fmt = b'@_' + procname + b'_%d=%s'
q = b'SET %s' % b','.join(fmt % (index, db.literal(arg))
for index, arg in enumerate(args))
self._query(q)
self.nextset()

q = "CALL %s(%s)" % (procname,
','.join(['@_%s_%d' % (procname, i)
for i in range(len(args))]))
if isinstance(q, unicode):
q = q.encode(db.encoding, 'surrogateescape')
q = b"CALL %s(%s)" % (procname,
b','.join([b'@_%s_%d' % (procname, i)
for i in range(len(args))]))
self._query(q)
return args

Expand Down
4 changes: 2 additions & 2 deletions MySQLdb/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ def Date_or_None(s):

def DateTime2literal(d, c):
"""Format a DateTime object as an ISO timestamp."""
return string_literal(format_TIMESTAMP(d), c)
return string_literal(format_TIMESTAMP(d))

def DateTimeDelta2literal(d, c):
"""Format a DateTimeDelta object as a time."""
return string_literal(format_TIMEDELTA(d),c)
return string_literal(format_TIMEDELTA(d))

def mysql_timestamp_converter(s):
"""Convert a MySQL TIMESTAMP to a Timestamp object."""
Expand Down

0 comments on commit 5e8eeac

Please sign in to comment.