Skip to content

Commit 9e723c3

Browse files
authored
Standardize the code of workflow use cases (#495)
1 parent d5da55b commit 9e723c3

File tree

12 files changed

+304
-177
lines changed

12 files changed

+304
-177
lines changed

.changeset/kind-mice-repair.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"create-llama": patch
3+
---
4+
5+
Standardize the code of the workflow use case (Python)

templates/components/agents/python/deep_research/app/workflows/deep_research.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232

3333

3434
def create_workflow(
35-
chat_history: Optional[List[ChatMessage]] = None,
3635
params: Optional[Dict[str, Any]] = None,
3736
**kwargs,
3837
) -> Workflow:
@@ -45,7 +44,6 @@ def create_workflow(
4544

4645
return DeepResearchWorkflow(
4746
index=index,
48-
chat_history=chat_history,
4947
timeout=120.0,
5048
)
5149

@@ -73,28 +71,29 @@ class DeepResearchWorkflow(Workflow):
7371
def __init__(
7472
self,
7573
index: BaseIndex,
76-
chat_history: Optional[List[ChatMessage]] = None,
77-
stream: bool = True,
7874
**kwargs,
7975
):
8076
super().__init__(**kwargs)
8177
self.index = index
8278
self.context_nodes = []
83-
self.stream = stream
84-
self.chat_history = chat_history
8579
self.memory = SimpleComposableMemory.from_defaults(
86-
primary_memory=ChatMemoryBuffer.from_defaults(
87-
chat_history=chat_history,
88-
),
80+
primary_memory=ChatMemoryBuffer.from_defaults(),
8981
)
9082

9183
@step
9284
async def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent:
9385
"""
9486
Initiate the workflow: memory, tools, agent
9587
"""
88+
self.stream = ev.get("stream", True)
89+
self.user_request = ev.get("user_msg")
90+
chat_history = ev.get("chat_history")
91+
if chat_history is not None:
92+
self.memory.put_messages(chat_history)
93+
9694
await ctx.set("total_questions", 0)
97-
self.user_request = ev.get("input")
95+
96+
# Add user message to memory
9897
self.memory.put_messages(
9998
messages=[
10099
ChatMessage(
@@ -319,7 +318,6 @@ async def report(self, ctx: Context, ev: ReportEvent) -> StopEvent:
319318
"""
320319
Report the answers
321320
"""
322-
logger.info("Writing the report")
323321
res = await write_report(
324322
memory=self.memory,
325323
user_request=self.user_request,

templates/components/agents/python/financial_report/app/workflows/financial_report.py

+24-21
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,5 @@
11
from typing import Any, Dict, List, Optional
22

3-
from app.engine.index import IndexConfig, get_index
4-
from app.engine.tools import ToolFactory
5-
from app.engine.tools.query_engine import get_query_engine_tool
6-
from app.workflows.events import AgentRunEvent
7-
from app.workflows.tools import (
8-
call_tools,
9-
chat_with_tools,
10-
)
113
from llama_index.core import Settings
124
from llama_index.core.base.llms.types import ChatMessage, MessageRole
135
from llama_index.core.llms.function_calling import FunctionCallingLLM
@@ -22,9 +14,17 @@
2214
step,
2315
)
2416

17+
from app.engine.index import IndexConfig, get_index
18+
from app.engine.tools import ToolFactory
19+
from app.engine.tools.query_engine import get_query_engine_tool
20+
from app.workflows.events import AgentRunEvent
21+
from app.workflows.tools import (
22+
call_tools,
23+
chat_with_tools,
24+
)
25+
2526

2627
def create_workflow(
27-
chat_history: Optional[List[ChatMessage]] = None,
2828
params: Optional[Dict[str, Any]] = None,
2929
**kwargs,
3030
) -> Workflow:
@@ -45,7 +45,6 @@ def create_workflow(
4545
query_engine_tool=query_engine_tool,
4646
code_interpreter_tool=code_interpreter_tool,
4747
document_generator_tool=document_generator_tool,
48-
chat_history=chat_history,
4948
)
5049

5150

@@ -91,6 +90,7 @@ class FinancialReportWorkflow(Workflow):
9190
It's good to using appropriate tools for the user request and always use the information from the tools, don't make up anything yourself.
9291
For the query engine tool, you should break down the user request into a list of queries and call the tool with the queries.
9392
"""
93+
stream: bool = True
9494

9595
def __init__(
9696
self,
@@ -99,12 +99,10 @@ def __init__(
9999
document_generator_tool: FunctionTool,
100100
llm: Optional[FunctionCallingLLM] = None,
101101
timeout: int = 360,
102-
chat_history: Optional[List[ChatMessage]] = None,
103102
system_prompt: Optional[str] = None,
104103
):
105104
super().__init__(timeout=timeout)
106105
self.system_prompt = system_prompt or self._default_system_prompt
107-
self.chat_history = chat_history or []
108106
self.query_engine_tool = query_engine_tool
109107
self.code_interpreter_tool = code_interpreter_tool
110108
self.document_generator_tool = document_generator_tool
@@ -122,23 +120,26 @@ def __init__(
122120
]
123121
self.llm: FunctionCallingLLM = llm or Settings.llm
124122
assert isinstance(self.llm, FunctionCallingLLM)
125-
self.memory = ChatMemoryBuffer.from_defaults(
126-
llm=self.llm, chat_history=self.chat_history
127-
)
123+
self.memory = ChatMemoryBuffer.from_defaults(llm=self.llm)
128124

129125
@step()
130126
async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent:
131-
ctx.data["input"] = ev.input
127+
self.stream = ev.get("stream", True)
128+
user_msg = ev.get("user_msg")
129+
chat_history = ev.get("chat_history")
130+
131+
if chat_history is not None:
132+
self.memory.put_messages(chat_history)
133+
134+
# Add user message to memory
135+
self.memory.put(ChatMessage(role=MessageRole.USER, content=user_msg))
132136

133137
if self.system_prompt:
134138
system_msg = ChatMessage(
135139
role=MessageRole.SYSTEM, content=self.system_prompt
136140
)
137141
self.memory.put(system_msg)
138142

139-
# Add user input to memory
140-
self.memory.put(ChatMessage(role=MessageRole.USER, content=ev.input))
141-
142143
return InputEvent(input=self.memory.get())
143144

144145
@step()
@@ -160,8 +161,10 @@ async def handle_llm_input( # type: ignore
160161
chat_history,
161162
)
162163
if not response.has_tool_calls():
163-
# If no tool call, return the response generator
164-
return StopEvent(result=response.generator)
164+
if self.stream:
165+
return StopEvent(result=response.generator)
166+
else:
167+
return StopEvent(result=await response.full_response())
165168
# calling different tools at the same time is not supported at the moment
166169
# add an error message to tell the AI to process step by step
167170
if response.is_calling_different_tools():

templates/components/agents/python/form_filling/app/workflows/form_filling.py

+15-15
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626

2727
def create_workflow(
28-
chat_history: Optional[List[ChatMessage]] = None,
2928
params: Optional[Dict[str, Any]] = None,
3029
**kwargs,
3130
) -> Workflow:
@@ -45,7 +44,6 @@ def create_workflow(
4544
query_engine_tool=query_engine_tool,
4645
extractor_tool=extractor_tool, # type: ignore
4746
filling_tool=filling_tool, # type: ignore
48-
chat_history=chat_history,
4947
)
5048

5149
return workflow
@@ -88,6 +86,7 @@ class FormFillingWorkflow(Workflow):
8886
Only use provided data - never make up any information yourself. Fill N/A if an answer is not found.
8987
If there is no query engine tool or the gathered information has many N/A values indicating the questions don't match the data, respond with a warning and ask the user to upload a different file or connect to a knowledge base.
9088
"""
89+
stream: bool = True
9190

9291
def __init__(
9392
self,
@@ -96,12 +95,10 @@ def __init__(
9695
filling_tool: FunctionTool,
9796
llm: Optional[FunctionCallingLLM] = None,
9897
timeout: int = 360,
99-
chat_history: Optional[List[ChatMessage]] = None,
10098
system_prompt: Optional[str] = None,
10199
):
102100
super().__init__(timeout=timeout)
103101
self.system_prompt = system_prompt or self._default_system_prompt
104-
self.chat_history = chat_history or []
105102
self.query_engine_tool = query_engine_tool
106103
self.extractor_tool = extractor_tool
107104
self.filling_tool = filling_tool
@@ -113,26 +110,26 @@ def __init__(
113110
self.llm: FunctionCallingLLM = llm or Settings.llm
114111
if not isinstance(self.llm, FunctionCallingLLM):
115112
raise ValueError("FormFillingWorkflow only supports FunctionCallingLLM.")
116-
self.memory = ChatMemoryBuffer.from_defaults(
117-
llm=self.llm, chat_history=self.chat_history
118-
)
113+
self.memory = ChatMemoryBuffer.from_defaults(llm=self.llm)
119114

120115
@step()
121116
async def start(self, ctx: Context, ev: StartEvent) -> InputEvent:
122-
ctx.data["input"] = ev.input
117+
self.stream = ev.get("stream", True)
118+
user_msg = ev.get("user_msg", "")
119+
chat_history = ev.get("chat_history", [])
120+
121+
if chat_history:
122+
self.memory.put_messages(chat_history)
123+
124+
self.memory.put(ChatMessage(role=MessageRole.USER, content=user_msg))
123125

124126
if self.system_prompt:
125127
system_msg = ChatMessage(
126128
role=MessageRole.SYSTEM, content=self.system_prompt
127129
)
128130
self.memory.put(system_msg)
129131

130-
user_input = ev.input
131-
user_msg = ChatMessage(role=MessageRole.USER, content=user_input)
132-
self.memory.put(user_msg)
133-
134-
chat_history = self.memory.get()
135-
return InputEvent(input=chat_history)
132+
return InputEvent(input=self.memory.get())
136133

137134
@step()
138135
async def handle_llm_input( # type: ignore
@@ -150,7 +147,10 @@ async def handle_llm_input( # type: ignore
150147
chat_history,
151148
)
152149
if not response.has_tool_calls():
153-
return StopEvent(result=response.generator)
150+
if self.stream:
151+
return StopEvent(result=response.generator)
152+
else:
153+
return StopEvent(result=await response.full_response())
154154
# calling different tools at the same time is not supported at the moment
155155
# add an error message to tell the AI to process step by step
156156
if response.is_calling_different_tools():

templates/components/multiagent/python/app/api/callbacks/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import logging
2+
from abc import ABC, abstractmethod
3+
from typing import Any
4+
5+
logger = logging.getLogger("uvicorn")
6+
7+
8+
class EventCallback(ABC):
9+
"""
10+
Base class for event callbacks during event streaming.
11+
"""
12+
13+
async def run(self, event: Any) -> Any:
14+
"""
15+
Called for each event in the stream.
16+
Default behavior: pass through the event unchanged.
17+
"""
18+
return event
19+
20+
async def on_complete(self, final_response: str) -> Any:
21+
"""
22+
Called when the stream is complete.
23+
Default behavior: return None.
24+
"""
25+
return None
26+
27+
@abstractmethod
28+
def from_default(self, *args, **kwargs) -> "EventCallback":
29+
"""
30+
Create a new instance of the processor from default values.
31+
"""
32+
pass
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import logging
2+
from typing import Any, List
3+
4+
from fastapi import BackgroundTasks
5+
from llama_index.core.schema import NodeWithScore
6+
7+
from app.api.callbacks.base import EventCallback
8+
9+
logger = logging.getLogger("uvicorn")
10+
11+
12+
class LlamaCloudFileDownload(EventCallback):
13+
"""
14+
Processor for handling LlamaCloud file downloads from source nodes.
15+
Only work if LlamaCloud service code is available.
16+
"""
17+
18+
def __init__(self, background_tasks: BackgroundTasks):
19+
self.background_tasks = background_tasks
20+
21+
async def run(self, event: Any) -> Any:
22+
if hasattr(event, "to_response"):
23+
event_response = event.to_response()
24+
if event_response.get("type") == "sources" and hasattr(event, "nodes"):
25+
await self._process_response_nodes(event.nodes)
26+
return event
27+
28+
async def _process_response_nodes(self, source_nodes: List[NodeWithScore]):
29+
try:
30+
from app.engine.service import LLamaCloudFileService # type: ignore
31+
32+
LLamaCloudFileService.download_files_from_nodes(
33+
source_nodes, self.background_tasks
34+
)
35+
except ImportError:
36+
pass
37+
38+
@classmethod
39+
def from_default(
40+
cls, background_tasks: BackgroundTasks
41+
) -> "LlamaCloudFileDownload":
42+
return cls(background_tasks=background_tasks)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import logging
2+
from typing import Any
3+
4+
from app.api.callbacks.base import EventCallback
5+
from app.api.routers.models import ChatData
6+
from app.api.services.suggestion import NextQuestionSuggestion
7+
8+
logger = logging.getLogger("uvicorn")
9+
10+
11+
class SuggestNextQuestions(EventCallback):
12+
"""Processor for generating next question suggestions."""
13+
14+
def __init__(self, chat_data: ChatData):
15+
self.chat_data = chat_data
16+
self.accumulated_text = ""
17+
18+
async def on_complete(self, final_response: str) -> Any:
19+
if final_response == "":
20+
return None
21+
22+
questions = await NextQuestionSuggestion.suggest_next_questions(
23+
self.chat_data.messages, final_response
24+
)
25+
if questions:
26+
return {
27+
"type": "suggested_questions",
28+
"data": questions,
29+
}
30+
return None
31+
32+
@classmethod
33+
def from_default(cls, chat_data: ChatData) -> "SuggestNextQuestions":
34+
return cls(chat_data=chat_data)

0 commit comments

Comments
 (0)