|
28 | 28 |
|
29 | 29 | def func_returntext():
|
30 | 30 | return "foo"
|
| 31 | +def func_returntextwithnull(): |
| 32 | + return "1\x002" |
31 | 33 | def func_returnunicode():
|
32 | 34 | return "bar"
|
33 | 35 | def func_returnint():
|
@@ -138,11 +140,21 @@ def step(self, val):
|
138 | 140 | def finalize(self):
|
139 | 141 | return self.val
|
140 | 142 |
|
| 143 | +class AggrText: |
| 144 | + def __init__(self): |
| 145 | + self.txt = "" |
| 146 | + def step(self, txt): |
| 147 | + self.txt = self.txt + txt |
| 148 | + def finalize(self): |
| 149 | + return self.txt |
| 150 | + |
| 151 | + |
141 | 152 | class FunctionTests(unittest.TestCase):
|
142 | 153 | def setUp(self):
|
143 | 154 | self.con = sqlite.connect(":memory:")
|
144 | 155 |
|
145 | 156 | self.con.create_function("returntext", 0, func_returntext)
|
| 157 | + self.con.create_function("returntextwithnull", 0, func_returntextwithnull) |
146 | 158 | self.con.create_function("returnunicode", 0, func_returnunicode)
|
147 | 159 | self.con.create_function("returnint", 0, func_returnint)
|
148 | 160 | self.con.create_function("returnfloat", 0, func_returnfloat)
|
@@ -186,6 +198,12 @@ def test_func_return_text(self):
|
186 | 198 | self.assertEqual(type(val), str)
|
187 | 199 | self.assertEqual(val, "foo")
|
188 | 200 |
|
| 201 | + def test_func_return_text_with_null_char(self): |
| 202 | + cur = self.con.cursor() |
| 203 | + res = cur.execute("select returntextwithnull()").fetchone()[0] |
| 204 | + self.assertEqual(type(res), str) |
| 205 | + self.assertEqual(res, "1\x002") |
| 206 | + |
189 | 207 | def test_func_return_unicode(self):
|
190 | 208 | cur = self.con.cursor()
|
191 | 209 | cur.execute("select returnunicode()")
|
@@ -364,6 +382,7 @@ def setUp(self):
|
364 | 382 | self.con.create_aggregate("checkType", 2, AggrCheckType)
|
365 | 383 | self.con.create_aggregate("checkTypes", -1, AggrCheckTypes)
|
366 | 384 | self.con.create_aggregate("mysum", 1, AggrSum)
|
| 385 | + self.con.create_aggregate("aggtxt", 1, AggrText) |
367 | 386 |
|
368 | 387 | def tearDown(self):
|
369 | 388 | #self.cur.close()
|
@@ -457,6 +476,15 @@ def test_aggr_no_match(self):
|
457 | 476 | val = cur.fetchone()[0]
|
458 | 477 | self.assertIsNone(val)
|
459 | 478 |
|
| 479 | + def test_aggr_text(self): |
| 480 | + cur = self.con.cursor() |
| 481 | + for txt in ["foo", "1\x002"]: |
| 482 | + with self.subTest(txt=txt): |
| 483 | + cur.execute("select aggtxt(?) from test", (txt,)) |
| 484 | + val = cur.fetchone()[0] |
| 485 | + self.assertEqual(val, txt) |
| 486 | + |
| 487 | + |
460 | 488 | class AuthorizerTests(unittest.TestCase):
|
461 | 489 | @staticmethod
|
462 | 490 | def authorizer_cb(action, arg1, arg2, dbname, source):
|
|
0 commit comments