Skip to content

Commit 8f1fd73

Browse files
authored
multi statements can be disabled (#500)
1 parent 24aaa72 commit 8f1fd73

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

Diff for: MySQLdb/connections.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ class object, used to create cursors (keyword only)
110110
:param int client_flag:
111111
flags to use or 0 (see MySQL docs or constants/CLIENTS.py)
112112
113+
:param bool multi_statements:
114+
If True, enable multi statements for clients >= 4.1.
115+
Defaults to True.
116+
113117
:param str ssl_mode:
114118
specify the security settings for connection to the server;
115119
see the MySQL documentation for more details
@@ -169,11 +173,16 @@ class object, used to create cursors (keyword only)
169173
self._binary_prefix = kwargs2.pop("binary_prefix", False)
170174

171175
client_flag = kwargs.get("client_flag", 0)
176+
172177
client_version = tuple(
173178
[numeric_part(n) for n in _mysql.get_client_info().split(".")[:2]]
174179
)
175-
if client_version >= (4, 1):
176-
client_flag |= CLIENT.MULTI_STATEMENTS
180+
181+
multi_statements = kwargs2.pop("multi_statements", True)
182+
if multi_statements:
183+
if client_version >= (4, 1):
184+
client_flag |= CLIENT.MULTI_STATEMENTS
185+
177186
if client_version >= (5, 0):
178187
client_flag |= CLIENT.MULTI_RESULTS
179188

Diff for: tests/test_connection.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import pytest
2+
3+
from MySQLdb._exceptions import ProgrammingError
4+
5+
from configdb import connection_factory
6+
7+
8+
def test_multi_statements_default_true():
9+
conn = connection_factory()
10+
cursor = conn.cursor()
11+
12+
cursor.execute("select 17; select 2")
13+
rows = cursor.fetchall()
14+
assert rows == ((17,),)
15+
16+
17+
def test_multi_statements_false():
18+
conn = connection_factory(multi_statements=False)
19+
cursor = conn.cursor()
20+
21+
with pytest.raises(ProgrammingError):
22+
cursor.execute("select 17; select 2")
23+
24+
cursor.execute("select 17")
25+
rows = cursor.fetchall()
26+
assert rows == ((17,),)

0 commit comments

Comments
 (0)