Skip to content

Commit 5e8eeac

Browse files
authored
Stop using surrogate escape (#302)
It was workaround for `bytes %`. Since we dropped Python 3.4 support, we can use just `bytes %` now.
1 parent 628bb1b commit 5e8eeac

File tree

5 files changed

+56
-125
lines changed

5 files changed

+56
-125
lines changed

Diff for: MySQLdb/_mysql.c

+14-21
Original file line numberDiff line numberDiff line change
@@ -915,14 +915,15 @@ _mysql.string_literal(obj) cannot handle character sets.";
915915
static PyObject *
916916
_mysql_string_literal(
917917
_mysql_ConnectionObject *self,
918-
PyObject *args)
918+
PyObject *o)
919919
{
920-
PyObject *str, *s, *o, *d;
920+
PyObject *str, *s;
921921
char *in, *out;
922922
int len, size;
923+
923924
if (self && PyModule_Check((PyObject*)self))
924925
self = NULL;
925-
if (!PyArg_ParseTuple(args, "O|O:string_literal", &o, &d)) return NULL;
926+
926927
if (PyBytes_Check(o)) {
927928
s = o;
928929
Py_INCREF(s);
@@ -965,33 +966,25 @@ static PyObject *_mysql_NULL;
965966

966967
static PyObject *
967968
_escape_item(
969+
PyObject *self,
968970
PyObject *item,
969971
PyObject *d)
970972
{
971973
PyObject *quoted=NULL, *itemtype, *itemconv;
972-
if (!(itemtype = PyObject_Type(item)))
973-
goto error;
974+
if (!(itemtype = PyObject_Type(item))) {
975+
return NULL;
976+
}
974977
itemconv = PyObject_GetItem(d, itemtype);
975978
Py_DECREF(itemtype);
976979
if (!itemconv) {
977980
PyErr_Clear();
978-
itemconv = PyObject_GetItem(d,
979-
#ifdef IS_PY3K
980-
(PyObject *) &PyUnicode_Type);
981-
#else
982-
(PyObject *) &PyString_Type);
983-
#endif
984-
}
985-
if (!itemconv) {
986-
PyErr_SetString(PyExc_TypeError,
987-
"no default type converter defined");
988-
goto error;
981+
return _mysql_string_literal((_mysql_ConnectionObject*)self, item);
989982
}
990983
Py_INCREF(d);
991984
quoted = PyObject_CallFunction(itemconv, "OO", item, d);
992985
Py_DECREF(d);
993986
Py_DECREF(itemconv);
994-
error:
987+
995988
return quoted;
996989
}
997990

@@ -1013,14 +1006,14 @@ _mysql_escape(
10131006
"argument 2 must be a mapping");
10141007
return NULL;
10151008
}
1016-
return _escape_item(o, d);
1009+
return _escape_item(self, o, d);
10171010
} else {
10181011
if (!self) {
10191012
PyErr_SetString(PyExc_TypeError,
10201013
"argument 2 must be a mapping");
10211014
return NULL;
10221015
}
1023-
return _escape_item(o,
1016+
return _escape_item(self, o,
10241017
((_mysql_ConnectionObject *) self)->converter);
10251018
}
10261019
}
@@ -2264,7 +2257,7 @@ static PyMethodDef _mysql_ConnectionObject_methods[] = {
22642257
{
22652258
"string_literal",
22662259
(PyCFunction)_mysql_string_literal,
2267-
METH_VARARGS,
2260+
METH_O,
22682261
_mysql_string_literal__doc__},
22692262
{
22702263
"thread_id",
@@ -2587,7 +2580,7 @@ _mysql_methods[] = {
25872580
{
25882581
"string_literal",
25892582
(PyCFunction)_mysql_string_literal,
2590-
METH_VARARGS,
2583+
METH_O,
25912584
_mysql_string_literal__doc__
25922585
},
25932586
{

Diff for: MySQLdb/connections.py

+10-38
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,6 @@
1616
)
1717

1818

19-
if not PY2:
20-
if sys.version_info[:2] < (3, 6):
21-
# See http://bugs.python.org/issue24870
22-
_surrogateescape_table = [chr(i) if i < 0x80 else chr(i + 0xdc00) for i in range(256)]
23-
24-
def _fast_surrogateescape(s):
25-
return s.decode('latin1').translate(_surrogateescape_table)
26-
else:
27-
def _fast_surrogateescape(s):
28-
return s.decode('ascii', 'surrogateescape')
29-
30-
3119
re_numeric_part = re.compile(r"^(\d+)")
3220

3321
def numeric_part(s):
@@ -183,21 +171,8 @@ class object, used to create cursors (keyword only)
183171
self.encoding = 'ascii' # overridden in set_character_set()
184172
db = proxy(self)
185173

186-
# Note: string_literal() is called for bytes object on Python 3 (via bytes_literal)
187-
def string_literal(obj, dummy=None):
188-
return db.string_literal(obj)
189-
190-
if PY2:
191-
# unicode_literal is called for only unicode object.
192-
def unicode_literal(u, dummy=None):
193-
return db.string_literal(u.encode(db.encoding))
194-
else:
195-
# unicode_literal() is called for arbitrary object.
196-
def unicode_literal(u, dummy=None):
197-
return db.string_literal(str(u).encode(db.encoding))
198-
199-
def bytes_literal(obj, dummy=None):
200-
return b'_binary' + db.string_literal(obj)
174+
def unicode_literal(u, dummy=None):
175+
return db.string_literal(u.encode(db.encoding))
201176

202177
def string_decoder(s):
203178
return s.decode(db.encoding)
@@ -214,7 +189,6 @@ def string_decoder(s):
214189
FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.LONG_BLOB, FIELD_TYPE.BLOB):
215190
self.converter[t].append((None, string_decoder))
216191

217-
self.encoders[bytes] = string_literal
218192
self.encoders[unicode] = unicode_literal
219193
self._transactional = self.server_capabilities & CLIENT.TRANSACTIONS
220194
if self._transactional:
@@ -250,7 +224,7 @@ def _bytes_literal(self, bs):
250224
return x
251225

252226
def _tuple_literal(self, t):
253-
return "(%s)" % (','.join(map(self.literal, t)))
227+
return b"(%s)" % (b','.join(map(self.literal, t)))
254228

255229
def literal(self, o):
256230
"""If o is a single object, returns an SQL literal as a string.
@@ -260,29 +234,27 @@ def literal(self, o):
260234
Non-standard. For internal use; do not use this in your
261235
applications.
262236
"""
263-
if isinstance(o, bytearray):
237+
if isinstance(o, unicode):
238+
s = self.string_literal(o.encode(self.encoding))
239+
elif isinstance(o, bytearray):
264240
s = self._bytes_literal(o)
265241
elif not PY2 and isinstance(o, bytes):
266242
s = self._bytes_literal(o)
267243
elif isinstance(o, (tuple, list)):
268244
s = self._tuple_literal(o)
269245
else:
270246
s = self.escape(o, self.encoders)
271-
# Python 3(~3.4) doesn't support % operation for bytes object.
272-
# We should decode it before using %.
273-
# Decoding with ascii and surrogateescape allows convert arbitrary
274-
# bytes to unicode and back again.
275-
# See http://python.org/dev/peps/pep-0383/
276-
if not PY2 and isinstance(s, (bytes, bytearray)):
277-
return _fast_surrogateescape(s)
247+
if isinstance(s, unicode):
248+
s = s.encode(self.encoding)
249+
assert isinstance(s, bytes)
278250
return s
279251

280252
def begin(self):
281253
"""Explicitly begin a connection.
282254
283255
This method is not used when autocommit=False (default).
284256
"""
285-
self.query("BEGIN")
257+
self.query(b"BEGIN")
286258

287259
if not hasattr(_mysql.connection, 'warning_count'):
288260

Diff for: MySQLdb/converters.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def Str2Set(s):
5353

5454
def Set2Str(s, d):
5555
# Only support ascii string. Not tested.
56-
return string_literal(','.join(s), d)
56+
return string_literal(','.join(s))
5757

5858
def Thing2Str(s, d):
5959
"""Convert something into a string via str()."""
@@ -80,7 +80,7 @@ def Thing2Literal(o, d):
8080
MySQL-3.23 or newer, string_literal() is a method of the
8181
_mysql.MYSQL object, and this function will be overridden with
8282
that method when the connection is created."""
83-
return string_literal(o, d)
83+
return string_literal(o)
8484

8585
def Decimal2Literal(o, d):
8686
return format(o, 'f')

Diff for: MySQLdb/cursors.py

+28-62
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,6 @@
1515
NotSupportedError, ProgrammingError)
1616

1717

18-
PY2 = sys.version_info[0] == 2
19-
if PY2:
20-
text_type = unicode
21-
else:
22-
text_type = str
23-
24-
2518
#: Regular expression for :meth:`Cursor.executemany`.
2619
#: executemany only supports simple bulk insert.
2720
#: You can use it to load large dataset.
@@ -95,31 +88,28 @@ def __exit__(self, *exc_info):
9588
del exc_info
9689
self.close()
9790

98-
def _ensure_bytes(self, x, encoding=None):
99-
if isinstance(x, text_type):
100-
x = x.encode(encoding)
101-
elif isinstance(x, (tuple, list)):
102-
x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x)
103-
return x
104-
10591
def _escape_args(self, args, conn):
106-
ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding)
92+
encoding = conn.encoding
93+
literal = conn.literal
94+
95+
def ensure_bytes(x):
96+
if isinstance(x, unicode):
97+
return x.encode(encoding)
98+
elif isinstance(x, tuple):
99+
return tuple(map(ensure_bytes, x))
100+
elif isinstance(x, list):
101+
return list(map(ensure_bytes, x))
102+
return x
107103

108104
if isinstance(args, (tuple, list)):
109-
if PY2:
110-
args = tuple(map(ensure_bytes, args))
111-
return tuple(conn.literal(arg) for arg in args)
105+
return tuple(literal(ensure_bytes(arg)) for arg in args)
112106
elif isinstance(args, dict):
113-
if PY2:
114-
args = dict((ensure_bytes(key), ensure_bytes(val)) for
115-
(key, val) in args.items())
116-
return dict((key, conn.literal(val)) for (key, val) in args.items())
107+
return {ensure_bytes(key): literal(ensure_bytes(val))
108+
for (key, val) in args.items()}
117109
else:
118110
# If it's not a dictionary let's try escaping it anyways.
119111
# Worst case it will throw a Value error
120-
if PY2:
121-
args = ensure_bytes(args)
122-
return conn.literal(args)
112+
return literal(ensure_bytes(args))
123113

124114
def _check_executed(self):
125115
if not self._executed:
@@ -186,31 +176,20 @@ def execute(self, query, args=None):
186176
pass
187177
db = self._get_db()
188178

189-
# NOTE:
190-
# Python 2: query should be bytes when executing %.
191-
# All unicode in args should be encoded to bytes on Python 2.
192-
# Python 3: query should be str (unicode) when executing %.
193-
# All bytes in args should be decoded with ascii and surrogateescape on Python 3.
194-
# db.literal(obj) always returns str.
195-
196-
if PY2 and isinstance(query, unicode):
179+
if isinstance(query, unicode):
197180
query = query.encode(db.encoding)
198181

199182
if args is not None:
200183
if isinstance(args, dict):
201184
args = dict((key, db.literal(item)) for key, item in args.items())
202185
else:
203186
args = tuple(map(db.literal, args))
204-
if not PY2 and isinstance(query, (bytes, bytearray)):
205-
query = query.decode(db.encoding)
206187
try:
207188
query = query % args
208189
except TypeError as m:
209190
raise ProgrammingError(str(m))
210191

211-
if isinstance(query, unicode):
212-
query = query.encode(db.encoding, 'surrogateescape')
213-
192+
assert isinstance(query, (bytes, bytearray))
214193
res = self._query(query)
215194
return res
216195

@@ -247,29 +226,19 @@ def executemany(self, query, args):
247226
def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
248227
conn = self._get_db()
249228
escape = self._escape_args
250-
if isinstance(prefix, text_type):
229+
if isinstance(prefix, unicode):
251230
prefix = prefix.encode(encoding)
252-
if PY2 and isinstance(values, text_type):
231+
if isinstance(values, unicode):
253232
values = values.encode(encoding)
254-
if isinstance(postfix, text_type):
233+
if isinstance(postfix, unicode):
255234
postfix = postfix.encode(encoding)
256235
sql = bytearray(prefix)
257236
args = iter(args)
258237
v = values % escape(next(args), conn)
259-
if isinstance(v, text_type):
260-
if PY2:
261-
v = v.encode(encoding)
262-
else:
263-
v = v.encode(encoding, 'surrogateescape')
264238
sql += v
265239
rows = 0
266240
for arg in args:
267241
v = values % escape(arg, conn)
268-
if isinstance(v, text_type):
269-
if PY2:
270-
v = v.encode(encoding)
271-
else:
272-
v = v.encode(encoding, 'surrogateescape')
273242
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
274243
rows += self.execute(sql + postfix)
275244
sql = bytearray(prefix)
@@ -308,22 +277,19 @@ def callproc(self, procname, args=()):
308277
to advance through all result sets; otherwise you may get
309278
disconnected.
310279
"""
311-
312280
db = self._get_db()
281+
if isinstance(procname, unicode):
282+
procname = procname.encode(db.encoding)
313283
if args:
314-
fmt = '@_{0}_%d=%s'.format(procname)
315-
q = 'SET %s' % ','.join(fmt % (index, db.literal(arg))
316-
for index, arg in enumerate(args))
317-
if isinstance(q, unicode):
318-
q = q.encode(db.encoding, 'surrogateescape')
284+
fmt = b'@_' + procname + b'_%d=%s'
285+
q = b'SET %s' % b','.join(fmt % (index, db.literal(arg))
286+
for index, arg in enumerate(args))
319287
self._query(q)
320288
self.nextset()
321289

322-
q = "CALL %s(%s)" % (procname,
323-
','.join(['@_%s_%d' % (procname, i)
324-
for i in range(len(args))]))
325-
if isinstance(q, unicode):
326-
q = q.encode(db.encoding, 'surrogateescape')
290+
q = b"CALL %s(%s)" % (procname,
291+
b','.join([b'@_%s_%d' % (procname, i)
292+
for i in range(len(args))]))
327293
self._query(q)
328294
return args
329295

Diff for: MySQLdb/times.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,11 @@ def Date_or_None(s):
124124

125125
def DateTime2literal(d, c):
126126
"""Format a DateTime object as an ISO timestamp."""
127-
return string_literal(format_TIMESTAMP(d), c)
127+
return string_literal(format_TIMESTAMP(d))
128128

129129
def DateTimeDelta2literal(d, c):
130130
"""Format a DateTimeDelta object as a time."""
131-
return string_literal(format_TIMEDELTA(d),c)
131+
return string_literal(format_TIMEDELTA(d))
132132

133133
def mysql_timestamp_converter(s):
134134
"""Convert a MySQL TIMESTAMP to a Timestamp object."""

0 commit comments

Comments
 (0)