@@ -866,26 +866,47 @@ cdef class Int64Factorizer:
866
866
self .count = len (self .uniques)
867
867
return labels
868
868
869
+ ctypedef fused kh_scalar64:
870
+ kh_int64_t
871
+ kh_float64_t
872
+
869
873
@ cython.boundscheck (False )
870
- cdef build_count_table_float64(float64_t[:] values, kh_float64_t * table, bint dropna):
874
+ cdef build_count_table_scalar64(sixty_four_bit_scalar[:] values,
875
+ kh_scalar64 * table, bint dropna):
871
876
cdef:
872
877
khiter_t k
873
878
Py_ssize_t i, n = len (values)
874
- float64_t val
879
+ sixty_four_bit_scalar val
875
880
int ret = 0
876
881
877
- with nogil:
878
- kh_resize_float64(table, n)
882
+ if sixty_four_bit_scalar is float64_t and kh_scalar64 is kh_float64_t:
883
+ with nogil:
884
+ kh_resize_float64(table, n)
879
885
880
- for i in range (n):
881
- val = values[i]
882
- if val == val or not dropna:
883
- k = kh_get_float64(table, val)
886
+ for i in range (n):
887
+ val = values[i]
888
+ if val == val or not dropna:
889
+ k = kh_get_float64(table, val)
890
+ if k != table.n_buckets:
891
+ table.vals[k] += 1
892
+ else :
893
+ k = kh_put_float64(table, val, & ret)
894
+ table.vals[k] = 1
895
+ elif sixty_four_bit_scalar is int64_t and kh_scalar64 is kh_int64_t:
896
+ with nogil:
897
+ kh_resize_int64(table, n)
898
+
899
+ for i in range (n):
900
+ val = values[i]
901
+ k = kh_get_int64(table, val)
884
902
if k != table.n_buckets:
885
903
table.vals[k] += 1
886
904
else :
887
- k = kh_put_float64 (table, val, & ret)
905
+ k = kh_put_int64 (table, val, & ret)
888
906
table.vals[k] = 1
907
+ else :
908
+ raise ValueError (" Table type must match scalar type." )
909
+
889
910
890
911
891
912
@ cython.boundscheck (False )
@@ -902,7 +923,7 @@ cpdef value_count_scalar64(sixty_four_bit_scalar[:] values, bint dropna):
902
923
903
924
if sixty_four_bit_scalar is float64_t:
904
925
ftable = kh_init_float64()
905
- build_count_table_float64 (values, ftable, dropna)
926
+ build_count_table_scalar64 (values, ftable, dropna)
906
927
907
928
result_keys = np.empty(ftable.n_occupied, dtype = np.float64)
908
929
result_counts = np.zeros(ftable.n_occupied, dtype = np.int64)
@@ -917,7 +938,7 @@ cpdef value_count_scalar64(sixty_four_bit_scalar[:] values, bint dropna):
917
938
918
939
elif sixty_four_bit_scalar is int64_t:
919
940
itable = kh_init_int64()
920
- build_count_table_int64 (values, itable)
941
+ build_count_table_scalar64 (values, itable, dropna )
921
942
922
943
result_keys = np.empty(itable.n_occupied, dtype = np.int64)
923
944
result_counts = np.zeros(itable.n_occupied, dtype = np.int64)
@@ -932,26 +953,6 @@ cpdef value_count_scalar64(sixty_four_bit_scalar[:] values, bint dropna):
932
953
933
954
return np.asarray(result_keys), np.asarray(result_counts)
934
955
935
- @ cython.boundscheck (False )
936
- cdef build_count_table_int64(int64_t[:] values, kh_int64_t * table):
937
- cdef:
938
- khiter_t k
939
- Py_ssize_t i, n = len (values)
940
- int64_t val
941
- int ret = 0
942
-
943
- with nogil:
944
- kh_resize_int64(table, n)
945
-
946
- for i in range (n):
947
- val = values[i]
948
- k = kh_get_int64(table, val)
949
- if k != table.n_buckets:
950
- table.vals[k] += 1
951
- else :
952
- k = kh_put_int64(table, val, & ret)
953
- table.vals[k] = 1
954
-
955
956
956
957
cdef build_count_table_object(ndarray[object ] values,
957
958
ndarray[uint8_t, cast= True ] mask,
@@ -1040,7 +1041,7 @@ def mode_int64(int64_t[:] values):
1040
1041
1041
1042
table = kh_init_int64()
1042
1043
1043
- build_count_table_int64 (values, table)
1044
+ build_count_table_scalar64 (values, table, 0 )
1044
1045
1045
1046
modes = np.empty(table.n_buckets, dtype = np.int64)
1046
1047
0 commit comments