Skip to content

Commit 7fc376e

Browse files
adding response_hook to redis instrumentor
1 parent 97e9f2f commit 7fc376e

File tree

3 files changed

+110
-71
lines changed

3 files changed

+110
-71
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v1.5.0-0.24b0...HEAD)
99

10+
### Added
11+
- `opentelemetry-instrumentation-redis` added response_hook callback passed as an argument to the instrument method.
12+
([#669](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/669))
13+
1014
### Changed
1115
- `opentelemetry-instrumentation-botocore` Unpatch botocore Endpoint.prepare_request on uninstrument
1216
([#664](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/664))

instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py

+75-71
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
API
4242
---
4343
"""
44-
45-
from typing import Collection
44+
import typing
45+
from typing import Any, Collection
4646

4747
import redis
4848
from wrapt import wrap_function_wrapper
@@ -57,9 +57,14 @@
5757
from opentelemetry.instrumentation.redis.version import __version__
5858
from opentelemetry.instrumentation.utils import unwrap
5959
from opentelemetry.semconv.trace import SpanAttributes
60+
from opentelemetry.trace import Span
6061

6162
_DEFAULT_SERVICE = "redis"
6263

64+
_ResponseHookT = typing.Optional[
65+
typing.Callable[[Span, redis.connection.Connection, Any], None]
66+
]
67+
6368

6469
def _set_connection_attributes(span, conn):
6570
if not span.is_recording():
@@ -70,42 +75,64 @@ def _set_connection_attributes(span, conn):
7075
span.set_attribute(key, value)
7176

7277

73-
def _traced_execute_command(func, instance, args, kwargs):
74-
tracer = getattr(redis, "_opentelemetry_tracer")
75-
query = _format_command_args(args)
76-
name = ""
77-
if len(args) > 0 and args[0]:
78-
name = args[0]
79-
else:
80-
name = instance.connection_pool.connection_kwargs.get("db", 0)
81-
with tracer.start_as_current_span(
82-
name, kind=trace.SpanKind.CLIENT
83-
) as span:
84-
if span.is_recording():
85-
span.set_attribute(SpanAttributes.DB_STATEMENT, query)
86-
_set_connection_attributes(span, instance)
87-
span.set_attribute("db.redis.args_length", len(args))
88-
return func(*args, **kwargs)
89-
90-
91-
def _traced_execute_pipeline(func, instance, args, kwargs):
92-
tracer = getattr(redis, "_opentelemetry_tracer")
93-
94-
cmds = [_format_command_args(c) for c, _ in instance.command_stack]
95-
resource = "\n".join(cmds)
96-
97-
span_name = " ".join([args[0] for args, _ in instance.command_stack])
98-
99-
with tracer.start_as_current_span(
100-
span_name, kind=trace.SpanKind.CLIENT
101-
) as span:
102-
if span.is_recording():
103-
span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
104-
_set_connection_attributes(span, instance)
105-
span.set_attribute(
106-
"db.redis.pipeline_length", len(instance.command_stack)
107-
)
108-
return func(*args, **kwargs)
78+
def _instrument(
79+
tracer, response_hook: _ResponseHookT = None,
80+
):
81+
def _traced_execute_command(func, instance, args, kwargs):
82+
query = _format_command_args(args)
83+
name = ""
84+
if len(args) > 0 and args[0]:
85+
name = args[0]
86+
else:
87+
name = instance.connection_pool.connection_kwargs.get("db", 0)
88+
with tracer.start_as_current_span(
89+
name, kind=trace.SpanKind.CLIENT
90+
) as span:
91+
if span.is_recording():
92+
span.set_attribute(SpanAttributes.DB_STATEMENT, query)
93+
_set_connection_attributes(span, instance)
94+
span.set_attribute("db.redis.args_length", len(args))
95+
response = func(*args, **kwargs)
96+
if callable(response_hook):
97+
response_hook(span, instance, response)
98+
return response
99+
100+
def _traced_execute_pipeline(func, instance, args, kwargs):
101+
cmds = [_format_command_args(c) for c, _ in instance.command_stack]
102+
resource = "\n".join(cmds)
103+
104+
span_name = " ".join([args[0] for args, _ in instance.command_stack])
105+
106+
with tracer.start_as_current_span(
107+
span_name, kind=trace.SpanKind.CLIENT
108+
) as span:
109+
if span.is_recording():
110+
span.set_attribute(SpanAttributes.DB_STATEMENT, resource)
111+
_set_connection_attributes(span, instance)
112+
span.set_attribute(
113+
"db.redis.pipeline_length", len(instance.command_stack)
114+
)
115+
response = func(*args, **kwargs)
116+
if callable(response_hook):
117+
response_hook(span, instance, response)
118+
return response
119+
120+
pipeline_class = (
121+
"BasePipeline" if redis.VERSION < (3, 0, 0) else "Pipeline"
122+
)
123+
redis_class = "StrictRedis" if redis.VERSION < (3, 0, 0) else "Redis"
124+
125+
wrap_function_wrapper(
126+
"redis", f"{redis_class}.execute_command", _traced_execute_command
127+
)
128+
wrap_function_wrapper(
129+
"redis.client", f"{pipeline_class}.execute", _traced_execute_pipeline,
130+
)
131+
wrap_function_wrapper(
132+
"redis.client",
133+
f"{pipeline_class}.immediate_execute_command",
134+
_traced_execute_command,
135+
)
109136

110137

111138
class RedisInstrumentor(BaseInstrumentor):
@@ -117,41 +144,18 @@ def instrumentation_dependencies(self) -> Collection[str]:
117144
return _instruments
118145

119146
def _instrument(self, **kwargs):
147+
"""Instruments the redis module
148+
149+
Args:
150+
**kwargs: Optional arguments
151+
``tracer_provider``: a TracerProvider, defaults to global.
152+
``response_hook``: An optional callback which is invoked right before the span is finished processing a response.
153+
"""
120154
tracer_provider = kwargs.get("tracer_provider")
121-
setattr(
122-
redis,
123-
"_opentelemetry_tracer",
124-
trace.get_tracer(
125-
__name__, __version__, tracer_provider=tracer_provider,
126-
),
155+
tracer = trace.get_tracer(
156+
__name__, __version__, tracer_provider=tracer_provider
127157
)
128-
129-
if redis.VERSION < (3, 0, 0):
130-
wrap_function_wrapper(
131-
"redis", "StrictRedis.execute_command", _traced_execute_command
132-
)
133-
wrap_function_wrapper(
134-
"redis.client",
135-
"BasePipeline.execute",
136-
_traced_execute_pipeline,
137-
)
138-
wrap_function_wrapper(
139-
"redis.client",
140-
"BasePipeline.immediate_execute_command",
141-
_traced_execute_command,
142-
)
143-
else:
144-
wrap_function_wrapper(
145-
"redis", "Redis.execute_command", _traced_execute_command
146-
)
147-
wrap_function_wrapper(
148-
"redis.client", "Pipeline.execute", _traced_execute_pipeline
149-
)
150-
wrap_function_wrapper(
151-
"redis.client",
152-
"Pipeline.immediate_execute_command",
153-
_traced_execute_command,
154-
)
158+
_instrument(tracer, response_hook=kwargs.get("response_hook"))
155159

156160
def _uninstrument(self, **kwargs):
157161
if redis.VERSION < (3, 0, 0):

instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py

+31
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,34 @@ def test_instrument_uninstrument(self):
8080

8181
spans = self.memory_exporter.get_finished_spans()
8282
self.assertEqual(len(spans), 1)
83+
84+
def test_response_hook(self):
85+
redis_client = redis.Redis()
86+
connection = redis.connection.Connection()
87+
redis_client.connection = connection
88+
89+
response_attribute_name = "db.redis.response"
90+
91+
def response_hook(span, conn, response):
92+
span.set_attribute(response_attribute_name, response)
93+
94+
RedisInstrumentor().uninstrument()
95+
RedisInstrumentor().instrument(
96+
tracer_provider=self.tracer_provider, response_hook=response_hook
97+
)
98+
99+
test_value = "test_value"
100+
101+
with mock.patch.object(connection, "send_command"):
102+
with mock.patch.object(
103+
redis_client, "parse_response", return_value=test_value
104+
):
105+
redis_client.get("key")
106+
107+
spans = self.memory_exporter.get_finished_spans()
108+
self.assertEqual(len(spans), 1)
109+
110+
span = spans[0]
111+
self.assertEqual(
112+
span.attributes.get(response_attribute_name), test_value
113+
)

0 commit comments

Comments
 (0)