Skip to content

Support cursor based queries #2501

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ suggestion-mode=yes
# active Python interpreter and may run arbitrary code.
unsafe-load-any-extension=no

# Run python dependant checks considering the baseline version
py-version=3.8

[MESSAGES CONTROL]

Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
([#2573](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2573))
- `opentelemetry-instrumentation-confluent-kafka` Add support for version 2.4.0 of confluent_kafka
([#2616](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2616))
- `opentelemetry-instrumentation-asyncpg` Add instrumentation to cursor based queries
([#2501](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/2501))

### Breaking changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,27 @@ def _instrument(self, **kwargs):
"asyncpg.connection", method, self._do_execute
)

def _uninstrument(self, **__):
for method in [
"execute",
"executemany",
"fetch",
"fetchval",
"fetchrow",
"Cursor.fetch",
"Cursor.forward",
"Cursor.fetchrow",
"CursorIterator.__anext__",
]:
unwrap(asyncpg.Connection, method)
wrapt.wrap_function_wrapper(
"asyncpg.cursor", method, self._do_cursor_execute
)

def _uninstrument(self, **__):
for cls, methods in [
(
asyncpg.connection.Connection,
("execute", "executemany", "fetch", "fetchval", "fetchrow"),
),
(asyncpg.cursor.Cursor, ("forward", "fetch", "fetchrow")),
(asyncpg.cursor.CursorIterator, ("__anext__",)),
]:
for method_name in methods:
unwrap(cls, method_name)

async def _do_execute(self, func, instance, args, kwargs):
exception = None
Expand Down Expand Up @@ -170,3 +182,49 @@ async def _do_execute(self, func, instance, args, kwargs):
span.set_status(Status(StatusCode.ERROR))

return result

async def _do_cursor_execute(self, func, instance, args, kwargs):
"""Wrap cursor based functions. For every call this will generate a new span."""
exception = None
params = getattr(instance._connection, "_params", {})
name = (
instance._query
if instance._query
else params.get("database", "postgresql")
)

try:
# Strip leading comments so we get the operation name.
name = self._leading_comment_remover.sub("", name).split()[0]
except IndexError:
name = ""

stop = False
with self._tracer.start_as_current_span(
f"CURSOR: {name}",
kind=SpanKind.CLIENT,
) as span:
if span.is_recording():
span_attributes = _hydrate_span_from_args(
instance._connection,
instance._query,
instance._args if self.capture_parameters else None,
)
for attribute, value in span_attributes.items():
span.set_attribute(attribute, value)

try:
result = await func(*args, **kwargs)
except StopAsyncIteration:
# Do not show this exception to the span
stop = True
except Exception as exc: # pylint: disable=W0703
exception = exc
raise
finally:
if span.is_recording() and exception is not None:
span.set_status(Status(StatusCode.ERROR))

if not stop:
return result
raise StopAsyncIteration
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from asyncpg import Connection
import asyncio
from unittest import mock

import pytest
from asyncpg import Connection, Record, cursor
from wrapt import ObjectProxy

from opentelemetry.instrumentation.asyncpg import AsyncPGInstrumentor
from opentelemetry.test.test_base import TestBase
Expand Down Expand Up @@ -34,3 +39,69 @@ def test_duplicated_uninstrumentation(self):
self.assertFalse(
hasattr(method, "_opentelemetry_ext_asyncpg_applied")
)

def test_cursor_instrumentation(self):
def assert_wrapped(assert_fnc):
for cls, methods in [
(cursor.Cursor, ("forward", "fetch", "fetchrow")),
(cursor.CursorIterator, ("__anext__",)),
]:
for method_name in methods:
method = getattr(cls, method_name, None)
assert_fnc(
isinstance(method, ObjectProxy),
f"{method} isinstance {type(method)}",
)

assert_wrapped(self.assertFalse)
AsyncPGInstrumentor().instrument()
assert_wrapped(self.assertTrue)
AsyncPGInstrumentor().uninstrument()
assert_wrapped(self.assertFalse)

def test_cursor_span_creation(self):
"""Test the cursor wrapper if it creates spans correctly."""

# Mock out all interaction with postgres
async def bind_mock(*args, **kwargs):
return []

async def exec_mock(*args, **kwargs):
return [], None, True

conn = mock.Mock()
conn.is_closed = lambda: False

conn._protocol = mock.Mock()
conn._protocol.bind = bind_mock
conn._protocol.execute = exec_mock
conn._protocol.bind_execute = exec_mock
conn._protocol.close_portal = bind_mock

state = mock.Mock()
state.closed = False

apg = AsyncPGInstrumentor()
apg.instrument(tracer_provider=self.tracer_provider)

# init the cursor and fetch a single record
crs = cursor.Cursor(conn, "SELECT * FROM test", state, [], Record)
asyncio.run(crs._init(1))
asyncio.run(crs.fetch(1))

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 1)
self.assertEqual(spans[0].name, "CURSOR: SELECT")
self.assertTrue(spans[0].status.is_ok)

# Now test that the StopAsyncIteration of the cursor does not get recorded as an ERROR
crs_iter = cursor.CursorIterator(
conn, "SELECT * FROM test", state, [], Record, 1, 1
)

with pytest.raises(StopAsyncIteration):
asyncio.run(crs_iter.__anext__())

spans = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans), 2)
self.assertEqual([span.status.is_ok for span in spans], [True, True])
Loading