diff --git a/.changeset/cool-cars-promise.md b/.changeset/cool-cars-promise.md new file mode 100644 index 000000000..e09127e17 --- /dev/null +++ b/.changeset/cool-cars-promise.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Migrate AgentRunner to Agent Workflow (Python) diff --git a/.github/workflows/e2e.yml b/.github/workflows/e2e.yml index 2aa812c08..755e278ef 100644 --- a/.github/workflows/e2e.yml +++ b/.github/workflows/e2e.yml @@ -19,7 +19,7 @@ jobs: python-version: ["3.11"] os: [macos-latest, windows-latest, ubuntu-22.04] frameworks: ["fastapi"] - datasources: ["--no-files", "--example-file", "--llamacloud"] + datasources: ["--example-file", "--llamacloud"] defaults: run: shell: bash diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts index aa6103339..cf557695d 100644 --- a/helpers/env-variables.ts +++ b/helpers/env-variables.ts @@ -483,11 +483,12 @@ const getSystemPromptEnv = ( }); } if (tools?.length == 0 && (dataSources?.length ?? 0 > 0)) { - const citationPrompt = `'You have provided information from a knowledge base that has been passed to you in nodes of information. -Each node has useful metadata such as node ID, file name, page, etc. -Please add the citation to the data node for each sentence or paragraph that you reference in the provided information. -The citation format is: . [citation:]() -Where the is the unique identifier of the data node. + const citationPrompt = `'You have provided information from a knowledge base that separates the information into multiple nodes. +Always add a citation to each sentence or paragraph that you reference in the provided information using the node_id field in the header of each node. + +The citation format is: [citation:] +Where the is the node_id field in the header of each node. +Always separate the citation by a space. Example: We have two nodes: @@ -497,11 +498,9 @@ We have two nodes: node_id: abc file_name: animal.pdf -User question: Tell me a fun fact about Llama. -Your answer: -A baby llama is called "Cria" [citation:xyz](). -It often live in desert [citation:abc](). -It\\'s cute animal. +Your answer with citations: +A baby llama is called "Cria" [citation:xyz] +It often lives in desert [citation:abc] [citation:xyz] '`; systemPromptEnv.push({ name: "SYSTEM_CITATION_PROMPT", diff --git a/helpers/python.ts b/helpers/python.ts index 3569a81df..701bc0832 100644 --- a/helpers/python.ts +++ b/helpers/python.ts @@ -444,49 +444,15 @@ export const installPythonTemplate = async ({ cwd: path.join(compPath, "settings", "python"), }); - // Copy services if (template == "streaming" || template == "multiagent") { + // Copy services await copy("**", path.join(root, "app", "api", "services"), { cwd: path.join(compPath, "services", "python"), }); - } - // Copy engine code - if (template === "streaming" || template === "multiagent") { - // Select and copy engine code based on data sources and tools - let engine; - // Multiagent always uses agent engine - if (template === "multiagent") { - engine = "agent"; - } else { - // For streaming, use chat engine by default - // Unless tools are selected, in which case use agent engine - if (dataSources.length > 0 && (!tools || tools.length === 0)) { - console.log( - "\nNo tools selected - use optimized context chat engine\n", - ); - engine = "chat"; - } else { - engine = "agent"; - } - } - - // Copy engine code - await copy("**", enginePath, { - parents: true, - cwd: path.join(compPath, "engines", "python", engine), - }); - // Copy router code await copyRouterCode(root, tools ?? []); } - // Copy multiagents overrides - if (template === "multiagent") { - await copy("**", path.join(root), { - cwd: path.join(compPath, "multiagent", "python"), - }); - } - if (template === "multiagent" || template === "reflex") { if (useCase) { const sourcePath = diff --git a/questions/datasources.ts b/questions/datasources.ts index 1961e4c88..b750184c2 100644 --- a/questions/datasources.ts +++ b/questions/datasources.ts @@ -19,10 +19,12 @@ export const getDataSourceChoices = ( }); } if (selectedDataSource === undefined || selectedDataSource.length === 0) { - choices.push({ - title: "No datasource", - value: "none", - }); + if (framework !== "fastapi") { + choices.push({ + title: "No datasource", + value: "none", + }); + } choices.push({ title: process.platform !== "linux" diff --git a/templates/components/agents/python/deep_research/app/workflows/deep_research.py b/templates/components/agents/python/deep_research/app/workflows/deep_research.py index 6af650826..17df08f73 100644 --- a/templates/components/agents/python/deep_research/app/workflows/deep_research.py +++ b/templates/components/agents/python/deep_research/app/workflows/deep_research.py @@ -18,13 +18,13 @@ from app.engine.index import IndexConfig, get_index from app.workflows.agents import plan_research, research, write_report -from app.workflows.events import SourceNodesEvent from app.workflows.models import ( CollectAnswersEvent, DataEvent, PlanResearchEvent, ReportEvent, ResearchEvent, + SourceNodesEvent, ) logger = logging.getLogger("uvicorn") diff --git a/templates/components/agents/python/deep_research/app/workflows/models.py b/templates/components/agents/python/deep_research/app/workflows/models.py index 0fe25b47a..fa4414cef 100644 --- a/templates/components/agents/python/deep_research/app/workflows/models.py +++ b/templates/components/agents/python/deep_research/app/workflows/models.py @@ -4,6 +4,8 @@ from llama_index.core.workflow import Event from pydantic import BaseModel +from app.api.routers.models import SourceNodes + # Workflow events class PlanResearchEvent(Event): @@ -41,3 +43,18 @@ class DataEvent(Event): def to_response(self): return self.model_dump() + + +class SourceNodesEvent(Event): + nodes: List[NodeWithScore] + + def to_response(self): + return { + "type": "sources", + "data": { + "nodes": [ + SourceNodes.from_source_node(node).model_dump() + for node in self.nodes + ] + }, + } diff --git a/templates/components/multiagent/python/app/workflows/events.py b/templates/components/agents/python/financial_report/app/workflows/events.py similarity index 54% rename from templates/components/multiagent/python/app/workflows/events.py rename to templates/components/agents/python/financial_report/app/workflows/events.py index f74f26b7c..f40e9e1ab 100644 --- a/templates/components/multiagent/python/app/workflows/events.py +++ b/templates/components/agents/python/financial_report/app/workflows/events.py @@ -1,11 +1,8 @@ from enum import Enum -from typing import List, Optional +from typing import Optional -from llama_index.core.schema import NodeWithScore from llama_index.core.workflow import Event -from app.api.routers.models import SourceNodes - class AgentRunEventType(Enum): TEXT = "text" @@ -28,18 +25,3 @@ def to_response(self) -> dict: "data": self.data, }, } - - -class SourceNodesEvent(Event): - nodes: List[NodeWithScore] - - def to_response(self): - return { - "type": "sources", - "data": { - "nodes": [ - SourceNodes.from_source_node(node).model_dump() - for node in self.nodes - ] - }, - } diff --git a/templates/components/multiagent/python/app/workflows/tools.py b/templates/components/agents/python/financial_report/app/workflows/tools.py similarity index 100% rename from templates/components/multiagent/python/app/workflows/tools.py rename to templates/components/agents/python/financial_report/app/workflows/tools.py diff --git a/templates/components/agents/python/form_filling/app/workflows/events.py b/templates/components/agents/python/form_filling/app/workflows/events.py new file mode 100644 index 000000000..f40e9e1ab --- /dev/null +++ b/templates/components/agents/python/form_filling/app/workflows/events.py @@ -0,0 +1,27 @@ +from enum import Enum +from typing import Optional + +from llama_index.core.workflow import Event + + +class AgentRunEventType(Enum): + TEXT = "text" + PROGRESS = "progress" + + +class AgentRunEvent(Event): + name: str + msg: str + event_type: AgentRunEventType = AgentRunEventType.TEXT + data: Optional[dict] = None + + def to_response(self) -> dict: + return { + "type": "agent", + "data": { + "agent": self.name, + "type": self.event_type.value, + "text": self.msg, + "data": self.data, + }, + } diff --git a/templates/components/agents/python/form_filling/app/workflows/tools.py b/templates/components/agents/python/form_filling/app/workflows/tools.py new file mode 100644 index 000000000..faab45955 --- /dev/null +++ b/templates/components/agents/python/form_filling/app/workflows/tools.py @@ -0,0 +1,230 @@ +import logging +import uuid +from abc import ABC, abstractmethod +from typing import Any, AsyncGenerator, Callable, Optional + +from llama_index.core.base.llms.types import ChatMessage, ChatResponse, MessageRole +from llama_index.core.llms.function_calling import FunctionCallingLLM +from llama_index.core.tools import ( + BaseTool, + FunctionTool, + ToolOutput, + ToolSelection, +) +from llama_index.core.workflow import Context +from pydantic import BaseModel, ConfigDict + +from app.workflows.events import AgentRunEvent, AgentRunEventType + +logger = logging.getLogger("uvicorn") + + +class ContextAwareTool(FunctionTool, ABC): + @abstractmethod + async def acall(self, ctx: Context, input: Any) -> ToolOutput: # type: ignore + pass + + +class ChatWithToolsResponse(BaseModel): + """ + A tool call response from chat_with_tools. + """ + + tool_calls: Optional[list[ToolSelection]] + tool_call_message: Optional[ChatMessage] + generator: Optional[AsyncGenerator[ChatResponse | None, None]] + + model_config = ConfigDict(arbitrary_types_allowed=True) + + def is_calling_different_tools(self) -> bool: + tool_names = {tool_call.tool_name for tool_call in self.tool_calls} + return len(tool_names) > 1 + + def has_tool_calls(self) -> bool: + return self.tool_calls is not None and len(self.tool_calls) > 0 + + def tool_name(self) -> str: + assert self.has_tool_calls() + assert not self.is_calling_different_tools() + return self.tool_calls[0].tool_name + + async def full_response(self) -> str: + assert self.generator is not None + full_response = "" + async for chunk in self.generator: + content = chunk.message.content + if content: + full_response += content + return full_response + + +async def chat_with_tools( # type: ignore + llm: FunctionCallingLLM, + tools: list[BaseTool], + chat_history: list[ChatMessage], +) -> ChatWithToolsResponse: + """ + Request LLM to call tools or not. + This function doesn't change the memory. + """ + generator = _tool_call_generator(llm, tools, chat_history) + is_tool_call = await generator.__anext__() + if is_tool_call: + # Last chunk is the full response + # Wait for the last chunk + full_response = None + async for chunk in generator: + full_response = chunk + assert isinstance(full_response, ChatResponse) + return ChatWithToolsResponse( + tool_calls=llm.get_tool_calls_from_response(full_response), + tool_call_message=full_response.message, + generator=None, + ) + else: + return ChatWithToolsResponse( + tool_calls=None, + tool_call_message=None, + generator=generator, + ) + + +async def call_tools( + ctx: Context, + agent_name: str, + tools: list[BaseTool], + tool_calls: list[ToolSelection], + emit_agent_events: bool = True, +) -> list[ChatMessage]: + if len(tool_calls) == 0: + return [] + + tools_by_name = {tool.metadata.get_name(): tool for tool in tools} + if len(tool_calls) == 1: + return [ + await call_tool( + ctx, + tools_by_name[tool_calls[0].tool_name], + tool_calls[0], + lambda msg: ctx.write_event_to_stream( + AgentRunEvent( + name=agent_name, + msg=msg, + ) + ), + ) + ] + # Multiple tool calls, show progress + tool_msgs: list[ChatMessage] = [] + + progress_id = str(uuid.uuid4()) + total_steps = len(tool_calls) + if emit_agent_events: + ctx.write_event_to_stream( + AgentRunEvent( + name=agent_name, + msg=f"Making {total_steps} tool calls", + ) + ) + for i, tool_call in enumerate(tool_calls): + tool = tools_by_name.get(tool_call.tool_name) + if not tool: + tool_msgs.append( + ChatMessage( + role=MessageRole.ASSISTANT, + content=f"Tool {tool_call.tool_name} does not exist", + ) + ) + continue + tool_msg = await call_tool( + ctx, + tool, + tool_call, + event_emitter=lambda msg: ctx.write_event_to_stream( + AgentRunEvent( + name=agent_name, + msg=msg, + event_type=AgentRunEventType.PROGRESS, + data={ + "id": progress_id, + "total": total_steps, + "current": i, + }, + ) + ), + ) + tool_msgs.append(tool_msg) + return tool_msgs + + +async def call_tool( + ctx: Context, + tool: BaseTool, + tool_call: ToolSelection, + event_emitter: Optional[Callable[[str], None]], +) -> ChatMessage: + if event_emitter: + event_emitter( + f"Calling tool {tool_call.tool_name}, {str(tool_call.tool_kwargs)}" + ) + try: + if isinstance(tool, ContextAwareTool): + if ctx is None: + raise ValueError("Context is required for context aware tool") + # inject context for calling an context aware tool + response = await tool.acall(ctx=ctx, **tool_call.tool_kwargs) + else: + response = await tool.acall(**tool_call.tool_kwargs) # type: ignore + return ChatMessage( + role=MessageRole.TOOL, + content=str(response.raw_output), + additional_kwargs={ + "tool_call_id": tool_call.tool_id, + "name": tool.metadata.get_name(), + }, + ) + except Exception as e: + logger.error(f"Got error in tool {tool_call.tool_name}: {str(e)}") + if event_emitter: + event_emitter(f"Got error in tool {tool_call.tool_name}: {str(e)}") + return ChatMessage( + role=MessageRole.TOOL, + content=f"Error: {str(e)}", + additional_kwargs={ + "tool_call_id": tool_call.tool_id, + "name": tool.metadata.get_name(), + }, + ) + + +async def _tool_call_generator( + llm: FunctionCallingLLM, + tools: list[BaseTool], + chat_history: list[ChatMessage], +) -> AsyncGenerator[ChatResponse | bool, None]: + response_stream = await llm.astream_chat_with_tools( + tools, + chat_history=chat_history, + allow_parallel_tool_calls=False, + ) + + full_response = None + yielded_indicator = False + async for chunk in response_stream: + if "tool_calls" not in chunk.message.additional_kwargs: + # Yield a boolean to indicate whether the response is a tool call + if not yielded_indicator: + yield False + yielded_indicator = True + + # if not a tool call, yield the chunks! + yield chunk # type: ignore + elif not yielded_indicator: + # Yield the indicator for a tool call + yield True + yielded_indicator = True + + full_response = chunk + + if full_response: + yield full_response # type: ignore diff --git a/templates/components/engines/python/chat/engine.py b/templates/components/engines/python/chat/engine.py deleted file mode 100644 index f3795afd3..000000000 --- a/templates/components/engines/python/chat/engine.py +++ /dev/null @@ -1,47 +0,0 @@ -import os - -from app.engine.index import IndexConfig, get_index -from app.engine.node_postprocessors import NodeCitationProcessor -from fastapi import HTTPException -from llama_index.core.callbacks import CallbackManager -from llama_index.core.chat_engine import CondensePlusContextChatEngine -from llama_index.core.memory import ChatMemoryBuffer -from llama_index.core.settings import Settings - - -def get_chat_engine(params=None, event_handlers=None, **kwargs): - system_prompt = os.getenv("SYSTEM_PROMPT") - citation_prompt = os.getenv("SYSTEM_CITATION_PROMPT", None) - top_k = int(os.getenv("TOP_K", 0)) - llm = Settings.llm - memory = ChatMemoryBuffer.from_defaults( - token_limit=llm.metadata.context_window - 256 - ) - callback_manager = CallbackManager(handlers=event_handlers or []) - - node_postprocessors = [] - if citation_prompt: - node_postprocessors = [NodeCitationProcessor()] - system_prompt = f"{system_prompt}\n{citation_prompt}" - - index_config = IndexConfig(callback_manager=callback_manager, **(params or {})) - index = get_index(index_config) - if index is None: - raise HTTPException( - status_code=500, - detail=str( - "StorageContext is empty - call 'poetry run generate' to generate the storage first" - ), - ) - if top_k != 0 and kwargs.get("similarity_top_k") is None: - kwargs["similarity_top_k"] = top_k - retriever = index.as_retriever(**kwargs) - - return CondensePlusContextChatEngine( - llm=llm, - memory=memory, - system_prompt=system_prompt, - retriever=retriever, - node_postprocessors=node_postprocessors, # type: ignore - callback_manager=callback_manager, - ) diff --git a/templates/components/engines/python/chat/node_postprocessors.py b/templates/components/engines/python/chat/node_postprocessors.py deleted file mode 100644 index 336cd0edc..000000000 --- a/templates/components/engines/python/chat/node_postprocessors.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import List, Optional - -from llama_index.core import QueryBundle -from llama_index.core.postprocessor.types import BaseNodePostprocessor -from llama_index.core.schema import NodeWithScore - - -class NodeCitationProcessor(BaseNodePostprocessor): - """ - Append node_id into metadata for citation purpose. - Config SYSTEM_CITATION_PROMPT in your runtime environment variable to enable this feature. - """ - - def _postprocess_nodes( - self, - nodes: List[NodeWithScore], - query_bundle: Optional[QueryBundle] = None, - ) -> List[NodeWithScore]: - for node_score in nodes: - node_score.node.metadata["node_id"] = node_score.node.node_id - return nodes diff --git a/templates/components/multiagent/python/app/api/routers/chat.py b/templates/components/multiagent/python/app/api/routers/chat.py deleted file mode 100644 index d7a44e691..000000000 --- a/templates/components/multiagent/python/app/api/routers/chat.py +++ /dev/null @@ -1,55 +0,0 @@ -import logging - -from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status - -from app.api.callbacks.llamacloud import LlamaCloudFileDownload -from app.api.callbacks.next_question import SuggestNextQuestions -from app.api.callbacks.stream_handler import StreamHandler -from app.api.routers.models import ( - ChatData, -) -from app.engine.query_filter import generate_filters -from app.workflows import create_workflow - -chat_router = r = APIRouter() - -logger = logging.getLogger("uvicorn") - - -@r.post("") -async def chat( - request: Request, - data: ChatData, - background_tasks: BackgroundTasks, -): - try: - last_message_content = data.get_last_message_content() - messages = data.get_history_messages(include_agent_messages=True) - - doc_ids = data.get_chat_document_ids() - filters = generate_filters(doc_ids) - params = data.data or {} - - workflow = create_workflow( - params=params, - filters=filters, - ) - - handler = workflow.run( - user_msg=last_message_content, - chat_history=messages, - stream=True, - ) - return StreamHandler.from_default( - handler=handler, - callbacks=[ - LlamaCloudFileDownload.from_default(background_tasks), - SuggestNextQuestions.from_default(data), - ], - ).vercel_stream() - except Exception as e: - logger.exception("Error in chat engine", exc_info=True) - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error in chat engine: {e}", - ) from e diff --git a/templates/components/multiagent/python/app/api/routers/vercel_response.py b/templates/components/multiagent/python/app/api/routers/vercel_response.py deleted file mode 100644 index a5f1d7a01..000000000 --- a/templates/components/multiagent/python/app/api/routers/vercel_response.py +++ /dev/null @@ -1,99 +0,0 @@ -import asyncio -import json -import logging -from typing import AsyncGenerator - -from fastapi.responses import StreamingResponse -from llama_index.core.agent.workflow.workflow_events import AgentStream -from llama_index.core.workflow import StopEvent - -from app.api.callbacks.stream_handler import StreamHandler - -logger = logging.getLogger("uvicorn") - - -class VercelStreamResponse(StreamingResponse): - """ - Converts preprocessed events into Vercel-compatible streaming response format. - """ - - TEXT_PREFIX = "0:" - DATA_PREFIX = "8:" - ERROR_PREFIX = "3:" - - def __init__( - self, - stream_handler: StreamHandler, - *args, - **kwargs, - ): - self.handler = stream_handler - super().__init__(content=self.content_generator()) - - async def content_generator(self): - """Generate Vercel-formatted content from preprocessed events.""" - stream_started = False - try: - async for event in self.handler.stream_events(): - if not stream_started: - # Start the stream with an empty message - stream_started = True - yield self.convert_text("") - - # Handle different types of events - if isinstance(event, (AgentStream, StopEvent)): - async for chunk in self._stream_text(event): - await self.handler.accumulate_text(chunk) - yield self.convert_text(chunk) - elif isinstance(event, dict): - yield self.convert_data(event) - elif hasattr(event, "to_response"): - event_response = event.to_response() - yield self.convert_data(event_response) - else: - yield self.convert_data(event.model_dump()) - - except asyncio.CancelledError: - logger.warning("Client cancelled the request!") - await self.handler.cancel_run() - except Exception as e: - logger.error(f"Error in stream response: {e}") - yield self.convert_error(str(e)) - await self.handler.cancel_run() - - async def _stream_text( - self, event: AgentStream | StopEvent - ) -> AsyncGenerator[str, None]: - """ - Accept stream text from either AgentStream or StopEvent with string or AsyncGenerator result - """ - if isinstance(event, AgentStream): - yield self.convert_text(event.delta) - elif isinstance(event, StopEvent): - if isinstance(event.result, str): - yield event.result - elif isinstance(event.result, AsyncGenerator): - async for chunk in event.result: - if isinstance(chunk, str): - yield chunk - elif hasattr(chunk, "delta"): - yield chunk.delta - - @classmethod - def convert_text(cls, token: str) -> str: - """Convert text event to Vercel format.""" - # Escape newlines and double quotes to avoid breaking the stream - token = json.dumps(token) - return f"{cls.TEXT_PREFIX}{token}\n" - - @classmethod - def convert_data(cls, data: dict) -> str: - """Convert data event to Vercel format.""" - data_str = json.dumps(data) - return f"{cls.DATA_PREFIX}[{data_str}]\n" - - @classmethod - def convert_error(cls, error: str) -> str: - """Convert error event to Vercel format.""" - error_str = json.dumps(error) - return f"{cls.ERROR_PREFIX}{error_str}\n" diff --git a/templates/components/multiagent/python/app/workflows/function_calling_agent.py b/templates/components/multiagent/python/app/workflows/function_calling_agent.py deleted file mode 100644 index 452fc5e7b..000000000 --- a/templates/components/multiagent/python/app/workflows/function_calling_agent.py +++ /dev/null @@ -1,121 +0,0 @@ -from typing import Any, List, Optional - -from app.workflows.events import AgentRunEvent -from app.workflows.tools import ToolCallResponse, call_tools, chat_with_tools -from llama_index.core.base.llms.types import ChatMessage -from llama_index.core.llms.function_calling import FunctionCallingLLM -from llama_index.core.memory import ChatMemoryBuffer -from llama_index.core.settings import Settings -from llama_index.core.tools.types import BaseTool -from llama_index.core.workflow import ( - Context, - Event, - StartEvent, - StopEvent, - Workflow, - step, -) - - -class InputEvent(Event): - input: list[ChatMessage] - - -class ToolCallEvent(Event): - input: ToolCallResponse - - -class FunctionCallingAgent(Workflow): - """ - A simple workflow to request LLM with tools independently. - You can share the previous chat history to provide the context for the LLM. - """ - - def __init__( - self, - *args: Any, - llm: FunctionCallingLLM | None = None, - chat_history: Optional[List[ChatMessage]] = None, - tools: List[BaseTool] | None = None, - system_prompt: str | None = None, - verbose: bool = False, - timeout: float = 360.0, - name: str, - write_events: bool = True, - **kwargs: Any, - ) -> None: - super().__init__(*args, verbose=verbose, timeout=timeout, **kwargs) # type: ignore - self.tools = tools or [] - self.name = name - self.write_events = write_events - - if llm is None: - llm = Settings.llm - self.llm = llm - if not self.llm.metadata.is_function_calling_model: - raise ValueError("The provided LLM must support function calling.") - - self.system_prompt = system_prompt - - self.memory = ChatMemoryBuffer.from_defaults( - llm=self.llm, chat_history=chat_history - ) - self.sources = [] # type: ignore - - @step() - async def prepare_chat_history(self, ctx: Context, ev: StartEvent) -> InputEvent: - # clear sources - self.sources = [] - - # set streaming - ctx.data["streaming"] = getattr(ev, "streaming", False) - - # set system prompt - if self.system_prompt is not None: - system_msg = ChatMessage(role="system", content=self.system_prompt) - self.memory.put(system_msg) - - # get user input - user_input = ev.input - user_msg = ChatMessage(role="user", content=user_input) - self.memory.put(user_msg) - - if self.write_events: - ctx.write_event_to_stream( - AgentRunEvent(name=self.name, msg=f"Start to work on: {user_input}") - ) - - return InputEvent(input=self.memory.get()) - - @step() - async def handle_llm_input( - self, - ctx: Context, - ev: InputEvent, - ) -> ToolCallEvent | StopEvent: - chat_history = ev.input - - response = await chat_with_tools( - self.llm, - self.tools, - chat_history, - ) - is_tool_call = isinstance(response, ToolCallResponse) - if not is_tool_call: - if ctx.data["streaming"]: - return StopEvent(result=response) - else: - full_response = "" - async for chunk in response.generator: - full_response += chunk.message.content - return StopEvent(result=full_response) - return ToolCallEvent(input=response) - - @step() - async def handle_tool_calls(self, ctx: Context, ev: ToolCallEvent) -> InputEvent: - tool_calls = ev.input.tool_calls - tool_call_message = ev.input.tool_call_message - self.memory.put(tool_call_message) - tool_messages = await call_tools(self.name, self.tools, ctx, tool_calls) - self.memory.put_messages(tool_messages) - return InputEvent(input=self.memory.get()) diff --git a/templates/components/multiagent/python/app/api/callbacks/__init__.py b/templates/types/streaming/fastapi/app/api/callbacks/__init__.py similarity index 100% rename from templates/components/multiagent/python/app/api/callbacks/__init__.py rename to templates/types/streaming/fastapi/app/api/callbacks/__init__.py diff --git a/templates/components/multiagent/python/app/api/callbacks/base.py b/templates/types/streaming/fastapi/app/api/callbacks/base.py similarity index 100% rename from templates/components/multiagent/python/app/api/callbacks/base.py rename to templates/types/streaming/fastapi/app/api/callbacks/base.py diff --git a/templates/components/multiagent/python/app/api/callbacks/llamacloud.py b/templates/types/streaming/fastapi/app/api/callbacks/llamacloud.py similarity index 100% rename from templates/components/multiagent/python/app/api/callbacks/llamacloud.py rename to templates/types/streaming/fastapi/app/api/callbacks/llamacloud.py diff --git a/templates/components/multiagent/python/app/api/callbacks/next_question.py b/templates/types/streaming/fastapi/app/api/callbacks/next_question.py similarity index 100% rename from templates/components/multiagent/python/app/api/callbacks/next_question.py rename to templates/types/streaming/fastapi/app/api/callbacks/next_question.py diff --git a/templates/types/streaming/fastapi/app/api/callbacks/source_nodes.py b/templates/types/streaming/fastapi/app/api/callbacks/source_nodes.py new file mode 100644 index 000000000..91bf3664f --- /dev/null +++ b/templates/types/streaming/fastapi/app/api/callbacks/source_nodes.py @@ -0,0 +1,62 @@ +import logging +import os +from typing import Any, Dict, Optional + +from app.api.callbacks.base import EventCallback +from app.config import DATA_DIR +from llama_index.core.agent.workflow.workflow_events import ToolCallResult + +logger = logging.getLogger("uvicorn") + + +class AddNodeUrl(EventCallback): + """ + Add URL to source nodes + """ + + async def run(self, event: Any) -> Any: + if self._is_retrieval_result_event(event): + for node_score in event.tool_output.raw_output.source_nodes: + node_score.node.metadata["url"] = self._get_url_from_metadata( + node_score.node.metadata + ) + return event + + def _is_retrieval_result_event(self, event: Any) -> bool: + if isinstance(event, ToolCallResult): + if event.tool_name == "query_index": + return True + return False + + def _get_url_from_metadata(self, metadata: Dict[str, Any]) -> Optional[str]: + url_prefix = os.getenv("FILESERVER_URL_PREFIX") + if not url_prefix: + logger.warning( + "Warning: FILESERVER_URL_PREFIX not set in environment variables. Can't use file server" + ) + file_name = metadata.get("file_name") + + if file_name and url_prefix: + # file_name exists and file server is configured + pipeline_id = metadata.get("pipeline_id") + if pipeline_id: + # file is from LlamaCloud + file_name = f"{pipeline_id}${file_name}" + return f"{url_prefix}/output/llamacloud/{file_name}" + is_private = metadata.get("private", "false") == "true" + if is_private: + # file is a private upload + return f"{url_prefix}/output/uploaded/{file_name}" + # file is from calling the 'generate' script + # Get the relative path of file_path to data_dir + file_path = metadata.get("file_path") + data_dir = os.path.abspath(DATA_DIR) + if file_path and data_dir: + relative_path = os.path.relpath(file_path, data_dir) + return f"{url_prefix}/data/{relative_path}" + # fallback to URL in metadata (e.g. for websites) + return metadata.get("URL") + + @classmethod + def from_default(cls) -> "AddNodeUrl": + return cls() diff --git a/templates/components/multiagent/python/app/api/callbacks/stream_handler.py b/templates/types/streaming/fastapi/app/api/callbacks/stream_handler.py similarity index 100% rename from templates/components/multiagent/python/app/api/callbacks/stream_handler.py rename to templates/types/streaming/fastapi/app/api/callbacks/stream_handler.py diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py index c024dad02..499a68875 100644 --- a/templates/types/streaming/fastapi/app/api/routers/chat.py +++ b/templates/types/streaming/fastapi/app/api/routers/chat.py @@ -1,25 +1,22 @@ import logging from fastapi import APIRouter, BackgroundTasks, HTTPException, Request, status -from llama_index.core.llms import MessageRole -from app.api.routers.events import EventCallbackHandler +from app.api.callbacks.llamacloud import LlamaCloudFileDownload +from app.api.callbacks.next_question import SuggestNextQuestions +from app.api.callbacks.source_nodes import AddNodeUrl +from app.api.callbacks.stream_handler import StreamHandler from app.api.routers.models import ( ChatData, - Message, - Result, - SourceNodes, ) -from app.api.routers.vercel_response import VercelStreamResponse -from app.engine.engine import get_chat_engine from app.engine.query_filter import generate_filters +from app.workflows import create_workflow chat_router = r = APIRouter() logger = logging.getLogger("uvicorn") -# streaming endpoint - delete if not needed @r.post("") async def chat( request: Request, @@ -28,50 +25,66 @@ async def chat( ): try: last_message_content = data.get_last_message_content() - messages = data.get_history_messages() + messages = data.get_history_messages(include_agent_messages=True) doc_ids = data.get_chat_document_ids() filters = generate_filters(doc_ids) params = data.data or {} - logger.info( - f"Creating chat engine with filters: {str(filters)}", - ) - event_handler = EventCallbackHandler() - chat_engine = get_chat_engine( - filters=filters, params=params, event_handlers=[event_handler] + + workflow = create_workflow( + params=params, + filters=filters, ) - response = chat_engine.astream_chat(last_message_content, messages) - return VercelStreamResponse( - request, event_handler, response, data, background_tasks + handler = workflow.run( + user_msg=last_message_content, + chat_history=messages, + stream=True, ) + return StreamHandler.from_default( + handler=handler, + callbacks=[ + LlamaCloudFileDownload.from_default(background_tasks), + SuggestNextQuestions.from_default(data), + AddNodeUrl.from_default(), + ], + ).vercel_stream() except Exception as e: - logger.exception("Error in chat engine", exc_info=True) + logger.exception("Error in chat", exc_info=True) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Error in chat engine: {e}", + detail=f"Error in chat: {e}", ) from e # non-streaming endpoint - delete if not needed @r.post("/request") async def chat_request( + request: Request, data: ChatData, -) -> Result: - last_message_content = data.get_last_message_content() - messages = data.get_history_messages() +): + try: + last_message_content = data.get_last_message_content() + messages = data.get_history_messages(include_agent_messages=True) - doc_ids = data.get_chat_document_ids() - filters = generate_filters(doc_ids) - params = data.data or {} - logger.info( - f"Creating chat engine with filters: {str(filters)}", - ) + doc_ids = data.get_chat_document_ids() + filters = generate_filters(doc_ids) + params = data.data or {} - chat_engine = get_chat_engine(filters=filters, params=params) + workflow = create_workflow( + params=params, + filters=filters, + ) - response = await chat_engine.achat(last_message_content, messages) - return Result( - result=Message(role=MessageRole.ASSISTANT, content=response.response), - nodes=SourceNodes.from_source_nodes(response.source_nodes), - ) + handler = workflow.run( + user_msg=last_message_content, + chat_history=messages, + stream=False, + ) + return await handler + except Exception as e: + logger.exception("Error in chat request", exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Error in chat request: {e}", + ) from e diff --git a/templates/types/streaming/fastapi/app/api/routers/models.py b/templates/types/streaming/fastapi/app/api/routers/models.py index 31f2fa46f..0f647c325 100644 --- a/templates/types/streaming/fastapi/app/api/routers/models.py +++ b/templates/types/streaming/fastapi/app/api/routers/models.py @@ -103,7 +103,13 @@ def to_content(self) -> Optional[str]: class Message(BaseModel): role: MessageRole content: str - annotations: List[Annotation] | None = None + annotations: Optional[List[Annotation]] = None + + @validator("annotations", pre=True) + def validate_annotations(cls, v): + if v is None: + return v + return [item for item in v if isinstance(item, Annotation)] class ChatData(BaseModel): @@ -317,7 +323,6 @@ def from_source_nodes(cls, source_nodes: List[NodeWithScore]): class Result(BaseModel): result: Message - nodes: List[SourceNodes] class ChatConfig(BaseModel): diff --git a/templates/types/streaming/fastapi/app/api/routers/vercel_response.py b/templates/types/streaming/fastapi/app/api/routers/vercel_response.py index 0d41d893e..a2d5c730d 100644 --- a/templates/types/streaming/fastapi/app/api/routers/vercel_response.py +++ b/templates/types/streaming/fastapi/app/api/routers/vercel_response.py @@ -1,23 +1,20 @@ +import asyncio import json import logging -from typing import Awaitable, List +from typing import AsyncGenerator -from aiostream import stream -from fastapi import BackgroundTasks, Request from fastapi.responses import StreamingResponse -from llama_index.core.chat_engine.types import StreamingAgentChatResponse -from llama_index.core.schema import NodeWithScore +from llama_index.core.agent.workflow.workflow_events import AgentStream +from llama_index.core.workflow import StopEvent -from app.api.routers.events import EventCallbackHandler -from app.api.routers.models import ChatData, Message, SourceNodes -from app.api.services.suggestion import NextQuestionSuggestion +from app.api.callbacks.stream_handler import StreamHandler logger = logging.getLogger("uvicorn") class VercelStreamResponse(StreamingResponse): """ - Class to convert the response from the chat engine to the streaming format expected by Vercel + Converts preprocessed events into Vercel-compatible streaming response format. """ TEXT_PREFIX = "0:" @@ -26,152 +23,80 @@ class VercelStreamResponse(StreamingResponse): def __init__( self, - request: Request, - event_handler: EventCallbackHandler, - response: Awaitable[StreamingAgentChatResponse], - chat_data: ChatData, - background_tasks: BackgroundTasks, + stream_handler: StreamHandler, + *args, + **kwargs, ): - content = VercelStreamResponse.content_generator( - request, event_handler, response, chat_data, background_tasks - ) - super().__init__(content=content) + self.handler = stream_handler + super().__init__(content=self.content_generator()) - @classmethod - async def content_generator( - cls, - request: Request, - event_handler: EventCallbackHandler, - response: Awaitable[StreamingAgentChatResponse], - chat_data: ChatData, - background_tasks: BackgroundTasks, - ): - chat_response_generator = cls._chat_response_generator( - response, background_tasks, event_handler, chat_data - ) - event_generator = cls._event_generator(event_handler) - - # Merge the chat response generator and the event generator - combine = stream.merge(chat_response_generator, event_generator) - is_stream_started = False + async def content_generator(self): + """Generate Vercel-formatted content from preprocessed events.""" + stream_started = False try: - async with combine.stream() as streamer: - async for output in streamer: - if await request.is_disconnected(): - break - - if not is_stream_started: - is_stream_started = True - # Stream a blank message to start displaying the response in the UI - yield cls.convert_text("") - - yield output - except Exception: - logger.exception("Error in stream response") - yield cls.convert_error( - "An unexpected error occurred while processing your request, preventing the creation of a final answer. Please try again." - ) - finally: - # Ensure event handler is marked as done even if connection breaks - event_handler.is_done = True - - @classmethod - async def _event_generator(cls, event_handler: EventCallbackHandler): - """ - Yield the events from the event handler - """ - async for event in event_handler.async_event_gen(): - event_response = event.to_response() - if event_response is not None: - yield cls.convert_data(event_response) - - @classmethod - async def _chat_response_generator( - cls, - response: Awaitable[StreamingAgentChatResponse], - background_tasks: BackgroundTasks, - event_handler: EventCallbackHandler, - chat_data: ChatData, - ): + async for event in self.handler.stream_events(): + if not stream_started: + # Start the stream with an empty message + stream_started = True + yield self.convert_text("") + + # Handle different types of events + if isinstance(event, (AgentStream, StopEvent)): + async for chunk in self._stream_text(event): + await self.handler.accumulate_text(chunk) + yield self.convert_text(chunk) + elif isinstance(event, dict): + yield self.convert_data(event) + elif hasattr(event, "to_response"): + event_response = event.to_response() + yield self.convert_data(event_response) + else: + yield self.convert_data(event.model_dump()) + + except asyncio.CancelledError: + logger.warning("Client cancelled the request!") + await self.handler.cancel_run() + except Exception as e: + logger.error(f"Error in stream response: {e}") + yield self.convert_error(str(e)) + await self.handler.cancel_run() + + async def _stream_text( + self, event: AgentStream | StopEvent + ) -> AsyncGenerator[str, None]: """ - Yield the text response and source nodes from the chat engine + Accept stream text from either AgentStream or StopEvent with string or AsyncGenerator result """ - # Wait for the response from the chat engine - result = await response - - # Once we got a source node, start a background task to download the files (if needed) - cls._process_response_nodes(result.source_nodes, background_tasks) - - # Yield the source nodes - yield cls.convert_data( - { - "type": "sources", - "data": { - "nodes": [ - SourceNodes.from_source_node(node).model_dump() - for node in result.source_nodes - ] - }, - } - ) - - final_response = "" - async for token in result.async_response_gen(): - final_response += token - yield cls.convert_text(token) - - # Generate next questions if next question prompt is configured - question_data = await cls._generate_next_questions( - chat_data.messages, final_response - ) - if question_data: - yield cls.convert_data(question_data) - - # the text_generator is the leading stream, once it's finished, also finish the event stream - event_handler.is_done = True + if isinstance(event, AgentStream): + if event.delta.strip(): # Only yield non-empty deltas + yield event.delta + elif isinstance(event, StopEvent): + if isinstance(event.result, str): + yield event.result + elif isinstance(event.result, AsyncGenerator): + async for chunk in event.result: + if isinstance(chunk, str): + yield chunk + elif ( + hasattr(chunk, "delta") and chunk.delta.strip() + ): # Only yield non-empty deltas + yield chunk.delta @classmethod - def convert_text(cls, token: str): + def convert_text(cls, token: str) -> str: + """Convert text event to Vercel format.""" # Escape newlines and double quotes to avoid breaking the stream token = json.dumps(token) return f"{cls.TEXT_PREFIX}{token}\n" @classmethod - def convert_data(cls, data: dict): + def convert_data(cls, data: dict) -> str: + """Convert data event to Vercel format.""" data_str = json.dumps(data) return f"{cls.DATA_PREFIX}[{data_str}]\n" @classmethod - def convert_error(cls, error: str): + def convert_error(cls, error: str) -> str: + """Convert error event to Vercel format.""" error_str = json.dumps(error) return f"{cls.ERROR_PREFIX}{error_str}\n" - - @staticmethod - def _process_response_nodes( - source_nodes: List[NodeWithScore], - background_tasks: BackgroundTasks, - ): - try: - # Start background tasks to download documents from LlamaCloud if needed - from app.engine.service import LLamaCloudFileService # type: ignore - - LLamaCloudFileService.download_files_from_nodes( - source_nodes, background_tasks - ) - except ImportError: - logger.debug( - "LlamaCloud is not configured. Skipping post processing of nodes" - ) - pass - - @staticmethod - async def _generate_next_questions(chat_history: List[Message], response: str): - questions = await NextQuestionSuggestion.suggest_next_questions( - chat_history, response - ) - if questions: - return { - "type": "suggested_questions", - "data": questions, - } - return None diff --git a/templates/components/engines/python/agent/tools/__init__.py b/templates/types/streaming/fastapi/app/engine/tools/__init__.py similarity index 100% rename from templates/components/engines/python/agent/tools/__init__.py rename to templates/types/streaming/fastapi/app/engine/tools/__init__.py diff --git a/templates/components/engines/python/agent/tools/artifact.py b/templates/types/streaming/fastapi/app/engine/tools/artifact.py similarity index 100% rename from templates/components/engines/python/agent/tools/artifact.py rename to templates/types/streaming/fastapi/app/engine/tools/artifact.py diff --git a/templates/components/engines/python/agent/tools/document_generator.py b/templates/types/streaming/fastapi/app/engine/tools/document_generator.py similarity index 100% rename from templates/components/engines/python/agent/tools/document_generator.py rename to templates/types/streaming/fastapi/app/engine/tools/document_generator.py diff --git a/templates/components/engines/python/agent/tools/duckduckgo.py b/templates/types/streaming/fastapi/app/engine/tools/duckduckgo.py similarity index 100% rename from templates/components/engines/python/agent/tools/duckduckgo.py rename to templates/types/streaming/fastapi/app/engine/tools/duckduckgo.py diff --git a/templates/components/engines/python/agent/tools/form_filling.py b/templates/types/streaming/fastapi/app/engine/tools/form_filling.py similarity index 100% rename from templates/components/engines/python/agent/tools/form_filling.py rename to templates/types/streaming/fastapi/app/engine/tools/form_filling.py diff --git a/templates/components/engines/python/agent/tools/img_gen.py b/templates/types/streaming/fastapi/app/engine/tools/img_gen.py similarity index 100% rename from templates/components/engines/python/agent/tools/img_gen.py rename to templates/types/streaming/fastapi/app/engine/tools/img_gen.py diff --git a/templates/components/engines/python/agent/tools/interpreter.py b/templates/types/streaming/fastapi/app/engine/tools/interpreter.py similarity index 100% rename from templates/components/engines/python/agent/tools/interpreter.py rename to templates/types/streaming/fastapi/app/engine/tools/interpreter.py diff --git a/templates/components/engines/python/agent/tools/openapi_action.py b/templates/types/streaming/fastapi/app/engine/tools/openapi_action.py similarity index 100% rename from templates/components/engines/python/agent/tools/openapi_action.py rename to templates/types/streaming/fastapi/app/engine/tools/openapi_action.py diff --git a/templates/components/engines/python/agent/tools/query_engine.py b/templates/types/streaming/fastapi/app/engine/tools/query_engine.py similarity index 100% rename from templates/components/engines/python/agent/tools/query_engine.py rename to templates/types/streaming/fastapi/app/engine/tools/query_engine.py diff --git a/templates/components/engines/python/agent/tools/weather.py b/templates/types/streaming/fastapi/app/engine/tools/weather.py similarity index 100% rename from templates/components/engines/python/agent/tools/weather.py rename to templates/types/streaming/fastapi/app/engine/tools/weather.py diff --git a/templates/types/streaming/fastapi/app/workflows/__init__.py b/templates/types/streaming/fastapi/app/workflows/__init__.py new file mode 100644 index 000000000..29c530646 --- /dev/null +++ b/templates/types/streaming/fastapi/app/workflows/__init__.py @@ -0,0 +1,4 @@ +from .agent import create_workflow + + +__all__ = ["create_workflow"] diff --git a/templates/components/engines/python/agent/engine.py b/templates/types/streaming/fastapi/app/workflows/agent.py similarity index 63% rename from templates/components/engines/python/agent/engine.py rename to templates/types/streaming/fastapi/app/workflows/agent.py index 3cc52e1ed..6dcfd76e5 100644 --- a/templates/components/engines/python/agent/engine.py +++ b/templates/types/streaming/fastapi/app/workflows/agent.py @@ -1,8 +1,7 @@ import os from typing import List -from llama_index.core.agent import AgentRunner -from llama_index.core.callbacks import CallbackManager +from llama_index.core.agent.workflow import AgentWorkflow from llama_index.core.settings import Settings from llama_index.core.tools import BaseTool @@ -11,13 +10,14 @@ from app.engine.tools.query_engine import get_query_engine_tool -def get_chat_engine(params=None, event_handlers=None, **kwargs): +def create_workflow(params=None, **kwargs): + if params is None: + params = {} system_prompt = os.getenv("SYSTEM_PROMPT") tools: List[BaseTool] = [] - callback_manager = CallbackManager(handlers=event_handlers or []) # Add query tool if index exists - index_config = IndexConfig(callback_manager=callback_manager, **(params or {})) + index_config = IndexConfig(**params) index = get_index(index_config) if index is not None: query_engine_tool = get_query_engine_tool(index, **kwargs) @@ -27,10 +27,11 @@ def get_chat_engine(params=None, event_handlers=None, **kwargs): configured_tools: List[BaseTool] = ToolFactory.from_env() tools.extend(configured_tools) - return AgentRunner.from_llm( + if len(tools) == 0: + raise RuntimeError("Please provide at least one tool!") + + return AgentWorkflow.from_tools_or_functions( + tools_or_functions=tools, # type: ignore llm=Settings.llm, - tools=tools, system_prompt=system_prompt, - callback_manager=callback_manager, - verbose=True, ) diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message-content.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message-content.tsx index 9d065b1f2..aed2b0df9 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/chat-message-content.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/chat-message-content.tsx @@ -1,17 +1,29 @@ import { ChatMessage } from "@llamaindex/chat-ui"; import { DeepResearchCard } from "./custom/deep-research-card"; +import { ArtifactToolComponent } from "./tools/artifact"; import { ToolAnnotations } from "./tools/chat-tools"; - +import { ChatSourcesComponent, RetrieverComponent } from "./tools/query-index"; +import { WeatherToolComponent } from "./tools/weather-card"; export function ChatMessageContent() { return ( + + + {/* For backward compatibility with the events from AgentRunner + * ToolAnnotations will be removed when we migrate to AgentWorkflow completely + */} + + + {/* For backward compatibility with the events from AgentRunner. + * The Source component will be removed when we migrate to AgentWorkflow completely + */} diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/custom/deep-research-card.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/custom/deep-research-card.tsx index bc6118e61..41aab5341 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/custom/deep-research-card.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/custom/deep-research-card.tsx @@ -157,53 +157,57 @@ export function DeepResearchCard({ className }: DeepResearchCardProps) { if (!state) return null; return ( - - - {state.retrieve.state !== null && ( - - - {state.retrieve.state === "inprogress" - ? "Searching..." - : "Search completed"} - - )} - {state.analyze.state !== null && ( - - - {state.analyze.state === "inprogress" ? "Analyzing..." : "Analysis"} - - )} - - - - {state.analyze.questions.length > 0 && ( - - {state.analyze.questions.map((question: QuestionState) => ( - - -
-
- {stateIcon[question.state]} + state.analyze.questions.length > 0 && ( + + + {state.retrieve.state !== null && ( + + + {state.retrieve.state === "inprogress" + ? "Searching..." + : "Search completed"} + + )} + {state.analyze.state !== null && ( + + + {state.analyze.state === "inprogress" + ? "Analyzing..." + : "Analysis"} + + )} + + + + {state.analyze.questions.length > 0 && ( + + {state.analyze.questions.map((question: QuestionState) => ( + + +
+
+ {stateIcon[question.state]} +
+ + {question.question} +
- - {question.question} - -
- - {question.answer && ( - - - - )} - - ))} - - )} - - + + {question.answer && ( + + + + )} + + ))} + + )} + + + ) ); } diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/tools/artifact.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/tools/artifact.tsx index fe6e81998..f3aff154e 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/tools/artifact.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/tools/artifact.tsx @@ -1,7 +1,13 @@ "use client"; +import { + getCustomAnnotation, + useChatMessage, + useChatUI, +} from "@llamaindex/chat-ui"; import { Check, ChevronDown, Code, Copy, Loader2 } from "lucide-react"; -import { useEffect, useRef, useState } from "react"; +import { useEffect, useMemo, useRef, useState } from "react"; +import { z } from "zod"; import { Button, buttonVariants } from "../../button"; import { Collapsible, @@ -386,3 +392,61 @@ function closePanel() { panel.classList.add("hidden"); }); } + +const ArtifactToolSchema = z.object({ + tool_name: z.literal("artifact"), + tool_kwargs: z.object({ + query: z.string(), + }), + tool_id: z.string(), + tool_output: z.object({ + content: z.string(), + tool_name: z.string(), + raw_input: z.object({ + args: z.array(z.unknown()), + kwargs: z.object({ + query: z.string(), + }), + }), + raw_output: z.custom(), + is_error: z.boolean(), + }), + return_direct: z.boolean().optional(), +}); + +type ArtifactTool = z.infer; + +export function ArtifactToolComponent() { + const { message } = useChatMessage(); + const { messages } = useChatUI(); + + const artifactOutputEvent = getCustomAnnotation( + message.annotations, + (annotation: unknown) => { + const result = ArtifactToolSchema.safeParse(annotation); + return result.success; + }, + ).at(0); + + const artifactVersion = useMemo(() => { + const artifactToolCalls = messages.filter((m) => + m.annotations?.some( + (a: unknown) => (a as ArtifactTool).tool_name === "artifact", + ), + ); + return artifactToolCalls.length; + }, [messages]); + + return ( + artifactOutputEvent && ( +
+ {artifactOutputEvent && ( + + )} +
+ ) + ); +} diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/tools/query-index.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/tools/query-index.tsx new file mode 100644 index 000000000..23698d5d6 --- /dev/null +++ b/templates/types/streaming/nextjs/app/components/ui/chat/tools/query-index.tsx @@ -0,0 +1,131 @@ +"use client"; + +import { + getCustomAnnotation, + SourceNode, + useChatMessage, +} from "@llamaindex/chat-ui"; +import { ChatEvents, ChatSources } from "@llamaindex/chat-ui/widgets"; +import { useMemo } from "react"; +import { z } from "zod"; + +const QueryIndexSchema = z.object({ + tool_name: z.literal("query_index"), + tool_kwargs: z.object({ + input: z.string(), + }), + tool_id: z.string(), + tool_output: z.optional( + z.object({ + content: z.string(), + tool_name: z.string(), + raw_output: z.object({ + source_nodes: z.array( + z.object({ + node: z.object({ + id_: z.string(), + metadata: z.object({ + url: z.string(), + }), + text: z.string(), + }), + score: z.number(), + }), + ), + }), + is_error: z.boolean().optional(), + }), + ), + return_direct: z.boolean().optional(), +}); +type QueryIndex = z.infer; + +type GroupedIndexQuery = { + initial: QueryIndex; + output?: QueryIndex; +}; + +export function RetrieverComponent() { + const { message } = useChatMessage(); + + const queryIndexEvents = getCustomAnnotation( + message.annotations, + (annotation) => { + const result = QueryIndexSchema.safeParse(annotation); + return result.success; + }, + ); + + // Group events by tool_id and render them in a single ChatEvents component + const groupedIndexQueries = useMemo(() => { + const groups = new Map(); + + queryIndexEvents?.forEach((event) => { + groups.set(event.tool_id, { initial: event }); + }); + + return Array.from(groups.values()); + }, [queryIndexEvents]); + + return ( + groupedIndexQueries.length > 0 && ( +
+ {groupedIndexQueries.map(({ initial }) => { + const eventData = [ + { + title: `Searching index with query: ${initial.tool_kwargs.input}`, + }, + ]; + + if (initial.tool_output) { + eventData.push({ + title: `Got ${JSON.stringify(initial.tool_output?.raw_output.source_nodes?.length ?? 0)} sources for query: ${initial.tool_kwargs.input}`, + }); + } + + return ( + + ); + })} +
+ ) + ); +} + +/** + * Render the source nodes whenever we got query_index tool with output + */ +export function ChatSourcesComponent() { + const { message } = useChatMessage(); + + const queryIndexEvents = getCustomAnnotation( + message.annotations, + (annotation) => { + const result = QueryIndexSchema.safeParse(annotation); + return result.success && !!result.data.tool_output; + }, + ); + + const sources: SourceNode[] = useMemo(() => { + return ( + queryIndexEvents?.flatMap((event) => { + const sourceNodes = event.tool_output?.raw_output?.source_nodes || []; + return sourceNodes.map((node) => { + return { + id: node.node.id_, + metadata: node.node.metadata, + score: node.score, + text: node.node.text, + url: node.node.metadata.url, + }; + }); + }) || [] + ); + }, [queryIndexEvents]); + + return ; +} diff --git a/templates/types/streaming/nextjs/app/components/ui/chat/tools/weather-card.tsx b/templates/types/streaming/nextjs/app/components/ui/chat/tools/weather-card.tsx index 8720c042a..74e40cf70 100644 --- a/templates/types/streaming/nextjs/app/components/ui/chat/tools/weather-card.tsx +++ b/templates/types/streaming/nextjs/app/components/ui/chat/tools/weather-card.tsx @@ -1,3 +1,8 @@ +import { getCustomAnnotation, useChatMessage } from "@llamaindex/chat-ui"; +import { ChatEvents } from "@llamaindex/chat-ui/widgets"; +import { useMemo } from "react"; +import { z } from "zod"; + export interface WeatherData { latitude: number; longitude: number; @@ -177,37 +182,120 @@ export function WeatherCard({ data }: { data: WeatherData }) { ); return ( -
-
-
-
{currentDayString}
-
- - {data.current.temperature_2m} {data.current_units.temperature_2m} - - {weatherCodeDisplayMap[data.current.weather_code].icon} + data && ( +
+
+
+
{currentDayString}
+
+ + {data.current.temperature_2m}{" "} + {data.current_units.temperature_2m} + + {weatherCodeDisplayMap[data.current.weather_code].icon} +
+ + {weatherCodeDisplayMap[data.current.weather_code].status} + +
+
+ {data.daily.time.map((time, index) => { + if (index === 0) return null; // skip the current day + return ( +
+ {displayDay(time)} +
+ {weatherCodeDisplayMap[data.daily.weather_code[index]].icon} +
+ + {weatherCodeDisplayMap[data.daily.weather_code[index]].status} + +
+ ); + })}
- - {weatherCodeDisplayMap[data.current.weather_code].status} -
-
- {data.daily.time.map((time, index) => { - if (index === 0) return null; // skip the current day + ) + ); +} + +// A new component for the weather tool which uses the WeatherCard component with the new data schema from agent workflow events +const WeatherToolSchema = z.object({ + tool_name: z.literal("get_weather_information"), + tool_kwargs: z.object({ + location: z.string(), + }), + tool_id: z.string(), + tool_output: z.optional( + z + .object({ + content: z.string(), + tool_name: z.string(), + raw_input: z.record(z.unknown()), + raw_output: z.custom(), + is_error: z.boolean().optional(), + }) + .optional(), + ), + return_direct: z.boolean().optional(), +}); + +type WeatherTool = z.infer; + +type GroupedWeatherQuery = { + initial: WeatherTool; + output?: WeatherTool; +}; + +export function WeatherToolComponent() { + const { message } = useChatMessage(); + + const weatherEvents = getCustomAnnotation( + message.annotations, + (annotation: unknown) => { + const result = WeatherToolSchema.safeParse(annotation); + return result.success; + }, + ); + + // Group events by tool_id + const groupedWeatherQueries = useMemo(() => { + const groups = new Map(); + + weatherEvents?.forEach((event: WeatherTool) => { + groups.set(event.tool_id, { initial: event }); + }); + + return Array.from(groups.values()); + }, [weatherEvents]); + + return ( + groupedWeatherQueries.length > 0 && ( +
+ {groupedWeatherQueries.map(({ initial }) => { + if (!initial.tool_output?.raw_output) { + return ( + + ); + } + return ( -
- {displayDay(time)} -
- {weatherCodeDisplayMap[data.daily.weather_code[index]].icon} -
- - {weatherCodeDisplayMap[data.daily.weather_code[index]].status} - -
+ ); })}
-
+ ) ); } diff --git a/templates/types/streaming/nextjs/package.json b/templates/types/streaming/nextjs/package.json index 4407ce861..f4bfb1104 100644 --- a/templates/types/streaming/nextjs/package.json +++ b/templates/types/streaming/nextjs/package.json @@ -37,7 +37,8 @@ "tiktoken": "^1.0.15", "uuid": "^9.0.1", "marked": "^14.1.2", - "wikipedia": "^2.1.2" + "wikipedia": "^2.1.2", + "zod": "^3.24.2" }, "devDependencies": { "@types/node": "^20.10.3",