Skip to content

Commit c19c0e0

Browse files
authored
Supporting dataframe and update idempotent query example (#57)
* Update idempotent insert example * Update df(),pl(),arrow() and fetchnumpy()
1 parent 7efd9fe commit c19c0e0

File tree

8 files changed

+343
-4
lines changed

8 files changed

+343
-4
lines changed

README.rst

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ Insert Data
9292
9393
Pandas DataFrame
9494
----------------
95-
Big fan of Pandas? We too! You can mix SQL and Pandas API together:
95+
Big fan of Pandas? We too! You can mix SQL and Pandas API together. Also you can converting query results to a variety of formats(e.g. Numpy Array, Pandas DataFrame, Polars DataFrame, Arrow Table) by DBAPI.
96+
9697

9798
.. code-block:: python
9899
@@ -128,3 +129,18 @@ Big fan of Pandas? We too! You can mix SQL and Pandas API together:
128129
df = c.query_dataframe('SELECT * FROM table(test)')
129130
print(df)
130131
print(df.describe())
132+
133+
# Converting query results to a variety of formats with dbapi
134+
with connect('proton://localhost') as conn:
135+
with conn.cursor() as cur:
136+
cur.execute('SELECT * FROM table(test)')
137+
print(cur.df()) # Pandas DataFrame
138+
139+
cur.execute('SELECT * FROM table(test)')
140+
print(cur.fetchnumpy()) # Numpy Arrays
141+
142+
cur.execute('SELECT * FROM table(test)')
143+
print(cur.pl()) # Polars DataFrame
144+
145+
cur.execute('SELECT * FROM table(test)')
146+
print(cur.arrow()) # Arrow Table

example/idempotent/idempotent.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from proton_driver import connect, Client
2+
from datetime import date
3+
from time import sleep
4+
5+
6+
# Create a test stream
7+
def create_test_stream(operator, table_name, table_columns):
8+
operator.execute(f'DROP STREAM IF EXISTS {table_name};')
9+
operator.execute(f'CREATE STREAM {table_name} ({table_columns})')
10+
11+
12+
# Use dbapi to implement idempotent insertion
13+
def use_dbapi():
14+
with connect('proton://localhost') as conn:
15+
with conn.cursor() as cur:
16+
create_test_stream(
17+
cur,
18+
'test_user',
19+
'id int32, name string, birthday date'
20+
)
21+
# Set idempotent_id.
22+
cur.set_settings(dict(idempotent_id='batch1'))
23+
# Insert data into test_user multiple times with the same idempotent_id. # noqa
24+
# The query result should contain only the first inserted data.
25+
data = [
26+
(123456, 'timeplus', date(2024, 10, 24)),
27+
(789012, 'stream ', date(2023, 10, 24)),
28+
(135790, 'proton ', date(2024, 10, 24)),
29+
(246801, 'database', date(2024, 10, 24)),
30+
]
31+
# Execute multiple insert operations.
32+
for _ in range(10):
33+
cur.execute(
34+
'INSERT INTO test_user (id, name, birthday) VALUES',
35+
data
36+
)
37+
cur.fetchall()
38+
# wait for 3 sec to make sure data available in historical store.
39+
sleep(3)
40+
cur.execute('SELECT count() FROM table(test_user)')
41+
res = cur.fetchall()
42+
# Data is inserted only once,so res == (4,).
43+
print(res)
44+
45+
46+
# Use Client to implement idempotent insertion
47+
def use_client():
48+
cli = Client('localhost', 8463)
49+
create_test_stream(cli, 'test_stream', '`i` int, `v` string')
50+
setting = {
51+
'idempotent_id': 'batch1'
52+
}
53+
data = [
54+
(1, 'a'), (2, 'b'), (3, 'c'), (4, 'd'),
55+
(5, 'e'), (6, 'f'), (7, 'g'), (8, 'h')
56+
]
57+
# Execute multiple insert operations.
58+
for _ in range(10):
59+
cli.execute(
60+
'INSERT INTO test_stream (i, v) VALUES',
61+
data,
62+
settings=setting
63+
)
64+
# wait for 3 sec to make sure data available in historical store.
65+
sleep(3)
66+
res = cli.execute('SELECT count() FROM table(test_stream)')
67+
# Data is inserted only once,so res == (8,).
68+
print(res)
69+
70+
71+
if __name__ == "__main__":
72+
use_dbapi() # (4,)
73+
use_client() # (8,)

example/pandas/dataframe.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pandas as pd
22
import time
33

4-
from proton_driver import client
4+
from proton_driver import client, connect
55

66
if __name__ == "__main__":
77
c = client.Client(host='127.0.0.1', port=8463)
@@ -37,3 +37,22 @@
3737
df = c.query_dataframe('SELECT * FROM table(test)')
3838
print(df)
3939
print(df.describe())
40+
41+
# Converting query results to a variety of formats with dbapi
42+
with connect('proton://localhost') as conn:
43+
with conn.cursor() as cur:
44+
cur.execute('SELECT * FROM table(test)')
45+
print('--------------Pandas DataFrame--------------')
46+
print(cur.df())
47+
48+
cur.execute('SELECT * FROM table(test)')
49+
print('----------------Numpy Arrays----------------')
50+
print(cur.fetchnumpy())
51+
52+
cur.execute('SELECT * FROM table(test)')
53+
print('--------------Polars DataFrame--------------')
54+
print(cur.pl())
55+
56+
cur.execute('SELECT * FROM table(test)')
57+
print('-----------------Arrow Table----------------')
58+
print(cur.arrow())

proton_driver/dbapi/cursor.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,78 @@ def fetchall(self):
208208
self._rows = []
209209
return rv
210210

211+
def df(self):
212+
"""
213+
Fetch all (remaining) rows of a query result, returning them as
214+
a pandas DataFrame.
215+
216+
:return: Pandas DataFrame of fetched rows.
217+
"""
218+
self._check_query_started()
219+
220+
import pandas as pd
221+
222+
rv = pd.DataFrame({
223+
name: [row[i] for row in self._rows] if name else None
224+
for i, name in enumerate(self._columns)
225+
})
226+
self._rows = []
227+
return rv
228+
229+
def fetchnumpy(self):
230+
"""
231+
Fetch all (remaining) rows of a query result, returning
232+
them as a dictionary of NumPy arrays.
233+
234+
:return: Dictionary of NumPy arrays of fetched rows.
235+
"""
236+
self._check_query_started()
237+
238+
import numpy as np
239+
240+
rv = {
241+
name: np.array([row[i] for row in self._rows]) if name else None
242+
for i, name in enumerate(self._columns)
243+
}
244+
self._rows = []
245+
return rv
246+
247+
def pl(self):
248+
"""
249+
Fetch all (remaining) rows of a query result, returning them as
250+
a Polars DataFrame.
251+
252+
:return: Polars DataFrame of fetched rows.
253+
"""
254+
self._check_query_started()
255+
256+
import polars as pl
257+
258+
rv = pl.DataFrame({
259+
name: [row[i] for row in self._rows] if name else None
260+
for i, name in enumerate(self._columns)
261+
})
262+
self._rows = []
263+
return rv
264+
265+
def arrow(self):
266+
"""
267+
Fetch all (remaining) rows of a query result, returning them as
268+
a Arrow Table.
269+
270+
:return: Arrow Table of fetched rows.
271+
"""
272+
self._check_query_started()
273+
274+
import pyarrow as pa
275+
276+
rv = pa.table({
277+
name: [row[i] for row in self._rows] if name else None
278+
for i, name in enumerate(self._columns)
279+
})
280+
self._rows = []
281+
return rv
282+
211283
def setinputsizes(self, sizes):
212284
# Do nothing.
213285
pass

proton_driver/settings/available.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,4 +402,7 @@
402402
'format_regexp_escaping_rule': SettingString,
403403
'format_regexp_skip_unmatched': SettingBool,
404404
'output_format_enable_streaming': SettingBool,
405+
406+
'idempotent_id': SettingString,
407+
'enable_idempotent_processing': SettingBool,
405408
}

tests/numpy/test_generic.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
from tests.testcase import BaseTestCase
1010
from tests.numpy.testcase import NumpyBaseTestCase
11+
from proton_driver import connect
12+
from datetime import datetime
13+
from decimal import Decimal
1114

1215

1316
class GenericTestCase(NumpyBaseTestCase):
@@ -171,3 +174,110 @@ def test_query_dataframe(self):
171174
self.assertEqual(
172175
'Extras for NumPy must be installed', str(e.exception)
173176
)
177+
178+
179+
class DataFrameDBAPITestCase(NumpyBaseTestCase):
180+
types = \
181+
'a int64, b string, c datetime,' \
182+
'd fixed_string(10), e decimal(9, 5), f float64,' \
183+
'g low_cardinality(string), h nullable(int32)'
184+
185+
columns = 'a,b,c,d,e,f,g,h'
186+
data = [
187+
[
188+
123, 'abc', datetime(2024, 5, 20, 12, 11, 10),
189+
'abcefgcxxx', Decimal('300.42'), 3.402823e12,
190+
'127001', 332
191+
],
192+
[
193+
456, 'cde', datetime(2024, 6, 21, 12, 13, 50),
194+
'1234567890', Decimal('171.31'), -3.4028235e13,
195+
'127001', None
196+
],
197+
[
198+
789, 'efg', datetime(1998, 7, 22, 12, 30, 10),
199+
'stream sql', Decimal('894.22'), float('inf'),
200+
'127001', None
201+
],
202+
]
203+
204+
def setUp(self):
205+
super(DataFrameDBAPITestCase, self).setUp()
206+
self.conn = connect('proton://localhost')
207+
self.cur = self.conn.cursor()
208+
self.cur.execute('DROP STREAM IF EXISTS test')
209+
self.cur.execute(f'CREATE STREAM test ({self.types}) ENGINE = Memory')
210+
self.cur.execute(
211+
f'INSERT INTO test ({self.columns}) VALUES',
212+
self.data
213+
)
214+
self.cur.execute(f'SELECT {self.columns} FROM test')
215+
216+
def tearDown(self):
217+
super(DataFrameDBAPITestCase, self).tearDown()
218+
self.cur.execute('DROP STREAM test')
219+
220+
def test_dbapi_fetchnumpy(self):
221+
expect = {
222+
col: np.array([row[i] for row in self.data])
223+
for i, col in enumerate(self.columns.split(','))
224+
}
225+
rv = self.cur.fetchnumpy()
226+
for key, value in expect.items():
227+
self.assertIsNotNone(rv.get(key))
228+
self.assertarraysEqual(value, rv[key])
229+
230+
def test_dbapi_df(self):
231+
expect = pd.DataFrame(self.data, columns=self.columns.split(','))
232+
df = self.cur.df()
233+
234+
self.assertIsInstance(df, pd.DataFrame)
235+
self.assertEqual(df.shape, (3, 8))
236+
self.assertEqual(
237+
[type.name for type in df.dtypes],
238+
['int64', 'object', 'datetime64[ns]',
239+
'object', 'object', 'float64',
240+
'object', 'float64']
241+
)
242+
self.assertTrue(expect.equals(df))
243+
244+
def test_dbapi_pl(self):
245+
try:
246+
import polars as pl
247+
except ImportError:
248+
self.skipTest('Polars extras are not installed')
249+
250+
expect = pl.DataFrame({
251+
col: [row[i] for row in self.data]
252+
for i, col in enumerate(self.columns.split(','))
253+
})
254+
255+
df = self.cur.pl()
256+
self.assertIsInstance(df, pl.DataFrame)
257+
self.assertEqual(df.shape, (3, 8))
258+
self.assertSequenceEqual(
259+
df.schema.dtypes(),
260+
[pl.Int64, pl.String, pl.Datetime, pl.String,
261+
pl.Decimal, pl.Float64, pl.String, pl.Int64]
262+
)
263+
self.assertTrue(expect.equals(df))
264+
265+
def test_dbapi_arrow(self):
266+
try:
267+
import pyarrow as pa
268+
except ImportError:
269+
self.skipTest('Pyarrow extras are not installed')
270+
271+
expect = pa.table({
272+
col: [row[i] for row in self.data]
273+
for i, col in enumerate(self.columns.split(','))
274+
})
275+
at = self.cur.arrow()
276+
self.assertEqual(at.shape, (3, 8))
277+
self.assertSequenceEqual(
278+
at.schema.types,
279+
[pa.int64(), pa.string(), pa.timestamp('us'),
280+
pa.string(), pa.decimal128(5, 2), pa.float64(),
281+
pa.string(), pa.int64()]
282+
)
283+
self.assertTrue(expect.equals(at))

tests/test_dbapi.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from contextlib import contextmanager
44
import socket
55
from unittest.mock import patch
6-
6+
from time import sleep
77
from proton_driver import connect
88
from proton_driver.dbapi import (
99
ProgrammingError, InterfaceError, OperationalError
@@ -159,6 +159,27 @@ def test_execute_insert(self):
159159
cursor.execute('INSERT INTO test VALUES', [[4]])
160160
self.assertEqual(cursor.rowcount, 1)
161161

162+
def test_idempotent_insert(self):
163+
with self.created_cursor() as cursor:
164+
cursor.execute('CREATE STREAM test (i int, v string)')
165+
data = [
166+
(123, 'abc'), (456, 'def'), (789, 'ghi'),
167+
(987, 'ihg'), (654, 'fed'), (321, 'cba'),
168+
]
169+
cursor.set_settings(dict(idempotent_id='batch1'))
170+
for _ in range(10):
171+
cursor.execute(
172+
'INSERT INTO test (i, v) VALUES',
173+
data
174+
)
175+
self.assertEqual(cursor.rowcount, 6)
176+
sleep(3)
177+
rv = cursor.execute('SELECT count(*) FROM table(test)')
178+
rv = cursor.fetchall()
179+
self.assertEqual(rv, [(6,)])
180+
181+
cursor.execute('DROP STREAM test')
182+
162183
def test_description(self):
163184
with self.created_cursor() as cursor:
164185
self.assertIsNone(cursor.description)

0 commit comments

Comments
 (0)