Skip to content

Commit 0ea8009

Browse files
committed
OpenAIv1: Don't instantiate openai.completions.create if key isn't present in environ
1 parent f973a83 commit 0ea8009

File tree

12 files changed

+108
-73
lines changed

12 files changed

+108
-73
lines changed

guardrails/applications/text2sql.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from guardrails.document_store import DocumentStoreBase, EphemeralDocumentStore
88
from guardrails.embedding import EmbeddingBase, OpenAIEmbedding
99
from guardrails.guard import Guard
10-
from guardrails.utils.openai_utils import static_openai_create_func
10+
from guardrails.utils.openai_utils import get_static_openai_create_func
1111
from guardrails.utils.sql_utils import create_sql_driver
1212
from guardrails.vectordb import Faiss, VectorDBBase
1313

@@ -70,7 +70,7 @@ def __init__(
7070
rail_params: Optional[Dict] = None,
7171
example_formatter: Callable = example_formatter,
7272
reask_prompt: str = REASK_PROMPT,
73-
llm_api: Callable = static_openai_create_func,
73+
llm_api: Optional[Callable] = None,
7474
llm_api_kwargs: Optional[Dict] = None,
7575
num_relevant_examples: int = 2,
7676
):
@@ -87,6 +87,8 @@ def __init__(
8787
example_formatter: Fn to format examples. Defaults to example_formatter.
8888
reask_prompt: Prompt to use for reasking. Defaults to REASK_PROMPT.
8989
"""
90+
if llm_api is None:
91+
llm_api = get_static_openai_create_func()
9092

9193
self.example_formatter = example_formatter
9294
self.llm_api = llm_api
@@ -184,9 +186,10 @@ def __call__(self, text: str) -> Optional[str]:
184186
"Async API is not supported in Text2SQL application. "
185187
"Please use a synchronous API."
186188
)
187-
189+
if self.llm_api is None:
190+
return None
188191
try:
189-
output = self.guard(
192+
return self.guard(
190193
self.llm_api,
191194
prompt_params={
192195
"nl_instruction": text,
@@ -200,6 +203,4 @@ def __call__(self, text: str) -> Optional[str]:
200203
"generated_sql"
201204
]
202205
except TypeError:
203-
output = None
204-
205-
return output
206+
return None

guardrails/llm_providers.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from guardrails.utils.openai_utils import (
77
AsyncOpenAIClient,
88
OpenAIClient,
9-
static_openai_acreate_func,
10-
static_openai_chat_acreate_func,
11-
static_openai_chat_create_func,
12-
static_openai_create_func,
9+
get_static_openai_chat_acreate_func,
10+
get_static_openai_chat_create_func,
11+
get_static_openai_create_func,
12+
get_static_openai_acreate_func,
1313
)
1414
from guardrails.utils.pydantic_utils import convert_pydantic_model_to_openai_fn
1515

@@ -260,9 +260,9 @@ def _invoke_llm(self, *args, **kwargs) -> LLMResponse:
260260
def get_llm_ask(llm_api: Callable, *args, **kwargs) -> PromptCallableBase:
261261
if "temperature" not in kwargs:
262262
kwargs.update({"temperature": 0})
263-
if llm_api == static_openai_create_func:
263+
if llm_api == get_static_openai_create_func():
264264
return OpenAICallable(*args, **kwargs)
265-
if llm_api == static_openai_chat_create_func:
265+
if llm_api == get_static_openai_chat_create_func():
266266
return OpenAIChatCallable(*args, **kwargs)
267267

268268
try:
@@ -475,9 +475,9 @@ def get_async_llm_ask(
475475
) -> AsyncPromptCallableBase:
476476

477477
# these only work with openai v0 (None otherwise)
478-
if llm_api == static_openai_acreate_func:
478+
if llm_api == get_static_openai_acreate_func():
479479
return AsyncOpenAICallable(*args, **kwargs)
480-
if llm_api == static_openai_chat_acreate_func:
480+
if llm_api == get_static_openai_chat_acreate_func():
481481
return AsyncOpenAIChatCallable(*args, **kwargs)
482482

483483
try:

guardrails/utils/openai_utils/__init__.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,30 @@
77
from .v0 import OpenAIClientV0 as OpenAIClient
88
from .v0 import (
99
OpenAIServiceUnavailableError,
10-
static_openai_acreate_func,
11-
static_openai_chat_acreate_func,
12-
static_openai_chat_create_func,
13-
static_openai_create_func,
10+
get_static_openai_chat_acreate_func,
11+
get_static_openai_chat_create_func,
12+
get_static_openai_create_func,
13+
get_static_openai_acreate_func,
1414
)
1515
else:
1616
from .v1 import AsyncOpenAIClientV1 as AsyncOpenAIClient
1717
from .v1 import OpenAIClientV1 as OpenAIClient
1818
from .v1 import (
1919
OpenAIServiceUnavailableError,
20-
static_openai_acreate_func,
21-
static_openai_chat_acreate_func,
22-
static_openai_chat_create_func,
23-
static_openai_create_func,
20+
get_static_openai_chat_acreate_func,
21+
get_static_openai_chat_create_func,
22+
get_static_openai_create_func,
23+
get_static_openai_acreate_func,
2424
)
2525

2626

2727
__all__ = [
2828
"OPENAI_VERSION",
2929
"AsyncOpenAIClient",
3030
"OpenAIClient",
31-
"static_openai_create_func",
32-
"static_openai_chat_create_func",
33-
"static_openai_acreate_func",
34-
"static_openai_chat_acreate_func",
31+
"get_static_openai_create_func",
32+
"get_static_openai_chat_create_func",
33+
"get_static_openai_acreate_func",
34+
"get_static_openai_chat_acreate_func",
3535
"OpenAIServiceUnavailableError",
3636
]

guardrails/utils/openai_utils/v0.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,23 @@
1010
BaseSyncOpenAIClient,
1111
)
1212

13-
static_openai_create_func = openai.Completion.create
14-
static_openai_chat_create_func = openai.ChatCompletion.create
15-
static_openai_acreate_func = openai.Completion.acreate
16-
static_openai_chat_acreate_func = openai.ChatCompletion.acreate
13+
14+
def get_static_openai_create_func():
15+
return openai.Completion.create
16+
17+
18+
def get_static_openai_chat_create_func():
19+
return openai.ChatCompletion.create
20+
21+
22+
def get_static_openai_acreate_func():
23+
return openai.Completion.acreate
24+
25+
26+
def get_static_openai_chat_acreate_func():
27+
return openai.ChatCompletion.acreate
28+
29+
1730
OpenAIServiceUnavailableError = openai.error.ServiceUnavailableError
1831

1932
OPENAI_RETRYABLE_ERRORS = [

guardrails/utils/openai_utils/v1.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,32 @@
1+
import os
12
from typing import Any, List
23

34
import openai
45

56
from guardrails.utils.llm_response import LLMResponse
67
from guardrails.utils.openai_utils.base import BaseOpenAIClient
78

8-
static_openai_create_func = openai.completions.create
9-
static_openai_chat_create_func = openai.chat.completions.create
10-
static_openai_acreate_func = None
11-
static_openai_chat_acreate_func = None
9+
10+
def get_static_openai_create_func():
11+
if "OPENAI_API_KEY" not in os.environ:
12+
return None
13+
return openai.completions.create
14+
15+
16+
def get_static_openai_chat_create_func():
17+
if "OPENAI_API_KEY" not in os.environ:
18+
return None
19+
return openai.chat.completions.create
20+
21+
22+
def get_static_openai_acreate_func():
23+
return None
24+
25+
26+
def get_static_openai_chat_acreate_func():
27+
return None
28+
29+
1230
OpenAIServiceUnavailableError = openai.APIError
1331

1432

guardrails/validators.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020

2121
from guardrails.utils.casting_utils import to_int
2222
from guardrails.utils.docs_utils import get_chunks_from_text, sentence_split
23-
from guardrails.utils.openai_utils import OpenAIClient, static_openai_chat_create_func
23+
from guardrails.utils.openai_utils import (
24+
OpenAIClient,
25+
get_static_openai_chat_create_func,
26+
)
2427
from guardrails.utils.sql_utils import SQLDriver, create_sql_driver
2528
from guardrails.utils.validator_utils import PROVENANCE_V1_PROMPT
2629
from guardrails.validator_base import (
@@ -1285,7 +1288,7 @@ def __init__(
12851288
)
12861289

12871290
self.llm_callable = (
1288-
llm_callable if llm_callable else static_openai_chat_create_func
1291+
llm_callable if llm_callable else get_static_openai_chat_create_func()
12891292
)
12901293

12911294
self._threshold = threshold
@@ -1397,7 +1400,7 @@ def __init__(
13971400
)
13981401

13991402
self.llm_callable = (
1400-
llm_callable if llm_callable else static_openai_chat_create_func
1403+
llm_callable if llm_callable else get_static_openai_chat_create_func()
14011404
)
14021405

14031406
def _selfeval(self, question: str, answer: str):

tests/integration_tests/test_guard.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import guardrails as gd
99
from guardrails.guard import Guard
1010
from guardrails.utils.openai_utils import (
11-
static_openai_chat_create_func,
12-
static_openai_create_func,
11+
get_static_openai_chat_create_func,
12+
get_static_openai_create_func,
1313
)
1414
from guardrails.utils.reask_utils import FieldReAsk
1515
from guardrails.validators import FailResult, OneLine
@@ -140,7 +140,7 @@ def test_entity_extraction_with_reask(
140140
guard = guard_initializer(rail, prompt)
141141

142142
_, final_output = guard(
143-
llm_api=static_openai_create_func,
143+
llm_api=get_static_openai_create_func(),
144144
prompt_params={"document": content[:6000]},
145145
num_reasks=1,
146146
max_tokens=2000,
@@ -219,7 +219,7 @@ def test_entity_extraction_with_noop(mocker, rail, prompt):
219219
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
220220
guard = guard_initializer(rail, prompt)
221221
_, final_output = guard(
222-
llm_api=static_openai_create_func,
222+
llm_api=get_static_openai_create_func(),
223223
prompt_params={"document": content[:6000]},
224224
num_reasks=1,
225225
)
@@ -255,7 +255,7 @@ def test_entity_extraction_with_filter(mocker, rail, prompt):
255255
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
256256
guard = guard_initializer(rail, prompt)
257257
_, final_output = guard(
258-
llm_api=static_openai_create_func,
258+
llm_api=get_static_openai_create_func(),
259259
prompt_params={"document": content[:6000]},
260260
num_reasks=1,
261261
)
@@ -290,7 +290,7 @@ def test_entity_extraction_with_fix(mocker, rail, prompt):
290290
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
291291
guard = guard_initializer(rail, prompt)
292292
_, final_output = guard(
293-
llm_api=static_openai_create_func,
293+
llm_api=get_static_openai_create_func(),
294294
prompt_params={"document": content[:6000]},
295295
num_reasks=1,
296296
)
@@ -326,7 +326,7 @@ def test_entity_extraction_with_refrain(mocker, rail, prompt):
326326
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
327327
guard = guard_initializer(rail, prompt)
328328
_, final_output = guard(
329-
llm_api=static_openai_create_func,
329+
llm_api=get_static_openai_create_func(),
330330
prompt_params={"document": content[:6000]},
331331
num_reasks=1,
332332
)
@@ -369,7 +369,7 @@ def test_entity_extraction_with_fix_chat_models(mocker, rail, prompt, instructio
369369
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
370370
guard = guard_initializer(rail, prompt, instructions)
371371
_, final_output = guard(
372-
llm_api=static_openai_chat_create_func,
372+
llm_api=get_static_openai_chat_create_func(),
373373
prompt_params={"document": content[:6000]},
374374
num_reasks=1,
375375
)
@@ -399,7 +399,7 @@ def test_string_output(mocker):
399399

400400
guard = gd.Guard.from_rail_string(string.RAIL_SPEC_FOR_STRING)
401401
_, final_output = guard(
402-
llm_api=static_openai_create_func,
402+
llm_api=get_static_openai_create_func(),
403403
prompt_params={"ingredients": "tomato, cheese, sour cream"},
404404
num_reasks=1,
405405
)
@@ -421,7 +421,7 @@ def test_string_reask(mocker):
421421

422422
guard = gd.Guard.from_rail_string(string.RAIL_SPEC_FOR_STRING_REASK)
423423
_, final_output = guard(
424-
llm_api=static_openai_create_func,
424+
llm_api=get_static_openai_create_func(),
425425
prompt_params={"ingredients": "tomato, cheese, sour cream"},
426426
num_reasks=1,
427427
max_tokens=100,
@@ -454,7 +454,7 @@ def test_skeleton_reask(mocker):
454454
content = gd.docs_utils.read_pdf("docs/examples/data/chase_card_agreement.pdf")
455455
guard = gd.Guard.from_rail_string(entity_extraction.RAIL_SPEC_WITH_SKELETON_REASK)
456456
_, final_output = guard(
457-
llm_api=static_openai_create_func,
457+
llm_api=get_static_openai_create_func(),
458458
prompt_params={"document": content[:6000]},
459459
max_tokens=1000,
460460
num_reasks=1,
@@ -497,7 +497,7 @@ def test_skeleton_reask(mocker):
497497
498498
guard = gd.Guard.from_rail_string(string.RAIL_SPEC_FOR_LIST)
499499
_, final_output = guard(
500-
llm_api=static_openai_create_func,
500+
llm_api=get_static_openai_create_func(),
501501
num_reasks=1,
502502
)
503503
assert final_output == string.LIST_LLM_OUTPUT
@@ -523,7 +523,7 @@ def test_skeleton_reask(mocker):
523523
entity_extraction.OPTIONAL_PROMPT_COMPLETION_MODEL,
524524
None,
525525
None,
526-
static_openai_create_func,
526+
get_static_openai_create_func(),
527527
entity_extraction.COMPILED_PROMPT,
528528
None,
529529
entity_extraction.COMPILED_PROMPT_REASK,
@@ -534,7 +534,7 @@ def test_skeleton_reask(mocker):
534534
entity_extraction.OPTIONAL_PROMPT_CHAT_MODEL,
535535
entity_extraction.OPTIONAL_INSTRUCTIONS_CHAT_MODEL,
536536
None,
537-
static_openai_chat_create_func,
537+
get_static_openai_chat_create_func(),
538538
entity_extraction.COMPILED_PROMPT_WITHOUT_INSTRUCTIONS,
539539
entity_extraction.COMPILED_INSTRUCTIONS,
540540
entity_extraction.COMPILED_PROMPT_REASK_WITHOUT_INSTRUCTIONS,
@@ -545,7 +545,7 @@ def test_skeleton_reask(mocker):
545545
None,
546546
None,
547547
entity_extraction.OPTIONAL_MSG_HISTORY,
548-
static_openai_chat_create_func,
548+
get_static_openai_chat_create_func(),
549549
None,
550550
None,
551551
entity_extraction.COMPILED_PROMPT_REASK_WITHOUT_INSTRUCTIONS,
@@ -566,7 +566,7 @@ def test_entity_extraction_with_reask_with_optional_prompts(
566566
expected_reask_instructions,
567567
):
568568
"""Test that the entity extraction works with re-asking."""
569-
if llm_api == static_openai_create_func:
569+
if llm_api == get_static_openai_create_func():
570570
mocker.patch("guardrails.llm_providers.OpenAICallable", new=MockOpenAICallable)
571571
else:
572572
mocker.patch(
@@ -653,7 +653,7 @@ def test_string_with_message_history_reask(mocker):
653653

654654
guard = gd.Guard.from_rail_string(string.RAIL_SPEC_FOR_MSG_HISTORY)
655655
_, final_output = guard(
656-
llm_api=static_openai_chat_create_func,
656+
llm_api=get_static_openai_chat_create_func(),
657657
msg_history=string.MOVIE_MSG_HISTORY,
658658
temperature=0.0,
659659
model="gpt-3.5-turbo",
@@ -689,7 +689,7 @@ def test_pydantic_with_message_history_reask(mocker):
689689

690690
guard = gd.Guard.from_pydantic(output_class=pydantic.WITH_MSG_HISTORY)
691691
raw_output, guarded_output = guard(
692-
llm_api=static_openai_chat_create_func,
692+
llm_api=get_static_openai_chat_create_func(),
693693
msg_history=string.MOVIE_MSG_HISTORY,
694694
temperature=0.0,
695695
model="gpt-3.5-turbo",
@@ -731,7 +731,7 @@ def test_sequential_validator_log_is_not_duplicated(mocker):
731731
)
732732

733733
_, final_output = guard(
734-
llm_api=static_openai_create_func,
734+
llm_api=get_static_openai_create_func(),
735735
prompt_params={"document": content[:6000]},
736736
num_reasks=1,
737737
)
@@ -765,7 +765,7 @@ def test_in_memory_validator_log_is_not_duplicated(mocker):
765765
)
766766

767767
_, final_output = guard(
768-
llm_api=static_openai_create_func,
768+
llm_api=get_static_openai_create_func(),
769769
prompt_params={"document": content[:6000]},
770770
num_reasks=1,
771771
)

0 commit comments

Comments
 (0)