@@ -385,6 +385,9 @@ def __init__(self, completekey='tab', stdin=None, stdout=None, skip=None,
385
385
self .commands_bnum = None # The breakpoint number for which we are
386
386
# defining a list
387
387
388
+ self .async_shim_frame = None
389
+ self .async_awaitable = None
390
+
388
391
self ._chained_exceptions = tuple ()
389
392
self ._chained_exception_index = 0
390
393
@@ -400,6 +403,57 @@ def set_trace(self, frame=None, *, commands=None):
400
403
401
404
super ().set_trace (frame )
402
405
406
+ async def set_trace_async (self , frame = None , * , commands = None ):
407
+ if self .async_awaitable is not None :
408
+ # We are already in a set_trace_async call, do not mess with it
409
+ return
410
+
411
+ if frame is None :
412
+ frame = sys ._getframe ().f_back
413
+
414
+ # We need set_trace to set up the basics, however, this will call
415
+ # set_stepinstr() will we need to compensate for, because we don't
416
+ # want to trigger on calls
417
+ self .set_trace (frame , commands = commands )
418
+ # Changing the stopframe will disable trace dispatch on calls
419
+ self .stopframe = frame
420
+ # We need to stop tracing because we don't have the privilege to avoid
421
+ # triggering tracing functions as normal, as we are not already in
422
+ # tracing functions
423
+ self .stop_trace ()
424
+
425
+ self .async_shim_frame = sys ._getframe ()
426
+ self .async_awaitable = None
427
+
428
+ while True :
429
+ self .async_awaitable = None
430
+ # Simulate a trace event
431
+ # This should bring up pdb and make pdb believe it's debugging the
432
+ # caller frame
433
+ self .trace_dispatch (frame , "opcode" , None )
434
+ if self .async_awaitable is not None :
435
+ try :
436
+ if self .breaks :
437
+ with self .set_enterframe (frame ):
438
+ # set_continue requires enterframe to work
439
+ self .set_continue ()
440
+ self .start_trace ()
441
+ await self .async_awaitable
442
+ except Exception :
443
+ self ._error_exc ()
444
+ else :
445
+ break
446
+
447
+ self .async_shim_frame = None
448
+
449
+ # start the trace (the actual command is already set by set_* calls)
450
+ if self .returnframe is None and self .stoplineno == - 1 and not self .breaks :
451
+ # This means we did a continue without any breakpoints, we should not
452
+ # start the trace
453
+ return
454
+
455
+ self .start_trace ()
456
+
403
457
def sigint_handler (self , signum , frame ):
404
458
if self .allow_kbdint :
405
459
raise KeyboardInterrupt
@@ -782,12 +836,25 @@ def _exec_in_closure(self, source, globals, locals):
782
836
783
837
return True
784
838
785
- def default (self , line ):
786
- if line [:1 ] == '!' : line = line [1 :].strip ()
787
- locals = self .curframe .f_locals
788
- globals = self .curframe .f_globals
839
+ def _exec_await (self , source , globals , locals ):
840
+ """ Run source code that contains await by playing with async shim frame"""
841
+ # Put the source in an async function
842
+ source_async = (
843
+ "async def __pdb_await():\n " +
844
+ textwrap .indent (source , " " ) + '\n ' +
845
+ " __pdb_locals.update(locals())"
846
+ )
847
+ ns = globals | locals
848
+ # We use __pdb_locals to do write back
849
+ ns ["__pdb_locals" ] = locals
850
+ exec (source_async , ns )
851
+ self .async_awaitable = ns ["__pdb_await" ]()
852
+
853
+ def _read_code (self , line ):
854
+ buffer = line
855
+ is_await_code = False
856
+ code = None
789
857
try :
790
- buffer = line
791
858
if (code := codeop .compile_command (line + '\n ' , '<stdin>' , 'single' )) is None :
792
859
# Multi-line mode
793
860
with self ._enable_multiline_completion ():
@@ -800,7 +867,7 @@ def default(self, line):
800
867
except (EOFError , KeyboardInterrupt ):
801
868
self .lastcmd = ""
802
869
print ('\n ' )
803
- return
870
+ return None , None , False
804
871
else :
805
872
self .stdout .write (continue_prompt )
806
873
self .stdout .flush ()
@@ -809,20 +876,44 @@ def default(self, line):
809
876
self .lastcmd = ""
810
877
self .stdout .write ('\n ' )
811
878
self .stdout .flush ()
812
- return
879
+ return None , None , False
813
880
else :
814
881
line = line .rstrip ('\r \n ' )
815
882
buffer += '\n ' + line
816
883
self .lastcmd = buffer
884
+ except SyntaxError as e :
885
+ # Maybe it's an await expression/statement
886
+ if (
887
+ self .async_shim_frame is not None
888
+ and e .msg == "'await' outside function"
889
+ ):
890
+ is_await_code = True
891
+ else :
892
+ raise
893
+
894
+ return code , buffer , is_await_code
895
+
896
+ def default (self , line ):
897
+ if line [:1 ] == '!' : line = line [1 :].strip ()
898
+ locals = self .curframe .f_locals
899
+ globals = self .curframe .f_globals
900
+ try :
901
+ code , buffer , is_await_code = self ._read_code (line )
902
+ if buffer is None :
903
+ return
817
904
save_stdout = sys .stdout
818
905
save_stdin = sys .stdin
819
906
save_displayhook = sys .displayhook
820
907
try :
821
908
sys .stdin = self .stdin
822
909
sys .stdout = self .stdout
823
910
sys .displayhook = self .displayhook
824
- if not self ._exec_in_closure (buffer , globals , locals ):
825
- exec (code , globals , locals )
911
+ if is_await_code :
912
+ self ._exec_await (buffer , globals , locals )
913
+ return True
914
+ else :
915
+ if not self ._exec_in_closure (buffer , globals , locals ):
916
+ exec (code , globals , locals )
826
917
finally :
827
918
sys .stdout = save_stdout
828
919
sys .stdin = save_stdin
@@ -2501,6 +2592,21 @@ def set_trace(*, header=None, commands=None):
2501
2592
pdb .message (header )
2502
2593
pdb .set_trace (sys ._getframe ().f_back , commands = commands )
2503
2594
2595
+ async def set_trace_async (* , header = None , commands = None ):
2596
+ """Enter the debugger at the calling stack frame, but in async mode.
2597
+
2598
+ This should be used as await pdb.set_trace_async(). Users can do await
2599
+ if they enter the debugger with this function. Otherwise it's the same
2600
+ as set_trace().
2601
+ """
2602
+ if Pdb ._last_pdb_instance is not None :
2603
+ pdb = Pdb ._last_pdb_instance
2604
+ else :
2605
+ pdb = Pdb (mode = 'inline' , backend = 'monitoring' )
2606
+ if header is not None :
2607
+ pdb .message (header )
2608
+ await pdb .set_trace_async (sys ._getframe ().f_back , commands = commands )
2609
+
2504
2610
# Remote PDB
2505
2611
2506
2612
class _PdbServer (Pdb ):
0 commit comments