Skip to content

Commit 1635575

Browse files
authored
[ML] Some corrections to error handling on fail to restore boosted tree training state (#577)
1 parent b341dbd commit 1635575

8 files changed

+305
-203
lines changed

include/core/RestoreMacros.h

Lines changed: 61 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -10,44 +10,60 @@
1010
namespace ml {
1111
namespace core {
1212

13-
#define RESTORE(tag, restore) \
14-
if (name == tag) { \
15-
if ((restore) == false) { \
16-
LOG_ERROR(<< "Failed to restore " #tag ", got " << traverser.value()); \
17-
return false; \
18-
} \
19-
continue; \
13+
#define RESTORE(tag, restore) \
14+
if (name == tag) { \
15+
if ((restore) == false) { \
16+
if (traverser.value().empty()) { \
17+
LOG_ERROR(<< "Failed to restore " #tag); \
18+
} else { \
19+
LOG_ERROR(<< "Failed to restore " #tag ", got " << traverser.value()); \
20+
} \
21+
return false; \
22+
} \
23+
continue; \
2024
}
2125

22-
#define RESTORE_BUILT_IN(tag, target) \
23-
if (name == tag) { \
24-
if (core::CStringUtils::stringToType(traverser.value(), target) == false) { \
25-
LOG_ERROR(<< "Failed to restore " #tag ", got " << traverser.value()); \
26-
return false; \
27-
} \
28-
continue; \
26+
#define RESTORE_BUILT_IN(tag, target) \
27+
if (name == tag) { \
28+
if (core::CStringUtils::stringToType(traverser.value(), target) == false) { \
29+
if (traverser.value().empty()) { \
30+
LOG_ERROR(<< "Failed to restore " #tag); \
31+
} else { \
32+
LOG_ERROR(<< "Failed to restore " #tag ", got " << traverser.value()); \
33+
} \
34+
return false; \
35+
} \
36+
continue; \
2937
}
3038

31-
#define RESTORE_BOOL(tag, target) \
32-
if (name == tag) { \
33-
int value; \
34-
if (core::CStringUtils::stringToType(traverser.value(), value) == false) { \
35-
LOG_ERROR(<< "Failed to restore " #tag ", got " << traverser.value()); \
36-
return false; \
37-
} \
38-
target = (value != 0); \
39-
continue; \
39+
#define RESTORE_BOOL(tag, target) \
40+
if (name == tag) { \
41+
int value; \
42+
if (core::CStringUtils::stringToType(traverser.value(), value) == false) { \
43+
if (traverser.value().empty()) { \
44+
LOG_ERROR(<< "Failed to restore " #tag); \
45+
} else { \
46+
LOG_ERROR(<< "Failed to restore " #tag ", got " << traverser.value()); \
47+
} \
48+
return false; \
49+
} \
50+
target = (value != 0); \
51+
continue; \
4052
}
4153

42-
#define RESTORE_ENUM(tag, target, enumtype) \
43-
if (name == tag) { \
44-
int value; \
45-
if (core::CStringUtils::stringToType(traverser.value(), value) == false) { \
46-
LOG_ERROR(<< "Failed to restore " #tag ", got " << traverser.value()); \
47-
return false; \
48-
} \
49-
target = enumtype(value); \
50-
continue; \
54+
#define RESTORE_ENUM(tag, target, enumtype) \
55+
if (name == tag) { \
56+
int value; \
57+
if (core::CStringUtils::stringToType(traverser.value(), value) == false) { \
58+
if (traverser.value().empty()) { \
59+
LOG_ERROR(<< "Failed to restore " #tag); \
60+
} else { \
61+
LOG_ERROR(<< "Failed to restore " #tag ", got " << traverser.value()); \
62+
} \
63+
return false; \
64+
} \
65+
target = enumtype(value); \
66+
continue; \
5167
}
5268

5369
#define RESTORE_ENUM_CHECKED(tag, target, enumtype, restoreSuccess) \
@@ -56,15 +72,19 @@ namespace core {
5672
RESTORE_ENUM(tag, target, enumtype) \
5773
}
5874

59-
#define RESTORE_SETUP_TEARDOWN(tag, setup, restore, teardown) \
60-
if (name == tag) { \
61-
setup; \
62-
if ((restore) == false) { \
63-
LOG_ERROR(<< "Failed to restore " #tag ", got " << traverser.value()); \
64-
return false; \
65-
} \
66-
teardown; \
67-
continue; \
75+
#define RESTORE_SETUP_TEARDOWN(tag, setup, restore, teardown) \
76+
if (name == tag) { \
77+
setup; \
78+
if ((restore) == false) { \
79+
if (traverser.value().empty()) { \
80+
LOG_ERROR(<< "Failed to restore " #tag); \
81+
} else { \
82+
LOG_ERROR(<< "Failed to restore " #tag ", got " << traverser.value()); \
83+
} \
84+
return false; \
85+
} \
86+
teardown; \
87+
continue; \
6888
}
6989

7090
#define RESTORE_NO_ERROR(tag, restore) \

include/core/UnwrapRef.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
//! file do not conform to the coding style to ease the eventual transition
1010
//! to std::unwrap_ref.
1111

12-
#ifndef INCLUDED_ml_core_UNWRAPREF_H_
13-
#define INCLUDED_ml_core_UNWRAPREF_H_
12+
#ifndef INCLUDED_ml_core_UnwrapRef_h
13+
#define INCLUDED_ml_core_UnwrapRef_h
1414

1515
#include <functional>
1616

@@ -44,4 +44,4 @@ typename unwrap_reference<T>::type& unwrap_ref(T& t) {
4444
}
4545
}
4646

47-
#endif /* INCLUDED_ml_core_UNWRAPREF_H_ */
47+
#endif // INCLUDED_ml_core_UnwrapRef_h

lib/maths/CBayesianOptimisation.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
#include <boost/math/distributions/normal.hpp>
2222

23+
#include <exception>
24+
2325
namespace ml {
2426
namespace maths {
2527

@@ -48,8 +50,10 @@ CBayesianOptimisation::CBayesianOptimisation(TDoubleDoublePrVec parameterBounds)
4850
}
4951

5052
CBayesianOptimisation::CBayesianOptimisation(core::CStateRestoreTraverser& traverser) {
51-
traverser.traverseSubLevel(std::bind(&CBayesianOptimisation::acceptRestoreTraverser,
52-
this, std::placeholders::_1));
53+
if (traverser.traverseSubLevel(std::bind(&CBayesianOptimisation::acceptRestoreTraverser,
54+
this, std::placeholders::_1)) == false) {
55+
throw std::runtime_error{"failed to restore Bayesian optimisation"};
56+
}
5357
}
5458

5559
void CBayesianOptimisation::add(TVector x, double fx, double vx) {

lib/maths/CBoostedTree.cc

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,11 @@ const std::string BOOSTED_TREE_IMPL_TAG{"boosted_tree_impl"};
8888
}
8989

9090
bool CBoostedTree::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
91-
try {
92-
do {
93-
const std::string& name = traverser.name();
94-
RESTORE(BOOSTED_TREE_IMPL_TAG,
95-
core::CPersistUtils::restore(BOOSTED_TREE_IMPL_TAG, *m_Impl, traverser))
96-
} while (traverser.next());
97-
} catch (std::exception& e) {
98-
LOG_ERROR(<< "Failed to restore state! " << e.what());
99-
return false;
100-
}
101-
91+
do {
92+
const std::string& name = traverser.name();
93+
RESTORE(BOOSTED_TREE_IMPL_TAG,
94+
core::CPersistUtils::restore(BOOSTED_TREE_IMPL_TAG, *m_Impl, traverser))
95+
} while (traverser.next());
10296
return true;
10397
}
10498

lib/maths/CBoostedTreeFactory.cc

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -278,23 +278,30 @@ CBoostedTreeFactory::constructFromParameters(std::size_t numberThreads,
278278
CBoostedTreeFactory::TBoostedTreeUPtr
279279
CBoostedTreeFactory::constructFromString(std::stringstream& jsonStringStream,
280280
core::CDataFrame& frame) {
281-
TBoostedTreeUPtr treePtr{
282-
new CBoostedTree{frame, TBoostedTreeImplUPtr{new CBoostedTreeImpl{}}}};
283-
core::CJsonStateRestoreTraverser traverser(jsonStringStream);
284-
treePtr->acceptRestoreTraverser(traverser);
285-
return treePtr;
281+
try {
282+
TBoostedTreeUPtr treePtr{
283+
new CBoostedTree{frame, TBoostedTreeImplUPtr{new CBoostedTreeImpl{}}}};
284+
core::CJsonStateRestoreTraverser traverser(jsonStringStream);
285+
if (treePtr->acceptRestoreTraverser(traverser) == false || traverser.haveBadState()) {
286+
throw std::runtime_error{"failed to restore boosted tree"};
287+
}
288+
return treePtr;
289+
} catch (const std::exception& e) {
290+
HANDLE_FATAL(<< "Input error: '" << e.what() << "'. Check logs for more details.");
291+
}
292+
return nullptr;
286293
}
287294

288-
CBoostedTreeFactory::~CBoostedTreeFactory() = default;
295+
CBoostedTreeFactory::CBoostedTreeFactory(std::size_t numberThreads,
296+
CBoostedTree::TLossFunctionUPtr loss)
297+
: m_TreeImpl{std::make_unique<CBoostedTreeImpl>(numberThreads, std::move(loss))} {
298+
}
289299

290300
CBoostedTreeFactory::CBoostedTreeFactory(CBoostedTreeFactory&&) = default;
291301

292302
CBoostedTreeFactory& CBoostedTreeFactory::operator=(CBoostedTreeFactory&&) = default;
293303

294-
CBoostedTreeFactory::CBoostedTreeFactory(std::size_t numberThreads,
295-
CBoostedTree::TLossFunctionUPtr loss)
296-
: m_TreeImpl{std::make_unique<CBoostedTreeImpl>(numberThreads, std::move(loss))} {
297-
}
304+
CBoostedTreeFactory::~CBoostedTreeFactory() = default;
298305

299306
CBoostedTreeFactory& CBoostedTreeFactory::numberFolds(std::size_t numberFolds) {
300307
if (numberFolds < 2) {

0 commit comments

Comments
 (0)