Skip to content

Commit 499338d

Browse files
authored
[ML] Move model test helper functions to base class (elastic#1523)
Move existing model test helper functions to base class and share functionality where possible.
1 parent 3fc14ae commit 499338d

8 files changed

+1315
-1462
lines changed

lib/model/unittest/CCountingModelTest.cc

+42-70
Original file line numberDiff line numberDiff line change
@@ -29,53 +29,28 @@ BOOST_AUTO_TEST_SUITE(CCountingModelTest)
2929
using namespace ml;
3030
using namespace model;
3131

32-
namespace {
33-
std::size_t addPerson(const std::string& p,
34-
const CModelFactory::TDataGathererPtr& gatherer,
35-
CResourceMonitor& resourceMonitor) {
36-
CDataGatherer::TStrCPtrVec person;
37-
person.push_back(&p);
38-
CEventData result;
39-
gatherer->processFields(person, result, resourceMonitor);
40-
return *result.personId();
41-
}
42-
43-
void addArrival(CDataGatherer& gatherer,
44-
CResourceMonitor& resourceMonitor,
45-
core_t::TTime time,
46-
const std::string& person) {
47-
CDataGatherer::TStrCPtrVec fieldValues;
48-
fieldValues.push_back(&person);
49-
50-
CEventData eventData;
51-
eventData.time(time);
52-
gatherer.addArrival(fieldValues, eventData, resourceMonitor);
53-
}
54-
55-
SModelParams::TStrDetectionRulePr
56-
makeScheduledEvent(const std::string& description, double start, double end) {
57-
CRuleCondition conditionGte;
58-
conditionGte.appliesTo(CRuleCondition::E_Time);
59-
conditionGte.op(CRuleCondition::E_GTE);
60-
conditionGte.value(start);
61-
CRuleCondition conditionLt;
62-
conditionLt.appliesTo(CRuleCondition::E_Time);
63-
conditionLt.op(CRuleCondition::E_LT);
64-
conditionLt.value(end);
65-
66-
CDetectionRule rule;
67-
rule.action(CDetectionRule::E_SkipModelUpdate);
68-
rule.addCondition(conditionGte);
69-
rule.addCondition(conditionLt);
70-
71-
SModelParams::TStrDetectionRulePr event = std::make_pair(description, rule);
72-
return event;
73-
}
74-
75-
const std::string EMPTY_STRING;
76-
}
77-
78-
class CTestFixture : public CModelTestFixtureBase {};
32+
class CTestFixture : public CModelTestFixtureBase {
33+
protected:
34+
SModelParams::TStrDetectionRulePr
35+
makeScheduledEvent(const std::string& description, double start, double end) {
36+
CRuleCondition conditionGte;
37+
conditionGte.appliesTo(CRuleCondition::E_Time);
38+
conditionGte.op(CRuleCondition::E_GTE);
39+
conditionGte.value(start);
40+
CRuleCondition conditionLt;
41+
conditionLt.appliesTo(CRuleCondition::E_Time);
42+
conditionLt.op(CRuleCondition::E_LT);
43+
conditionLt.value(end);
44+
45+
CDetectionRule rule;
46+
rule.action(CDetectionRule::E_SkipModelUpdate);
47+
rule.addCondition(conditionGte);
48+
rule.addCondition(conditionLt);
49+
50+
SModelParams::TStrDetectionRulePr event = std::make_pair(description, rule);
51+
return event;
52+
}
53+
};
7954

8055
BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) {
8156
core_t::TTime startTime{100};
@@ -94,20 +69,20 @@ BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) {
9469
CModelFactory::SGathererInitializationData gathererNoGapInitData(startTime);
9570
CModelFactory::TDataGathererPtr gathererNoGap(
9671
factory.makeDataGatherer(gathererNoGapInitData));
97-
BOOST_REQUIRE_EQUAL(std::size_t(0), addPerson("p", gathererNoGap, m_ResourceMonitor));
72+
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p", gathererNoGap));
9873
CModelFactory::SModelInitializationData modelNoGapInitData(gathererNoGap);
9974
CAnomalyDetectorModel::TModelPtr modelHolderNoGap(factory.makeModel(modelNoGapInitData));
10075
CCountingModel* modelNoGap =
10176
dynamic_cast<CCountingModel*>(modelHolderNoGap.get());
10277

10378
// |2|2|0|0|1| -> 1.0 mean count
104-
addArrival(*gathererNoGap, m_ResourceMonitor, 100, "p");
105-
addArrival(*gathererNoGap, m_ResourceMonitor, 110, "p");
79+
this->addArrival(*gathererNoGap, 100, "p");
80+
this->addArrival(*gathererNoGap, 110, "p");
10681
modelNoGap->sample(100, 200, m_ResourceMonitor);
107-
addArrival(*gathererNoGap, m_ResourceMonitor, 250, "p");
108-
addArrival(*gathererNoGap, m_ResourceMonitor, 280, "p");
82+
this->addArrival(*gathererNoGap, 250, "p");
83+
this->addArrival(*gathererNoGap, 280, "p");
10984
modelNoGap->sample(200, 500, m_ResourceMonitor);
110-
addArrival(*gathererNoGap, m_ResourceMonitor, 500, "p");
85+
this->addArrival(*gathererNoGap, 500, "p");
11186
modelNoGap->sample(500, 600, m_ResourceMonitor);
11287

11388
BOOST_REQUIRE_EQUAL(1.0, *modelNoGap->baselineBucketCount(0));
@@ -118,8 +93,7 @@ BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) {
11893
CModelFactory::SGathererInitializationData gathererWithGapInitData(startTime);
11994
CModelFactory::TDataGathererPtr gathererWithGap(
12095
factory.makeDataGatherer(gathererWithGapInitData));
121-
BOOST_REQUIRE_EQUAL(std::size_t(0),
122-
addPerson("p", gathererWithGap, m_ResourceMonitor));
96+
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p", gathererWithGap));
12397
CModelFactory::SModelInitializationData modelWithGapInitData(gathererWithGap);
12498
CAnomalyDetectorModel::TModelPtr modelHolderWithGap(
12599
factory.makeModel(modelWithGapInitData));
@@ -128,15 +102,15 @@ BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) {
128102

129103
// |2|2|0|0|1|
130104
// |2|X|X|X|1| -> 1.5 mean count where X means skipped bucket
131-
addArrival(*gathererWithGap, m_ResourceMonitor, 100, "p");
132-
addArrival(*gathererWithGap, m_ResourceMonitor, 110, "p");
105+
this->addArrival(*gathererWithGap, 100, "p");
106+
this->addArrival(*gathererWithGap, 110, "p");
133107
modelWithGap->sample(100, 200, m_ResourceMonitor);
134-
addArrival(*gathererWithGap, m_ResourceMonitor, 250, "p");
135-
addArrival(*gathererWithGap, m_ResourceMonitor, 280, "p");
108+
this->addArrival(*gathererWithGap, 250, "p");
109+
this->addArrival(*gathererWithGap, 280, "p");
136110
modelWithGap->skipSampling(500);
137111
modelWithGap->prune(maxAgeBuckets);
138112
BOOST_REQUIRE_EQUAL(std::size_t(1), gathererWithGap->numberActivePeople());
139-
addArrival(*gathererWithGap, m_ResourceMonitor, 500, "p");
113+
this->addArrival(*gathererWithGap, 500, "p");
140114
modelWithGap->sample(500, 600, m_ResourceMonitor);
141115

142116
BOOST_REQUIRE_EQUAL(1.5, *modelWithGap->baselineBucketCount(0));
@@ -166,7 +140,7 @@ BOOST_FIXTURE_TEST_CASE(testCheckScheduledEvents, CTestFixture) {
166140
CModelFactory::SGathererInitializationData gathererNoGapInitData(startTime);
167141
CModelFactory::TDataGathererPtr gatherer(factory.makeDataGatherer(gathererNoGapInitData));
168142
CModelFactory::SModelInitializationData modelNoGapInitData(gatherer);
169-
addArrival(*gatherer, m_ResourceMonitor, 200, "p");
143+
this->addArrival(*gatherer, 200, "p");
170144

171145
CAnomalyDetectorModel::TModelPtr modelHolderNoGap(factory.makeModel(modelNoGapInitData));
172146
CCountingModel* modelNoGap =
@@ -215,7 +189,7 @@ BOOST_FIXTURE_TEST_CASE(testCheckScheduledEvents, CTestFixture) {
215189
CModelFactory::SGathererInitializationData gathererNoGapInitData(startTime);
216190
CModelFactory::TDataGathererPtr gatherer(factory.makeDataGatherer(gathererNoGapInitData));
217191
CModelFactory::SModelInitializationData modelNoGapInitData(gatherer);
218-
addArrival(*gatherer, m_ResourceMonitor, 100, "p");
192+
this->addArrival(*gatherer, 100, "p");
219193

220194
CAnomalyDetectorModel::TModelPtr modelHolderNoGap(factory.makeModel(modelNoGapInitData));
221195
CCountingModel* modelNoGap =
@@ -259,8 +233,8 @@ BOOST_FIXTURE_TEST_CASE(testInterimBucketCorrector, CTestFixture) {
259233

260234
CModelFactory::SGathererInitializationData gathererInitData(time);
261235
CModelFactory::TDataGathererPtr gatherer(factory.makeDataGatherer(gathererInitData));
262-
BOOST_REQUIRE_EQUAL(std::size_t(0), addPerson("p1", gatherer, m_ResourceMonitor));
263-
BOOST_REQUIRE_EQUAL(std::size_t(1), addPerson("p2", gatherer, m_ResourceMonitor));
236+
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p1", gatherer));
237+
BOOST_REQUIRE_EQUAL(std::size_t(1), this->addPerson("p2", gatherer));
264238
CModelFactory::SModelInitializationData modelInitData(gatherer);
265239
CAnomalyDetectorModel::TModelPtr modelHolder(factory.makeModel(modelInitData));
266240
CCountingModel* model{dynamic_cast<CCountingModel*>(modelHolder.get())};
@@ -275,9 +249,8 @@ BOOST_FIXTURE_TEST_CASE(testInterimBucketCorrector, CTestFixture) {
275249
std::sort(offsets.begin(), offsets.end());
276250
for (auto offset : offsets) {
277251
rng.generateUniformSamples(0.0, 1.0, 1, uniform01);
278-
addArrival(*gatherer, m_ResourceMonitor,
279-
time + static_cast<core_t::TTime>(offset),
280-
uniform01[0] < 0.5 ? "p1" : "p2");
252+
this->addArrival(*gatherer, time + static_cast<core_t::TTime>(offset),
253+
uniform01[0] < 0.5 ? "p1" : "p2");
281254
}
282255
model->sample(time, time + bucketLength, m_ResourceMonitor);
283256
}
@@ -287,9 +260,8 @@ BOOST_FIXTURE_TEST_CASE(testInterimBucketCorrector, CTestFixture) {
287260

288261
for (std::size_t i = 0u; i < offsets.size(); ++i) {
289262
rng.generateUniformSamples(0.0, 1.0, 1, uniform01);
290-
addArrival(*gatherer, m_ResourceMonitor,
291-
time + static_cast<core_t::TTime>(offsets[i]),
292-
uniform01[0] < 0.5 ? "p1" : "p2");
263+
this->addArrival(*gatherer, time + static_cast<core_t::TTime>(offsets[i]),
264+
uniform01[0] < 0.5 ? "p1" : "p2");
293265
model->sampleBucketStatistics(time, time + bucketLength, m_ResourceMonitor);
294266
BOOST_REQUIRE_EQUAL(static_cast<double>(i + 1) / 10.0,
295267
interimBucketCorrector->completeness());

0 commit comments

Comments
 (0)