|
1 | 1 | import copy
|
| 2 | +from functools import partial |
2 | 3 | from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
3 | 4 |
|
4 | 5 | from pydantic import BaseModel
|
@@ -259,37 +260,25 @@ async def async_call(
|
259 | 260 | 2. Convert the response string to a dict,
|
260 | 261 | 3. Log the output
|
261 | 262 | """
|
| 263 | + # If the API supports a base model, pass it in. |
| 264 | + api_fn = api |
| 265 | + if api is not None: |
| 266 | + supports_base_model = getattr(api, "supports_base_model", False) |
| 267 | + if supports_base_model: |
| 268 | + api_fn = partial(api, base_model=self.base_model) |
| 269 | + |
262 | 270 | if output is not None:
|
263 | 271 | llm_response = LLMResponse(
|
264 | 272 | output=output,
|
265 | 273 | )
|
266 |
| - elif api is None: |
| 274 | + elif api_fn is None: |
267 | 275 | raise ValueError("Either API or output must be provided.")
|
268 | 276 | elif msg_history:
|
269 |
| - try: |
270 |
| - llm_response = await api( |
271 |
| - msg_history=msg_history_source(msg_history), |
272 |
| - base_model=self.base_model, |
273 |
| - ) |
274 |
| - except Exception: |
275 |
| - # If the API call fails, try calling again without the base model. |
276 |
| - llm_response = await api(msg_history=msg_history_source(msg_history)) |
| 277 | + llm_response = await api_fn(msg_history=msg_history_source(msg_history)) |
277 | 278 | elif prompt and instructions:
|
278 |
| - try: |
279 |
| - llm_response = await api( |
280 |
| - prompt.source, |
281 |
| - instructions=instructions.source, |
282 |
| - base_model=self.base_model, |
283 |
| - ) |
284 |
| - except Exception: |
285 |
| - llm_response = await api( |
286 |
| - prompt.source, instructions=instructions.source |
287 |
| - ) |
| 279 | + llm_response = await api_fn(prompt.source, instructions=instructions.source) |
288 | 280 | elif prompt:
|
289 |
| - try: |
290 |
| - llm_response = await api(prompt.source, base_model=self.base_model) |
291 |
| - except Exception: |
292 |
| - llm_response = await api(prompt.source) |
| 281 | + llm_response = await api_fn(prompt.source) |
293 | 282 | else:
|
294 | 283 | raise ValueError("'output', 'prompt' or 'msg_history' must be provided.")
|
295 | 284 |
|
|
0 commit comments