diff --git a/aws_xray_sdk/ext/psycopg2/patch.py b/aws_xray_sdk/ext/psycopg2/patch.py index c6d8d1b8..0efb27fa 100644 --- a/aws_xray_sdk/ext/psycopg2/patch.py +++ b/aws_xray_sdk/ext/psycopg2/patch.py @@ -1,8 +1,9 @@ +import copy import re import wrapt from operator import methodcaller -from aws_xray_sdk.ext.dbapi2 import XRayTracedConn +from aws_xray_sdk.ext.dbapi2 import XRayTracedConn, XRayTracedCursor def patch(): @@ -12,6 +13,11 @@ def patch(): 'connect', _xray_traced_connect ) + wrapt.wrap_function_wrapper( + 'psycopg2.extensions', + 'register_type', + _xray_register_type_fix + ) def _xray_traced_connect(wrapped, instance, args, kwargs): @@ -32,3 +38,11 @@ def _xray_traced_connect(wrapped, instance, args, kwargs): } return XRayTracedConn(conn, meta) + +def _xray_register_type_fix(wrapped, instance, args, kwargs): + """Send the actual connection or curser to register type.""" + our_args = list(copy.copy(args)) + if len(our_args) == 2 and isinstance(our_args[1], (XRayTracedConn, XRayTracedCursor)): + our_args[1] = our_args[1].__wrapped__ + + return wrapped(*our_args, **kwargs) diff --git a/tests/ext/psycopg2/test_psycopg2.py b/tests/ext/psycopg2/test_psycopg2.py index 941d8675..3d275408 100644 --- a/tests/ext/psycopg2/test_psycopg2.py +++ b/tests/ext/psycopg2/test_psycopg2.py @@ -1,4 +1,5 @@ import psycopg2 +import psycopg2.extras import psycopg2.pool import pytest @@ -144,3 +145,16 @@ def test_execute_bad_query(): exception = subsegment.cause['exceptions'][0] assert exception.type == 'ProgrammingError' + + +def test_register_extensions(): + with testing.postgresql.Postgresql() as postgresql: + url = postgresql.url() + dsn = postgresql.dsn() + conn = psycopg2.connect('dbname=' + dsn['database'] + + ' password=mypassword' + + ' host=' + dsn['host'] + + ' port=' + str(dsn['port']) + + ' user=' + dsn['user']) + assert psycopg2.extras.register_uuid(None, conn) + assert psycopg2.extras.register_uuid(None, conn.cursor()) \ No newline at end of file