Skip to content

Commit d3e8e16

Browse files
authored
Merge pull request #881 from PrefectHQ/instructions
Allow assistant instructions to be jinja and self-referential
2 parents 0da9680 + d52c07b commit d3e8e16

File tree

3 files changed

+51
-47
lines changed

3 files changed

+51
-47
lines changed

docs/docs/interactive/assistants.md

+3
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,9 @@ Each assistant can be given `instructions` that describe its purpose, personalit
171171
!!! success "Result"
172172
![](/assets/images/docs/assistants/instructions.png)
173173

174+
Instructions are rendered as a Jinja template, which means you can use variables and conditionals to customize the assistant's behavior. A special variable, `self_` is provided to the template, which represents the assistant object itself. This allows you to template the assistant's name, tools, or other attributes into the instructions.
175+
176+
174177
### Tools
175178

176179
Each assistant can be given a list of `tools` that it can use when responding to a message. Tools are a way to extend the assistant's capabilities beyond its default behavior, including giving it access to external systems like the internet, a database, your computer, or any API.

src/marvin/beta/assistants/assistants.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
run_async,
2121
run_sync,
2222
)
23+
from marvin.utilities.jinja import Environment as JinjaEnvironment
2324
from marvin.utilities.logging import get_logger
2425

2526
from .threads import Thread
@@ -92,7 +93,9 @@ def get_tools(self) -> list[AssistantTool]:
9293
]
9394

9495
def get_instructions(self, thread: Thread = None) -> str:
95-
return self.instructions or ""
96+
if self.instructions:
97+
return JinjaEnvironment.render(self.instructions, self_=self)
98+
return ""
9699

97100
@expose_sync_method("say")
98101
async def say_async(

tests/ai/beta/vision/test_extract.py

+44-46
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
class Location(BaseModel):
1010
city: str
11-
state: str = Field(description="The two letter abbreviation")
11+
state: str = Field(description="The two letter abbreviation for the state")
1212

1313

1414
@pytest.mark.flaky(max_runs=2)
@@ -17,45 +17,45 @@ def test_ny(self):
1717
img = marvin.beta.Image(
1818
"https://images.unsplash.com/photo-1568515387631-8b650bbcdb90"
1919
)
20-
result = marvin.beta.extract(img, target=Location)
21-
assert result in (
22-
[Location(city="New York", state="NY")],
23-
[Location(city="New York City", state="NY")],
24-
)
20+
locations = marvin.beta.extract(img, target=Location)
21+
assert len(locations) == 1
22+
location = locations[0]
23+
assert location.city.startswith("New York") or location.city == "Manhattan"
24+
assert location.state == "NY"
2525

2626
def test_ny_images_input(self):
2727
img = marvin.beta.Image(
2828
"https://images.unsplash.com/photo-1568515387631-8b650bbcdb90"
2929
)
30-
result = marvin.beta.extract(data=None, images=[img], target=Location)
31-
assert result in (
32-
[Location(city="New York", state="NY")],
33-
[Location(city="New York City", state="NY")],
34-
)
30+
locations = marvin.beta.extract(data=None, images=[img], target=Location)
31+
assert len(locations) == 1
32+
location = locations[0]
33+
assert location.city.startswith("New York") or location.city == "Manhattan"
34+
assert location.state == "NY"
3535

3636
def test_ny_image_input(self):
3737
img = marvin.beta.Image(
3838
"https://images.unsplash.com/photo-1568515387631-8b650bbcdb90"
3939
)
40-
result = marvin.beta.extract(data=img, target=Location)
41-
assert result in (
42-
[Location(city="New York", state="NY")],
43-
[Location(city="New York City", state="NY")],
44-
)
40+
locations = marvin.beta.extract(data=img, target=Location)
41+
assert len(locations) == 1
42+
location = locations[0]
43+
assert location.city.startswith("New York") or location.city == "Manhattan"
44+
assert location.state == "NY"
4545

4646
def test_ny_image_and_text(self):
4747
img = marvin.beta.Image(
4848
"https://images.unsplash.com/photo-1568515387631-8b650bbcdb90"
4949
)
50-
result = marvin.beta.extract(
50+
locations = marvin.beta.extract(
5151
data="I see the empire state building",
5252
images=[img],
5353
target=Location,
5454
)
55-
assert result in (
56-
[Location(city="New York", state="NY")],
57-
[Location(city="New York City", state="NY")],
58-
)
55+
assert len(locations) == 1
56+
location = locations[0]
57+
assert location.city.startswith("New York") or location.city == "Manhattan"
58+
assert location.state == "NY"
5959

6060
@pytest.mark.flaky(max_runs=3)
6161
def test_dog(self):
@@ -90,11 +90,11 @@ async def test_ny(self):
9090
img = marvin.beta.Image(
9191
"https://images.unsplash.com/photo-1568515387631-8b650bbcdb90"
9292
)
93-
result = await marvin.beta.extract_async(img, target=Location)
94-
assert result in (
95-
[Location(city="New York", state="NY")],
96-
[Location(city="New York City", state="NY")],
97-
)
93+
locations = await marvin.beta.extract_async(img, target=Location)
94+
assert len(locations) == 1
95+
location = locations[0]
96+
assert location.city.startswith("New York") or location.city == "Manhattan"
97+
assert location.state == "NY"
9898

9999

100100
class TestMapping:
@@ -105,16 +105,15 @@ def test_map(self):
105105
dc = marvin.beta.Image(
106106
"https://images.unsplash.com/photo-1617581629397-a72507c3de9e"
107107
)
108-
result = marvin.beta.extract.map([ny, dc], target=Location)
109-
assert isinstance(result, list)
110-
assert result[0][0] in (
111-
Location(city="New York", state="NY"),
112-
Location(city="New York City", state="NY"),
113-
)
114-
assert result[1][0] in (
115-
Location(city="Washington", state="DC"),
116-
Location(city="Washington", state="D.C."),
117-
)
108+
locations = marvin.beta.extract.map([ny, dc], target=Location)
109+
assert len(locations) == 2
110+
ny_location, dc_location = locations
111+
112+
assert ny_location[0].city.startswith("New York")
113+
assert ny_location[0].state == "NY"
114+
115+
assert dc_location[0].city == "Washington"
116+
assert dc_location[0].state.index("D") < dc_location[0].state.index("C")
118117

119118
async def test_async_map(self):
120119
ny = marvin.beta.Image(
@@ -123,13 +122,12 @@ async def test_async_map(self):
123122
dc = marvin.beta.Image(
124123
"https://images.unsplash.com/photo-1617581629397-a72507c3de9e"
125124
)
126-
result = await marvin.beta.extract_async.map([ny, dc], target=Location)
127-
assert isinstance(result, list)
128-
assert result[0][0] in (
129-
Location(city="New York", state="NY"),
130-
Location(city="New York City", state="NY"),
131-
)
132-
assert result[1][0] in (
133-
Location(city="Washington", state="DC"),
134-
Location(city="Washington", state="D.C."),
135-
)
125+
locations = await marvin.beta.extract_async.map([ny, dc], target=Location)
126+
assert len(locations) == 2
127+
ny_location, dc_location = locations
128+
129+
assert ny_location[0].city.startswith("New York")
130+
assert ny_location[0].state == "NY"
131+
132+
assert dc_location[0].city == "Washington"
133+
assert dc_location[0].state.index("D") < dc_location[0].state.index("C")

0 commit comments

Comments
 (0)