11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ from typing import Coroutine
14
15
from unittest import mock
15
16
16
17
from sqlalchemy import create_engine
17
-
18
+ import sqlalchemy
18
19
from opentelemetry import trace
19
20
from opentelemetry .instrumentation .sqlalchemy import SQLAlchemyInstrumentor
20
21
from opentelemetry .test .test_base import TestBase
22
+ import asyncio
23
+
24
+
25
+ def _call_async (coro : Coroutine ):
26
+ return asyncio .get_event_loop ().run_until_complete (coro )
21
27
22
28
23
29
class TestSqlalchemyInstrumentation (TestBase ):
@@ -28,7 +34,8 @@ def tearDown(self):
28
34
def test_trace_integration (self ):
29
35
engine = create_engine ("sqlite:///:memory:" )
30
36
SQLAlchemyInstrumentor ().instrument (
31
- engine = engine , tracer_provider = self .tracer_provider ,
37
+ engine = engine ,
38
+ tracer_provider = self .tracer_provider ,
32
39
)
33
40
cnx = engine .connect ()
34
41
cnx .execute ("SELECT 1 + 1;" ).fetchall ()
@@ -38,6 +45,25 @@ def test_trace_integration(self):
38
45
self .assertEqual (spans [0 ].name , "SELECT :memory:" )
39
46
self .assertEqual (spans [0 ].kind , trace .SpanKind .CLIENT )
40
47
48
+ def test_async_trace_integration (self ):
49
+ if sqlalchemy .__version__ .startswith ("1.3" ):
50
+ return
51
+ from sqlalchemy .ext .asyncio import (
52
+ create_async_engine ,
53
+ ) # pylint: disable-all
54
+
55
+ engine = create_async_engine ("sqlite+aiosqlite:///:memory:" )
56
+ SQLAlchemyInstrumentor ().instrument (
57
+ engine = engine .sync_engine , tracer_provider = self .tracer_provider
58
+ )
59
+ cnx = _call_async (engine .connect ())
60
+ _call_async (cnx .execute (sqlalchemy .text ("SELECT 1 + 1;" ))).fetchall ()
61
+ _call_async (cnx .close ())
62
+ spans = self .memory_exporter .get_finished_spans ()
63
+ self .assertEqual (len (spans ), 1 )
64
+ self .assertEqual (spans [0 ].name , "SELECT :memory:" )
65
+ self .assertEqual (spans [0 ].kind , trace .SpanKind .CLIENT )
66
+
41
67
def test_not_recording (self ):
42
68
mock_tracer = mock .Mock ()
43
69
mock_span = mock .Mock ()
@@ -47,7 +73,8 @@ def test_not_recording(self):
47
73
tracer .return_value = mock_tracer
48
74
engine = create_engine ("sqlite:///:memory:" )
49
75
SQLAlchemyInstrumentor ().instrument (
50
- engine = engine , tracer_provider = self .tracer_provider ,
76
+ engine = engine ,
77
+ tracer_provider = self .tracer_provider ,
51
78
)
52
79
cnx = engine .connect ()
53
80
cnx .execute ("SELECT 1 + 1;" ).fetchall ()
0 commit comments