Skip to content

Commit 90f22e2

Browse files
committed
Pass the exemplar into the NB restore constructor rather than persisting and restoring
1 parent a7fa732 commit 90f22e2

File tree

4 files changed

+16
-24
lines changed

4 files changed

+16
-24
lines changed

include/maths/CNaiveBayes.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ class MATHS_EXPORT CNaiveBayes {
149149
explicit CNaiveBayes(const CNaiveBayesFeatureDensity& exemplar,
150150
double decayRate = 0.0,
151151
TOptionalDouble minMaxLogLikelihoodToUseFeature = TOptionalDouble());
152-
CNaiveBayes(const SDistributionRestoreParams& params,
152+
CNaiveBayes(const CNaiveBayesFeatureDensity& exemplar,
153+
const SDistributionRestoreParams& params,
153154
core::CStateRestoreTraverser& traverser);
154155
CNaiveBayes(const CNaiveBayes& other);
155156

lib/maths/CNaiveBayes.cc

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ const std::string CLASS_MODEL_TAG{"c"};
3535
const std::string MIN_MAX_LOG_LIKELIHOOD_TO_USE_FEATURE_TAG{"d"};
3636
const std::string COUNT_TAG{"e"};
3737
const std::string CONDITIONAL_DENSITY_FROM_PRIOR_TAG{"f"};
38-
const std::string EXEMPLAR_FROM_PRIOR_TAG{"g"};
3938
}
4039

4140
CNaiveBayesFeatureDensityFromPrior::CNaiveBayesFeatureDensityFromPrior(const CPrior& prior)
@@ -125,9 +124,10 @@ CNaiveBayes::CNaiveBayes(const CNaiveBayesFeatureDensity& exemplar,
125124
m_DecayRate{decayRate}, m_Exemplar{exemplar.clone()}, m_ClassConditionalDensities{2} {
126125
}
127126

128-
CNaiveBayes::CNaiveBayes(const SDistributionRestoreParams& params,
127+
CNaiveBayes::CNaiveBayes(const CNaiveBayesFeatureDensity& exemplar,
128+
const SDistributionRestoreParams& params,
129129
core::CStateRestoreTraverser& traverser)
130-
: m_DecayRate{params.s_DecayRate}, m_ClassConditionalDensities{2} {
130+
: m_DecayRate{params.s_DecayRate}, m_Exemplar{exemplar.clone()}, m_ClassConditionalDensities{2} {
131131
traverser.traverseSubLevel(boost::bind(&CNaiveBayes::acceptRestoreTraverser,
132132
this, boost::cref(params), _1));
133133
}
@@ -146,13 +146,6 @@ bool CNaiveBayes::acceptRestoreTraverser(const SDistributionRestoreParams& param
146146
do {
147147
const std::string& name{traverser.name()};
148148
RESTORE_BUILT_IN(CLASS_LABEL_TAG, label)
149-
RESTORE_SETUP_TEARDOWN(
150-
EXEMPLAR_FROM_PRIOR_TAG, CNaiveBayesFeatureDensityFromPrior density,
151-
traverser.traverseSubLevel(
152-
boost::bind(&CNaiveBayesFeatureDensityFromPrior::acceptRestoreTraverser,
153-
boost::ref(density), boost::cref(params), _1)),
154-
m_Exemplar.reset(density.clone()))
155-
// Add other implementations' restore code here.
156149
RESTORE_SETUP_TEARDOWN(
157150
CLASS_MODEL_TAG, CClass class_,
158151
traverser.traverseSubLevel(boost::bind(&CClass::acceptRestoreTraverser,
@@ -170,13 +163,6 @@ void CNaiveBayes::acceptPersistInserter(core::CStatePersistInserter& inserter) c
170163
using TSizeClassUMapCItr = TSizeClassUMap::const_iterator;
171164
using TSizeClassUMapCItrVec = std::vector<TSizeClassUMapCItr>;
172165

173-
if (dynamic_cast<const CNaiveBayesFeatureDensityFromPrior*>(m_Exemplar.get())) {
174-
inserter.insertLevel(EXEMPLAR_FROM_PRIOR_TAG,
175-
boost::bind(&CNaiveBayesFeatureDensity::acceptPersistInserter,
176-
m_Exemplar.get(), _1));
177-
}
178-
// Add other implementations' persist code here.
179-
180166
TSizeClassUMapCItrVec classes;
181167
classes.reserve(m_ClassConditionalDensities.size());
182168
for (auto i = m_ClassConditionalDensities.begin();

lib/maths/CTrendComponent.cc

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,14 @@ TOptionalDoubleDoublePr confidenceInterval(double prediction, double variance, d
6969
return TOptionalDoubleDoublePr{};
7070
}
7171

72+
CNaiveBayesFeatureDensityFromPrior naiveBayesExemplar(double decayRate) {
73+
return CNaiveBayesFeatureDensityFromPrior{CNormalMeanPrecConjugate::nonInformativePrior(
74+
maths_t::E_ContinuousData, TIME_SCALES[NUMBER_MODELS - 1] * decayRate)};
75+
}
76+
7277
CNaiveBayes initialProbabilityOfChangeModel(double decayRate) {
73-
decayRate *= TIME_SCALES[NUMBER_MODELS - 1];
74-
return CNaiveBayes{CNaiveBayesFeatureDensityFromPrior{CNormalMeanPrecConjugate::nonInformativePrior(
75-
maths_t::E_ContinuousData, decayRate)},
76-
decayRate, -20.0};
78+
return CNaiveBayes{naiveBayesExemplar(decayRate),
79+
TIME_SCALES[NUMBER_MODELS - 1] * decayRate, -20.0};
7780
}
7881

7982
CNormalMeanPrecConjugate initialMagnitudeOfChangeModel(double decayRate) {
@@ -157,7 +160,8 @@ bool CTrendComponent::acceptRestoreTraverser(const SDistributionRestoreParams& p
157160
RESTORE(VALUE_MOMENTS_TAG, m_ValueMoments.fromDelimited(traverser.value()))
158161
RESTORE_BUILT_IN(TIME_OF_LAST_LEVEL_CHANGE_TAG, m_TimeOfLastLevelChange)
159162
RESTORE_NO_ERROR(PROBABILITY_OF_LEVEL_CHANGE_MODEL_TAG,
160-
m_ProbabilityOfLevelChangeModel = CNaiveBayes(params, traverser))
163+
m_ProbabilityOfLevelChangeModel = std::move(CNaiveBayes(
164+
naiveBayesExemplar(m_DefaultDecayRate), params, traverser)))
161165
RESTORE_NO_ERROR(MAGNITUDE_OF_LEVEL_CHANGE_MODEL_TAG,
162166
m_MagnitudeOfLevelChangeModel =
163167
CNormalMeanPrecConjugate(params, traverser))

lib/maths/unittest/CNaiveBayesTest.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,8 @@ void CNaiveBayesTest::testPersist() {
306306
core::CRapidXmlStateRestoreTraverser traverser(parser);
307307

308308
maths::SDistributionRestoreParams params{maths_t::E_ContinuousData, 0.1, 0.0, 0.0, 0.0};
309-
maths::CNaiveBayes restoredNb{params, traverser};
309+
maths::CNaiveBayes restoredNb{maths::CNaiveBayesFeatureDensityFromPrior(normal),
310+
params, traverser};
310311

311312
CPPUNIT_ASSERT_EQUAL(origNb.checksum(), restoredNb.checksum());
312313

0 commit comments

Comments
 (0)