diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 7a263044f3e..d8684350361 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -41,6 +41,7 @@ Makefile @pcrespov @sanderegg /services/static-webserver/ @GitHK /services/static-webserver/client/ @odeimaiz /services/storage/ @sanderegg +/services/storage/modules/celery @giancarloromeo /services/web/server/ @pcrespov @sanderegg @GitHK @matusdrobuliak66 /tests/e2e-frontend/ @odeimaiz /tests/e2e-playwright/ @matusdrobuliak66 diff --git a/api/specs/web-server/_long_running_tasks.py b/api/specs/web-server/_long_running_tasks.py index 884c81708da..f204c1de5b4 100644 --- a/api/specs/web-server/_long_running_tasks.py +++ b/api/specs/web-server/_long_running_tasks.py @@ -58,7 +58,7 @@ def get_async_job_status( responses=_export_data_responses, status_code=status.HTTP_204_NO_CONTENT, ) -def abort_async_job( +def cancel_async_job( _path_params: Annotated[_PathParam, Depends()], ): ... diff --git a/services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py b/services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py index 080d5edf045..3186237eb7e 100644 --- a/services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py +++ b/services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py @@ -36,7 +36,7 @@ async def cancel(app: FastAPI, job_id: AsyncJobId, job_id_data: AsyncJobNameData assert app # nosec assert job_id_data # nosec try: - await get_celery_client(app).abort_task( + await get_celery_client(app).cancel_task( task_context=job_id_data.model_dump(), task_uuid=job_id, ) diff --git a/services/storage/src/simcore_service_storage/modules/celery/_task.py b/services/storage/src/simcore_service_storage/modules/celery/_task.py index a6f7c1a365e..e367a3a73da 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/_task.py +++ b/services/storage/src/simcore_service_storage/modules/celery/_task.py @@ -67,6 +67,7 @@ async def abort_monitor(): main_task, max_delay=_DEFAULT_CANCEL_TASK_TIMEOUT.total_seconds(), ) + AbortableAsyncResult(task_id, app=app).forget() raise TaskAbortedError await asyncio.sleep( _DEFAULT_ABORT_TASK_TIMEOUT.total_seconds() diff --git a/services/storage/src/simcore_service_storage/modules/celery/client.py b/services/storage/src/simcore_service_storage/modules/celery/client.py index 305731f946a..f68baf558fe 100644 --- a/services/storage/src/simcore_service_storage/modules/celery/client.py +++ b/services/storage/src/simcore_service_storage/modules/celery/client.py @@ -35,7 +35,7 @@ class CeleryTaskClient: _celery_app: Celery _celery_settings: CelerySettings - _task_store: TaskInfoStore + _task_info_store: TaskInfoStore async def submit_task( self, @@ -63,22 +63,25 @@ async def submit_task( if task_metadata.ephemeral else self._celery_settings.CELERY_RESULT_EXPIRES ) - await self._task_store.create_task(task_id, task_metadata, expiry=expiry) + await self._task_info_store.create_task( + task_id, task_metadata, expiry=expiry + ) return task_uuid @make_async() - def _abort_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> None: - AbortableAsyncResult( - build_task_id(task_context, task_uuid), app=self._celery_app - ).abort() + def _abort_task(self, task_id: TaskID) -> None: + AbortableAsyncResult(task_id, app=self._celery_app).abort() - async def abort_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> None: + async def cancel_task(self, task_context: TaskContext, task_uuid: TaskUUID) -> None: with log_context( _logger, logging.DEBUG, - msg=f"Abort task: {task_context=} {task_uuid=}", + msg=f"task cancellation: {task_context=} {task_uuid=}", ): - await self._abort_task(task_context, task_uuid) + task_id = build_task_id(task_context, task_uuid) + if not (await self.get_task_status(task_context, task_uuid)).is_done: + await self._abort_task(task_id) + await self._task_info_store.remove_task(task_id) @make_async() def _forget_task(self, task_id: TaskID) -> None: @@ -96,10 +99,10 @@ async def get_task_result( async_result = self._celery_app.AsyncResult(task_id) result = async_result.result if async_result.ready(): - task_metadata = await self._task_store.get_task_metadata(task_id) + task_metadata = await self._task_info_store.get_task_metadata(task_id) if task_metadata is not None and task_metadata.ephemeral: - await self._task_store.remove_task(task_id) await self._forget_task(task_id) + await self._task_info_store.remove_task(task_id) return result async def _get_task_progress_report( @@ -107,7 +110,7 @@ async def _get_task_progress_report( ) -> ProgressReport: if task_state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED): task_id = build_task_id(task_context, task_uuid) - progress = await self._task_store.get_task_progress(task_id) + progress = await self._task_info_store.get_task_progress(task_id) if progress is not None: return progress if task_state in ( @@ -124,10 +127,7 @@ async def _get_task_progress_report( ) @make_async() - def _get_task_celery_state( - self, task_context: TaskContext, task_uuid: TaskUUID - ) -> TaskState: - task_id = build_task_id(task_context, task_uuid) + def _get_task_celery_state(self, task_id: TaskID) -> TaskState: return TaskState(self._celery_app.AsyncResult(task_id).state) async def get_task_status( @@ -138,7 +138,8 @@ async def get_task_status( logging.DEBUG, msg=f"Getting task status: {task_context=} {task_uuid=}", ): - task_state = await self._get_task_celery_state(task_context, task_uuid) + task_id = build_task_id(task_context, task_uuid) + task_state = await self._get_task_celery_state(task_id) return TaskStatus( task_uuid=task_uuid, task_state=task_state, @@ -153,4 +154,4 @@ async def list_tasks(self, task_context: TaskContext) -> list[Task]: logging.DEBUG, msg=f"Listing tasks: {task_context=}", ): - return await self._task_store.list_tasks(task_context) + return await self._task_info_store.list_tasks(task_context) diff --git a/services/storage/tests/unit/test_async_jobs.py b/services/storage/tests/unit/test_async_jobs.py index 95319a6533f..36f29a15bd8 100644 --- a/services/storage/tests/unit/test_async_jobs.py +++ b/services/storage/tests/unit/test_async_jobs.py @@ -277,6 +277,14 @@ async def test_async_jobs_cancel( job_id_data=job_id_data, ) + jobs = await async_jobs.list_jobs( + storage_rabbitmq_rpc_client, + rpc_namespace=STORAGE_RPC_NAMESPACE, + filter_="", # currently not used + job_id_data=job_id_data, + ) + assert async_job_get.job_id not in [job.job_id for job in jobs] + with pytest.raises(JobAbortedError): await async_jobs.result( storage_rabbitmq_rpc_client, diff --git a/services/storage/tests/unit/test_modules_celery.py b/services/storage/tests/unit/test_modules_celery.py index d5f3ce70b98..b1819aabb44 100644 --- a/services/storage/tests/unit/test_modules_celery.py +++ b/services/storage/tests/unit/test_modules_celery.py @@ -166,7 +166,7 @@ async def test_submitting_task_with_failure_results_with_error( assert f"{raw_result}" == "Something strange happened: BOOM!" -async def test_aborting_task_results_with_aborted_state( +async def test_cancelling_a_running_task_aborts_and_deletes( celery_client: CeleryTaskClient, ): task_context = TaskContext(user_id=42) @@ -178,7 +178,7 @@ async def test_aborting_task_results_with_aborted_state( task_context=task_context, ) - await celery_client.abort_task(task_context, task_uuid) + await celery_client.cancel_task(task_context, task_uuid) for attempt in Retrying( retry=retry_if_exception_type(AssertionError), @@ -193,6 +193,8 @@ async def test_aborting_task_results_with_aborted_state( await celery_client.get_task_status(task_context, task_uuid) ).task_state == TaskState.ABORTED + assert task_uuid not in await celery_client.list_tasks(task_context) + async def test_listing_task_uuids_contains_submitted_task( celery_client: CeleryTaskClient, diff --git a/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml b/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml index 73a7abf032c..3681c1e66e1 100644 --- a/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml +++ b/services/web/server/src/simcore_service_webserver/api/v0/openapi.yaml @@ -3129,7 +3129,7 @@ paths: - long-running-tasks summary: Cancel And Delete Task description: Cancels and deletes a task - operationId: abort_async_job + operationId: cancel_async_job parameters: - name: task_id in: path diff --git a/services/web/server/src/simcore_service_webserver/storage/_rest.py b/services/web/server/src/simcore_service_webserver/storage/_rest.py index 9fa80f4bc4d..b8a1f18a398 100644 --- a/services/web/server/src/simcore_service_webserver/storage/_rest.py +++ b/services/web/server/src/simcore_service_webserver/storage/_rest.py @@ -185,7 +185,7 @@ def _create_data_response_from_async_job( task_id=async_job_id, task_name=async_job_id, status_href=f"{request.url.with_path(str(request.app.router['get_async_job_status'].url_for(task_id=async_job_id)))}", - abort_href=f"{request.url.with_path(str(request.app.router['abort_async_job'].url_for(task_id=async_job_id)))}", + abort_href=f"{request.url.with_path(str(request.app.router['cancel_async_job'].url_for(task_id=async_job_id)))}", result_href=f"{request.url.with_path(str(request.app.router['get_async_job_result'].url_for(task_id=async_job_id)))}", ), status=status.HTTP_202_ACCEPTED, @@ -505,7 +505,7 @@ def allow_only_simcore(cls, v: int) -> int: task_id=_job_id, task_name=_job_id, status_href=f"{request.url.with_path(str(request.app.router['get_async_job_status'].url_for(task_id=_job_id)))}", - abort_href=f"{request.url.with_path(str(request.app.router['abort_async_job'].url_for(task_id=_job_id)))}", + abort_href=f"{request.url.with_path(str(request.app.router['cancel_async_job'].url_for(task_id=_job_id)))}", result_href=f"{request.url.with_path(str(request.app.router['get_async_job_result'].url_for(task_id=_job_id)))}", ), status=status.HTTP_202_ACCEPTED, diff --git a/services/web/server/src/simcore_service_webserver/tasks/_rest.py b/services/web/server/src/simcore_service_webserver/tasks/_rest.py index 71850d627a7..a4c95a6e1cc 100644 --- a/services/web/server/src/simcore_service_webserver/tasks/_rest.py +++ b/services/web/server/src/simcore_service_webserver/tasks/_rest.py @@ -87,7 +87,7 @@ async def get_async_jobs(request: web.Request) -> web.Response: task_id=f"{job.job_id}", task_name=job.job_name, status_href=f"{request.url.with_path(str(request.app.router['get_async_job_status'].url_for(task_id=str(job.job_id))))}", - abort_href=f"{request.url.with_path(str(request.app.router['abort_async_job'].url_for(task_id=str(job.job_id))))}", + abort_href=f"{request.url.with_path(str(request.app.router['cancel_async_job'].url_for(task_id=str(job.job_id))))}", result_href=f"{request.url.with_path(str(request.app.router['get_async_job_result'].url_for(task_id=str(job.job_id))))}", ) for job in user_async_jobs @@ -136,17 +136,18 @@ async def get_async_job_status(request: web.Request) -> web.Response: @routes.delete( _task_prefix + "/{task_id}", - name="abort_async_job", + name="cancel_async_job", ) @login_required @permission_required("storage.files.*") @handle_export_data_exceptions -async def abort_async_job(request: web.Request) -> web.Response: +async def cancel_async_job(request: web.Request) -> web.Response: _req_ctx = RequestContext.model_validate(request) rabbitmq_rpc_client = get_rabbitmq_rpc_client(request.app) async_job_get = parse_request_path_parameters_as(_StorageAsyncJobId, request) + await async_jobs.cancel( rabbitmq_rpc_client=rabbitmq_rpc_client, rpc_namespace=STORAGE_RPC_NAMESPACE, @@ -155,6 +156,7 @@ async def abort_async_job(request: web.Request) -> web.Response: user_id=_req_ctx.user_id, product_name=_req_ctx.product_name ), ) + return web.Response(status=status.HTTP_204_NO_CONTENT) diff --git a/services/web/server/tests/unit/with_dbs/01/storage/test_storage.py b/services/web/server/tests/unit/with_dbs/01/storage/test_storage.py index 73b6c3c086a..d7a0c1087b4 100644 --- a/services/web/server/tests/unit/with_dbs/01/storage/test_storage.py +++ b/services/web/server/tests/unit/with_dbs/01/storage/test_storage.py @@ -538,7 +538,7 @@ async def test_get_async_jobs_status( ], ids=lambda x: type(x).__name__, ) -async def test_abort_async_jobs( +async def test_cancel_async_jobs( user_role: UserRole, logged_user: UserInfoDict, client: TestClient,