Skip to content

Commit df70eeb

Browse files
committed
Implement access token strategy db adapter
1 parent ff273b1 commit df70eeb

File tree

5 files changed

+193
-5
lines changed

5 files changed

+193
-5
lines changed
+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from datetime import datetime
2+
from typing import Generic, Optional, Type, cast
3+
4+
from fastapi_users.authentication.strategy.db import A, AccessTokenDatabase
5+
from tortoise import fields, models
6+
from tortoise.contrib.pydantic import PydanticModel
7+
8+
9+
class TortoiseBaseAccessTokenModel(models.Model):
10+
token = fields.CharField(pk=True, max_length=43)
11+
created_at = fields.DatetimeField(
12+
null=False,
13+
auto_now_add=True,
14+
)
15+
16+
17+
class TortoiseAccessTokenDatabase(AccessTokenDatabase, Generic[A]):
18+
"""
19+
Access token database adapter for Tortoise ORM.
20+
21+
:param access_token_model: Pydantic model of a DB representation of an access token.
22+
:param model: Tortoise ORM model.
23+
"""
24+
25+
def __init__(
26+
self, access_token_model: Type[A], model: Type[TortoiseBaseAccessTokenModel]
27+
):
28+
self.access_token_model = access_token_model
29+
self.model = model
30+
31+
async def get_by_token(
32+
self, token: str, max_age: Optional[datetime] = None
33+
) -> Optional[A]:
34+
query = self.model.filter(token=token)
35+
if max_age is not None:
36+
query = query.filter(created_at__gte=max_age)
37+
38+
access_token = await query.first()
39+
if access_token is not None:
40+
return await self._model_to_pydantic(access_token)
41+
return None
42+
43+
async def create(self, access_token: A) -> A:
44+
model = self.model(**access_token.dict())
45+
await model.save()
46+
await model.refresh_from_db()
47+
return await self._model_to_pydantic(model)
48+
49+
async def update(self, access_token: A) -> A:
50+
model = await self.model.get(token=access_token.token)
51+
for field, value in access_token.dict().items():
52+
setattr(model, field, value)
53+
await model.save()
54+
return await self._model_to_pydantic(model)
55+
56+
async def delete(self, access_token: A) -> None:
57+
await self.model.filter(token=access_token.token).delete()
58+
59+
async def _model_to_pydantic(self, model: TortoiseBaseAccessTokenModel) -> A:
60+
pydantic_access_token = await cast(
61+
PydanticModel, self.access_token_model
62+
).from_tortoise_orm(model)
63+
return cast(A, pydantic_access_token)

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
description-file = "README.md"
2323
requires-python = ">=3.7"
2424
requires = [
25-
"fastapi-users >= 6.1.2",
25+
"fastapi-users >= 9.1.0",
2626
"tortoise-orm >=0.17.6,<0.18.0"
2727
]
2828

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
fastapi-users >= 6.1.2
1+
fastapi-users >= 9.1.0
22
tortoise-orm >= 0.17.6,<0.18.0

tests/test_access_token.py

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import uuid
2+
from datetime import datetime, timedelta, timezone
3+
from typing import AsyncGenerator
4+
5+
import pytest
6+
from fastapi_users.authentication.strategy.db.models import BaseAccessToken
7+
from pydantic import UUID4
8+
from tortoise import Tortoise, fields
9+
from tortoise.contrib.pydantic import PydanticModel
10+
from tortoise.exceptions import IntegrityError
11+
12+
from fastapi_users_db_tortoise import TortoiseBaseUserModel
13+
from fastapi_users_db_tortoise.access_token import (
14+
TortoiseAccessTokenDatabase,
15+
TortoiseBaseAccessTokenModel,
16+
)
17+
from tests.conftest import UserDB as BaseUserDB
18+
19+
20+
class UserModel(TortoiseBaseUserModel):
21+
pass
22+
23+
24+
class UserDB(BaseUserDB, PydanticModel):
25+
class Config:
26+
orm_mode = True
27+
orig_model = UserModel
28+
29+
30+
class AccessTokenModel(TortoiseBaseAccessTokenModel):
31+
user = fields.ForeignKeyField("models.UserModel", related_name="access_tokens")
32+
33+
34+
class AccessToken(BaseAccessToken, PydanticModel):
35+
class Config:
36+
orm_mode = True
37+
orig_model = AccessTokenModel
38+
39+
40+
@pytest.fixture
41+
def user_id() -> UUID4:
42+
return uuid.uuid4()
43+
44+
45+
@pytest.fixture
46+
async def tortoise_access_token_db(
47+
user_id: UUID4,
48+
) -> AsyncGenerator[TortoiseAccessTokenDatabase, None]:
49+
DATABASE_URL = "sqlite://./test-tortoise-access-token.db"
50+
51+
await Tortoise.init(
52+
db_url=DATABASE_URL,
53+
modules={"models": ["tests.test_access_token"]},
54+
)
55+
await Tortoise.generate_schemas()
56+
57+
user = UserModel(
58+
id=user_id,
59+
60+
hashed_password="guinevere",
61+
is_active=True,
62+
is_verified=True,
63+
is_superuser=False,
64+
)
65+
await user.save()
66+
67+
yield TortoiseAccessTokenDatabase(AccessToken, AccessTokenModel)
68+
69+
await AccessTokenModel.all().delete()
70+
await UserModel.all().delete()
71+
await Tortoise.close_connections()
72+
73+
74+
@pytest.mark.asyncio
75+
@pytest.mark.db
76+
async def test_queries(
77+
tortoise_access_token_db: TortoiseAccessTokenDatabase[AccessToken],
78+
user_id: UUID4,
79+
):
80+
access_token = AccessToken(token="TOKEN", user_id=user_id)
81+
82+
# Create
83+
access_token_db = await tortoise_access_token_db.create(access_token)
84+
assert access_token_db.token == "TOKEN"
85+
assert access_token_db.user_id == user_id
86+
87+
# Update
88+
access_token_db.created_at = datetime.now(timezone.utc)
89+
await tortoise_access_token_db.update(access_token_db)
90+
91+
# Get by token
92+
access_token_by_token = await tortoise_access_token_db.get_by_token(
93+
access_token_db.token
94+
)
95+
assert access_token_by_token is not None
96+
97+
# Get by token expired
98+
access_token_by_token = await tortoise_access_token_db.get_by_token(
99+
access_token_db.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1)
100+
)
101+
assert access_token_by_token is None
102+
103+
# Get by token not expired
104+
access_token_by_token = await tortoise_access_token_db.get_by_token(
105+
access_token_db.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1)
106+
)
107+
assert access_token_by_token is not None
108+
109+
# Get by token unknown
110+
access_token_by_token = await tortoise_access_token_db.get_by_token(
111+
"NOT_EXISTING_TOKEN"
112+
)
113+
assert access_token_by_token is None
114+
115+
# Exception when inserting existing token
116+
with pytest.raises(IntegrityError):
117+
await tortoise_access_token_db.create(access_token_db)
118+
119+
# Delete token
120+
await tortoise_access_token_db.delete(access_token_db)
121+
deleted_access_token = await tortoise_access_token_db.get_by_token(
122+
access_token_db.token
123+
)
124+
assert deleted_access_token is None

tests/test_fastapi_users_db_tortoise.py renamed to tests/test_users.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
TortoiseBaseUserModel,
1111
TortoiseUserDatabase,
1212
)
13-
from tests.conftest import UserDB as BaseUserDB, UserDBOAuth as BaseUserDBOAuth
13+
from tests.conftest import UserDB as BaseUserDB
14+
from tests.conftest import UserDBOAuth as BaseUserDBOAuth
1415

1516

1617
class User(TortoiseBaseUserModel):
@@ -39,7 +40,7 @@ async def tortoise_user_db() -> AsyncGenerator[TortoiseUserDatabase, None]:
3940

4041
await Tortoise.init(
4142
db_url=DATABASE_URL,
42-
modules={"models": ["tests.test_fastapi_users_db_tortoise"]},
43+
modules={"models": ["tests.test_users"]},
4344
)
4445
await Tortoise.generate_schemas()
4546

@@ -55,7 +56,7 @@ async def tortoise_user_db_oauth() -> AsyncGenerator[TortoiseUserDatabase, None]
5556

5657
await Tortoise.init(
5758
db_url=DATABASE_URL,
58-
modules={"models": ["tests.test_fastapi_users_db_tortoise"]},
59+
modules={"models": ["tests.test_users"]},
5960
)
6061
await Tortoise.generate_schemas()
6162

0 commit comments

Comments
 (0)