16
16
# under the License.
17
17
18
18
import datafusion
19
- import pyarrow
19
+ import pyarrow as pa
20
20
import pyarrow .compute
21
21
from datafusion import Accumulator , col , udaf
22
22
@@ -26,48 +26,44 @@ class MyAccumulator(Accumulator):
26
26
Interface of a user-defined accumulation.
27
27
"""
28
28
29
- def __init__ (self ):
30
- self ._sum = pyarrow .scalar (0.0 )
29
+ def __init__ (self ) -> None :
30
+ self ._sum = pa .scalar (0.0 )
31
31
32
- def update (self , values : pyarrow .Array ) -> None :
32
+ def update (self , values : pa .Array ) -> None :
33
33
# not nice since pyarrow scalars can't be summed yet. This breaks on `None`
34
- self ._sum = pyarrow .scalar (
35
- self ._sum .as_py () + pyarrow .compute .sum (values ).as_py ()
36
- )
34
+ self ._sum = pa .scalar (self ._sum .as_py () + pa .compute .sum (values ).as_py ())
37
35
38
- def merge (self , states : pyarrow .Array ) -> None :
36
+ def merge (self , states : pa .Array ) -> None :
39
37
# not nice since pyarrow scalars can't be summed yet. This breaks on `None`
40
- self ._sum = pyarrow .scalar (
41
- self ._sum .as_py () + pyarrow .compute .sum (states ).as_py ()
42
- )
38
+ self ._sum = pa .scalar (self ._sum .as_py () + pa .compute .sum (states ).as_py ())
43
39
44
- def state (self ) -> pyarrow .Array :
45
- return pyarrow .array ([self ._sum .as_py ()])
40
+ def state (self ) -> pa .Array :
41
+ return pa .array ([self ._sum .as_py ()])
46
42
47
- def evaluate (self ) -> pyarrow .Scalar :
43
+ def evaluate (self ) -> pa .Scalar :
48
44
return self ._sum
49
45
50
46
51
47
# create a context
52
48
ctx = datafusion .SessionContext ()
53
49
54
50
# create a RecordBatch and a new DataFrame from it
55
- batch = pyarrow .RecordBatch .from_arrays (
56
- [pyarrow .array ([1 , 2 , 3 ]), pyarrow .array ([4 , 5 , 6 ])],
51
+ batch = pa .RecordBatch .from_arrays (
52
+ [pa .array ([1 , 2 , 3 ]), pa .array ([4 , 5 , 6 ])],
57
53
names = ["a" , "b" ],
58
54
)
59
55
df = ctx .create_dataframe ([[batch ]])
60
56
61
57
my_udaf = udaf (
62
58
MyAccumulator ,
63
- pyarrow .float64 (),
64
- pyarrow .float64 (),
65
- [pyarrow .float64 ()],
59
+ pa .float64 (),
60
+ pa .float64 (),
61
+ [pa .float64 ()],
66
62
"stable" ,
67
63
)
68
64
69
65
df = df .aggregate ([], [my_udaf (col ("a" ))])
70
66
71
67
result = df .collect ()[0 ]
72
68
73
- assert result .column (0 ) == pyarrow .array ([6.0 ])
69
+ assert result .column (0 ) == pa .array ([6.0 ])
0 commit comments