41
41
API
42
42
---
43
43
"""
44
-
45
- from typing import Collection
44
+ import typing
45
+ from typing import Any , Collection
46
46
47
47
import redis
48
48
from wrapt import wrap_function_wrapper
57
57
from opentelemetry .instrumentation .redis .version import __version__
58
58
from opentelemetry .instrumentation .utils import unwrap
59
59
from opentelemetry .semconv .trace import SpanAttributes
60
+ from opentelemetry .trace import Span
60
61
61
62
_DEFAULT_SERVICE = "redis"
62
63
64
+ _ResponseHookT = typing .Optional [
65
+ typing .Callable [[Span , redis .connection .Connection , Any ], None ]
66
+ ]
67
+
63
68
64
69
def _set_connection_attributes (span , conn ):
65
70
if not span .is_recording ():
@@ -70,42 +75,64 @@ def _set_connection_attributes(span, conn):
70
75
span .set_attribute (key , value )
71
76
72
77
73
- def _traced_execute_command (func , instance , args , kwargs ):
74
- tracer = getattr (redis , "_opentelemetry_tracer" )
75
- query = _format_command_args (args )
76
- name = ""
77
- if len (args ) > 0 and args [0 ]:
78
- name = args [0 ]
79
- else :
80
- name = instance .connection_pool .connection_kwargs .get ("db" , 0 )
81
- with tracer .start_as_current_span (
82
- name , kind = trace .SpanKind .CLIENT
83
- ) as span :
84
- if span .is_recording ():
85
- span .set_attribute (SpanAttributes .DB_STATEMENT , query )
86
- _set_connection_attributes (span , instance )
87
- span .set_attribute ("db.redis.args_length" , len (args ))
88
- return func (* args , ** kwargs )
89
-
90
-
91
- def _traced_execute_pipeline (func , instance , args , kwargs ):
92
- tracer = getattr (redis , "_opentelemetry_tracer" )
93
-
94
- cmds = [_format_command_args (c ) for c , _ in instance .command_stack ]
95
- resource = "\n " .join (cmds )
96
-
97
- span_name = " " .join ([args [0 ] for args , _ in instance .command_stack ])
98
-
99
- with tracer .start_as_current_span (
100
- span_name , kind = trace .SpanKind .CLIENT
101
- ) as span :
102
- if span .is_recording ():
103
- span .set_attribute (SpanAttributes .DB_STATEMENT , resource )
104
- _set_connection_attributes (span , instance )
105
- span .set_attribute (
106
- "db.redis.pipeline_length" , len (instance .command_stack )
107
- )
108
- return func (* args , ** kwargs )
78
+ def _instrument (
79
+ tracer , response_hook : _ResponseHookT = None ,
80
+ ):
81
+ def _traced_execute_command (func , instance , args , kwargs ):
82
+ query = _format_command_args (args )
83
+ name = ""
84
+ if len (args ) > 0 and args [0 ]:
85
+ name = args [0 ]
86
+ else :
87
+ name = instance .connection_pool .connection_kwargs .get ("db" , 0 )
88
+ with tracer .start_as_current_span (
89
+ name , kind = trace .SpanKind .CLIENT
90
+ ) as span :
91
+ if span .is_recording ():
92
+ span .set_attribute (SpanAttributes .DB_STATEMENT , query )
93
+ _set_connection_attributes (span , instance )
94
+ span .set_attribute ("db.redis.args_length" , len (args ))
95
+ response = func (* args , ** kwargs )
96
+ if callable (response_hook ):
97
+ response_hook (span , instance , response )
98
+ return response
99
+
100
+ def _traced_execute_pipeline (func , instance , args , kwargs ):
101
+ cmds = [_format_command_args (c ) for c , _ in instance .command_stack ]
102
+ resource = "\n " .join (cmds )
103
+
104
+ span_name = " " .join ([args [0 ] for args , _ in instance .command_stack ])
105
+
106
+ with tracer .start_as_current_span (
107
+ span_name , kind = trace .SpanKind .CLIENT
108
+ ) as span :
109
+ if span .is_recording ():
110
+ span .set_attribute (SpanAttributes .DB_STATEMENT , resource )
111
+ _set_connection_attributes (span , instance )
112
+ span .set_attribute (
113
+ "db.redis.pipeline_length" , len (instance .command_stack )
114
+ )
115
+ response = func (* args , ** kwargs )
116
+ if callable (response_hook ):
117
+ response_hook (span , instance , response )
118
+ return response
119
+
120
+ pipeline_class = (
121
+ "BasePipeline" if redis .VERSION < (3 , 0 , 0 ) else "Pipeline"
122
+ )
123
+ redis_class = "StrictRedis" if redis .VERSION < (3 , 0 , 0 ) else "Redis"
124
+
125
+ wrap_function_wrapper (
126
+ "redis" , f"{ redis_class } .execute_command" , _traced_execute_command
127
+ )
128
+ wrap_function_wrapper (
129
+ "redis.client" , f"{ pipeline_class } .execute" , _traced_execute_pipeline ,
130
+ )
131
+ wrap_function_wrapper (
132
+ "redis.client" ,
133
+ f"{ pipeline_class } .immediate_execute_command" ,
134
+ _traced_execute_command ,
135
+ )
109
136
110
137
111
138
class RedisInstrumentor (BaseInstrumentor ):
@@ -117,41 +144,18 @@ def instrumentation_dependencies(self) -> Collection[str]:
117
144
return _instruments
118
145
119
146
def _instrument (self , ** kwargs ):
147
+ """Instruments the redis module
148
+
149
+ Args:
150
+ **kwargs: Optional arguments
151
+ ``tracer_provider``: a TracerProvider, defaults to global.
152
+ ``response_hook``: An optional callback which is invoked right before the span is finished processing a response.
153
+ """
120
154
tracer_provider = kwargs .get ("tracer_provider" )
121
- setattr (
122
- redis ,
123
- "_opentelemetry_tracer" ,
124
- trace .get_tracer (
125
- __name__ , __version__ , tracer_provider = tracer_provider ,
126
- ),
155
+ tracer = trace .get_tracer (
156
+ __name__ , __version__ , tracer_provider = tracer_provider
127
157
)
128
-
129
- if redis .VERSION < (3 , 0 , 0 ):
130
- wrap_function_wrapper (
131
- "redis" , "StrictRedis.execute_command" , _traced_execute_command
132
- )
133
- wrap_function_wrapper (
134
- "redis.client" ,
135
- "BasePipeline.execute" ,
136
- _traced_execute_pipeline ,
137
- )
138
- wrap_function_wrapper (
139
- "redis.client" ,
140
- "BasePipeline.immediate_execute_command" ,
141
- _traced_execute_command ,
142
- )
143
- else :
144
- wrap_function_wrapper (
145
- "redis" , "Redis.execute_command" , _traced_execute_command
146
- )
147
- wrap_function_wrapper (
148
- "redis.client" , "Pipeline.execute" , _traced_execute_pipeline
149
- )
150
- wrap_function_wrapper (
151
- "redis.client" ,
152
- "Pipeline.immediate_execute_command" ,
153
- _traced_execute_command ,
154
- )
158
+ _instrument (tracer , response_hook = kwargs .get ("response_hook" ))
155
159
156
160
def _uninstrument (self , ** kwargs ):
157
161
if redis .VERSION < (3 , 0 , 0 ):
0 commit comments