diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index 960f2492da..83eb666726 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio -from unittest import mock -from unittest.mock import AsyncMock +import pytest import redis -import redis.asyncio +import redis.asyncio as redis_async from redis.exceptions import WatchError +from unittest import mock, IsolatedAsyncioTestCase +from unittest.mock import AsyncMock from opentelemetry import trace from opentelemetry.instrumentation.redis import RedisInstrumentor @@ -27,7 +28,7 @@ SpanAttributes, ) from opentelemetry.test.test_base import TestBase -from opentelemetry.trace import SpanKind, StatusCode +from opentelemetry.trace import SpanKind class TestRedis(TestBase): @@ -144,7 +145,7 @@ def response_hook(span, conn, response): with mock.patch.object(connection, "send_command"): with mock.patch.object( - redis_client, "parse_response", return_value=test_value + redis_client, "parse_response", return_value=test_value ): redis_client.get("key") @@ -176,7 +177,7 @@ def request_hook(span, conn, args, kwargs): with mock.patch.object(connection, "send_command"): with mock.patch.object( - redis_client, "parse_response", return_value=test_value + redis_client, "parse_response", return_value=test_value ): redis_client.get("key") @@ -313,24 +314,26 @@ def test_attributes_unix_socket(self): NetTransportValues.OTHER.value, ) - def test_watch_error(self): + def test_successful_transaction(self): redis_client = redis.Redis() - # Mock the pipeline to raise a WatchError + # Create a mock pipeline mock_pipeline = mock.MagicMock() + mock_pipeline.__enter__.return_value = mock_pipeline # Ensure __enter__ returns the mock_pipeline mock_pipeline.watch.return_value = None mock_pipeline.multi.return_value = mock_pipeline - mock_pipeline.execute.side_effect = WatchError("Watched variable changed") + mock_pipeline.execute.return_value = ["OK"] # This is what we want to return with mock.patch.object(redis_client, "pipeline", return_value=mock_pipeline): - try: - with redis_client.pipeline() as pipe: - pipe.watch("key") - pipe.multi() - pipe.set("key", "value") - pipe.execute() - except WatchError: - pass # We expect this exception to be raised + with redis_client.pipeline() as pipe: + pipe.watch("key") + pipe.multi() + pipe.set("key", "value") + result = pipe.execute() + + # Check that the transaction was successful + print(f"Result: {result}") + self.assertEqual(result, ["OK"]) spans = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans), 1) @@ -339,14 +342,48 @@ def test_watch_error(self): # Check that the span is not marked as an error self.assertIsNone(span.status.status_code) - # Check that the WatchError is recorded as an event, not an exception + # Check that there are no exception events events = span.events - self.assertEqual(len(events), 1) - self.assertEqual(events[0].name, "exception") - self.assertEqual(events[0].attributes["exception.type"], "WatchError") - self.assertIn("Watched variable changed", events[0].attributes["exception.message"]) + self.assertEqual(len(events), 0) # Verify other span properties self.assertEqual(span.name, "MULTI") self.assertEqual(span.kind, SpanKind.CLIENT) self.assertEqual(span.attributes.get("db.system"), "redis") + + # Verify that the SET command is recorded in the span + self.assertIn("SET", span.attributes.get("db.statement", "")) + + # Optionally, check for any additional attributes specific to your instrumentation + # For example, you might want to verify that the database index is correctly recorded + self.assertEqual(span.attributes.get("db.redis.database_index"), 0) + + +class TestRedisAsync(TestBase, IsolatedAsyncioTestCase): + def setUp(self): + super().setUp() + RedisInstrumentor().instrument(tracer_provider=self.tracer_provider) + + def tearDown(self): + super().tearDown() + RedisInstrumentor().uninstrument() + + @pytest.mark.asyncio + async def test_redis_operations(self): + async def redis_operations(): + try: + r = redis_async.Redis() + async with r.pipeline(transaction=False) as pipe: + await pipe.watch("a") + await r.set("a", "bad") + pipe.multi() + await pipe.set("a", "1") + await pipe.execute() + except WatchError: + pass + + await redis_operations() + + spans = self.memory_exporter.get_finished_spans() + assert spans[-1].status.status_code == trace.StatusCode.UNSET + assert any(event.name == "WatchError" for event in spans[-1].events)