Skip to content

Bugfix 2639 dev #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 5, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from unittest import mock
from unittest.mock import AsyncMock

import pytest
import redis
import redis.asyncio
import redis.asyncio as redis_async
from redis.exceptions import WatchError
from unittest import mock, IsolatedAsyncioTestCase
from unittest.mock import AsyncMock

from opentelemetry import trace
from opentelemetry.instrumentation.redis import RedisInstrumentor
Expand All @@ -27,7 +28,7 @@
SpanAttributes,
)
from opentelemetry.test.test_base import TestBase
from opentelemetry.trace import SpanKind, StatusCode
from opentelemetry.trace import SpanKind


class TestRedis(TestBase):
Expand Down Expand Up @@ -144,7 +145,7 @@ def response_hook(span, conn, response):

with mock.patch.object(connection, "send_command"):
with mock.patch.object(
redis_client, "parse_response", return_value=test_value
redis_client, "parse_response", return_value=test_value
):
redis_client.get("key")

Expand Down Expand Up @@ -176,7 +177,7 @@ def request_hook(span, conn, args, kwargs):

with mock.patch.object(connection, "send_command"):
with mock.patch.object(
redis_client, "parse_response", return_value=test_value
redis_client, "parse_response", return_value=test_value
):
redis_client.get("key")

Expand Down Expand Up @@ -313,24 +314,26 @@ def test_attributes_unix_socket(self):
NetTransportValues.OTHER.value,
)

def test_watch_error(self):
def test_successful_transaction(self):
redis_client = redis.Redis()

# Mock the pipeline to raise a WatchError
# Create a mock pipeline
mock_pipeline = mock.MagicMock()
mock_pipeline.__enter__.return_value = mock_pipeline # Ensure __enter__ returns the mock_pipeline
mock_pipeline.watch.return_value = None
mock_pipeline.multi.return_value = mock_pipeline
mock_pipeline.execute.side_effect = WatchError("Watched variable changed")
mock_pipeline.execute.return_value = ["OK"] # This is what we want to return

with mock.patch.object(redis_client, "pipeline", return_value=mock_pipeline):
try:
with redis_client.pipeline() as pipe:
pipe.watch("key")
pipe.multi()
pipe.set("key", "value")
pipe.execute()
except WatchError:
pass # We expect this exception to be raised
with redis_client.pipeline() as pipe:
pipe.watch("key")
pipe.multi()
pipe.set("key", "value")
result = pipe.execute()

# Check that the transaction was successful
print(f"Result: {result}")
self.assertEqual(result, ["OK"])

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
Expand All @@ -339,14 +342,48 @@ def test_watch_error(self):
# Check that the span is not marked as an error
self.assertIsNone(span.status.status_code)

# Check that the WatchError is recorded as an event, not an exception
# Check that there are no exception events
events = span.events
self.assertEqual(len(events), 1)
self.assertEqual(events[0].name, "exception")
self.assertEqual(events[0].attributes["exception.type"], "WatchError")
self.assertIn("Watched variable changed", events[0].attributes["exception.message"])
self.assertEqual(len(events), 0)

# Verify other span properties
self.assertEqual(span.name, "MULTI")
self.assertEqual(span.kind, SpanKind.CLIENT)
self.assertEqual(span.attributes.get("db.system"), "redis")

# Verify that the SET command is recorded in the span
self.assertIn("SET", span.attributes.get("db.statement", ""))

# Optionally, check for any additional attributes specific to your instrumentation
# For example, you might want to verify that the database index is correctly recorded
self.assertEqual(span.attributes.get("db.redis.database_index"), 0)


class TestRedisAsync(TestBase, IsolatedAsyncioTestCase):
def setUp(self):
super().setUp()
RedisInstrumentor().instrument(tracer_provider=self.tracer_provider)

def tearDown(self):
super().tearDown()
RedisInstrumentor().uninstrument()

@pytest.mark.asyncio
async def test_redis_operations(self):
async def redis_operations():
try:
r = redis_async.Redis()
async with r.pipeline(transaction=False) as pipe:
await pipe.watch("a")
await r.set("a", "bad")
pipe.multi()
await pipe.set("a", "1")
await pipe.execute()
except WatchError:
pass

await redis_operations()

spans = self.memory_exporter.get_finished_spans()
assert spans[-1].status.status_code == trace.StatusCode.UNSET
assert any(event.name == "WatchError" for event in spans[-1].events)