6
6
from contextlib import contextmanager
7
7
from datetime import datetime
8
8
from pathlib import Path
9
- from typing import Dict , List , Union , Any
9
+ from typing import Any , Dict
10
10
11
- import aio_pika
12
- import aiofiles
13
- import attr
11
+ # import aiofiles
14
12
import docker
15
13
from celery .utils .log import get_task_logger
14
+ from pydantic import BaseModel
16
15
from sqlalchemy import and_ , exc
17
16
18
17
from servicelib .utils import logged_gather
26
25
from . import config
27
26
from .rabbitmq import RabbitMQ
28
27
from .utils import (DbSettings , DockerSettings , ExecutorSettings , S3Settings ,
29
- find_entry_point , is_node_ready , safe_channel )
28
+ find_entry_point , is_node_ready )
30
29
31
30
log = get_task_logger (__name__ )
32
31
log .setLevel (config .SIDECAR_LOGLEVEL )
@@ -48,8 +47,7 @@ def session_scope(session_factory):
48
47
finally :
49
48
session .close ()
50
49
51
- @attr .s (auto_attribs = True )
52
- class Sidecar : # pylint: disable=too-many-instance-attributes
50
+ class Sidecar (BaseModel ):
53
51
_rabbit_mq : RabbitMQ
54
52
_docker : DockerSettings = DockerSettings ()
55
53
_s3 : S3Settings = S3Settings ()
@@ -153,34 +151,6 @@ async def _pull_image(self):
153
151
log .exception (msg )
154
152
raise docker .errors .APIError (msg )
155
153
156
- async def _post_log (
157
- self , channel : pika .channel .Channel , msg : Union [str , List [str ]]
158
- ):
159
- log_data = {
160
- "Channel" : "Log" ,
161
- "Node" : self ._task .node_id ,
162
- "user_id" : self ._user_id ,
163
- "project_id" : self ._task .project_id ,
164
- "Messages" : msg if isinstance (msg , list ) else [msg ],
165
- }
166
- log_body = json .dumps (log_data )
167
- channel .basic_publish (
168
- exchange = self ._pika .log_channel , routing_key = "" , body = log_body
169
- )
170
-
171
- async def _post_progress (self , channel , progress ):
172
- prog_data = {
173
- "Channel" : "Progress" ,
174
- "Node" : self ._task .node_id ,
175
- "user_id" : self ._user_id ,
176
- "project_id" : self ._task .project_id ,
177
- "Progress" : progress ,
178
- }
179
- prog_body = json .dumps (prog_data )
180
- channel .basic_publish (
181
- exchange = self ._pika .progress_channel , routing_key = "" , body = prog_body
182
- )
183
-
184
154
async def log_file_processor (self , log_file : Path ) -> None :
185
155
"""checks both container logs and the log_file if any
186
156
"""
@@ -208,7 +178,6 @@ async def log_file_processor(self, log_file: Path) -> None:
208
178
209
179
210
180
# try:
211
- # import pdb; pdb.set_trace()
212
181
# TIME_BETWEEN_LOGS_S: int = 2
213
182
# time_logs_sent = time.monotonic()
214
183
# accumulated_logs = []
@@ -508,25 +477,17 @@ async def run(self):
508
477
self ._task .node_id ,
509
478
self ._task .internal_id ,
510
479
)
511
- # NOTE: the rabbit has a timeout of 60seconds so blocking this channel for more is a no go.
512
-
513
- with safe_channel (self ._pika ) as (channel , _ ):
514
- await self ._post_log (channel , msg = "Preprocessing start..." )
515
-
480
+ await self ._rabbit_mq .post_log_message ("Preprocessing start..." )
516
481
await self .preprocess ()
517
-
518
- with safe_channel (self ._pika ) as (channel , _ ):
519
- await self ._post_log (channel , msg = "...preprocessing end" )
520
- await self ._post_log (channel , msg = "Processing start..." )
482
+ await self ._rabbit_mq .post_log_message ("...preprocessing end" )
483
+
484
+ await self ._rabbit_mq .post_log_message ("Processing start..." )
521
485
await self .process ()
486
+ await self ._rabbit_mq .post_log_message ("...processing end" )
522
487
523
- with safe_channel (self ._pika ) as (channel , _ ):
524
- await self ._post_log (channel , msg = "...processing end" )
525
- await self ._post_log (channel , msg = "Postprocessing start..." )
488
+ await self ._rabbit_mq .post_log_message ("Postprocessing start..." )
526
489
await self .postprocess ()
527
-
528
- with safe_channel (self ._pika ) as (channel , _ ):
529
- await self ._post_log (channel , msg = "...postprocessing end" )
490
+ await self ._rabbit_mq .post_log_message ("...postprocessing end" )
530
491
531
492
log .debug (
532
493
"Running Pipeline DONE %s:node %s:internal id %s from container" ,
@@ -617,7 +578,7 @@ async def inspect(self, job_request_id: int, user_id: str, project_id: str, node
617
578
return next_task_nodes
618
579
619
580
# the task is ready!
620
- task .job_id = celery_task . request . id
581
+ task .job_id = job_request_id
621
582
_session .add (task )
622
583
_session .commit ()
623
584
@@ -632,7 +593,7 @@ async def inspect(self, job_request_id: int, user_id: str, project_id: str, node
632
593
.one ()
633
594
)
634
595
635
- if task .job_id != celery_task . request . id :
596
+ if task .job_id != job_request_id :
636
597
# somebody else was faster
637
598
return next_task_nodes
638
599
task .state = RUNNING
0 commit comments