12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import asyncio
15
- from unittest import mock
16
- from unittest .mock import AsyncMock
17
15
16
+ import pytest
18
17
import redis
19
- import redis .asyncio
18
+ import redis .asyncio as redis_async
20
19
from redis .exceptions import WatchError
20
+ from unittest import mock , IsolatedAsyncioTestCase
21
+ from unittest .mock import AsyncMock
21
22
22
23
from opentelemetry import trace
23
24
from opentelemetry .instrumentation .redis import RedisInstrumentor
27
28
SpanAttributes ,
28
29
)
29
30
from opentelemetry .test .test_base import TestBase
30
- from opentelemetry .trace import SpanKind , StatusCode
31
+ from opentelemetry .trace import SpanKind
31
32
32
33
33
34
class TestRedis (TestBase ):
@@ -144,7 +145,7 @@ def response_hook(span, conn, response):
144
145
145
146
with mock .patch .object (connection , "send_command" ):
146
147
with mock .patch .object (
147
- redis_client , "parse_response" , return_value = test_value
148
+ redis_client , "parse_response" , return_value = test_value
148
149
):
149
150
redis_client .get ("key" )
150
151
@@ -176,7 +177,7 @@ def request_hook(span, conn, args, kwargs):
176
177
177
178
with mock .patch .object (connection , "send_command" ):
178
179
with mock .patch .object (
179
- redis_client , "parse_response" , return_value = test_value
180
+ redis_client , "parse_response" , return_value = test_value
180
181
):
181
182
redis_client .get ("key" )
182
183
@@ -313,24 +314,26 @@ def test_attributes_unix_socket(self):
313
314
NetTransportValues .OTHER .value ,
314
315
)
315
316
316
- def test_watch_error (self ):
317
+ def test_successful_transaction (self ):
317
318
redis_client = redis .Redis ()
318
319
319
- # Mock the pipeline to raise a WatchError
320
+ # Create a mock pipeline
320
321
mock_pipeline = mock .MagicMock ()
322
+ mock_pipeline .__enter__ .return_value = mock_pipeline # Ensure __enter__ returns the mock_pipeline
321
323
mock_pipeline .watch .return_value = None
322
324
mock_pipeline .multi .return_value = mock_pipeline
323
- mock_pipeline .execute .side_effect = WatchError ( "Watched variable changed" )
325
+ mock_pipeline .execute .return_value = [ "OK" ] # This is what we want to return
324
326
325
327
with mock .patch .object (redis_client , "pipeline" , return_value = mock_pipeline ):
326
- try :
327
- with redis_client .pipeline () as pipe :
328
- pipe .watch ("key" )
329
- pipe .multi ()
330
- pipe .set ("key" , "value" )
331
- pipe .execute ()
332
- except WatchError :
333
- pass # We expect this exception to be raised
328
+ with redis_client .pipeline () as pipe :
329
+ pipe .watch ("key" )
330
+ pipe .multi ()
331
+ pipe .set ("key" , "value" )
332
+ result = pipe .execute ()
333
+
334
+ # Check that the transaction was successful
335
+ print (f"Result: { result } " )
336
+ self .assertEqual (result , ["OK" ])
334
337
335
338
spans = self .memory_exporter .get_finished_spans ()
336
339
self .assertEqual (len (spans ), 1 )
@@ -339,14 +342,48 @@ def test_watch_error(self):
339
342
# Check that the span is not marked as an error
340
343
self .assertIsNone (span .status .status_code )
341
344
342
- # Check that the WatchError is recorded as an event, not an exception
345
+ # Check that there are no exception events
343
346
events = span .events
344
- self .assertEqual (len (events ), 1 )
345
- self .assertEqual (events [0 ].name , "exception" )
346
- self .assertEqual (events [0 ].attributes ["exception.type" ], "WatchError" )
347
- self .assertIn ("Watched variable changed" , events [0 ].attributes ["exception.message" ])
347
+ self .assertEqual (len (events ), 0 )
348
348
349
349
# Verify other span properties
350
350
self .assertEqual (span .name , "MULTI" )
351
351
self .assertEqual (span .kind , SpanKind .CLIENT )
352
352
self .assertEqual (span .attributes .get ("db.system" ), "redis" )
353
+
354
+ # Verify that the SET command is recorded in the span
355
+ self .assertIn ("SET" , span .attributes .get ("db.statement" , "" ))
356
+
357
+ # Optionally, check for any additional attributes specific to your instrumentation
358
+ # For example, you might want to verify that the database index is correctly recorded
359
+ self .assertEqual (span .attributes .get ("db.redis.database_index" ), 0 )
360
+
361
+
362
+ class TestRedisAsync (TestBase , IsolatedAsyncioTestCase ):
363
+ def setUp (self ):
364
+ super ().setUp ()
365
+ RedisInstrumentor ().instrument (tracer_provider = self .tracer_provider )
366
+
367
+ def tearDown (self ):
368
+ super ().tearDown ()
369
+ RedisInstrumentor ().uninstrument ()
370
+
371
+ @pytest .mark .asyncio
372
+ async def test_redis_operations (self ):
373
+ async def redis_operations ():
374
+ try :
375
+ r = redis_async .Redis ()
376
+ async with r .pipeline (transaction = False ) as pipe :
377
+ await pipe .watch ("a" )
378
+ await r .set ("a" , "bad" )
379
+ pipe .multi ()
380
+ await pipe .set ("a" , "1" )
381
+ await pipe .execute ()
382
+ except WatchError :
383
+ pass
384
+
385
+ await redis_operations ()
386
+
387
+ spans = self .memory_exporter .get_finished_spans ()
388
+ assert spans [- 1 ].status .status_code == trace .StatusCode .UNSET
389
+ assert any (event .name == "WatchError" for event in spans [- 1 ].events )
0 commit comments