Skip to content

Commit 6588f14

Browse files
committed
Let multi statements be optional
- Disabling multi statements can help protect against SQL injection attacks.
1 parent 24aaa72 commit 6588f14

File tree

2 files changed

+41
-7
lines changed

2 files changed

+41
-7
lines changed

MySQLdb/connections.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ class object, used to create cursors (keyword only)
110110
:param int client_flag:
111111
flags to use or 0 (see MySQL docs or constants/CLIENTS.py)
112112
113+
:param bool multi_statements:
114+
If True, enable multi statements for clients >= 4.1 and multi
115+
results for clients >= 5.0. Set to False to disable it, which gives
116+
some protection against injection attacks. Defaults to True.
117+
113118
:param str ssl_mode:
114119
specify the security settings for connection to the server;
115120
see the MySQL documentation for more details
@@ -169,13 +174,16 @@ class object, used to create cursors (keyword only)
169174
self._binary_prefix = kwargs2.pop("binary_prefix", False)
170175

171176
client_flag = kwargs.get("client_flag", 0)
172-
client_version = tuple(
173-
[numeric_part(n) for n in _mysql.get_client_info().split(".")[:2]]
174-
)
175-
if client_version >= (4, 1):
176-
client_flag |= CLIENT.MULTI_STATEMENTS
177-
if client_version >= (5, 0):
178-
client_flag |= CLIENT.MULTI_RESULTS
177+
178+
multi_statements = kwargs2.pop("multi_statements", True)
179+
if multi_statements:
180+
client_version = tuple(
181+
[numeric_part(n) for n in _mysql.get_client_info().split(".")[:2]]
182+
)
183+
if client_version >= (4, 1):
184+
client_flag |= CLIENT.MULTI_STATEMENTS
185+
if client_version >= (5, 0):
186+
client_flag |= CLIENT.MULTI_RESULTS
179187

180188
kwargs2["client_flag"] = client_flag
181189

tests/test_connection.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import pytest
2+
3+
from MySQLdb._exceptions import ProgrammingError
4+
5+
from configdb import connection_factory
6+
7+
8+
def test_multi_statements_default_true():
9+
conn = connection_factory()
10+
cursor = conn.cursor()
11+
12+
cursor.execute("select 17; select 2")
13+
rows = cursor.fetchall()
14+
assert rows == ((17,),)
15+
16+
17+
def test_multi_statements_false():
18+
conn = connection_factory(multi_statements=False)
19+
cursor = conn.cursor()
20+
21+
with pytest.raises(ProgrammingError):
22+
cursor.execute("select 17; select 2")
23+
24+
cursor.execute("select 17")
25+
rows = cursor.fetchall()
26+
assert rows == ((17,),)

0 commit comments

Comments
 (0)