Skip to content

Commit ed5259e

Browse files
committed
respect provided tracer provider when instrumenting sqlalchemy
This change updates the SQLALchemyInstrumentor to respect the tracer provider that is passed in through the kwargs when patching the `create_engine` functionality provided by SQLAlchemy. Previously, it would default to the global tracer provider.
1 parent 224780f commit ed5259e

File tree

3 files changed

+51
-17
lines changed

3 files changed

+51
-17
lines changed

Diff for: instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,13 @@ def _instrument(self, **kwargs):
8888
Returns:
8989
An instrumented engine if passed in as an argument, None otherwise.
9090
"""
91-
_w("sqlalchemy", "create_engine", _wrap_create_engine)
92-
_w("sqlalchemy.engine", "create_engine", _wrap_create_engine)
91+
_w("sqlalchemy", "create_engine", _wrap_create_engine(kwargs))
92+
_w("sqlalchemy.engine", "create_engine", _wrap_create_engine(kwargs))
9393
if parse_version(sqlalchemy.__version__).release >= (1, 4):
9494
_w(
9595
"sqlalchemy.ext.asyncio",
9696
"create_async_engine",
97-
_wrap_create_async_engine,
97+
_wrap_create_async_engine(kwargs),
9898
)
9999

100100
if kwargs.get("engine") is not None:

Diff for: instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py

+24-14
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,33 @@ def _get_tracer(engine, tracer_provider=None):
4343

4444

4545
# pylint: disable=unused-argument
46-
def _wrap_create_async_engine(func, module, args, kwargs):
47-
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
48-
object that will listen to SQLAlchemy events.
49-
"""
50-
engine = func(*args, **kwargs)
51-
EngineTracer(_get_tracer(engine), engine.sync_engine)
52-
return engine
46+
def _wrap_create_async_engine(kwargs):
47+
tracer_provider = kwargs.get("tracer_provider")
48+
49+
def _wrap_create_async_engine_internal(func, module, args, kwargs):
50+
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
51+
object that will listen to SQLAlchemy events.
52+
"""
53+
engine = func(*args, **kwargs)
54+
EngineTracer(_get_tracer(engine, tracer_provider), engine.sync_engine)
55+
return engine
56+
57+
return _wrap_create_async_engine_internal
5358

5459

5560
# pylint: disable=unused-argument
56-
def _wrap_create_engine(func, module, args, kwargs):
57-
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
58-
object that will listen to SQLAlchemy events.
59-
"""
60-
engine = func(*args, **kwargs)
61-
EngineTracer(_get_tracer(engine), engine)
62-
return engine
61+
def _wrap_create_engine(kwargs):
62+
tracer_provider = kwargs.get("tracer_provider")
63+
64+
def _wrap_create_engine_internal(func, module, args, kwargs):
65+
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
66+
object that will listen to SQLAlchemy events.
67+
"""
68+
engine = func(*args, **kwargs)
69+
EngineTracer(_get_tracer(engine, tracer_provider), engine)
70+
return engine
71+
72+
return _wrap_create_engine_internal
6373

6474

6575
class EngineTracer:

Diff for: instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py

+24
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from sqlalchemy import create_engine
2020

2121
from opentelemetry import trace
22+
from opentelemetry.sdk.trace import TracerProvider, export
23+
from opentelemetry.sdk.resources import Resource
2224
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
2325
from opentelemetry.test.test_base import TestBase
2426

@@ -95,6 +97,28 @@ def test_create_engine_wrapper(self):
9597
self.assertEqual(spans[0].name, "SELECT :memory:")
9698
self.assertEqual(spans[0].kind, trace.SpanKind.CLIENT)
9799

100+
def test_custom_tracer_provider(self):
101+
provider = TracerProvider(
102+
resource=Resource.create(
103+
{"service.name": "test", "deployment.environment": "env", "service.version": "1234"},
104+
),
105+
)
106+
provider.add_span_processor(export.SimpleSpanProcessor(self.memory_exporter))
107+
108+
SQLAlchemyInstrumentor().instrument(tracer_provider=provider)
109+
from sqlalchemy import create_engine # pylint: disable-all
110+
111+
engine = create_engine("sqlite:///:memory:")
112+
cnx = engine.connect()
113+
cnx.execute("SELECT 1 + 1;").fetchall()
114+
spans = self.memory_exporter.get_finished_spans()
115+
116+
self.assertEqual(len(spans), 1)
117+
self.assertEqual(spans[0].resource.attributes["service.name"], "test")
118+
self.assertEqual(spans[0].resource.attributes["deployment.environment"], "env")
119+
self.assertEqual(spans[0].resource.attributes["service.version"], "1234")
120+
121+
98122
@pytest.mark.skipif(
99123
not sqlalchemy.__version__.startswith("1.4"),
100124
reason="only run async tests for 1.4",

0 commit comments

Comments
 (0)