Skip to content

Commit e3eb630

Browse files
committed
Unpatch dbapi2, patch use custom cursor for Django and chunked_cursor
1 parent 1bf9472 commit e3eb630

File tree

5 files changed

+142
-37
lines changed

5 files changed

+142
-37
lines changed

aws_xray_sdk/ext/dbapi2.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,23 @@ def __enter__(self):
4343
@xray_recorder.capture()
4444
def execute(self, query, *args, **kwargs):
4545

46-
add_sql_meta(self._xray_meta, query)
46+
add_sql_meta(self._xray_meta)
4747
return self.__wrapped__.execute(query, *args, **kwargs)
4848

4949
@xray_recorder.capture()
5050
def executemany(self, query, *args, **kwargs):
5151

52-
add_sql_meta(self._xray_meta, query)
52+
add_sql_meta(self._xray_meta)
5353
return self.__wrapped__.executemany(query, *args, **kwargs)
5454

5555
@xray_recorder.capture()
5656
def callproc(self, proc, args):
5757

58-
add_sql_meta(self._xray_meta, proc)
58+
add_sql_meta(self._xray_meta)
5959
return self.__wrapped__.callproc(proc, args)
6060

6161

62-
def add_sql_meta(meta, query):
62+
def add_sql_meta(meta):
6363

6464
subsegment = xray_recorder.current_subsegment()
6565

@@ -72,7 +72,5 @@ def add_sql_meta(meta, query):
7272
sql_meta = copy.copy(meta)
7373
if sql_meta.get('name', None):
7474
del sql_meta['name']
75-
if xray_recorder.stream_sql:
76-
sql_meta['sanitized_query'] = query
7775
subsegment.set_sql(sql_meta)
7876
subsegment.namespace = 'remote'

aws_xray_sdk/ext/django/db.py

+47-9
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,62 @@
1+
import copy
12
import logging
23
import importlib
34

45
from django.db import connections
56

7+
from aws_xray_sdk.core import xray_recorder
68
from aws_xray_sdk.ext.dbapi2 import XRayTracedCursor
79

810
log = logging.getLogger(__name__)
911

1012

1113
def patch_db():
12-
1314
for conn in connections.all():
1415
module = importlib.import_module(conn.__module__)
1516
_patch_conn(getattr(module, conn.__class__.__name__))
1617

1718

18-
def _patch_conn(conn):
19-
20-
attr = '_xray_original_cursor'
19+
class DjangoXRayTracedCursor(XRayTracedCursor):
20+
def execute(self, query, *args, **kwargs):
21+
if xray_recorder.stream_sql:
22+
_previous_meta = copy.copy(self._xray_meta)
23+
self._xray_meta['sanitized_query'] = query
24+
result = super(DjangoXRayTracedCursor, self).execute(query, *args, **kwargs)
25+
if xray_recorder.stream_sql:
26+
self._xray_meta = _previous_meta
27+
return result
28+
29+
def executemany(self, query, *args, **kwargs):
30+
if xray_recorder.stream_sql:
31+
_previous_meta = copy.copy(self._xray_meta)
32+
self._xray_meta['sanitized_query'] = query
33+
result = super(DjangoXRayTracedCursor, self).executemany(query, *args, **kwargs)
34+
if xray_recorder.stream_sql:
35+
self._xray_meta = _previous_meta
36+
return result
37+
38+
def callproc(self, proc, args):
39+
if xray_recorder.stream_sql:
40+
_previous_meta = copy.copy(self._xray_meta)
41+
self._xray_meta['sanitized_query'] = proc
42+
result = super(DjangoXRayTracedCursor, self).callproc(proc, args)
43+
if xray_recorder.stream_sql:
44+
self._xray_meta = _previous_meta
45+
return result
46+
47+
48+
def _patch_cursor(cursor_name, conn):
49+
attr = '_xray_original_{}'.format(cursor_name)
2150

2251
if hasattr(conn, attr):
23-
log.debug('django built-in db already patched')
52+
log.debug('django built-in db {} already patched'.format(cursor_name))
53+
return
54+
55+
if not hasattr(conn, cursor_name):
56+
log.debug('django built-in db does not have {}'.format(cursor_name))
2457
return
2558

26-
setattr(conn, attr, conn.cursor)
59+
setattr(conn, attr, getattr(conn, cursor_name))
2760

2861
meta = {}
2962

@@ -45,7 +78,12 @@ def cursor(self, *args, **kwargs):
4578
if user:
4679
meta['user'] = user
4780

48-
return XRayTracedCursor(
49-
self._xray_original_cursor(*args, **kwargs), meta)
81+
original_cursor = getattr(self, attr)(*args, **kwargs)
82+
return DjangoXRayTracedCursor(original_cursor, meta)
83+
84+
setattr(conn, cursor_name, cursor)
5085

51-
conn.cursor = cursor
86+
87+
def _patch_conn(conn):
88+
_patch_cursor('cursor', conn)
89+
_patch_cursor('chunked_cursor', conn)

tests/ext/django/test_db.py

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import django
2+
3+
import pytest
4+
5+
from aws_xray_sdk.core import xray_recorder
6+
from aws_xray_sdk.core.context import Context
7+
from aws_xray_sdk.ext.django.db import patch_db
8+
9+
10+
@pytest.fixture(scope='module', autouse=True)
11+
def setup():
12+
django.setup()
13+
xray_recorder.configure(context=Context(),
14+
context_missing='LOG_ERROR')
15+
patch_db()
16+
17+
18+
@pytest.fixture(scope='module')
19+
def user_class(setup):
20+
from django.db import models
21+
from django_fake_model import models as f
22+
23+
class User(f.FakeModel):
24+
name = models.CharField(max_length=255)
25+
password = models.CharField(max_length=255)
26+
27+
return User
28+
29+
30+
@pytest.fixture(
31+
autouse=True,
32+
params=[
33+
False,
34+
True,
35+
]
36+
)
37+
@pytest.mark.django_db
38+
def func_setup(request, user_class):
39+
xray_recorder.stream_sql = request.param
40+
xray_recorder.clear_trace_entities()
41+
xray_recorder.begin_segment('name')
42+
try:
43+
user_class.create_table()
44+
yield
45+
finally:
46+
xray_recorder.clear_trace_entities()
47+
try:
48+
user_class.delete_table()
49+
finally:
50+
xray_recorder.end_segment()
51+
52+
53+
def _assert_query(sql_meta):
54+
if xray_recorder.stream_sql:
55+
assert 'sanitized_query' in sql_meta
56+
assert sql_meta['sanitized_query']
57+
assert sql_meta['sanitized_query'].startswith('SELECT')
58+
else:
59+
if 'sanitized_query' in sql_meta:
60+
assert sql_meta['sanitized_query']
61+
# Django internally executes queries for table checks, ignore those
62+
assert not sql_meta['sanitized_query'].startswith('SELECT')
63+
64+
65+
def test_all(user_class):
66+
""" Test calling all() on get all records.
67+
Verify we run the query and return the SQL as metadata"""
68+
# Materialising the query executes the SQL
69+
list(user_class.objects.all())
70+
subsegment = xray_recorder.current_segment().subsegments[-1]
71+
sql = subsegment.sql
72+
assert sql['database_type'] == 'sqlite'
73+
_assert_query(sql)
74+
75+
76+
def test_filter(user_class):
77+
""" Test calling filter() to get filtered records.
78+
Verify we run the query and return the SQL as metadata"""
79+
# Materialising the query executes the SQL
80+
list(user_class.objects.filter(password='mypassword!').all())
81+
subsegment = xray_recorder.current_segment().subsegments[-1]
82+
sql = subsegment.sql
83+
assert sql['database_type'] == 'sqlite'
84+
_assert_query(sql)
85+
if xray_recorder.stream_sql:
86+
assert 'mypassword!' not in sql['sanitized_query']
87+
assert '"password" = %s' in sql['sanitized_query']

tests/ext/psycopg2/test_psycopg2.py

+3-22
Original file line numberDiff line numberDiff line change
@@ -12,34 +12,20 @@
1212
patch(('psycopg2',))
1313

1414

15-
@pytest.fixture(
16-
autouse=True,
17-
params=[
18-
False,
19-
True,
20-
],
21-
)
22-
def construct_ctx(request):
15+
@pytest.fixture(autouse=True)
16+
def construct_ctx():
2317
"""
2418
Clean up context storage on each test run and begin a segment
2519
so that later subsegment can be attached. After each test run
2620
it cleans up context storage again.
2721
"""
28-
xray_recorder.configure(service='test', sampling=False, context=Context(), stream_sql=request.param)
22+
xray_recorder.configure(service='test', sampling=False, context=Context())
2923
xray_recorder.clear_trace_entities()
3024
xray_recorder.begin_segment('name')
3125
yield
3226
xray_recorder.clear_trace_entities()
3327

3428

35-
def _assert_query(sql_meta, query):
36-
if xray_recorder.stream_sql:
37-
assert 'sanitized_query' in sql_meta
38-
assert sql_meta['sanitized_query'] == query
39-
else:
40-
assert 'sanitized_query' not in sql_meta
41-
42-
4329
def test_execute_dsn_kwargs():
4430
q = 'SELECT 1'
4531
with testing.postgresql.Postgresql() as postgresql:
@@ -60,7 +46,6 @@ def test_execute_dsn_kwargs():
6046
assert sql['user'] == dsn['user']
6147
assert sql['url'] == url
6248
assert sql['database_version']
63-
_assert_query(sql, q)
6449

6550

6651
def test_execute_dsn_kwargs_alt_dbname():
@@ -87,7 +72,6 @@ def test_execute_dsn_kwargs_alt_dbname():
8772
assert sql['user'] == dsn['user']
8873
assert sql['url'] == url
8974
assert sql['database_version']
90-
_assert_query(sql, q)
9175

9276

9377
def test_execute_dsn_string():
@@ -110,7 +94,6 @@ def test_execute_dsn_string():
11094
assert sql['user'] == dsn['user']
11195
assert sql['url'] == url
11296
assert sql['database_version']
113-
_assert_query(sql, q)
11497

11598

11699
def test_execute_in_pool():
@@ -134,7 +117,6 @@ def test_execute_in_pool():
134117
assert sql['user'] == dsn['user']
135118
assert sql['url'] == url
136119
assert sql['database_version']
137-
_assert_query(sql, q)
138120

139121

140122
def test_execute_bad_query():
@@ -163,7 +145,6 @@ def test_execute_bad_query():
163145

164146
exception = subsegment.cause['exceptions'][0]
165147
assert exception.type == 'ProgrammingError'
166-
_assert_query(sql, q)
167148

168149

169150
def test_register_extensions():

tox.ini

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ deps =
1818
future
1919
# the sdk doesn't support earlier version of django
2020
django >= 1.10, <2.0
21+
django-fake-model
2122
pynamodb >= 3.3.1
2223
psycopg2
2324
pg8000

0 commit comments

Comments
 (0)