@@ -866,96 +866,90 @@ cdef class Int64Factorizer:
866
866
self .count = len (self .uniques)
867
867
return labels
868
868
869
+ ctypedef fused kh_scalar64:
870
+ kh_int64_t
871
+ kh_float64_t
872
+
869
873
@ cython.boundscheck (False )
870
- cdef build_count_table_float64(float64_t[:] values, kh_float64_t * table, bint dropna):
874
+ cdef build_count_table_scalar64(sixty_four_bit_scalar[:] values,
875
+ kh_scalar64 * table, bint dropna):
871
876
cdef:
872
877
khiter_t k
873
878
Py_ssize_t i, n = len (values)
874
- float64_t val
879
+ sixty_four_bit_scalar val
875
880
int ret = 0
876
881
877
- with nogil:
878
- kh_resize_float64(table, n)
882
+ if sixty_four_bit_scalar is float64_t and kh_scalar64 is kh_float64_t:
883
+ with nogil:
884
+ kh_resize_float64(table, n)
879
885
880
- for i in range (n):
881
- val = values[i]
882
- if val == val or not dropna:
883
- k = kh_get_float64(table, val)
886
+ for i in range (n):
887
+ val = values[i]
888
+ if val == val or not dropna:
889
+ k = kh_get_float64(table, val)
890
+ if k != table.n_buckets:
891
+ table.vals[k] += 1
892
+ else :
893
+ k = kh_put_float64(table, val, & ret)
894
+ table.vals[k] = 1
895
+ elif sixty_four_bit_scalar is int64_t and kh_scalar64 is kh_int64_t:
896
+ with nogil:
897
+ kh_resize_int64(table, n)
898
+
899
+ for i in range (n):
900
+ val = values[i]
901
+ k = kh_get_int64(table, val)
884
902
if k != table.n_buckets:
885
903
table.vals[k] += 1
886
904
else :
887
- k = kh_put_float64 (table, val, & ret)
905
+ k = kh_put_int64 (table, val, & ret)
888
906
table.vals[k] = 1
907
+ else :
908
+ raise ValueError (" Table type must match scalar type." )
909
+
910
+
889
911
890
912
@ cython.boundscheck (False )
891
- cpdef value_count_float64(float64_t [:] values, bint dropna):
913
+ cpdef value_count_scalar64(sixty_four_bit_scalar [:] values, bint dropna):
892
914
cdef:
893
915
Py_ssize_t i
894
- kh_float64_t * table
895
- float64_t[:] result_keys
916
+ kh_float64_t * ftable
917
+ kh_int64_t * itable
918
+ sixty_four_bit_scalar[:] result_keys
896
919
int64_t[:] result_counts
897
920
int k
898
921
899
- table = kh_init_float64()
900
- build_count_table_float64(values, table, dropna)
901
-
902
922
i = 0
903
- result_keys = np.empty(table.n_occupied, dtype = np.float64)
904
- result_counts = np.zeros(table.n_occupied, dtype = np.int64)
905
923
906
- with nogil:
907
- for k in range (table.n_buckets):
908
- if kh_exist_float64(table, k):
909
- result_keys[i] = table.keys[k]
910
- result_counts[i] = table.vals[k]
911
- i += 1
912
- kh_destroy_float64(table)
924
+ if sixty_four_bit_scalar is float64_t:
925
+ ftable = kh_init_float64()
926
+ build_count_table_scalar64(values, ftable, dropna)
913
927
914
- return np.asarray(result_keys), np.asarray(result_counts)
928
+ result_keys = np.empty(ftable.n_occupied, dtype = np.float64)
929
+ result_counts = np.zeros(ftable.n_occupied, dtype = np.int64)
915
930
916
- @ cython.boundscheck (False )
917
- cdef build_count_table_int64(int64_t[:] values, kh_int64_t * table):
918
- cdef:
919
- khiter_t k
920
- Py_ssize_t i, n = len (values)
921
- int64_t val
922
- int ret = 0
923
-
924
- with nogil:
925
- kh_resize_int64(table, n)
926
-
927
- for i in range (n):
928
- val = values[i]
929
- k = kh_get_int64(table, val)
930
- if k != table.n_buckets:
931
- table.vals[k] += 1
932
- else :
933
- k = kh_put_int64(table, val, & ret)
934
- table.vals[k] = 1
935
-
936
-
937
- @ cython.boundscheck (False )
938
- cpdef value_count_int64(int64_t[:] values):
939
- cdef:
940
- Py_ssize_t i
941
- kh_int64_t * table
942
- int64_t[:] result_keys, result_counts
943
- int k
931
+ with nogil:
932
+ for k in range (ftable.n_buckets):
933
+ if kh_exist_float64(ftable, k):
934
+ result_keys[i] = ftable.keys[k]
935
+ result_counts[i] = ftable.vals[k]
936
+ i += 1
937
+ kh_destroy_float64(ftable)
944
938
945
- table = kh_init_int64()
946
- build_count_table_int64(values, table)
939
+ elif sixty_four_bit_scalar is int64_t:
940
+ itable = kh_init_int64()
941
+ build_count_table_scalar64(values, itable, dropna)
947
942
948
- i = 0
949
- result_keys = np.empty(table.n_occupied, dtype = np.int64)
950
- result_counts = np.zeros(table.n_occupied, dtype = np.int64)
943
+ result_keys = np.empty(itable.n_occupied, dtype = np.int64)
944
+ result_counts = np.zeros(itable.n_occupied, dtype = np.int64)
951
945
952
- with nogil:
953
- for k in range (table .n_buckets):
954
- if kh_exist_int64(table , k):
955
- result_keys[i] = table .keys[k]
956
- result_counts[i] = table .vals[k]
957
- i += 1
958
- kh_destroy_int64(table )
946
+ with nogil:
947
+ for k in range (itable .n_buckets):
948
+ if kh_exist_int64(itable , k):
949
+ result_keys[i] = itable .keys[k]
950
+ result_counts[i] = itable .vals[k]
951
+ i += 1
952
+ kh_destroy_int64(itable )
959
953
960
954
return np.asarray(result_keys), np.asarray(result_counts)
961
955
@@ -1047,7 +1041,7 @@ def mode_int64(int64_t[:] values):
1047
1041
1048
1042
table = kh_init_int64()
1049
1043
1050
- build_count_table_int64 (values, table)
1044
+ build_count_table_scalar64 (values, table, 0 )
1051
1045
1052
1046
modes = np.empty(table.n_buckets, dtype = np.int64)
1053
1047
0 commit comments