Skip to content

Commit e814140

Browse files
YusakuNo1l0lawrence
authored andcommitted
Prompt support for Inference SDK (Azure#37917)
* Prompty support within Azure AI Inference SDK * Fix unit test * Address PR feedback with copyright, merge PromptConfig to PromptTemplate * Add comment and set model_name as optional * Bug fixes * Updated parameter names from PM feedbacks * Improve sample code and unit tests * Update readme and comments * Rename files * Address PR comment * add Pydantic as dependency * Fix type errors * Fix spelling issues * Address PR comments and fix linter issues * Fix type import for "Self" * Change to keyword-only constructor and fix linter issues * Rename function `from_message` to `from_str`; `render` to `create_messages` * Change from `from_str` to `from_string` * Merge latest code from `microsoft/prompty` and resolve linter issues * Fix PR comment * Fix PR comments
1 parent 73362b3 commit e814140

21 files changed

+2751
-12
lines changed

.vscode/cspell.json

+13-5
Original file line numberDiff line numberDiff line change
@@ -1324,12 +1324,20 @@
13241324
{
13251325
"filename": "sdk/ai/azure-ai-inference/**",
13261326
"words": [
1327-
"ubinary",
1328-
"mros",
1329-
"Nify",
13301327
"ctxt",
1331-
"wday",
1332-
"dtype"
1328+
"dels",
1329+
"dtype",
1330+
"fmatter",
1331+
"fspath",
1332+
"fstring",
1333+
"ldel",
1334+
"mros",
1335+
"nify",
1336+
"okwargs",
1337+
"prompty",
1338+
"rdel",
1339+
"ubinary",
1340+
"wday"
13331341
]
13341342
},
13351343
{

sdk/ai/azure-ai-inference/azure/ai/inference/_model_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ def _get_deserialize_callable_from_annotation( # pylint: disable=R0911, R0915,
674674
except AttributeError:
675675
model_name = annotation
676676
if module is not None:
677-
annotation = _get_model(module, model_name)
677+
annotation = _get_model(module, model_name) # type: ignore
678678

679679
try:
680680
if module and _is_model(annotation):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
# pylint: disable=unused-import
6+
from ._patch import patch_sdk as _patch_sdk, PromptTemplate
7+
8+
_patch_sdk()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
# ------------------------------------
2+
# Copyright (c) Microsoft Corporation.
3+
# Licensed under the MIT License.
4+
# ------------------------------------
5+
# mypy: disable-error-code="assignment,attr-defined,index,arg-type"
6+
# pylint: disable=line-too-long,R,consider-iterating-dictionary,raise-missing-from,dangerous-default-value
7+
from __future__ import annotations
8+
import os
9+
from dataclasses import dataclass, field, asdict
10+
from pathlib import Path
11+
from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Union
12+
from ._tracer import Tracer, to_dict
13+
from ._utils import load_json
14+
15+
16+
@dataclass
17+
class ToolCall:
18+
id: str
19+
name: str
20+
arguments: str
21+
22+
23+
@dataclass
24+
class PropertySettings:
25+
"""PropertySettings class to define the properties of the model
26+
27+
Attributes
28+
----------
29+
type : str
30+
The type of the property
31+
default : Any
32+
The default value of the property
33+
description : str
34+
The description of the property
35+
"""
36+
37+
type: Literal["string", "number", "array", "object", "boolean"]
38+
default: Union[str, int, float, List, Dict, bool, None] = field(default=None)
39+
description: str = field(default="")
40+
41+
42+
@dataclass
43+
class ModelSettings:
44+
"""ModelSettings class to define the model of the prompty
45+
46+
Attributes
47+
----------
48+
api : str
49+
The api of the model
50+
configuration : Dict
51+
The configuration of the model
52+
parameters : Dict
53+
The parameters of the model
54+
response : Dict
55+
The response of the model
56+
"""
57+
58+
api: str = field(default="")
59+
configuration: Dict = field(default_factory=dict)
60+
parameters: Dict = field(default_factory=dict)
61+
response: Dict = field(default_factory=dict)
62+
63+
64+
@dataclass
65+
class TemplateSettings:
66+
"""TemplateSettings class to define the template of the prompty
67+
68+
Attributes
69+
----------
70+
type : str
71+
The type of the template
72+
parser : str
73+
The parser of the template
74+
"""
75+
76+
type: str = field(default="mustache")
77+
parser: str = field(default="")
78+
79+
80+
@dataclass
81+
class Prompty:
82+
"""Prompty class to define the prompty
83+
84+
Attributes
85+
----------
86+
name : str
87+
The name of the prompty
88+
description : str
89+
The description of the prompty
90+
authors : List[str]
91+
The authors of the prompty
92+
tags : List[str]
93+
The tags of the prompty
94+
version : str
95+
The version of the prompty
96+
base : str
97+
The base of the prompty
98+
basePrompty : Prompty
99+
The base prompty
100+
model : ModelSettings
101+
The model of the prompty
102+
sample : Dict
103+
The sample of the prompty
104+
inputs : Dict[str, PropertySettings]
105+
The inputs of the prompty
106+
outputs : Dict[str, PropertySettings]
107+
The outputs of the prompty
108+
template : TemplateSettings
109+
The template of the prompty
110+
file : FilePath
111+
The file of the prompty
112+
content : Union[str, List[str], Dict]
113+
The content of the prompty
114+
"""
115+
116+
# metadata
117+
name: str = field(default="")
118+
description: str = field(default="")
119+
authors: List[str] = field(default_factory=list)
120+
tags: List[str] = field(default_factory=list)
121+
version: str = field(default="")
122+
base: str = field(default="")
123+
basePrompty: Union[Prompty, None] = field(default=None)
124+
# model
125+
model: ModelSettings = field(default_factory=ModelSettings)
126+
127+
# sample
128+
sample: Dict = field(default_factory=dict)
129+
130+
# input / output
131+
inputs: Dict[str, PropertySettings] = field(default_factory=dict)
132+
outputs: Dict[str, PropertySettings] = field(default_factory=dict)
133+
134+
# template
135+
template: TemplateSettings = field(default_factory=TemplateSettings)
136+
137+
file: Union[Path, str] = field(default="")
138+
content: Union[str, List[str], Dict] = field(default="")
139+
140+
def to_safe_dict(self) -> Dict[str, Any]:
141+
d = {}
142+
if self.model:
143+
d["model"] = asdict(self.model)
144+
_mask_secrets(d, ["model", "configuration"])
145+
if self.template:
146+
d["template"] = asdict(self.template)
147+
if self.inputs:
148+
d["inputs"] = {k: asdict(v) for k, v in self.inputs.items()}
149+
if self.outputs:
150+
d["outputs"] = {k: asdict(v) for k, v in self.outputs.items()}
151+
if self.file:
152+
d["file"] = str(self.file.as_posix()) if isinstance(self.file, Path) else self.file
153+
return d
154+
155+
@staticmethod
156+
def hoist_base_prompty(top: Prompty, base: Prompty) -> Prompty:
157+
top.name = base.name if top.name == "" else top.name
158+
top.description = base.description if top.description == "" else top.description
159+
top.authors = list(set(base.authors + top.authors))
160+
top.tags = list(set(base.tags + top.tags))
161+
top.version = base.version if top.version == "" else top.version
162+
163+
top.model.api = base.model.api if top.model.api == "" else top.model.api
164+
top.model.configuration = param_hoisting(top.model.configuration, base.model.configuration)
165+
top.model.parameters = param_hoisting(top.model.parameters, base.model.parameters)
166+
top.model.response = param_hoisting(top.model.response, base.model.response)
167+
168+
top.sample = param_hoisting(top.sample, base.sample)
169+
170+
top.basePrompty = base
171+
172+
return top
173+
174+
@staticmethod
175+
def _process_file(file: str, parent: Path) -> Any:
176+
file_path = Path(parent / Path(file)).resolve().absolute()
177+
if file_path.exists():
178+
items = load_json(file_path)
179+
if isinstance(items, list):
180+
return [Prompty.normalize(value, parent) for value in items]
181+
elif isinstance(items, Dict):
182+
return {key: Prompty.normalize(value, parent) for key, value in items.items()}
183+
else:
184+
return items
185+
else:
186+
raise FileNotFoundError(f"File {file} not found")
187+
188+
@staticmethod
189+
def _process_env(variable: str, env_error=True, default: Union[str, None] = None) -> Any:
190+
if variable in os.environ.keys():
191+
return os.environ[variable]
192+
else:
193+
if default:
194+
return default
195+
if env_error:
196+
raise ValueError(f"Variable {variable} not found in environment")
197+
198+
return ""
199+
200+
@staticmethod
201+
def normalize(attribute: Any, parent: Path, env_error=True) -> Any:
202+
if isinstance(attribute, str):
203+
attribute = attribute.strip()
204+
if attribute.startswith("${") and attribute.endswith("}"):
205+
# check if env or file
206+
variable = attribute[2:-1].split(":")
207+
if variable[0] == "env" and len(variable) > 1:
208+
return Prompty._process_env(
209+
variable[1],
210+
env_error,
211+
variable[2] if len(variable) > 2 else None,
212+
)
213+
elif variable[0] == "file" and len(variable) > 1:
214+
return Prompty._process_file(variable[1], parent)
215+
else:
216+
raise ValueError(f"Invalid attribute format ({attribute})")
217+
else:
218+
return attribute
219+
elif isinstance(attribute, list):
220+
return [Prompty.normalize(value, parent) for value in attribute]
221+
elif isinstance(attribute, Dict):
222+
return {key: Prompty.normalize(value, parent) for key, value in attribute.items()}
223+
else:
224+
return attribute
225+
226+
227+
def param_hoisting(top: Dict[str, Any], bottom: Dict[str, Any], top_key: Union[str, None] = None) -> Dict[str, Any]:
228+
if top_key:
229+
new_dict = {**top[top_key]} if top_key in top else {}
230+
else:
231+
new_dict = {**top}
232+
for key, value in bottom.items():
233+
if not key in new_dict:
234+
new_dict[key] = value
235+
return new_dict
236+
237+
238+
class PromptyStream(Iterator):
239+
"""PromptyStream class to iterate over LLM stream.
240+
Necessary for Prompty to handle streaming data when tracing."""
241+
242+
def __init__(self, name: str, iterator: Iterator):
243+
self.name = name
244+
self.iterator = iterator
245+
self.items: List[Any] = []
246+
self.__name__ = "PromptyStream"
247+
248+
def __iter__(self):
249+
return self
250+
251+
def __next__(self):
252+
try:
253+
# enumerate but add to list
254+
o = self.iterator.__next__()
255+
self.items.append(o)
256+
return o
257+
258+
except StopIteration:
259+
# StopIteration is raised
260+
# contents are exhausted
261+
if len(self.items) > 0:
262+
with Tracer.start("PromptyStream") as trace:
263+
trace("signature", f"{self.name}.PromptyStream")
264+
trace("inputs", "None")
265+
trace("result", [to_dict(s) for s in self.items])
266+
267+
raise StopIteration
268+
269+
270+
class AsyncPromptyStream(AsyncIterator):
271+
"""AsyncPromptyStream class to iterate over LLM stream.
272+
Necessary for Prompty to handle streaming data when tracing."""
273+
274+
def __init__(self, name: str, iterator: AsyncIterator):
275+
self.name = name
276+
self.iterator = iterator
277+
self.items: List[Any] = []
278+
self.__name__ = "AsyncPromptyStream"
279+
280+
def __aiter__(self):
281+
return self
282+
283+
async def __anext__(self):
284+
try:
285+
# enumerate but add to list
286+
o = await self.iterator.__anext__()
287+
self.items.append(o)
288+
return o
289+
290+
except StopAsyncIteration:
291+
# StopIteration is raised
292+
# contents are exhausted
293+
if len(self.items) > 0:
294+
with Tracer.start("AsyncPromptyStream") as trace:
295+
trace("signature", f"{self.name}.AsyncPromptyStream")
296+
trace("inputs", "None")
297+
trace("result", [to_dict(s) for s in self.items])
298+
299+
raise StopAsyncIteration
300+
301+
302+
def _mask_secrets(d: Dict[str, Any], path: list[str], patterns: list[str] = ["key", "secret"]) -> bool:
303+
sub_d = d
304+
for key in path:
305+
if key not in sub_d:
306+
return False
307+
sub_d = sub_d[key]
308+
309+
for k, v in sub_d.items():
310+
if any([pattern in k.lower() for pattern in patterns]):
311+
sub_d[k] = "*" * len(v)
312+
return True

0 commit comments

Comments
 (0)