From ee4e12226d53e4b616e06588cb31c75c5b92fbc6 Mon Sep 17 00:00:00 2001 From: Prashant Srivastava Date: Tue, 26 Jul 2022 16:09:20 -0700 Subject: [PATCH] patching register_default_jsonb from psycopg2.extras --- aws_xray_sdk/ext/psycopg2/patch.py | 22 +++++++++++++++++++--- tests/ext/psycopg2/test_psycopg2.py | 14 ++++++++++++++ 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/aws_xray_sdk/ext/psycopg2/patch.py b/aws_xray_sdk/ext/psycopg2/patch.py index 708bf45c..f6a1d4c6 100644 --- a/aws_xray_sdk/ext/psycopg2/patch.py +++ b/aws_xray_sdk/ext/psycopg2/patch.py @@ -7,7 +7,6 @@ def patch(): - wrapt.wrap_function_wrapper( 'psycopg2', 'connect', @@ -24,11 +23,16 @@ def patch(): _xray_register_type_fix ) + wrapt.wrap_function_wrapper( + 'psycopg2.extras', + 'register_default_jsonb', + _xray_register_default_jsonb_fix + ) -def _xray_traced_connect(wrapped, instance, args, kwargs): +def _xray_traced_connect(wrapped, instance, args, kwargs): conn = wrapped(*args, **kwargs) - parameterized_dsn = { c[0]: c[-1] for c in map(methodcaller('split', '='), conn.dsn.split(' '))} + parameterized_dsn = {c[0]: c[-1] for c in map(methodcaller('split', '='), conn.dsn.split(' '))} meta = { 'database_type': 'PostgreSQL', 'url': 'postgresql://{}@{}:{}/{}'.format( @@ -44,6 +48,7 @@ def _xray_traced_connect(wrapped, instance, args, kwargs): return XRayTracedConn(conn, meta) + def _xray_register_type_fix(wrapped, instance, args, kwargs): """Send the actual connection or curser to register type.""" our_args = list(copy.copy(args)) @@ -51,3 +56,14 @@ def _xray_register_type_fix(wrapped, instance, args, kwargs): our_args[1] = our_args[1].__wrapped__ return wrapped(*our_args, **kwargs) + + +def _xray_register_default_jsonb_fix(wrapped, instance, args, kwargs): + our_kwargs = dict() + for key, value in kwargs.items(): + if key == "conn_or_curs" and isinstance(value, (XRayTracedConn, XRayTracedCursor)): + # unwrap the connection or cursor to be sent to register_default_jsonb + value = value.__wrapped__ + our_kwargs[key] = value + + return wrapped(*args, **our_kwargs) diff --git a/tests/ext/psycopg2/test_psycopg2.py b/tests/ext/psycopg2/test_psycopg2.py index 4736b5c0..9ab80069 100644 --- a/tests/ext/psycopg2/test_psycopg2.py +++ b/tests/ext/psycopg2/test_psycopg2.py @@ -173,3 +173,17 @@ def test_query_as_string(): test_sql = psycopg2.sql.Identifier('test') assert test_sql.as_string(conn) assert test_sql.as_string(conn.cursor()) + + +def test_register_default_jsonb(): + with testing.postgresql.Postgresql() as postgresql: + url = postgresql.url() + dsn = postgresql.dsn() + conn = psycopg2.connect('dbname=' + dsn['database'] + + ' password=mypassword' + + ' host=' + dsn['host'] + + ' port=' + str(dsn['port']) + + ' user=' + dsn['user']) + + assert psycopg2.extras.register_default_jsonb(conn_or_curs=conn, loads=lambda x: x) + assert psycopg2.extras.register_default_jsonb(conn_or_curs=conn.cursor(), loads=lambda x: x)