@@ -1013,6 +1013,26 @@ def test_schema_error_message(self):
1013
1013
)
1014
1014
1015
1015
1016
+ class MockImport (object ):
1017
+
1018
+ def __init__ (self , module , _mock ):
1019
+ self ._module = module
1020
+ self ._mock = _mock
1021
+ self ._orig_import = None
1022
+
1023
+ def __enter__ (self ):
1024
+ self ._orig_import = sys .modules .get (self ._module , None )
1025
+ sys .modules [self ._module ] = self ._mock
1026
+ return self ._mock
1027
+
1028
+ def __exit__ (self , * args ):
1029
+ if self ._orig_import is None :
1030
+ del sys .modules [self ._module ]
1031
+ else :
1032
+ sys .modules [self ._module ] = self ._orig_import
1033
+ return True
1034
+
1035
+
1016
1036
class TestRefResolver (TestCase ):
1017
1037
1018
1038
base_uri = ""
@@ -1062,7 +1082,7 @@ def test_it_retrieves_unstored_refs_via_requests(self):
1062
1082
ref = "http://bar#baz"
1063
1083
schema = {"baz" : 12 }
1064
1084
1065
- with mock . patch ( "jsonschema.validators. requests" ) as requests :
1085
+ with MockImport ( " requests", mock . Mock () ) as requests :
1066
1086
requests .get .return_value .json .return_value = schema
1067
1087
with self .resolver .resolving (ref ) as resolved :
1068
1088
self .assertEqual (resolved , 12 )
@@ -1072,7 +1092,7 @@ def test_it_retrieves_unstored_refs_via_urlopen(self):
1072
1092
ref = "http://bar#baz"
1073
1093
schema = {"baz" : 12 }
1074
1094
1075
- with mock . patch ( "jsonschema.validators. requests" , None ):
1095
+ with MockImport ( " requests" , None ):
1076
1096
with mock .patch ("jsonschema.validators.urlopen" ) as urlopen :
1077
1097
urlopen .return_value .read .return_value = (
1078
1098
json .dumps (schema ).encode ("utf8" ))
0 commit comments