Skip to content

Commit 35d4af3

Browse files
text_expansion query support (#1837) (#1838)
(cherry picked from commit c9612c1) Co-authored-by: Miguel Grinberg <[email protected]>
1 parent aa02fe4 commit 35d4af3

File tree

3 files changed

+397
-0
lines changed

3 files changed

+397
-0
lines changed

Diff for: elasticsearch_dsl/query.py

+4
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,10 @@ class TermsSet(Query):
551551
name = "terms_set"
552552

553553

554+
class TextExpansion(Query):
555+
name = "text_expansion"
556+
557+
554558
class Wildcard(Query):
555559
name = "wildcard"
556560

Diff for: examples/async/sparse_vectors.py

+197
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# Licensed to Elasticsearch B.V. under one or more contributor
2+
# license agreements. See the NOTICE file distributed with
3+
# this work for additional information regarding copyright
4+
# ownership. Elasticsearch B.V. licenses this file to you under
5+
# the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""
19+
# Sparse vector database example
20+
21+
Requirements:
22+
23+
$ pip install nltk tqdm elasticsearch-dsl[async]
24+
25+
Before running this example, the ELSER v2 model must be downloaded and deployed
26+
to the Elasticsearch cluster, and an ingest pipeline must be defined. This can
27+
be done manually from Kibana, or with the following three curl commands from a
28+
terminal, adjusting the endpoint as needed:
29+
30+
curl -X PUT \
31+
"http://localhost:9200/_ml/trained_models/.elser_model_2?wait_for_completion" \
32+
-H "Content-Type: application/json" \
33+
-d '{"input":{"field_names":["text_field"]}}'
34+
curl -X POST \
35+
"http://localhost:9200/_ml/trained_models/.elser_model_2/deployment/_start?wait_for=fully_allocated"
36+
curl -X PUT \
37+
"http://localhost:9200/_ingest/pipeline/elser_ingest_pipeline" \
38+
-H "Content-Type: application/json" \
39+
-d '{"processors":[{"foreach":{"field":"passages","processor":{"inference":{"model_id":".elser_model_2","input_output":[{"input_field":"_ingest._value.content","output_field":"_ingest._value.embedding"}]}}}}]}'
40+
41+
To run the example:
42+
43+
$ python sparse_vectors.py "text to search"
44+
45+
The index will be created automatically if it does not exist. Add
46+
`--recreate-index` to regenerate it.
47+
48+
The example dataset includes a selection of workplace documents. The
49+
following are good example queries to try out with this dataset:
50+
51+
$ python sparse_vectors.py "work from home"
52+
$ python sparse_vectors.py "vacation time"
53+
$ python sparse_vectors.py "can I bring a bird to work?"
54+
55+
When the index is created, the documents are split into short passages, and for
56+
each passage a sparse embedding is generated using Elastic's ELSER v2 model.
57+
The documents that are returned as search results are those that have the
58+
highest scored passages. Add `--show-inner-hits` to the command to see
59+
individual passage results as well.
60+
"""
61+
62+
import argparse
63+
import asyncio
64+
import json
65+
import os
66+
from urllib.request import urlopen
67+
68+
import nltk
69+
from tqdm import tqdm
70+
71+
from elasticsearch_dsl import (
72+
AsyncDocument,
73+
Date,
74+
InnerDoc,
75+
Keyword,
76+
Nested,
77+
Q,
78+
SparseVector,
79+
Text,
80+
async_connections,
81+
)
82+
83+
DATASET_URL = "https://raw.githubusercontent.com/elastic/elasticsearch-labs/main/datasets/workplace-documents.json"
84+
85+
# initialize sentence tokenizer
86+
nltk.download("punkt", quiet=True)
87+
88+
89+
class Passage(InnerDoc):
90+
content = Text()
91+
embedding = SparseVector()
92+
93+
94+
class WorkplaceDoc(AsyncDocument):
95+
class Index:
96+
name = "workplace_documents_sparse"
97+
settings = {"default_pipeline": "elser_ingest_pipeline"}
98+
99+
name = Text()
100+
summary = Text()
101+
content = Text()
102+
created = Date()
103+
updated = Date()
104+
url = Keyword()
105+
category = Keyword()
106+
passages = Nested(Passage)
107+
108+
_model = None
109+
110+
def clean(self):
111+
# split the content into sentences
112+
passages = nltk.sent_tokenize(self.content)
113+
114+
# generate an embedding for each passage and save it as a nested document
115+
for passage in passages:
116+
self.passages.append(Passage(content=passage))
117+
118+
119+
async def create():
120+
121+
# create the index
122+
await WorkplaceDoc._index.delete(ignore_unavailable=True)
123+
await WorkplaceDoc.init()
124+
125+
# download the data
126+
dataset = json.loads(urlopen(DATASET_URL).read())
127+
128+
# import the dataset
129+
for data in tqdm(dataset, desc="Indexing documents..."):
130+
doc = WorkplaceDoc(
131+
name=data["name"],
132+
summary=data["summary"],
133+
content=data["content"],
134+
created=data.get("created_on"),
135+
updated=data.get("updated_at"),
136+
url=data["url"],
137+
category=data["category"],
138+
)
139+
await doc.save()
140+
141+
142+
async def search(query):
143+
return WorkplaceDoc.search()[:5].query(
144+
"nested",
145+
path="passages",
146+
query=Q(
147+
"text_expansion",
148+
passages__content={
149+
"model_id": ".elser_model_2",
150+
"model_text": query,
151+
},
152+
),
153+
inner_hits={"size": 2},
154+
)
155+
156+
157+
def parse_args():
158+
parser = argparse.ArgumentParser(description="Vector database with Elasticsearch")
159+
parser.add_argument(
160+
"--recreate-index", action="store_true", help="Recreate and populate the index"
161+
)
162+
parser.add_argument(
163+
"--show-inner-hits",
164+
action="store_true",
165+
help="Show results for individual passages",
166+
)
167+
parser.add_argument("query", action="store", help="The search query")
168+
return parser.parse_args()
169+
170+
171+
async def main():
172+
args = parse_args()
173+
174+
# initiate the default connection to elasticsearch
175+
async_connections.create_connection(hosts=[os.environ["ELASTICSEARCH_URL"]])
176+
177+
if args.recreate_index or not await WorkplaceDoc._index.exists():
178+
await create()
179+
180+
results = await search(args.query)
181+
182+
async for hit in results:
183+
print(
184+
f"Document: {hit.name} [Category: {hit.category}] [Score: {hit.meta.score}]"
185+
)
186+
print(f"Summary: {hit.summary}")
187+
if args.show_inner_hits:
188+
for passage in hit.meta.inner_hits.passages:
189+
print(f" - [Score: {passage.meta.score}] {passage.content!r}")
190+
print("")
191+
192+
# close the connection
193+
await async_connections.get_connection().close()
194+
195+
196+
if __name__ == "__main__":
197+
asyncio.run(main())

0 commit comments

Comments
 (0)