12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
import asyncio
15
- import logging
16
- from unittest import mock
17
- from unittest .mock import AsyncMock , patch
18
15
19
16
import pytest
20
17
import redis
21
- import redis .asyncio
18
+ import redis .asyncio as redis_async
22
19
from redis .exceptions import WatchError
20
+ from unittest import mock , IsolatedAsyncioTestCase
21
+ from unittest .mock import AsyncMock
23
22
24
23
from opentelemetry import trace
25
24
from opentelemetry .instrumentation .redis import RedisInstrumentor
29
28
SpanAttributes ,
30
29
)
31
30
from opentelemetry .test .test_base import TestBase
32
- from opentelemetry .trace import SpanKind , StatusCode
31
+ from opentelemetry .trace import SpanKind
33
32
34
33
35
34
class TestRedis (TestBase ):
@@ -359,114 +358,32 @@ def test_successful_transaction(self):
359
358
# For example, you might want to verify that the database index is correctly recorded
360
359
self .assertEqual (span .attributes .get ("db.redis.database_index" ), 0 )
361
360
362
- def test_watch_error (self ):
363
- redis_client = redis .Redis ()
364
-
365
- # Mock the pipeline to raise a WatchError
366
- mock_pipeline = mock .MagicMock ()
367
- mock_pipeline .watch .return_value = None
368
- mock_pipeline .multi .return_value = mock_pipeline
369
- mock_pipeline .execute .side_effect = WatchError ("Watched variable changed" )
370
-
371
- with mock .patch .object (redis_client , "pipeline" , return_value = mock_pipeline ):
372
- try :
373
- with redis_client .pipeline () as pipe :
374
- pipe .watch ("key" )
375
- pipe .multi ()
376
- pipe .set ("key" , "value" )
377
- pipe .execute ()
378
- except WatchError :
379
- pass # We expect this exception to be raised
380
-
381
- spans = self .memory_exporter .get_finished_spans ()
382
- self .assertEqual (len (spans ), 1 )
383
- span = spans [0 ]
384
-
385
- # Check that the span is not marked as an error
386
- self .assertIsNone (span .status .status_code )
387
-
388
- # Check that the WatchError is recorded as an event, not an exception
389
- events = span .events
390
- self .assertEqual (len (events ), 1 )
391
- self .assertEqual (events [0 ].name , "exception" )
392
- self .assertEqual (events [0 ].attributes ["exception.type" ], "WatchError" )
393
- self .assertIn ("Watched variable changed" , events [0 ].attributes ["exception.message" ])
394
-
395
- # Verify other span properties
396
- self .assertEqual (span .name , "MULTI" )
397
- self .assertEqual (span .kind , SpanKind .CLIENT )
398
- self .assertEqual (span .attributes .get ("db.system" ), "redis" )
399
-
400
-
401
- import pytest
402
- import redis .asyncio
403
- from redis .exceptions import WatchError
404
- from opentelemetry import trace
405
- from opentelemetry .semconv .trace import SpanAttributes
406
- from opentelemetry .instrumentation .redis import RedisInstrumentor
407
- from opentelemetry .sdk .trace import TracerProvider
408
- from opentelemetry .sdk .trace .export import SimpleSpanProcessor
409
- from opentelemetry .sdk .trace .export .in_memory_span_exporter import \
410
- InMemorySpanExporter # This is the correct import for MemorySpanExporter
411
-
412
-
413
- class Test_Redis :
414
- @pytest .fixture (autouse = True )
415
- def setup_and_teardown (self ):
416
- # Setup
417
- self .tracer_provider = TracerProvider ()
418
- self .memory_exporter = InMemorySpanExporter ()
419
- span_processor = SimpleSpanProcessor (self .memory_exporter )
420
- self .tracer_provider .add_span_processor (span_processor )
421
- trace .set_tracer_provider (self .tracer_provider )
422
361
362
+ class TestRedisAsync (TestBase , IsolatedAsyncioTestCase ):
363
+ def setUp (self ):
364
+ super ().setUp ()
423
365
RedisInstrumentor ().instrument (tracer_provider = self .tracer_provider )
424
366
425
- yield
426
-
427
- # Teardown
367
+ def tearDown (self ):
368
+ super ().tearDown ()
428
369
RedisInstrumentor ().uninstrument ()
429
370
430
371
@pytest .mark .asyncio
431
- async def test_watch_error (self ):
432
- r = redis .asyncio .Redis ()
433
- await r .set ("a" , "0" )
434
-
435
- try :
436
- async with r .pipeline (transaction = False ) as pipe :
437
- await pipe .watch ("a" )
438
- a = await pipe .get ("a" )
439
-
440
- # Simulate a change by another client
441
- await r .set ("a" , "bad" )
442
-
443
- pipe .multi ()
444
- await pipe .set ("a" , str (int (a ) + 1 ))
372
+ async def test_redis_operations (self ):
373
+ async def redis_operations ():
374
+ try :
375
+ r = redis_async .Redis ()
376
+ async with r .pipeline (transaction = False ) as pipe :
377
+ await pipe .watch ("a" )
378
+ await r .set ("a" , "bad" )
379
+ pipe .multi ()
380
+ await pipe .set ("a" , "1" )
381
+ await pipe .execute ()
382
+ except WatchError :
383
+ pass
445
384
446
- await pipe .execute ()
447
- except WatchError :
448
- print ("WatchError caught as expected" )
449
- else :
450
- pytest .fail ("WatchError was not raised" )
385
+ await redis_operations ()
451
386
452
387
spans = self .memory_exporter .get_finished_spans ()
453
- assert len (spans ) > 0 , "No spans were recorded"
454
-
455
- # Check the last span for WatchError evidence
456
- last_span = spans [- 1 ]
457
-
458
- # The span itself should not be marked as an error
459
- assert last_span .status .status_code is None
460
-
461
- # Check for WatchError in span events
462
- watch_error_events = [event for event in last_span .events
463
- if event .name == "exception" and
464
- event .attributes .get ("exception.type" ) == "WatchError" ]
465
- assert len (watch_error_events ) > 0 , "WatchError event not found in span"
466
-
467
- # Verify that the value in Redis wasn't changed due to the WatchError
468
- final_value = await r .get ("a" )
469
- assert final_value == b"bad"
470
-
471
- # Clean up
472
- await r .delete ("a" )
388
+ assert spans [- 1 ].status .status_code == trace .StatusCode .UNSET
389
+ assert any (event .name == "WatchError" for event in spans [- 1 ].events )
0 commit comments