Skip to content

Commit b39bddd

Browse files
hallacykennyhsu5madeleineth
authored
Hallacy/11 4 release (#54)
* Make embeddings_utils be importable (#104) * Make embeddings_utils be importable * Small tweaks to dicts for typing * Remove default api_prefix and move v1 prefix to default api_base (#95) * make construct_from key argument optional (#92) * Split search.prepare_data into answers/classifications/search versions (#93) * Break out prepare_data into answers, classifications, and search * And cleaned up CLI * Validate search files (#69) * Add validators for search files * Clean up fields Co-authored-by: kennyhsu5 <[email protected]> Co-authored-by: Madeleine Thompson <[email protected]>
1 parent 88bbe08 commit b39bddd

19 files changed

+246
-115
lines changed

examples/embeddings/Classification.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@
9090
}
9191
],
9292
"source": [
93-
"from utils import plot_multiclass_precision_recall\n",
93+
"from openai.embeddings_utils import plot_multiclass_precision_recall\n",
9494
"\n",
9595
"plot_multiclass_precision_recall(probas, y_test, [1,2,3,4,5], clf)"
9696
]

examples/embeddings/Code_search.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@
185185
}
186186
],
187187
"source": [
188-
"from utils import get_embedding\n",
188+
"from openai.embeddings_utils import get_embedding\n",
189189
"\n",
190190
"df = pd.DataFrame(all_funcs)\n",
191191
"df['code_embedding'] = df['code'].apply(lambda x: get_embedding(x, engine='babbage-code-search-code'))\n",
@@ -231,7 +231,7 @@
231231
}
232232
],
233233
"source": [
234-
"from utils import cosine_similarity\n",
234+
"from openai.embeddings_utils import cosine_similarity\n",
235235
"\n",
236236
"def search_functions(df, code_query, n=3, pprint=True, n_lines=7):\n",
237237
" embedding = get_embedding(code_query, engine='babbage-code-search-text')\n",

examples/embeddings/Obtain_dataset.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@
156156
"metadata": {},
157157
"outputs": [],
158158
"source": [
159-
"from utils import get_embedding\n",
159+
"from openai.embeddings_utils import get_embedding\n",
160160
"\n",
161161
"# This will take just under 10 minutes\n",
162162
"df['babbage_similarity'] = df.combined.apply(lambda x: get_embedding(x, engine='babbage-similarity'))\n",

examples/embeddings/Semantic_text_search_using_embeddings.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
}
5050
],
5151
"source": [
52-
"from utils import get_embedding, cosine_similarity\n",
52+
"from openai.embeddings_utils import get_embedding, cosine_similarity\n",
5353
"\n",
5454
"# search through the reviews for a specific product\n",
5555
"def search_reviews(df, product_description, n=3, pprint=True):\n",

examples/embeddings/User_and_product_embeddings.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
"metadata": {},
7171
"outputs": [],
7272
"source": [
73-
"from utils import cosine_similarity\n",
73+
"from openai.embeddings_utils import cosine_similarity\n",
7474
"\n",
7575
"# evaluate embeddings as recommendations on X_test\n",
7676
"def evaluate_single_match(row):\n",

examples/embeddings/Zero-shot_classification.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@
7878
}
7979
],
8080
"source": [
81-
"from utils import cosine_similarity, get_embedding\n",
81+
"from openai.embeddings_utils import cosine_similarity, get_embedding\n",
8282
"from sklearn.metrics import PrecisionRecallDisplay\n",
8383
"\n",
8484
"def evaluate_emeddings_approach(\n",

openai/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
api_key_path: Optional[str] = os.environ.get("OPENAI_API_KEY_PATH")
2626

2727
organization = os.environ.get("OPENAI_ORGANIZATION")
28-
api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com")
28+
api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
2929
api_version = None
3030
verify_ssl_certs = True # No effect. Certificates are always verified.
3131
proxy = None

openai/api_resources/abstract/api_resource.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
class APIResource(OpenAIObject):
8-
api_prefix = "v1"
8+
api_prefix = ""
99

1010
@classmethod
1111
def retrieve(cls, id, api_key=None, request_id=None, **params):
@@ -28,7 +28,9 @@ def class_url(cls):
2828
# Namespaces are separated in object names with periods (.) and in URLs
2929
# with forward slashes (/), so replace the former with the latter.
3030
base = cls.OBJECT_NAME.replace(".", "/") # type: ignore
31-
return "/%s/%ss" % (cls.api_prefix, base)
31+
if cls.api_prefix:
32+
return "/%s/%ss" % (cls.api_prefix, base)
33+
return "/%ss" % (base)
3234

3335
def instance_url(self):
3436
id = self.get("id")

openai/api_resources/abstract/engine_api_resource.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ def class_url(cls, engine: Optional[str] = None):
2222
# with forward slashes (/), so replace the former with the latter.
2323
base = cls.OBJECT_NAME.replace(".", "/") # type: ignore
2424
if engine is None:
25-
return "/%s/%ss" % (cls.api_prefix, base)
25+
return "/%ss" % (base)
2626

2727
extn = quote_plus(engine)
28-
return "/%s/engines/%s/%ss" % (cls.api_prefix, extn, base)
28+
return "/engines/%s/%ss" % (extn, base)
2929

3030
@classmethod
3131
def create(

openai/api_resources/answer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22

33

44
class Answer(OpenAIObject):
5-
api_prefix = "v1"
6-
75
@classmethod
8-
def get_url(self, base):
9-
return "/%s/%s" % (self.api_prefix, base)
6+
def get_url(self):
7+
return "/answers"
108

119
@classmethod
1210
def create(cls, **params):
1311
instance = cls()
14-
return instance.request("post", cls.get_url("answers"), params)
12+
return instance.request("post", cls.get_url(), params)

openai/api_resources/classification.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22

33

44
class Classification(OpenAIObject):
5-
api_prefix = "v1"
6-
75
@classmethod
8-
def get_url(self, base):
9-
return "/%s/%s" % (self.api_prefix, base)
6+
def get_url(self):
7+
return "/classifications"
108

119
@classmethod
1210
def create(cls, **params):
1311
instance = cls()
14-
return instance.request("post", cls.get_url("classifications"), params)
12+
return instance.request("post", cls.get_url(), params)

openai/api_resources/search.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22

33

44
class Search(APIResource):
5-
api_prefix = "v1"
6-
OBJECT_NAME = "search_indices"
7-
85
@classmethod
96
def class_url(cls):
10-
return "/%s/%s" % (cls.api_prefix, cls.OBJECT_NAME)
7+
return "/search_indices/search"
118

129
@classmethod
1310
def create_alpha(cls, **params):
1411
instance = cls()
15-
return instance.request("post", f"{cls.class_url()}/search", params)
12+
return instance.request("post", cls.class_url(), params)

openai/cli.py

Lines changed: 96 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import signal
44
import sys
55
import warnings
6+
from functools import partial
67
from typing import Optional
78

89
import requests
@@ -11,10 +12,12 @@
1112
from openai.upload_progress import BufferReader
1213
from openai.validators import (
1314
apply_necessary_remediation,
14-
apply_optional_remediation,
15+
apply_validators,
16+
get_search_validators,
1517
get_validators,
1618
read_any_format,
1719
write_out_file,
20+
write_out_search_file,
1821
)
1922

2023

@@ -227,6 +230,40 @@ def list(cls, args):
227230

228231

229232
class Search:
233+
@classmethod
234+
def prepare_data(cls, args, purpose):
235+
236+
sys.stdout.write("Analyzing...\n")
237+
fname = args.file
238+
auto_accept = args.quiet
239+
240+
optional_fields = ["metadata"]
241+
242+
if purpose == "classifications":
243+
required_fields = ["text", "labels"]
244+
else:
245+
required_fields = ["text"]
246+
247+
df, remediation = read_any_format(
248+
fname, fields=required_fields + optional_fields
249+
)
250+
251+
if "metadata" not in df:
252+
df["metadata"] = None
253+
254+
apply_necessary_remediation(None, remediation)
255+
validators = get_search_validators(required_fields, optional_fields)
256+
257+
write_out_file_func = partial(
258+
write_out_search_file,
259+
purpose=purpose,
260+
fields=required_fields + optional_fields,
261+
)
262+
263+
apply_validators(
264+
df, fname, remediation, validators, auto_accept, write_out_file_func
265+
)
266+
230267
@classmethod
231268
def create_alpha(cls, args):
232269
resp = openai.Search.create_alpha(
@@ -489,49 +526,14 @@ def prepare_data(cls, args):
489526

490527
validators = get_validators()
491528

492-
optional_remediations = []
493-
if remediation is not None:
494-
optional_remediations.append(remediation)
495-
for validator in validators:
496-
remediation = validator(df)
497-
if remediation is not None:
498-
optional_remediations.append(remediation)
499-
df = apply_necessary_remediation(df, remediation)
500-
501-
any_optional_or_necessary_remediations = any(
502-
[
503-
remediation
504-
for remediation in optional_remediations
505-
if remediation.optional_msg is not None
506-
or remediation.necessary_msg is not None
507-
]
529+
apply_validators(
530+
df,
531+
fname,
532+
remediation,
533+
validators,
534+
auto_accept,
535+
write_out_file_func=write_out_file,
508536
)
509-
any_necessary_applied = any(
510-
[
511-
remediation
512-
for remediation in optional_remediations
513-
if remediation.necessary_msg is not None
514-
]
515-
)
516-
any_optional_applied = False
517-
518-
if any_optional_or_necessary_remediations:
519-
sys.stdout.write(
520-
"\n\nBased on the analysis we will perform the following actions:\n"
521-
)
522-
for remediation in optional_remediations:
523-
df, optional_applied = apply_optional_remediation(
524-
df, remediation, auto_accept
525-
)
526-
any_optional_applied = any_optional_applied or optional_applied
527-
else:
528-
sys.stdout.write("\n\nNo remediations found.\n")
529-
530-
any_optional_or_necessary_applied = (
531-
any_optional_applied or any_necessary_applied
532-
)
533-
534-
write_out_file(df, fname, any_optional_or_necessary_applied, auto_accept)
535537

536538

537539
def tools_register(parser):
@@ -561,6 +563,57 @@ def help(args):
561563
)
562564
sub.set_defaults(func=FineTune.prepare_data)
563565

566+
sub = subparsers.add_parser("search.prepare_data")
567+
sub.add_argument(
568+
"-f",
569+
"--file",
570+
required=True,
571+
help="JSONL, JSON, CSV, TSV, TXT or XLSX file containing text examples to be analyzed."
572+
"This should be the local file path.",
573+
)
574+
sub.add_argument(
575+
"-q",
576+
"--quiet",
577+
required=False,
578+
action="store_true",
579+
help="Auto accepts all suggestions, without asking for user input. To be used within scripts.",
580+
)
581+
sub.set_defaults(func=partial(Search.prepare_data, purpose="search"))
582+
583+
sub = subparsers.add_parser("classifications.prepare_data")
584+
sub.add_argument(
585+
"-f",
586+
"--file",
587+
required=True,
588+
help="JSONL, JSON, CSV, TSV, TXT or XLSX file containing text-label examples to be analyzed."
589+
"This should be the local file path.",
590+
)
591+
sub.add_argument(
592+
"-q",
593+
"--quiet",
594+
required=False,
595+
action="store_true",
596+
help="Auto accepts all suggestions, without asking for user input. To be used within scripts.",
597+
)
598+
sub.set_defaults(func=partial(Search.prepare_data, purpose="classification"))
599+
600+
sub = subparsers.add_parser("answers.prepare_data")
601+
sub.add_argument(
602+
"-f",
603+
"--file",
604+
required=True,
605+
help="JSONL, JSON, CSV, TSV, TXT or XLSX file containing text examples to be analyzed."
606+
"This should be the local file path.",
607+
)
608+
sub.add_argument(
609+
"-q",
610+
"--quiet",
611+
required=False,
612+
action="store_true",
613+
help="Auto accepts all suggestions, without asking for user input. To be used within scripts.",
614+
)
615+
sub.set_defaults(func=partial(Search.prepare_data, purpose="answer"))
616+
564617

565618
def api_register(parser):
566619
# Engine management

0 commit comments

Comments
 (0)