Skip to content

Commit acff3f4

Browse files
committed
Instrument RedisCluster clients
1 parent e267ebc commit acff3f4

File tree

4 files changed

+201
-2
lines changed

4 files changed

+201
-2
lines changed

Diff for: instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,9 @@ def response_hook(span, instance, response):
122122
if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
123123
import redis.asyncio
124124

125+
_REDIS_CLUSTER_VERSION = (4, 1, 0)
126+
_REDIS_ASYNCIO_CLUSTER_VERSION = (4, 3, 0)
127+
125128

126129
def _set_connection_attributes(span, conn):
127130
if not span.is_recording():
@@ -149,7 +152,8 @@ def _traced_execute_command(func, instance, args, kwargs):
149152
) as span:
150153
if span.is_recording():
151154
span.set_attribute(SpanAttributes.DB_STATEMENT, query)
152-
_set_connection_attributes(span, instance)
155+
if hasattr(instance, "connection_pool"):
156+
_set_connection_attributes(span, instance)
153157
span.set_attribute("db.redis.args_length", len(args))
154158
if callable(request_hook):
155159
request_hook(span, instance, args, kwargs)
@@ -178,6 +182,27 @@ def _traced_execute_pipeline(func, instance, args, kwargs):
178182
response_hook(span, instance, response)
179183
return response
180184

185+
def _traced_execute_cluster_pipeline(func, instance, args, kwargs):
186+
cmds = [_format_command_args(c.args) for c in (instance.command_stack if hasattr(instance, "command_stack") else instance._command_stack)]
187+
resource = "\n".join(cmds)
188+
189+
span_name = " ".join([c.args[0] for c in (instance.command_stack if hasattr(instance, "command_stack") else instance._command_stack)])
190+
191+
with tracer.start_as_current_span(
192+
span_name, kind=trace.SpanKind.CLIENT
193+
) as span:
194+
if span.is_recording():
195+
span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
196+
if hasattr(instance, "connection_pool"):
197+
_set_connection_attributes(span, instance)
198+
span.set_attribute(
199+
"db.redis.pipeline_length", len(instance.command_stack) if hasattr(instance, "command_stack") else len(instance._command_stack)
200+
)
201+
response = func(*args, **kwargs)
202+
if callable(response_hook):
203+
response_hook(span, instance, response)
204+
return response
205+
181206
pipeline_class = (
182207
"BasePipeline" if redis.VERSION < (3, 0, 0) else "Pipeline"
183208
)
@@ -196,6 +221,17 @@ def _traced_execute_pipeline(func, instance, args, kwargs):
196221
f"{pipeline_class}.immediate_execute_command",
197222
_traced_execute_command,
198223
)
224+
if redis.VERSION >= _REDIS_CLUSTER_VERSION:
225+
wrap_function_wrapper(
226+
"redis.cluster",
227+
"RedisCluster.execute_command",
228+
_traced_execute_command,
229+
)
230+
wrap_function_wrapper(
231+
"redis.cluster",
232+
"ClusterPipeline.execute",
233+
_traced_execute_cluster_pipeline,
234+
)
199235
if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
200236
wrap_function_wrapper(
201237
"redis.asyncio",
@@ -212,6 +248,17 @@ def _traced_execute_pipeline(func, instance, args, kwargs):
212248
f"{pipeline_class}.immediate_execute_command",
213249
_traced_execute_command,
214250
)
251+
if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION:
252+
wrap_function_wrapper(
253+
"redis.asyncio.cluster",
254+
"RedisCluster.execute_command",
255+
_traced_execute_command,
256+
)
257+
wrap_function_wrapper(
258+
"redis.asyncio.cluster",
259+
"ClusterPipeline.execute",
260+
_traced_execute_cluster_pipeline,
261+
)
215262

216263

217264
class RedisInstrumentor(BaseInstrumentor):
@@ -258,8 +305,16 @@ def _uninstrument(self, **kwargs):
258305
unwrap(redis.Redis, "pipeline")
259306
unwrap(redis.client.Pipeline, "execute")
260307
unwrap(redis.client.Pipeline, "immediate_execute_command")
308+
if redis.VERSION >= _REDIS_CLUSTER_VERSION:
309+
unwrap(redis.cluster.RedisCluster, "execute_command")
310+
unwrap(redis.cluster.RedisCluster, "pipeline")
311+
unwrap(redis.cluster.ClusterPipeline, "execute")
261312
if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
262313
unwrap(redis.asyncio.Redis, "execute_command")
263314
unwrap(redis.asyncio.Redis, "pipeline")
264315
unwrap(redis.asyncio.client.Pipeline, "execute")
265316
unwrap(redis.asyncio.client.Pipeline, "immediate_execute_command")
317+
if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION:
318+
unwrap(redis.asyncio.cluster.RedisCluster, "execute_command")
319+
unwrap(redis.asyncio.cluster.RedisCluster, "pipeline")
320+
unwrap(redis.asyncio.cluster.ClusterPipeline, "execute")

Diff for: tests/opentelemetry-docker-tests/tests/docker-compose.yml

+11
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,17 @@ services:
2727
image: redis:4.0-alpine
2828
ports:
2929
- "127.0.0.1:6379:6379"
30+
otrediscluster:
31+
image: grokzen/redis-cluster:6.2.0
32+
environment:
33+
- IP=0.0.0.0
34+
ports:
35+
- "127.0.0.1:7000:7000"
36+
- "127.0.0.1:7001:7001"
37+
- "127.0.0.1:7002:7002"
38+
- "127.0.0.1:7003:7003"
39+
- "127.0.0.1:7004:7004"
40+
- "127.0.0.1:7005:7005"
3041
otjaeger:
3142
image: jaegertracing/all-in-one:1.8
3243
environment:

Diff for: tests/opentelemetry-docker-tests/tests/redis/test_redis_functional.py

+133
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,70 @@ def test_parent(self):
124124
self.assertEqual(child_span.name, "GET")
125125

126126

127+
class TestRedisClusterInstrument(TestBase):
128+
def setUp(self):
129+
super().setUp()
130+
self.redis_client = redis.cluster.RedisCluster(host="localhost", port=7000)
131+
self.redis_client.flushall()
132+
RedisInstrumentor().instrument(tracer_provider=self.tracer_provider)
133+
134+
def tearDown(self):
135+
super().tearDown()
136+
RedisInstrumentor().uninstrument()
137+
138+
def _check_span(self, span, name):
139+
self.assertEqual(span.name, name)
140+
self.assertIs(span.status.status_code, trace.StatusCode.UNSET)
141+
142+
def test_basics(self):
143+
self.assertIsNone(self.redis_client.get("cheese"))
144+
spans = self.memory_exporter.get_finished_spans()
145+
self.assertEqual(len(spans), 1)
146+
span = spans[0]
147+
self._check_span(span, "GET")
148+
self.assertEqual(
149+
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET cheese"
150+
)
151+
self.assertEqual(span.attributes.get("db.redis.args_length"), 2)
152+
153+
def test_pipeline_traced(self):
154+
with self.redis_client.pipeline(transaction=False) as pipeline:
155+
pipeline.set("blah", 32)
156+
pipeline.rpush("foo", "éé")
157+
pipeline.hgetall("xxx")
158+
pipeline.execute()
159+
160+
spans = self.memory_exporter.get_finished_spans()
161+
self.assertEqual(len(spans), 1)
162+
span = spans[0]
163+
self._check_span(span, "SET RPUSH HGETALL")
164+
self.assertEqual(
165+
span.attributes.get(SpanAttributes.DB_STATEMENT),
166+
"SET blah 32\nRPUSH foo éé\nHGETALL xxx",
167+
)
168+
self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3)
169+
170+
def test_parent(self):
171+
"""Ensure OpenTelemetry works with redis."""
172+
ot_tracer = trace.get_tracer("redis_svc")
173+
174+
with ot_tracer.start_as_current_span("redis_get"):
175+
self.assertIsNone(self.redis_client.get("cheese"))
176+
177+
spans = self.memory_exporter.get_finished_spans()
178+
self.assertEqual(len(spans), 2)
179+
child_span, parent_span = spans[0], spans[1]
180+
181+
# confirm the parenting
182+
self.assertIsNone(parent_span.parent)
183+
self.assertIs(child_span.parent, parent_span.get_span_context())
184+
185+
self.assertEqual(parent_span.name, "redis_get")
186+
self.assertEqual(parent_span.instrumentation_info.name, "redis_svc")
187+
188+
self.assertEqual(child_span.name, "GET")
189+
190+
127191
def async_call(coro):
128192
loop = asyncio.get_event_loop()
129193
return loop.run_until_complete(coro)
@@ -238,6 +302,75 @@ def test_parent(self):
238302
self.assertEqual(child_span.name, "GET")
239303

240304

305+
class TestAsyncRedisClusterInstrument(TestBase):
306+
def setUp(self):
307+
super().setUp()
308+
self.redis_client = redis.asyncio.cluster.RedisCluster(host="localhost", port=7000)
309+
async_call(self.redis_client.flushall())
310+
RedisInstrumentor().instrument(tracer_provider=self.tracer_provider)
311+
312+
def tearDown(self):
313+
super().tearDown()
314+
RedisInstrumentor().uninstrument()
315+
316+
def _check_span(self, span, name):
317+
self.assertEqual(span.name, name)
318+
self.assertIs(span.status.status_code, trace.StatusCode.UNSET)
319+
320+
def test_basics(self):
321+
self.assertIsNone(async_call(self.redis_client.get("cheese")))
322+
spans = self.memory_exporter.get_finished_spans()
323+
self.assertEqual(len(spans), 1)
324+
span = spans[0]
325+
self._check_span(span, "GET")
326+
self.assertEqual(
327+
span.attributes.get(SpanAttributes.DB_STATEMENT), "GET cheese"
328+
)
329+
self.assertEqual(span.attributes.get("db.redis.args_length"), 2)
330+
331+
def test_pipeline_traced(self):
332+
async def pipeline_simple():
333+
async with self.redis_client.pipeline(
334+
transaction=False
335+
) as pipeline:
336+
pipeline.set("blah", 32)
337+
pipeline.rpush("foo", "éé")
338+
pipeline.hgetall("xxx")
339+
await pipeline.execute()
340+
341+
async_call(pipeline_simple())
342+
343+
spans = self.memory_exporter.get_finished_spans()
344+
self.assertEqual(len(spans), 1)
345+
span = spans[0]
346+
self._check_span(span, "SET RPUSH HGETALL")
347+
self.assertEqual(
348+
span.attributes.get(SpanAttributes.DB_STATEMENT),
349+
"SET blah 32\nRPUSH foo éé\nHGETALL xxx",
350+
)
351+
self.assertEqual(span.attributes.get("db.redis.pipeline_length"), 3)
352+
353+
def test_parent(self):
354+
"""Ensure OpenTelemetry works with redis."""
355+
ot_tracer = trace.get_tracer("redis_svc")
356+
357+
with ot_tracer.start_as_current_span("redis_get"):
358+
self.assertIsNone(async_call(self.redis_client.get("cheese")))
359+
360+
spans = self.memory_exporter.get_finished_spans()
361+
self.assertEqual(len(spans), 2)
362+
child_span, parent_span = spans[0], spans[1]
363+
364+
# confirm the parenting
365+
self.assertIsNone(parent_span.parent)
366+
self.assertIs(child_span.parent, parent_span.get_span_context())
367+
368+
self.assertEqual(parent_span.name, "redis_get")
369+
self.assertEqual(parent_span.instrumentation_info.name, "redis_svc")
370+
371+
self.assertEqual(child_span.name, "GET")
372+
373+
241374
class TestRedisDBIndexInstrument(TestBase):
242375
def setUp(self):
243376
super().setUp()

Diff for: tox.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ deps =
500500
psycopg2 ~= 2.8.4
501501
aiopg >= 0.13.0, < 1.3.0
502502
sqlalchemy ~= 1.4
503-
redis ~= 4.2
503+
redis ~= 4.3
504504
celery[pytest] >= 4.0, < 6.0
505505
protobuf~=3.13
506506
requests==2.25.0

0 commit comments

Comments
 (0)