1
1
# SPDX-License-Identifier: Apache-2.0
2
2
import asyncio
3
+ import contextlib
3
4
import queue
4
5
import uuid
5
6
import weakref
6
7
from abc import ABC , abstractmethod
8
+ from collections import deque
7
9
from collections .abc import Awaitable , Sequence
8
10
from concurrent .futures import Future
9
11
from dataclasses import dataclass , field
@@ -396,6 +398,12 @@ def __init__(
396
398
self ._wait_for_engine_startup ()
397
399
398
400
self .utility_results : dict [int , AnyFuture ] = {}
401
+
402
+ # Request objects which may contain pytorch-allocated tensors
403
+ # that we need to keep references to until zmq is done with the
404
+ # underlying data.
405
+ self .pending_messages = deque [tuple [zmq .MessageTracker , Any ]]()
406
+
399
407
success = True
400
408
finally :
401
409
if not success :
@@ -459,6 +467,14 @@ def ensure_alive(self):
459
467
if self .resources .engine_dead :
460
468
raise EngineDeadError ()
461
469
470
+ def add_pending_message (self , tracker : zmq .MessageTracker , msg : Any ):
471
+ if not tracker .done :
472
+ self .pending_messages .appendleft ((tracker , msg ))
473
+
474
+ def free_pending_messages (self ):
475
+ while self .pending_messages and self .pending_messages [- 1 ][0 ].done :
476
+ self .pending_messages .pop ()
477
+
462
478
463
479
def _process_utility_output (output : UtilityOutput ,
464
480
utility_results : dict [int , AnyFuture ]):
@@ -544,10 +560,18 @@ def get_output(self) -> EngineCoreOutputs:
544
560
545
561
def _send_input (self , request_type : EngineCoreRequestType , request : Any ):
546
562
self .ensure_alive ()
563
+ self .free_pending_messages ()
547
564
# (Identity, RequestType, SerializedRequest)
548
565
msg = (self .core_engine .identity , request_type .value ,
549
566
* self .encoder .encode (request ))
550
- self .input_socket .send_multipart (msg , copy = False )
567
+
568
+ if len (msg ) <= 3 :
569
+ # No auxiliary buffers => no tensor backing buffers in request.
570
+ self .input_socket .send_multipart (msg , copy = False )
571
+ return
572
+
573
+ tracker = self .input_socket .send_multipart (msg , copy = False , track = True )
574
+ self .add_pending_message (tracker , request )
551
575
552
576
def call_utility (self , method : str , * args ) -> Any :
553
577
call_id = uuid .uuid1 ().int >> 64
@@ -698,19 +722,38 @@ async def get_output_async(self) -> EngineCoreOutputs:
698
722
def _send_input (self ,
699
723
request_type : EngineCoreRequestType ,
700
724
request : Any ,
701
- engine : Optional [CoreEngine ] = None ) -> Awaitable [None ]:
725
+ engine : Optional [CoreEngine ] = None ) -> Awaitable [Any ]:
702
726
self .ensure_alive ()
703
727
if engine is None :
704
728
engine = self .core_engine
705
729
706
730
message = (request_type .value , * self .encoder .encode (request ))
707
- return self ._send_input_message (message , engine )
708
-
709
- def _send_input_message (self , message : tuple [bytestr , ...],
710
- engine : CoreEngine ) -> Awaitable [None ]:
731
+ return self ._send_input_message (message , engine , request )
732
+
733
+ def _send_input_message (self , message : tuple [bytestr ,
734
+ ...], engine : CoreEngine ,
735
+ objects : Any ) -> Awaitable [Any ]:
736
+ """
737
+ objects is a reference to retain until zmq is finished with the
738
+ buffers, in case they were extracted from tensors in the request.
739
+ """
711
740
self .ensure_alive ()
712
- message = (engine .identity , ) + message
713
- return self .input_socket .send_multipart (message , copy = False )
741
+ self .free_pending_messages ()
742
+
743
+ msg = (engine .identity , ) + message
744
+ if not objects or len (msg ) <= 3 :
745
+ # No auxiliary buffers => no tensor backing buffers in request.
746
+ return self .input_socket .send_multipart (msg , copy = False )
747
+
748
+ future : asyncio .Future [zmq .MessageTracker ]
749
+ future = self .input_socket .send_multipart (msg , copy = False , track = True )
750
+
751
+ def add_pending (f : asyncio .Future [zmq .MessageTracker ]):
752
+ with contextlib .suppress (BaseException ):
753
+ self .add_pending_message (f .result (), objects )
754
+
755
+ future .add_done_callback (add_pending )
756
+ return future
714
757
715
758
async def call_utility_async (self , method : str , * args ) -> Any :
716
759
return await self ._call_utility_async (method ,
@@ -724,7 +767,7 @@ async def _call_utility_async(self, method: str, *args,
724
767
self .utility_results [call_id ] = future
725
768
message = (EngineCoreRequestType .UTILITY .value , * self .encoder .encode (
726
769
(call_id , method , args )))
727
- await self ._send_input_message (message , engine )
770
+ await self ._send_input_message (message , engine , args )
728
771
self ._ensure_output_queue_task ()
729
772
return await future
730
773
0 commit comments