Skip to content

Commit bfd47ba

Browse files
committed
Reduce memory consumption
1 parent 162474a commit bfd47ba

File tree

2 files changed

+26
-16
lines changed

2 files changed

+26
-16
lines changed

lap.h

+7-7
Original file line numberDiff line numberDiff line change
@@ -228,15 +228,14 @@ find_umins(
228228
/// @param u out dual variables, row reduction numbers / size dim
229229
/// @param v out dual variables, column reduction numbers / size dim
230230
/// @return achieved minimum assignment cost
231-
template <bool avx2, typename idx, typename cost>
232-
cost lap(int dim, const cost *restrict assign_cost, bool verbose,
231+
template <bool avx2, bool verbose, typename idx, typename cost>
232+
cost lap(int dim, const cost *restrict assign_cost,
233233
idx *restrict rowsol, idx *restrict colsol,
234234
cost *restrict u, cost *restrict v) {
235-
auto free = std::unique_ptr<idx[]>(new idx[dim]); // list of unassigned rows.
236-
auto collist = std::unique_ptr<idx[]>(new idx[dim]); // list of columns to be scanned in various ways.
237-
auto matches = std::unique_ptr<idx[]>(new idx[dim]); // counts how many times a row could be assigned.
238-
auto d = std::unique_ptr<cost[]>(new cost[dim]); // 'cost-distance' in augmenting path calculation.
239-
auto pred = std::unique_ptr<idx[]>(new idx[dim]); // row-predecessor of column in augmenting/alternating path.
235+
auto collist = std::make_unique<idx[]>(dim); // list of columns to be scanned in various ways.
236+
auto matches = std::make_unique<idx[]>(dim); // counts how many times a row could be assigned.
237+
auto d = std::make_unique<cost[]>(dim); // 'cost-distance' in augmenting path calculation.
238+
auto pred = std::make_unique<idx[]>(dim); // row-predecessor of column in augmenting/alternating path.
240239

241240
// init how many times a row will be assigned in the column reduction.
242241
#if _OPENMP >= 201307
@@ -273,6 +272,7 @@ cost lap(int dim, const cost *restrict assign_cost, bool verbose,
273272
}
274273

275274
// REDUCTION TRANSFER
275+
auto free = matches.get(); // list of unassigned rows.
276276
idx numfree = 0;
277277
for (idx i = 0; i < dim; i++) {
278278
const cost *local_cost = &assign_cost[i * dim];

python.cc

+19-9
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ using pyobj = _pyobj<PyObject>;
6464
using pyarray = _pyobj<PyArrayObject>;
6565

6666
template <typename F>
67-
static always_inline double call_lap(int dim, const void *restrict cost_matrix, bool verbose,
67+
static always_inline double call_lap(int dim, const void *restrict cost_matrix,
68+
bool verbose, bool disable_avx,
6869
int *restrict row_ind, int *restrict col_ind,
6970
void *restrict u, void *restrict v) {
7071
double lapcost;
@@ -76,10 +77,18 @@ static always_inline double call_lap(int dim, const void *restrict cost_matrix,
7677
auto cost_matrix_typed = reinterpret_cast<const F*>(cost_matrix);
7778
auto u_typed = reinterpret_cast<F*>(u);
7879
auto v_typed = reinterpret_cast<F*>(v);
79-
if (hasAVX2) {
80-
lapcost = lap<true>(dim, cost_matrix_typed, verbose, row_ind, col_ind, u_typed, v_typed);
80+
if (hasAVX2 && !disable_avx) {
81+
if (verbose) {
82+
lapcost = lap<true, true>(dim, cost_matrix_typed, row_ind, col_ind, u_typed, v_typed);
83+
} else {
84+
lapcost = lap<true, false>(dim, cost_matrix_typed, row_ind, col_ind, u_typed, v_typed);
85+
}
8186
} else {
82-
lapcost = lap<false>(dim, cost_matrix_typed, verbose, row_ind, col_ind, u_typed, v_typed);
87+
if (verbose) {
88+
lapcost = lap<false, true>(dim, cost_matrix_typed, row_ind, col_ind, u_typed, v_typed);
89+
} else {
90+
lapcost = lap<false, false>(dim, cost_matrix_typed, row_ind, col_ind, u_typed, v_typed);
91+
}
8392
}
8493
Py_END_ALLOW_THREADS
8594
return lapcost;
@@ -88,12 +97,13 @@ static always_inline double call_lap(int dim, const void *restrict cost_matrix,
8897
static PyObject *py_lapjv(PyObject *self, PyObject *args, PyObject *kwargs) {
8998
PyObject *cost_matrix_obj;
9099
int verbose = 0;
100+
int disable_avx = 0;
91101
int force_doubles = 0;
92102
static const char *kwlist[] = {
93-
"cost_matrix", "verbose", "force_doubles", NULL};
103+
"cost_matrix", "verbose", "disable_avx", "force_doubles", NULL};
94104
if (!PyArg_ParseTupleAndKeywords(
95-
args, kwargs, "O|pb", const_cast<char**>(kwlist),
96-
&cost_matrix_obj, &verbose, &force_doubles)) {
105+
args, kwargs, "O|pbb", const_cast<char**>(kwlist),
106+
&cost_matrix_obj, &verbose, &disable_avx, &force_doubles)) {
97107
return NULL;
98108
}
99109
pyarray cost_matrix_array;
@@ -144,9 +154,9 @@ static PyObject *py_lapjv(PyObject *self, PyObject *args, PyObject *kwargs) {
144154
auto u = PyArray_DATA(u_array.get());
145155
auto v = PyArray_DATA(v_array.get());
146156
if (float32) {
147-
lapcost = call_lap<float>(dim, cost_matrix, verbose, row_ind, col_ind, u, v);
157+
lapcost = call_lap<float>(dim, cost_matrix, verbose, disable_avx, row_ind, col_ind, u, v);
148158
} else {
149-
lapcost = call_lap<double>(dim, cost_matrix, verbose, row_ind, col_ind, u, v);
159+
lapcost = call_lap<double>(dim, cost_matrix, verbose, disable_avx, row_ind, col_ind, u, v);
150160
}
151161
return Py_BuildValue("(OO(dOO))",
152162
row_ind_array.get(), col_ind_array.get(), lapcost,

0 commit comments

Comments
 (0)