Skip to content

Commit 2ac989b

Browse files
committed
support types with comma inside map
1 parent 2e6331e commit 2ac989b

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

proton_driver/columns/mapcolumn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import re
12
from .base import Column
23
from .intcolumn import UInt64Column
34
from ..util.helpers import pairwise
45

6+
comma_re = re.compile(r',(?![^()]*\))')
7+
58

69
class MapColumn(Column):
710
py_types = (dict, )
@@ -51,7 +54,9 @@ def write_items(self, items, buf):
5154

5255

5356
def create_map_column(spec, column_by_spec_getter):
54-
key, value = spec[4:-1].split(',')
57+
# Match commas outside of parentheses, so we don't match the comma in
58+
# Decimal types.
59+
key, value = comma_re.split(spec[4:-1])
5560
key_column = column_by_spec_getter(key.strip())
5661
value_column = column_by_spec_getter(value.strip())
5762

tests/columns/test_map.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from tests.testcase import BaseTestCase
2+
from decimal import Decimal
23

34

45
class MapTestCase(BaseTestCase):
@@ -22,7 +23,7 @@ def _sorted_dicts(self, text):
2223
return '\n'.join(items) + '\n'
2324

2425
def test_simple(self):
25-
with self.create_stream('a map(string, uint64)'):
26+
with self.create_stream('a map(string, int)'):
2627
data = [
2728
({},),
2829
({'key1': 1}, ),
@@ -100,3 +101,22 @@ def test_array(self):
100101
)
101102
inserted = self.client.execute(query)
102103
self.assertEqual(inserted, data)
104+
105+
def test_decimal(self):
106+
with self.create_stream('a map(string, Decimal(9, 2))'):
107+
data = [
108+
({'key1': Decimal('123.45')}, ),
109+
({'key2': Decimal('234.56')}, ),
110+
({'key3': Decimal('345.67')}, )
111+
]
112+
self.client.execute('INSERT INTO test (a) VALUES', data)
113+
query = 'SELECT * FROM test'
114+
inserted = self.emit_cli(query)
115+
self.assertEqual(
116+
inserted,
117+
"{'key1':123.45}\n"
118+
"{'key2':234.56}\n"
119+
"{'key3':345.67}\n"
120+
)
121+
inserted = self.client.execute(query)
122+
self.assertEqual(inserted, data)

0 commit comments

Comments
 (0)