Skip to content

Commit 8bb17cb

Browse files
CLN: Combined build_count_table_int64 and build_count_table_float64 into a
single function using fused types.
1 parent cf002dc commit 8bb17cb

File tree

1 file changed

+33
-32
lines changed

1 file changed

+33
-32
lines changed

Diff for: pandas/hashtable.pyx

+33-32
Original file line numberDiff line numberDiff line change
@@ -866,26 +866,47 @@ cdef class Int64Factorizer:
866866
self.count = len(self.uniques)
867867
return labels
868868

869+
ctypedef fused kh_scalar64:
870+
kh_int64_t
871+
kh_float64_t
872+
869873
@cython.boundscheck(False)
870-
cdef build_count_table_float64(float64_t[:] values, kh_float64_t *table, bint dropna):
874+
cdef build_count_table_scalar64(sixty_four_bit_scalar[:] values,
875+
kh_scalar64 *table, bint dropna):
871876
cdef:
872877
khiter_t k
873878
Py_ssize_t i, n = len(values)
874-
float64_t val
879+
sixty_four_bit_scalar val
875880
int ret = 0
876881

877-
with nogil:
878-
kh_resize_float64(table, n)
882+
if sixty_four_bit_scalar is float64_t and kh_scalar64 is kh_float64_t:
883+
with nogil:
884+
kh_resize_float64(table, n)
879885

880-
for i in range(n):
881-
val = values[i]
882-
if val == val or not dropna:
883-
k = kh_get_float64(table, val)
886+
for i in range(n):
887+
val = values[i]
888+
if val == val or not dropna:
889+
k = kh_get_float64(table, val)
890+
if k != table.n_buckets:
891+
table.vals[k] += 1
892+
else:
893+
k = kh_put_float64(table, val, &ret)
894+
table.vals[k] = 1
895+
elif sixty_four_bit_scalar is int64_t and kh_scalar64 is kh_int64_t:
896+
with nogil:
897+
kh_resize_int64(table, n)
898+
899+
for i in range(n):
900+
val = values[i]
901+
k = kh_get_int64(table, val)
884902
if k != table.n_buckets:
885903
table.vals[k] += 1
886904
else:
887-
k = kh_put_float64(table, val, &ret)
905+
k = kh_put_int64(table, val, &ret)
888906
table.vals[k] = 1
907+
else:
908+
raise ValueError("Table type must match scalar type.")
909+
889910

890911

891912
@cython.boundscheck(False)
@@ -902,7 +923,7 @@ cpdef value_count_scalar64(sixty_four_bit_scalar[:] values, bint dropna):
902923

903924
if sixty_four_bit_scalar is float64_t:
904925
ftable = kh_init_float64()
905-
build_count_table_float64(values, ftable, dropna)
926+
build_count_table_scalar64(values, ftable, dropna)
906927

907928
result_keys = np.empty(ftable.n_occupied, dtype=np.float64)
908929
result_counts = np.zeros(ftable.n_occupied, dtype=np.int64)
@@ -917,7 +938,7 @@ cpdef value_count_scalar64(sixty_four_bit_scalar[:] values, bint dropna):
917938

918939
elif sixty_four_bit_scalar is int64_t:
919940
itable = kh_init_int64()
920-
build_count_table_int64(values, itable)
941+
build_count_table_scalar64(values, itable, dropna)
921942

922943
result_keys = np.empty(itable.n_occupied, dtype=np.int64)
923944
result_counts = np.zeros(itable.n_occupied, dtype=np.int64)
@@ -932,26 +953,6 @@ cpdef value_count_scalar64(sixty_four_bit_scalar[:] values, bint dropna):
932953

933954
return np.asarray(result_keys), np.asarray(result_counts)
934955

935-
@cython.boundscheck(False)
936-
cdef build_count_table_int64(int64_t[:] values, kh_int64_t *table):
937-
cdef:
938-
khiter_t k
939-
Py_ssize_t i, n = len(values)
940-
int64_t val
941-
int ret = 0
942-
943-
with nogil:
944-
kh_resize_int64(table, n)
945-
946-
for i in range(n):
947-
val = values[i]
948-
k = kh_get_int64(table, val)
949-
if k != table.n_buckets:
950-
table.vals[k] += 1
951-
else:
952-
k = kh_put_int64(table, val, &ret)
953-
table.vals[k] = 1
954-
955956

956957
cdef build_count_table_object(ndarray[object] values,
957958
ndarray[uint8_t, cast=True] mask,
@@ -1040,7 +1041,7 @@ def mode_int64(int64_t[:] values):
10401041

10411042
table = kh_init_int64()
10421043

1043-
build_count_table_int64(values, table)
1044+
build_count_table_scalar64(values, table, 0)
10441045

10451046
modes = np.empty(table.n_buckets, dtype=np.int64)
10461047

0 commit comments

Comments
 (0)