Skip to content

Commit 2a5238c

Browse files
authored
Merge pull request #682 from guardrails-ai/async-func-calling
Better Async Function Calling
2 parents 9a4f436 + ab4e847 commit 2a5238c

File tree

1 file changed

+12
-23
lines changed

1 file changed

+12
-23
lines changed

guardrails/run/async_runner.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
from functools import partial
23
from typing import Any, Dict, List, Optional, Tuple, Type, Union
34

45
from pydantic import BaseModel
@@ -259,37 +260,25 @@ async def async_call(
259260
2. Convert the response string to a dict,
260261
3. Log the output
261262
"""
263+
# If the API supports a base model, pass it in.
264+
api_fn = api
265+
if api is not None:
266+
supports_base_model = getattr(api, "supports_base_model", False)
267+
if supports_base_model:
268+
api_fn = partial(api, base_model=self.base_model)
269+
262270
if output is not None:
263271
llm_response = LLMResponse(
264272
output=output,
265273
)
266-
elif api is None:
274+
elif api_fn is None:
267275
raise ValueError("Either API or output must be provided.")
268276
elif msg_history:
269-
try:
270-
llm_response = await api(
271-
msg_history=msg_history_source(msg_history),
272-
base_model=self.base_model,
273-
)
274-
except Exception:
275-
# If the API call fails, try calling again without the base model.
276-
llm_response = await api(msg_history=msg_history_source(msg_history))
277+
llm_response = await api_fn(msg_history=msg_history_source(msg_history))
277278
elif prompt and instructions:
278-
try:
279-
llm_response = await api(
280-
prompt.source,
281-
instructions=instructions.source,
282-
base_model=self.base_model,
283-
)
284-
except Exception:
285-
llm_response = await api(
286-
prompt.source, instructions=instructions.source
287-
)
279+
llm_response = await api_fn(prompt.source, instructions=instructions.source)
288280
elif prompt:
289-
try:
290-
llm_response = await api(prompt.source, base_model=self.base_model)
291-
except Exception:
292-
llm_response = await api(prompt.source)
281+
llm_response = await api_fn(prompt.source)
293282
else:
294283
raise ValueError("'output', 'prompt' or 'msg_history' must be provided.")
295284

0 commit comments

Comments
 (0)