Skip to content

Commit c6200ea

Browse files
authored
[7.x][ML] Provide factory setup for creating models (elastic#1527) (elastic#1532)
Move boilerplate code for creating models to a base class method. This goes some way to reducing duplicated code and standardizing how models are created in the tests. Backports elastic#1527
1 parent 0cd1feb commit c6200ea

6 files changed

+386
-495
lines changed

lib/model/unittest/CCountingModelTest.cc

+37-49
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ using namespace model;
3131

3232
class CTestFixture : public CModelTestFixtureBase {
3333
protected:
34-
SModelParams::TStrDetectionRulePr
34+
static SModelParams::TStrDetectionRulePr
3535
makeScheduledEvent(const std::string& description, double start, double end) {
3636
CRuleCondition conditionGte;
3737
conditionGte.appliesTo(CRuleCondition::E_Time);
@@ -50,6 +50,13 @@ class CTestFixture : public CModelTestFixtureBase {
5050
SModelParams::TStrDetectionRulePr event = std::make_pair(description, rule);
5151
return event;
5252
}
53+
54+
void makeModel(const SModelParams& params,
55+
const model_t::TFeatureVec& features,
56+
core_t::TTime startTime) {
57+
this->makeModelT<CCountingModelFactory>(
58+
params, features, startTime, model_t::E_Counting, m_Gatherer, m_Model);
59+
}
5360
};
5461

5562
BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) {
@@ -66,14 +73,11 @@ BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) {
6673

6774
// Model where gap is not skipped
6875
{
69-
CModelFactory::SGathererInitializationData gathererNoGapInitData(startTime);
70-
CModelFactory::TDataGathererPtr gathererNoGap(
71-
factory.makeDataGatherer(gathererNoGapInitData));
72-
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p", gathererNoGap));
73-
CModelFactory::SModelInitializationData modelNoGapInitData(gathererNoGap);
74-
CAnomalyDetectorModel::TModelPtr modelHolderNoGap(factory.makeModel(modelNoGapInitData));
75-
CCountingModel* modelNoGap =
76-
dynamic_cast<CCountingModel*>(modelHolderNoGap.get());
76+
CModelFactory::TDataGathererPtr gathererNoGap;
77+
CModelFactory::TModelPtr modelNoGap;
78+
this->makeModelT<CCountingModelFactory>(
79+
params, features, startTime, model_t::E_Counting, gathererNoGap, modelNoGap);
80+
BOOST_REQUIRE_EQUAL(0, this->addPerson("p", gathererNoGap));
7781

7882
// |2|2|0|0|1| -> 1.0 mean count
7983
this->addArrival(*gathererNoGap, 100, "p");
@@ -90,15 +94,12 @@ BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) {
9094

9195
// Model where gap is skipped
9296
{
93-
CModelFactory::SGathererInitializationData gathererWithGapInitData(startTime);
94-
CModelFactory::TDataGathererPtr gathererWithGap(
95-
factory.makeDataGatherer(gathererWithGapInitData));
96-
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p", gathererWithGap));
97-
CModelFactory::SModelInitializationData modelWithGapInitData(gathererWithGap);
98-
CAnomalyDetectorModel::TModelPtr modelHolderWithGap(
99-
factory.makeModel(modelWithGapInitData));
100-
CCountingModel* modelWithGap =
101-
dynamic_cast<CCountingModel*>(modelHolderWithGap.get());
97+
CModelFactory::TDataGathererPtr gathererWithGap;
98+
CModelFactory::TModelPtr modelWithGap;
99+
this->makeModelT<CCountingModelFactory>(params, features, startTime,
100+
model_t::E_Counting,
101+
gathererWithGap, modelWithGap);
102+
BOOST_REQUIRE_EQUAL(0, this->addPerson("p", gathererWithGap));
102103

103104
// |2|2|0|0|1|
104105
// |2|X|X|X|1| -> 1.5 mean count where X means skipped bucket
@@ -109,7 +110,7 @@ BOOST_FIXTURE_TEST_CASE(testSkipSampling, CTestFixture) {
109110
this->addArrival(*gathererWithGap, 280, "p");
110111
modelWithGap->skipSampling(500);
111112
modelWithGap->prune(maxAgeBuckets);
112-
BOOST_REQUIRE_EQUAL(std::size_t(1), gathererWithGap->numberActivePeople());
113+
BOOST_REQUIRE_EQUAL(1, gathererWithGap->numberActivePeople());
113114
this->addArrival(*gathererWithGap, 500, "p");
114115
modelWithGap->sample(500, 600, m_ResourceMonitor);
115116

@@ -137,14 +138,10 @@ BOOST_FIXTURE_TEST_CASE(testCheckScheduledEvents, CTestFixture) {
137138
factory.features(features);
138139

139140
{
140-
CModelFactory::SGathererInitializationData gathererNoGapInitData(startTime);
141-
CModelFactory::TDataGathererPtr gatherer(factory.makeDataGatherer(gathererNoGapInitData));
142-
CModelFactory::SModelInitializationData modelNoGapInitData(gatherer);
143-
this->addArrival(*gatherer, 200, "p");
144-
145-
CAnomalyDetectorModel::TModelPtr modelHolderNoGap(factory.makeModel(modelNoGapInitData));
146-
CCountingModel* modelNoGap =
147-
dynamic_cast<CCountingModel*>(modelHolderNoGap.get());
141+
this->makeModel(params, features, startTime);
142+
CCountingModel* modelNoGap = dynamic_cast<CCountingModel*>(m_Model.get());
143+
BOOST_TEST_REQUIRE(modelNoGap);
144+
BOOST_REQUIRE_EQUAL(0, this->addPerson("p", m_Gatherer));
148145

149146
SModelParams::TStrDetectionRulePrVec matchedEvents =
150147
modelNoGap->checkScheduledEvents(50);
@@ -186,14 +183,10 @@ BOOST_FIXTURE_TEST_CASE(testCheckScheduledEvents, CTestFixture) {
186183

187184
// Test sampleBucketStatistics
188185
{
189-
CModelFactory::SGathererInitializationData gathererNoGapInitData(startTime);
190-
CModelFactory::TDataGathererPtr gatherer(factory.makeDataGatherer(gathererNoGapInitData));
191-
CModelFactory::SModelInitializationData modelNoGapInitData(gatherer);
192-
this->addArrival(*gatherer, 100, "p");
193-
194-
CAnomalyDetectorModel::TModelPtr modelHolderNoGap(factory.makeModel(modelNoGapInitData));
195-
CCountingModel* modelNoGap =
196-
dynamic_cast<CCountingModel*>(modelHolderNoGap.get());
186+
this->makeModel(params, features, startTime);
187+
CCountingModel* modelNoGap = dynamic_cast<CCountingModel*>(m_Model.get());
188+
BOOST_TEST_REQUIRE(modelNoGap);
189+
BOOST_REQUIRE_EQUAL(0, this->addPerson("p", m_Gatherer));
197190

198191
// There are no events at this time
199192
modelNoGap->sampleBucketStatistics(0, 100, m_ResourceMonitor);
@@ -226,18 +219,13 @@ BOOST_FIXTURE_TEST_CASE(testInterimBucketCorrector, CTestFixture) {
226219

227220
SModelParams params(bucketLength);
228221
params.s_DecayRate = 0.001;
229-
auto interimBucketCorrector = std::make_shared<CInterimBucketCorrector>(bucketLength);
230-
CCountingModelFactory factory(params, interimBucketCorrector);
231-
model_t::TFeatureVec features{model_t::E_IndividualCountByBucketAndPerson};
232-
factory.features(features);
233222

234-
CModelFactory::SGathererInitializationData gathererInitData(time);
235-
CModelFactory::TDataGathererPtr gatherer(factory.makeDataGatherer(gathererInitData));
236-
BOOST_REQUIRE_EQUAL(std::size_t(0), this->addPerson("p1", gatherer));
237-
BOOST_REQUIRE_EQUAL(std::size_t(1), this->addPerson("p2", gatherer));
238-
CModelFactory::SModelInitializationData modelInitData(gatherer);
239-
CAnomalyDetectorModel::TModelPtr modelHolder(factory.makeModel(modelInitData));
240-
CCountingModel* model{dynamic_cast<CCountingModel*>(modelHolder.get())};
223+
this->makeModel(params, {model_t::E_IndividualCountByBucketAndPerson}, time);
224+
CCountingModel* model = dynamic_cast<CCountingModel*>(m_Model.get());
225+
BOOST_TEST_REQUIRE(model);
226+
227+
BOOST_REQUIRE_EQUAL(0, this->addPerson("p1", m_Gatherer));
228+
BOOST_REQUIRE_EQUAL(1, this->addPerson("p2", m_Gatherer));
241229

242230
test::CRandomNumbers rng;
243231

@@ -249,7 +237,7 @@ BOOST_FIXTURE_TEST_CASE(testInterimBucketCorrector, CTestFixture) {
249237
std::sort(offsets.begin(), offsets.end());
250238
for (auto offset : offsets) {
251239
rng.generateUniformSamples(0.0, 1.0, 1, uniform01);
252-
this->addArrival(*gatherer, time + static_cast<core_t::TTime>(offset),
240+
this->addArrival(*m_Gatherer, time + static_cast<core_t::TTime>(offset),
253241
uniform01[0] < 0.5 ? "p1" : "p2");
254242
}
255243
model->sample(time, time + bucketLength, m_ResourceMonitor);
@@ -260,11 +248,11 @@ BOOST_FIXTURE_TEST_CASE(testInterimBucketCorrector, CTestFixture) {
260248

261249
for (std::size_t i = 0u; i < offsets.size(); ++i) {
262250
rng.generateUniformSamples(0.0, 1.0, 1, uniform01);
263-
this->addArrival(*gatherer, time + static_cast<core_t::TTime>(offsets[i]),
251+
this->addArrival(*m_Gatherer, time + static_cast<core_t::TTime>(offsets[i]),
264252
uniform01[0] < 0.5 ? "p1" : "p2");
265253
model->sampleBucketStatistics(time, time + bucketLength, m_ResourceMonitor);
266254
BOOST_REQUIRE_EQUAL(static_cast<double>(i + 1) / 10.0,
267-
interimBucketCorrector->completeness());
255+
m_InterimBucketCorrector->completeness());
268256
}
269257
}
270258

0 commit comments

Comments
 (0)