5
5
from types import TracebackType
6
6
from typing import (
7
7
Any , AsyncContextManager , AsyncIterable , Awaitable , Callable , Dict ,
8
- FrozenSet , Generator , Iterator , MutableMapping , NamedTuple , Optional , Set ,
8
+ FrozenSet , Iterator , MutableMapping , NamedTuple , Optional , Set ,
9
9
Tuple , Type , TypeVar , Union ,
10
10
)
11
11
12
- import aiormq
12
+ import aiormq . abc
13
13
from aiormq .abc import ExceptionType
14
14
from pamqp .common import Arguments
15
15
from yarl import URL
16
16
17
17
from .pool import PoolInstance
18
- from .tools import CallbackCollection , CallbackSetType , CallbackType
19
-
18
+ from .tools import (
19
+ CallbackCollection , CallbackSetType , CallbackType , OneShotCallback
20
+ )
20
21
21
22
TimeoutType = Optional [Union [int , float ]]
22
23
@@ -219,7 +220,6 @@ async def __aexit__(
219
220
220
221
class AbstractQueue :
221
222
channel : "AbstractChannel"
222
- connection : "AbstractConnection"
223
223
name : str
224
224
durable : bool
225
225
exclusive : bool
@@ -228,6 +228,19 @@ class AbstractQueue:
228
228
passive : bool
229
229
declaration_result : aiormq .spec .Queue .DeclareOk
230
230
231
+ @abstractmethod
232
+ def __init__ (
233
+ self ,
234
+ channel : "AbstractChannel" ,
235
+ name : Optional [str ],
236
+ durable : bool ,
237
+ exclusive : bool ,
238
+ auto_delete : bool ,
239
+ arguments : Arguments ,
240
+ passive : bool = False ,
241
+ ):
242
+ raise NotImplementedError
243
+
231
244
@abstractmethod
232
245
async def declare (
233
246
self , timeout : TimeoutType = None ,
@@ -341,6 +354,21 @@ async def __anext__(self) -> AbstractIncomingMessage:
341
354
342
355
343
356
class AbstractExchange (ABC ):
357
+ @abstractmethod
358
+ def __init__ (
359
+ self ,
360
+ channel : "AbstractChannel" ,
361
+ name : str ,
362
+ type : Union [ExchangeType , str ] = ExchangeType .DIRECT ,
363
+ * ,
364
+ auto_delete : bool = False ,
365
+ durable : bool = False ,
366
+ internal : bool = False ,
367
+ passive : bool = False ,
368
+ arguments : Arguments = None
369
+ ):
370
+ raise NotImplementedError
371
+
344
372
@property
345
373
@abstractmethod
346
374
def channel (self ) -> "AbstractChannel" :
@@ -392,20 +420,46 @@ async def delete(
392
420
raise NotImplementedError
393
421
394
422
423
+ class UnderlayChannel (NamedTuple ):
424
+ channel : aiormq .abc .AbstractChannel
425
+ close_callback : OneShotCallback
426
+
427
+ @classmethod
428
+ async def create_channel (
429
+ cls , transport : "UnderlayConnection" ,
430
+ close_callback : Callable [..., Awaitable [Any ]], ** kwargs : Any
431
+ ) -> "UnderlayChannel" :
432
+ close_callback = OneShotCallback (close_callback )
433
+
434
+ await transport .connection .ready ()
435
+ transport .connection .closing .add_done_callback (close_callback )
436
+ channel = await transport .connection .channel (** kwargs )
437
+ channel .closing .add_done_callback (close_callback )
438
+
439
+ return cls (
440
+ channel = channel ,
441
+ close_callback = close_callback ,
442
+ )
443
+
444
+ async def close (self , exc : Optional [ExceptionType ] = None ) -> Any :
445
+ result : Any
446
+ result , _ = await asyncio .gather (
447
+ self .channel .close (exc ), self .close_callback .wait ()
448
+ )
449
+ return result
450
+
451
+
395
452
class AbstractChannel (PoolInstance , ABC ):
396
453
QUEUE_CLASS : Type [AbstractQueue ]
397
454
EXCHANGE_CLASS : Type [AbstractExchange ]
398
455
399
456
close_callbacks : CallbackCollection
400
457
return_callbacks : CallbackCollection
401
- connection : "AbstractConnection"
458
+ ready : asyncio . Event
402
459
loop : asyncio .AbstractEventLoop
403
460
default_exchange : AbstractExchange
404
461
405
- @property
406
- @abstractmethod
407
- def done_callbacks (self ) -> CallbackCollection :
408
- raise NotImplementedError
462
+ publisher_confirms : bool
409
463
410
464
@property
411
465
@abstractmethod
@@ -431,10 +485,6 @@ def channel(self) -> aiormq.abc.AbstractChannel:
431
485
def number (self ) -> Optional [int ]:
432
486
raise NotImplementedError
433
487
434
- @abstractmethod
435
- def __await__ (self ) -> Generator [Any , Any , "AbstractChannel" ]:
436
- raise NotImplementedError
437
-
438
488
@abstractmethod
439
489
async def __aenter__ (self ) -> "AbstractChannel" :
440
490
raise NotImplementedError
@@ -537,19 +587,50 @@ def transaction(self) -> AbstractTransaction:
537
587
async def flow (self , active : bool = True ) -> aiormq .spec .Channel .FlowOk :
538
588
raise NotImplementedError
539
589
590
+ @abstractmethod
591
+ def __await__ (self ) -> Awaitable ["AbstractChannel" ]:
592
+ raise NotImplementedError
593
+
594
+
595
+ class UnderlayConnection (NamedTuple ):
596
+ connection : aiormq .abc .AbstractConnection
597
+ close_callback : OneShotCallback
598
+
599
+ @classmethod
600
+ async def connect (
601
+ cls , url : URL , close_callback : Callable [..., Awaitable [Any ]],
602
+ timeout : TimeoutType = None , ** kwargs : Any
603
+ ) -> "UnderlayConnection" :
604
+ connection : aiormq .abc .AbstractConnection = await asyncio .wait_for (
605
+ aiormq .connect (url , ** kwargs ), timeout = timeout ,
606
+ )
607
+ close_callback = OneShotCallback (close_callback )
608
+ connection .closing .add_done_callback (close_callback )
609
+ await connection .ready ()
610
+ return cls (
611
+ connection = connection ,
612
+ close_callback = close_callback
613
+ )
614
+
615
+ async def close (self , exc : Optional [aiormq .abc .ExceptionType ]):
616
+ result , _ = await asyncio .gather (
617
+ self .connection .close (exc ), self .close_callback .wait ()
618
+ )
619
+ return result
620
+
540
621
541
622
class AbstractConnection (PoolInstance , ABC ):
542
623
loop : asyncio .AbstractEventLoop
543
624
close_callbacks : CallbackCollection
544
625
connected : asyncio .Event
545
- connection : aiormq . abc . AbstractConnection
626
+ transport : UnderlayConnection
546
627
547
628
@abstractmethod
548
629
def __init__ (
549
630
self , url : URL , loop : Optional [asyncio .AbstractEventLoop ] = None ,
550
631
** kwargs : Any
551
632
):
552
- NotImplementedError (
633
+ raise NotImplementedError (
553
634
f"Method not implemented, passed: url={ url } , loop={ loop !r} " ,
554
635
)
555
636
@@ -748,5 +829,7 @@ def channel(
748
829
"MILLISECONDS" ,
749
830
"TimeoutType" ,
750
831
"TransactionState" ,
832
+ "UnderlayChannel" ,
833
+ "UnderlayConnection" ,
751
834
"ZERO_TIME" ,
752
835
)
0 commit comments