Skip to content

Commit 5693dd4

Browse files
authored
Whitelist what pickle is serializing in context (#11688)
* Whitelist what pickle is serializing in context * Indent
1 parent 205a1cb commit 5693dd4

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

sdk/core/azure-core/azure/core/pipeline/__init__.py

+18
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ class PipelineContext(dict):
6363
:param transport: The HTTP transport type.
6464
:param kwargs: Developer-defined keyword arguments.
6565
"""
66+
_PICKLE_CONTEXT = {
67+
'deserialized_data'
68+
}
6669

6770
def __init__(self, transport, **kwargs): # pylint: disable=super-init-not-called
6871
self.transport = transport
@@ -75,6 +78,21 @@ def __getstate__(self):
7578
del state['transport']
7679
return state
7780

81+
def __reduce__(self):
82+
reduced = super(PipelineContext, self).__reduce__()
83+
saved_context = {}
84+
for key, value in self.items():
85+
if key in self._PICKLE_CONTEXT:
86+
saved_context[key] = value
87+
# 1 is for from __reduce__ spec of pickle (generic args for recreation)
88+
# 2 is how dict is implementing __reduce__ (dict specific)
89+
# tuple are read-only, we use a list in the meantime
90+
reduced = list(reduced)
91+
dict_reduced_result = list(reduced[1])
92+
dict_reduced_result[2] = saved_context
93+
reduced[1] = tuple(dict_reduced_result)
94+
return tuple(reduced)
95+
7896
def __setstate__(self, state):
7997
self.__dict__.update(state)
8098
# Re-create the unpickable entries

sdk/core/azure-core/tests/test_universal_pipeline.py

+32
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#
2626
#--------------------------------------------------------------------------
2727
import logging
28+
import pickle
2829
try:
2930
from unittest import mock
3031
except ImportError:
@@ -56,6 +57,37 @@
5657
HTTPPolicy,
5758
)
5859

60+
def test_pipeline_context():
61+
kwargs={
62+
'stream':True,
63+
'cont_token':"bla"
64+
}
65+
context = PipelineContext('transport', **kwargs)
66+
context['foo'] = 'bar'
67+
context['xyz'] = '123'
68+
context['deserialized_data'] = 'marvelous'
69+
70+
assert context['foo'] == 'bar'
71+
assert context.options == kwargs
72+
73+
with pytest.raises(TypeError):
74+
context.clear()
75+
76+
with pytest.raises(TypeError):
77+
context.update({})
78+
79+
assert context.pop('foo') == 'bar'
80+
assert 'foo' not in context
81+
82+
serialized = pickle.dumps(context)
83+
84+
revived_context = pickle.loads(serialized)
85+
assert revived_context.options == kwargs
86+
assert revived_context.transport is None
87+
assert 'deserialized_data' in revived_context
88+
assert len(revived_context) == 1
89+
90+
5991
def test_request_history():
6092
class Non_deep_copiable(object):
6193
def __deepcopy__(self, memodict={}):

0 commit comments

Comments
 (0)