Skip to content

Commit db636a4

Browse files
adding response_hook to redis instrumentor (#669)
1 parent 291e508 commit db636a4

File tree

3 files changed

+185
-71
lines changed

3 files changed

+185
-71
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
- `opentelemetry-instrumentation-elasticsearch` Added `response_hook` and `request_hook` callbacks
1414
([#670](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/670))
1515

16+
### Added
17+
- `opentelemetry-instrumentation-redis` added request_hook and response_hook callbacks passed as arguments to the instrument method.
18+
([#669](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/669))
19+
1620
### Changed
1721
- `opentelemetry-instrumentation-botocore` Unpatch botocore Endpoint.prepare_request on uninstrument
1822
([#664](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/664))

instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py

+120-71
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,43 @@
3838
client = redis.StrictRedis(host="localhost", port=6379)
3939
client.get("my-key")
4040
41+
The `instrument` method accepts the following keyword args:
42+
43+
tracer_provider (TracerProvider) - an optional tracer provider
44+
45+
request_hook (Callable) - a function with extra user-defined logic to be performed before performing the request
46+
this function signature is: def request_hook(span: Span, instance: redis.connection.Connection, args, kwargs) -> None
47+
48+
response_hook (Callable) - a function with extra user-defined logic to be performed after performing the request
49+
this function signature is: def response_hook(span: Span, instance: redis.connection.Connection, response) -> None
50+
51+
for example:
52+
53+
.. code: python
54+
55+
from opentelemetry.instrumentation.redis import RedisInstrumentor
56+
import redis
57+
58+
def request_hook(span, instance, args, kwargs):
59+
if span and span.is_recording():
60+
span.set_attribute("custom_user_attribute_from_request_hook", "some-value")
61+
62+
def response_hook(span, instance, response):
63+
if span and span.is_recording():
64+
span.set_attribute("custom_user_attribute_from_response_hook", "some-value")
65+
66+
# Instrument redis with hooks
67+
RedisInstrumentor().instrument(request_hook=request_hook, response_hook=response_hook)
68+
69+
# This will report a span with the default settings and the custom attributes added from the hooks
70+
client = redis.StrictRedis(host="localhost", port=6379)
71+
client.get("my-key")
72+
4173
API
4274
---
4375
"""
44-
45-
from typing import Collection
76+
import typing
77+
from typing import Any, Collection
4678

4779
import redis
4880
from wrapt import wrap_function_wrapper
@@ -57,9 +89,19 @@
5789
from opentelemetry.instrumentation.redis.version import __version__
5890
from opentelemetry.instrumentation.utils import unwrap
5991
from opentelemetry.semconv.trace import SpanAttributes
92+
from opentelemetry.trace import Span
6093

6194
_DEFAULT_SERVICE = "redis"
6295

96+
_RequestHookT = typing.Optional[
97+
typing.Callable[
98+
[Span, redis.connection.Connection, typing.List, typing.Dict], None
99+
]
100+
]
101+
_ResponseHookT = typing.Optional[
102+
typing.Callable[[Span, redis.connection.Connection, Any], None]
103+
]
104+
63105

64106
def _set_connection_attributes(span, conn):
65107
if not span.is_recording():
@@ -70,42 +112,68 @@ def _set_connection_attributes(span, conn):
70112
span.set_attribute(key, value)
71113

72114

73-
def _traced_execute_command(func, instance, args, kwargs):
74-
tracer = getattr(redis, "_opentelemetry_tracer")
75-
query = _format_command_args(args)
76-
name = ""
77-
if len(args) > 0 and args[0]:
78-
name = args[0]
79-
else:
80-
name = instance.connection_pool.connection_kwargs.get("db", 0)
81-
with tracer.start_as_current_span(
82-
name, kind=trace.SpanKind.CLIENT
83-
) as span:
84-
if span.is_recording():
85-
span.set_attribute(SpanAttributes.DB_STATEMENT, query)
86-
_set_connection_attributes(span, instance)
87-
span.set_attribute("db.redis.args_length", len(args))
88-
return func(*args, **kwargs)
89-
90-
91-
def _traced_execute_pipeline(func, instance, args, kwargs):
92-
tracer = getattr(redis, "_opentelemetry_tracer")
93-
94-
cmds = [_format_command_args(c) for c, _ in instance.command_stack]
95-
resource = "\n".join(cmds)
96-
97-
span_name = " ".join([args[0] for args, _ in instance.command_stack])
98-
99-
with tracer.start_as_current_span(
100-
span_name, kind=trace.SpanKind.CLIENT
101-
) as span:
102-
if span.is_recording():
103-
span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
104-
_set_connection_attributes(span, instance)
105-
span.set_attribute(
106-
"db.redis.pipeline_length", len(instance.command_stack)
107-
)
108-
return func(*args, **kwargs)
115+
def _instrument(
116+
tracer,
117+
request_hook: _RequestHookT = None,
118+
response_hook: _ResponseHookT = None,
119+
):
120+
def _traced_execute_command(func, instance, args, kwargs):
121+
query = _format_command_args(args)
122+
name = ""
123+
if len(args) > 0 and args[0]:
124+
name = args[0]
125+
else:
126+
name = instance.connection_pool.connection_kwargs.get("db", 0)
127+
with tracer.start_as_current_span(
128+
name, kind=trace.SpanKind.CLIENT
129+
) as span:
130+
if span.is_recording():
131+
span.set_attribute(SpanAttributes.DB_STATEMENT, query)
132+
_set_connection_attributes(span, instance)
133+
span.set_attribute("db.redis.args_length", len(args))
134+
if callable(request_hook):
135+
request_hook(span, instance, args, kwargs)
136+
response = func(*args, **kwargs)
137+
if callable(response_hook):
138+
response_hook(span, instance, response)
139+
return response
140+
141+
def _traced_execute_pipeline(func, instance, args, kwargs):
142+
cmds = [_format_command_args(c) for c, _ in instance.command_stack]
143+
resource = "\n".join(cmds)
144+
145+
span_name = " ".join([args[0] for args, _ in instance.command_stack])
146+
147+
with tracer.start_as_current_span(
148+
span_name, kind=trace.SpanKind.CLIENT
149+
) as span:
150+
if span.is_recording():
151+
span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
152+
_set_connection_attributes(span, instance)
153+
span.set_attribute(
154+
"db.redis.pipeline_length", len(instance.command_stack)
155+
)
156+
response = func(*args, **kwargs)
157+
if callable(response_hook):
158+
response_hook(span, instance, response)
159+
return response
160+
161+
pipeline_class = (
162+
"BasePipeline" if redis.VERSION < (3, 0, 0) else "Pipeline"
163+
)
164+
redis_class = "StrictRedis" if redis.VERSION < (3, 0, 0) else "Redis"
165+
166+
wrap_function_wrapper(
167+
"redis", f"{redis_class}.execute_command", _traced_execute_command
168+
)
169+
wrap_function_wrapper(
170+
"redis.client", f"{pipeline_class}.execute", _traced_execute_pipeline,
171+
)
172+
wrap_function_wrapper(
173+
"redis.client",
174+
f"{pipeline_class}.immediate_execute_command",
175+
_traced_execute_command,
176+
)
109177

110178

111179
class RedisInstrumentor(BaseInstrumentor):
@@ -117,41 +185,22 @@ def instrumentation_dependencies(self) -> Collection[str]:
117185
return _instruments
118186

119187
def _instrument(self, **kwargs):
188+
"""Instruments the redis module
189+
190+
Args:
191+
**kwargs: Optional arguments
192+
``tracer_provider``: a TracerProvider, defaults to global.
193+
``response_hook``: An optional callback which is invoked right before the span is finished processing a response.
194+
"""
120195
tracer_provider = kwargs.get("tracer_provider")
121-
setattr(
122-
redis,
123-
"_opentelemetry_tracer",
124-
trace.get_tracer(
125-
__name__, __version__, tracer_provider=tracer_provider,
126-
),
196+
tracer = trace.get_tracer(
197+
__name__, __version__, tracer_provider=tracer_provider
198+
)
199+
_instrument(
200+
tracer,
201+
request_hook=kwargs.get("request_hook"),
202+
response_hook=kwargs.get("response_hook"),
127203
)
128-
129-
if redis.VERSION < (3, 0, 0):
130-
wrap_function_wrapper(
131-
"redis", "StrictRedis.execute_command", _traced_execute_command
132-
)
133-
wrap_function_wrapper(
134-
"redis.client",
135-
"BasePipeline.execute",
136-
_traced_execute_pipeline,
137-
)
138-
wrap_function_wrapper(
139-
"redis.client",
140-
"BasePipeline.immediate_execute_command",
141-
_traced_execute_command,
142-
)
143-
else:
144-
wrap_function_wrapper(
145-
"redis", "Redis.execute_command", _traced_execute_command
146-
)
147-
wrap_function_wrapper(
148-
"redis.client", "Pipeline.execute", _traced_execute_pipeline
149-
)
150-
wrap_function_wrapper(
151-
"redis.client",
152-
"Pipeline.immediate_execute_command",
153-
_traced_execute_command,
154-
)
155204

156205
def _uninstrument(self, **kwargs):
157206
if redis.VERSION < (3, 0, 0):

instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py

+61
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,64 @@ def test_instrument_uninstrument(self):
8080

8181
spans = self.memory_exporter.get_finished_spans()
8282
self.assertEqual(len(spans), 1)
83+
84+
def test_response_hook(self):
85+
redis_client = redis.Redis()
86+
connection = redis.connection.Connection()
87+
redis_client.connection = connection
88+
89+
response_attribute_name = "db.redis.response"
90+
91+
def response_hook(span, conn, response):
92+
span.set_attribute(response_attribute_name, response)
93+
94+
RedisInstrumentor().uninstrument()
95+
RedisInstrumentor().instrument(
96+
tracer_provider=self.tracer_provider, response_hook=response_hook
97+
)
98+
99+
test_value = "test_value"
100+
101+
with mock.patch.object(connection, "send_command"):
102+
with mock.patch.object(
103+
redis_client, "parse_response", return_value=test_value
104+
):
105+
redis_client.get("key")
106+
107+
spans = self.memory_exporter.get_finished_spans()
108+
self.assertEqual(len(spans), 1)
109+
110+
span = spans[0]
111+
self.assertEqual(
112+
span.attributes.get(response_attribute_name), test_value
113+
)
114+
115+
def test_request_hook(self):
116+
redis_client = redis.Redis()
117+
connection = redis.connection.Connection()
118+
redis_client.connection = connection
119+
120+
custom_attribute_name = "my.request.attribute"
121+
122+
def request_hook(span, conn, args, kwargs):
123+
if span and span.is_recording():
124+
span.set_attribute(custom_attribute_name, args[0])
125+
126+
RedisInstrumentor().uninstrument()
127+
RedisInstrumentor().instrument(
128+
tracer_provider=self.tracer_provider, request_hook=request_hook
129+
)
130+
131+
test_value = "test_value"
132+
133+
with mock.patch.object(connection, "send_command"):
134+
with mock.patch.object(
135+
redis_client, "parse_response", return_value=test_value
136+
):
137+
redis_client.get("key")
138+
139+
spans = self.memory_exporter.get_finished_spans()
140+
self.assertEqual(len(spans), 1)
141+
142+
span = spans[0]
143+
self.assertEqual(span.attributes.get(custom_attribute_name), "GET")

0 commit comments

Comments
 (0)