Skip to content

Commit c59c305

Browse files
committed
improve test class
1 parent 4e3009a commit c59c305

File tree

1 file changed

+42
-45
lines changed

1 file changed

+42
-45
lines changed

test/test_change_stream.py

+42-45
Original file line numberDiff line numberDiff line change
@@ -1106,6 +1106,46 @@ def setFailPoint(self, scenario_dict):
11061106
client_context.client.admin.command,
11071107
'configureFailPoint', fail_cmd['configureFailPoint'], mode='off')
11081108

1109+
def assert_list_contents_are_subset(self, superlist, sublist):
1110+
"""Check that each element in sublist is a subset of the corresponding
1111+
element in superlist."""
1112+
self.assertEqual(len(superlist), len(sublist))
1113+
for sup, sub in zip(superlist, sublist):
1114+
if isinstance(sub, dict):
1115+
self.assert_dict_is_subset(sup, sub)
1116+
continue
1117+
if isinstance(sub, (list, tuple)):
1118+
self.assert_list_contents_are_subset(sup, sub)
1119+
continue
1120+
self.assertEqual(sup, sub)
1121+
1122+
def assert_dict_is_subset(self, superdict, subdict):
1123+
"""Check that subdict is a subset of superdict."""
1124+
exempt_fields = ["documentKey", "_id", "getMore"]
1125+
for key, value in iteritems(subdict):
1126+
if key not in superdict:
1127+
self.fail('Key %s not found in %s' % (key, superdict))
1128+
if isinstance(value, dict):
1129+
self.assert_dict_is_subset(superdict[key], value)
1130+
continue
1131+
if isinstance(value, (list, tuple)):
1132+
self.assert_list_contents_are_subset(superdict[key], value)
1133+
continue
1134+
if key in exempt_fields:
1135+
# Only check for presence of these exempt fields, but not value.
1136+
self.assertIn(key, superdict)
1137+
else:
1138+
self.assertEqual(superdict[key], value)
1139+
1140+
def check_event(self, event, expectation_dict):
1141+
if event is None:
1142+
self.fail()
1143+
for key, value in iteritems(expectation_dict):
1144+
if isinstance(value, dict):
1145+
self.assert_dict_is_subset(getattr(event, key), value)
1146+
else:
1147+
self.assertEqual(getattr(event, key), value)
1148+
11091149
def tearDown(self):
11101150
self.listener.results.clear()
11111151

@@ -1159,49 +1199,6 @@ def run_operation(client, operation):
11591199
return cmd(**arguments)
11601200

11611201

1162-
def assert_list_contents_are_subset(superlist, sublist):
1163-
assert len(superlist) == len(sublist)
1164-
for super, sub in zip(superlist, sublist):
1165-
if isinstance(sub, dict):
1166-
assert_dict_is_subset(super, sub)
1167-
continue
1168-
if isinstance(sub, (list, tuple)):
1169-
assert_list_contents_are_subset(super, sub)
1170-
continue
1171-
assert super == sub
1172-
1173-
1174-
def assert_dict_is_subset(superdict, subdict):
1175-
"""Check that subdict is a subset of superdict."""
1176-
exempt_fields = ["documentKey", "_id", "getMore"]
1177-
for key, value in iteritems(subdict):
1178-
if key not in superdict:
1179-
assert False
1180-
if isinstance(value, dict):
1181-
assert_dict_is_subset(superdict[key], value)
1182-
continue
1183-
if isinstance(value, (list, tuple)):
1184-
assert_list_contents_are_subset(superdict[key], value)
1185-
continue
1186-
if key in exempt_fields:
1187-
# Only check for presence of these exempt fields, but not value.
1188-
assert key in superdict
1189-
else:
1190-
assert superdict[key] == value
1191-
1192-
1193-
def check_event(event, expectation_dict):
1194-
if event is None:
1195-
raise AssertionError
1196-
for key, value in iteritems(expectation_dict):
1197-
if isinstance(value, dict):
1198-
assert_dict_is_subset(
1199-
getattr(event, key), value
1200-
)
1201-
else:
1202-
assert getattr(event, key) == value
1203-
1204-
12051202
def create_test(scenario_def, test):
12061203
def run_scenario(self):
12071204
# Set up
@@ -1232,7 +1229,7 @@ def run_scenario(self):
12321229
else:
12331230
# Check for expected output from change streams
12341231
for change, expected_changes in zip(changes, test["result"]["success"]):
1235-
assert_dict_is_subset(change, expected_changes)
1232+
self.assert_dict_is_subset(change, expected_changes)
12361233
self.assertEqual(len(changes), len(test["result"]["success"]))
12371234

12381235
finally:
@@ -1242,7 +1239,7 @@ def run_scenario(self):
12421239
for event_type, event_desc in iteritems(expectation):
12431240
results_key = event_type.split("_")[1]
12441241
event = results[results_key][idx] if len(results[results_key]) > idx else None
1245-
check_event(event, event_desc)
1242+
self.check_event(event, event_desc)
12461243

12471244
return run_scenario
12481245

0 commit comments

Comments
 (0)