@@ -55,10 +55,6 @@ class AwsXRayFormat(TextMapPropagator):
55
55
IS_SAMPLED = "1"
56
56
NOT_SAMPLED = "0"
57
57
58
- # pylint: disable=too-many-locals
59
- # pylint: disable=too-many-return-statements
60
- # pylint: disable=too-many-branches
61
- # pylint: disable=too-many-statements
62
58
def extract (
63
59
self ,
64
60
getter : Getter [TextMapPropagatorT ],
@@ -79,73 +75,78 @@ def extract(
79
75
trace .INVALID_SPAN , context = context
80
76
)
81
77
78
+ trace_id , span_id , sampled , err = self .extract_span_properties (
79
+ trace_header
80
+ )
81
+
82
+ if err is not None :
83
+ return trace .set_span_in_context (
84
+ trace .INVALID_SPAN , context = context
85
+ )
86
+
87
+ options = 0
88
+ if sampled :
89
+ options |= trace .TraceFlags .SAMPLED
90
+
91
+ span_context = trace .SpanContext (
92
+ trace_id = trace_id ,
93
+ span_id = span_id ,
94
+ is_remote = True ,
95
+ trace_flags = trace .TraceFlags (options ),
96
+ trace_state = trace .TraceState (),
97
+ )
98
+
99
+ if not span_context .is_valid :
100
+ _logger .error (
101
+ "Invalid Span Extracted. Insertting INVALID span into provided context."
102
+ )
103
+ return trace .set_span_in_context (
104
+ trace .INVALID_SPAN , context = context
105
+ )
106
+
107
+ return trace .set_span_in_context (
108
+ trace .DefaultSpan (span_context ), context = context
109
+ )
110
+
111
+ def extract_span_properties (self , trace_header ):
82
112
trace_id = trace .INVALID_TRACE_ID
83
113
span_id = trace .INVALID_SPAN_ID
84
114
sampled = False
85
115
86
- next_kv_pair_start = 0
116
+ extract_err = None
87
117
88
- while next_kv_pair_start < len (trace_header ):
89
- try :
90
- kv_pair_delimiter_index = trace_header .index (
91
- self .KV_PAIR_DELIMITER , next_kv_pair_start
92
- )
93
- kv_pair_subset = trace_header [
94
- next_kv_pair_start :kv_pair_delimiter_index
95
- ]
96
- next_kv_pair_start = kv_pair_delimiter_index + 1
97
- except ValueError :
98
- kv_pair_subset = trace_header [next_kv_pair_start :]
99
- next_kv_pair_start = len (trace_header )
100
-
101
- stripped_kv_pair = kv_pair_subset .strip ()
118
+ for kv_pair_str in trace_header .split (self .KV_PAIR_DELIMITER ):
119
+ if extract_err :
120
+ break
102
121
103
122
try :
104
- key_and_value_delimiter_index = stripped_kv_pair . index (
123
+ key_str , value_str = kv_pair_str . split (
105
124
self .KEY_AND_VALUE_DELIMITER
106
125
)
126
+ key , value = key_str .strip (), value_str .strip ()
107
127
except ValueError :
108
128
_logger .error (
109
129
(
110
130
"Error parsing X-Ray trace header. Invalid key value pair: %s. Returning INVALID span context." ,
111
- kv_pair_subset ,
131
+ kv_pair_str ,
112
132
)
113
133
)
114
- return trace .set_span_in_context (
115
- trace .INVALID_SPAN , context = context
116
- )
134
+ return trace_id , span_id , sampled , extract_err
117
135
118
- value = stripped_kv_pair [key_and_value_delimiter_index + 1 :]
119
-
120
- if stripped_kv_pair .startswith (self .TRACE_ID_KEY ):
121
- if (
122
- len (value ) != self .TRACE_ID_LENGTH
123
- or not value .startswith (self .TRACE_ID_VERSION )
124
- or value [self .TRACE_ID_DELIMITER_INDEX_1 ]
125
- != self .TRACE_ID_DELIMITER
126
- or value [self .TRACE_ID_DELIMITER_INDEX_2 ]
127
- != self .TRACE_ID_DELIMITER
128
- ):
136
+ if key == self .TRACE_ID_KEY :
137
+ if not self .validate_trace_id (value ):
129
138
_logger .error (
130
139
(
131
140
"Invalid TraceId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context." ,
132
141
self .TRACE_HEADER_KEY ,
133
142
trace_header ,
134
143
)
135
144
)
136
- return trace .set_span_in_context (
137
- trace .INVALID_SPAN , context = context
138
- )
145
+ extract_err = True
146
+ break
139
147
140
- timestamp_subset = value [
141
- self .TRACE_ID_DELIMITER_INDEX_1
142
- + 1 : self .TRACE_ID_DELIMITER_INDEX_2
143
- ]
144
- unique_id_subset = value [
145
- self .TRACE_ID_DELIMITER_INDEX_2 + 1 : self .TRACE_ID_LENGTH
146
- ]
147
148
try :
148
- trace_id = int ( timestamp_subset + unique_id_subset , 16 )
149
+ trace_id = self . parse_trace_id ( value )
149
150
except ValueError :
150
151
_logger .error (
151
152
(
@@ -154,24 +155,21 @@ def extract(
154
155
trace_header ,
155
156
)
156
157
)
157
- return trace .set_span_in_context (
158
- trace .INVALID_SPAN , context = context
159
- )
160
- elif stripped_kv_pair .startswith (self .PARENT_ID_KEY ):
161
- if len (value ) != self .PARENT_ID_LENGTH :
158
+ extract_err = True
159
+ elif key == self .PARENT_ID_KEY :
160
+ if not self .validate_span_id (value ):
162
161
_logger .error (
163
162
(
164
163
"Invalid ParentId in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context." ,
165
164
self .TRACE_HEADER_KEY ,
166
165
trace_header ,
167
166
)
168
167
)
169
- return trace .set_span_in_context (
170
- trace .INVALID_SPAN , context = context
171
- )
168
+ extract_err = True
169
+ break
172
170
173
171
try :
174
- span_id = int (value , 16 )
172
+ span_id = AwsXRayFormat . parse_span_id (value )
175
173
except ValueError :
176
174
_logger .error (
177
175
(
@@ -180,60 +178,61 @@ def extract(
180
178
trace_header ,
181
179
)
182
180
)
183
- return trace .set_span_in_context (
184
- trace .INVALID_SPAN , context = context
185
- )
186
- elif stripped_kv_pair .startswith (self .SAMPLED_FLAG_KEY ):
187
- is_sampled_flag_valid = True
188
-
189
- if len (value ) != self .SAMPLED_FLAG_LENGTH :
190
- is_sampled_flag_valid = False
191
-
192
- if is_sampled_flag_valid :
193
- sampled_flag = value [0 ]
194
- if sampled_flag == self .IS_SAMPLED :
195
- sampled = True
196
- elif sampled_flag == self .NOT_SAMPLED :
197
- sampled = False
198
- else :
199
- is_sampled_flag_valid = False
200
-
201
- if not is_sampled_flag_valid :
181
+ extract_err = True
182
+ elif key == self .SAMPLED_FLAG_KEY :
183
+ if not self .validate_sampled_flag (value ):
202
184
_logger .error (
203
185
(
204
186
"Invalid Sampling flag in X-Ray trace header: '%s' with value '%s'. Returning INVALID span context." ,
205
187
self .TRACE_HEADER_KEY ,
206
188
trace_header ,
207
189
)
208
190
)
209
- return trace .set_span_in_context (
210
- trace .INVALID_SPAN , context = context
211
- )
191
+ extract_err = True
192
+ break
212
193
213
- options = 0
214
- if sampled :
215
- options |= trace .TraceFlags .SAMPLED
194
+ sampled = self .parse_sampled_flag (value )
216
195
217
- span_context = trace .SpanContext (
218
- trace_id = trace_id ,
219
- span_id = span_id ,
220
- is_remote = True ,
221
- trace_flags = trace .TraceFlags (options ),
222
- trace_state = trace .TraceState (),
223
- )
196
+ return trace_id , span_id , sampled , extract_err
224
197
225
- if not span_context .is_valid :
226
- _logger .error (
227
- "Invalid Span Extracted. Insertting INVALID span into provided context."
228
- )
229
- return trace .set_span_in_context (
230
- trace .INVALID_SPAN , context = context
231
- )
198
+ def validate_trace_id (self , trace_id_str ):
199
+ return (
200
+ len (trace_id_str ) == self .TRACE_ID_LENGTH
201
+ and trace_id_str .startswith (self .TRACE_ID_VERSION )
202
+ and trace_id_str [self .TRACE_ID_DELIMITER_INDEX_1 ]
203
+ == self .TRACE_ID_DELIMITER
204
+ and trace_id_str [self .TRACE_ID_DELIMITER_INDEX_2 ]
205
+ == self .TRACE_ID_DELIMITER
206
+ )
232
207
233
- return trace .set_span_in_context (
234
- trace .DefaultSpan (span_context ), context = context
208
+ def parse_trace_id (self , trace_id_str ):
209
+ timestamp_subset = trace_id_str [
210
+ self .TRACE_ID_DELIMITER_INDEX_1
211
+ + 1 : self .TRACE_ID_DELIMITER_INDEX_2
212
+ ]
213
+ unique_id_subset = trace_id_str [
214
+ self .TRACE_ID_DELIMITER_INDEX_2 + 1 : self .TRACE_ID_LENGTH
215
+ ]
216
+ return int (timestamp_subset + unique_id_subset , 16 )
217
+
218
+ def validate_span_id (self , span_id_str ):
219
+ return len (span_id_str ) == self .PARENT_ID_LENGTH
220
+
221
+ @staticmethod
222
+ def parse_span_id (span_id_str ):
223
+ return int (span_id_str , 16 )
224
+
225
+ def validate_sampled_flag (self , sampled_flag_str ):
226
+ return len (
227
+ sampled_flag_str
228
+ ) == self .SAMPLED_FLAG_LENGTH and sampled_flag_str in (
229
+ self .IS_SAMPLED ,
230
+ self .NOT_SAMPLED ,
235
231
)
236
232
233
+ def parse_sampled_flag (self , sampled_flag_str ):
234
+ return sampled_flag_str [0 ] == self .IS_SAMPLED
235
+
237
236
def inject (
238
237
self ,
239
238
set_in_carrier : Setter [TextMapPropagatorT ],
0 commit comments