1
+ import json
1
2
from langchain_core .messages import BaseMessage , AIMessage , AIMessageChunk , ToolMessage
2
3
from rich .console import Console , ConsoleDimensions
3
4
from rich .live import Live
4
5
from rich .markdown import Markdown
5
6
from rich .prompt import Confirm
6
7
7
8
class OutputHandler :
8
- def __init__ (self , text_only : bool = False ):
9
+ def __init__ (self , text_only : bool = False , only_last_message : bool = False ):
9
10
self .console = Console ()
10
11
self .text_only = text_only
12
+ self .only_last_message = only_last_message
13
+ self .last_message = ""
11
14
if self .text_only :
12
15
self .md = ""
13
16
else :
@@ -26,6 +29,9 @@ def start(self):
26
29
27
30
def update (self , chunk : any ):
28
31
self .md = self ._parse_chunk (chunk , self .md )
32
+ if (self .only_last_message and self .text_only ):
33
+ # when only_last_message, we print in finish()
34
+ return
29
35
if self .text_only :
30
36
self .console .print (self ._parse_chunk (chunk ), end = "" )
31
37
else :
@@ -36,9 +42,13 @@ def update(self, chunk: any):
36
42
37
43
def update_error (self , error : Exception ):
38
44
import traceback
39
- self .md += f"Error: { error } \n \n Stack trace:\n ```\n { traceback .format_exc ()} ```"
45
+ error = f"Error: { error } \n \n Stack trace:\n ```\n { traceback .format_exc ()} ```"
46
+ self .md += error ;
47
+ if (self .only_last_message ):
48
+ self .console .print (error )
49
+ return ;
40
50
if self .text_only :
41
- self .console .print (self .md )
51
+ self .console .print_exception (self .md )
42
52
else :
43
53
partial_md = self ._truncate_md_to_fit (self .md , self .console .size )
44
54
self ._live .update (Markdown (partial_md ), refresh = True )
@@ -63,9 +73,12 @@ def confirm_tool_call(self, config: dict, chunk: any) -> bool:
63
73
64
74
def finish (self ):
65
75
self .stop ()
66
- if not self .text_only :
76
+ to_print = self .last_message if self .only_last_message else Markdown (self .md )
77
+ if not self .text_only and not self .only_last_message :
67
78
self .console .clear ()
68
79
self .console .print (Markdown (self .md ))
80
+ if self .only_last_message :
81
+ self .console .print (to_print )
69
82
70
83
def _parse_chunk (self , chunk : any , md : str = "" ) -> str :
71
84
"""
@@ -77,6 +90,7 @@ def _parse_chunk(self, chunk: any, md: str = "") -> str:
77
90
message_chunk = chunk [1 ][0 ] # Get the message content
78
91
if isinstance (message_chunk , AIMessageChunk ):
79
92
content = message_chunk .content
93
+ self .last_message += content
80
94
if isinstance (content , str ):
81
95
md += content
82
96
elif isinstance (content , list ) and len (content ) > 0 and isinstance (content [0 ], dict ) and "text" in content [0 ]:
@@ -85,6 +99,7 @@ def _parse_chunk(self, chunk: any, md: str = "") -> str:
85
99
elif isinstance (chunk , dict ) and "messages" in chunk :
86
100
# Print a newline after the complete message
87
101
md += "\n "
102
+ self .last_message = ""
88
103
elif isinstance (chunk , tuple ) and chunk [0 ] == "values" :
89
104
message : BaseMessage = chunk [1 ]['messages' ][- 1 ]
90
105
if isinstance (message , AIMessage ) and message .tool_calls :
@@ -108,6 +123,7 @@ def _parse_chunk(self, chunk: any, md: str = "") -> str:
108
123
lines .append (f"{ arg } : { value } " )
109
124
lines .append ("```\n " )
110
125
md += "\n " .join (lines )
126
+ self .last_message = ""
111
127
elif isinstance (message , ToolMessage ) and message .status != "success" :
112
128
md += "Failed call with error:"
113
129
md += f"\n \n { message .content } "
0 commit comments