Skip to content

Commit e4b408f

Browse files
committed
update cli
1 parent 77ffbc9 commit e4b408f

File tree

10 files changed

+262
-673
lines changed

10 files changed

+262
-673
lines changed

examples/refresh_vectorstore/tpuf_namespace.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# ]
66
# ///
77

8+
import os
89
from datetime import timedelta
910

1011
from prefect import flow, task
@@ -57,14 +58,17 @@
5758
}
5859

5960

61+
def _cache_key_with_invalidation(context, parameters):
62+
return f"{task_input_hash(context, parameters)}:{os.getenv("RAGGY_CACHE_VERSION", "0")}"
63+
64+
6065
@task(
61-
retries=2,
62-
retry_delay_seconds=[3, 60],
63-
cache_key_fn=task_input_hash,
66+
retries=1,
67+
retry_delay_seconds=3,
68+
cache_key_fn=_cache_key_with_invalidation,
6469
cache_expiration=timedelta(days=1),
6570
task_run_name="Run {loader.__class__.__name__}",
6671
persist_result=True,
67-
# refresh_cache=True,
6872
)
6973
async def run_loader(loader: Loader) -> list[Document]:
7074
return await loader.load()

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dependencies = [
1919
"chardet",
2020
"fake-useragent",
2121
"gh-util",
22-
"openai>1.0.0",
22+
"pydantic-ai-slim[openai]",
2323
"pypdf",
2424
"tenacity",
2525
"tiktoken",

src/raggy/cli/__init__.py

+87-76
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import argparse
22
import os
33
import sys
4+
import asyncio
45
from datetime import datetime, timezone
56
from pathlib import Path
6-
7-
import openai
87
from prompt_toolkit import PromptSession
98
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
109
from prompt_toolkit.history import FileHistory
@@ -15,7 +14,10 @@
1514
from rich.syntax import Syntax
1615
from rich.text import Text
1716

17+
from pydantic_ai import Agent, result as pa_result
18+
1819

20+
# Prettify code fences with Rich
1921
class SimpleCodeBlock(CodeBlock):
2022
def __rich_console__(
2123
self, console: Console, options: ConsoleOptions
@@ -39,61 +41,53 @@ def app() -> int:
3941
parser = argparse.ArgumentParser(
4042
prog="aicli",
4143
description="""\
42-
OpenAI powered AI CLI (thank you samuelcolvin)
44+
Pydantic AI powered CLI
4345
4446
Special prompts:
4547
* `show-markdown` - show the markdown output from the previous response
4648
* `multiline` - toggle multiline mode
4749
""",
4850
)
49-
parser.add_argument(
50-
"prompt", nargs="?", help="AI Prompt, if omitted fall into interactive mode"
51-
)
52-
53-
parser.add_argument(
54-
"--no-stream",
55-
action="store_true",
56-
help="Whether to stream responses from OpenAI",
57-
)
58-
51+
parser.add_argument("prompt", nargs="?", help="AI Prompt, else interactive mode")
52+
parser.add_argument("--no-stream", action="store_true", help="Disable streaming")
5953
parser.add_argument("--version", action="store_true", help="Show version and exit")
6054

6155
args = parser.parse_args()
6256

6357
console = Console()
64-
console.print("OpenAI powered AI CLI", style="green bold", highlight=False)
58+
console.print("Pydantic AI CLI", style="green bold", highlight=False)
6559
if args.version:
6660
return 0
6761

68-
try:
69-
openai_api_key = os.environ["OPENAI_API_KEY"]
70-
except KeyError:
62+
# Check for an API key (e.g. OPENAI_API_KEY)
63+
if "OPENAI_API_KEY" not in os.environ:
7164
console.print(
7265
"You must set the OPENAI_API_KEY environment variable", style="red"
7366
)
7467
return 1
7568

76-
client = openai.OpenAI(api_key=openai_api_key)
77-
78-
now_utc = datetime.now(timezone.utc)
79-
t = now_utc.astimezone().tzinfo.tzname(now_utc) # type: ignore
80-
setup = f"""\
81-
Help the user by responding to their request, the output should
82-
be concise and always written in markdown. The current date and time
83-
is {datetime.now()} {t}. The user is running {sys.platform}."""
69+
# Create your agent; we set a global system prompt
70+
agent = Agent(
71+
"openai:gpt-4o",
72+
system_prompt="Be a helpful assistant and respond in concise markdown.",
73+
)
8474

75+
# We'll accumulate the conversation in here (both user and assistant messages)
76+
conversation = None
8577
stream = not args.no_stream
86-
messages = [{"role": "system", "content": setup}]
8778

79+
# If the user supplied a single prompt, just run once
8880
if args.prompt:
89-
messages.append({"role": "user", "content": args.prompt})
9081
try:
91-
ask_openai(client, messages, stream, console)
82+
asyncio.run(
83+
run_and_display(agent, args.prompt, conversation, stream, console)
84+
)
9285
except KeyboardInterrupt:
9386
pass
9487
return 0
9588

96-
history = Path().home() / ".openai-prompt-history.txt"
89+
# Otherwise, interactive mode with prompt_toolkit
90+
history = Path.home() / ".openai-prompt-history.txt"
9791
session = PromptSession(history=FileHistory(str(history)))
9892
multiline = False
9993

@@ -105,70 +99,87 @@ def app() -> int:
10599
except (KeyboardInterrupt, EOFError):
106100
return 0
107101

108-
if not text.strip():
102+
cmd = text.lower().strip()
103+
if not cmd:
109104
continue
110105

111-
ident_prompt = text.lower().strip(" ").replace(" ", "-")
112-
if ident_prompt == "show-markdown":
113-
last_content = messages[-1]["content"]
114-
console.print("[dim]Last markdown output of last question:[/dim]\n")
115-
console.print(
116-
Syntax(last_content, lexer="markdown", background_color="default")
117-
)
106+
if cmd == "show-markdown":
107+
# Show last assistant message
108+
if not conversation:
109+
console.print("No messages yet.", style="dim")
110+
continue
111+
# The last run result's assistant message is the last item
112+
# (the user might have broken the loop, so we search from end)
113+
assistant_msg = None
114+
for m in reversed(conversation):
115+
if m.kind == "response":
116+
# Collect text parts from the response
117+
text_part = "".join(
118+
p.content for p in m.parts if p.part_kind == "text"
119+
)
120+
assistant_msg = text_part
121+
break
122+
if assistant_msg:
123+
console.print("[dim]Last assistant markdown output:[/dim]\n")
124+
console.print(
125+
Syntax(assistant_msg, lexer="markdown", background_color="default")
126+
)
127+
else:
128+
console.print("No assistant response found.", style="dim")
118129
continue
119-
elif ident_prompt == "multiline":
130+
131+
elif cmd == "multiline":
120132
multiline = not multiline
121133
if multiline:
122134
console.print(
123135
"Enabling multiline mode. "
124-
"[dim]Press [Meta+Enter] or [Esc] followed by [Enter] to accept input.[/dim]"
136+
"[dim]Press [Meta+Enter] or [Esc] then [Enter] to submit.[/dim]"
125137
)
126138
else:
127139
console.print("Disabling multiline mode.")
128140
continue
129141

130-
messages.append({"role": "user", "content": text})
131-
142+
# Normal user prompt
132143
try:
133-
content = ask_openai(client, messages, stream, console)
144+
conversation = asyncio.run(
145+
run_and_display(agent, text, conversation, stream, console)
146+
)
134147
except KeyboardInterrupt:
135148
return 0
136-
messages.append({"role": "assistant", "content": content})
137-
138-
139-
def ask_openai(
140-
client: openai.OpenAI,
141-
messages: list[dict[str, str]],
142-
stream: bool,
143-
console: Console,
144-
) -> str:
145-
with Status("[dim]Working on it…[/dim]", console=console):
146-
response = client.chat.completions.create(
147-
model="gpt-4", messages=messages, stream=stream
148-
)
149149

150+
return 0
151+
152+
153+
async def run_and_display(
154+
agent: Agent, user_text: str, conversation, stream: bool, console: Console
155+
):
156+
"""
157+
Runs the agent (stream or not) with user_text, returning the updated conversation.
158+
If conversation is None, run from scratch (includes system prompt).
159+
Otherwise pass conversation as message_history to continue it.
160+
"""
150161
console.print("\nResponse:", style="green")
151-
if stream:
152-
content = ""
153-
interrupted = False
154-
with Live("", refresh_per_second=15, console=console) as live:
155-
try:
156-
for chunk in response:
157-
if chunk.choices[0].finish_reason is not None:
158-
break
159-
chunk_text = chunk.choices[0].delta.content
160-
content += chunk_text
161-
live.update(Markdown(content))
162-
except KeyboardInterrupt:
163-
interrupted = True
164-
165-
if interrupted:
166-
console.print("[dim]Interrupted[/dim]")
167-
else:
168-
content = response.choices[0].message.content
169-
console.print(Markdown(content))
170-
171-
return content
162+
163+
with Live(
164+
"[dim]Working on it…[/dim]",
165+
console=console,
166+
refresh_per_second=15,
167+
vertical_overflow="visible",
168+
) as live:
169+
if stream:
170+
async with agent.run_stream(user_text, message_history=conversation) as run:
171+
try:
172+
async for chunk in run.stream_text():
173+
live.update(Markdown(chunk))
174+
except Exception as e:
175+
console.print(f"Error: {e}", style="red")
176+
new_conversation = run.all_messages()
177+
else:
178+
run_result = await agent.run(user_text, message_history=conversation)
179+
live.update(Markdown(run_result.data))
180+
new_conversation = run_result.all_messages()
181+
182+
return new_conversation
172183

173184

174185
if __name__ == "__main__":

src/raggy/documents.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
import asyncio
22
import inspect
33
from functools import partial
4-
from typing import Annotated
4+
from typing import Annotated, Callable
55

66
from jinja2 import Environment, Template
7-
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
7+
from pydantic import (
8+
BaseModel,
9+
ConfigDict,
10+
Field,
11+
PrivateAttr,
12+
field_validator,
13+
model_validator,
14+
)
815

916
from raggy.utilities.ids import generate_prefixed_uuid
1017
from raggy.utilities.text import count_tokens, extract_keywords, hash_text, split_text
@@ -37,6 +44,8 @@ class Document(BaseModel):
3744
tokens: int | None = Field(default=None)
3845
keywords: list[str] = Field(default_factory=list)
3946

47+
_parent_document_id: str | None = PrivateAttr(default=None)
48+
4049
@field_validator("metadata", mode="before")
4150
@classmethod
4251
def ensure_metadata(cls, v):
@@ -76,6 +85,7 @@ async def document_to_excerpts(
7685
excerpt_template: Template | None = None,
7786
chunk_tokens: int = 300,
7887
overlap: Annotated[float, Field(strict=True, ge=0, le=1)] = 0.1,
88+
split_text_fn: Callable[..., list[str]] = split_text,
7989
**extra_template_kwargs,
8090
) -> list[Document]:
8191
"""
@@ -91,7 +101,7 @@ async def document_to_excerpts(
91101
if not excerpt_template:
92102
excerpt_template = EXCERPT_TEMPLATE
93103

94-
text_chunks: list[str] = split_text(
104+
text_chunks: list[str] = split_text_fn(
95105
text=document.text,
96106
chunk_size=chunk_tokens,
97107
chunk_overlap=overlap,
@@ -126,7 +136,7 @@ async def _create_excerpt(
126136
**extra_template_kwargs,
127137
)
128138
return Document(
129-
parent_document_id=document.id,
139+
_parent_document_id=document.id, # type: ignore[reportCallIssue]
130140
text=excerpt_text,
131141
keywords=keywords,
132142
metadata=document.metadata if document.metadata else {},

0 commit comments

Comments
 (0)