|
| 1 | +import uuid |
| 2 | +from datetime import datetime, timedelta, timezone |
| 3 | +from typing import AsyncGenerator |
| 4 | + |
| 5 | +import pytest |
| 6 | +from fastapi_users.authentication.strategy.db.models import BaseAccessToken |
| 7 | +from pydantic import UUID4 |
| 8 | +from tortoise import Tortoise, fields |
| 9 | +from tortoise.contrib.pydantic import PydanticModel |
| 10 | +from tortoise.exceptions import IntegrityError |
| 11 | + |
| 12 | +from fastapi_users_db_tortoise import TortoiseBaseUserModel |
| 13 | +from fastapi_users_db_tortoise.access_token import ( |
| 14 | + TortoiseAccessTokenDatabase, |
| 15 | + TortoiseBaseAccessTokenModel, |
| 16 | +) |
| 17 | +from tests.conftest import UserDB as BaseUserDB |
| 18 | + |
| 19 | + |
| 20 | +class UserModel(TortoiseBaseUserModel): |
| 21 | + pass |
| 22 | + |
| 23 | + |
| 24 | +class UserDB(BaseUserDB, PydanticModel): |
| 25 | + class Config: |
| 26 | + orm_mode = True |
| 27 | + orig_model = UserModel |
| 28 | + |
| 29 | + |
| 30 | +class AccessTokenModel(TortoiseBaseAccessTokenModel): |
| 31 | + user = fields.ForeignKeyField("models.UserModel", related_name="access_tokens") |
| 32 | + |
| 33 | + |
| 34 | +class AccessToken(BaseAccessToken, PydanticModel): |
| 35 | + class Config: |
| 36 | + orm_mode = True |
| 37 | + orig_model = AccessTokenModel |
| 38 | + |
| 39 | + |
| 40 | +@pytest.fixture |
| 41 | +def user_id() -> UUID4: |
| 42 | + return uuid.uuid4() |
| 43 | + |
| 44 | + |
| 45 | +@pytest.fixture |
| 46 | +async def tortoise_access_token_db( |
| 47 | + user_id: UUID4, |
| 48 | +) -> AsyncGenerator[TortoiseAccessTokenDatabase, None]: |
| 49 | + DATABASE_URL = "sqlite://./test-tortoise-access-token.db" |
| 50 | + |
| 51 | + await Tortoise.init( |
| 52 | + db_url=DATABASE_URL, |
| 53 | + modules={"models": ["tests.test_access_token"]}, |
| 54 | + ) |
| 55 | + await Tortoise.generate_schemas() |
| 56 | + |
| 57 | + user = UserModel( |
| 58 | + id=user_id, |
| 59 | + |
| 60 | + hashed_password="guinevere", |
| 61 | + is_active=True, |
| 62 | + is_verified=True, |
| 63 | + is_superuser=False, |
| 64 | + ) |
| 65 | + await user.save() |
| 66 | + |
| 67 | + yield TortoiseAccessTokenDatabase(AccessToken, AccessTokenModel) |
| 68 | + |
| 69 | + await AccessTokenModel.all().delete() |
| 70 | + await UserModel.all().delete() |
| 71 | + await Tortoise.close_connections() |
| 72 | + |
| 73 | + |
| 74 | +@pytest.mark.asyncio |
| 75 | +@pytest.mark.db |
| 76 | +async def test_queries( |
| 77 | + tortoise_access_token_db: TortoiseAccessTokenDatabase[AccessToken], |
| 78 | + user_id: UUID4, |
| 79 | +): |
| 80 | + access_token = AccessToken(token="TOKEN", user_id=user_id) |
| 81 | + |
| 82 | + # Create |
| 83 | + access_token_db = await tortoise_access_token_db.create(access_token) |
| 84 | + assert access_token_db.token == "TOKEN" |
| 85 | + assert access_token_db.user_id == user_id |
| 86 | + |
| 87 | + # Update |
| 88 | + access_token_db.created_at = datetime.now(timezone.utc) |
| 89 | + await tortoise_access_token_db.update(access_token_db) |
| 90 | + |
| 91 | + # Get by token |
| 92 | + access_token_by_token = await tortoise_access_token_db.get_by_token( |
| 93 | + access_token_db.token |
| 94 | + ) |
| 95 | + assert access_token_by_token is not None |
| 96 | + |
| 97 | + # Get by token expired |
| 98 | + access_token_by_token = await tortoise_access_token_db.get_by_token( |
| 99 | + access_token_db.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1) |
| 100 | + ) |
| 101 | + assert access_token_by_token is None |
| 102 | + |
| 103 | + # Get by token not expired |
| 104 | + access_token_by_token = await tortoise_access_token_db.get_by_token( |
| 105 | + access_token_db.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1) |
| 106 | + ) |
| 107 | + assert access_token_by_token is not None |
| 108 | + |
| 109 | + # Get by token unknown |
| 110 | + access_token_by_token = await tortoise_access_token_db.get_by_token( |
| 111 | + "NOT_EXISTING_TOKEN" |
| 112 | + ) |
| 113 | + assert access_token_by_token is None |
| 114 | + |
| 115 | + # Exception when inserting existing token |
| 116 | + with pytest.raises(IntegrityError): |
| 117 | + await tortoise_access_token_db.create(access_token_db) |
| 118 | + |
| 119 | + # Delete token |
| 120 | + await tortoise_access_token_db.delete(access_token_db) |
| 121 | + deleted_access_token = await tortoise_access_token_db.get_by_token( |
| 122 | + access_token_db.token |
| 123 | + ) |
| 124 | + assert deleted_access_token is None |
0 commit comments