2
2
"""
3
3
4
4
import itertools
5
+ from dataclasses import dataclass
5
6
from typing import TypedDict
6
7
7
- from sqlalchemy . ext . asyncio import AsyncConnection , AsyncEngine
8
+ from aiopg . sa . connection import SAConnection
8
9
9
- from .base_repo import MinimalRepo , get_or_create_connection , transaction_context
10
- from .models .tags import tags as tags_table
11
10
from .tags_sql import (
12
11
count_users_with_access_rights_stmt ,
13
12
create_tag_stmt ,
@@ -50,46 +49,43 @@ class TagDict(TypedDict, total=True):
50
49
delete : bool
51
50
52
51
52
+ @dataclass (frozen = True )
53
53
class TagsRepo :
54
- def __init__ (self , engine : AsyncEngine ):
55
- self .engine = engine
56
- self ._impl = MinimalRepo (engine = engine , table = tags_table )
54
+ user_id : int # Determines access-rights
57
55
58
56
async def access_count (
59
57
self ,
60
- user_id : int ,
58
+ conn : SAConnection ,
61
59
tag_id : int ,
60
+ * ,
62
61
read : bool | None = None ,
63
62
write : bool | None = None ,
64
63
delete : bool | None = None ,
65
- connection : AsyncConnection | None = None ,
66
64
) -> int :
67
65
"""
68
66
Returns 0 if tag does not match access
69
67
Returns >0 if it does and represents the number of groups granting this access to the user
70
68
"""
71
- async with get_or_create_connection (self .engine , connection ) as conn :
72
- count_stmt = count_users_with_access_rights_stmt (
73
- user_id = user_id , tag_id = tag_id , read = read , write = write , delete = delete
74
- )
75
- permissions_count : int | None = await conn .scalar (count_stmt )
76
- return permissions_count if permissions_count else 0
69
+ count_stmt = count_users_with_access_rights_stmt (
70
+ user_id = self .user_id , tag_id = tag_id , read = read , write = write , delete = delete
71
+ )
72
+ permissions_count : int | None = await conn .scalar (count_stmt )
73
+ return permissions_count if permissions_count else 0
77
74
78
75
#
79
76
# CRUD operations
80
77
#
81
78
82
79
async def create (
83
80
self ,
81
+ conn : SAConnection ,
84
82
* ,
85
- user_id : int ,
86
83
name : str ,
87
84
color : str ,
88
85
description : str | None = None , # =nullable
89
86
read : bool = True ,
90
87
write : bool = True ,
91
88
delete : bool = True ,
92
- connection : AsyncConnection | None = None ,
93
89
) -> TagDict :
94
90
values = {
95
91
"name" : name ,
@@ -98,58 +94,44 @@ async def create(
98
94
if description :
99
95
values ["description" ] = description
100
96
101
- async with transaction_context ( self . engine , connection ) as conn :
97
+ async with conn . begin () :
102
98
# insert new tag
103
99
insert_stmt = create_tag_stmt (** values )
104
100
result = await conn .execute (insert_stmt )
105
- tag = result .first ()
101
+ tag = await result .first ()
106
102
assert tag # nosec
107
103
108
104
# take tag ownership
109
105
access_stmt = set_tag_access_rights_stmt (
110
106
tag_id = tag .id ,
111
- user_id = user_id ,
107
+ user_id = self . user_id ,
112
108
read = read ,
113
109
write = write ,
114
110
delete = delete ,
115
111
)
116
112
result = await conn .execute (access_stmt )
117
- access = result .first ()
113
+ access = await result .first ()
118
114
assert access
119
115
120
116
return TagDict (itertools .chain (tag .items (), access .items ())) # type: ignore
121
117
122
- async def list_all (
123
- self ,
124
- * ,
125
- user_id : int ,
126
- connection : AsyncConnection | None = None ,
127
- ) -> list [TagDict ]:
128
- async with get_or_create_connection (self .engine , connection ) as conn :
129
- stmt_list = list_tags_stmt (user_id = user_id )
130
- return [TagDict (row .items ()) async for row in conn .execute (stmt_list )] # type: ignore
131
-
132
- async def get (
133
- self ,
134
- user_id : int ,
135
- tag_id : int ,
136
- connection : AsyncConnection | None = None ,
137
- ) -> TagDict :
138
- async with get_or_create_connection (self .engine , connection ) as conn :
139
- stmt_get = get_tag_stmt (user_id = user_id , tag_id = tag_id )
140
- result = await conn .execute (stmt_get )
141
- row = result .first ()
142
- if not row :
143
- msg = f"{ tag_id = } not found: either no access or does not exists"
144
- raise TagNotFoundError (msg )
145
- return TagDict (row .items ()) # type: ignore
118
+ async def list_all (self , conn : SAConnection ) -> list [TagDict ]:
119
+ stmt_list = list_tags_stmt (user_id = self .user_id )
120
+ return [TagDict (row .items ()) async for row in conn .execute (stmt_list )] # type: ignore
121
+
122
+ async def get (self , conn : SAConnection , tag_id : int ) -> TagDict :
123
+ stmt_get = get_tag_stmt (user_id = self .user_id , tag_id = tag_id )
124
+ result = await conn .execute (stmt_get )
125
+ row = await result .first ()
126
+ if not row :
127
+ msg = f"{ tag_id = } not found: either no access or does not exists"
128
+ raise TagNotFoundError (msg )
129
+ return TagDict (row .items ()) # type: ignore
146
130
147
131
async def update (
148
132
self ,
149
- * ,
133
+ conn : SAConnection ,
150
134
tag_id : int ,
151
- user_id : int ,
152
- connection : AsyncConnection | None = None ,
153
135
** fields ,
154
136
) -> TagDict :
155
137
updates = {
@@ -160,28 +142,21 @@ async def update(
160
142
161
143
if not updates :
162
144
# no updates == get
163
- return await self .get (user_id = user_id , tag_id = tag_id , connection = connection )
145
+ return await self .get (conn , tag_id = tag_id )
164
146
165
- async with get_or_create_connection (self .engine , connection ) as conn :
166
- update_stmt = update_tag_stmt (user_id = user_id , tag_id = tag_id , ** updates )
167
- result = await conn .execute (update_stmt )
168
- row = result .first ()
169
- if not row :
170
- msg = f"{ tag_id = } not updated: either no access or not found"
171
- raise TagOperationNotAllowedError (msg )
147
+ update_stmt = update_tag_stmt (user_id = self .user_id , tag_id = tag_id , ** updates )
148
+ result = await conn .execute (update_stmt )
149
+ row = await result .first ()
150
+ if not row :
151
+ msg = f"{ tag_id = } not updated: either no access or not found"
152
+ raise TagOperationNotAllowedError (msg )
172
153
173
- return TagDict (row .items ()) # type: ignore
154
+ return TagDict (row .items ()) # type: ignore
174
155
175
- async def delete (
176
- self ,
177
- * ,
178
- user_id : int ,
179
- tag_id : int ,
180
- connection : AsyncConnection | None = None ,
181
- ) -> None :
182
- async with get_or_create_connection (self .engine , connection ) as conn :
183
- stmt_delete = delete_tag_stmt (user_id = user_id , tag_id = tag_id )
184
- deleted = await conn .scalar (stmt_delete )
185
- if not deleted :
186
- msg = f"Could not delete { tag_id = } . Not found or insuficient access."
187
- raise TagOperationNotAllowedError (msg )
156
+ async def delete (self , conn : SAConnection , tag_id : int ) -> None :
157
+ stmt_delete = delete_tag_stmt (user_id = self .user_id , tag_id = tag_id )
158
+
159
+ deleted = await conn .scalar (stmt_delete )
160
+ if not deleted :
161
+ msg = f"Could not delete { tag_id = } . Not found or insuficient access."
162
+ raise TagOperationNotAllowedError (msg )
0 commit comments