Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync with Make setters and getters optional #372

Merged
merged 28 commits into from
Mar 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
- 'release/*'
pull_request:
env:
CORE_REPO_SHA: d3694fc520f8542b232fd1065133286f4591dcec
CORE_REPO_SHA: 94d9c1eb77723123a779ddd18d2b12d4bec445fb

jobs:
build:
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased](https://github.com/open-telemetry/opentelemetry-python-contrib/compare/v0.18b0...HEAD)
- Make getters and setters optional
([#372](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/372))
- Updated instrumentations to use `opentelemetry.trace.use_span` instead of `Tracer.use_span()`
([#364](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/364))
- `opentelemetry-propagator-ot-trace` Do not throw an exception when headers are not present
Expand Down
10 changes: 8 additions & 2 deletions docs/nitpick-exceptions.ini
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
[default]
class_references=
; TODO: Understand why sphinx is not able to find this local class
opentelemetry.propagators.textmap.CarrierT
opentelemetry.propagators.textmap.Setter
opentelemetry.propagators.textmap.Getter
opentelemetry.propagators.textmap.TextMapPropagator
; - AwsXRayFormat
opentelemetry.propagators.textmap.DictGetter
opentelemetry.propagators.textmap.DefaultGetter
; API
opentelemetry.propagators.textmap.Getter
; - DatadogFormat
; - AWSXRayFormat
opentelemetry.sdk.trace.id_generator.IdGenerator
; - AwsXRayIdGenerator
TextMapPropagatorT
TextMapPropagator
CarrierT
Setter
Getter
; - AwsXRayFormat.extract

anys=
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
from opentelemetry.context import Context
from opentelemetry.exporter.datadog import constants
from opentelemetry.propagators.textmap import (
CarrierT,
Getter,
Setter,
TextMapPropagator,
TextMapPropagatorT,
default_getter,
default_setter,
)
from opentelemetry.trace import get_current_span, set_span_in_context

Expand All @@ -36,9 +38,9 @@ class DatadogFormat(TextMapPropagator):

def extract(
self,
getter: Getter[TextMapPropagatorT],
carrier: TextMapPropagatorT,
carrier: CarrierT,
context: typing.Optional[Context] = None,
getter: Getter = default_getter,
) -> Context:
trace_id = extract_first_element(
getter.get(carrier, self.TRACE_ID_KEY)
Expand Down Expand Up @@ -81,28 +83,28 @@ def extract(

def inject(
self,
set_in_carrier: Setter[TextMapPropagatorT],
carrier: TextMapPropagatorT,
carrier: CarrierT,
context: typing.Optional[Context] = None,
setter: Setter = default_setter,
) -> None:
span = get_current_span(context)
span_context = span.get_span_context()
if span_context == trace.INVALID_SPAN_CONTEXT:
return
sampled = (trace.TraceFlags.SAMPLED & span.context.trace_flags) != 0
set_in_carrier(
setter.set(
carrier, self.TRACE_ID_KEY, format_trace_id(span.context.trace_id),
)
set_in_carrier(
setter.set(
carrier, self.PARENT_ID_KEY, format_span_id(span.context.span_id)
)
set_in_carrier(
setter.set(
carrier,
self.SAMPLING_PRIORITY_KEY,
str(constants.AUTO_KEEP if sampled else constants.AUTO_REJECT),
)
if constants.DD_ORIGIN in span.context.trace_state:
set_in_carrier(
setter.set(
carrier,
self.ORIGIN_KEY,
span.context.trace_state[constants.DD_ORIGIN],
Expand Down Expand Up @@ -134,8 +136,8 @@ def format_span_id(span_id: int) -> str:


def extract_first_element(
items: typing.Iterable[TextMapPropagatorT],
) -> typing.Optional[TextMapPropagatorT]:
items: typing.Iterable[CarrierT],
) -> typing.Optional[CarrierT]:
if items is None:
return None
return next(iter(items), None)
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@

from opentelemetry import trace as trace_api
from opentelemetry.exporter.datadog import constants, propagator
from opentelemetry.propagators.textmap import DictGetter
from opentelemetry.sdk import trace
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
from opentelemetry.trace import get_current_span, set_span_in_context

FORMAT = propagator.DatadogFormat()

carrier_getter = DictGetter()


class TestDatadogFormat(unittest.TestCase):
@classmethod
Expand All @@ -45,7 +42,6 @@ def test_malformed_headers(self):
malformed_parent_id_key = FORMAT.PARENT_ID_KEY + "-x"
context = get_current_span(
FORMAT.extract(
carrier_getter,
{
malformed_trace_id_key: self.serialized_trace_id,
malformed_parent_id_key: self.serialized_parent_id,
Expand All @@ -63,7 +59,7 @@ def test_missing_trace_id(self):
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
}

ctx = FORMAT.extract(carrier_getter, carrier)
ctx = FORMAT.extract(carrier)
span_context = get_current_span(ctx).get_span_context()
self.assertEqual(span_context.trace_id, trace_api.INVALID_TRACE_ID)

Expand All @@ -73,15 +69,14 @@ def test_missing_parent_id(self):
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
}

ctx = FORMAT.extract(carrier_getter, carrier)
ctx = FORMAT.extract(carrier)
span_context = get_current_span(ctx).get_span_context()
self.assertEqual(span_context.span_id, trace_api.INVALID_SPAN_ID)

def test_context_propagation(self):
"""Test the propagation of Datadog headers."""
parent_span_context = get_current_span(
FORMAT.extract(
carrier_getter,
{
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
Expand Down Expand Up @@ -118,7 +113,7 @@ def test_context_propagation(self):

child_carrier = {}
child_context = set_span_in_context(child)
FORMAT.inject(dict.__setitem__, child_carrier, context=child_context)
FORMAT.inject(child_carrier, context=child_context)

self.assertEqual(
child_carrier[FORMAT.TRACE_ID_KEY], self.serialized_trace_id
Expand All @@ -138,7 +133,6 @@ def test_sampling_priority_auto_reject(self):
"""Test sampling priority rejected."""
parent_span_context = get_current_span(
FORMAT.extract(
carrier_getter,
{
FORMAT.TRACE_ID_KEY: self.serialized_trace_id,
FORMAT.PARENT_ID_KEY: self.serialized_parent_id,
Expand All @@ -165,7 +159,7 @@ def test_sampling_priority_auto_reject(self):

child_carrier = {}
child_context = set_span_in_context(child)
FORMAT.inject(dict.__setitem__, child_carrier, context=child_context)
FORMAT.inject(child_carrier, context=child_context)

self.assertEqual(
child_carrier[FORMAT.SAMPLING_PRIORITY_KEY],
Expand All @@ -178,7 +172,7 @@ def test_fields(self, mock_get_current_span):

tracer = trace.TracerProvider().get_tracer("sdk_tracer_provider")

mock_set_in_carrier = Mock()
mock_setter = Mock()

mock_get_current_span.configure_mock(
**{
Expand All @@ -195,11 +189,11 @@ def test_fields(self, mock_get_current_span):

with tracer.start_as_current_span("parent"):
with tracer.start_as_current_span("child"):
FORMAT.inject(mock_set_in_carrier, {})
FORMAT.inject({}, setter=mock_setter)

inject_fields = set()

for call in mock_set_in_carrier.mock_calls:
for call in mock_setter.mock_calls:
inject_fields.add(call[1][1])

self.assertEqual(FORMAT.fields, inject_fields)
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ async def on_request_start(
trace.set_span_in_context(trace_config_ctx.span)
)

inject(type(params.headers).__setitem__, params.headers)
inject(params.headers)

async def on_request_end(
unused_session: aiohttp.ClientSession,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
from opentelemetry.instrumentation.asgi.version import __version__ # noqa
from opentelemetry.instrumentation.utils import http_status_to_status_code
from opentelemetry.propagate import extract
from opentelemetry.propagators.textmap import DictGetter
from opentelemetry.propagators.textmap import Getter
from opentelemetry.trace.status import Status, StatusCode


class CarrierGetter(DictGetter):
class ASGIGetter(Getter):
def get(
self, carrier: dict, key: str
) -> typing.Optional[typing.List[str]]:
Expand Down Expand Up @@ -62,8 +62,11 @@ def get(
return None
return decoded

def keys(self, carrier: dict) -> typing.List[str]:
return list(carrier.keys())

carrier_getter = CarrierGetter()

asgi_getter = ASGIGetter()


def collect_request_attributes(scope):
Expand All @@ -88,10 +91,10 @@ def collect_request_attributes(scope):
if http_method:
result["http.method"] = http_method

http_host_value_list = carrier_getter.get(scope, "host")
http_host_value_list = asgi_getter.get(scope, "host")
if http_host_value_list:
result["http.server_name"] = ",".join(http_host_value_list)
http_user_agent = carrier_getter.get(scope, "user-agent")
http_user_agent = asgi_getter.get(scope, "user-agent")
if http_user_agent:
result["http.user_agent"] = http_user_agent[0]

Expand Down Expand Up @@ -186,7 +189,7 @@ async def __call__(self, scope, receive, send):
if self.excluded_urls and self.excluded_urls.url_disabled(url):
return await self.app(scope, receive, send)

token = context.attach(extract(carrier_getter, scope))
token = context.attach(extract(scope, getter=asgi_getter))
span_name, additional_attributes = self.span_details_callback(scope)

try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,18 @@

from unittest import TestCase

from opentelemetry.instrumentation.asgi import CarrierGetter
from opentelemetry.instrumentation.asgi import ASGIGetter


class TestCarrierGetter(TestCase):
class TestASGIGetter(TestCase):
def test_get_none(self):
getter = CarrierGetter()
getter = ASGIGetter()
carrier = {}
val = getter.get(carrier, "test")
self.assertIsNone(val)

def test_get_(self):
getter = CarrierGetter()
getter = ASGIGetter()
carrier = {"headers": [(b"test-key", b"val")]}
expected_val = ["val"]
self.assertEqual(
Expand All @@ -45,6 +45,6 @@ def test_get_(self):
)

def test_keys(self):
getter = CarrierGetter()
getter = ASGIGetter()
keys = getter.keys({})
self.assertEqual(keys, [])
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
def _patched_endpoint_prepare_request(wrapped, instance, args, kwargs):
request = args[0]
headers = request.headers
inject(type(headers).__setitem__, headers)
inject(headers)
return wrapped(*args, **kwargs)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def add(x, y):
from opentelemetry.instrumentation.celery.version import __version__
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
from opentelemetry.propagate import extract, inject
from opentelemetry.propagators.textmap import DictGetter
from opentelemetry.propagators.textmap import Getter
from opentelemetry.trace.status import Status, StatusCode

logger = logging.getLogger(__name__)
Expand All @@ -78,7 +78,7 @@ def add(x, y):
_MESSAGE_ID_ATTRIBUTE_NAME = "messaging.message_id"


class CarrierGetter(DictGetter):
class CeleryGetter(Getter):
def get(self, carrier, key):
value = getattr(carrier, key, None)
if value is None:
Expand All @@ -91,7 +91,7 @@ def keys(self, carrier):
return []


carrier_getter = CarrierGetter()
celery_getter = CeleryGetter()


class CeleryInstrumentor(BaseInstrumentor):
Expand Down Expand Up @@ -128,7 +128,7 @@ def _trace_prerun(self, *args, **kwargs):
return

request = task.request
tracectx = extract(carrier_getter, request) or None
tracectx = extract(request, getter=celery_getter) or None

logger.debug("prerun signal start task_id=%s", task_id)

Expand Down Expand Up @@ -193,7 +193,7 @@ def _trace_before_publish(self, *args, **kwargs):

headers = kwargs.get("headers")
if headers:
inject(type(headers).__setitem__, headers)
inject(headers)

@staticmethod
def _trace_after_publish(*args, **kwargs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,31 @@

from unittest import TestCase, mock

from opentelemetry.instrumentation.celery import CarrierGetter
from opentelemetry.instrumentation.celery import CeleryGetter


class TestCarrierGetter(TestCase):
class TestCeleryGetter(TestCase):
def test_get_none(self):
getter = CarrierGetter()
getter = CeleryGetter()
carrier = {}
val = getter.get(carrier, "test")
self.assertIsNone(val)

def test_get_str(self):
mock_obj = mock.Mock()
getter = CarrierGetter()
getter = CeleryGetter()
mock_obj.test = "val"
val = getter.get(mock_obj, "test")
self.assertEqual(val, ("val",))

def test_get_iter(self):
mock_obj = mock.Mock()
getter = CarrierGetter()
getter = CeleryGetter()
mock_obj.test = ["val"]
val = getter.get(mock_obj, "test")
self.assertEqual(val, ["val"])

def test_keys(self):
getter = CarrierGetter()
getter = CeleryGetter()
keys = getter.keys({})
self.assertEqual(keys, [])
Loading