@@ -79,6 +79,26 @@ def test_instrumentor_connect(self):
79
79
spans_list = self .memory_exporter .get_finished_spans ()
80
80
self .assertEqual (len (spans_list ), 1 )
81
81
82
+ def test_instrumentor_connect_ctx_manager (self ):
83
+ async def _ctx_manager_connect ():
84
+ AiopgInstrumentor ().instrument ()
85
+
86
+ async with aiopg .connect (database = "test" ) as cnx :
87
+ async with cnx .cursor () as cursor :
88
+ query = "SELECT * FROM test"
89
+ await cursor .execute (query )
90
+
91
+ spans_list = self .memory_exporter .get_finished_spans ()
92
+ self .assertEqual (len (spans_list ), 1 )
93
+ span = spans_list [0 ]
94
+
95
+ # Check version and name in span's instrumentation info
96
+ self .check_span_instrumentation_info (
97
+ span , opentelemetry .instrumentation .aiopg
98
+ )
99
+
100
+ async_call (_ctx_manager_connect ())
101
+
82
102
def test_instrumentor_create_pool (self ):
83
103
AiopgInstrumentor ().instrument ()
84
104
@@ -110,6 +130,27 @@ def test_instrumentor_create_pool(self):
110
130
spans_list = self .memory_exporter .get_finished_spans ()
111
131
self .assertEqual (len (spans_list ), 1 )
112
132
133
+ def test_instrumentor_create_pool_ctx_manager (self ):
134
+ async def _ctx_manager_pool ():
135
+ AiopgInstrumentor ().instrument ()
136
+
137
+ async with aiopg .create_pool (database = "test" ) as pool :
138
+ async with pool .acquire () as cnx :
139
+ async with cnx .cursor () as cursor :
140
+ query = "SELECT * FROM test"
141
+ await cursor .execute (query )
142
+
143
+ spans_list = self .memory_exporter .get_finished_spans ()
144
+ self .assertEqual (len (spans_list ), 1 )
145
+ span = spans_list [0 ]
146
+
147
+ # Check version and name in span's instrumentation info
148
+ self .check_span_instrumentation_info (
149
+ span , opentelemetry .instrumentation .aiopg
150
+ )
151
+
152
+ async_call (_ctx_manager_pool ())
153
+
113
154
def test_custom_tracer_provider_connect (self ):
114
155
resource = resources .Resource .create ({})
115
156
result = self .create_tracer_provider (resource = resource )
@@ -428,6 +469,12 @@ async def _acquire(self):
428
469
)
429
470
return connect
430
471
472
+ def close (self ):
473
+ pass
474
+
475
+ async def wait_closed (self ):
476
+ pass
477
+
431
478
432
479
class MockPsycopg2Connection :
433
480
def __init__ (self , database , server_port , server_host , user ):
@@ -471,6 +518,9 @@ async def callproc(self, query, params=None, throw_exception=False):
471
518
if throw_exception :
472
519
raise Exception ("Test Exception" )
473
520
521
+ def close (self ):
522
+ pass
523
+
474
524
475
525
class AiopgConnectionMock :
476
526
_conn = MagicMock ()
0 commit comments