Skip to content

How to use PostgresChatMessageHistory with async connections when using RunnableWithMessageHistory? #122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
lgabs opened this issue Sep 29, 2024 · 5 comments

Comments

@lgabs
Copy link

lgabs commented Sep 29, 2024

I'm attempting to implement an asynchronous approach using PostgresChatMessageHistory in combination with RunnableWithMessageHistory, but I'm encountering some challenges. While PostgresChatMessageHistory appears to support asynchronous connections with async methods, I haven't found clear guidance on how to properly use them. Additionally, I couldn't locate any relevant documentation addressing this use case.

Any ideas on how to resolve this problem?

Here the a sample code I'm using and the traceback error:

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables.history import RunnableWithMessageHistory

from langchain_postgres import PostgresChatMessageHistory

import psycopg
from psycopg import AsyncConnection

import uuid

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a helpful assistant."),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{question}"),
    ]
)

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

chain = prompt | llm | StrOutputParser()

def get_session_history(session_id: str = None) -> PostgresChatMessageHistory:
    session_id = session_id or str(uuid.uuid4())
    sync_connection = psycopg.connect(DATABASE_URL) # the db url was initialized before
    return PostgresChatMessageHistory(
        "chat_history", 
        session_id, 
        sync_connection=sync_connection
    )

async def aget_session_history(session_id: str = None) -> PostgresChatMessageHistory:
    session_id = session_id or str(uuid.uuid4())
    async with await AsyncConnection.connect(DATABASE_URL) as async_connection:
        return PostgresChatMessageHistory(
            "chat_history", 
            session_id, 
            async_connection=async_connection
        )

chain_with_history_sync = (
    RunnableWithMessageHistory(
        chain,
        get_session_history,
        input_messages_key="question",
        history_messages_key="chat_history",
    )
)

chain_with_history_async = (
    RunnableWithMessageHistory(
        chain,
        aget_session_history,
        input_messages_key="question",
        history_messages_key="chat_history",
    )
)

session_id = str(uuid.uuid4())

answer = chain_with_history_sync.invoke(
    {"question": "Good morning!"},
    {"configurable": {"session_id": session_id}}
)
print("Sync invoke:\n", answer)

session_id = str(uuid.uuid4())

answer = await chain_with_history_async.ainvoke(
    {"question": "Good morning!"},
    {"configurable": {"session_id": session_id}}
)
print("Async invoke:\n", answer)

Traceback error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[2], line 71
     67 print("Sync invoke:\n", answer)
     69 session_id = str(uuid.uuid4())
---> 71 answer = await chain_with_history_async.ainvoke(
     72     {"question": "Good morning!"},
     73     {"configurable": {"session_id": session_id}}
     74 )
     75 print("Async invoke:\n", answer)

File /usr/local/lib/python3.12/site-packages/langchain_core/runnables/base.py:5105, in RunnableBindingBase.ainvoke(self, input, config, **kwargs)
   5099 async def ainvoke(
   5100     self,
   5101     input: Input,
   5102     config: Optional[RunnableConfig] = None,
   5103     **kwargs: Optional[Any],
   5104 ) -> Output:
-> 5105     return await self.bound.ainvoke(
   5106         input,
   5107         self._merge_configs(config),
   5108         **{**self.kwargs, **kwargs},
   5109     )

File /usr/local/lib/python3.12/site-packages/langchain_core/runnables/base.py:5105, in RunnableBindingBase.ainvoke(self, input, config, **kwargs)
   5099 async def ainvoke(
   5100     self,
   5101     input: Input,
   5102     config: Optional[RunnableConfig] = None,
   5103     **kwargs: Optional[Any],
   5104 ) -> Output:
-> 5105     return await self.bound.ainvoke(
   5106         input,
   5107         self._merge_configs(config),
   5108         **{**self.kwargs, **kwargs},
   5109     )

File /usr/local/lib/python3.12/site-packages/langchain_core/runnables/base.py:2921, in RunnableSequence.ainvoke(self, input, config, **kwargs)
   2919     part = functools.partial(step.ainvoke, input, config)
   2920 if asyncio_accepts_context():
-> 2921     input = await asyncio.create_task(part(), context=context)  # type: ignore
   2922 else:
   2923     input = await asyncio.create_task(part())

File /usr/local/lib/python3.12/site-packages/langchain_core/runnables/base.py:5105, in RunnableBindingBase.ainvoke(self, input, config, **kwargs)
   5099 async def ainvoke(
   5100     self,
   5101     input: Input,
   5102     config: Optional[RunnableConfig] = None,
   5103     **kwargs: Optional[Any],
   5104 ) -> Output:
-> 5105     return await self.bound.ainvoke(
   5106         input,
   5107         self._merge_configs(config),
   5108         **{**self.kwargs, **kwargs},
   5109     )

File /usr/local/lib/python3.12/site-packages/langchain_core/runnables/passthrough.py:523, in RunnableAssign.ainvoke(self, input, config, **kwargs)
    517 async def ainvoke(
    518     self,
    519     input: Dict[str, Any],
    520     config: Optional[RunnableConfig] = None,
    521     **kwargs: Any,
    522 ) -> Dict[str, Any]:
--> 523     return await self._acall_with_config(self._ainvoke, input, config, **kwargs)

File /usr/local/lib/python3.12/site-packages/langchain_core/runnables/base.py:1837, in Runnable._acall_with_config(self, func, input, config, run_type, serialized, **kwargs)
   1833 coro = acall_func_with_variable_args(
   1834     func, input, config, run_manager, **kwargs
   1835 )
   1836 if asyncio_accepts_context():
-> 1837     output: Output = await asyncio.create_task(coro, context=context)  # type: ignore
   1838 else:
   1839     output = await coro

File /usr/local/lib/python3.12/site-packages/langchain_core/runnables/passthrough.py:510, in RunnableAssign._ainvoke(self, input, run_manager, config, **kwargs)
    497 async def _ainvoke(
    498     self,
    499     input: Dict[str, Any],
   (...)
    502     **kwargs: Any,
    503 ) -> Dict[str, Any]:
    504     assert isinstance(
    505         input, dict
    506     ), "The input to RunnablePassthrough.assign() must be a dict."
    508     return {
    509         **input,
--> 510         **await self.mapper.ainvoke(
    511             input,
    512             patch_config(config, callbacks=run_manager.get_child()),
    513             **kwargs,
    514         ),
    515     }

File /usr/local/lib/python3.12/site-packages/langchain_core/runnables/base.py:3626, in RunnableParallel.ainvoke(self, input, config, **kwargs)
   3623 try:
   3624     # copy to avoid issues from the caller mutating the steps during invoke()
   3625     steps = dict(self.steps__)
-> 3626     results = await asyncio.gather(
   3627         *(
   3628             _ainvoke_step(
   3629                 step,
   3630                 input,
   3631                 # mark each step as a child run
   3632                 config,
   3633                 key,
   3634             )
   3635             for key, step in steps.items()
   3636         )
   3637     )
   3638     output = {key: value for key, value in zip(steps, results)}
   3639 # finish the root run

File /usr/local/lib/python3.12/site-packages/langchain_core/runnables/base.py:3616, in RunnableParallel.ainvoke.<locals>._ainvoke_step(step, input, config, key)
   3614 context.run(_set_config_context, child_config)
   3615 if asyncio_accepts_context():
-> 3616     return await asyncio.create_task(  # type: ignore
   3617         step.ainvoke(input, child_config), context=context
   3618     )
   3619 else:
   3620     return await asyncio.create_task(step.ainvoke(input, child_config))

File /usr/local/lib/python3.12/site-packages/langchain_core/runnables/base.py:5105, in RunnableBindingBase.ainvoke(self, input, config, **kwargs)
   5099 async def ainvoke(
   5100     self,
   5101     input: Input,
   5102     config: Optional[RunnableConfig] = None,
   5103     **kwargs: Optional[Any],
   5104 ) -> Output:
-> 5105     return await self.bound.ainvoke(
   5106         input,
   5107         self._merge_configs(config),
   5108         **{**self.kwargs, **kwargs},
   5109     )

File /usr/local/lib/python3.12/site-packages/langchain_core/runnables/base.py:4504, in RunnableLambda.ainvoke(self, input, config, **kwargs)
   4493 """Invoke this Runnable asynchronously.
   4494 
   4495 Args:
   (...)
   4501     The output of this Runnable.
   4502 """
   4503 the_func = self.afunc if hasattr(self, "afunc") else self.func
-> 4504 return await self._acall_with_config(
   4505     self._ainvoke,
   4506     input,
   4507     self._config(config, the_func),
   4508     **kwargs,
   4509 )

File /usr/local/lib/python3.12/site-packages/langchain_core/runnables/base.py:1837, in Runnable._acall_with_config(self, func, input, config, run_type, serialized, **kwargs)
   1833 coro = acall_func_with_variable_args(
   1834     func, input, config, run_manager, **kwargs
   1835 )
   1836 if asyncio_accepts_context():
-> 1837     output: Output = await asyncio.create_task(coro, context=context)  # type: ignore
   1838 else:
   1839     output = await coro

File /usr/local/lib/python3.12/site-packages/langchain_core/runnables/base.py:4430, in RunnableLambda._ainvoke(self, input, run_manager, config, **kwargs)
   4428                     output = chunk
   4429 else:
-> 4430     output = await acall_func_with_variable_args(
   4431         cast(Callable, afunc), input, config, run_manager, **kwargs
   4432     )
   4433 # If the output is a Runnable, invoke it
   4434 if isinstance(output, Runnable):

File /usr/local/lib/python3.12/site-packages/langchain_core/runnables/history.py:498, in RunnableWithMessageHistory._aenter_history(self, input, config)
    494 async def _aenter_history(
    495     self, input: Dict[str, Any], config: RunnableConfig
    496 ) -> List[BaseMessage]:
    497     hist: BaseChatMessageHistory = config["configurable"]["message_history"]
--> 498     messages = (await hist.aget_messages()).copy()
    500     if not self.history_messages_key:
    501         # return all messages
    502         input_val = (
    503             input if not self.input_messages_key else input[self.input_messages_key]
    504         )

AttributeError: 'coroutine' object has no attribute 'aget_messages'

The error indicates that 'coroutine' object has no attribute 'aget_messages', which suggests the asynchronous method handling might not be correctly invoked.

langchain==0.2.8
langchain-community==0.2.5
langchain-core==0.2.41
langchain-openai==0.1.8
langchain-postgres==0.0.7
@pprados
Copy link
Contributor

pprados commented Sep 30, 2024

Read the TU here

@lgabs
Copy link
Author

lgabs commented Oct 1, 2024

Thanks, @pprados

However, it's still unclear to me how to properly use RunnableWithMessageHistory. Even after reviewing the examples in the unit tests, I encountered the exact same error mentioned earlier (see the code below). Based on the error message, I believe the issue is that RunnableWithMessageHistory expects a function that returns an instance of BaseChatMessageHistory for a given session_id. The last few lines of the error indicate that the instance is being created with:

File /usr/local/lib/python3.12/site-packages/langchain_core/runnables/history.py:498, in RunnableWithMessageHistory._aenter_history(self, input, config)
    494 async def _aenter_history(
    495     self, input: Dict[str, Any], config: RunnableConfig
    496 ) -> List[BaseMessage]:
    497     hist: BaseChatMessageHistory = config["configurable"]["message_history"]
--> 498     messages = (await hist.aget_messages()).copy()
    500     if not self.history_messages_key:
    501         # return all messages
    502         input_val = (
    503             input if not self.input_messages_key else input[self.input_messages_key]

However, since hist is not awaited, its value remains a coroutine, which doesn't have the attribute aget_messages, as the error suggests.

Here's a new approach, yet failing, based on the unit tests:

from contextlib import asynccontextmanager, contextmanager
from typing_extensions import AsyncGenerator

from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables.history import RunnableWithMessageHistory

from langchain_postgres import PostgresChatMessageHistory

import psycopg
from psycopg import AsyncConnection

import uuid

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a helpful assistant."),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{question}"),
    ]
)

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

chain = prompt | llm | StrOutputParser()

@asynccontextmanager
async def asyncpg_client() -> AsyncGenerator[psycopg.AsyncConnection, None]:
    # Establish a connection to your test database
    conn = await psycopg.AsyncConnection.connect(conninfo=DATABASE_URL)
    try:
        yield conn
    finally:
        # Cleanup: close the connection after the test is done
        await conn.close()

async def aget_session_history(session_id: str = None) -> PostgresChatMessageHistory:
    session_id = session_id or str(uuid.uuid4())
    async with asyncpg_client() as async_connection:
        return PostgresChatMessageHistory(
            "chat_history", 
            session_id, 
            async_connection=async_connection
        )

chain_with_history_async = (
    RunnableWithMessageHistory(
        chain,
        aget_session_history,
        input_messages_key="question",
        history_messages_key="chat_history",
    )
)

session_id = str(uuid.uuid4())

answer = await chain_with_history_async.ainvoke(
    {"question": "Good morning!"},
    {"configurable": {"session_id": session_id}}
)
print("Async invoke:\n", answer)

Let me know if there's something I'm still missing or if any adjustments are needed. Thanks!

@pprados
Copy link
Contributor

pprados commented Oct 1, 2024

thank you for this analysis. Can you suggest a PR?

@jackadair
Copy link

@lgabs Any luck solving this issue?

@lgabs
Copy link
Author

lgabs commented Oct 4, 2024

I found a workaround to for that problem, but before preparing a PR, I'll share the idea for discussion. I've made some changes to the async methods of PostgresChatMessageHistory to open an async connection when the class instance does not have the async conn created in its instantiation, using a new attribute conn_str. While this works, it opens a new connection every time, making it bad for connection reusage. What do you think?

Here it is the full code:

from typing import List, Union, AsyncGenerator, Sequence
import json 
from contextlib import asynccontextmanager

import psycopg
from psycopg import sql, AsyncConnection

from langchain_postgres import PostgresChatMessageHistory
from langchain_core.messages import (
    HumanMessage,
    AIMessage,
    SystemMessage,
    BaseMessage,
    messages_from_dict,
    message_to_dict,
)


from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables.history import RunnableWithMessageHistory

from langchain_postgres import PostgresChatMessageHistory


import uuid

class CustomAsyncPostgresChatMessageHistory(PostgresChatMessageHistory):
    def __init__(self, session_id: str, conn_str: str):
        sync_connection = psycopg.connect(conn_str)
        super().__init__("chat_history", session_id, sync_connection=sync_connection)
        self.conn_str = conn_str

    @staticmethod
    def _insert_message_query(table_name: str) -> sql.Composed:
        """Make a SQL query to insert a message."""
        return sql.SQL(
            "INSERT INTO {table_name} (session_id, message) VALUES (%s, %s)"
        ).format(table_name=sql.Identifier(table_name))

    @staticmethod
    def _get_messages_query(table_name: str) -> sql.Composed:
        """Make a SQL query to get messages for a given session."""
        return sql.SQL(
            "SELECT message "
            "FROM {table_name} "
            "WHERE session_id = %(session_id)s "
            "ORDER BY id;"
        ).format(table_name=sql.Identifier(table_name))

    @asynccontextmanager
    async def asyncpg_client(
        self,
        conn_str: str,
    ) -> AsyncGenerator[psycopg.AsyncConnection, None]:
        # Establish a connection to the database
        conn = await psycopg.AsyncConnection.connect(conninfo=conn_str)
        try:
            yield conn
        finally:
            # Cleanup: close the connection after use
            await conn.close()

    async def aget_messages(self) -> List[BaseMessage]:
        """Retrieve messages from the chat message history."""
        if not self.conn_str:
            raise ValueError(
                "Please initialize the PostgresChatMessageHistory "
                "with a connection string."
            )

        async with self.asyncpg_client(self.conn_str) as async_connection:
            query = self._get_messages_query(self._table_name)
            async with async_connection.cursor() as cursor:
                await cursor.execute(query, {"session_id": self._session_id})
                items = [record[0] for record in await cursor.fetchall()]

        return messages_from_dict(items)
    
    async def aadd_messages(self, messages: Sequence[BaseMessage]) -> None:
        """Add messages to the chat message history."""
        if not self.conn_str:
            raise ValueError(
                "Please initialize the PostgresChatMessageHistory "
                "with a connection string."
            )

        values = [
            (self._session_id, json.dumps(message_to_dict(message)))
            for message in messages
        ]

        async with self.asyncpg_client(self.conn_str) as async_connection:
            query = self._insert_message_query(self._table_name)
            async with async_connection.cursor() as cursor:
                await cursor.executemany(query, values)
            await async_connection.commit()


def get_message_history(session_id: str = None) -> CustomAsyncPostgresChatMessageHistory:
    session_id = session_id or str(uuid.uuid4())
    return CustomAsyncPostgresChatMessageHistory(session_id, conn_str=DATABASE_URL)

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a helpful assistant."),
        MessagesPlaceholder(variable_name="chat_history"),
        ("human", "{question}"),
    ]
)

llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

chain = prompt | llm | StrOutputParser()

chain_with_history_custom_async = (
    RunnableWithMessageHistory(
        chain,
        get_message_history,
        input_messages_key="question",
        history_messages_key="chat_history",
    )
)

session_id = str(uuid.uuid4())
answer = await chain_with_history_custom_async.ainvoke(
    {"question": "Good morning!"},
    {"configurable": {"session_id": session_id}}
)
print("async answer: ", answer)

session_id = str(uuid.uuid4())
answer2 = chain_with_history_custom_async.invoke(
    {"question": "Good morning!"},
    {"configurable": {"session_id": session_id}}
)
print("sync answer: ", answer2)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants