Skip to content

Bugfix: Pika basicConsume context propagation #766

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import wrapt
from pika.adapters import BlockingConnection
from pika.channel import Channel
from pika.adapters.blocking_connection import BlockingChannel

from opentelemetry import trace
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
Expand All @@ -35,18 +35,25 @@
class PikaInstrumentor(BaseInstrumentor): # type: ignore
# pylint: disable=attribute-defined-outside-init
@staticmethod
def _instrument_consumers(
consumers_dict: Dict[str, Callable[..., Any]], tracer: Tracer
def _instrument_blocking_channel_consumers(
channel: BlockingChannel, tracer: Tracer
) -> Any:
for key, callback in consumers_dict.items():
for consumer_tag, consumer_info in channel._consumer_infos.items():
decorated_callback = utils._decorate_callback(
callback, tracer, key
consumer_info.on_message_callback, tracer, consumer_tag
)
setattr(decorated_callback, "_original_callback", callback)
consumers_dict[key] = decorated_callback

setattr(
decorated_callback,
"_original_callback",
consumer_info.on_message_callback,
)
consumer_info.on_message_callback = decorated_callback

@staticmethod
def _instrument_basic_publish(channel: Channel, tracer: Tracer) -> None:
def _instrument_basic_publish(
channel: BlockingChannel, tracer: Tracer
) -> None:
original_function = getattr(channel, "basic_publish")
decorated_function = utils._decorate_basic_publish(
original_function, channel, tracer
Expand All @@ -57,13 +64,13 @@ def _instrument_basic_publish(channel: Channel, tracer: Tracer) -> None:

@staticmethod
def _instrument_channel_functions(
channel: Channel, tracer: Tracer
channel: BlockingChannel, tracer: Tracer
) -> None:
if hasattr(channel, "basic_publish"):
PikaInstrumentor._instrument_basic_publish(channel, tracer)

@staticmethod
def _uninstrument_channel_functions(channel: Channel) -> None:
def _uninstrument_channel_functions(channel: BlockingChannel) -> None:
for function_name in _FUNCTIONS_TO_UNINSTRUMENT:
if not hasattr(channel, function_name):
continue
Expand All @@ -73,8 +80,10 @@ def _uninstrument_channel_functions(channel: Channel) -> None:
unwrap(channel, "basic_consume")

@staticmethod
# Make sure that the spans are created inside hash them set as parent and not as brothers
def instrument_channel(
channel: Channel, tracer_provider: Optional[TracerProvider] = None,
channel: BlockingChannel,
tracer_provider: Optional[TracerProvider] = None,
) -> None:
if not hasattr(channel, "_is_instrumented_by_opentelemetry"):
channel._is_instrumented_by_opentelemetry = False
Expand All @@ -84,18 +93,14 @@ def instrument_channel(
)
return
tracer = trace.get_tracer(__name__, __version__, tracer_provider)
if not hasattr(channel, "_impl"):
_LOG.error("Could not find implementation for provided channel!")
return
if channel._impl._consumers:
PikaInstrumentor._instrument_consumers(
channel._impl._consumers, tracer
)
PikaInstrumentor._instrument_blocking_channel_consumers(
channel, tracer
)
PikaInstrumentor._decorate_basic_consume(channel, tracer)
PikaInstrumentor._instrument_channel_functions(channel, tracer)

@staticmethod
def uninstrument_channel(channel: Channel) -> None:
def uninstrument_channel(channel: BlockingChannel) -> None:
if (
not hasattr(channel, "_is_instrumented_by_opentelemetry")
or not channel._is_instrumented_by_opentelemetry
Expand All @@ -104,12 +109,12 @@ def uninstrument_channel(channel: Channel) -> None:
"Attempting to uninstrument Pika channel while already uninstrumented!"
)
return
if not hasattr(channel, "_impl"):
_LOG.error("Could not find implementation for provided channel!")
return
for key, callback in channel._impl._consumers.items():
if hasattr(callback, "_original_callback"):
channel._impl._consumers[key] = callback._original_callback

for consumers_tag, client_info in channel._consumer_infos.items():
if hasattr(client_info.on_message_callback, "_original_callback"):
channel._consumer_infos[
consumers_tag
] = client_info.on_message_callback._original_callback
PikaInstrumentor._uninstrument_channel_functions(channel)

def _decorate_channel_function(
Expand All @@ -123,28 +128,15 @@ def wrapper(wrapped, instance, args, kwargs):
wrapt.wrap_function_wrapper(BlockingConnection, "channel", wrapper)

@staticmethod
def _decorate_basic_consume(channel, tracer: Optional[Tracer]) -> None:
def _decorate_basic_consume(
channel: BlockingChannel, tracer: Optional[Tracer]
) -> None:
def wrapper(wrapped, instance, args, kwargs):
if not hasattr(channel, "_impl"):
_LOG.error(
"Could not find implementation for provided channel!"
)
return wrapped(*args, **kwargs)
current_keys = set(channel._impl._consumers.keys())
return_value = wrapped(*args, **kwargs)
new_key_list = list(
set(channel._impl._consumers.keys()) - current_keys
)
if not new_key_list:
_LOG.error("Could not find added callback")
return return_value
new_key = new_key_list[0]
callback = channel._impl._consumers[new_key]
decorated_callback = utils._decorate_callback(
callback, tracer, new_key

PikaInstrumentor._instrument_blocking_channel_consumers(
channel, tracer
)
setattr(decorated_callback, "_original_callback", callback)
channel._impl._consumers[new_key] = decorated_callback
return return_value

wrapt.wrap_function_wrapper(channel, "basic_consume", wrapper)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,18 @@ def decorated_callback(
ctx = propagate.extract(properties.headers, getter=_pika_getter)
if not ctx:
ctx = context.get_current()
token = context.attach(ctx)
span = _get_span(
tracer,
channel,
properties,
span_kind=SpanKind.CONSUMER,
task_name=task_name,
ctx=ctx,
operation=MessagingOperationValues.RECEIVE,
)
with trace.use_span(span, end_on_exit=True):
retval = callback(channel, method, properties, body)
context.detach(token)
return retval

return decorated_callback
Expand All @@ -78,14 +79,12 @@ def decorated_function(
properties = BasicProperties(headers={})
if properties.headers is None:
properties.headers = {}
ctx = context.get_current()
span = _get_span(
tracer,
channel,
properties,
span_kind=SpanKind.PRODUCER,
task_name="(temporary)",
ctx=ctx,
operation=None,
)
if not span:
Expand All @@ -109,7 +108,6 @@ def _get_span(
properties: BasicProperties,
task_name: str,
span_kind: SpanKind,
ctx: context.Context,
operation: Optional[MessagingOperationValues] = None,
) -> Optional[Span]:
if context.get_value("suppress_instrumentation") or context.get_value(
Expand All @@ -118,9 +116,7 @@ def _get_span(
return None
task_name = properties.type if properties.type else task_name
span = tracer.start_span(
context=ctx,
name=_generate_span_name(task_name, operation),
kind=span_kind,
name=_generate_span_name("pika", operation), kind=span_kind,
)
if span.is_recording():
_enrich_span(span, channel, properties, task_name, operation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
class TestPika(TestCase):
def setUp(self) -> None:
self.channel = mock.MagicMock(spec=Channel)
self.channel._impl = mock.MagicMock(spec=BaseConnection)
consumer_info = mock.MagicMock()
consumer_info.on_message_callback = mock.MagicMock()
self.channel._consumer_infos = {"consumer-tag": consumer_info}
self.mock_callback = mock.MagicMock()
self.channel._impl._consumers = {"mock_key": self.mock_callback}

def test_instrument_api(self) -> None:
instrumentation = PikaInstrumentor()
Expand All @@ -49,19 +50,19 @@ def test_instrument_api(self) -> None:
"opentelemetry.instrumentation.pika.PikaInstrumentor._decorate_basic_consume"
)
@mock.patch(
"opentelemetry.instrumentation.pika.PikaInstrumentor._instrument_consumers"
"opentelemetry.instrumentation.pika.PikaInstrumentor._instrument_blocking_channel_consumers"
)
def test_instrument(
self,
instrument_consumers: mock.MagicMock,
instrument_blocking_channel_consumers: mock.MagicMock,
instrument_basic_consume: mock.MagicMock,
instrument_channel_functions: mock.MagicMock,
):
PikaInstrumentor.instrument_channel(channel=self.channel)
assert hasattr(
self.channel, "_is_instrumented_by_opentelemetry"
), "channel is not marked as instrumented!"
instrument_consumers.assert_called_once()
instrument_blocking_channel_consumers.assert_called_once()
instrument_basic_consume.assert_called_once()
instrument_channel_functions.assert_called_once()

Expand All @@ -71,18 +72,18 @@ def test_instrument_consumers(
) -> None:
tracer = mock.MagicMock(spec=Tracer)
expected_decoration_calls = [
mock.call(value, tracer, key)
for key, value in self.channel._impl._consumers.items()
mock.call(value.on_message_callback, tracer, key)
for key, value in self.channel._consumer_infos.items()
]
PikaInstrumentor._instrument_consumers(
self.channel._impl._consumers, tracer
PikaInstrumentor._instrument_blocking_channel_consumers(
self.channel, tracer
)
decorate_callback.assert_has_calls(
calls=expected_decoration_calls, any_order=True
)
assert all(
hasattr(callback, "_original_callback")
for callback in self.channel._impl._consumers.values()
for callback in self.channel._consumer_infos.values()
)

@mock.patch(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,10 @@ def test_get_span(
task_name = "test.test"
span_kind = mock.MagicMock(spec=SpanKind)
get_value.return_value = None
ctx = mock.MagicMock()
_ = utils._get_span(
tracer, channel, properties, task_name, span_kind, ctx
)
_ = utils._get_span(tracer, channel, properties, task_name, span_kind)
generate_span_name.assert_called_once()
tracer.start_span.assert_called_once_with(
context=ctx, name=generate_span_name.return_value, kind=span_kind
name=generate_span_name.return_value, kind=span_kind
)
enrich_span.assert_called_once()

Expand Down Expand Up @@ -200,7 +197,6 @@ def test_decorate_callback(
properties,
span_kind=SpanKind.CONSUMER,
task_name=mock_task_name,
ctx=extract.return_value,
operation=MessagingOperationValues.RECEIVE,
)
use_span.assert_called_once_with(
Expand All @@ -213,12 +209,10 @@ def test_decorate_callback(

@mock.patch("opentelemetry.instrumentation.pika.utils._get_span")
@mock.patch("opentelemetry.propagate.inject")
@mock.patch("opentelemetry.context.get_current")
@mock.patch("opentelemetry.trace.use_span")
def test_decorate_basic_publish(
self,
use_span: mock.MagicMock,
get_current: mock.MagicMock,
inject: mock.MagicMock,
get_span: mock.MagicMock,
) -> None:
Expand All @@ -234,14 +228,12 @@ def test_decorate_basic_publish(
retval = decorated_basic_publish(
channel, method, mock_body, properties
)
get_current.assert_called_once()
get_span.assert_called_once_with(
tracer,
channel,
properties,
span_kind=SpanKind.PRODUCER,
task_name="(temporary)",
ctx=get_current.return_value,
operation=None,
)
use_span.assert_called_once_with(
Expand All @@ -256,14 +248,12 @@ def test_decorate_basic_publish(

@mock.patch("opentelemetry.instrumentation.pika.utils._get_span")
@mock.patch("opentelemetry.propagate.inject")
@mock.patch("opentelemetry.context.get_current")
@mock.patch("opentelemetry.trace.use_span")
@mock.patch("pika.spec.BasicProperties.__new__")
def test_decorate_basic_publish_no_properties(
self,
basic_properties: mock.MagicMock,
use_span: mock.MagicMock,
get_current: mock.MagicMock,
inject: mock.MagicMock,
get_span: mock.MagicMock,
) -> None:
Expand All @@ -277,7 +267,6 @@ def test_decorate_basic_publish_no_properties(
)
retval = decorated_basic_publish(channel, method, body=mock_body)
basic_properties.assert_called_once_with(BasicProperties, headers={})
get_current.assert_called_once()
use_span.assert_called_once_with(
get_span.return_value, end_on_exit=True
)
Expand Down