12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- import unittest
16
- from unittest import mock
17
-
18
15
from opentelemetry import trace as trace_api
19
16
from opentelemetry .ext .dbapi import DatabaseApiIntegration
17
+ from opentelemetry .test .test_base import TestBase
20
18
21
19
22
- class TestDBApiIntegration (unittest . TestCase ):
20
+ class TestDBApiIntegration (TestBase ):
23
21
def setUp (self ):
24
- self .tracer = trace_api .DefaultTracer ()
25
- self .span = MockSpan ()
26
- self .start_current_span_patcher = mock .patch .object (
27
- self .tracer ,
28
- "start_as_current_span" ,
29
- autospec = True ,
30
- spec_set = True ,
31
- return_value = self .span ,
32
- )
33
-
34
- self .start_as_current_span = self .start_current_span_patcher .start ()
35
-
36
- def tearDown (self ):
37
- self .start_current_span_patcher .stop ()
22
+ super ().setUp ()
23
+ self .tracer = self .tracer_provider .get_tracer (__name__ )
38
24
39
25
def test_span_succeeded (self ):
40
26
connection_props = {
@@ -57,28 +43,25 @@ def test_span_succeeded(self):
57
43
)
58
44
cursor = mock_connection .cursor ()
59
45
cursor .execute ("Test query" , ("param1Value" , False ))
60
- self .assertTrue (self .start_as_current_span .called )
46
+ spans_list = self .memory_exporter .get_finished_spans ()
47
+ self .assertEqual (len (spans_list ), 1 )
48
+ span = spans_list [0 ]
49
+ self .assertEqual (span .name , "testcomponent.testdatabase" )
50
+ self .assertIs (span .kind , trace_api .SpanKind .CLIENT )
51
+
52
+ self .assertEqual (span .attributes ["component" ], "testcomponent" )
53
+ self .assertEqual (span .attributes ["db.type" ], "testtype" )
54
+ self .assertEqual (span .attributes ["db.instance" ], "testdatabase" )
55
+ self .assertEqual (span .attributes ["db.statement" ], "Test query" )
61
56
self .assertEqual (
62
- self .start_as_current_span .call_args [0 ][0 ],
63
- "testcomponent.testdatabase" ,
64
- )
65
- self .assertIs (
66
- self .start_as_current_span .call_args [1 ]["kind" ],
67
- trace_api .SpanKind .CLIENT ,
68
- )
69
- self .assertEqual (self .span .attributes ["component" ], "testcomponent" )
70
- self .assertEqual (self .span .attributes ["db.type" ], "testtype" )
71
- self .assertEqual (self .span .attributes ["db.instance" ], "testdatabase" )
72
- self .assertEqual (self .span .attributes ["db.statement" ], "Test query" )
73
- self .assertEqual (
74
- self .span .attributes ["db.statement.parameters" ],
57
+ span .attributes ["db.statement.parameters" ],
75
58
"('param1Value', False)" ,
76
59
)
77
- self .assertEqual (self . span .attributes ["db.user" ], "testuser" )
78
- self .assertEqual (self . span .attributes ["net.peer.name" ], "testhost" )
79
- self .assertEqual (self . span .attributes ["net.peer.port" ], 123 )
60
+ self .assertEqual (span .attributes ["db.user" ], "testuser" )
61
+ self .assertEqual (span .attributes ["net.peer.name" ], "testhost" )
62
+ self .assertEqual (span .attributes ["net.peer.port" ], 123 )
80
63
self .assertIs (
81
- self . span .status .canonical_code ,
64
+ span .status .canonical_code ,
82
65
trace_api .status .StatusCanonicalCode .OK ,
83
66
)
84
67
@@ -88,17 +71,18 @@ def test_span_failed(self):
88
71
mock_connect , {}, {}
89
72
)
90
73
cursor = mock_connection .cursor ()
91
- try :
74
+ with self . assertRaises ( Exception ) :
92
75
cursor .execute ("Test query" , throw_exception = True )
93
- except Exception : # pylint: disable=broad-except
94
- self .assertEqual (
95
- self .span .attributes ["db.statement" ], "Test query"
96
- )
97
- self .assertIs (
98
- self .span .status .canonical_code ,
99
- trace_api .status .StatusCanonicalCode .UNKNOWN ,
100
- )
101
- self .assertEqual (self .span .status .description , "Test Exception" )
76
+
77
+ spans_list = self .memory_exporter .get_finished_spans ()
78
+ self .assertEqual (len (spans_list ), 1 )
79
+ span = spans_list [0 ]
80
+ self .assertEqual (span .attributes ["db.statement" ], "Test query" )
81
+ self .assertIs (
82
+ span .status .canonical_code ,
83
+ trace_api .status .StatusCanonicalCode .UNKNOWN ,
84
+ )
85
+ self .assertEqual (span .status .description , "Test Exception" )
102
86
103
87
def test_executemany (self ):
104
88
db_integration = DatabaseApiIntegration (self .tracer , "testcomponent" )
@@ -107,8 +91,10 @@ def test_executemany(self):
107
91
)
108
92
cursor = mock_connection .cursor ()
109
93
cursor .executemany ("Test query" )
110
- self .assertTrue (self .start_as_current_span .called )
111
- self .assertEqual (self .span .attributes ["db.statement" ], "Test query" )
94
+ spans_list = self .memory_exporter .get_finished_spans ()
95
+ self .assertEqual (len (spans_list ), 1 )
96
+ span = spans_list [0 ]
97
+ self .assertEqual (span .attributes ["db.statement" ], "Test query" )
112
98
113
99
def test_callproc (self ):
114
100
db_integration = DatabaseApiIntegration (self .tracer , "testcomponent" )
@@ -117,9 +103,11 @@ def test_callproc(self):
117
103
)
118
104
cursor = mock_connection .cursor ()
119
105
cursor .callproc ("Test stored procedure" )
120
- self .assertTrue (self .start_as_current_span .called )
106
+ spans_list = self .memory_exporter .get_finished_spans ()
107
+ self .assertEqual (len (spans_list ), 1 )
108
+ span = spans_list [0 ]
121
109
self .assertEqual (
122
- self . span .attributes ["db.statement" ], "Test stored procedure"
110
+ span .attributes ["db.statement" ], "Test stored procedure"
123
111
)
124
112
125
113
@@ -159,23 +147,3 @@ def executemany(self, query, params=None, throw_exception=False):
159
147
def callproc (self , query , params = None , throw_exception = False ):
160
148
if throw_exception :
161
149
raise Exception ("Test Exception" )
162
-
163
-
164
- class MockSpan :
165
- def __enter__ (self ):
166
- return self
167
-
168
- def __exit__ (self , exc_type , exc_val , exc_tb ):
169
- return False
170
-
171
- def __init__ (self ):
172
- self .status = None
173
- self .name = ""
174
- self .kind = trace_api .SpanKind .INTERNAL
175
- self .attributes = {}
176
-
177
- def set_attribute (self , key , value ):
178
- self .attributes [key ] = value
179
-
180
- def set_status (self , status ):
181
- self .status = status
0 commit comments