1
1
""" Repository pattern, errors and data structures for models.tags
2
2
"""
3
-
4
- from typing import TypedDict
5
-
3
+ from common_library .errors_classes import OsparcErrorMixin
6
4
from sqlalchemy .ext .asyncio import AsyncConnection , AsyncEngine
5
+ from typing_extensions import TypedDict
7
6
8
7
from .utils_repos import pass_or_acquire_connection , transaction_context
9
8
from .utils_tags_sql import (
9
+ TagAccessRightsDict ,
10
10
count_groups_with_given_access_rights_stmt ,
11
11
create_tag_stmt ,
12
+ delete_tag_access_rights_stmt ,
12
13
delete_tag_stmt ,
13
14
get_tag_stmt ,
15
+ has_access_rights_stmt ,
16
+ list_tag_group_access_stmt ,
14
17
list_tags_stmt ,
15
- set_tag_access_rights_stmt ,
16
18
update_tag_stmt ,
19
+ upsert_tags_access_rights_stmt ,
17
20
)
18
21
22
+ __all__ : tuple [str , ...] = ("TagAccessRightsDict" ,)
23
+
19
24
20
25
#
21
26
# Errors
22
27
#
23
- class BaseTagError ( Exception ):
24
- pass
28
+ class _BaseTagError ( OsparcErrorMixin , Exception ):
29
+ msg_template = "Tag repo error on tag {tag_id}"
25
30
26
31
27
- class TagNotFoundError (BaseTagError ):
32
+ class TagNotFoundError (_BaseTagError ):
28
33
pass
29
34
30
35
31
- class TagOperationNotAllowedError (BaseTagError ): # maps to AccessForbidden
36
+ class TagOperationNotAllowedError (_BaseTagError ): # maps to AccessForbidden
32
37
pass
33
38
34
39
@@ -108,7 +113,7 @@ async def create(
108
113
assert tag # nosec
109
114
110
115
# take tag ownership
111
- access_stmt = set_tag_access_rights_stmt (
116
+ access_stmt = upsert_tags_access_rights_stmt (
112
117
tag_id = tag .id ,
113
118
user_id = user_id ,
114
119
read = read ,
@@ -163,8 +168,7 @@ async def get(
163
168
result = await conn .execute (stmt_get )
164
169
row = result .first ()
165
170
if not row :
166
- msg = f"{ tag_id = } not found: either no access or does not exists"
167
- raise TagNotFoundError (msg )
171
+ raise TagNotFoundError (operation = "get" , tag_id = tag_id , user_id = user_id )
168
172
return TagDict (
169
173
id = row .id ,
170
174
name = row .name ,
@@ -198,8 +202,9 @@ async def update(
198
202
result = await conn .execute (update_stmt )
199
203
row = result .first ()
200
204
if not row :
201
- msg = f"{ tag_id = } not updated: either no access or not found"
202
- raise TagOperationNotAllowedError (msg )
205
+ raise TagOperationNotAllowedError (
206
+ operation = "update" , tag_id = tag_id , user_id = user_id
207
+ )
203
208
204
209
return TagDict (
205
210
id = row .id ,
@@ -222,44 +227,95 @@ async def delete(
222
227
async with transaction_context (self .engine , connection ) as conn :
223
228
deleted = await conn .scalar (stmt_delete )
224
229
if not deleted :
225
- msg = f"Could not delete { tag_id = } . Not found or insuficient access."
226
- raise TagOperationNotAllowedError (msg )
230
+ raise TagOperationNotAllowedError (
231
+ operation = "delete" , tag_id = tag_id , user_id = user_id
232
+ )
227
233
228
234
#
229
235
# ACCESS RIGHTS
230
236
#
231
237
232
- async def create_access_rights (
238
+ async def has_access_rights (
233
239
self ,
234
240
connection : AsyncConnection | None = None ,
235
241
* ,
236
242
user_id : int ,
237
243
tag_id : int ,
238
- group_id : int ,
239
- read : bool ,
240
- write : bool ,
241
- delete : bool ,
242
- ):
243
- raise NotImplementedError
244
+ read : bool = False ,
245
+ write : bool = False ,
246
+ delete : bool = False ,
247
+ ) -> bool :
248
+ async with pass_or_acquire_connection (self .engine , connection ) as conn :
249
+ group_id_or_none = await conn .scalar (
250
+ has_access_rights_stmt (
251
+ tag_id = tag_id ,
252
+ caller_user_id = user_id ,
253
+ read = read ,
254
+ write = write ,
255
+ delete = delete ,
256
+ )
257
+ )
258
+ return bool (group_id_or_none )
244
259
245
- async def update_access_rights (
260
+ async def list_access_rights (
261
+ self ,
262
+ connection : AsyncConnection | None = None ,
263
+ * ,
264
+ tag_id : int ,
265
+ ) -> list [TagAccessRightsDict ]:
266
+ async with pass_or_acquire_connection (self .engine , connection ) as conn :
267
+ result = await conn .execute (list_tag_group_access_stmt (tag_id = tag_id ))
268
+ return [
269
+ TagAccessRightsDict (
270
+ tag_id = row .tag_id ,
271
+ group_id = row .group_id ,
272
+ read = row .read ,
273
+ write = row .write ,
274
+ delete = row .delete ,
275
+ )
276
+ for row in result .fetchall ()
277
+ ]
278
+
279
+ async def create_or_update_access_rights (
246
280
self ,
247
281
connection : AsyncConnection | None = None ,
248
282
* ,
249
- user_id : int ,
250
283
tag_id : int ,
251
284
group_id : int ,
252
285
read : bool ,
253
286
write : bool ,
254
287
delete : bool ,
255
- ):
256
- raise NotImplementedError
288
+ ) -> TagAccessRightsDict :
289
+ async with transaction_context (self .engine , connection ) as conn :
290
+ result = await conn .execute (
291
+ upsert_tags_access_rights_stmt (
292
+ tag_id = tag_id ,
293
+ group_id = group_id ,
294
+ read = read ,
295
+ write = write ,
296
+ delete = delete ,
297
+ )
298
+ )
299
+ row = result .first ()
300
+ assert row is not None
301
+
302
+ return TagAccessRightsDict (
303
+ tag_id = row .tag_id ,
304
+ group_id = row .group_id ,
305
+ read = row .read ,
306
+ write = row .write ,
307
+ delete = row .delete ,
308
+ )
257
309
258
310
async def delete_access_rights (
259
311
self ,
260
312
connection : AsyncConnection | None = None ,
261
313
* ,
262
- user_id : int ,
263
314
tag_id : int ,
264
- ):
265
- raise NotImplementedError
315
+ group_id : int ,
316
+ ) -> bool :
317
+ async with transaction_context (self .engine , connection ) as conn :
318
+ deleted : bool = await conn .scalar (
319
+ delete_tag_access_rights_stmt (tag_id = tag_id , group_id = group_id )
320
+ )
321
+ return deleted
0 commit comments