Skip to content

Add support for HNSW and IVFFlat indexes, add tests #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 134 additions & 1 deletion langchain_postgres/vectorstores.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from langchain_core.indexing import UpsertResponse
from langchain_core.utils import get_from_dict_or_env
from langchain_core.vectorstores import VectorStore
from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select
from sqlalchemy import SQLColumnExpression, cast, create_engine, delete, func, select, text
from sqlalchemy.dialects.postgresql import JSON, JSONB, JSONPATH, UUID, insert
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.ext.asyncio import (
Expand Down Expand Up @@ -97,6 +97,124 @@ class DistanceStrategy(str, enum.Enum):
.union(SPECIAL_CASED_OPERATORS)
)

class IndexManager:
"""Manages the creation, listing, and retrieval of indexes for the embedding column in a PostgreSQL database.

This class provides both synchronous and asynchronous methods to interact with the database, allowing for
the creation of different types of indexes (e.g., HNSW, IVFFlat) with various distance functions (e.g., l2, cosine).

Args:
connection (Union[str, Engine, AsyncEngine]): The database connection string or engine instance.
async_mode (bool): Flag to indicate if asynchronous operations should be used. Defaults to False.
"""
def __init__(self, connection: Union[str, Engine, AsyncEngine], async_mode: bool = False):
self.async_mode = async_mode
if isinstance(connection, str):
if async_mode:
self._engine = create_async_engine(connection)
else:
self._engine = create_engine(connection)
elif isinstance(connection, (Engine, AsyncEngine)):
self._engine = connection
else:
raise ValueError("Invalid connection type")

def list_indexes(self) -> List[Dict[str, Any]]:
"""List all indexes."""
with self._engine.connect() as conn:
result = conn.execute(text("SELECT * FROM pg_indexes WHERE tablename = 'langchain_pg_embedding'"))
indexes = [dict(row) for row in result]
return indexes

async def alist_indexes(self) -> List[Dict[str, Any]]:
"""Asynchronously list all indexes."""
async with self._engine.connect() as conn:
result = await conn.execute(text("SELECT * FROM pg_indexes WHERE tablename = 'langchain_pg_embedding'"))
indexes = [dict(row) for row in result]
return indexes

def create_index(self, index_type: str, distance_strategy: DistanceStrategy, **kwargs: Any) -> str:
"""Create an index (HNSW or IVFFlat) on the embedding column.

Args:
index_type: The type of index to create ('hnsw' or 'ivfflat').
distance_strategy: The distance strategy to use (e.g., DistanceStrategy.L2, DistanceStrategy.COSINE).
kwargs: Additional parameters for the index creation (e.g., m, ef_construction, lists).

Returns:
The name of the created index.
"""
index_ops = f"vector_{distance_strategy.value}_ops"
index_name = f"{index_type}_{distance_strategy.value}_index"
index_params = ", ".join(f"{key} = {value}" for key, value in kwargs.items())
with self._engine.connect() as conn:
conn.execute(
text(
f"""
CREATE INDEX {index_name} ON langchain_pg_embedding USING {index_type} (embedding {index_ops})
WITH ({index_params});
"""
)
)
return index_name

async def acreate_index(self, index_type: str, distance_strategy: DistanceStrategy, **kwargs: Any) -> str:
"""Asynchronously create an index (HNSW or IVFFlat) on the embedding column.

Args:
index_type: The type of index to create ('hnsw' or 'ivfflat').
distance_strategy: The distance strategy to use (e.g., DistanceStrategy.L2, DistanceStrategy.COSINE).
kwargs: Additional parameters for the index creation (e.g., m, ef_construction, lists).

Returns:
The name of the created index.
"""
index_ops = f"vector_{distance_strategy.value}_ops"
index_name = f"{index_type}_{distance_strategy.value}_index"
index_params = ", ".join(f"{key} = {value}" for key, value in kwargs.items())
async with self._engine.connect() as conn:
await conn.execute(
text(
f"""
CREATE INDEX {index_name} ON langchain_pg_embedding USING {index_type} (embedding {index_ops})
WITH ({index_params});
"""
)
)
return index_name

def get_index(self, index_name: str, embeddings: Embeddings, collection_name: str) -> Optional[PGVector]:
"""Get details of a specific index and return a PGVector instance."""
with self._engine.connect() as conn:
result = conn.execute(text(f"SELECT * FROM pg_indexes WHERE indexname = :index_name"), {"index_name": index_name})
index = result.fetchone()
if index:
distance_strategy = DistanceStrategy(index['indexdef'].split(' ')[-1].split('_')[1])
return PGVector(
embeddings=embeddings,
connection=self._engine,
collection_name=collection_name,
distance_strategy=distance_strategy,
async_mode=self.async_mode
)
return None

async def aget_index(self, index_name: str, embeddings: Embeddings, collection_name: str) -> Optional[PGVector]:
"""Asynchronously get details of a specific index and return a PGVector instance."""
async with self._engine.connect() as conn:
result = await conn.execute(text(f"SELECT * FROM pg_indexes WHERE indexname = :index_name"), {"index_name": index_name})
index = result.fetchone()
if index:
distance_strategy = DistanceStrategy(index['indexdef'].split(' ')[-1].split('_')[1])
return PGVector(
embeddings=embeddings,
connection=self._engine,
collection_name=collection_name,
distance_strategy=distance_strategy,
async_mode=self.async_mode
)
return None


def _get_embedding_collection_store(vector_dimension: Optional[int] = None) -> Any:
global _classes
Expand Down Expand Up @@ -355,6 +473,8 @@ def __init__(
use_jsonb: bool = True,
create_extension: bool = True,
async_mode: bool = False,
index_type: Optional[str] = None,
index_params: Optional[Dict[str, Any]] = None,
) -> None:
"""Initialize the PGVector store.
For an async version, use `PGVector.acreate()` instead.
Expand Down Expand Up @@ -383,6 +503,8 @@ def __init__(
create_extension: If True, will create the vector extension if it
doesn't exist. disabling creation is useful when using ReadOnly
Databases.
index_type: The type of index to create. (default: None)
index_params: The parameters for the index. (default: None)
"""
self.async_mode = async_mode
self.embedding_function = embeddings
Expand All @@ -396,6 +518,8 @@ def __init__(
self._engine: Optional[Engine] = None
self._async_engine: Optional[AsyncEngine] = None
self._async_init = False
self.index_type = index_type
self.index_params = index_params or {}

if isinstance(connection, str):
if async_mode:
Expand Down Expand Up @@ -427,6 +551,9 @@ def __init__(
if not use_jsonb:
# Replace with a deprecation warning.
raise NotImplementedError("use_jsonb=False is no longer supported.")

self.index_manager = IndexManager(connection=self._engine, async_mode=self.async_mode)

if not self.async_mode:
self.__post_init__()

Expand All @@ -445,6 +572,9 @@ def __post_init__(
self.create_tables_if_not_exists()
self.create_collection()

if self.index_type:
self.index_manager.create_index(self.index_type, self._distance_strategy, **self.index_params)

async def __apost_init__(
self,
) -> None:
Expand All @@ -464,6 +594,9 @@ async def __apost_init__(
await self.acreate_tables_if_not_exists()
await self.acreate_collection()

if self.index_type:
await self.index_manager.acreate_index(self.index_type, self._distance_strategy, **self.index_params)

@property
def embeddings(self) -> Embeddings:
return self.embedding_function
Expand Down
85 changes: 85 additions & 0 deletions tests/unit_tests/test_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from langchain_postgres.vectorstores import (
SUPPORTED_OPERATORS,
DistanceStrategy,
IndexManager,
PGVector,
)
from tests.unit_tests.fake_embeddings import FakeEmbeddings
Expand Down Expand Up @@ -1068,3 +1070,86 @@ def test_validate_operators() -> None:
"$not",
"$or",
]

def test_pgvector_with_hnsw_index() -> None:
"""Test end to end construction and search with HNSW index."""
texts = ["foo", "bar", "baz"]
docsearch = PGVector.from_texts(
texts=texts,
collection_name="test_collection_hnsw",
embedding=FakeEmbeddingsWithAdaDimension(),
connection=CONNECTION_STRING,
pre_delete_collection=True,
index_type="hnsw",
index_params={"m": 16, "ef_construction": 64},
distance_strategy=DistanceStrategy.L2,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", id=AnyStr())]

@pytest.mark.asyncio
async def test_async_pgvector_with_hnsw_index() -> None:
"""Test end to end construction and search with HNSW index."""
texts = ["foo", "bar", "baz"]
docsearch = await PGVector.afrom_texts(
texts=texts,
collection_name="test_collection_hnsw",
embedding=FakeEmbeddingsWithAdaDimension(),
connection=CONNECTION_STRING,
pre_delete_collection=True,
index_type="hnsw",
index_params={"m": 16, "ef_construction": 64},
distance_strategy=DistanceStrategy.L2,
)
output = await docsearch.asimilarity_search("foo", k=1)
assert output == [Document(page_content="foo", id=AnyStr())]

def test_pgvector_with_ivfflat_index() -> None:
"""Test end to end construction and search with IVFFlat index."""
texts = ["foo", "bar", "baz"]
docsearch = PGVector.from_texts(
texts=texts,
collection_name="test_collection_ivfflat",
embedding=FakeEmbeddingsWithAdaDimension(),
connection=CONNECTION_STRING,
pre_delete_collection=True,
index_type="ivfflat",
index_params={"lists": 100},
distance_strategy=DistanceStrategy.COSINE,
)
output = docsearch.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", id=AnyStr())]

@pytest.mark.asyncio
async def test_async_pgvector_with_ivfflat_index() -> None:
"""Test end to end construction and search with IVFFlat index."""
texts = ["foo", "bar", "baz"]
docsearch = await PGVector.afrom_texts(
texts=texts,
collection_name="test_collection_ivfflat",
embedding=FakeEmbeddingsWithAdaDimension(),
connection=CONNECTION_STRING,
pre_delete_collection=True,
index_type="ivfflat",
index_params={"lists": 100},
distance_strategy=DistanceStrategy.COSINE,
)
output = await docsearch.asimilarity_search("foo", k=1)
assert output == [Document(page_content="foo", id=AnyStr())]

def test_get_index() -> None:
"""Test retrieving a VectorStore instance."""
index_manager = IndexManager(connection=CONNECTION_STRING)
vectorstore = index_manager.get_index("hnsw_l2_index", FakeEmbeddingsWithAdaDimension(), "test_collection_hnsw")
assert vectorstore is not None
output = vectorstore.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", id=AnyStr())]

@pytest.mark.asyncio
async def test_async_get_index() -> None:
"""Test asynchronously retrieving a VectorStore instance."""
index_manager = IndexManager(connection=CONNECTION_STRING, async_mode=True)
vectorstore = await index_manager.aget_index("hnsw_l2_index", FakeEmbeddingsWithAdaDimension(), "test_collection_hnsw")
assert vectorstore is not None
output = await vectorstore.asimilarity_search("foo", k=1)
assert output == [Document(page_content="foo", id=AnyStr())]