12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ from threading import local
16
+ from weakref import WeakKeyDictionary
17
+
15
18
from sqlalchemy .event import listen # pylint: disable=no-name-in-module
16
19
17
20
from opentelemetry import trace
@@ -66,12 +69,21 @@ def __init__(self, tracer, engine):
66
69
self .tracer = tracer
67
70
self .engine = engine
68
71
self .vendor = _normalize_vendor (engine .name )
69
- self .current_span = None
72
+ self .cursor_mapping = WeakKeyDictionary ()
73
+ self .local = local ()
70
74
71
75
listen (engine , "before_cursor_execute" , self ._before_cur_exec )
72
76
listen (engine , "after_cursor_execute" , self ._after_cur_exec )
73
77
listen (engine , "handle_error" , self ._handle_error )
74
78
79
+ @property
80
+ def current_thread_span (self ):
81
+ return getattr (self .local , "current_span" , None )
82
+
83
+ @current_thread_span .setter
84
+ def current_thread_span (self , span ):
85
+ setattr (self .local , "current_span" , span )
86
+
75
87
def _operation_name (self , db_name , statement ):
76
88
parts = []
77
89
if isinstance (statement , str ):
@@ -94,34 +106,38 @@ def _before_cur_exec(self, conn, cursor, statement, *args):
94
106
attrs = _get_attributes_from_cursor (self .vendor , cursor , attrs )
95
107
96
108
db_name = attrs .get (_DB , "" )
97
- self . current_span = self .tracer .start_span (
109
+ span = self .tracer .start_span (
98
110
self ._operation_name (db_name , statement ),
99
111
kind = trace .SpanKind .CLIENT ,
100
112
)
101
- with trace .use_span (self .current_span , end_on_exit = False ):
102
- if self .current_span .is_recording ():
103
- self .current_span .set_attribute (_STMT , statement )
104
- self .current_span .set_attribute ("db.system" , self .vendor )
113
+ self .current_thread_span = self .cursor_mapping [cursor ] = span
114
+ with trace .use_span (span , end_on_exit = False ):
115
+ if span .is_recording ():
116
+ span .set_attribute (_STMT , statement )
117
+ span .set_attribute ("db.system" , self .vendor )
105
118
for key , value in attrs .items ():
106
- self . current_span .set_attribute (key , value )
119
+ span .set_attribute (key , value )
107
120
108
121
# pylint: disable=unused-argument
109
122
def _after_cur_exec (self , conn , cursor , statement , * args ):
110
- if self .current_span is None :
123
+ span = self .cursor_mapping .get (cursor , None )
124
+ if span is None :
111
125
return
112
- self .current_span .end ()
126
+
127
+ span .end ()
113
128
114
129
def _handle_error (self , context ):
115
- if self .current_span is None :
130
+ span = self .current_thread_span
131
+ if span is None :
116
132
return
117
133
118
134
try :
119
- if self . current_span .is_recording ():
120
- self . current_span .set_status (
135
+ if span .is_recording ():
136
+ span .set_status (
121
137
Status (StatusCode .ERROR , str (context .original_exception ),)
122
138
)
123
139
finally :
124
- self . current_span .end ()
140
+ span .end ()
125
141
126
142
127
143
def _get_attributes_from_url (url ):
0 commit comments