1
- from typing import Any , Mapping , Optional , Sequence , TypeVar , Union , overload
1
+ from typing import (
2
+ Any ,
3
+ Dict ,
4
+ Mapping ,
5
+ Optional ,
6
+ Sequence ,
7
+ Type ,
8
+ TypeVar ,
9
+ Union ,
10
+ cast ,
11
+ overload ,
12
+ )
2
13
3
14
from sqlalchemy import util
15
+ from sqlalchemy .engine .interfaces import _CoreAnyExecuteParams
16
+ from sqlalchemy .engine .result import Result , ScalarResult , TupleResult
4
17
from sqlalchemy .ext .asyncio import AsyncSession as _AsyncSession
5
- from sqlalchemy .ext .asyncio import engine
6
- from sqlalchemy .ext .asyncio .engine import AsyncConnection , AsyncEngine
18
+ from sqlalchemy .ext .asyncio .result import _ensure_sync_result
19
+ from sqlalchemy .ext .asyncio .session import _EXECUTE_OPTIONS
20
+ from sqlalchemy .orm ._typing import OrmExecuteOptionsParameter
21
+ from sqlalchemy .sql .base import Executable as _Executable
7
22
from sqlalchemy .util .concurrency import greenlet_spawn
23
+ from typing_extensions import deprecated
8
24
9
- from ...engine .result import Result , ScalarResult
10
25
from ...orm .session import Session
11
26
from ...sql .base import Executable
12
27
from ...sql .expression import Select , SelectOfScalar
13
28
14
- _TSelectParam = TypeVar ("_TSelectParam" )
29
+ _TSelectParam = TypeVar ("_TSelectParam" , bound = Any )
15
30
16
31
17
32
class AsyncSession (_AsyncSession ):
33
+ sync_session_class : Type [Session ] = Session
18
34
sync_session : Session
19
35
20
- def __init__ (
21
- self ,
22
- bind : Optional [Union [AsyncConnection , AsyncEngine ]] = None ,
23
- binds : Optional [Mapping [object , Union [AsyncConnection , AsyncEngine ]]] = None ,
24
- ** kw : Any ,
25
- ):
26
- # All the same code of the original AsyncSession
27
- kw ["future" ] = True
28
- if bind :
29
- self .bind = bind
30
- bind = engine ._get_sync_engine_or_connection (bind ) # type: ignore
31
-
32
- if binds :
33
- self .binds = binds
34
- binds = {
35
- key : engine ._get_sync_engine_or_connection (b ) # type: ignore
36
- for key , b in binds .items ()
37
- }
38
-
39
- self .sync_session = self ._proxied = self ._assign_proxied ( # type: ignore
40
- Session (bind = bind , binds = binds , ** kw ) # type: ignore
41
- )
42
-
43
36
@overload
44
37
async def exec (
45
38
self ,
46
39
statement : Select [_TSelectParam ],
47
40
* ,
48
41
params : Optional [Union [Mapping [str , Any ], Sequence [Mapping [str , Any ]]]] = None ,
49
42
execution_options : Mapping [str , Any ] = util .EMPTY_DICT ,
50
- bind_arguments : Optional [Mapping [str , Any ]] = None ,
43
+ bind_arguments : Optional [Dict [str , Any ]] = None ,
51
44
_parent_execute_state : Optional [Any ] = None ,
52
45
_add_event : Optional [Any ] = None ,
53
- ** kw : Any ,
54
- ) -> Result [_TSelectParam ]:
46
+ ) -> TupleResult [_TSelectParam ]:
55
47
...
56
48
57
49
@overload
@@ -61,10 +53,9 @@ async def exec(
61
53
* ,
62
54
params : Optional [Union [Mapping [str , Any ], Sequence [Mapping [str , Any ]]]] = None ,
63
55
execution_options : Mapping [str , Any ] = util .EMPTY_DICT ,
64
- bind_arguments : Optional [Mapping [str , Any ]] = None ,
56
+ bind_arguments : Optional [Dict [str , Any ]] = None ,
65
57
_parent_execute_state : Optional [Any ] = None ,
66
58
_add_event : Optional [Any ] = None ,
67
- ** kw : Any ,
68
59
) -> ScalarResult [_TSelectParam ]:
69
60
...
70
61
@@ -75,20 +66,87 @@ async def exec(
75
66
SelectOfScalar [_TSelectParam ],
76
67
Executable [_TSelectParam ],
77
68
],
69
+ * ,
78
70
params : Optional [Union [Mapping [str , Any ], Sequence [Mapping [str , Any ]]]] = None ,
79
- execution_options : Mapping [Any , Any ] = util .EMPTY_DICT ,
80
- bind_arguments : Optional [Mapping [str , Any ]] = None ,
81
- ** kw : Any ,
82
- ) -> Union [Result [_TSelectParam ], ScalarResult [_TSelectParam ]]:
83
- # TODO: the documentation says execution_options accepts a dict, but only
84
- # util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
85
- execution_options = execution_options .union ({"prebuffer_rows" : True }) # type: ignore
86
-
87
- return await greenlet_spawn (
71
+ execution_options : Mapping [str , Any ] = util .EMPTY_DICT ,
72
+ bind_arguments : Optional [Dict [str , Any ]] = None ,
73
+ _parent_execute_state : Optional [Any ] = None ,
74
+ _add_event : Optional [Any ] = None ,
75
+ ) -> Union [TupleResult [_TSelectParam ], ScalarResult [_TSelectParam ]]:
76
+ if execution_options :
77
+ execution_options = util .immutabledict (execution_options ).union (
78
+ _EXECUTE_OPTIONS
79
+ )
80
+ else :
81
+ execution_options = _EXECUTE_OPTIONS
82
+
83
+ result = await greenlet_spawn (
88
84
self .sync_session .exec ,
89
85
statement ,
90
86
params = params ,
91
87
execution_options = execution_options ,
92
88
bind_arguments = bind_arguments ,
93
- ** kw ,
89
+ _parent_execute_state = _parent_execute_state ,
90
+ _add_event = _add_event ,
91
+ )
92
+ result_value = await _ensure_sync_result (
93
+ cast (Result [_TSelectParam ], result ), self .exec
94
+ )
95
+ return result_value # type: ignore
96
+
97
+ @deprecated (
98
+ """
99
+ 🚨 You probably want to use `session.exec()` instead of `session.execute()`.
100
+
101
+ This is the original SQLAlchemy `session.execute()` method that returns objects
102
+ of type `Row`, and that you have to call `scalars()` to get the model objects.
103
+
104
+ For example:
105
+
106
+ ```Python
107
+ heroes = await session.execute(select(Hero)).scalars().all()
108
+ ```
109
+
110
+ instead you could use `exec()`:
111
+
112
+ ```Python
113
+ heroes = await session.exec(select(Hero)).all()
114
+ ```
115
+ """
116
+ )
117
+ async def execute ( # type: ignore
118
+ self ,
119
+ statement : _Executable ,
120
+ params : Optional [_CoreAnyExecuteParams ] = None ,
121
+ * ,
122
+ execution_options : OrmExecuteOptionsParameter = util .EMPTY_DICT ,
123
+ bind_arguments : Optional [Dict [str , Any ]] = None ,
124
+ _parent_execute_state : Optional [Any ] = None ,
125
+ _add_event : Optional [Any ] = None ,
126
+ ) -> Result [Any ]:
127
+ """
128
+ 🚨 You probably want to use `session.exec()` instead of `session.execute()`.
129
+
130
+ This is the original SQLAlchemy `session.execute()` method that returns objects
131
+ of type `Row`, and that you have to call `scalars()` to get the model objects.
132
+
133
+ For example:
134
+
135
+ ```Python
136
+ heroes = await session.execute(select(Hero)).scalars().all()
137
+ ```
138
+
139
+ instead you could use `exec()`:
140
+
141
+ ```Python
142
+ heroes = await session.exec(select(Hero)).all()
143
+ ```
144
+ """
145
+ return await super ().execute (
146
+ statement ,
147
+ params = params ,
148
+ execution_options = execution_options ,
149
+ bind_arguments = bind_arguments ,
150
+ _parent_execute_state = _parent_execute_state ,
151
+ _add_event = _add_event ,
94
152
)
0 commit comments