diff --git a/docs/conf.py b/docs/conf.py index 68b871aaac2..b227cb51818 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -102,6 +102,7 @@ ("py:class", "ObjectProxy"), # TODO: Understand why sphinx is not able to find this local class ("py:class", "opentelemetry.trace.propagation.textmap.TextMapPropagator",), + ("py:class", "opentelemetry.trace.propagation.textmap.DictGetter",), ( "any", "opentelemetry.trace.propagation.textmap.TextMapPropagator.extract", diff --git a/exporter/opentelemetry-exporter-datadog/src/opentelemetry/exporter/datadog/propagator.py b/exporter/opentelemetry-exporter-datadog/src/opentelemetry/exporter/datadog/propagator.py index 3ad9fa1ae21..ab1468c54ab 100644 --- a/exporter/opentelemetry-exporter-datadog/src/opentelemetry/exporter/datadog/propagator.py +++ b/exporter/opentelemetry-exporter-datadog/src/opentelemetry/exporter/datadog/propagator.py @@ -39,25 +39,23 @@ class DatadogFormat(TextMapPropagator): def extract( self, - get_from_carrier: Getter[TextMapPropagatorT], + getter: Getter[TextMapPropagatorT], carrier: TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: trace_id = extract_first_element( - get_from_carrier(carrier, self.TRACE_ID_KEY) + getter.get(carrier, self.TRACE_ID_KEY) ) span_id = extract_first_element( - get_from_carrier(carrier, self.PARENT_ID_KEY) + getter.get(carrier, self.PARENT_ID_KEY) ) sampled = extract_first_element( - get_from_carrier(carrier, self.SAMPLING_PRIORITY_KEY) + getter.get(carrier, self.SAMPLING_PRIORITY_KEY) ) - origin = extract_first_element( - get_from_carrier(carrier, self.ORIGIN_KEY) - ) + origin = extract_first_element(getter.get(carrier, self.ORIGIN_KEY)) trace_flags = trace.TraceFlags() if sampled and int(sampled) in ( diff --git a/exporter/opentelemetry-exporter-datadog/tests/test_datadog_format.py b/exporter/opentelemetry-exporter-datadog/tests/test_datadog_format.py index 8480374471b..bb41fef49b1 100644 --- a/exporter/opentelemetry-exporter-datadog/tests/test_datadog_format.py +++ b/exporter/opentelemetry-exporter-datadog/tests/test_datadog_format.py @@ -18,13 +18,11 @@ from opentelemetry.exporter.datadog import constants, propagator from opentelemetry.sdk import trace from opentelemetry.trace import get_current_span, set_span_in_context +from opentelemetry.trace.propagation.textmap import DictGetter FORMAT = propagator.DatadogFormat() - -def get_as_list(dict_object, key): - value = dict_object.get(key) - return [value] if value is not None else [] +carrier_getter = DictGetter() class TestDatadogFormat(unittest.TestCase): @@ -45,7 +43,7 @@ def test_malformed_headers(self): malformed_parent_id_key = FORMAT.PARENT_ID_KEY + "-x" context = get_current_span( FORMAT.extract( - get_as_list, + carrier_getter, { malformed_trace_id_key: self.serialized_trace_id, malformed_parent_id_key: self.serialized_parent_id, @@ -63,7 +61,7 @@ def test_missing_trace_id(self): FORMAT.PARENT_ID_KEY: self.serialized_parent_id, } - ctx = FORMAT.extract(get_as_list, carrier) + ctx = FORMAT.extract(carrier_getter, carrier) span_context = get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) @@ -73,7 +71,7 @@ def test_missing_parent_id(self): FORMAT.TRACE_ID_KEY: self.serialized_trace_id, } - ctx = FORMAT.extract(get_as_list, carrier) + ctx = FORMAT.extract(carrier_getter, carrier) span_context = get_current_span(ctx).get_span_context() self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) @@ -81,7 +79,7 @@ def test_context_propagation(self): """Test the propagation of Datadog headers.""" parent_span_context = get_current_span( FORMAT.extract( - get_as_list, + carrier_getter, { FORMAT.TRACE_ID_KEY: self.serialized_trace_id, FORMAT.PARENT_ID_KEY: self.serialized_parent_id, @@ -138,7 +136,7 @@ def test_sampling_priority_auto_reject(self): """Test sampling priority rejected.""" parent_span_context = get_current_span( FORMAT.extract( - get_as_list, + carrier_getter, { FORMAT.TRACE_ID_KEY: self.serialized_trace_id, FORMAT.PARENT_ID_KEY: self.serialized_parent_id, diff --git a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py index 1a0bb47a644..1c442889a1b 100644 --- a/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-asgi/src/opentelemetry/instrumentation/asgi/__init__.py @@ -29,21 +29,31 @@ from opentelemetry import context, propagators, trace from opentelemetry.instrumentation.asgi.version import __version__ # noqa from opentelemetry.instrumentation.utils import http_status_to_status_code +from opentelemetry.trace.propagation.textmap import DictGetter from opentelemetry.trace.status import Status, StatusCode -def get_header_from_scope(scope: dict, header_name: str) -> typing.List[str]: - """Retrieve a HTTP header value from the ASGI scope. +class CarrierGetter(DictGetter): + def get(self, carrier: dict, key: str) -> typing.List[str]: + """Getter implementation to retrieve a HTTP header value from the ASGI + scope. - Returns: - A list with a single string with the header value if it exists, else an empty list. - """ - headers = scope.get("headers") - return [ - value.decode("utf8") - for (key, value) in headers - if key.decode("utf8") == header_name - ] + Args: + carrier: ASGI scope object + key: header name in scope + Returns: + A list with a single string with the header value if it exists, + else an empty list. + """ + headers = carrier.get("headers") + return [ + _value.decode("utf8") + for (_key, _value) in headers + if _key.decode("utf8") == key + ] + + +carrier_getter = CarrierGetter() def collect_request_attributes(scope): @@ -72,10 +82,10 @@ def collect_request_attributes(scope): http_method = scope.get("method") if http_method: result["http.method"] = http_method - http_host_value = ",".join(get_header_from_scope(scope, "host")) + http_host_value = ",".join(carrier_getter.get(scope, "host")) if http_host_value: result["http.server_name"] = http_host_value - http_user_agent = get_header_from_scope(scope, "user-agent") + http_user_agent = carrier_getter.get(scope, "user-agent") if len(http_user_agent) > 0: result["http.user_agent"] = http_user_agent[0] @@ -154,9 +164,7 @@ async def __call__(self, scope, receive, send): if scope["type"] not in ("http", "websocket"): return await self.app(scope, receive, send) - token = context.attach( - propagators.extract(get_header_from_scope, scope) - ) + token = context.attach(propagators.extract(carrier_getter, scope)) span_name, additional_attributes = self.span_details_callback(scope) try: diff --git a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py index a5897f8b6e5..d225e6bd069 100644 --- a/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-celery/src/opentelemetry/instrumentation/celery/__init__.py @@ -67,7 +67,7 @@ def add(x, y): from opentelemetry.instrumentation.celery import utils from opentelemetry.instrumentation.celery.version import __version__ from opentelemetry.instrumentation.instrumentor import BaseInstrumentor -from opentelemetry.trace.propagation import get_current_span +from opentelemetry.trace.propagation.textmap import DictGetter from opentelemetry.trace.status import Status, StatusCode logger = logging.getLogger(__name__) @@ -84,6 +84,20 @@ def add(x, y): _MESSAGE_ID_ATTRIBUTE_NAME = "messaging.message_id" +class CarrierGetter(DictGetter): + def get(self, carrier, key): + value = getattr(carrier, key, []) + if isinstance(value, str) or not isinstance(value, Iterable): + value = (value,) + return value + + def keys(self, carrier): + return [] + + +carrier_getter = CarrierGetter() + + class CeleryInstrumentor(BaseInstrumentor): def _instrument(self, **kwargs): tracer_provider = kwargs.get("tracer_provider") @@ -118,7 +132,7 @@ def _trace_prerun(self, *args, **kwargs): return request = task.request - tracectx = propagators.extract(carrier_extractor, request) or None + tracectx = propagators.extract(carrier_getter, request) or None logger.debug("prerun signal start task_id=%s", task_id) @@ -246,10 +260,3 @@ def _trace_retry(*args, **kwargs): # Use `str(reason)` instead of `reason.message` in case we get # something that isn't an `Exception` span.set_attribute(_TASK_RETRY_REASON_KEY, str(reason)) - - -def carrier_extractor(carrier, key): - value = getattr(carrier, key, []) - if isinstance(value, str) or not isinstance(value, Iterable): - value = (value,) - return value diff --git a/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware.py b/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware.py index 5c27ed289a2..1f465ca57a7 100644 --- a/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware.py +++ b/instrumentation/opentelemetry-instrumentation-django/src/opentelemetry/instrumentation/django/middleware.py @@ -23,8 +23,8 @@ from opentelemetry.instrumentation.utils import extract_attributes_from_object from opentelemetry.instrumentation.wsgi import ( add_response_attributes, + carrier_getter, collect_request_attributes, - get_header_from_environ, ) from opentelemetry.propagators import extract from opentelemetry.trace import SpanKind, get_tracer @@ -125,7 +125,7 @@ def process_request(self, request): environ = request.META - token = attach(extract(get_header_from_environ, environ)) + token = attach(extract(carrier_getter, environ)) tracer = get_tracer(__name__, __version__) diff --git a/instrumentation/opentelemetry-instrumentation-falcon/src/opentelemetry/instrumentation/falcon/__init__.py b/instrumentation/opentelemetry-instrumentation-falcon/src/opentelemetry/instrumentation/falcon/__init__.py index 66e6563dfff..55f8e98dcb1 100644 --- a/instrumentation/opentelemetry-instrumentation-falcon/src/opentelemetry/instrumentation/falcon/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-falcon/src/opentelemetry/instrumentation/falcon/__init__.py @@ -115,7 +115,7 @@ def __call__(self, env, start_response): start_time = time_ns() token = context.attach( - propagators.extract(otel_wsgi.get_header_from_environ, env) + propagators.extract(otel_wsgi.carrier_getter, env) ) span = self._tracer.start_span( otel_wsgi.get_default_span_name(env), diff --git a/instrumentation/opentelemetry-instrumentation-flask/src/opentelemetry/instrumentation/flask/__init__.py b/instrumentation/opentelemetry-instrumentation-flask/src/opentelemetry/instrumentation/flask/__init__.py index c88a7609059..1235b09a307 100644 --- a/instrumentation/opentelemetry-instrumentation-flask/src/opentelemetry/instrumentation/flask/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-flask/src/opentelemetry/instrumentation/flask/__init__.py @@ -118,7 +118,7 @@ def _before_request(): if span_name is None: span_name = otel_wsgi.get_default_span_name(environ) token = context.attach( - propagators.extract(otel_wsgi.get_header_from_environ, environ) + propagators.extract(otel_wsgi.carrier_getter, environ) ) tracer = trace.get_tracer(__name__, __version__) diff --git a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py index 5927c99b51b..087cf4f9ccb 100644 --- a/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py +++ b/instrumentation/opentelemetry-instrumentation-grpc/src/opentelemetry/instrumentation/grpc/_server.py @@ -23,12 +23,12 @@ import logging from contextlib import contextmanager -from typing import List import grpc from opentelemetry import propagators, trace from opentelemetry.context import attach, detach +from opentelemetry.trace.propagation.textmap import DictGetter from opentelemetry.trace.status import Status, StatusCode logger = logging.getLogger(__name__) @@ -163,18 +163,14 @@ class OpenTelemetryServerInterceptor(grpc.ServerInterceptor): def __init__(self, tracer): self._tracer = tracer + self._carrier_getter = DictGetter() @contextmanager def _set_remote_context(self, servicer_context): metadata = servicer_context.invocation_metadata() if metadata: md_dict = {md.key: md.value for md in metadata} - - def get_from_grpc_metadata(metadata, key) -> List[str]: - return [md_dict[key]] if key in md_dict else [] - - # Update the context with the traceparent from the RPC metadata. - ctx = propagators.extract(get_from_grpc_metadata, metadata) + ctx = propagators.extract(self._carrier_getter, md_dict) token = attach(ctx) try: yield diff --git a/instrumentation/opentelemetry-instrumentation-opentracing-shim/src/opentelemetry/instrumentation/opentracing_shim/__init__.py b/instrumentation/opentelemetry-instrumentation-opentracing-shim/src/opentelemetry/instrumentation/opentracing_shim/__init__.py index 63be13fe489..00a0b8d0cb8 100644 --- a/instrumentation/opentelemetry-instrumentation-opentracing-shim/src/opentelemetry/instrumentation/opentracing_shim/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-opentracing-shim/src/opentelemetry/instrumentation/opentracing_shim/__init__.py @@ -112,6 +112,7 @@ get_current_span, set_span_in_context, ) +from opentelemetry.trace.propagation.textmap import DictGetter from opentelemetry.util.types import Attributes ValueT = TypeVar("ValueT", int, float, bool, str) @@ -527,6 +528,7 @@ def __init__(self, tracer: OtelTracer): Format.TEXT_MAP, Format.HTTP_HEADERS, ) + self._carrier_getter = DictGetter() def unwrap(self): """Returns the :class:`opentelemetry.trace.Tracer` object that is @@ -710,12 +712,8 @@ def extract(self, format: object, carrier: object): if format not in self._supported_formats: raise UnsupportedFormatException - def get_as_list(dict_object, key): - value = dict_object.get(key) - return [value] if value is not None else [] - propagator = propagators.get_global_textmap() - ctx = propagator.extract(get_as_list, carrier) + ctx = propagator.extract(self._carrier_getter, carrier) span = get_current_span(ctx) if span is not None: otel_context = span.get_span_context() diff --git a/instrumentation/opentelemetry-instrumentation-pyramid/src/opentelemetry/instrumentation/pyramid/callbacks.py b/instrumentation/opentelemetry-instrumentation-pyramid/src/opentelemetry/instrumentation/pyramid/callbacks.py index ada239b8e31..e7110bd2b55 100644 --- a/instrumentation/opentelemetry-instrumentation-pyramid/src/opentelemetry/instrumentation/pyramid/callbacks.py +++ b/instrumentation/opentelemetry-instrumentation-pyramid/src/opentelemetry/instrumentation/pyramid/callbacks.py @@ -70,7 +70,7 @@ def _before_traversal(event): start_time = environ.get(_ENVIRON_STARTTIME_KEY) token = context.attach( - propagators.extract(otel_wsgi.get_header_from_environ, environ) + propagators.extract(otel_wsgi.carrier_getter, environ) ) tracer = trace.get_tracer(__name__, __version__) diff --git a/instrumentation/opentelemetry-instrumentation-tornado/src/opentelemetry/instrumentation/tornado/__init__.py b/instrumentation/opentelemetry-instrumentation-tornado/src/opentelemetry/instrumentation/tornado/__init__.py index 9a7959ab956..6bb956ecb58 100644 --- a/instrumentation/opentelemetry-instrumentation-tornado/src/opentelemetry/instrumentation/tornado/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-tornado/src/opentelemetry/instrumentation/tornado/__init__.py @@ -54,6 +54,7 @@ def get(self): http_status_to_status_code, unwrap, ) +from opentelemetry.trace.propagation.textmap import DictGetter from opentelemetry.trace.status import Status from opentelemetry.util import ExcludeList, time_ns @@ -84,6 +85,8 @@ def get_traced_request_attrs(): _excluded_urls = get_excluded_urls() _traced_attrs = get_traced_request_attrs() +carrier_getter = DictGetter() + class TornadoInstrumentor(BaseInstrumentor): patched_handlers = [] @@ -185,13 +188,6 @@ def _log_exception(tracer, func, handler, args, kwargs): return func(*args, **kwargs) -def _get_header_from_request_headers( - headers: dict, header_name: str -) -> typing.List[str]: - header = headers.get(header_name) - return [header] if header else [] - - def _get_attributes_from_request(request): attrs = { "component": "tornado", @@ -218,9 +214,7 @@ def _get_operation_name(handler, request): def _start_span(tracer, handler, start_time) -> _TraceContext: token = context.attach( - propagators.extract( - _get_header_from_request_headers, handler.request.headers, - ) + propagators.extract(carrier_getter, handler.request.headers,) ) span = tracer.start_span( diff --git a/instrumentation/opentelemetry-instrumentation-wsgi/src/opentelemetry/instrumentation/wsgi/__init__.py b/instrumentation/opentelemetry-instrumentation-wsgi/src/opentelemetry/instrumentation/wsgi/__init__.py index 62eef43251e..e1ef92c6baf 100644 --- a/instrumentation/opentelemetry-instrumentation-wsgi/src/opentelemetry/instrumentation/wsgi/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-wsgi/src/opentelemetry/instrumentation/wsgi/__init__.py @@ -61,24 +61,35 @@ def hello(): from opentelemetry import context, propagators, trace from opentelemetry.instrumentation.utils import http_status_to_status_code from opentelemetry.instrumentation.wsgi.version import __version__ +from opentelemetry.trace.propagation.textmap import DictGetter from opentelemetry.trace.status import Status, StatusCode _HTTP_VERSION_PREFIX = "HTTP/" -def get_header_from_environ( - environ: dict, header_name: str -) -> typing.List[str]: - """Retrieve a HTTP header value from the PEP3333-conforming WSGI environ. +class CarrierGetter(DictGetter): + def get(self, carrier: dict, key: str) -> typing.List[str]: + """Getter implementation to retrieve a HTTP header value from the + PEP3333-conforming WSGI environ - Returns: - A list with a single string with the header value if it exists, else an empty list. - """ - environ_key = "HTTP_" + header_name.upper().replace("-", "_") - value = environ.get(environ_key) - if value is not None: - return [value] - return [] + Args: + carrier: WSGI environ object + key: header name in environ object + Returns: + A list with a single string with the header value if it exists, + else an empty list. + """ + environ_key = "HTTP_" + key.upper().replace("-", "_") + value = carrier.get(environ_key) + if value is not None: + return [value] + return [] + + def keys(self, carrier): + return [] + + +carrier_getter = CarrierGetter() def setifnotnone(dic, key, value): @@ -195,9 +206,7 @@ def __call__(self, environ, start_response): start_response: The WSGI start_response callable. """ - token = context.attach( - propagators.extract(get_header_from_environ, environ) - ) + token = context.attach(propagators.extract(carrier_getter, environ)) span_name = self.name_callback(environ) span = self.tracer.start_span( diff --git a/opentelemetry-api/CHANGELOG.md b/opentelemetry-api/CHANGELOG.md index 977dd6375c0..175eba10c32 100644 --- a/opentelemetry-api/CHANGELOG.md +++ b/opentelemetry-api/CHANGELOG.md @@ -20,6 +20,8 @@ Released 2020-10-13 ([#1134](https://github.com/open-telemetry/opentelemetry-python/pull/1134)) - Parent is now always passed in via Context, intead of Span or SpanContext ([#1146](https://github.com/open-telemetry/opentelemetry-python/pull/1146)) +- Add keys method to TextMap propagator Getter + ([#1196](https://github.com/open-telemetry/opentelemetry-python/issues/1196)) ## Version 0.13b0 diff --git a/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py b/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py index d0920e68590..70e73a3b23b 100644 --- a/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py +++ b/opentelemetry-api/src/opentelemetry/baggage/propagation/__init__.py @@ -29,7 +29,7 @@ class BaggagePropagator(textmap.TextMapPropagator): def extract( self, - get_from_carrier: textmap.Getter[textmap.TextMapPropagatorT], + getter: textmap.Getter[textmap.TextMapPropagatorT], carrier: textmap.TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: @@ -43,7 +43,7 @@ def extract( context = get_current() header = _extract_first_element( - get_from_carrier(carrier, self._BAGGAGE_HEADER_NAME) + getter.get(carrier, self._BAGGAGE_HEADER_NAME) ) if not header or len(header) > self.MAX_HEADER_LENGTH: diff --git a/opentelemetry-api/src/opentelemetry/propagators/__init__.py b/opentelemetry-api/src/opentelemetry/propagators/__init__.py index 18c322f7939..fb2863ac8ad 100644 --- a/opentelemetry-api/src/opentelemetry/propagators/__init__.py +++ b/opentelemetry-api/src/opentelemetry/propagators/__init__.py @@ -82,24 +82,24 @@ def example_route(): def extract( - get_from_carrier: textmap.Getter[textmap.TextMapPropagatorT], + getter: textmap.Getter[textmap.TextMapPropagatorT], carrier: textmap.TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: """ Uses the configured propagator to extract a Context from the carrier. Args: - get_from_carrier: a function that can retrieve zero - or more values from the carrier. In the case that - the value does not exist, return an empty list. + getter: an object which contains a get function that can retrieve zero + or more values from the carrier and a keys function that can get all the keys + from carrier. carrier: and object which contains values that are used to construct a Context. This object - must be paired with an appropriate get_from_carrier + must be paired with an appropriate getter which understands how to extract a value from it. context: an optional Context to use. Defaults to current context if not set. """ - return get_global_textmap().extract(get_from_carrier, carrier, context) + return get_global_textmap().extract(getter, carrier, context) def inject( diff --git a/opentelemetry-api/src/opentelemetry/propagators/composite.py b/opentelemetry-api/src/opentelemetry/propagators/composite.py index 3499d2ea08a..441098ec1f0 100644 --- a/opentelemetry-api/src/opentelemetry/propagators/composite.py +++ b/opentelemetry-api/src/opentelemetry/propagators/composite.py @@ -35,7 +35,7 @@ def __init__( def extract( self, - get_from_carrier: textmap.Getter[textmap.TextMapPropagatorT], + getter: textmap.Getter[textmap.TextMapPropagatorT], carrier: textmap.TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: @@ -47,7 +47,7 @@ def extract( See `opentelemetry.trace.propagation.textmap.TextMapPropagator.extract` """ for propagator in self._propagators: - context = propagator.extract(get_from_carrier, carrier, context) + context = propagator.extract(getter, carrier, context) return context # type: ignore def inject( diff --git a/opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py b/opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py index 6f9ed897e11..ec505135bec 100644 --- a/opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py +++ b/opentelemetry-api/src/opentelemetry/trace/propagation/textmap.py @@ -18,9 +18,54 @@ from opentelemetry.context.context import Context TextMapPropagatorT = typing.TypeVar("TextMapPropagatorT") +CarrierValT = typing.Union[typing.List[str], str] Setter = typing.Callable[[TextMapPropagatorT, str, str], None] -Getter = typing.Callable[[TextMapPropagatorT, str], typing.List[str]] + + +class Getter(typing.Generic[TextMapPropagatorT]): + """This class implements a Getter that enables extracting propagated + fields from a carrier + + """ + + def get(self, carrier: TextMapPropagatorT, key: str) -> typing.List[str]: + """Function that can retrieve zero + or more values from the carrier. In the case that + the value does not exist, returns an empty list. + + Args: + carrier: An object which contains values that are used to + construct a Context. + key: key of a field in carrier. + Returns: first value of the propagation key or an empty list if the + key doesn't exist. + """ + raise NotImplementedError() + + def keys(self, carrier: TextMapPropagatorT) -> typing.List[str]: + """Function that can retrieve all the keys in a carrier object. + + Args: + carrier: An object which contains values that are + used to construct a Context. + Returns: + list of keys from the carrier. + """ + raise NotImplementedError() + + +class DictGetter(Getter[typing.Dict[str, CarrierValT]]): + def get( + self, carrier: typing.Dict[str, CarrierValT], key: str + ) -> typing.List[str]: + val = carrier.get(key, []) + if isinstance(val, typing.Iterable) and not isinstance(val, str): + return list(val) + return [val] + + def keys(self, carrier: typing.Dict[str, CarrierValT]) -> typing.List[str]: + return list(carrier.keys()) class TextMapPropagator(abc.ABC): @@ -35,23 +80,23 @@ class TextMapPropagator(abc.ABC): @abc.abstractmethod def extract( self, - get_from_carrier: Getter[TextMapPropagatorT], + getter: Getter[TextMapPropagatorT], carrier: TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: """Create a Context from values in the carrier. The extract function should retrieve values from the carrier - object using get_from_carrier, and use values to populate a + object using getter, and use values to populate a Context value and return it. Args: - get_from_carrier: a function that can retrieve zero + getter: a function that can retrieve zero or more values from the carrier. In the case that the value does not exist, return an empty list. carrier: and object which contains values that are used to construct a Context. This object - must be paired with an appropriate get_from_carrier + must be paired with an appropriate getter which understands how to extract a value from it. context: an optional Context to use. Defaults to current context if not set. diff --git a/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py b/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py index 4b77246ac5e..57933ef9c41 100644 --- a/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py +++ b/opentelemetry-api/src/opentelemetry/trace/propagation/tracecontext.py @@ -26,7 +26,6 @@ # .. _W3C Trace Context - Tracestate: # https://www.w3.org/TR/trace-context/#tracestate-field - _KEY_WITHOUT_VENDOR_FORMAT = r"[a-z][_0-9a-z\-\*\/]{0,255}" _KEY_WITH_VENDOR_FORMAT = ( r"[a-z0-9][_0-9a-z\-\*\/]{0,240}@[a-z][_0-9a-z\-\*\/]{0,13}" @@ -60,7 +59,7 @@ class TraceContextTextMapPropagator(textmap.TextMapPropagator): def extract( self, - get_from_carrier: textmap.Getter[textmap.TextMapPropagatorT], + getter: textmap.Getter[textmap.TextMapPropagatorT], carrier: textmap.TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: @@ -68,7 +67,7 @@ def extract( See `opentelemetry.trace.propagation.textmap.TextMapPropagator.extract` """ - header = get_from_carrier(carrier, self._TRACEPARENT_HEADER_NAME) + header = getter.get(carrier, self._TRACEPARENT_HEADER_NAME) if not header: return trace.set_span_in_context(trace.INVALID_SPAN, context) @@ -91,9 +90,7 @@ def extract( if version == "ff": return trace.set_span_in_context(trace.INVALID_SPAN, context) - tracestate_headers = get_from_carrier( - carrier, self._TRACESTATE_HEADER_NAME - ) + tracestate_headers = getter.get(carrier, self._TRACESTATE_HEADER_NAME) tracestate = _parse_tracestate(tracestate_headers) span_context = trace.SpanContext( diff --git a/opentelemetry-api/tests/baggage/test_baggage_propagation.py b/opentelemetry-api/tests/baggage/test_baggage_propagation.py index d5c16ead5d4..4c7b3de215a 100644 --- a/opentelemetry-api/tests/baggage/test_baggage_propagation.py +++ b/opentelemetry-api/tests/baggage/test_baggage_propagation.py @@ -18,12 +18,9 @@ from opentelemetry import baggage from opentelemetry.baggage.propagation import BaggagePropagator from opentelemetry.context import get_current +from opentelemetry.trace.propagation.textmap import DictGetter - -def get_as_list( - dict_object: typing.Dict[str, typing.List[str]], key: str -) -> typing.List[str]: - return dict_object.get(key, []) +carrier_getter = DictGetter() class TestBaggagePropagation(unittest.TestCase): @@ -33,7 +30,7 @@ def setUp(self): def _extract(self, header_value): """Test helper""" header = {"baggage": [header_value]} - return baggage.get_all(self.propagator.extract(get_as_list, header)) + return baggage.get_all(self.propagator.extract(carrier_getter, header)) def _inject(self, values): """Test helper""" @@ -46,7 +43,7 @@ def _inject(self, values): def test_no_context_header(self): baggage_entries = baggage.get_all( - self.propagator.extract(get_as_list, {}) + self.propagator.extract(carrier_getter, {}) ) self.assertEqual(baggage_entries, {}) diff --git a/opentelemetry-api/tests/propagators/test_global_httptextformat.py b/opentelemetry-api/tests/propagators/test_global_httptextformat.py index 2668be27c3d..b704207ed5f 100644 --- a/opentelemetry-api/tests/propagators/test_global_httptextformat.py +++ b/opentelemetry-api/tests/propagators/test_global_httptextformat.py @@ -18,13 +18,9 @@ from opentelemetry import baggage, trace from opentelemetry.propagators import extract, inject from opentelemetry.trace import get_current_span, set_span_in_context +from opentelemetry.trace.propagation.textmap import DictGetter - -def get_as_list( - dict_object: typing.Dict[str, typing.List[str]], key: str -) -> typing.List[str]: - value = dict_object.get(key) - return value if value is not None else [] +carrier_getter = DictGetter() class TestDefaultGlobalPropagator(unittest.TestCase): @@ -44,7 +40,7 @@ def test_propagation(self): "traceparent": [traceparent_value], "tracestate": [tracestate_value], } - ctx = extract(get_as_list, headers) + ctx = extract(carrier_getter, headers) baggage_entries = baggage.get_all(context=ctx) expected = {"key1": "val1", "key2": "val2"} self.assertEqual(baggage_entries, expected) diff --git a/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py b/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py index 295e3971a39..f012be2a233 100644 --- a/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py +++ b/opentelemetry-api/tests/trace/propagation/test_tracecontexthttptextformat.py @@ -17,15 +17,11 @@ from opentelemetry import trace from opentelemetry.trace.propagation import tracecontext +from opentelemetry.trace.propagation.textmap import DictGetter FORMAT = tracecontext.TraceContextTextMapPropagator() - -def get_as_list( - dict_object: typing.Dict[str, typing.List[str]], key: str -) -> typing.List[str]: - value = dict_object.get(key) - return value if value is not None else [] +carrier_getter = DictGetter() class TestTraceContextFormat(unittest.TestCase): @@ -42,7 +38,7 @@ def test_no_traceparent_header(self): trace-id and parent-id that represents the current request. """ output = {} # type:typing.Dict[str, typing.List[str]] - span = trace.get_current_span(FORMAT.extract(get_as_list, output)) + span = trace.get_current_span(FORMAT.extract(carrier_getter, output)) self.assertIsInstance(span.get_span_context(), trace.SpanContext) def test_headers_with_tracestate(self): @@ -56,7 +52,7 @@ def test_headers_with_tracestate(self): tracestate_value = "foo=1,bar=2,baz=3" span_context = trace.get_current_span( FORMAT.extract( - get_as_list, + carrier_getter, { "traceparent": [traceparent_value], "tracestate": [tracestate_value], @@ -100,7 +96,7 @@ def test_invalid_trace_id(self): """ span = trace.get_current_span( FORMAT.extract( - get_as_list, + carrier_getter, { "traceparent": [ "00-00000000000000000000000000000000-1234567890123456-00" @@ -131,7 +127,7 @@ def test_invalid_parent_id(self): """ span = trace.get_current_span( FORMAT.extract( - get_as_list, + carrier_getter, { "traceparent": [ "00-00000000000000000000000000000000-0000000000000000-00" @@ -169,7 +165,7 @@ def test_format_not_supported(self): """ span = trace.get_current_span( FORMAT.extract( - get_as_list, + carrier_getter, { "traceparent": [ "00-12345678901234567890123456789012-" @@ -193,7 +189,7 @@ def test_tracestate_empty_header(self): """ span = trace.get_current_span( FORMAT.extract( - get_as_list, + carrier_getter, { "traceparent": [ "00-12345678901234567890123456789012-1234567890123456-00" @@ -209,7 +205,7 @@ def test_tracestate_header_with_trailing_comma(self): """ span = trace.get_current_span( FORMAT.extract( - get_as_list, + carrier_getter, { "traceparent": [ "00-12345678901234567890123456789012-1234567890123456-00" @@ -233,7 +229,7 @@ def test_tracestate_keys(self): ) span = trace.get_current_span( FORMAT.extract( - get_as_list, + carrier_getter, { "traceparent": [ "00-12345678901234567890123456789012-1234567890123456-00" diff --git a/opentelemetry-sdk/CHANGELOG.md b/opentelemetry-sdk/CHANGELOG.md index 2c1c3d971d0..d1aa2261148 100644 --- a/opentelemetry-sdk/CHANGELOG.md +++ b/opentelemetry-sdk/CHANGELOG.md @@ -43,6 +43,8 @@ Released 2020-10-13 ([#1209](https://github.com/open-telemetry/opentelemetry-python/pull/1209)) - Parent is now always passed in via Context, intead of Span or SpanContext ([#1146](https://github.com/open-telemetry/opentelemetry-python/pull/1146)) +- Add keys method to TextMap propagator Getter + ([#1196](https://github.com/open-telemetry/opentelemetry-python/issues/1196)) ## Version 0.13b0 diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/propagation/b3_format.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/propagation/b3_format.py index 8a6b8e2247c..c629d107d36 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/propagation/b3_format.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/propagation/b3_format.py @@ -43,7 +43,7 @@ class B3Format(TextMapPropagator): def extract( self, - get_from_carrier: Getter[TextMapPropagatorT], + getter: Getter[TextMapPropagatorT], carrier: TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: @@ -53,7 +53,7 @@ def extract( flags = None single_header = _extract_first_element( - get_from_carrier(carrier, self.SINGLE_HEADER_KEY) + getter.get(carrier, self.SINGLE_HEADER_KEY) ) if single_header: # The b3 spec calls for the sampling state to be @@ -74,27 +74,19 @@ def extract( return trace.set_span_in_context(trace.INVALID_SPAN) else: trace_id = ( - _extract_first_element( - get_from_carrier(carrier, self.TRACE_ID_KEY) - ) + _extract_first_element(getter.get(carrier, self.TRACE_ID_KEY)) or trace_id ) span_id = ( - _extract_first_element( - get_from_carrier(carrier, self.SPAN_ID_KEY) - ) + _extract_first_element(getter.get(carrier, self.SPAN_ID_KEY)) or span_id ) sampled = ( - _extract_first_element( - get_from_carrier(carrier, self.SAMPLED_KEY) - ) + _extract_first_element(getter.get(carrier, self.SAMPLED_KEY)) or sampled ) flags = ( - _extract_first_element( - get_from_carrier(carrier, self.FLAGS_KEY) - ) + _extract_first_element(getter.get(carrier, self.FLAGS_KEY)) or flags ) diff --git a/opentelemetry-sdk/tests/trace/propagation/test_b3_format.py b/opentelemetry-sdk/tests/trace/propagation/test_b3_format.py index a788a812c76..79c4618aee7 100644 --- a/opentelemetry-sdk/tests/trace/propagation/test_b3_format.py +++ b/opentelemetry-sdk/tests/trace/propagation/test_b3_format.py @@ -19,18 +19,17 @@ import opentelemetry.sdk.trace.propagation.b3_format as b3_format import opentelemetry.trace as trace_api from opentelemetry.context import get_current +from opentelemetry.trace.propagation.textmap import DictGetter FORMAT = b3_format.B3Format() -def get_as_list(dict_object, key): - value = dict_object.get(key) - return [value] if value is not None else [] +carrier_getter = DictGetter() def get_child_parent_new_carrier(old_carrier): - ctx = FORMAT.extract(get_as_list, old_carrier) + ctx = FORMAT.extract(carrier_getter, old_carrier) parent_span_context = trace_api.get_current_span(ctx).get_span_context() parent = trace._Span("parent", parent_span_context) @@ -231,7 +230,7 @@ def test_invalid_single_header(self): invalid SpanContext. """ carrier = {FORMAT.SINGLE_HEADER_KEY: "0-1-2-3-4-5-6-7"} - ctx = FORMAT.extract(get_as_list, carrier) + ctx = FORMAT.extract(carrier_getter, carrier) span_context = trace_api.get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) @@ -243,7 +242,7 @@ def test_missing_trace_id(self): FORMAT.FLAGS_KEY: "1", } - ctx = FORMAT.extract(get_as_list, carrier) + ctx = FORMAT.extract(carrier_getter, carrier) span_context = trace_api.get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID) @@ -267,7 +266,7 @@ def test_invalid_trace_id( FORMAT.FLAGS_KEY: "1", } - ctx = FORMAT.extract(get_as_list, carrier) + ctx = FORMAT.extract(carrier_getter, carrier) span_context = trace_api.get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, 1) @@ -293,7 +292,7 @@ def test_invalid_span_id( FORMAT.FLAGS_KEY: "1", } - ctx = FORMAT.extract(get_as_list, carrier) + ctx = FORMAT.extract(carrier_getter, carrier) span_context = trace_api.get_current_span(ctx).get_span_context() self.assertEqual(span_context.trace_id, 1) @@ -306,7 +305,7 @@ def test_missing_span_id(self): FORMAT.FLAGS_KEY: "1", } - ctx = FORMAT.extract(get_as_list, carrier) + ctx = FORMAT.extract(carrier_getter, carrier) span_context = trace_api.get_current_span(ctx).get_span_context() self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID) @@ -321,11 +320,12 @@ def test_inject_empty_context(): def test_default_span(): """Make sure propagator does not crash when working with DefaultSpan""" - def getter(carrier, key): - return carrier.get(key, None) + class CarrierGetter(DictGetter): + def get(self, carrier, key): + return carrier.get(key, None) def setter(carrier, key, value): carrier[key] = value - ctx = FORMAT.extract(getter, {}) + ctx = FORMAT.extract(CarrierGetter(), {}) FORMAT.inject(setter, {}, ctx) diff --git a/tests/util/src/opentelemetry/test/mock_textmap.py b/tests/util/src/opentelemetry/test/mock_textmap.py index bf46ec32fa1..1fee6d2a9ca 100644 --- a/tests/util/src/opentelemetry/test/mock_textmap.py +++ b/tests/util/src/opentelemetry/test/mock_textmap.py @@ -33,7 +33,7 @@ class NOOPTextMapPropagator(TextMapPropagator): def extract( self, - get_from_carrier: Getter[TextMapPropagatorT], + getter: Getter[TextMapPropagatorT], carrier: TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: @@ -56,12 +56,12 @@ class MockTextMapPropagator(TextMapPropagator): def extract( self, - get_from_carrier: Getter[TextMapPropagatorT], + getter: Getter[TextMapPropagatorT], carrier: TextMapPropagatorT, context: typing.Optional[Context] = None, ) -> Context: - trace_id_list = get_from_carrier(carrier, self.TRACE_ID_KEY) - span_id_list = get_from_carrier(carrier, self.SPAN_ID_KEY) + trace_id_list = getter.get(carrier, self.TRACE_ID_KEY) + span_id_list = getter.get(carrier, self.SPAN_ID_KEY) if not trace_id_list or not span_id_list: return trace.set_span_in_context(trace.INVALID_SPAN)