Skip to content

Commit 80a5c52

Browse files
committed
✨ Update AsyncSession with support for SQLAlchemy 2.0 and new Session
1 parent baa5e3a commit 80a5c52

File tree

1 file changed

+101
-43
lines changed

1 file changed

+101
-43
lines changed

Diff for: sqlmodel/ext/asyncio/session.py

+101-43
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,49 @@
1-
from typing import Any, Mapping, Optional, Sequence, TypeVar, Union, overload
1+
from typing import (
2+
Any,
3+
Dict,
4+
Mapping,
5+
Optional,
6+
Sequence,
7+
Type,
8+
TypeVar,
9+
Union,
10+
cast,
11+
overload,
12+
)
213

314
from sqlalchemy import util
15+
from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams
16+
from sqlalchemy.engine.result import Result, ScalarResult, TupleResult
417
from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession
5-
from sqlalchemy.ext.asyncio import engine
6-
from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine
18+
from sqlalchemy.ext.asyncio.result import _ensure_sync_result
19+
from sqlalchemy.ext.asyncio.session import _EXECUTE_OPTIONS
20+
from sqlalchemy.orm._typing import OrmExecuteOptionsParameter
21+
from sqlalchemy.sql.base import Executable as _Executable
722
from sqlalchemy.util.concurrency import greenlet_spawn
23+
from typing_extensions import deprecated
824

9-
from ...engine.result import Result, ScalarResult
1025
from ...orm.session import Session
1126
from ...sql.base import Executable
1227
from ...sql.expression import Select, SelectOfScalar
1328

14-
_TSelectParam = TypeVar("_TSelectParam")
29+
_TSelectParam = TypeVar("_TSelectParam", bound=Any)
1530

1631

1732
class AsyncSession(_AsyncSession):
33+
sync_session_class: Type[Session] = Session
1834
sync_session: Session
1935

20-
def __init__(
21-
self,
22-
bind: Optional[Union[AsyncConnection, AsyncEngine]] = None,
23-
binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None,
24-
**kw: Any,
25-
):
26-
# All the same code of the original AsyncSession
27-
kw["future"] = True
28-
if bind:
29-
self.bind = bind
30-
bind = engine._get_sync_engine_or_connection(bind) # type: ignore
31-
32-
if binds:
33-
self.binds = binds
34-
binds = {
35-
key: engine._get_sync_engine_or_connection(b) # type: ignore
36-
for key, b in binds.items()
37-
}
38-
39-
self.sync_session = self._proxied = self._assign_proxied( # type: ignore
40-
Session(bind=bind, binds=binds, **kw) # type: ignore
41-
)
42-
4336
@overload
4437
async def exec(
4538
self,
4639
statement: Select[_TSelectParam],
4740
*,
4841
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
4942
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
50-
bind_arguments: Optional[Mapping[str, Any]] = None,
43+
bind_arguments: Optional[Dict[str, Any]] = None,
5144
_parent_execute_state: Optional[Any] = None,
5245
_add_event: Optional[Any] = None,
53-
**kw: Any,
54-
) -> Result[_TSelectParam]:
46+
) -> TupleResult[_TSelectParam]:
5547
...
5648

5749
@overload
@@ -61,10 +53,9 @@ async def exec(
6153
*,
6254
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
6355
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
64-
bind_arguments: Optional[Mapping[str, Any]] = None,
56+
bind_arguments: Optional[Dict[str, Any]] = None,
6557
_parent_execute_state: Optional[Any] = None,
6658
_add_event: Optional[Any] = None,
67-
**kw: Any,
6859
) -> ScalarResult[_TSelectParam]:
6960
...
7061

@@ -75,20 +66,87 @@ async def exec(
7566
SelectOfScalar[_TSelectParam],
7667
Executable[_TSelectParam],
7768
],
69+
*,
7870
params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None,
79-
execution_options: Mapping[Any, Any] = util.EMPTY_DICT,
80-
bind_arguments: Optional[Mapping[str, Any]] = None,
81-
**kw: Any,
82-
) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]:
83-
# TODO: the documentation says execution_options accepts a dict, but only
84-
# util.immutabledict has the union() method. Is this a bug in SQLAlchemy?
85-
execution_options = execution_options.union({"prebuffer_rows": True}) # type: ignore
86-
87-
return await greenlet_spawn(
71+
execution_options: Mapping[str, Any] = util.EMPTY_DICT,
72+
bind_arguments: Optional[Dict[str, Any]] = None,
73+
_parent_execute_state: Optional[Any] = None,
74+
_add_event: Optional[Any] = None,
75+
) -> Union[TupleResult[_TSelectParam], ScalarResult[_TSelectParam]]:
76+
if execution_options:
77+
execution_options = util.immutabledict(execution_options).union(
78+
_EXECUTE_OPTIONS
79+
)
80+
else:
81+
execution_options = _EXECUTE_OPTIONS
82+
83+
result = await greenlet_spawn(
8884
self.sync_session.exec,
8985
statement,
9086
params=params,
9187
execution_options=execution_options,
9288
bind_arguments=bind_arguments,
93-
**kw,
89+
_parent_execute_state=_parent_execute_state,
90+
_add_event=_add_event,
91+
)
92+
result_value = await _ensure_sync_result(
93+
cast(Result[_TSelectParam], result), self.exec
94+
)
95+
return result_value # type: ignore
96+
97+
@deprecated(
98+
"""
99+
🚨 You probably want to use `session.exec()` instead of `session.execute()`.
100+
101+
This is the original SQLAlchemy `session.execute()` method that returns objects
102+
of type `Row`, and that you have to call `scalars()` to get the model objects.
103+
104+
For example:
105+
106+
```Python
107+
heroes = await session.execute(select(Hero)).scalars().all()
108+
```
109+
110+
instead you could use `exec()`:
111+
112+
```Python
113+
heroes = await session.exec(select(Hero)).all()
114+
```
115+
"""
116+
)
117+
async def execute( # type: ignore
118+
self,
119+
statement: _Executable,
120+
params: Optional[_CoreAnyExecuteParams] = None,
121+
*,
122+
execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
123+
bind_arguments: Optional[Dict[str, Any]] = None,
124+
_parent_execute_state: Optional[Any] = None,
125+
_add_event: Optional[Any] = None,
126+
) -> Result[Any]:
127+
"""
128+
🚨 You probably want to use `session.exec()` instead of `session.execute()`.
129+
130+
This is the original SQLAlchemy `session.execute()` method that returns objects
131+
of type `Row`, and that you have to call `scalars()` to get the model objects.
132+
133+
For example:
134+
135+
```Python
136+
heroes = await session.execute(select(Hero)).scalars().all()
137+
```
138+
139+
instead you could use `exec()`:
140+
141+
```Python
142+
heroes = await session.exec(select(Hero)).all()
143+
```
144+
"""
145+
return await super().execute(
146+
statement,
147+
params=params,
148+
execution_options=execution_options,
149+
bind_arguments=bind_arguments,
150+
_parent_execute_state=_parent_execute_state,
151+
_add_event=_add_event,
94152
)

0 commit comments

Comments
 (0)