Skip to content

Commit 4060a56

Browse files
committed
Fix psycopg2 register type
1 parent 180de75 commit 4060a56

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

aws_xray_sdk/ext/psycopg2/patch.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
import copy
12
import re
23
import wrapt
34
from operator import methodcaller
45

5-
from aws_xray_sdk.ext.dbapi2 import XRayTracedConn
6+
from aws_xray_sdk.ext.dbapi2 import XRayTracedConn, XRayTracedCursor
67

78

89
def patch():
@@ -12,6 +13,11 @@ def patch():
1213
'connect',
1314
_xray_traced_connect
1415
)
16+
wrapt.wrap_function_wrapper(
17+
'psycopg2.extensions',
18+
'register_type',
19+
_xray_register_type_fix
20+
)
1521

1622

1723
def _xray_traced_connect(wrapped, instance, args, kwargs):
@@ -32,3 +38,11 @@ def _xray_traced_connect(wrapped, instance, args, kwargs):
3238
}
3339

3440
return XRayTracedConn(conn, meta)
41+
42+
def _xray_register_type_fix(wrapped, instance, args, kwargs):
43+
"""Send the actual connection or curser to register type."""
44+
our_args = list(copy.copy(args))
45+
if len(our_args) == 2 and isinstance(our_args[1], (XRayTracedConn, XRayTracedCursor)):
46+
our_args[1] = our_args[1].__wrapped__
47+
48+
return wrapped(*our_args, **kwargs)

tests/ext/psycopg2/test_psycopg2.py

+14
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import psycopg2
2+
import psycopg2.extras
23
import psycopg2.pool
34

45
import pytest
@@ -144,3 +145,16 @@ def test_execute_bad_query():
144145

145146
exception = subsegment.cause['exceptions'][0]
146147
assert exception.type == 'ProgrammingError'
148+
149+
150+
def test_register_extensions():
151+
with testing.postgresql.Postgresql() as postgresql:
152+
url = postgresql.url()
153+
dsn = postgresql.dsn()
154+
conn = psycopg2.connect('dbname=' + dsn['database'] +
155+
' password=mypassword' +
156+
' host=' + dsn['host'] +
157+
' port=' + str(dsn['port']) +
158+
' user=' + dsn['user'])
159+
assert psycopg2.extras.register_uuid(None, conn)
160+
assert psycopg2.extras.register_uuid(None, conn.cursor())

0 commit comments

Comments
 (0)