Skip to content

Commit 7615c7b

Browse files
authored
Rename to use BaseChatMessage and BaseAgentEvent. Bring back union types. (#6144)
Rename the `ChatMessage` and `AgentEvent` base classes to `BaseChatMessage` and `BaseAgentEvent`. Bring back the `ChatMessage` and `AgentEvent` as union of built-in concrete types to avoid breaking existing applications that depends on Pydantic serialization. Why? Many existing code uses containers like this: ```python class AppMessage(BaseModel): name: str message: ChatMessage # Serialization is this: m = AppMessage(...) m.model_dump_json() # Fields like HandoffMessage.target will be lost because it is now treated as a base class without content or target fields. ``` The assumption on `ChatMessage` or `AgentEvent` to be a union of concrete types could be in many existing code bases. So this PR brings back the union types, while keep method type hints such as those on `on_messages` to use the `BaseChatMessage` and `BaseAgentEvent` base classes for flexibility.
1 parent e686342 commit 7615c7b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+1533
-1443
lines changed

python/packages/agbench/benchmarks/GAIA/Templates/SelectorGroupChat/scenario.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from autogen_ext.agents.web_surfer import MultimodalWebSurfer
1717
from autogen_ext.agents.file_surfer import FileSurfer
1818
from autogen_agentchat.agents import CodeExecutorAgent
19-
from autogen_agentchat.messages import TextMessage, AgentEvent, ChatMessage, HandoffMessage, MultiModalMessage, StopMessage
19+
from autogen_agentchat.messages import TextMessage, BaseAgentEvent, BaseChatMessage, HandoffMessage, MultiModalMessage, StopMessage
2020
from autogen_core.models import LLMMessage, UserMessage, AssistantMessage
2121

2222
# Suppress warnings about the requests.Session() not being closed
@@ -141,7 +141,7 @@ def __init__(self, prompt: str, model_client: ChatCompletionClient, termination_
141141
def terminated(self) -> bool:
142142
return self._terminated
143143

144-
async def __call__(self, messages: Sequence[AgentEvent | ChatMessage]) -> StopMessage | None:
144+
async def __call__(self, messages: Sequence[BaseAgentEvent | BaseChatMessage]) -> StopMessage | None:
145145
if self._terminated:
146146
raise TerminatedException("Termination condition has already been reached")
147147

Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# __init__.py
2-
from ._base import Code, Document, CodedDocument, BaseQualitativeCoder
2+
from ._base import BaseQualitativeCoder, Code, CodedDocument, Document
33

44
__all__ = ["Code", "Document", "CodedDocument", "BaseQualitativeCoder"]

python/packages/agbench/src/agbench/linter/_base.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import json
21
import hashlib
2+
import json
33
import re
4-
from typing import Protocol, List, Set, Optional
4+
from typing import List, Optional, Protocol, Set
5+
56
from pydantic import BaseModel, Field
67

78

python/packages/agbench/src/agbench/linter/cli.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
import os
21
import argparse
3-
from typing import List, Sequence, Optional
2+
import os
3+
from typing import List, Optional, Sequence
4+
45
from openai import OpenAI
5-
from ._base import Document, CodedDocument
6+
7+
from ._base import CodedDocument, Document
68
from .coders.oai_coder import OAIQualitativeCoder
79

810

python/packages/agbench/src/agbench/linter/coders/oai_coder.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import os
22
import re
3-
4-
from typing import List, Set, Optional
5-
from pydantic import BaseModel
3+
from typing import List, Optional, Set
64

75
from openai import OpenAI
6+
from pydantic import BaseModel
87

9-
from .._base import CodedDocument, Document, Code
10-
from .._base import BaseQualitativeCoder
8+
from .._base import BaseQualitativeCoder, Code, CodedDocument, Document
119

1210

1311
class CodeList(BaseModel):

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
from ..base import Handoff as HandoffBase
4141
from ..base import Response
4242
from ..messages import (
43-
AgentEvent,
44-
ChatMessage,
43+
BaseAgentEvent,
44+
BaseChatMessage,
4545
HandoffMessage,
4646
MemoryQueryEvent,
4747
ModelClientStreamingChunkEvent,
@@ -697,8 +697,8 @@ def __init__(
697697
self._is_running = False
698698

699699
@property
700-
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
701-
message_types: List[type[ChatMessage]] = [TextMessage]
700+
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
701+
message_types: List[type[BaseChatMessage]] = [TextMessage]
702702
if self._handoffs:
703703
message_types.append(HandoffMessage)
704704
if self._tools:
@@ -712,15 +712,15 @@ def model_context(self) -> ChatCompletionContext:
712712
"""
713713
return self._model_context
714714

715-
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
715+
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
716716
async for message in self.on_messages_stream(messages, cancellation_token):
717717
if isinstance(message, Response):
718718
return message
719719
raise AssertionError("The stream should have returned the final result.")
720720

721721
async def on_messages_stream(
722-
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
723-
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
722+
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
723+
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
724724
"""
725725
Process the incoming messages with the assistant agent and yield events/responses as they happen.
726726
"""
@@ -745,7 +745,7 @@ async def on_messages_stream(
745745
)
746746

747747
# STEP 2: Update model context with any relevant memory
748-
inner_messages: List[AgentEvent | ChatMessage] = []
748+
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
749749
for event_msg in await self._update_model_context_with_memory(
750750
memory=memory,
751751
model_context=model_context,
@@ -810,7 +810,7 @@ async def on_messages_stream(
810810
@staticmethod
811811
async def _add_messages_to_context(
812812
model_context: ChatCompletionContext,
813-
messages: Sequence[ChatMessage],
813+
messages: Sequence[BaseChatMessage],
814814
) -> None:
815815
"""
816816
Add incoming messages to the model context.
@@ -886,7 +886,7 @@ async def _call_llm(
886886
async def _process_model_result(
887887
cls,
888888
model_result: CreateResult,
889-
inner_messages: List[AgentEvent | ChatMessage],
889+
inner_messages: List[BaseAgentEvent | BaseChatMessage],
890890
cancellation_token: CancellationToken,
891891
agent_name: str,
892892
system_messages: List[SystemMessage],
@@ -898,7 +898,7 @@ async def _process_model_result(
898898
model_client_stream: bool,
899899
reflect_on_tool_use: bool,
900900
tool_call_summary_format: str,
901-
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
901+
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
902902
"""
903903
Handle final or partial responses from model_result, including tool calls, handoffs,
904904
and reflection if needed.
@@ -992,7 +992,7 @@ async def _process_model_result(
992992
def _check_and_handle_handoff(
993993
model_result: CreateResult,
994994
executed_calls_and_results: List[Tuple[FunctionCall, FunctionExecutionResult]],
995-
inner_messages: List[AgentEvent | ChatMessage],
995+
inner_messages: List[BaseAgentEvent | BaseChatMessage],
996996
handoffs: Dict[str, HandoffBase],
997997
agent_name: str,
998998
) -> Optional[Response]:
@@ -1061,7 +1061,7 @@ async def _reflect_on_tool_use_flow(
10611061
model_client_stream: bool,
10621062
model_context: ChatCompletionContext,
10631063
agent_name: str,
1064-
inner_messages: List[AgentEvent | ChatMessage],
1064+
inner_messages: List[BaseAgentEvent | BaseChatMessage],
10651065
) -> AsyncGenerator[Response | ModelClientStreamingChunkEvent | ThoughtEvent, None]:
10661066
"""
10671067
If reflect_on_tool_use=True, we do another inference based on tool results
@@ -1113,7 +1113,7 @@ async def _reflect_on_tool_use_flow(
11131113
@staticmethod
11141114
def _summarize_tool_use(
11151115
executed_calls_and_results: List[Tuple[FunctionCall, FunctionExecutionResult]],
1116-
inner_messages: List[AgentEvent | ChatMessage],
1116+
inner_messages: List[BaseAgentEvent | BaseChatMessage],
11171117
handoffs: Dict[str, HandoffBase],
11181118
tool_call_summary_format: str,
11191119
agent_name: str,

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py

+18-18
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66

77
from ..base import ChatAgent, Response, TaskResult
88
from ..messages import (
9-
AgentEvent,
10-
ChatMessage,
9+
BaseAgentEvent,
10+
BaseChatMessage,
1111
ModelClientStreamingChunkEvent,
1212
TextMessage,
1313
)
@@ -59,13 +59,13 @@ def description(self) -> str:
5959

6060
@property
6161
@abstractmethod
62-
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
62+
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
6363
"""The types of messages that the agent produces in the
64-
:attr:`Response.chat_message` field. They must be :class:`ChatMessage` types."""
64+
:attr:`Response.chat_message` field. They must be :class:`BaseChatMessage` types."""
6565
...
6666

6767
@abstractmethod
68-
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
68+
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
6969
"""Handles incoming messages and returns a response.
7070
7171
.. note::
@@ -81,8 +81,8 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
8181
...
8282

8383
async def on_messages_stream(
84-
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
85-
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
84+
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
85+
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
8686
"""Handles incoming messages and returns a stream of messages and
8787
and the final item is the response. The base implementation in
8888
:class:`BaseChatAgent` simply calls :meth:`on_messages` and yields
@@ -106,29 +106,29 @@ async def on_messages_stream(
106106
async def run(
107107
self,
108108
*,
109-
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
109+
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
110110
cancellation_token: CancellationToken | None = None,
111111
) -> TaskResult:
112112
"""Run the agent with the given task and return the result."""
113113
if cancellation_token is None:
114114
cancellation_token = CancellationToken()
115-
input_messages: List[ChatMessage] = []
116-
output_messages: List[AgentEvent | ChatMessage] = []
115+
input_messages: List[BaseChatMessage] = []
116+
output_messages: List[BaseAgentEvent | BaseChatMessage] = []
117117
if task is None:
118118
pass
119119
elif isinstance(task, str):
120120
text_msg = TextMessage(content=task, source="user")
121121
input_messages.append(text_msg)
122122
output_messages.append(text_msg)
123-
elif isinstance(task, ChatMessage):
123+
elif isinstance(task, BaseChatMessage):
124124
input_messages.append(task)
125125
output_messages.append(task)
126126
else:
127127
if not task:
128128
raise ValueError("Task list cannot be empty.")
129129
# Task is a sequence of messages.
130130
for msg in task:
131-
if isinstance(msg, ChatMessage):
131+
if isinstance(msg, BaseChatMessage):
132132
input_messages.append(msg)
133133
output_messages.append(msg)
134134
else:
@@ -142,31 +142,31 @@ async def run(
142142
async def run_stream(
143143
self,
144144
*,
145-
task: str | ChatMessage | Sequence[ChatMessage] | None = None,
145+
task: str | BaseChatMessage | Sequence[BaseChatMessage] | None = None,
146146
cancellation_token: CancellationToken | None = None,
147-
) -> AsyncGenerator[AgentEvent | ChatMessage | TaskResult, None]:
147+
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | TaskResult, None]:
148148
"""Run the agent with the given task and return a stream of messages
149149
and the final task result as the last item in the stream."""
150150
if cancellation_token is None:
151151
cancellation_token = CancellationToken()
152-
input_messages: List[ChatMessage] = []
153-
output_messages: List[AgentEvent | ChatMessage] = []
152+
input_messages: List[BaseChatMessage] = []
153+
output_messages: List[BaseAgentEvent | BaseChatMessage] = []
154154
if task is None:
155155
pass
156156
elif isinstance(task, str):
157157
text_msg = TextMessage(content=task, source="user")
158158
input_messages.append(text_msg)
159159
output_messages.append(text_msg)
160160
yield text_msg
161-
elif isinstance(task, ChatMessage):
161+
elif isinstance(task, BaseChatMessage):
162162
input_messages.append(task)
163163
output_messages.append(task)
164164
yield task
165165
else:
166166
if not task:
167167
raise ValueError("Task list cannot be empty.")
168168
for msg in task:
169-
if isinstance(msg, ChatMessage):
169+
if isinstance(msg, BaseChatMessage):
170170
input_messages.append(msg)
171171
output_messages.append(msg)
172172
yield msg

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_code_executor_agent.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing_extensions import Self
88

99
from ..base import Response
10-
from ..messages import ChatMessage, TextMessage
10+
from ..messages import BaseChatMessage, TextMessage
1111
from ._base_chat_agent import BaseChatAgent
1212

1313

@@ -119,11 +119,11 @@ def __init__(
119119
self._sources = sources
120120

121121
@property
122-
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
122+
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
123123
"""The types of messages that the code executor agent produces."""
124124
return (TextMessage,)
125125

126-
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
126+
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
127127
# Extract code blocks from the messages.
128128
code_blocks: List[CodeBlock] = []
129129
for msg in messages:

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_society_of_mind_agent.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
from ..base import TaskResult, Team
1212
from ..messages import (
13-
AgentEvent,
14-
ChatMessage,
13+
BaseAgentEvent,
14+
BaseChatMessage,
1515
ModelClientStreamingChunkEvent,
1616
TextMessage,
1717
)
@@ -122,10 +122,10 @@ def __init__(
122122
self._response_prompt = response_prompt
123123

124124
@property
125-
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
125+
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
126126
return (TextMessage,)
127127

128-
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
128+
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
129129
# Call the stream method and collect the messages.
130130
response: Response | None = None
131131
async for msg in self.on_messages_stream(messages, cancellation_token):
@@ -135,14 +135,14 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
135135
return response
136136

137137
async def on_messages_stream(
138-
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
139-
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
138+
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
139+
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
140140
# Prepare the task for the team of agents.
141141
task = list(messages)
142142

143143
# Run the team of agents.
144144
result: TaskResult | None = None
145-
inner_messages: List[AgentEvent | ChatMessage] = []
145+
inner_messages: List[BaseAgentEvent | BaseChatMessage] = []
146146
count = 0
147147
async for inner_msg in self._team.run_stream(task=task, cancellation_token=cancellation_token):
148148
if isinstance(inner_msg, TaskResult):
@@ -167,7 +167,7 @@ async def on_messages_stream(
167167
# Generate a response using the model client.
168168
llm_messages: List[LLMMessage] = [SystemMessage(content=self._instruction)]
169169
for message in messages:
170-
if isinstance(message, ChatMessage):
170+
if isinstance(message, BaseChatMessage):
171171
llm_messages.append(message.to_model_message())
172172
llm_messages.append(SystemMessage(content=self._response_prompt))
173173
completion = await self._model_client.create(messages=llm_messages, cancellation_token=cancellation_token)

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_user_proxy_agent.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing_extensions import Self
1111

1212
from ..base import Response
13-
from ..messages import AgentEvent, ChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent
13+
from ..messages import BaseAgentEvent, BaseChatMessage, HandoffMessage, TextMessage, UserInputRequestedEvent
1414
from ._base_chat_agent import BaseChatAgent
1515

1616
SyncInputFunc = Callable[[str], str]
@@ -170,11 +170,11 @@ def __init__(
170170
self._is_async = iscoroutinefunction(self.input_func)
171171

172172
@property
173-
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
173+
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
174174
"""Message types this agent can produce."""
175175
return (TextMessage, HandoffMessage)
176176

177-
def _get_latest_handoff(self, messages: Sequence[ChatMessage]) -> Optional[HandoffMessage]:
177+
def _get_latest_handoff(self, messages: Sequence[BaseChatMessage]) -> Optional[HandoffMessage]:
178178
"""Find the HandoffMessage in the message sequence that addresses this agent."""
179179
if len(messages) > 0 and isinstance(messages[-1], HandoffMessage):
180180
if messages[-1].target == self.name:
@@ -201,15 +201,15 @@ async def _get_input(self, prompt: str, cancellation_token: Optional[Cancellatio
201201
except Exception as e:
202202
raise RuntimeError(f"Failed to get user input: {str(e)}") from e
203203

204-
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
204+
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
205205
async for message in self.on_messages_stream(messages, cancellation_token):
206206
if isinstance(message, Response):
207207
return message
208208
raise AssertionError("The stream should have returned the final result.")
209209

210210
async def on_messages_stream(
211-
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
212-
) -> AsyncGenerator[AgentEvent | ChatMessage | Response, None]:
211+
self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken
212+
) -> AsyncGenerator[BaseAgentEvent | BaseChatMessage | Response, None]:
213213
"""Handle incoming messages by requesting user input."""
214214
try:
215215
# Check for handoff first

0 commit comments

Comments
 (0)