Skip to content

Commit 590c32c

Browse files
authored
Handle B3 trace_id and span_id correctly (open-telemetry#934)
1 parent 8b1da35 commit 590c32c

File tree

2 files changed

+64
-2
lines changed

2 files changed

+64
-2
lines changed

opentelemetry-sdk/src/opentelemetry/sdk/trace/propagation/b3_format.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# limitations under the License.
1414

1515
import typing
16+
from re import compile as re_compile
1617

1718
import opentelemetry.trace as trace
1819
from opentelemetry.context import Context
20+
from opentelemetry.sdk.trace import generate_span_id, generate_trace_id
1921
from opentelemetry.trace.propagation.httptextformat import (
2022
Getter,
2123
HTTPTextFormat,
@@ -37,6 +39,8 @@ class B3Format(HTTPTextFormat):
3739
SAMPLED_KEY = "x-b3-sampled"
3840
FLAGS_KEY = "x-b3-flags"
3941
_SAMPLE_PROPAGATE_VALUES = set(["1", "True", "true", "d"])
42+
_trace_id_regex = re_compile(r"[\da-fA-F]{16}|[\da-fA-F]{32}")
43+
_span_id_regex = re_compile(r"[\da-fA-F]{16}")
4044

4145
def extract(
4246
self,
@@ -95,19 +99,32 @@ def extract(
9599
or flags
96100
)
97101

102+
if (
103+
self._trace_id_regex.fullmatch(trace_id) is None
104+
or self._span_id_regex.fullmatch(span_id) is None
105+
):
106+
trace_id = generate_trace_id()
107+
span_id = generate_span_id()
108+
sampled = "0"
109+
110+
else:
111+
trace_id = int(trace_id, 16)
112+
span_id = int(span_id, 16)
113+
98114
options = 0
99115
# The b3 spec provides no defined behavior for both sample and
100116
# flag values set. Since the setting of at least one implies
101117
# the desire for some form of sampling, propagate if either
102118
# header is set to allow.
103119
if sampled in self._SAMPLE_PROPAGATE_VALUES or flags == "1":
104120
options |= trace.TraceFlags.SAMPLED
121+
105122
return trace.set_span_in_context(
106123
trace.DefaultSpan(
107124
trace.SpanContext(
108125
# trace an span ids are encoded in hex, so must be converted
109-
trace_id=int(trace_id, 16),
110-
span_id=int(span_id, 16),
126+
trace_id=trace_id,
127+
span_id=span_id,
111128
is_remote=True,
112129
trace_flags=trace.TraceFlags(options),
113130
trace_state=trace.TraceState(),

opentelemetry-sdk/tests/trace/propagation/test_b3_format.py

+45
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import unittest
16+
from unittest.mock import patch
1617

1718
import opentelemetry.sdk.trace as trace
1819
import opentelemetry.sdk.trace.propagation.b3_format as b3_format
@@ -245,6 +246,50 @@ def test_missing_trace_id(self):
245246
span_context = trace_api.get_current_span(ctx).get_context()
246247
self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID)
247248

249+
@patch("opentelemetry.sdk.trace.propagation.b3_format.generate_trace_id")
250+
@patch("opentelemetry.sdk.trace.propagation.b3_format.generate_span_id")
251+
def test_invalid_trace_id(
252+
self, mock_generate_span_id, mock_generate_trace_id
253+
):
254+
"""If a trace id is invalid, generate a trace id."""
255+
256+
mock_generate_trace_id.configure_mock(return_value=1)
257+
mock_generate_span_id.configure_mock(return_value=2)
258+
259+
carrier = {
260+
FORMAT.TRACE_ID_KEY: "abc123",
261+
FORMAT.SPAN_ID_KEY: self.serialized_span_id,
262+
FORMAT.FLAGS_KEY: "1",
263+
}
264+
265+
ctx = FORMAT.extract(get_as_list, carrier)
266+
span_context = trace_api.get_current_span(ctx).get_context()
267+
268+
self.assertEqual(span_context.trace_id, 1)
269+
self.assertEqual(span_context.span_id, 2)
270+
271+
@patch("opentelemetry.sdk.trace.propagation.b3_format.generate_trace_id")
272+
@patch("opentelemetry.sdk.trace.propagation.b3_format.generate_span_id")
273+
def test_invalid_span_id(
274+
self, mock_generate_span_id, mock_generate_trace_id
275+
):
276+
"""If a span id is invalid, generate a trace id."""
277+
278+
mock_generate_trace_id.configure_mock(return_value=1)
279+
mock_generate_span_id.configure_mock(return_value=2)
280+
281+
carrier = {
282+
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
283+
FORMAT.SPAN_ID_KEY: "abc123",
284+
FORMAT.FLAGS_KEY: "1",
285+
}
286+
287+
ctx = FORMAT.extract(get_as_list, carrier)
288+
span_context = trace_api.get_current_span(ctx).get_context()
289+
290+
self.assertEqual(span_context.trace_id, 1)
291+
self.assertEqual(span_context.span_id, 2)
292+
248293
def test_missing_span_id(self):
249294
"""If a trace id is missing, populate an invalid trace id."""
250295
carrier = {

0 commit comments

Comments
 (0)