@@ -64,7 +64,8 @@ using pyobj = _pyobj<PyObject>;
64
64
using pyarray = _pyobj<PyArrayObject>;
65
65
66
66
template <typename F>
67
- static always_inline double call_lap (int dim, const void *restrict cost_matrix, bool verbose,
67
+ static always_inline double call_lap (int dim, const void *restrict cost_matrix,
68
+ bool verbose, bool disable_avx,
68
69
int *restrict row_ind, int *restrict col_ind,
69
70
void *restrict u, void *restrict v) {
70
71
double lapcost;
@@ -76,10 +77,18 @@ static always_inline double call_lap(int dim, const void *restrict cost_matrix,
76
77
auto cost_matrix_typed = reinterpret_cast <const F*>(cost_matrix);
77
78
auto u_typed = reinterpret_cast <F*>(u);
78
79
auto v_typed = reinterpret_cast <F*>(v);
79
- if (hasAVX2) {
80
- lapcost = lap<true >(dim, cost_matrix_typed, verbose, row_ind, col_ind, u_typed, v_typed);
80
+ if (hasAVX2 && !disable_avx) {
81
+ if (verbose) {
82
+ lapcost = lap<true , true >(dim, cost_matrix_typed, row_ind, col_ind, u_typed, v_typed);
83
+ } else {
84
+ lapcost = lap<true , false >(dim, cost_matrix_typed, row_ind, col_ind, u_typed, v_typed);
85
+ }
81
86
} else {
82
- lapcost = lap<false >(dim, cost_matrix_typed, verbose, row_ind, col_ind, u_typed, v_typed);
87
+ if (verbose) {
88
+ lapcost = lap<false , true >(dim, cost_matrix_typed, row_ind, col_ind, u_typed, v_typed);
89
+ } else {
90
+ lapcost = lap<false , false >(dim, cost_matrix_typed, row_ind, col_ind, u_typed, v_typed);
91
+ }
83
92
}
84
93
Py_END_ALLOW_THREADS
85
94
return lapcost;
@@ -88,12 +97,13 @@ static always_inline double call_lap(int dim, const void *restrict cost_matrix,
88
97
static PyObject *py_lapjv (PyObject *self, PyObject *args, PyObject *kwargs) {
89
98
PyObject *cost_matrix_obj;
90
99
int verbose = 0 ;
100
+ int disable_avx = 0 ;
91
101
int force_doubles = 0 ;
92
102
static const char *kwlist[] = {
93
- " cost_matrix" , " verbose" , " force_doubles" , NULL };
103
+ " cost_matrix" , " verbose" , " disable_avx " , " force_doubles" , NULL };
94
104
if (!PyArg_ParseTupleAndKeywords (
95
- args, kwargs, " O|pb " , const_cast <char **>(kwlist),
96
- &cost_matrix_obj, &verbose, &force_doubles)) {
105
+ args, kwargs, " O|pbb " , const_cast <char **>(kwlist),
106
+ &cost_matrix_obj, &verbose, &disable_avx, & force_doubles)) {
97
107
return NULL ;
98
108
}
99
109
pyarray cost_matrix_array;
@@ -144,9 +154,9 @@ static PyObject *py_lapjv(PyObject *self, PyObject *args, PyObject *kwargs) {
144
154
auto u = PyArray_DATA (u_array.get ());
145
155
auto v = PyArray_DATA (v_array.get ());
146
156
if (float32) {
147
- lapcost = call_lap<float >(dim, cost_matrix, verbose, row_ind, col_ind, u, v);
157
+ lapcost = call_lap<float >(dim, cost_matrix, verbose, disable_avx, row_ind, col_ind, u, v);
148
158
} else {
149
- lapcost = call_lap<double >(dim, cost_matrix, verbose, row_ind, col_ind, u, v);
159
+ lapcost = call_lap<double >(dim, cost_matrix, verbose, disable_avx, row_ind, col_ind, u, v);
150
160
}
151
161
return Py_BuildValue (" (OO(dOO))" ,
152
162
row_ind_array.get (), col_ind_array.get (), lapcost,
0 commit comments