Skip to content

Commit c8b6de6

Browse files
authored
Add support for SQLAlchemy 1.4 (#568)
1 parent d671a10 commit c8b6de6

File tree

6 files changed

+123
-48
lines changed

6 files changed

+123
-48
lines changed

Diff for: CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
6666
([#563](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/563))
6767
- `opentelemetry-exporter-datadog` Datadog exporter should not use `unknown_service` as fallback resource service name.
6868
([#570](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/570))
69+
- Add support for the async extension of SQLAlchemy (>= 1.4)
70+
([#568](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/568))
6971

7072
### Added
7173
- `opentelemetry-instrumentation-httpx` Add `httpx` instrumentation

Diff for: instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py

+23
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,32 @@
3636
engine=engine,
3737
)
3838
39+
# of the async variant of SQLAlchemy
40+
41+
from sqlalchemy.ext.asyncio import create_async_engine
42+
43+
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
44+
import sqlalchemy
45+
46+
engine = create_async_engine("sqlite:///:memory:")
47+
SQLAlchemyInstrumentor().instrument(
48+
engine=engine.sync_engine
49+
)
50+
3951
API
4052
---
4153
"""
4254
from typing import Collection
4355

4456
import sqlalchemy
57+
from packaging.version import parse as parse_version
4558
from wrapt import wrap_function_wrapper as _w
4659

4760
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
4861
from opentelemetry.instrumentation.sqlalchemy.engine import (
4962
EngineTracer,
5063
_get_tracer,
64+
_wrap_create_async_engine,
5165
_wrap_create_engine,
5266
)
5367
from opentelemetry.instrumentation.sqlalchemy.package import _instruments
@@ -76,6 +90,13 @@ def _instrument(self, **kwargs):
7690
"""
7791
_w("sqlalchemy", "create_engine", _wrap_create_engine)
7892
_w("sqlalchemy.engine", "create_engine", _wrap_create_engine)
93+
if parse_version(sqlalchemy.__version__).release >= (1, 4):
94+
_w(
95+
"sqlalchemy.ext.asyncio",
96+
"create_async_engine",
97+
_wrap_create_async_engine,
98+
)
99+
79100
if kwargs.get("engine") is not None:
80101
return EngineTracer(
81102
_get_tracer(
@@ -88,3 +109,5 @@ def _instrument(self, **kwargs):
88109
def _uninstrument(self, **kwargs):
89110
unwrap(sqlalchemy, "create_engine")
90111
unwrap(sqlalchemy.engine, "create_engine")
112+
if parse_version(sqlalchemy.__version__).release >= (1, 4):
113+
unwrap(sqlalchemy.ext.asyncio, "create_async_engine")

Diff for: instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py

+34-41
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from threading import local
16-
1715
from sqlalchemy.event import listen # pylint: disable=no-name-in-module
1816

1917
from opentelemetry import trace
@@ -44,6 +42,16 @@ def _get_tracer(engine, tracer_provider=None):
4442
)
4543

4644

45+
# pylint: disable=unused-argument
46+
def _wrap_create_async_engine(func, module, args, kwargs):
47+
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
48+
object that will listen to SQLAlchemy events.
49+
"""
50+
engine = func(*args, **kwargs)
51+
EngineTracer(_get_tracer(engine), engine.sync_engine)
52+
return engine
53+
54+
4755
# pylint: disable=unused-argument
4856
def _wrap_create_engine(func, module, args, kwargs):
4957
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
@@ -59,20 +67,10 @@ def __init__(self, tracer, engine):
5967
self.tracer = tracer
6068
self.engine = engine
6169
self.vendor = _normalize_vendor(engine.name)
62-
self.cursor_mapping = {}
63-
self.local = local()
6470

6571
listen(engine, "before_cursor_execute", self._before_cur_exec)
66-
listen(engine, "after_cursor_execute", self._after_cur_exec)
67-
listen(engine, "handle_error", self._handle_error)
68-
69-
@property
70-
def current_thread_span(self):
71-
return getattr(self.local, "current_span", None)
72-
73-
@current_thread_span.setter
74-
def current_thread_span(self, span):
75-
setattr(self.local, "current_span", span)
72+
listen(engine, "after_cursor_execute", _after_cur_exec)
73+
listen(engine, "handle_error", _handle_error)
7674

7775
def _operation_name(self, db_name, statement):
7876
parts = []
@@ -90,7 +88,9 @@ def _operation_name(self, db_name, statement):
9088
return " ".join(parts)
9189

9290
# pylint: disable=unused-argument
93-
def _before_cur_exec(self, conn, cursor, statement, *args):
91+
def _before_cur_exec(
92+
self, conn, cursor, statement, params, context, executemany
93+
):
9494
attrs, found = _get_attributes_from_url(conn.engine.url)
9595
if not found:
9696
attrs = _get_attributes_from_cursor(self.vendor, cursor, attrs)
@@ -100,42 +100,35 @@ def _before_cur_exec(self, conn, cursor, statement, *args):
100100
self._operation_name(db_name, statement),
101101
kind=trace.SpanKind.CLIENT,
102102
)
103-
self.current_thread_span = self.cursor_mapping[cursor] = span
104103
with trace.use_span(span, end_on_exit=False):
105104
if span.is_recording():
106105
span.set_attribute(SpanAttributes.DB_STATEMENT, statement)
107106
span.set_attribute(SpanAttributes.DB_SYSTEM, self.vendor)
108107
for key, value in attrs.items():
109108
span.set_attribute(key, value)
110109

111-
# pylint: disable=unused-argument
112-
def _after_cur_exec(self, conn, cursor, statement, *args):
113-
span = self.cursor_mapping.get(cursor, None)
114-
if span is None:
115-
return
110+
context._otel_span = span
116111

117-
span.end()
118-
self._cleanup(cursor)
119112

120-
def _handle_error(self, context):
121-
span = self.current_thread_span
122-
if span is None:
123-
return
113+
# pylint: disable=unused-argument
114+
def _after_cur_exec(conn, cursor, statement, params, context, executemany):
115+
span = getattr(context, "_otel_span", None)
116+
if span is None:
117+
return
124118

125-
try:
126-
if span.is_recording():
127-
span.set_status(
128-
Status(StatusCode.ERROR, str(context.original_exception),)
129-
)
130-
finally:
131-
span.end()
132-
self._cleanup(context.cursor)
133-
134-
def _cleanup(self, cursor):
135-
try:
136-
del self.cursor_mapping[cursor]
137-
except KeyError:
138-
pass
119+
span.end()
120+
121+
122+
def _handle_error(context):
123+
span = getattr(context.execution_context, "_otel_span", None)
124+
if span is None:
125+
return
126+
127+
if span.is_recording():
128+
span.set_status(
129+
Status(StatusCode.ERROR, str(context.original_exception),)
130+
)
131+
span.end()
139132

140133

141134
def _get_attributes_from_url(url):

Diff for: instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py

+47
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import asyncio
1415
from unittest import mock
1516

17+
import pytest
18+
import sqlalchemy
1619
from sqlalchemy import create_engine
1720

1821
from opentelemetry import trace
@@ -38,6 +41,29 @@ def test_trace_integration(self):
3841
self.assertEqual(spans[0].name, "SELECT :memory:")
3942
self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT)
4043

44+
@pytest.mark.skipif(
45+
not sqlalchemy.__version__.startswith("1.4"),
46+
reason="only run async tests for 1.4",
47+
)
48+
def test_async_trace_integration(self):
49+
async def run():
50+
from sqlalchemy.ext.asyncio import ( # pylint: disable-all
51+
create_async_engine,
52+
)
53+
54+
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
55+
SQLAlchemyInstrumentor().instrument(
56+
engine=engine.sync_engine, tracer_provider=self.tracer_provider
57+
)
58+
async with engine.connect() as cnx:
59+
await cnx.execute(sqlalchemy.text("SELECT 1 + 1;"))
60+
spans = self.memory_exporter.get_finished_spans()
61+
self.assertEqual(len(spans), 1)
62+
self.assertEqual(spans[0].name, "SELECT :memory:")
63+
self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT)
64+
65+
asyncio.get_event_loop().run_until_complete(run())
66+
4167
def test_not_recording(self):
4268
mock_tracer = mock.Mock()
4369
mock_span = mock.Mock()
@@ -68,3 +94,24 @@ def test_create_engine_wrapper(self):
6894
self.assertEqual(len(spans), 1)
6995
self.assertEqual(spans[0].name, "SELECT :memory:")
7096
self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT)
97+
98+
@pytest.mark.skipif(
99+
not sqlalchemy.__version__.startswith("1.4"),
100+
reason="only run async tests for 1.4",
101+
)
102+
def test_create_async_engine_wrapper(self):
103+
async def run():
104+
SQLAlchemyInstrumentor().instrument()
105+
from sqlalchemy.ext.asyncio import ( # pylint: disable-all
106+
create_async_engine,
107+
)
108+
109+
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
110+
async with engine.connect() as cnx:
111+
await cnx.execute(sqlalchemy.text("SELECT 1 + 1;"))
112+
spans = self.memory_exporter.get_finished_spans()
113+
self.assertEqual(len(spans), 1)
114+
self.assertEqual(spans[0].name, "SELECT :memory:")
115+
self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT)
116+
117+
asyncio.get_event_loop().run_until_complete(run())

Diff for: tests/opentelemetry-docker-tests/tests/sqlalchemy_tests/mixins.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import contextlib
1616
import logging
1717
import threading
18+
import unittest
1819

1920
from sqlalchemy import Column, Integer, String, create_engine, insert
2021
from sqlalchemy.ext.declarative import declarative_base
@@ -242,4 +243,10 @@ def insert_players(session):
242243
close_all_sessions()
243244

244245
spans = self.memory_exporter.get_finished_spans()
245-
self.assertEqual(len(spans), 5)
246+
247+
# SQLAlchemy 1.4 uses the `execute_values` extension of the psycopg2 dialect to
248+
# batch inserts together which means `insert_players` only generates one span.
249+
# See https://docs.sqlalchemy.org/en/14/changelog/migration_14.html#orm-batch-inserts-with-psycopg2-now-batch-statements-with-returning-in-most-cases
250+
self.assertEqual(
251+
len(spans), 5 if self.VENDOR not in ["postgresql"] else 3
252+
)

Diff for: tox.ini

+9-6
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ envlist =
122122
py3{6,7,8,9}-test-instrumentation-grpc
123123

124124
; opentelemetry-instrumentation-sqlalchemy
125-
py3{6,7,8,9}-test-instrumentation-sqlalchemy
126-
pypy3-test-instrumentation-sqlalchemy
125+
py3{6,7,8,9}-test-instrumentation-sqlalchemy{11,14}
126+
pypy3-test-instrumentation-sqlalchemy{11,14}
127127

128128
; opentelemetry-instrumentation-redis
129129
py3{6,7,8,9}-test-instrumentation-redis
@@ -173,6 +173,9 @@ deps =
173173
elasticsearch6: elasticsearch>=6.0,<7.0
174174
elasticsearch7: elasticsearch-dsl>=7.0,<8.0
175175
elasticsearch7: elasticsearch>=7.0,<8.0
176+
sqlalchemy11: sqlalchemy>=1.1,<1.2
177+
sqlalchemy14: aiosqlite
178+
sqlalchemy14: sqlalchemy~=1.4
176179

177180
; FIXME: add coverage testing
178181
; FIXME: add mypy testing
@@ -205,7 +208,7 @@ changedir =
205208
test-instrumentation-redis: instrumentation/opentelemetry-instrumentation-redis/tests
206209
test-instrumentation-requests: instrumentation/opentelemetry-instrumentation-requests/tests
207210
test-instrumentation-sklearn: instrumentation/opentelemetry-instrumentation-sklearn/tests
208-
test-instrumentation-sqlalchemy: instrumentation/opentelemetry-instrumentation-sqlalchemy/tests
211+
test-instrumentation-sqlalchemy{11,14}: instrumentation/opentelemetry-instrumentation-sqlalchemy/tests
209212
test-instrumentation-sqlite3: instrumentation/opentelemetry-instrumentation-sqlite3/tests
210213
test-instrumentation-starlette: instrumentation/opentelemetry-instrumentation-starlette/tests
211214
test-instrumentation-tornado: instrumentation/opentelemetry-instrumentation-tornado/tests
@@ -290,7 +293,7 @@ commands_pre =
290293

291294
sklearn: pip install {toxinidir}/instrumentation/opentelemetry-instrumentation-sklearn[test]
292295

293-
sqlalchemy: pip install {toxinidir}/instrumentation/opentelemetry-instrumentation-sqlalchemy[test]
296+
sqlalchemy{11,14}: pip install {toxinidir}/instrumentation/opentelemetry-instrumentation-sqlalchemy[test]
294297

295298
elasticsearch{2,5,6,7}: pip install {toxinidir}/opentelemetry-python-core/opentelemetry-instrumentation {toxinidir}/instrumentation/opentelemetry-instrumentation-elasticsearch[test]
296299

@@ -329,7 +332,7 @@ commands =
329332

330333
[testenv:lint]
331334
basepython: python3.9
332-
recreate = False
335+
recreate = False
333336
deps =
334337
-c dev-requirements.txt
335338
flaky
@@ -399,7 +402,7 @@ deps =
399402
PyMySQL ~= 0.10.1
400403
psycopg2 ~= 2.8.4
401404
aiopg >= 0.13.0, < 1.3.0
402-
sqlalchemy ~= 1.3.16
405+
sqlalchemy ~= 1.4
403406
redis ~= 3.3.11
404407
celery[pytest] >= 4.0, < 6.0
405408
protobuf>=3.13.0

0 commit comments

Comments
 (0)