Skip to content

Commit bfb57fc

Browse files
authored
[7.8[ML] Multiclass maximise minimum recall (#1113) (#1133)
1 parent 93c4d38 commit bfb57fc

13 files changed

+512
-122
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
when CPU is constrained. (See {ml-pull}1109[#1109].)
3838
* Take `training_percent` into account when estimating memory usage for classification and regression.
3939
(See {ml-pull}1111[#1111].)
40+
* Support maximize minimum recall when assigning class labels for multiclass classification.
41+
(See {ml-pull}1113[#1113].)
42+
* Improve robustness of anomaly detection to bad input data. (See {ml-pull}1114[#1114].)
4043
* Adds new `num_matches` and `preferred_to_categories` fields to category output.
4144
(See {ml-pull}1062[#1062])
4245
* Improve robustness of anomaly detection to bad input data. (See {ml-pull}1114[#1114].)

include/maths/CBoostedTreeImpl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class MATHS_EXPORT CBoostedTreeImpl final {
112112
//! Get the number of columns training the model will add to the data frame.
113113
static std::size_t numberExtraColumnsForTrain(std::size_t numberLossParameters) {
114114
// We store as follows:
115-
// 1. The predicted values for the dependent variables
115+
// 1. The predicted values for the dependent variable
116116
// 2. The gradient of the loss function
117117
// 3. The upper triangle of the hessian of the loss function
118118
// 4. The example's weight

include/maths/CDataFrameUtils.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ class MATHS_EXPORT CDataFrameUtils : private core::CNonInstantiatable {
6969
using TRowRef = core::CDataFrame::TRowRef;
7070
using TWeightFunc = std::function<double(const TRowRef&)>;
7171
using TDoubleVector = CDenseVector<double>;
72-
using TReadPredictionFunc = std::function<TDoubleVector(const TRowRef)>;
72+
using TMemoryMappedFloatVector = CMemoryMappedDenseVector<CFloatStorage>;
73+
using TReadPredictionFunc = std::function<TMemoryMappedFloatVector(const TRowRef&)>;
7374
using TQuantileSketchVec = std::vector<CQuantileSketch>;
7475
using TPackedBitVectorVec = std::vector<core::CPackedBitVector>;
7576

@@ -408,6 +409,19 @@ class MATHS_EXPORT CDataFrameUtils : private core::CNonInstantiatable {
408409
const core::CPackedBitVector& rowMask,
409410
const TSizeVec& columnMask,
410411
std::size_t numberSamples);
412+
static TDoubleVector
413+
maximizeMinimumRecallForBinary(std::size_t numberThreads,
414+
const core::CDataFrame& frame,
415+
const core::CPackedBitVector& rowMask,
416+
std::size_t targetColumn,
417+
const TReadPredictionFunc& readPrediction);
418+
static TDoubleVector
419+
maximizeMinimumRecallForMulticlass(std::size_t numberThreads,
420+
const core::CDataFrame& frame,
421+
const core::CPackedBitVector& rowMask,
422+
std::size_t numberClasses,
423+
std::size_t targetColumn,
424+
const TReadPredictionFunc& readPrediction);
411425
static void removeMetricColumns(const core::CDataFrame& frame, TSizeVec& columnMask);
412426
static void removeCategoricalColumns(const core::CDataFrame& frame, TSizeVec& columnMask);
413427
static double unitWeight(const TRowRef&);

include/maths/CTools.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <cstring>
2727
#include <iosfwd>
2828
#include <limits>
29+
#include <numeric>
2930
#include <vector>
3031

3132
namespace ml {
@@ -684,7 +685,7 @@ class MATHS_EXPORT CTools : private core::CNonInstantiatable {
684685
return sigmoid(std::exp(std::copysign(1.0, sign) * (x - x0) / width));
685686
}
686687

687-
//! Compute the softmax from the multinomial logit values \p logit.
688+
//! Compute the softmax for the multinomial logit values \p logit.
688689
//!
689690
//! i.e. \f$[\sigma(z)]_i = \frac{exp(z_i)}{\sum_j exp(z_j)}\f$.
690691
//!
@@ -703,10 +704,29 @@ class MATHS_EXPORT CTools : private core::CNonInstantiatable {
703704
}
704705
}
705706

706-
//! Specialize the softmax for our dense vector type.
707+
//! Compute the log of the softmax for the multinomial logit values \p logit.
708+
template<typename COLLECTION>
709+
static void inplaceLogSoftmax(COLLECTION& z) {
710+
double zmax{*std::max_element(z.begin(), z.end())};
711+
for (auto& zi : z) {
712+
zi -= zmax;
713+
}
714+
double logZ{std::log(std::accumulate(
715+
z.begin(), z.end(), 0.0,
716+
[](double sum, const auto& zi) { return sum + std::exp(zi); }))};
717+
for (auto& zi : z) {
718+
zi -= logZ;
719+
}
720+
}
721+
722+
//! Specialize the softmax for CDenseVector.
707723
template<typename T>
708724
static void inplaceSoftmax(CDenseVector<T>& z);
709725

726+
//! Specialize the log(softmax) for CDenseVector.
727+
template<typename SCALAR>
728+
static void inplaceLogSoftmax(CDenseVector<SCALAR>& z);
729+
710730
//! Linearly interpolate a function on the interval [\p a, \p b].
711731
static double linearlyInterpolate(double a, double b, double fa, double fb, double x);
712732

include/maths/CToolsDetail.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include <maths/CCompositeFunctions.h>
1313
#include <maths/CIntegration.h>
14+
#include <maths/CLinearAlgebraEigen.h>
1415
#include <maths/CMixtureDistribution.h>
1516
#include <maths/COrderings.h>
1617
#include <maths/CTools.h>
@@ -308,6 +309,15 @@ void CTools::inplaceSoftmax(CDenseVector<T>& z) {
308309
z.array() = z.array().exp();
309310
z /= z.sum();
310311
}
312+
313+
template<typename SCALAR>
314+
void CTools::inplaceLogSoftmax(CDenseVector<SCALAR>& z) {
315+
// Handle under/overflow when taking exponentials by subtracting zmax.
316+
double zmax{z.maxCoeff()};
317+
z.array() -= zmax;
318+
double Z{z.array().exp().sum()};
319+
z.array() -= std::log(Z);
320+
}
311321
}
312322
}
313323

include/test/CDataFrameAnalyzerTrainingFactory.h

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#define INCLUDED_ml_test_CDataFrameAnalyzerTrainingFactory_h
99

1010
#include <core/CDataFrame.h>
11+
#include <core/CSmallVector.h>
1112

1213
#include <maths/CBoostedTreeFactory.h>
1314
#include <maths/CBoostedTreeLoss.h>
@@ -122,13 +123,11 @@ class TEST_EXPORT CDataFrameAnalyzerTrainingFactory {
122123
auto prediction = tree->readAndAdjustPrediction(*row);
123124
switch (type) {
124125
case E_Regression:
125-
appendPrediction(*frame, weights.size(), prediction[0], expectedPredictions);
126+
appendPrediction(*frame, weights.size(), prediction, expectedPredictions);
126127
break;
127128
case E_BinaryClassification:
128-
appendPrediction(*frame, weights.size(), prediction[1], expectedPredictions);
129-
break;
130129
case E_MulticlassClassification:
131-
// TODO.
130+
appendPrediction(*frame, weights.size(), prediction, expectedPredictions);
132131
break;
133132
}
134133
}
@@ -149,15 +148,19 @@ class TEST_EXPORT CDataFrameAnalyzerTrainingFactory {
149148
TStrVec& targets);
150149

151150
private:
151+
using TDouble2Vec = core::CSmallVector<double, 2>;
152152
using TBoolVec = std::vector<bool>;
153153
using TRowItr = core::CDataFrame::TRowItr;
154154

155155
private:
156-
static void appendPrediction(core::CDataFrame&, std::size_t, double prediction, TDoubleVec& predictions);
156+
static void appendPrediction(core::CDataFrame&,
157+
std::size_t,
158+
const TDouble2Vec& prediction,
159+
TDoubleVec& predictions);
157160

158161
static void appendPrediction(core::CDataFrame& frame,
159162
std::size_t target,
160-
double class1Score,
163+
const TDouble2Vec& class1Score,
161164
TStrVec& predictions);
162165
};
163166
}

lib/maths/CBoostedTreeImpl.cc

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,12 +377,16 @@ void CBoostedTreeImpl::initializePerFoldTestLosses() {
377377
}
378378

379379
void CBoostedTreeImpl::computeClassificationWeights(const core::CDataFrame& frame) {
380+
381+
using TFloatStorageVec = std::vector<CFloatStorage>;
382+
380383
if (m_Loss->type() == CLoss::E_BinaryClassification ||
381384
m_Loss->type() == CLoss::E_MulticlassClassification) {
382385

383386
std::size_t numberClasses{m_Loss->type() == CLoss::E_BinaryClassification
384387
? 2
385388
: m_Loss->numberParameters()};
389+
TFloatStorageVec storage(2);
386390

387391
switch (m_ClassAssignmentObjective) {
388392
case CBoostedTree::E_Accuracy:
@@ -391,9 +395,20 @@ void CBoostedTreeImpl::computeClassificationWeights(const core::CDataFrame& fram
391395
case CBoostedTree::E_MinimumRecall:
392396
m_ClassificationWeights = CDataFrameUtils::maximumMinimumRecallClassWeights(
393397
m_NumberThreads, frame, this->allTrainingRowsMask(),
394-
numberClasses, m_DependentVariable, [this](const TRowRef& row) {
395-
return m_Loss->transform(readPrediction(
396-
row, m_NumberInputColumns, m_Loss->numberParameters()));
398+
numberClasses, m_DependentVariable,
399+
[storage, numberClasses, this](const TRowRef& row) mutable {
400+
if (m_Loss->type() == CLoss::E_BinaryClassification) {
401+
// We predict the log-odds but this is expected to return
402+
// the log of the predicted class probabilities.
403+
TMemoryMappedFloatVector result{&storage[0], 2};
404+
result.array() = m_Loss
405+
->transform(readPrediction(
406+
row, m_NumberInputColumns, numberClasses))
407+
.array()
408+
.log();
409+
return result;
410+
}
411+
return readPrediction(row, m_NumberInputColumns, numberClasses);
397412
});
398413
break;
399414
}

lib/maths/CBoostedTreeLoss.cc

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,6 @@ double logLogistic(double logOdds) {
3939
}
4040
return std::log(CTools::logisticFunction(logOdds));
4141
}
42-
43-
template<typename SCALAR>
44-
void inplaceLogSoftmax(CDenseVector<SCALAR>& z) {
45-
// Handle under/overflow when taking exponentials by subtracting zmax.
46-
double zmax{z.maxCoeff()};
47-
z.array() -= zmax;
48-
double Z{z.array().exp().sum()};
49-
z.array() -= std::log(Z);
50-
}
5142
}
5243

5344
namespace boosted_tree_detail {
@@ -332,7 +323,7 @@ CArgMinMultinomialLogisticLossImpl::objective() const {
332323
if (m_Centres.size() == 1) {
333324
return [logProbabilities, lambda, this](const TDoubleVector& weight) mutable {
334325
logProbabilities = m_Centres[0] + weight;
335-
inplaceLogSoftmax(logProbabilities);
326+
CTools::inplaceLogSoftmax(logProbabilities);
336327
return lambda * weight.squaredNorm() - m_ClassCounts.transpose() * logProbabilities;
337328
};
338329
}
@@ -341,7 +332,7 @@ CArgMinMultinomialLogisticLossImpl::objective() const {
341332
for (std::size_t i = 0; i < m_CentresClassCounts.size(); ++i) {
342333
if (m_CentresClassCounts[i].sum() > 0.0) {
343334
logProbabilities = m_Centres[i] + weight;
344-
inplaceLogSoftmax(logProbabilities);
335+
CTools::inplaceLogSoftmax(logProbabilities);
345336
loss -= m_CentresClassCounts[i].transpose() * logProbabilities;
346337
}
347338
}

0 commit comments

Comments
 (0)