Skip to content

Fix: deep research use case #493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ class AnalysisDecision(BaseModel):
description="Whether to continue research, write a report, or cancel the research after several retries"
)
research_questions: Optional[List[str]] = Field(
description="Questions to research if continuing research. Maximum 3 questions. Set to null or empty if writing a report.",
description="""
If the decision is to research, provide a list of questions to research that related to the user request.
Maximum 3 questions. Set to null or empty if writing a report or cancel the research.
""",
default_factory=list,
)
cancel_reason: Optional[str] = Field(
Expand All @@ -29,32 +32,53 @@ async def plan_research(
memory: SimpleComposableMemory,
context_nodes: List[Node],
user_request: str,
total_questions: int,
) -> AnalysisDecision:
analyze_prompt = PromptTemplate(
"""
analyze_prompt = """
You are a professor who is guiding a researcher to research a specific request/problem.
Your task is to decide on a research plan for the researcher.

The possible actions are:
+ Provide a list of questions for the researcher to investigate, with the purpose of clarifying the request.
+ Write a report if the researcher has already gathered enough research on the topic and can resolve the initial request.
+ Cancel the research if most of the answers from researchers indicate there is insufficient information to research the request. Do not attempt more than 3 research iterations or too many questions.

The workflow should be:
+ Always begin by providing some initial questions for the researcher to investigate.
+ Analyze the provided answers against the initial topic/request. If the answers are insufficient to resolve the initial request, provide additional questions for the researcher to investigate.
+ If the answers are sufficient to resolve the initial request, instruct the researcher to write a report.
<User request>
{user_request}
</User request>

Here are the context:
<Collected information>
{context_str}
</Collected information>

<Conversation context>
{conversation_context}
</Conversation context>

{enhanced_prompt}

Now, provide your decision in the required format for this user request:
<User request>
{user_request}
</User request>
"""
)
# Manually craft the prompt to avoid LLM hallucination
enhanced_prompt = ""
if total_questions == 0:
# Avoid writing a report without any research context
enhanced_prompt = """

The student has no questions to research. Let start by asking some questions.
"""
elif total_questions > 6:
# Avoid asking too many questions (when the data is not ready for writing a report)
enhanced_prompt = """

The student has researched {total_questions} questions. Should cancel the research if the context is not enough to write a report.
"""

conversation_context = "\n".join(
[f"{message.role}: {message.content}" for message in memory.get_all()]
)
Expand All @@ -63,10 +87,11 @@ async def plan_research(
)
res = await Settings.llm.astructured_predict(
output_cls=AnalysisDecision,
prompt=analyze_prompt,
prompt=PromptTemplate(template=analyze_prompt),
user_request=user_request,
context_str=context_str,
conversation_context=conversation_context,
enhanced_prompt=enhanced_prompt,
)
return res

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,11 @@ def __init__(
)

@step
def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent:
async def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent:
"""
Initiate the workflow: memory, tools, agent
"""
await ctx.set("total_questions", 0)
self.user_request = ev.get("input")
self.memory.put_messages(
messages=[
Expand Down Expand Up @@ -132,9 +133,7 @@ def retrieve(self, ctx: Context, ev: StartEvent) -> PlanResearchEvent:
nodes=nodes,
)
)
return PlanResearchEvent(
context_nodes=self.context_nodes,
)
return PlanResearchEvent()

@step
async def analyze(
Expand All @@ -153,10 +152,12 @@ async def analyze(
},
)
)
total_questions = await ctx.get("total_questions")
res = await plan_research(
memory=self.memory,
context_nodes=self.context_nodes,
user_request=self.user_request,
total_questions=total_questions,
)
if res.decision == "cancel":
ctx.write_event_to_stream(
Expand All @@ -172,6 +173,22 @@ async def analyze(
result=res.cancel_reason,
)
elif res.decision == "write":
# Writing a report without any research context is not allowed.
# It's a LLM hallucination.
if total_questions == 0:
ctx.write_event_to_stream(
DataEvent(
type="deep_research_event",
data={
"event": "analyze",
"state": "done",
},
)
)
return StopEvent(
result="Sorry, I have a problem when analyzing the retrieved information. Please try again.",
)

self.memory.put(
message=ChatMessage(
role=MessageRole.ASSISTANT,
Expand All @@ -180,7 +197,11 @@ async def analyze(
)
ctx.send_event(ReportEvent())
else:
await ctx.set("n_questions", len(res.research_questions))
total_questions += len(res.research_questions)
await ctx.set("total_questions", total_questions) # For tracking
await ctx.set(
"waiting_questions", len(res.research_questions)
) # For waiting questions to be answered
self.memory.put(
message=ChatMessage(
role=MessageRole.ASSISTANT,
Expand Down Expand Up @@ -270,7 +291,7 @@ async def collect_answers(
"""
Collect answers to all questions
"""
num_questions = await ctx.get("n_questions")
num_questions = await ctx.get("waiting_questions")
results = ctx.collect_events(
ev,
expected=[CollectAnswersEvent] * num_questions,
Expand All @@ -284,7 +305,7 @@ async def collect_answers(
content=f"<Question>{result.question}</Question>\n<Answer>{result.answer}</Answer>",
)
)
await ctx.set("n_questions", 0)
await ctx.set("waiting_questions", 0)
self.memory.put(
message=ChatMessage(
role=MessageRole.ASSISTANT,
Expand Down
21 changes: 17 additions & 4 deletions templates/types/streaming/fastapi/main.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
# flake8: noqa: E402
from app.config import DATA_DIR, STATIC_DIR
from dotenv import load_dotenv

from app.config import DATA_DIR, STATIC_DIR

load_dotenv()

import logging
import os

import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles

from app.api.routers import api_router
from app.middlewares.frontend import FrontendProxyMiddleware
from app.observability import init_observability
from app.settings import init_settings
from fastapi import FastAPI
from fastapi.responses import RedirectResponse
from fastapi.staticfiles import StaticFiles

servers = []
app_name = os.getenv("FLY_APP_NAME")
Expand All @@ -28,6 +31,16 @@
environment = os.getenv("ENVIRONMENT", "dev") # Default to 'development' if not set
logger = logging.getLogger("uvicorn")

# Add CORS middleware for development
if environment == "dev":
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:*", "http://0.0.0.0:*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)


def mount_static_files(directory, path, html=False):
if os.path.exists(directory):
Expand Down
Loading