|
27 | 27 |
|
28 | 28 | def func_returntext():
|
29 | 29 | return "foo"
|
| 30 | +def func_returntextwithnull(): |
| 31 | + return "1\x002" |
30 | 32 | def func_returnunicode():
|
31 | 33 | return "bar"
|
32 | 34 | def func_returnint():
|
@@ -137,11 +139,21 @@ def step(self, val):
|
137 | 139 | def finalize(self):
|
138 | 140 | return self.val
|
139 | 141 |
|
| 142 | +class AggrText: |
| 143 | + def __init__(self): |
| 144 | + self.txt = "" |
| 145 | + def step(self, txt): |
| 146 | + self.txt = self.txt + txt |
| 147 | + def finalize(self): |
| 148 | + return self.txt |
| 149 | + |
| 150 | + |
140 | 151 | class FunctionTests(unittest.TestCase):
|
141 | 152 | def setUp(self):
|
142 | 153 | self.con = sqlite.connect(":memory:")
|
143 | 154 |
|
144 | 155 | self.con.create_function("returntext", 0, func_returntext)
|
| 156 | + self.con.create_function("returntextwithnull", 0, func_returntextwithnull) |
145 | 157 | self.con.create_function("returnunicode", 0, func_returnunicode)
|
146 | 158 | self.con.create_function("returnint", 0, func_returnint)
|
147 | 159 | self.con.create_function("returnfloat", 0, func_returnfloat)
|
@@ -185,6 +197,12 @@ def CheckFuncReturnText(self):
|
185 | 197 | self.assertEqual(type(val), str)
|
186 | 198 | self.assertEqual(val, "foo")
|
187 | 199 |
|
| 200 | + def CheckFuncReturnTextWithNullChar(self): |
| 201 | + cur = self.con.cursor() |
| 202 | + res = cur.execute("select returntextwithnull()").fetchone()[0] |
| 203 | + self.assertEqual(type(res), str) |
| 204 | + self.assertEqual(res, "1\x002") |
| 205 | + |
188 | 206 | def CheckFuncReturnUnicode(self):
|
189 | 207 | cur = self.con.cursor()
|
190 | 208 | cur.execute("select returnunicode()")
|
@@ -343,6 +361,7 @@ def setUp(self):
|
343 | 361 | self.con.create_aggregate("checkType", 2, AggrCheckType)
|
344 | 362 | self.con.create_aggregate("checkTypes", -1, AggrCheckTypes)
|
345 | 363 | self.con.create_aggregate("mysum", 1, AggrSum)
|
| 364 | + self.con.create_aggregate("aggtxt", 1, AggrText) |
346 | 365 |
|
347 | 366 | def tearDown(self):
|
348 | 367 | #self.cur.close()
|
@@ -431,6 +450,15 @@ def CheckAggrCheckAggrSum(self):
|
431 | 450 | val = cur.fetchone()[0]
|
432 | 451 | self.assertEqual(val, 60)
|
433 | 452 |
|
| 453 | + def CheckAggrText(self): |
| 454 | + cur = self.con.cursor() |
| 455 | + for txt in ["foo", "1\x002"]: |
| 456 | + with self.subTest(txt=txt): |
| 457 | + cur.execute("select aggtxt(?) from test", (txt,)) |
| 458 | + val = cur.fetchone()[0] |
| 459 | + self.assertEqual(val, txt) |
| 460 | + |
| 461 | + |
434 | 462 | class AuthorizerTests(unittest.TestCase):
|
435 | 463 | @staticmethod
|
436 | 464 | def authorizer_cb(action, arg1, arg2, dbname, source):
|
|
0 commit comments