Skip to content

Update embeddings_utils.py to set default model to text-embedding-ada-002 #604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions openai/embeddings_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embedding(text: str, engine="text-similarity-davinci-001", **kwargs) -> List[float]:
def get_embedding(text: str, engine="text-embedding-ada-002", **kwargs) -> List[float]:

# replace newlines, which can negatively affect performance.
text = text.replace("\n", " ")
Expand All @@ -25,7 +25,7 @@ def get_embedding(text: str, engine="text-similarity-davinci-001", **kwargs) ->

@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embedding(
text: str, engine="text-similarity-davinci-001", **kwargs
text: str, engine="text-embedding-ada-002", **kwargs
) -> List[float]:

# replace newlines, which can negatively affect performance.
Expand All @@ -38,9 +38,9 @@ async def aget_embedding(

@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
def get_embeddings(
list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs
list_of_text: List[str], engine="text-embedding-ada-002", **kwargs
) -> List[List[float]]:
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
assert len(list_of_text) <= 8191, "The batch size should not be larger than 8191."

# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
Expand All @@ -51,9 +51,9 @@ def get_embeddings(

@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))
async def aget_embeddings(
list_of_text: List[str], engine="text-similarity-babbage-001", **kwargs
list_of_text: List[str], engine="text-embedding-ada-002", **kwargs
) -> List[List[float]]:
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."
assert len(list_of_text) <= 8191, "The batch size should not be larger than 8191."

# replace newlines, which can negatively affect performance.
list_of_text = [text.replace("\n", " ") for text in list_of_text]
Expand Down