|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import unittest
|
| 16 | +from unittest.mock import patch |
16 | 17 |
|
17 | 18 | import opentelemetry.sdk.trace as trace
|
18 | 19 | import opentelemetry.sdk.trace.propagation.b3_format as b3_format
|
@@ -245,6 +246,50 @@ def test_missing_trace_id(self):
|
245 | 246 | span_context = trace_api.get_current_span(ctx).get_context()
|
246 | 247 | self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID)
|
247 | 248 |
|
| 249 | + @patch("opentelemetry.sdk.trace.propagation.b3_format.generate_trace_id") |
| 250 | + @patch("opentelemetry.sdk.trace.propagation.b3_format.generate_span_id") |
| 251 | + def test_invalid_trace_id( |
| 252 | + self, mock_generate_span_id, mock_generate_trace_id |
| 253 | + ): |
| 254 | + """If a trace id is invalid, generate a trace id.""" |
| 255 | + |
| 256 | + mock_generate_trace_id.configure_mock(return_value=1) |
| 257 | + mock_generate_span_id.configure_mock(return_value=2) |
| 258 | + |
| 259 | + carrier = { |
| 260 | + FORMAT.TRACE_ID_KEY: "abc123", |
| 261 | + FORMAT.SPAN_ID_KEY: self.serialized_span_id, |
| 262 | + FORMAT.FLAGS_KEY: "1", |
| 263 | + } |
| 264 | + |
| 265 | + ctx = FORMAT.extract(get_as_list, carrier) |
| 266 | + span_context = trace_api.get_current_span(ctx).get_context() |
| 267 | + |
| 268 | + self.assertEqual(span_context.trace_id, 1) |
| 269 | + self.assertEqual(span_context.span_id, 2) |
| 270 | + |
| 271 | + @patch("opentelemetry.sdk.trace.propagation.b3_format.generate_trace_id") |
| 272 | + @patch("opentelemetry.sdk.trace.propagation.b3_format.generate_span_id") |
| 273 | + def test_invalid_span_id( |
| 274 | + self, mock_generate_span_id, mock_generate_trace_id |
| 275 | + ): |
| 276 | + """If a span id is invalid, generate a trace id.""" |
| 277 | + |
| 278 | + mock_generate_trace_id.configure_mock(return_value=1) |
| 279 | + mock_generate_span_id.configure_mock(return_value=2) |
| 280 | + |
| 281 | + carrier = { |
| 282 | + FORMAT.TRACE_ID_KEY: self.serialized_trace_id, |
| 283 | + FORMAT.SPAN_ID_KEY: "abc123", |
| 284 | + FORMAT.FLAGS_KEY: "1", |
| 285 | + } |
| 286 | + |
| 287 | + ctx = FORMAT.extract(get_as_list, carrier) |
| 288 | + span_context = trace_api.get_current_span(ctx).get_context() |
| 289 | + |
| 290 | + self.assertEqual(span_context.trace_id, 1) |
| 291 | + self.assertEqual(span_context.span_id, 2) |
| 292 | + |
248 | 293 | def test_missing_span_id(self):
|
249 | 294 | """If a trace id is missing, populate an invalid trace id."""
|
250 | 295 | carrier = {
|
|
0 commit comments