diff --git a/CHANGELOG.md b/CHANGELOG.md index c408a041bb..35ea6ff194 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#378](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/378)) - `opentelemetry-instrumentation-wsgi` Reimplement `keys` method to return actual keys from the carrier instead of an empty list. ([#379](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/379)) +- `opentelemetry-instrumentation-sqlalchemy` Fix multithreading issues in recording spans from SQLAlchemy + ([#315](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/315)) ### Changed - Rename `IdsGenerator` to `IdGenerator` diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py index 032de562b0..683af77e45 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from threading import local +from weakref import WeakKeyDictionary + from sqlalchemy.event import listen # pylint: disable=no-name-in-module from opentelemetry import trace @@ -66,12 +69,21 @@ def __init__(self, tracer, engine): self.tracer = tracer self.engine = engine self.vendor = _normalize_vendor(engine.name) - self.current_span = None + self.cursor_mapping = WeakKeyDictionary() + self.local = local() listen(engine, "before_cursor_execute", self._before_cur_exec) listen(engine, "after_cursor_execute", self._after_cur_exec) listen(engine, "handle_error", self._handle_error) + @property + def current_thread_span(self): + return getattr(self.local, "current_span", None) + + @current_thread_span.setter + def current_thread_span(self, span): + setattr(self.local, "current_span", span) + def _operation_name(self, db_name, statement): parts = [] if isinstance(statement, str): @@ -94,34 +106,38 @@ def _before_cur_exec(self, conn, cursor, statement, *args): attrs = _get_attributes_from_cursor(self.vendor, cursor, attrs) db_name = attrs.get(_DB, "") - self.current_span = self.tracer.start_span( + span = self.tracer.start_span( self._operation_name(db_name, statement), kind=trace.SpanKind.CLIENT, ) - with trace.use_span(self.current_span, end_on_exit=False): - if self.current_span.is_recording(): - self.current_span.set_attribute(_STMT, statement) - self.current_span.set_attribute("db.system", self.vendor) + self.current_thread_span = self.cursor_mapping[cursor] = span + with trace.use_span(span, end_on_exit=False): + if span.is_recording(): + span.set_attribute(_STMT, statement) + span.set_attribute("db.system", self.vendor) for key, value in attrs.items(): - self.current_span.set_attribute(key, value) + span.set_attribute(key, value) # pylint: disable=unused-argument def _after_cur_exec(self, conn, cursor, statement, *args): - if self.current_span is None: + span = self.cursor_mapping.get(cursor, None) + if span is None: return - self.current_span.end() + + span.end() def _handle_error(self, context): - if self.current_span is None: + span = self.current_thread_span + if span is None: return try: - if self.current_span.is_recording(): - self.current_span.set_status( + if span.is_recording(): + span.set_status( Status(StatusCode.ERROR, str(context.original_exception),) ) finally: - self.current_span.end() + span.end() def _get_attributes_from_url(url):