diff --git a/lib/maths/CDataFrameUtils.cc b/lib/maths/CDataFrameUtils.cc index 83bfe51745..581cc4d998 100644 --- a/lib/maths/CDataFrameUtils.cc +++ b/lib/maths/CDataFrameUtils.cc @@ -1072,27 +1072,34 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads, std::size_t numberSamples{ static_cast(std::min(1000.0, rowMask.manhattan()))}; - TStratifiedSamplerPtr sampler; - std::tie(sampler, std::ignore) = classifierStratifiedCrossValidationRowSampler( - numberThreads, frame, targetColumn, rng, numberSamples, rowMask); - - TSizeVec rowIndices; - frame.readRows(1, 0, frame.numberRows(), - [&](TRowItr beginRows, TRowItr endRows) { - for (auto row = beginRows; row != endRows; ++row) { - sampler->sample(*row); - } - }, - &rowMask); - sampler->finishSampling(rng, rowIndices); - std::sort(rowIndices.begin(), rowIndices.end()); core::CPackedBitVector sampleMask; - for (auto row : rowIndices) { - sampleMask.extend(false, row - sampleMask.size()); - sampleMask.extend(true); + + // No need to sample if were going to use every row we've been given. + if (numberSamples < static_cast(rowMask.manhattan())) { + TStratifiedSamplerPtr sampler; + std::tie(sampler, std::ignore) = classifierStratifiedCrossValidationRowSampler( + numberThreads, frame, targetColumn, rng, numberSamples, rowMask); + + TSizeVec rowIndices; + frame.readRows(1, 0, frame.numberRows(), + [&](TRowItr beginRows, TRowItr endRows) { + for (auto row = beginRows; row != endRows; ++row) { + sampler->sample(*row); + } + }, + &rowMask); + sampler->finishSampling(rng, rowIndices); + std::sort(rowIndices.begin(), rowIndices.end()); + LOG_TRACE(<< "# row indices = " << rowIndices.size()); + + for (auto row : rowIndices) { + sampleMask.extend(false, row - sampleMask.size()); + sampleMask.extend(true); + } + sampleMask.extend(false, rowMask.size() - sampleMask.size()); + } else { + sampleMask = rowMask; } - sampleMask.extend(false, rowMask.size() - sampleMask.size()); - LOG_TRACE(<< "# row indices = " << rowIndices.size()); // Compute the count of each class in the sample set. auto readClassCountsAndRecalls = core::bindRetrievableState( @@ -1119,6 +1126,11 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads, copyClassCountsAndRecalls, reduceClassCountsAndRecalls, classCountsAndRecalls); TDoubleVector classCounts{classCountsAndRecalls.topRows(numberClasses)}; TDoubleVector classRecalls{classCountsAndRecalls.bottomRows(numberClasses)}; + // If a class count is zero the derivative of the loss functin w.r.t. that + // class is not well defined. The objective is independent of such a class + // so the choice for its count and recall is not important. However, its + // count must be non-zero so that we don't run into a NaN cascade. + classCounts = classCounts.cwiseMax(1.0); classRecalls.array() /= classCounts.array(); LOG_TRACE(<< "class counts = " << classCounts.transpose()); LOG_TRACE(<< "class recalls = " << classRecalls.transpose()); @@ -1219,7 +1231,7 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads, // and use this to bias initial values. TMinAccumulator minLoss; - TDoubleVector minWeights; + TDoubleVector minLossWeights; TDoubleVector w0{TDoubleVector::Ones(numberClasses)}; for (std::size_t i = 0; i < 5; ++i) { CLbfgs lbfgs{5}; @@ -1227,7 +1239,7 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads, std::tie(w0, loss) = lbfgs.minimize(objective, objectiveGradient, std::move(w0)); LOG_TRACE(<< "weights* = " << w0.transpose() << ", loss* = " << loss); if (minLoss.add(loss)) { - minWeights = std::move(w0); + minLossWeights = std::move(w0); } w0 = TDoubleVector::Ones(numberClasses); for (std::size_t j = 1; j < numberClasses; ++j) { @@ -1237,10 +1249,10 @@ CDataFrameUtils::maximizeMinimumRecallForMulticlass(std::size_t numberThreads, // Since we take argmax_i w_i p_i we can multiply by a constant. We arrange for // the largest weight to always be one. - minWeights.array() /= minWeights.maxCoeff(); - LOG_TRACE(<< "weights = " << minWeights.transpose()); + minLossWeights.array() /= minLossWeights.maxCoeff(); + LOG_TRACE(<< "weights = " << minLossWeights.transpose()); - return minWeights; + return minLossWeights; } void CDataFrameUtils::removeMetricColumns(const core::CDataFrame& frame, TSizeVec& columnMask) {