Skip to content

Commit cd6b024

Browse files
Vivanov98shalevr
andauthored
Fix async redis clients tracing (#1830)
* Fix async redis clients tracing * Update changelog * Add functional integration tests and fix linting issues --------- Co-authored-by: Shalev Roda <[email protected]>
1 parent e70437a commit cd6b024

File tree

4 files changed

+270
-34
lines changed

4 files changed

+270
-34
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818

1919
### Added
2020

21+
- Fix async redis clients not being traced correctly ([#1830](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1830))
2122
- Make Flask request span attributes available for `start_span`.
2223
([#1784](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1784))
2324
- Fix falcon instrumentation's usage of Span Status to only set the description if the status code is ERROR.

instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py

+88-34
Original file line numberDiff line numberDiff line change
@@ -136,18 +136,52 @@ def _set_connection_attributes(span, conn):
136136
span.set_attribute(key, value)
137137

138138

139+
def _build_span_name(instance, cmd_args):
140+
if len(cmd_args) > 0 and cmd_args[0]:
141+
name = cmd_args[0]
142+
else:
143+
name = instance.connection_pool.connection_kwargs.get("db", 0)
144+
return name
145+
146+
147+
def _build_span_meta_data_for_pipeline(instance):
148+
try:
149+
command_stack = (
150+
instance.command_stack
151+
if hasattr(instance, "command_stack")
152+
else instance._command_stack
153+
)
154+
155+
cmds = [
156+
_format_command_args(c.args if hasattr(c, "args") else c[0])
157+
for c in command_stack
158+
]
159+
resource = "\n".join(cmds)
160+
161+
span_name = " ".join(
162+
[
163+
(c.args[0] if hasattr(c, "args") else c[0][0])
164+
for c in command_stack
165+
]
166+
)
167+
except (AttributeError, IndexError):
168+
command_stack = []
169+
resource = ""
170+
span_name = ""
171+
172+
return command_stack, resource, span_name
173+
174+
175+
# pylint: disable=R0915
139176
def _instrument(
140177
tracer,
141178
request_hook: _RequestHookT = None,
142179
response_hook: _ResponseHookT = None,
143180
):
144181
def _traced_execute_command(func, instance, args, kwargs):
145182
query = _format_command_args(args)
183+
name = _build_span_name(instance, args)
146184

147-
if len(args) > 0 and args[0]:
148-
name = args[0]
149-
else:
150-
name = instance.connection_pool.connection_kwargs.get("db", 0)
151185
with tracer.start_as_current_span(
152186
name, kind=trace.SpanKind.CLIENT
153187
) as span:
@@ -163,31 +197,11 @@ def _traced_execute_command(func, instance, args, kwargs):
163197
return response
164198

165199
def _traced_execute_pipeline(func, instance, args, kwargs):
166-
try:
167-
command_stack = (
168-
instance.command_stack
169-
if hasattr(instance, "command_stack")
170-
else instance._command_stack
171-
)
172-
173-
cmds = [
174-
_format_command_args(
175-
c.args if hasattr(c, "args") else c[0],
176-
)
177-
for c in command_stack
178-
]
179-
resource = "\n".join(cmds)
180-
181-
span_name = " ".join(
182-
[
183-
(c.args[0] if hasattr(c, "args") else c[0][0])
184-
for c in command_stack
185-
]
186-
)
187-
except (AttributeError, IndexError):
188-
command_stack = []
189-
resource = ""
190-
span_name = ""
200+
(
201+
command_stack,
202+
resource,
203+
span_name,
204+
) = _build_span_meta_data_for_pipeline(instance)
191205

192206
with tracer.start_as_current_span(
193207
span_name, kind=trace.SpanKind.CLIENT
@@ -232,32 +246,72 @@ def _traced_execute_pipeline(func, instance, args, kwargs):
232246
"ClusterPipeline.execute",
233247
_traced_execute_pipeline,
234248
)
249+
250+
async def _async_traced_execute_command(func, instance, args, kwargs):
251+
query = _format_command_args(args)
252+
name = _build_span_name(instance, args)
253+
254+
with tracer.start_as_current_span(
255+
name, kind=trace.SpanKind.CLIENT
256+
) as span:
257+
if span.is_recording():
258+
span.set_attribute(SpanAttributes.DB_STATEMENT, query)
259+
_set_connection_attributes(span, instance)
260+
span.set_attribute("db.redis.args_length", len(args))
261+
if callable(request_hook):
262+
request_hook(span, instance, args, kwargs)
263+
response = await func(*args, **kwargs)
264+
if callable(response_hook):
265+
response_hook(span, instance, response)
266+
return response
267+
268+
async def _async_traced_execute_pipeline(func, instance, args, kwargs):
269+
(
270+
command_stack,
271+
resource,
272+
span_name,
273+
) = _build_span_meta_data_for_pipeline(instance)
274+
275+
with tracer.start_as_current_span(
276+
span_name, kind=trace.SpanKind.CLIENT
277+
) as span:
278+
if span.is_recording():
279+
span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
280+
_set_connection_attributes(span, instance)
281+
span.set_attribute(
282+
"db.redis.pipeline_length", len(command_stack)
283+
)
284+
response = await func(*args, **kwargs)
285+
if callable(response_hook):
286+
response_hook(span, instance, response)
287+
return response
288+
235289
if redis.VERSION >= _REDIS_ASYNCIO_VERSION:
236290
wrap_function_wrapper(
237291
"redis.asyncio",
238292
f"{redis_class}.execute_command",
239-
_traced_execute_command,
293+
_async_traced_execute_command,
240294
)
241295
wrap_function_wrapper(
242296
"redis.asyncio.client",
243297
f"{pipeline_class}.execute",
244-
_traced_execute_pipeline,
298+
_async_traced_execute_pipeline,
245299
)
246300
wrap_function_wrapper(
247301
"redis.asyncio.client",
248302
f"{pipeline_class}.immediate_execute_command",
249-
_traced_execute_command,
303+
_async_traced_execute_command,
250304
)
251305
if redis.VERSION >= _REDIS_ASYNCIO_CLUSTER_VERSION:
252306
wrap_function_wrapper(
253307
"redis.asyncio.cluster",
254308
"RedisCluster.execute_command",
255-
_traced_execute_command,
309+
_async_traced_execute_command,
256310
)
257311
wrap_function_wrapper(
258312
"redis.asyncio.cluster",
259313
"ClusterPipeline.execute",
260-
_traced_execute_pipeline,
314+
_async_traced_execute_pipeline,
261315
)
262316

263317

instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py

+49
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,36 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import asyncio
1415
from unittest import mock
1516

1617
import redis
18+
import redis.asyncio
1719

1820
from opentelemetry import trace
1921
from opentelemetry.instrumentation.redis import RedisInstrumentor
2022
from opentelemetry.test.test_base import TestBase
2123
from opentelemetry.trace import SpanKind
2224

2325

26+
class AsyncMock:
27+
"""A sufficient async mock implementation.
28+
29+
Python 3.7 doesn't have an inbuilt async mock class, so this is used.
30+
"""
31+
32+
def __init__(self):
33+
self.mock = mock.Mock()
34+
35+
async def __call__(self, *args, **kwargs):
36+
future = asyncio.Future()
37+
future.set_result("random")
38+
return future
39+
40+
def __getattr__(self, item):
41+
return AsyncMock()
42+
43+
2444
class TestRedis(TestBase):
2545
def setUp(self):
2646
super().setUp()
@@ -87,6 +107,35 @@ def test_instrument_uninstrument(self):
87107
spans = self.memory_exporter.get_finished_spans()
88108
self.assertEqual(len(spans), 1)
89109

110+
def test_instrument_uninstrument_async_client_command(self):
111+
redis_client = redis.asyncio.Redis()
112+
113+
with mock.patch.object(redis_client, "connection", AsyncMock()):
114+
asyncio.run(redis_client.get("key"))
115+
116+
spans = self.memory_exporter.get_finished_spans()
117+
self.assertEqual(len(spans), 1)
118+
self.memory_exporter.clear()
119+
120+
# Test uninstrument
121+
RedisInstrumentor().uninstrument()
122+
123+
with mock.patch.object(redis_client, "connection", AsyncMock()):
124+
asyncio.run(redis_client.get("key"))
125+
126+
spans = self.memory_exporter.get_finished_spans()
127+
self.assertEqual(len(spans), 0)
128+
self.memory_exporter.clear()
129+
130+
# Test instrument again
131+
RedisInstrumentor().instrument()
132+
133+
with mock.patch.object(redis_client, "connection", AsyncMock()):
134+
asyncio.run(redis_client.get("key"))
135+
136+
spans = self.memory_exporter.get_finished_spans()
137+
self.assertEqual(len(spans), 1)
138+
90139
def test_response_hook(self):
91140
redis_client = redis.Redis()
92141
connection = redis.connection.Connection()

0 commit comments

Comments
 (0)