Skip to content

Commit 372dc7f

Browse files
Thommy257eyurtsev
andauthored
core[patch]: fix loss of partially initialized variables during prompt composition (#30096)
**Description:** This PR addresses the loss of partially initialised variables when composing different prompts. I.e. it allows the following snippet to run: ```python from langchain_core.prompts import ChatPromptTemplate prompt = ChatPromptTemplate.from_messages([('system', 'Prompt {x} {y}')]).partial(x='1') appendix = ChatPromptTemplate.from_messages([('system', 'Appendix {z}')]) (prompt + appendix).invoke({'y': '2', 'z': '3'}) ``` Previously, this would have raised a `KeyError`, stating that variable `x` remains undefined. **Issue** References issue #30049 **Todo** - [x] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ --------- Co-authored-by: Eugene Yurtsev <[email protected]>
1 parent e7883d5 commit 372dc7f

File tree

2 files changed

+36
-4
lines changed

2 files changed

+36
-4
lines changed

Diff for: libs/core/langchain_core/prompts/chat.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -1040,19 +1040,34 @@ def __add__(self, other: Any) -> ChatPromptTemplate:
10401040
Returns:
10411041
Combined prompt template.
10421042
"""
1043+
partials = {**self.partial_variables}
1044+
1045+
# Need to check that other has partial variables since it may not be
1046+
# a ChatPromptTemplate.
1047+
if hasattr(other, "partial_variables") and other.partial_variables:
1048+
partials.update(other.partial_variables)
1049+
10431050
# Allow for easy combining
10441051
if isinstance(other, ChatPromptTemplate):
1045-
return ChatPromptTemplate(messages=self.messages + other.messages) # type: ignore[call-arg]
1052+
return ChatPromptTemplate(messages=self.messages + other.messages).partial(
1053+
**partials
1054+
) # type: ignore[call-arg]
10461055
elif isinstance(
10471056
other, (BaseMessagePromptTemplate, BaseMessage, BaseChatPromptTemplate)
10481057
):
1049-
return ChatPromptTemplate(messages=self.messages + [other]) # type: ignore[call-arg]
1058+
return ChatPromptTemplate(messages=self.messages + [other]).partial(
1059+
**partials
1060+
) # type: ignore[call-arg]
10501061
elif isinstance(other, (list, tuple)):
10511062
_other = ChatPromptTemplate.from_messages(other)
1052-
return ChatPromptTemplate(messages=self.messages + _other.messages) # type: ignore[call-arg]
1063+
return ChatPromptTemplate(messages=self.messages + _other.messages).partial(
1064+
**partials
1065+
) # type: ignore[call-arg]
10531066
elif isinstance(other, str):
10541067
prompt = HumanMessagePromptTemplate.from_template(other)
1055-
return ChatPromptTemplate(messages=self.messages + [prompt]) # type: ignore[call-arg]
1068+
return ChatPromptTemplate(messages=self.messages + [prompt]).partial(
1069+
**partials
1070+
) # type: ignore[call-arg]
10561071
else:
10571072
msg = f"Unsupported operand type for +: {type(other)}"
10581073
raise NotImplementedError(msg)

Diff for: libs/core/tests/unit_tests/prompts/test_chat.py

+17
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,23 @@ def test_chat_message_partial() -> None:
582582
assert template2.format(input="hello") == get_buffer_string(expected)
583583

584584

585+
def test_chat_message_partial_composition() -> None:
586+
"""Test composition of partially initialized messages."""
587+
prompt = ChatPromptTemplate.from_messages([("system", "Prompt {x} {y}")]).partial(
588+
x="1"
589+
)
590+
591+
appendix = ChatPromptTemplate.from_messages([("system", "Appendix {z}")])
592+
593+
res = (prompt + appendix).format_messages(y="2", z="3")
594+
expected = [
595+
SystemMessage(content="Prompt 1 2"),
596+
SystemMessage(content="Appendix 3"),
597+
]
598+
599+
assert res == expected
600+
601+
585602
async def test_chat_tmpl_from_messages_multipart_text() -> None:
586603
template = ChatPromptTemplate.from_messages(
587604
[

0 commit comments

Comments
 (0)