@@ -1072,27 +1072,34 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads,
1072
1072
std::size_t numberSamples{
1073
1073
static_cast <std::size_t >(std::min (1000.0 , rowMask.manhattan ()))};
1074
1074
1075
- TStratifiedSamplerPtr sampler;
1076
- std::tie (sampler, std::ignore) = classifierStratifiedCrossValidationRowSampler (
1077
- numberThreads, frame, targetColumn, rng, numberSamples, rowMask);
1078
-
1079
- TSizeVec rowIndices;
1080
- frame.readRows (1 , 0 , frame.numberRows (),
1081
- [&](TRowItr beginRows, TRowItr endRows) {
1082
- for (auto row = beginRows; row != endRows; ++row) {
1083
- sampler->sample (*row);
1084
- }
1085
- },
1086
- &rowMask);
1087
- sampler->finishSampling (rng, rowIndices);
1088
- std::sort (rowIndices.begin (), rowIndices.end ());
1089
1075
core::CPackedBitVector sampleMask;
1090
- for (auto row : rowIndices) {
1091
- sampleMask.extend (false , row - sampleMask.size ());
1092
- sampleMask.extend (true );
1076
+
1077
+ // No need to sample if were going to use every row we've been given.
1078
+ if (numberSamples < static_cast <std::size_t >(rowMask.manhattan ())) {
1079
+ TStratifiedSamplerPtr sampler;
1080
+ std::tie (sampler, std::ignore) = classifierStratifiedCrossValidationRowSampler (
1081
+ numberThreads, frame, targetColumn, rng, numberSamples, rowMask);
1082
+
1083
+ TSizeVec rowIndices;
1084
+ frame.readRows (1 , 0 , frame.numberRows (),
1085
+ [&](TRowItr beginRows, TRowItr endRows) {
1086
+ for (auto row = beginRows; row != endRows; ++row) {
1087
+ sampler->sample (*row);
1088
+ }
1089
+ },
1090
+ &rowMask);
1091
+ sampler->finishSampling (rng, rowIndices);
1092
+ std::sort (rowIndices.begin (), rowIndices.end ());
1093
+ LOG_TRACE (<< " # row indices = " << rowIndices.size ());
1094
+
1095
+ for (auto row : rowIndices) {
1096
+ sampleMask.extend (false , row - sampleMask.size ());
1097
+ sampleMask.extend (true );
1098
+ }
1099
+ sampleMask.extend (false , rowMask.size () - sampleMask.size ());
1100
+ } else {
1101
+ sampleMask = rowMask;
1093
1102
}
1094
- sampleMask.extend (false , rowMask.size () - sampleMask.size ());
1095
- LOG_TRACE (<< " # row indices = " << rowIndices.size ());
1096
1103
1097
1104
// Compute the count of each class in the sample set.
1098
1105
auto readClassCountsAndRecalls = core::bindRetrievableState (
@@ -1119,6 +1126,11 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads,
1119
1126
copyClassCountsAndRecalls, reduceClassCountsAndRecalls, classCountsAndRecalls);
1120
1127
TDoubleVector classCounts{classCountsAndRecalls.topRows (numberClasses)};
1121
1128
TDoubleVector classRecalls{classCountsAndRecalls.bottomRows (numberClasses)};
1129
+ // If a class count is zero the derivative of the loss functin w.r.t. that
1130
+ // class is not well defined. The objective is independent of such a class
1131
+ // so the choice for its count and recall is not important. However, its
1132
+ // count must be non-zero so that we don't run into a NaN cascade.
1133
+ classCounts = classCounts.cwiseMax (1.0 );
1122
1134
classRecalls.array () /= classCounts.array ();
1123
1135
LOG_TRACE (<< " class counts = " << classCounts.transpose ());
1124
1136
LOG_TRACE (<< " class recalls = " << classRecalls.transpose ());
@@ -1219,15 +1231,15 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads,
1219
1231
// and use this to bias initial values.
1220
1232
1221
1233
TMinAccumulator minLoss;
1222
- TDoubleVector minWeights ;
1234
+ TDoubleVector minLossWeights ;
1223
1235
TDoubleVector w0{TDoubleVector::Ones (numberClasses)};
1224
1236
for (std::size_t i = 0 ; i < 5 ; ++i) {
1225
1237
CLbfgs<TDoubleVector> lbfgs{5 };
1226
1238
double loss;
1227
1239
std::tie (w0, loss) = lbfgs.minimize (objective, objectiveGradient, std::move (w0));
1228
1240
LOG_TRACE (<< " weights* = " << w0.transpose () << " , loss* = " << loss);
1229
1241
if (minLoss.add (loss)) {
1230
- minWeights = std::move (w0);
1242
+ minLossWeights = std::move (w0);
1231
1243
}
1232
1244
w0 = TDoubleVector::Ones (numberClasses);
1233
1245
for (std::size_t j = 1 ; j < numberClasses; ++j) {
@@ -1237,10 +1249,10 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads,
1237
1249
1238
1250
// Since we take argmax_i w_i p_i we can multiply by a constant. We arrange for
1239
1251
// the largest weight to always be one.
1240
- minWeights .array () /= minWeights .maxCoeff ();
1241
- LOG_TRACE (<< " weights = " << minWeights .transpose ());
1252
+ minLossWeights .array () /= minLossWeights .maxCoeff ();
1253
+ LOG_TRACE (<< " weights = " << minLossWeights .transpose ());
1242
1254
1243
- return minWeights ;
1255
+ return minLossWeights ;
1244
1256
}
1245
1257
1246
1258
void CDataFrameUtils::removeMetricColumns (const core::CDataFrame& frame, TSizeVec& columnMask) {
0 commit comments