@@ -385,6 +385,9 @@ def __init__(self, completekey='tab', stdin=None, stdout=None, skip=None,
385385self .commands_bnum = None # The breakpoint number for which we are
386386# defining a list
387387
388+ self .async_shim_frame = None
389+ self .async_awaitable = None
390+
388391self ._chained_exceptions = tuple ()
389392self ._chained_exception_index = 0
390393
@@ -400,6 +403,57 @@ def set_trace(self, frame=None, *, commands=None):
400403
401404super ().set_trace (frame )
402405
406+ async def set_trace_async (self ,frame = None ,* ,commands = None ):
407+ if self .async_awaitable is not None :
408+ # We are already in a set_trace_async call, do not mess with it
409+ return
410+
411+ if frame is None :
412+ frame = sys ._getframe ().f_back
413+
414+ # We need set_trace to set up the basics, however, this will call
415+ # set_stepinstr() will we need to compensate for, because we don't
416+ # want to trigger on calls
417+ self .set_trace (frame ,commands = commands )
418+ # Changing the stopframe will disable trace dispatch on calls
419+ self .stopframe = frame
420+ # We need to stop tracing because we don't have the privilege to avoid
421+ # triggering tracing functions as normal, as we are not already in
422+ # tracing functions
423+ self .stop_trace ()
424+
425+ self .async_shim_frame = sys ._getframe ()
426+ self .async_awaitable = None
427+
428+ while True :
429+ self .async_awaitable = None
430+ # Simulate a trace event
431+ # This should bring up pdb and make pdb believe it's debugging the
432+ # caller frame
433+ self .trace_dispatch (frame ,"opcode" ,None )
434+ if self .async_awaitable is not None :
435+ try :
436+ if self .breaks :
437+ with self .set_enterframe (frame ):
438+ # set_continue requires enterframe to work
439+ self .set_continue ()
440+ self .start_trace ()
441+ await self .async_awaitable
442+ except Exception :
443+ self ._error_exc ()
444+ else :
445+ break
446+
447+ self .async_shim_frame = None
448+
449+ # start the trace (the actual command is already set by set_* calls)
450+ if self .returnframe is None and self .stoplineno == - 1 and not self .breaks :
451+ # This means we did a continue without any breakpoints, we should not
452+ # start the trace
453+ return
454+
455+ self .start_trace ()
456+
403457def sigint_handler (self ,signum ,frame ):
404458if self .allow_kbdint :
405459raise KeyboardInterrupt
@@ -782,12 +836,25 @@ def _exec_in_closure(self, source, globals, locals):
782836
783837return True
784838
785- def default (self ,line ):
786- if line [:1 ]== '!' :line = line [1 :].strip ()
787- locals = self .curframe .f_locals
788- globals = self .curframe .f_globals
839+ def _exec_await (self ,source ,globals ,locals ):
840+ """ Run source code that contains await by playing with async shim frame"""
841+ # Put the source in an async function
842+ source_async = (
843+ "async def __pdb_await():\n " +
844+ textwrap .indent (source ," " )+ '\n ' +
845+ " __pdb_locals.update(locals())"
846+ )
847+ ns = globals | locals
848+ # We use __pdb_locals to do write back
849+ ns ["__pdb_locals" ]= locals
850+ exec (source_async ,ns )
851+ self .async_awaitable = ns ["__pdb_await" ]()
852+
853+ def _read_code (self ,line ):
854+ buffer = line
855+ is_await_code = False
856+ code = None
789857try :
790- buffer = line
791858if (code := codeop .compile_command (line + '\n ' ,'<stdin>' ,'single' ))is None :
792859# Multi-line mode
793860with self ._enable_multiline_completion ():
@@ -800,7 +867,7 @@ def default(self, line):
800867except (EOFError ,KeyboardInterrupt ):
801868self .lastcmd = ""
802869print ('\n ' )
803- return
870+ return None , None , False
804871else :
805872self .stdout .write (continue_prompt )
806873self .stdout .flush ()
@@ -809,20 +876,44 @@ def default(self, line):
809876self .lastcmd = ""
810877self .stdout .write ('\n ' )
811878self .stdout .flush ()
812- return
879+ return None , None , False
813880else :
814881line = line .rstrip ('\r \n ' )
815882buffer += '\n ' + line
816883self .lastcmd = buffer
884+ except SyntaxError as e :
885+ # Maybe it's an await expression/statement
886+ if (
887+ self .async_shim_frame is not None
888+ and e .msg == "'await' outside function"
889+ ):
890+ is_await_code = True
891+ else :
892+ raise
893+
894+ return code ,buffer ,is_await_code
895+
896+ def default (self ,line ):
897+ if line [:1 ]== '!' :line = line [1 :].strip ()
898+ locals = self .curframe .f_locals
899+ globals = self .curframe .f_globals
900+ try :
901+ code ,buffer ,is_await_code = self ._read_code (line )
902+ if buffer is None :
903+ return
817904save_stdout = sys .stdout
818905save_stdin = sys .stdin
819906save_displayhook = sys .displayhook
820907try :
821908sys .stdin = self .stdin
822909sys .stdout = self .stdout
823910sys .displayhook = self .displayhook
824- if not self ._exec_in_closure (buffer ,globals ,locals ):
825- exec (code ,globals ,locals )
911+ if is_await_code :
912+ self ._exec_await (buffer ,globals ,locals )
913+ return True
914+ else :
915+ if not self ._exec_in_closure (buffer ,globals ,locals ):
916+ exec (code ,globals ,locals )
826917finally :
827918sys .stdout = save_stdout
828919sys .stdin = save_stdin
@@ -2501,6 +2592,21 @@ def set_trace(*, header=None, commands=None):
25012592pdb .message (header )
25022593pdb .set_trace (sys ._getframe ().f_back ,commands = commands )
25032594
2595+ async def set_trace_async (* ,header = None ,commands = None ):
2596+ """Enter the debugger at the calling stack frame, but in async mode.
2597+
2598+ This should be used as await pdb.set_trace_async(). Users can do await
2599+ if they enter the debugger with this function. Otherwise it's the same
2600+ as set_trace().
2601+ """
2602+ if Pdb ._last_pdb_instance is not None :
2603+ pdb = Pdb ._last_pdb_instance
2604+ else :
2605+ pdb = Pdb (mode = 'inline' ,backend = 'monitoring' )
2606+ if header is not None :
2607+ pdb .message (header )
2608+ await pdb .set_trace_async (sys ._getframe ().f_back ,commands = commands )
2609+
25042610# Remote PDB
25052611
25062612class _PdbServer (Pdb ):