Skip to content

Commit a0f2908

Browse files
authored
[7.8][ML] Speed up the lat_long function (elastic#1118)
Backport elastic#1102.
1 parent d6143bf commit a0f2908

File tree

5 files changed

+94
-117
lines changed

5 files changed

+94
-117
lines changed

docs/CHANGELOG.asciidoc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333
=== Enhancements
3434

35+
* Speedup anomaly detection for the lat_long function. (See {ml-pull}1102[#1102].)
3536
* Reduce CPU scheduling priority of native analysis processes to favor the ES JVM
3637
when CPU is constrained. (See {ml-pull}1109[#1109].)
3738
* Take `training_percent` into account when estimating memory usage for classification and regression.

include/maths/CKMeans.h

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#ifndef INCLUDED_ml_maths_CKMeans_h
88
#define INCLUDED_ml_maths_CKMeans_h
99

10+
#include <core/CContainerPrinter.h>
1011
#include <core/CLogger.h>
1112

1213
#include <maths/CBasicStatistics.h>
@@ -23,6 +24,7 @@
2324
#include <algorithm>
2425
#include <cstddef>
2526
#include <cstdint>
27+
#include <limits>
2628
#include <sstream>
2729
#include <utility>
2830
#include <vector>
@@ -545,58 +547,41 @@ class CKMeansPlusPlusInitialization : private core::CNonCopyable {
545547
//! \param[out] result Filled in with the seed centres.
546548
template<typename ITR>
547549
void run(ITR beginPoints, ITR endPoints, std::size_t k, TPointVec& result) const {
550+
548551
result.clear();
549552
if (beginPoints == endPoints || k == 0) {
550553
return;
551554
}
552555

553-
using TPointCRef = std::reference_wrapper<const POINT>;
554-
using TPointCRefVec = std::vector<TPointCRef>;
555-
556-
std::size_t n = std::distance(beginPoints, endPoints);
556+
std::size_t n(std::distance(beginPoints, endPoints));
557557
LOG_TRACE(<< "# points = " << n);
558558

559-
std::size_t select = CSampling::uniformSample(m_Rng, std::size_t(0), n);
560-
LOG_TRACE(<< "select = " << select);
561-
562559
result.reserve(k);
563-
result.push_back(beginPoints[select]);
564-
LOG_TRACE(<< "selected to date = " << core::CContainerPrinter::print(result));
565-
566-
TPointCRefVec selected{std::cref(result.back())};
567-
selected.reserve(k);
568-
569-
CKdTree<TPointCRef> selectedLookup;
570-
selectedLookup.reserve(k);
571-
572-
TDoubleVec distances(n, 0.0);
573-
POINT distancesToHyperplanes{las::zero(result.back())};
560+
result.push_back(beginPoints[CSampling::uniformSample(m_Rng, std::size_t(0), n)]);
574561

562+
m_Distances.assign(n, std::numeric_limits<double>::max());
575563
for (std::size_t i = 1; i < k; ++i) {
564+
this->updateDistances(result.back(), beginPoints, endPoints);
565+
m_Probabilities.assign(m_Distances.begin(), m_Distances.end());
566+
result.push_back(beginPoints[CSampling::categoricalSample(m_Rng, m_Probabilities)]);
567+
}
568+
LOG_TRACE(<< "selected = " << core::CContainerPrinter::print(result));
569+
}
576570

577-
selectedLookup.build(selected);
578-
579-
std::size_t j{0};
580-
for (ITR point = beginPoints; point != endPoints; ++j, ++point) {
581-
las::setZero(distancesToHyperplanes);
582-
const auto* nn = selectedLookup.nearestNeighbour(*point, distancesToHyperplanes);
583-
distances[j] = nn != nullptr
584-
? CTools::pow2(las::distance(*point, nn->get()))
585-
: 0.0;
586-
}
587-
588-
select = CSampling::categoricalSample(m_Rng, distances);
589-
LOG_TRACE(<< "select = " << select);
590-
591-
result.push_back(beginPoints[select]);
592-
selected.push_back(std::cref(result.back()));
593-
LOG_TRACE(<< "selected to date = " << core::CContainerPrinter::print(result));
571+
private:
572+
template<typename ITR>
573+
void updateDistances(const POINT& selected, ITR beginPoints, ITR endPoints) const {
574+
std::size_t j{0};
575+
for (ITR point = beginPoints; point != endPoints; ++j, ++point) {
576+
m_Distances[j] = std::min(
577+
m_Distances[j], CTools::pow2(las::distance(*point, selected)));
594578
}
595579
}
596580

597581
private:
598-
//! The random number generator.
599582
RNG& m_Rng;
583+
mutable TDoubleVec m_Distances;
584+
mutable TDoubleVec m_Probabilities;
600585
};
601586
}
602587
}

include/maths/CKMeansOnline.h

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,15 @@ class CKMeansOnline {
126126
}
127127

128128
//! Construct with \p clusters.
129-
CKMeansOnline(std::size_t k,
129+
CKMeansOnline(TStoragePointMeanAccumulatorDoublePrVec& clusters,
130+
std::size_t k,
130131
double decayRate,
131-
double minClusterSize,
132-
TStoragePointMeanAccumulatorDoublePrVec& clusters)
133-
: CKMeansOnline{k, decayRate, minClusterSize} {
132+
double minClusterSize = MINIMUM_CATEGORY_COUNT,
133+
std::size_t bufferSize = BUFFER_SIZE,
134+
std::size_t numberSeeds = NUMBER_SEEDS,
135+
std::size_t maxIterations = MAX_ITERATIONS)
136+
: CKMeansOnline{k, decayRate, minClusterSize,
137+
bufferSize, numberSeeds, maxIterations} {
134138
m_Clusters.swap(clusters);
135139
m_Clusters.reserve(m_K + m_BufferSize + 1);
136140
}
@@ -265,9 +269,9 @@ class CKMeansOnline {
265269
CBasicStatistics::SMin<double>::TAccumulator minCost;
266270
TSphericalClusterVec centres;
267271
TSphericalClusterVecVec candidates;
272+
CKMeansPlusPlusInitialization<TSphericalCluster, RNG> seed{rng};
268273
for (std::size_t i = 0; i < numberSeeds; ++i) {
269-
CKMeansPlusPlusInitialization<TSphericalCluster, RNG> seedCentres(rng);
270-
seedCentres.run(kmeans.beginPoints(), kmeans.endPoints(), k, centres);
274+
seed.run(kmeans.beginPoints(), kmeans.endPoints(), k, centres);
271275
kmeans.setCentres(centres);
272276
kmeans.run(maxIterations);
273277
kmeans.clusters(candidates);
@@ -309,7 +313,8 @@ class CKMeansOnline {
309313
for (std::size_t j = 0u; j < split[i].size(); ++j) {
310314
clusters.push_back(m_Clusters[split[i][j]]);
311315
}
312-
result.emplace_back(m_K, m_DecayRate, m_MinClusterSize, clusters);
316+
result.emplace_back(clusters, m_K, m_DecayRate, m_MinClusterSize,
317+
m_BufferSize, m_NumberSeeds, m_MaxIterations);
313318
}
314319

315320
return true;
@@ -561,11 +566,10 @@ class CKMeansOnline {
561566
//! \note We assume \p points is small so the bruteforce approach is fast.
562567
static void deduplicate(TStoragePointMeanAccumulatorDoublePrVec& clusters) {
563568
if (clusters.size() > 1) {
564-
std::stable_sort(clusters.begin(), clusters.end(),
565-
[](const auto& lhs, const auto& rhs) {
566-
return CBasicStatistics::mean(lhs.first) <
567-
CBasicStatistics::mean(rhs.first);
568-
});
569+
std::sort(clusters.begin(), clusters.end(), [](const auto& lhs, const auto& rhs) {
570+
return CBasicStatistics::mean(lhs.first) <
571+
CBasicStatistics::mean(rhs.first);
572+
});
569573
auto back = clusters.begin();
570574
for (auto i = back + 1; i != clusters.end(); ++i) {
571575
if (CBasicStatistics::mean(back->first) == CBasicStatistics::mean(i->first)) {

0 commit comments

Comments
 (0)