|
18 | 18 | import redis
|
19 | 19 | import redis.asyncio
|
20 | 20 |
|
| 21 | +from redis.exceptions import ResponseError |
| 22 | +from redis.commands.search.indexDefinition import IndexDefinition, IndexType |
| 23 | +from redis.commands.search.aggregation import AggregateRequest |
| 24 | +from redis.commands.search.query import Query |
| 25 | +from redis.commands.search.field import ( |
| 26 | + TextField, |
| 27 | + VectorField, |
| 28 | +) |
| 29 | + |
21 | 30 | from opentelemetry import trace
|
22 | 31 | from opentelemetry.instrumentation.redis import RedisInstrumentor
|
23 | 32 | from opentelemetry.semconv.trace import SpanAttributes
|
@@ -614,3 +623,72 @@ def test_get(self):
|
614 | 623 | self.assertEqual(
|
615 | 624 | span.attributes.get(SpanAttributes.DB_STATEMENT), "GET ?"
|
616 | 625 | )
|
| 626 | + |
| 627 | + |
| 628 | +class TestRedisearchInstrument(TestBase): |
| 629 | + def setUp(self): |
| 630 | + super().setUp() |
| 631 | + self.redis_client = redis.Redis(port=6379) |
| 632 | + self.redis_client.flushall() |
| 633 | + self.embedding_dim = 256 |
| 634 | + RedisInstrumentor().instrument(tracer_provider=self.tracer_provider) |
| 635 | + self.prepare_data() |
| 636 | + self.create_index() |
| 637 | + |
| 638 | + def tearDown(self): |
| 639 | + RedisInstrumentor().uninstrument() |
| 640 | + super().tearDown() |
| 641 | + |
| 642 | + def prepare_data(self): |
| 643 | + try: |
| 644 | + self.redis_client.ft("idx:test_vss").dropindex(True) |
| 645 | + except ResponseError: |
| 646 | + print("No such index") |
| 647 | + item = {"name": "test", |
| 648 | + "value": "test_value", |
| 649 | + "embeddings": [0.1] * 256} |
| 650 | + pipeline = self.redis_client.pipeline() |
| 651 | + pipeline.json().set(f"test:001", "$", item) |
| 652 | + res = pipeline.execute() |
| 653 | + assert False not in res |
| 654 | + |
| 655 | + def create_index(self): |
| 656 | + schema = ( |
| 657 | + TextField("$.name", no_stem=True, as_name="name"), |
| 658 | + TextField("$.value", no_stem=True, as_name="value"), |
| 659 | + VectorField("$.embeddings", |
| 660 | + "FLAT", |
| 661 | + { |
| 662 | + "TYPE": "FLOAT32", |
| 663 | + "DIM": self.embedding_dim, |
| 664 | + "DISTANCE_METRIC": "COSINE", |
| 665 | + }, |
| 666 | + as_name="vector",), |
| 667 | + ) |
| 668 | + definition = IndexDefinition(prefix=["test:"], index_type=IndexType.JSON) |
| 669 | + res = self.redis_client.ft("idx:test_vss").create_index(fields=schema, definition=definition) |
| 670 | + assert "OK" in str(res) |
| 671 | + |
| 672 | + def test_redis_create_index(self): |
| 673 | + spans = self.memory_exporter.get_finished_spans() |
| 674 | + span = next(span for span in spans if span.name == "redis.create_index") |
| 675 | + assert "redis.create_index.definition" in span.attributes |
| 676 | + assert "redis.create_index.fields" in span.attributes |
| 677 | + |
| 678 | + def test_redis_aggregate(self): |
| 679 | + query = "*" |
| 680 | + self.redis_client.ft("idx:test_vss").aggregate(AggregateRequest(query).load()) |
| 681 | + spans = self.memory_exporter.get_finished_spans() |
| 682 | + span = next(span for span in spans if span.name == "redis.aggregate") |
| 683 | + assert span.attributes.get("redis.commands.aggregate.query") == query |
| 684 | + assert "redis.commands.aggregate.results" in span.attributes |
| 685 | + |
| 686 | + def test_redis_query(self): |
| 687 | + query = "@name:test" |
| 688 | + res = self.redis_client.ft("idx:test_vss").search(Query(query)) |
| 689 | + |
| 690 | + spans = self.memory_exporter.get_finished_spans() |
| 691 | + span = next(span for span in spans if span.name == "redis.search") |
| 692 | + |
| 693 | + assert span.attributes.get("redis.commands.search.query") == query |
| 694 | + assert span.attributes.get("redis.commands.search.total") == 1 |
0 commit comments