Skip to content

Commit d46b221

Browse files
[3.9] bpo-45500: Rewrite test_dbm (GH-29002) (GH-29074)
* Generate test classes at import time. It allows to filter them when run with unittest. E.g: "./python -m unittest test.test_dbm.TestCase_gnu -v". * Create a database class in a new directory which will be removed after test. It guarantees that all created files and directories be removed and will not conflict with other dbm tests. * Restore dbm._defaultmod after tests. Previously it was set to the last dbm module (dbm.dumb) which affected other tests. * Enable the whichdb test for dbm.dumb. * Move test_keys to the correct test class. It does not test whichdb(). * Remove some outdated code and comments.. (cherry picked from commit 975b94b) Co-authored-by: Serhiy Storchaka <[email protected]> Co-authored-by: Serhiy Storchaka <[email protected]>
1 parent a18e4e9 commit d46b221

File tree

1 file changed

+50
-64
lines changed

1 file changed

+50
-64
lines changed

Lib/test/test_dbm.py

Lines changed: 50 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,20 @@
11
"""Test script for the dbm.open function based on testdumbdbm.py"""
22

33
import unittest
4-
import glob
4+
import dbm
5+
import os
56
import test.support
67

7-
# Skip tests if dbm module doesn't exist.
8-
dbm = test.support.import_module('dbm')
9-
108
try:
119
from dbm import ndbm
1210
except ImportError:
1311
ndbm = None
1412

15-
_fname = test.support.TESTFN
13+
dirname = test.support.TESTFN
14+
_fname = os.path.join(dirname, test.support.TESTFN)
1615

1716
#
18-
# Iterates over every database module supported by dbm currently available,
19-
# setting dbm to use each in turn, and yielding that module
17+
# Iterates over every database module supported by dbm currently available.
2018
#
2119
def dbm_iterator():
2220
for name in dbm._names:
@@ -30,11 +28,12 @@ def dbm_iterator():
3028
#
3129
# Clean up all scratch databases we might have created during testing
3230
#
33-
def delete_files():
34-
# we don't know the precise name the underlying database uses
35-
# so we use glob to locate all names
36-
for f in glob.glob(glob.escape(_fname) + "*"):
37-
test.support.unlink(f)
31+
def cleaunup_test_dir():
32+
test.support.rmtree(dirname)
33+
34+
def setup_test_dir():
35+
cleaunup_test_dir()
36+
os.mkdir(dirname)
3837

3938

4039
class AnyDBMTestCase:
@@ -133,80 +132,67 @@ def read_helper(self, f):
133132
for key in self._dict:
134133
self.assertEqual(self._dict[key], f[key.encode("ascii")])
135134

136-
def tearDown(self):
137-
delete_files()
135+
def test_keys(self):
136+
with dbm.open(_fname, 'c') as d:
137+
self.assertEqual(d.keys(), [])
138+
a = [(b'a', b'b'), (b'12345678910', b'019237410982340912840198242')]
139+
for k, v in a:
140+
d[k] = v
141+
self.assertEqual(sorted(d.keys()), sorted(k for (k, v) in a))
142+
for k, v in a:
143+
self.assertIn(k, d)
144+
self.assertEqual(d[k], v)
145+
self.assertNotIn(b'xxx', d)
146+
self.assertRaises(KeyError, lambda: d[b'xxx'])
138147

139148
def setUp(self):
149+
self.addCleanup(setattr, dbm, '_defaultmod', dbm._defaultmod)
140150
dbm._defaultmod = self.module
141-
delete_files()
151+
self.addCleanup(cleaunup_test_dir)
152+
setup_test_dir()
142153

143154

144155
class WhichDBTestCase(unittest.TestCase):
145156
def test_whichdb(self):
157+
self.addCleanup(setattr, dbm, '_defaultmod', dbm._defaultmod)
146158
for module in dbm_iterator():
147159
# Check whether whichdb correctly guesses module name
148160
# for databases opened with "module" module.
149-
# Try with empty files first
150161
name = module.__name__
151-
if name == 'dbm.dumb':
152-
continue # whichdb can't support dbm.dumb
153-
delete_files()
154-
f = module.open(_fname, 'c')
155-
f.close()
162+
setup_test_dir()
163+
dbm._defaultmod = module
164+
# Try with empty files first
165+
with module.open(_fname, 'c'): pass
156166
self.assertEqual(name, self.dbm.whichdb(_fname))
157167
# Now add a key
158-
f = module.open(_fname, 'w')
159-
f[b"1"] = b"1"
160-
# and test that we can find it
161-
self.assertIn(b"1", f)
162-
# and read it
163-
self.assertEqual(f[b"1"], b"1")
164-
f.close()
168+
with module.open(_fname, 'w') as f:
169+
f[b"1"] = b"1"
170+
# and test that we can find it
171+
self.assertIn(b"1", f)
172+
# and read it
173+
self.assertEqual(f[b"1"], b"1")
165174
self.assertEqual(name, self.dbm.whichdb(_fname))
166175

167176
@unittest.skipUnless(ndbm, reason='Test requires ndbm')
168177
def test_whichdb_ndbm(self):
169178
# Issue 17198: check that ndbm which is referenced in whichdb is defined
170-
db_file = '{}_ndbm.db'.format(_fname)
171-
with open(db_file, 'w'):
172-
self.addCleanup(test.support.unlink, db_file)
173-
self.assertIsNone(self.dbm.whichdb(db_file[:-3]))
174-
175-
def tearDown(self):
176-
delete_files()
179+
with open(_fname + '.db', 'wb'): pass
180+
self.assertIsNone(self.dbm.whichdb(_fname))
177181

178182
def setUp(self):
179-
delete_files()
180-
self.filename = test.support.TESTFN
181-
self.d = dbm.open(self.filename, 'c')
182-
self.d.close()
183+
self.addCleanup(cleaunup_test_dir)
184+
setup_test_dir()
183185
self.dbm = test.support.import_fresh_module('dbm')
184186

185-
def test_keys(self):
186-
self.d = dbm.open(self.filename, 'c')
187-
self.assertEqual(self.d.keys(), [])
188-
a = [(b'a', b'b'), (b'12345678910', b'019237410982340912840198242')]
189-
for k, v in a:
190-
self.d[k] = v
191-
self.assertEqual(sorted(self.d.keys()), sorted(k for (k, v) in a))
192-
for k, v in a:
193-
self.assertIn(k, self.d)
194-
self.assertEqual(self.d[k], v)
195-
self.assertNotIn(b'xxx', self.d)
196-
self.assertRaises(KeyError, lambda: self.d[b'xxx'])
197-
self.d.close()
198-
199-
200-
def load_tests(loader, tests, pattern):
201-
classes = []
202-
for mod in dbm_iterator():
203-
classes.append(type("TestCase-" + mod.__name__,
204-
(AnyDBMTestCase, unittest.TestCase),
205-
{'module': mod}))
206-
suites = [unittest.makeSuite(c) for c in classes]
207-
208-
tests.addTests(suites)
209-
return tests
187+
188+
for mod in dbm_iterator():
189+
assert mod.__name__.startswith('dbm.')
190+
suffix = mod.__name__[4:]
191+
testname = f'TestCase_{suffix}'
192+
globals()[testname] = type(testname,
193+
(AnyDBMTestCase, unittest.TestCase),
194+
{'module': mod})
195+
210196

211197
if __name__ == "__main__":
212198
unittest.main()

0 commit comments

Comments
 (0)