Skip to content

Commit 0c17709

Browse files
committed
Change the code to save all listen params so we will know how to remove them when uninstument
1 parent 0bac6a0 commit 0c17709

File tree

3 files changed

+39
-17
lines changed

3 files changed

+39
-17
lines changed

Diff for: instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -160,22 +160,17 @@ def _instrument(self, **kwargs):
160160
"create_async_engine",
161161
_wrap_create_async_engine(tracer_provider, enable_commenter),
162162
)
163-
164-
self.engines = []
165163
if kwargs.get("engine") is not None:
166-
self.engines.append(
167-
EngineTracer(
168-
_get_tracer(tracer_provider),
169-
kwargs.get("engine"),
170-
kwargs.get("enable_commenter", False),
171-
kwargs.get("commenter_options", {}),
172-
)
164+
return EngineTracer(
165+
_get_tracer(tracer_provider),
166+
kwargs.get("engine"),
167+
kwargs.get("enable_commenter", False),
168+
kwargs.get("commenter_options", {}),
173169
)
174-
return self.engines[0]
175170
if kwargs.get("engines") is not None and isinstance(
176171
kwargs.get("engines"), Sequence
177172
):
178-
self.engines = [
173+
return [
179174
EngineTracer(
180175
_get_tracer(tracer_provider),
181176
engine,
@@ -184,7 +179,6 @@ def _instrument(self, **kwargs):
184179
)
185180
for engine in kwargs.get("engines")
186181
]
187-
return self.engines
188182

189183
return None
190184

@@ -194,5 +188,4 @@ def _uninstrument(self, **kwargs):
194188
unwrap(Engine, "connect")
195189
if parse_version(sqlalchemy.__version__).release >= (1, 4):
196190
unwrap(sqlalchemy.ext.asyncio, "create_async_engine")
197-
for engine in self.engines:
198-
engine.remove_event_listeners()
191+
EngineTracer.remove_all_event_listeners()

Diff for: instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ def _wrap_connect_internal(func, module, args, kwargs):
9898

9999

100100
class EngineTracer:
101+
_removeEventListenerParams = []
102+
101103
def __init__(
102104
self, tracer, engine, enable_commenter=False, commenter_options=None
103105
):
@@ -108,11 +110,22 @@ def __init__(
108110
self.commenter_options = commenter_options if commenter_options else {}
109111
self._leading_comment_remover = re.compile(r"^/\*.*?\*/")
110112

111-
listen(
113+
self.register_event_listener(
112114
engine, "before_cursor_execute", self._before_cur_exec, retval=True
113115
)
114-
listen(engine, "after_cursor_execute", _after_cur_exec)
115-
listen(engine, "handle_error", _handle_error)
116+
self.register_event_listener(engine, "after_cursor_execute", _after_cur_exec)
117+
self.register_event_listener(engine, "handle_error", _handle_error)
118+
119+
@classmethod
120+
def register_event_listener(cls, target, identifier, fn, *args, **kw):
121+
listen(target, identifier, fn, *args, **kw)
122+
cls._removeEventListenerParams.append((target, identifier, fn))
123+
124+
@classmethod
125+
def remove_all_event_listeners(cls):
126+
for removeParams in cls._removeEventListenerParams:
127+
remove(*removeParams)
128+
cls._removeEventListenerParams.clear()
116129

117130
def remove_event_listeners(self):
118131
remove(self.engine, "before_cursor_execute", self._before_cur_exec)

Diff for: instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py

+16
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,22 @@ def test_uninstrument(self):
255255
spans = self.memory_exporter.get_finished_spans()
256256
self.assertEqual(len(spans), 0)
257257

258+
def test_uninstrument_without_engine(self):
259+
SQLAlchemyInstrumentor().instrument(tracer_provider=self.tracer_provider)
260+
from sqlalchemy import create_engine
261+
engine = create_engine("sqlite:///:memory:")
262+
263+
cnx = engine.connect()
264+
cnx.execute("SELECT 1 + 1;").fetchall()
265+
spans = self.memory_exporter.get_finished_spans()
266+
self.assertEqual(len(spans), 2)
267+
268+
self.memory_exporter.clear()
269+
SQLAlchemyInstrumentor().uninstrument()
270+
cnx.execute("SELECT 1 + 1;").fetchall()
271+
spans = self.memory_exporter.get_finished_spans()
272+
self.assertEqual(len(spans), 0)
273+
258274
def test_no_op_tracer_provider(self):
259275
engine = create_engine("sqlite:///:memory:")
260276
SQLAlchemyInstrumentor().instrument(

0 commit comments

Comments
 (0)