diff --git a/CHANGELOG.md b/CHANGELOG.md index e7f76a9797f..62d96498441 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added ProxyTracerProvider and ProxyTracer implementations to allow fetching provider and tracer instances before a global provider is set up. ([#1726](https://github.com/open-telemetry/opentelemetry-python/pull/1726)) +- Added `__contains__` to `opentelementry.trace.span.TraceState`. + ([#1773](https://github.com/open-telemetry/opentelemetry-python/pull/1773)) ## [1.0.0](https://github.com/open-telemetry/opentelemetry-python/releases/tag/v1.0.0) - 2021-03-26 diff --git a/opentelemetry-api/src/opentelemetry/trace/span.py b/opentelemetry-api/src/opentelemetry/trace/span.py index d04cdfa49dd..c4c713cf3e7 100644 --- a/opentelemetry-api/src/opentelemetry/trace/span.py +++ b/opentelemetry-api/src/opentelemetry/trace/span.py @@ -231,8 +231,11 @@ def __init__( "Invalid key/value pair (%s, %s) found.", key, value ) - def __getitem__(self, key: str) -> typing.Optional[str]: # type: ignore - return self._dict.get(key) + def __contains__(self, item: object) -> bool: + return item in self._dict + + def __getitem__(self, key: str) -> str: + return self._dict[key] def __iter__(self) -> typing.Iterator[str]: return iter(self._dict) diff --git a/opentelemetry-api/tests/trace/test_tracestate.py b/opentelemetry-api/tests/trace/test_tracestate.py index 6665dd612dd..625b260d548 100644 --- a/opentelemetry-api/tests/trace/test_tracestate.py +++ b/opentelemetry-api/tests/trace/test_tracestate.py @@ -96,3 +96,19 @@ def test_tracestate_order_changed(self): foo_place = entries.index(("foo", "bar33")) # type: ignore prev_first_place = entries.index(("1a-2f@foo", "bar1")) # type: ignore self.assertLessEqual(foo_place, prev_first_place) + + def test_trace_contains(self): + entries = [ + "1a-2f@foo=bar1", + "1a-_*/2b@foo=bar2", + "foo=bar3", + "foo-_*/bar=bar4", + ] + header_list = [",".join(entries)] + state = TraceState.from_header(header_list) + + self.assertTrue("foo" in state) + self.assertFalse("bar" in state) + self.assertIsNone(state.get("bar")) + with self.assertRaises(KeyError): + state["bar"] # pylint:disable=W0104