Skip to content

Commit c284b08

Browse files
authored
Make pipeline work with parallel runs (neo4j#119)
* Add failing test * Define a "run_id" in Orchestrator - save results per run_id * Make unit test work * Make intermediate results accessible from outside pipeline for investigation * Remove unused imports * Update examples and CHANGELOG * Cleaning: remove deprecated code * Fix ruff * Fix examples * Fix examples again * PR reviews * Removing useless status assignment
1 parent 411b5ea commit c284b08

File tree

9 files changed

+173
-152
lines changed

9 files changed

+173
-152
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## Next
44

5+
## Fixed
6+
- Pipelines now return correct results when the same pipeline is run in parallel.
7+
58
## 0.5.0
69

710
### Added

examples/pipeline/kg_builder_from_pdf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@
3535
LangChainTextSplitterAdapter,
3636
)
3737
from neo4j_genai.experimental.pipeline import Component, DataModel
38+
from neo4j_genai.experimental.pipeline.pipeline import PipelineResult
3839
from neo4j_genai.llm import OpenAILLM
3940
from pydantic import BaseModel, validate_call
4041

41-
logging.basicConfig(level=logging.DEBUG)
42+
logging.basicConfig(level=logging.INFO)
4243

4344

4445
class DocumentChunkModel(DataModel):
@@ -98,7 +99,7 @@ async def run(self, graph: Neo4jGraph) -> WriterModel:
9899
)
99100

100101

101-
async def main(neo4j_driver: neo4j.Driver) -> dict[str, Any]:
102+
async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
102103
from neo4j_genai.experimental.pipeline import Pipeline
103104

104105
# Instantiate Entity and Relation objects

examples/pipeline/kg_builder_from_text.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616

1717
import asyncio
1818
import logging.config
19-
from typing import Any
2019

2120
import neo4j
2221
from langchain_text_splitters import CharacterTextSplitter
22+
from neo4j_genai.embeddings.openai import OpenAIEmbeddings
23+
from neo4j_genai.experimental.components.embedder import TextChunkEmbedder
2324
from neo4j_genai.experimental.components.entity_relation_extractor import (
2425
LLMEntityRelationExtractor,
2526
OnError,
@@ -35,6 +36,7 @@
3536
LangChainTextSplitterAdapter,
3637
)
3738
from neo4j_genai.experimental.pipeline import Pipeline
39+
from neo4j_genai.experimental.pipeline.pipeline import PipelineResult
3840
from neo4j_genai.llm import OpenAILLM
3941

4042
# set log level to DEBUG for all neo4j_genai.* loggers
@@ -58,7 +60,7 @@
5860
)
5961

6062

61-
async def main(neo4j_driver: neo4j.Driver) -> dict[str, Any]:
63+
async def main(neo4j_driver: neo4j.Driver) -> PipelineResult:
6264
"""This is where we define and run the KG builder pipeline, instantiating a few
6365
components:
6466
- Text Splitter: in this example we use a text splitter from the LangChain package
@@ -80,6 +82,7 @@ async def main(neo4j_driver: neo4j.Driver) -> dict[str, Any]:
8082
),
8183
"splitter",
8284
)
85+
pipe.add_component(TextChunkEmbedder(embedder=OpenAIEmbeddings()), "chunk_embedder")
8386
pipe.add_component(SchemaBuilder(), "schema")
8487
pipe.add_component(
8588
LLMEntityRelationExtractor(

examples/pipeline/rag.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
from __future__ import annotations
2222

2323
import asyncio
24+
from typing import List
2425

2526
import neo4j
2627
from neo4j_genai.embeddings.openai import OpenAIEmbeddings
2728
from neo4j_genai.experimental.pipeline import Component, Pipeline
2829
from neo4j_genai.experimental.pipeline.component import DataModel
30+
from neo4j_genai.experimental.pipeline.pipeline import PipelineResult
2931
from neo4j_genai.experimental.pipeline.types import (
3032
ComponentConfig,
3133
ConnectionConfig,
@@ -37,35 +39,37 @@
3739
from neo4j_genai.retrievers.base import Retriever
3840

3941

40-
class StringDataModel(DataModel):
41-
result: str
42+
class ComponentResultDataModel(DataModel):
43+
"""A simple DataModel with a single text field"""
44+
45+
text: str
4246

4347

4448
class RetrieverComponent(Component):
4549
def __init__(self, retriever: Retriever) -> None:
4650
self.retriever = retriever
4751

48-
async def run(self, query: str) -> StringDataModel:
52+
async def run(self, query: str) -> ComponentResultDataModel:
4953
res = self.retriever.search(query_text=query)
50-
return StringDataModel(result="\n".join(c.content for c in res.items))
54+
return ComponentResultDataModel(text="\n".join(c.content for c in res.items))
5155

5256

5357
class PromptTemplateComponent(Component):
5458
def __init__(self, prompt: PromptTemplate) -> None:
5559
self.prompt = prompt
5660

57-
async def run(self, query: str, context: list[str]) -> StringDataModel:
61+
async def run(self, query: str, context: List[str]) -> ComponentResultDataModel:
5862
prompt = self.prompt.format(query, context, examples="")
59-
return StringDataModel(result=prompt)
63+
return ComponentResultDataModel(text=prompt)
6064

6165

6266
class LLMComponent(Component):
6367
def __init__(self, llm: LLMInterface) -> None:
6468
self.llm = llm
6569

66-
async def run(self, prompt: str) -> StringDataModel:
70+
async def run(self, prompt: str) -> ComponentResultDataModel:
6771
llm_response = self.llm.invoke(prompt)
68-
return StringDataModel(result=llm_response.content)
72+
return ComponentResultDataModel(text=llm_response.content)
6973

7074

7175
if __name__ == "__main__":
@@ -96,21 +100,21 @@ async def run(self, prompt: str) -> StringDataModel:
96100
ConnectionConfig(
97101
start="retrieve",
98102
end="augment",
99-
input_config={"context": "retrieve.result"},
103+
input_config={"context": "retrieve.text"},
100104
),
101105
ConnectionConfig(
102106
start="augment",
103107
end="generate",
104-
input_config={"prompt": "augment.result"},
108+
input_config={"prompt": "augment.text"},
105109
),
106110
],
107111
)
108112
)
109113

110114
query = "A movie about the US presidency"
111-
result = asyncio.run(
115+
pipe_output: PipelineResult = asyncio.run(
112116
pipe.run({"retrieve": {"query": query}, "augment": {"query": query}})
113117
)
114-
print(result["generate"]["result"])
118+
print(pipe_output.result["generate"]["text"])
115119

116120
driver.close()

0 commit comments

Comments
 (0)