|
| 1 | +# ------------------------------------ |
| 2 | +# Copyright (c) Microsoft Corporation. |
| 3 | +# Licensed under the MIT License. |
| 4 | +# ------------------------------------ |
| 5 | +# mypy: disable-error-code="assignment,attr-defined,index,arg-type" |
| 6 | +# pylint: disable=line-too-long,R,consider-iterating-dictionary,raise-missing-from,dangerous-default-value |
| 7 | +from __future__ import annotations |
| 8 | +import os |
| 9 | +from dataclasses import dataclass, field, asdict |
| 10 | +from pathlib import Path |
| 11 | +from typing import Any, AsyncIterator, Dict, Iterator, List, Literal, Union |
| 12 | +from ._tracer import Tracer, to_dict |
| 13 | +from ._utils import load_json |
| 14 | + |
| 15 | + |
| 16 | +@dataclass |
| 17 | +class ToolCall: |
| 18 | + id: str |
| 19 | + name: str |
| 20 | + arguments: str |
| 21 | + |
| 22 | + |
| 23 | +@dataclass |
| 24 | +class PropertySettings: |
| 25 | + """PropertySettings class to define the properties of the model |
| 26 | +
|
| 27 | + Attributes |
| 28 | + ---------- |
| 29 | + type : str |
| 30 | + The type of the property |
| 31 | + default : Any |
| 32 | + The default value of the property |
| 33 | + description : str |
| 34 | + The description of the property |
| 35 | + """ |
| 36 | + |
| 37 | + type: Literal["string", "number", "array", "object", "boolean"] |
| 38 | + default: Union[str, int, float, List, Dict, bool, None] = field(default=None) |
| 39 | + description: str = field(default="") |
| 40 | + |
| 41 | + |
| 42 | +@dataclass |
| 43 | +class ModelSettings: |
| 44 | + """ModelSettings class to define the model of the prompty |
| 45 | +
|
| 46 | + Attributes |
| 47 | + ---------- |
| 48 | + api : str |
| 49 | + The api of the model |
| 50 | + configuration : Dict |
| 51 | + The configuration of the model |
| 52 | + parameters : Dict |
| 53 | + The parameters of the model |
| 54 | + response : Dict |
| 55 | + The response of the model |
| 56 | + """ |
| 57 | + |
| 58 | + api: str = field(default="") |
| 59 | + configuration: Dict = field(default_factory=dict) |
| 60 | + parameters: Dict = field(default_factory=dict) |
| 61 | + response: Dict = field(default_factory=dict) |
| 62 | + |
| 63 | + |
| 64 | +@dataclass |
| 65 | +class TemplateSettings: |
| 66 | + """TemplateSettings class to define the template of the prompty |
| 67 | +
|
| 68 | + Attributes |
| 69 | + ---------- |
| 70 | + type : str |
| 71 | + The type of the template |
| 72 | + parser : str |
| 73 | + The parser of the template |
| 74 | + """ |
| 75 | + |
| 76 | + type: str = field(default="mustache") |
| 77 | + parser: str = field(default="") |
| 78 | + |
| 79 | + |
| 80 | +@dataclass |
| 81 | +class Prompty: |
| 82 | + """Prompty class to define the prompty |
| 83 | +
|
| 84 | + Attributes |
| 85 | + ---------- |
| 86 | + name : str |
| 87 | + The name of the prompty |
| 88 | + description : str |
| 89 | + The description of the prompty |
| 90 | + authors : List[str] |
| 91 | + The authors of the prompty |
| 92 | + tags : List[str] |
| 93 | + The tags of the prompty |
| 94 | + version : str |
| 95 | + The version of the prompty |
| 96 | + base : str |
| 97 | + The base of the prompty |
| 98 | + basePrompty : Prompty |
| 99 | + The base prompty |
| 100 | + model : ModelSettings |
| 101 | + The model of the prompty |
| 102 | + sample : Dict |
| 103 | + The sample of the prompty |
| 104 | + inputs : Dict[str, PropertySettings] |
| 105 | + The inputs of the prompty |
| 106 | + outputs : Dict[str, PropertySettings] |
| 107 | + The outputs of the prompty |
| 108 | + template : TemplateSettings |
| 109 | + The template of the prompty |
| 110 | + file : FilePath |
| 111 | + The file of the prompty |
| 112 | + content : Union[str, List[str], Dict] |
| 113 | + The content of the prompty |
| 114 | + """ |
| 115 | + |
| 116 | + # metadata |
| 117 | + name: str = field(default="") |
| 118 | + description: str = field(default="") |
| 119 | + authors: List[str] = field(default_factory=list) |
| 120 | + tags: List[str] = field(default_factory=list) |
| 121 | + version: str = field(default="") |
| 122 | + base: str = field(default="") |
| 123 | + basePrompty: Union[Prompty, None] = field(default=None) |
| 124 | + # model |
| 125 | + model: ModelSettings = field(default_factory=ModelSettings) |
| 126 | + |
| 127 | + # sample |
| 128 | + sample: Dict = field(default_factory=dict) |
| 129 | + |
| 130 | + # input / output |
| 131 | + inputs: Dict[str, PropertySettings] = field(default_factory=dict) |
| 132 | + outputs: Dict[str, PropertySettings] = field(default_factory=dict) |
| 133 | + |
| 134 | + # template |
| 135 | + template: TemplateSettings = field(default_factory=TemplateSettings) |
| 136 | + |
| 137 | + file: Union[Path, str] = field(default="") |
| 138 | + content: Union[str, List[str], Dict] = field(default="") |
| 139 | + |
| 140 | + def to_safe_dict(self) -> Dict[str, Any]: |
| 141 | + d = {} |
| 142 | + if self.model: |
| 143 | + d["model"] = asdict(self.model) |
| 144 | + _mask_secrets(d, ["model", "configuration"]) |
| 145 | + if self.template: |
| 146 | + d["template"] = asdict(self.template) |
| 147 | + if self.inputs: |
| 148 | + d["inputs"] = {k: asdict(v) for k, v in self.inputs.items()} |
| 149 | + if self.outputs: |
| 150 | + d["outputs"] = {k: asdict(v) for k, v in self.outputs.items()} |
| 151 | + if self.file: |
| 152 | + d["file"] = str(self.file.as_posix()) if isinstance(self.file, Path) else self.file |
| 153 | + return d |
| 154 | + |
| 155 | + @staticmethod |
| 156 | + def hoist_base_prompty(top: Prompty, base: Prompty) -> Prompty: |
| 157 | + top.name = base.name if top.name == "" else top.name |
| 158 | + top.description = base.description if top.description == "" else top.description |
| 159 | + top.authors = list(set(base.authors + top.authors)) |
| 160 | + top.tags = list(set(base.tags + top.tags)) |
| 161 | + top.version = base.version if top.version == "" else top.version |
| 162 | + |
| 163 | + top.model.api = base.model.api if top.model.api == "" else top.model.api |
| 164 | + top.model.configuration = param_hoisting(top.model.configuration, base.model.configuration) |
| 165 | + top.model.parameters = param_hoisting(top.model.parameters, base.model.parameters) |
| 166 | + top.model.response = param_hoisting(top.model.response, base.model.response) |
| 167 | + |
| 168 | + top.sample = param_hoisting(top.sample, base.sample) |
| 169 | + |
| 170 | + top.basePrompty = base |
| 171 | + |
| 172 | + return top |
| 173 | + |
| 174 | + @staticmethod |
| 175 | + def _process_file(file: str, parent: Path) -> Any: |
| 176 | + file_path = Path(parent / Path(file)).resolve().absolute() |
| 177 | + if file_path.exists(): |
| 178 | + items = load_json(file_path) |
| 179 | + if isinstance(items, list): |
| 180 | + return [Prompty.normalize(value, parent) for value in items] |
| 181 | + elif isinstance(items, Dict): |
| 182 | + return {key: Prompty.normalize(value, parent) for key, value in items.items()} |
| 183 | + else: |
| 184 | + return items |
| 185 | + else: |
| 186 | + raise FileNotFoundError(f"File {file} not found") |
| 187 | + |
| 188 | + @staticmethod |
| 189 | + def _process_env(variable: str, env_error=True, default: Union[str, None] = None) -> Any: |
| 190 | + if variable in os.environ.keys(): |
| 191 | + return os.environ[variable] |
| 192 | + else: |
| 193 | + if default: |
| 194 | + return default |
| 195 | + if env_error: |
| 196 | + raise ValueError(f"Variable {variable} not found in environment") |
| 197 | + |
| 198 | + return "" |
| 199 | + |
| 200 | + @staticmethod |
| 201 | + def normalize(attribute: Any, parent: Path, env_error=True) -> Any: |
| 202 | + if isinstance(attribute, str): |
| 203 | + attribute = attribute.strip() |
| 204 | + if attribute.startswith("${") and attribute.endswith("}"): |
| 205 | + # check if env or file |
| 206 | + variable = attribute[2:-1].split(":") |
| 207 | + if variable[0] == "env" and len(variable) > 1: |
| 208 | + return Prompty._process_env( |
| 209 | + variable[1], |
| 210 | + env_error, |
| 211 | + variable[2] if len(variable) > 2 else None, |
| 212 | + ) |
| 213 | + elif variable[0] == "file" and len(variable) > 1: |
| 214 | + return Prompty._process_file(variable[1], parent) |
| 215 | + else: |
| 216 | + raise ValueError(f"Invalid attribute format ({attribute})") |
| 217 | + else: |
| 218 | + return attribute |
| 219 | + elif isinstance(attribute, list): |
| 220 | + return [Prompty.normalize(value, parent) for value in attribute] |
| 221 | + elif isinstance(attribute, Dict): |
| 222 | + return {key: Prompty.normalize(value, parent) for key, value in attribute.items()} |
| 223 | + else: |
| 224 | + return attribute |
| 225 | + |
| 226 | + |
| 227 | +def param_hoisting(top: Dict[str, Any], bottom: Dict[str, Any], top_key: Union[str, None] = None) -> Dict[str, Any]: |
| 228 | + if top_key: |
| 229 | + new_dict = {**top[top_key]} if top_key in top else {} |
| 230 | + else: |
| 231 | + new_dict = {**top} |
| 232 | + for key, value in bottom.items(): |
| 233 | + if not key in new_dict: |
| 234 | + new_dict[key] = value |
| 235 | + return new_dict |
| 236 | + |
| 237 | + |
| 238 | +class PromptyStream(Iterator): |
| 239 | + """PromptyStream class to iterate over LLM stream. |
| 240 | + Necessary for Prompty to handle streaming data when tracing.""" |
| 241 | + |
| 242 | + def __init__(self, name: str, iterator: Iterator): |
| 243 | + self.name = name |
| 244 | + self.iterator = iterator |
| 245 | + self.items: List[Any] = [] |
| 246 | + self.__name__ = "PromptyStream" |
| 247 | + |
| 248 | + def __iter__(self): |
| 249 | + return self |
| 250 | + |
| 251 | + def __next__(self): |
| 252 | + try: |
| 253 | + # enumerate but add to list |
| 254 | + o = self.iterator.__next__() |
| 255 | + self.items.append(o) |
| 256 | + return o |
| 257 | + |
| 258 | + except StopIteration: |
| 259 | + # StopIteration is raised |
| 260 | + # contents are exhausted |
| 261 | + if len(self.items) > 0: |
| 262 | + with Tracer.start("PromptyStream") as trace: |
| 263 | + trace("signature", f"{self.name}.PromptyStream") |
| 264 | + trace("inputs", "None") |
| 265 | + trace("result", [to_dict(s) for s in self.items]) |
| 266 | + |
| 267 | + raise StopIteration |
| 268 | + |
| 269 | + |
| 270 | +class AsyncPromptyStream(AsyncIterator): |
| 271 | + """AsyncPromptyStream class to iterate over LLM stream. |
| 272 | + Necessary for Prompty to handle streaming data when tracing.""" |
| 273 | + |
| 274 | + def __init__(self, name: str, iterator: AsyncIterator): |
| 275 | + self.name = name |
| 276 | + self.iterator = iterator |
| 277 | + self.items: List[Any] = [] |
| 278 | + self.__name__ = "AsyncPromptyStream" |
| 279 | + |
| 280 | + def __aiter__(self): |
| 281 | + return self |
| 282 | + |
| 283 | + async def __anext__(self): |
| 284 | + try: |
| 285 | + # enumerate but add to list |
| 286 | + o = await self.iterator.__anext__() |
| 287 | + self.items.append(o) |
| 288 | + return o |
| 289 | + |
| 290 | + except StopAsyncIteration: |
| 291 | + # StopIteration is raised |
| 292 | + # contents are exhausted |
| 293 | + if len(self.items) > 0: |
| 294 | + with Tracer.start("AsyncPromptyStream") as trace: |
| 295 | + trace("signature", f"{self.name}.AsyncPromptyStream") |
| 296 | + trace("inputs", "None") |
| 297 | + trace("result", [to_dict(s) for s in self.items]) |
| 298 | + |
| 299 | + raise StopAsyncIteration |
| 300 | + |
| 301 | + |
| 302 | +def _mask_secrets(d: Dict[str, Any], path: list[str], patterns: list[str] = ["key", "secret"]) -> bool: |
| 303 | + sub_d = d |
| 304 | + for key in path: |
| 305 | + if key not in sub_d: |
| 306 | + return False |
| 307 | + sub_d = sub_d[key] |
| 308 | + |
| 309 | + for k, v in sub_d.items(): |
| 310 | + if any([pattern in k.lower() for pattern in patterns]): |
| 311 | + sub_d[k] = "*" * len(v) |
| 312 | + return True |
0 commit comments