Skip to content

Commit 7f46e93

Browse files
author
Viktor Ivanov
committed
Fix async redis clients tracing
1 parent a679754 commit 7f46e93

File tree

2 files changed

+138
-34
lines changed

2 files changed

+138
-34
lines changed

instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py

+89-34
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,44 @@ def _set_connection_attributes(span, conn):
157157
span.set_attribute(key, value)
158158

159159

160+
def _build_span_name(instance, cmd_args):
161+
if len(cmd_args) > 0 and cmd_args[0]:
162+
name = cmd_args[0]
163+
else:
164+
name = instance.connection_pool.connection_kwargs.get("db", 0)
165+
return name
166+
167+
168+
def _build_span_meta_data_for_pipeline(instance, sanitize_query):
169+
try:
170+
command_stack = (
171+
instance.command_stack
172+
if hasattr(instance, "command_stack")
173+
else instance._command_stack
174+
)
175+
176+
cmds = [
177+
_format_command_args(
178+
c.args if hasattr(c, "args") else c[0], sanitize_query
179+
)
180+
for c in command_stack
181+
]
182+
resource = "\n".join(cmds)
183+
184+
span_name = " ".join(
185+
[
186+
(c.args[0] if hasattr(c, "args") else c[0][0])
187+
for c in command_stack
188+
]
189+
)
190+
except (AttributeError, IndexError):
191+
command_stack = []
192+
resource = ""
193+
span_name = ""
194+
195+
return command_stack, resource, span_name
196+
197+
160198
def _instrument(
161199
tracer,
162200
request_hook: _RequestHookT = None,
@@ -165,11 +203,8 @@ def _instrument(
165203
):
166204
def _traced_execute_command(func, instance, args, kwargs):
167205
query = _format_command_args(args, sanitize_query)
206+
name = _build_span_name(instance, args)
168207

169-
if len(args) > 0 and args[0]:
170-
name = args[0]
171-
else:
172-
name = instance.connection_pool.connection_kwargs.get("db", 0)
173208
with tracer.start_as_current_span(
174209
name, kind=trace.SpanKind.CLIENT
175210
) as span:
@@ -185,31 +220,11 @@ def _traced_execute_command(func, instance, args, kwargs):
185220
return response
186221

187222
def _traced_execute_pipeline(func, instance, args, kwargs):
188-
try:
189-
command_stack = (
190-
instance.command_stack
191-
if hasattr(instance, "command_stack")
192-
else instance._command_stack
193-
)
194-
195-
cmds = [
196-
_format_command_args(
197-
c.args if hasattr(c, "args") else c[0], sanitize_query
198-
)
199-
for c in command_stack
200-
]
201-
resource = "\n".join(cmds)
202-
203-
span_name = " ".join(
204-
[
205-
(c.args[0] if hasattr(c, "args") else c[0][0])
206-
for c in command_stack
207-
]
208-
)
209-
except (AttributeError, IndexError):
210-
command_stack = []
211-
resource = ""
212-
span_name = ""
223+
(
224+
command_stack,
225+
resource,
226+
span_name,
227+
) = _build_span_meta_data_for_pipeline(instance, sanitize_query)
213228

214229
with tracer.start_as_current_span(
215230
span_name, kind=trace.SpanKind.CLIENT
@@ -254,32 +269,72 @@ def _traced_execute_pipeline(func, instance, args, kwargs):
254269
"ClusterPipeline.execute",
255270
_traced_execute_pipeline,
256271
)
272+
273+
async def _async_traced_execute_command(func, instance, args, kwargs):
274+
query = _format_command_args(args, sanitize_query)
275+
name = _build_span_name(instance, args)
276+
277+
with tracer.start_as_current_span(
278+
name, kind=trace.SpanKind.CLIENT
279+
) as span:
280+
if span.is_recording():
281+
span.set_attribute(SpanAttributes.DB_STATEMENT, query)
282+
_set_connection_attributes(span, instance)
283+
span.set_attribute("db.redis.args_length", len(args))
284+
if callable(request_hook):
285+
request_hook(span, instance, args, kwargs)
286+
response = await func(*args, **kwargs)
287+
if callable(response_hook):
288+
response_hook(span, instance, response)
289+
return response
290+
291+
async def _async_traced_execute_pipeline(func, instance, args, kwargs):
292+
(
293+
command_stack,
294+
resource,
295+
span_name,
296+
) = _build_span_meta_data_for_pipeline(instance, sanitize_query)
297+
298+
with tracer.start_as_current_span(
299+
span_name, kind=trace.SpanKind.CLIENT
300+
) as span:
301+
if span.is_recording():
302+
span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
303+
_set_connection_attributes(span, instance)
304+
span.set_attribute(
305+
"db.redis.pipeline_length", len(command_stack)
306+
)
307+
response = await func(*args, **kwargs)
308+
if callable(response_hook):
309+
response_hook(span, instance, response)
310+
return response
311+
257312
if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
258313
wrap_function_wrapper(
259314
"redis.asyncio",
260315
f"{redis_class}.execute_command",
261-
_traced_execute_command,
316+
_async_traced_execute_command,
262317
)
263318
wrap_function_wrapper(
264319
"redis.asyncio.client",
265320
f"{pipeline_class}.execute",
266-
_traced_execute_pipeline,
321+
_async_traced_execute_pipeline,
267322
)
268323
wrap_function_wrapper(
269324
"redis.asyncio.client",
270325
f"{pipeline_class}.immediate_execute_command",
271-
_traced_execute_command,
326+
_async_traced_execute_command,
272327
)
273328
if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION:
274329
wrap_function_wrapper(
275330
"redis.asyncio.cluster",
276331
"RedisCluster.execute_command",
277-
_traced_execute_command,
332+
_async_traced_execute_command,
278333
)
279334
wrap_function_wrapper(
280335
"redis.asyncio.cluster",
281336
"ClusterPipeline.execute",
282-
_traced_execute_pipeline,
337+
_async_traced_execute_pipeline,
283338
)
284339

285340

instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py

+49
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,36 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import asyncio
1415
from unittest import mock
1516

1617
import redis
18+
import redis.asyncio
1719

1820
from opentelemetry import trace
1921
from opentelemetry.instrumentation.redis import RedisInstrumentor
2022
from opentelemetry.test.test_base import TestBase
2123
from opentelemetry.trace import SpanKind
2224

2325

26+
class AsyncMock:
27+
"""A sufficient async mock implementation.
28+
29+
Python 3.7 doesn't have an inbuilt async mock class, so this is used.
30+
"""
31+
32+
def __init__(self):
33+
self.mock = mock.Mock()
34+
35+
async def __call__(self, *args, **kwargs):
36+
f = asyncio.Future()
37+
f.set_result("random")
38+
return f
39+
40+
def __getattr__(self, item):
41+
return AsyncMock()
42+
43+
2444
class TestRedis(TestBase):
2545
def setUp(self):
2646
super().setUp()
@@ -87,6 +107,35 @@ def test_instrument_uninstrument(self):
87107
spans = self.memory_exporter.get_finished_spans()
88108
self.assertEqual(len(spans), 1)
89109

110+
def test_instrument_uninstrument_async_client_command(self):
111+
redis_client = redis.asyncio.Redis()
112+
113+
with mock.patch.object(redis_client, "connection", AsyncMock()):
114+
asyncio.run(redis_client.get("key"))
115+
116+
spans = self.memory_exporter.get_finished_spans()
117+
self.assertEqual(len(spans), 1)
118+
self.memory_exporter.clear()
119+
120+
# Test uninstrument
121+
RedisInstrumentor().uninstrument()
122+
123+
with mock.patch.object(redis_client, "connection", AsyncMock()):
124+
asyncio.run(redis_client.get("key"))
125+
126+
spans = self.memory_exporter.get_finished_spans()
127+
self.assertEqual(len(spans), 0)
128+
self.memory_exporter.clear()
129+
130+
# Test instrument again
131+
RedisInstrumentor().instrument()
132+
133+
with mock.patch.object(redis_client, "connection", AsyncMock()):
134+
asyncio.run(redis_client.get("key"))
135+
136+
spans = self.memory_exporter.get_finished_spans()
137+
self.assertEqual(len(spans), 1)
138+
90139
def test_response_hook(self):
91140
redis_client = redis.Redis()
92141
connection = redis.connection.Connection()

0 commit comments

Comments
 (0)