Skip to content

Commit 874eb2e

Browse files
authored
[7.5][ML] Add state format version for resilience (#726) (#727)
This PR adds the version number of the current release (7.5) to the format of the training state using for resilience. In the case of the version mismatch, the deserialization of the training state will fail and training will start from scratch. I also took an opportunity to refactor the test for persistence state error handling to have a more granular control of the reasons for failure.
1 parent 1da7698 commit 874eb2e

7 files changed

+558
-169
lines changed

lib/maths/CBayesianOptimisation.cc

Lines changed: 41 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ namespace ml {
2727
namespace maths {
2828

2929
namespace {
30+
const std::string VERSION_7_5_TAG{"7.5"};
31+
3032
const std::string MIN_BOUNDARY_TAG{"min_boundary"};
3133
const std::string MAX_BOUNDARY_TAG{"max_boundary"};
3234
const std::string ERROR_VARIANCES_TAG{"error_variances"};
@@ -443,6 +445,7 @@ double CBayesianOptimisation::kernel(const TVector& a, const TVector& x, const T
443445

444446
void CBayesianOptimisation::acceptPersistInserter(core::CStatePersistInserter& inserter) const {
445447
try {
448+
core::CPersistUtils::persist(VERSION_7_5_TAG, "", inserter);
446449
inserter.insertValue(RNG_TAG, m_Rng.toString());
447450
core::CPersistUtils::persist(MIN_BOUNDARY_TAG, m_MinBoundary, inserter);
448451
core::CPersistUtils::persist(MAX_BOUNDARY_TAG, m_MaxBoundary, inserter);
@@ -460,39 +463,45 @@ void CBayesianOptimisation::acceptPersistInserter(core::CStatePersistInserter& i
460463
}
461464

462465
bool CBayesianOptimisation::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
463-
try {
464-
do {
465-
const std::string& name = traverser.name();
466-
RESTORE(RNG_TAG, m_Rng.fromString(traverser.value()))
467-
RESTORE(MIN_BOUNDARY_TAG,
468-
core::CPersistUtils::restore(MIN_BOUNDARY_TAG, m_MinBoundary, traverser))
469-
RESTORE(MAX_BOUNDARY_TAG,
470-
core::CPersistUtils::restore(MAX_BOUNDARY_TAG, m_MaxBoundary, traverser))
471-
RESTORE(ERROR_VARIANCES_TAG,
472-
core::CPersistUtils::restore(ERROR_VARIANCES_TAG, m_ErrorVariances, traverser))
473-
RESTORE(RANGE_SHIFT_TAG,
474-
core::CPersistUtils::restore(RANGE_SHIFT_TAG, m_RangeShift, traverser))
475-
RESTORE(RANGE_SCALE_TAG,
476-
core::CPersistUtils::restore(RANGE_SCALE_TAG, m_RangeScale, traverser))
477-
RESTORE(RESTARTS_TAG,
478-
core::CPersistUtils::restore(RESTARTS_TAG, m_Restarts, traverser))
479-
RESTORE(KERNEL_PARAMETERS_TAG,
480-
core::CPersistUtils::restore(KERNEL_PARAMETERS_TAG,
481-
m_KernelParameters, traverser))
482-
RESTORE(MIN_KERNEL_COORDINATE_DISTANCE_SCALES_TAG,
483-
core::CPersistUtils::restore(
484-
MIN_KERNEL_COORDINATE_DISTANCE_SCALES_TAG,
485-
m_MinimumKernelCoordinateDistanceScale, traverser))
486-
RESTORE(FUNCTION_MEAN_VALUES_TAG,
487-
core::CPersistUtils::restore(FUNCTION_MEAN_VALUES_TAG,
488-
m_FunctionMeanValues, traverser))
489-
} while (traverser.next());
490-
} catch (std::exception& e) {
491-
LOG_ERROR(<< "Failed to restore state! " << e.what());
492-
return false;
493-
}
466+
if (traverser.name() == VERSION_7_5_TAG) {
467+
try {
468+
do {
469+
const std::string& name = traverser.name();
470+
RESTORE(RNG_TAG, m_Rng.fromString(traverser.value()))
471+
RESTORE(MIN_BOUNDARY_TAG,
472+
core::CPersistUtils::restore(MIN_BOUNDARY_TAG, m_MinBoundary, traverser))
473+
RESTORE(MAX_BOUNDARY_TAG,
474+
core::CPersistUtils::restore(MAX_BOUNDARY_TAG, m_MaxBoundary, traverser))
475+
RESTORE(ERROR_VARIANCES_TAG,
476+
core::CPersistUtils::restore(ERROR_VARIANCES_TAG,
477+
m_ErrorVariances, traverser))
478+
RESTORE(RANGE_SHIFT_TAG,
479+
core::CPersistUtils::restore(RANGE_SHIFT_TAG, m_RangeShift, traverser))
480+
RESTORE(RANGE_SCALE_TAG,
481+
core::CPersistUtils::restore(RANGE_SCALE_TAG, m_RangeScale, traverser))
482+
RESTORE(RESTARTS_TAG,
483+
core::CPersistUtils::restore(RESTARTS_TAG, m_Restarts, traverser))
484+
RESTORE(KERNEL_PARAMETERS_TAG,
485+
core::CPersistUtils::restore(KERNEL_PARAMETERS_TAG,
486+
m_KernelParameters, traverser))
487+
RESTORE(MIN_KERNEL_COORDINATE_DISTANCE_SCALES_TAG,
488+
core::CPersistUtils::restore(
489+
MIN_KERNEL_COORDINATE_DISTANCE_SCALES_TAG,
490+
m_MinimumKernelCoordinateDistanceScale, traverser))
491+
RESTORE(FUNCTION_MEAN_VALUES_TAG,
492+
core::CPersistUtils::restore(FUNCTION_MEAN_VALUES_TAG,
493+
m_FunctionMeanValues, traverser))
494+
} while (traverser.next());
495+
} catch (std::exception& e) {
496+
LOG_ERROR(<< "Failed to restore state! " << e.what());
497+
return false;
498+
}
494499

495-
return true;
500+
return true;
501+
}
502+
LOG_ERROR(<< "Input error: unsupported state serialization version. Currently supported version: "
503+
<< VERSION_7_5_TAG);
504+
return false;
496505
}
497506

498507
std::size_t CBayesianOptimisation::memoryUsage() const {

lib/maths/CBoostedTreeImpl.cc

Lines changed: 95 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -961,13 +961,15 @@ std::size_t CBoostedTreeImpl::maximumTreeSize(std::size_t numberRows) const {
961961
const std::size_t CBoostedTreeImpl::PACKED_BIT_VECTOR_MAXIMUM_ROWS_PER_BYTE{256};
962962

963963
namespace {
964+
const std::string VERSION_7_5_TAG{"7.5"};
965+
964966
const std::string BAYESIAN_OPTIMIZATION_TAG{"bayesian_optimization"};
965967
const std::string BEST_FOREST_TAG{"best_forest"};
966968
const std::string BEST_FOREST_TEST_LOSS_TAG{"best_forest_test_loss"};
967969
const std::string BEST_HYPERPARAMETERS_TAG{"best_hyperparameters"};
968970
const std::string CURRENT_ROUND_TAG{"current_round"};
969971
const std::string DEPENDENT_VARIABLE_TAG{"dependent_variable"};
970-
const std::string ENCODER_TAG{"encoder_tag"};
972+
const std::string ENCODER_TAG{"encoder"};
971973
const std::string ETA_GROWTH_RATE_PER_TREE_TAG{"eta_growth_rate_per_tree"};
972974
const std::string ETA_OVERRIDE_TAG{"eta_override"};
973975
const std::string ETA_TAG{"eta"};
@@ -1034,6 +1036,7 @@ void CBoostedTreeImpl::SHyperparameters::acceptPersistInserter(core::CStatePersi
10341036
}
10351037

10361038
void CBoostedTreeImpl::acceptPersistInserter(core::CStatePersistInserter& inserter) const {
1039+
core::CPersistUtils::persist(VERSION_7_5_TAG, "", inserter);
10371040
core::CPersistUtils::persist(BAYESIAN_OPTIMIZATION_TAG, *m_BayesianOptimization, inserter);
10381041
core::CPersistUtils::persist(BEST_FOREST_TEST_LOSS_TAG, m_BestForestTestLoss, inserter);
10391042
core::CPersistUtils::persist(CURRENT_ROUND_TAG, m_CurrentRound, inserter);
@@ -1120,89 +1123,97 @@ bool CBoostedTreeImpl::SHyperparameters::acceptRestoreTraverser(core::CStateRest
11201123
}
11211124

11221125
bool CBoostedTreeImpl::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
1123-
do {
1124-
const std::string& name = traverser.name();
1125-
RESTORE_NO_ERROR(BAYESIAN_OPTIMIZATION_TAG,
1126-
m_BayesianOptimization =
1127-
std::make_unique<CBayesianOptimisation>(traverser))
1128-
RESTORE(BEST_FOREST_TEST_LOSS_TAG,
1129-
core::CPersistUtils::restore(BEST_FOREST_TEST_LOSS_TAG,
1130-
m_BestForestTestLoss, traverser))
1131-
RESTORE(CURRENT_ROUND_TAG,
1132-
core::CPersistUtils::restore(CURRENT_ROUND_TAG, m_CurrentRound, traverser))
1133-
RESTORE(DEPENDENT_VARIABLE_TAG,
1134-
core::CPersistUtils::restore(DEPENDENT_VARIABLE_TAG,
1135-
m_DependentVariable, traverser))
1136-
RESTORE_NO_ERROR(ENCODER_TAG,
1137-
m_Encoder = std::make_unique<CDataFrameCategoryEncoder>(traverser))
1138-
RESTORE(ETA_GROWTH_RATE_PER_TREE_TAG,
1139-
core::CPersistUtils::restore(ETA_GROWTH_RATE_PER_TREE_TAG,
1140-
m_EtaGrowthRatePerTree, traverser))
1141-
RESTORE(ETA_TAG, core::CPersistUtils::restore(ETA_TAG, m_Eta, traverser))
1142-
RESTORE(FEATURE_BAG_FRACTION_TAG,
1143-
core::CPersistUtils::restore(FEATURE_BAG_FRACTION_TAG,
1144-
m_FeatureBagFraction, traverser))
1145-
RESTORE(FEATURE_DATA_TYPES_TAG,
1146-
core::CPersistUtils::restore(FEATURE_DATA_TYPES_TAG,
1147-
m_FeatureDataTypes, traverser));
1148-
RESTORE(FEATURE_SAMPLE_PROBABILITIES_TAG,
1149-
core::CPersistUtils::restore(FEATURE_SAMPLE_PROBABILITIES_TAG,
1150-
m_FeatureSampleProbabilities, traverser))
1151-
RESTORE(MAXIMUM_ATTEMPTS_TO_ADD_TREE_TAG,
1152-
core::CPersistUtils::restore(MAXIMUM_ATTEMPTS_TO_ADD_TREE_TAG,
1153-
m_MaximumAttemptsToAddTree, traverser))
1154-
RESTORE(MAXIMUM_OPTIMISATION_ROUNDS_PER_HYPERPARAMETER_TAG,
1155-
core::CPersistUtils::restore(
1156-
MAXIMUM_OPTIMISATION_ROUNDS_PER_HYPERPARAMETER_TAG,
1157-
m_MaximumOptimisationRoundsPerHyperparameter, traverser))
1158-
RESTORE(MAXIMUM_TREE_SIZE_MULTIPLIER_TAG,
1159-
core::CPersistUtils::restore(MAXIMUM_TREE_SIZE_MULTIPLIER_TAG,
1160-
m_MaximumTreeSizeMultiplier, traverser))
1161-
RESTORE(MISSING_FEATURE_ROW_MASKS_TAG,
1162-
core::CPersistUtils::restore(MISSING_FEATURE_ROW_MASKS_TAG,
1163-
m_MissingFeatureRowMasks, traverser))
1164-
RESTORE(NUMBER_FOLDS_TAG,
1165-
core::CPersistUtils::restore(NUMBER_FOLDS_TAG, m_NumberFolds, traverser))
1166-
RESTORE(NUMBER_ROUNDS_TAG,
1167-
core::CPersistUtils::restore(NUMBER_ROUNDS_TAG, m_NumberRounds, traverser))
1168-
RESTORE(NUMBER_SPLITS_PER_FEATURE_TAG,
1169-
core::CPersistUtils::restore(NUMBER_SPLITS_PER_FEATURE_TAG,
1170-
m_NumberSplitsPerFeature, traverser))
1171-
RESTORE(NUMBER_THREADS_TAG,
1172-
core::CPersistUtils::restore(NUMBER_THREADS_TAG, m_NumberThreads, traverser))
1173-
RESTORE(RANDOM_NUMBER_GENERATOR_TAG, m_Rng.fromString(traverser.value()))
1174-
RESTORE(REGULARIZATION_TAG,
1175-
core::CPersistUtils::restore(REGULARIZATION_TAG, m_Regularization, traverser))
1176-
RESTORE(REGULARIZATION_OVERRIDE_TAG,
1177-
core::CPersistUtils::restore(REGULARIZATION_OVERRIDE_TAG,
1178-
m_RegularizationOverride, traverser))
1179-
RESTORE(ROWS_PER_FEATURE_TAG,
1180-
core::CPersistUtils::restore(ROWS_PER_FEATURE_TAG, m_RowsPerFeature, traverser))
1181-
RESTORE(TESTING_ROW_MASKS_TAG,
1182-
core::CPersistUtils::restore(TESTING_ROW_MASKS_TAG, m_TestingRowMasks, traverser))
1183-
RESTORE(MAXIMUM_NUMBER_TREES_TAG,
1184-
core::CPersistUtils::restore(MAXIMUM_NUMBER_TREES_TAG,
1185-
m_MaximumNumberTrees, traverser))
1186-
RESTORE(TRAINING_ROW_MASKS_TAG,
1187-
core::CPersistUtils::restore(TRAINING_ROW_MASKS_TAG, m_TrainingRowMasks, traverser))
1188-
RESTORE(TRAINING_PROGRESS_TAG,
1189-
core::CPersistUtils::restore(TRAINING_PROGRESS_TAG, m_TrainingProgress, traverser))
1190-
RESTORE(BEST_FOREST_TAG,
1191-
core::CPersistUtils::restore(BEST_FOREST_TAG, m_BestForest, traverser))
1192-
RESTORE(BEST_HYPERPARAMETERS_TAG,
1193-
core::CPersistUtils::restore(BEST_HYPERPARAMETERS_TAG,
1194-
m_BestHyperparameters, traverser))
1195-
RESTORE(ETA_OVERRIDE_TAG,
1196-
core::CPersistUtils::restore(ETA_OVERRIDE_TAG, m_EtaOverride, traverser))
1197-
RESTORE(FEATURE_BAG_FRACTION_OVERRIDE_TAG,
1198-
core::CPersistUtils::restore(FEATURE_BAG_FRACTION_OVERRIDE_TAG,
1199-
m_FeatureBagFractionOverride, traverser))
1200-
RESTORE(MAXIMUM_NUMBER_TREES_OVERRIDE_TAG,
1201-
core::CPersistUtils::restore(MAXIMUM_NUMBER_TREES_OVERRIDE_TAG,
1202-
m_MaximumNumberTreesOverride, traverser))
1203-
RESTORE(LOSS_TAG, restoreLoss(m_Loss, traverser))
1204-
} while (traverser.next());
1205-
return true;
1126+
if (traverser.name() == VERSION_7_5_TAG) {
1127+
do {
1128+
const std::string& name = traverser.name();
1129+
RESTORE_NO_ERROR(BAYESIAN_OPTIMIZATION_TAG,
1130+
m_BayesianOptimization =
1131+
std::make_unique<CBayesianOptimisation>(traverser))
1132+
RESTORE(BEST_FOREST_TEST_LOSS_TAG,
1133+
core::CPersistUtils::restore(BEST_FOREST_TEST_LOSS_TAG,
1134+
m_BestForestTestLoss, traverser))
1135+
RESTORE(CURRENT_ROUND_TAG,
1136+
core::CPersistUtils::restore(CURRENT_ROUND_TAG, m_CurrentRound, traverser))
1137+
RESTORE(DEPENDENT_VARIABLE_TAG,
1138+
core::CPersistUtils::restore(DEPENDENT_VARIABLE_TAG,
1139+
m_DependentVariable, traverser))
1140+
RESTORE_NO_ERROR(ENCODER_TAG,
1141+
m_Encoder = std::make_unique<CDataFrameCategoryEncoder>(traverser))
1142+
RESTORE(ETA_GROWTH_RATE_PER_TREE_TAG,
1143+
core::CPersistUtils::restore(ETA_GROWTH_RATE_PER_TREE_TAG,
1144+
m_EtaGrowthRatePerTree, traverser))
1145+
RESTORE(ETA_TAG, core::CPersistUtils::restore(ETA_TAG, m_Eta, traverser))
1146+
RESTORE(FEATURE_BAG_FRACTION_TAG,
1147+
core::CPersistUtils::restore(FEATURE_BAG_FRACTION_TAG,
1148+
m_FeatureBagFraction, traverser))
1149+
RESTORE(FEATURE_DATA_TYPES_TAG,
1150+
core::CPersistUtils::restore(FEATURE_DATA_TYPES_TAG,
1151+
m_FeatureDataTypes, traverser));
1152+
RESTORE(FEATURE_SAMPLE_PROBABILITIES_TAG,
1153+
core::CPersistUtils::restore(FEATURE_SAMPLE_PROBABILITIES_TAG,
1154+
m_FeatureSampleProbabilities, traverser))
1155+
RESTORE(MAXIMUM_ATTEMPTS_TO_ADD_TREE_TAG,
1156+
core::CPersistUtils::restore(MAXIMUM_ATTEMPTS_TO_ADD_TREE_TAG,
1157+
m_MaximumAttemptsToAddTree, traverser))
1158+
RESTORE(MAXIMUM_OPTIMISATION_ROUNDS_PER_HYPERPARAMETER_TAG,
1159+
core::CPersistUtils::restore(
1160+
MAXIMUM_OPTIMISATION_ROUNDS_PER_HYPERPARAMETER_TAG,
1161+
m_MaximumOptimisationRoundsPerHyperparameter, traverser))
1162+
RESTORE(MAXIMUM_TREE_SIZE_MULTIPLIER_TAG,
1163+
core::CPersistUtils::restore(MAXIMUM_TREE_SIZE_MULTIPLIER_TAG,
1164+
m_MaximumTreeSizeMultiplier, traverser))
1165+
RESTORE(MISSING_FEATURE_ROW_MASKS_TAG,
1166+
core::CPersistUtils::restore(MISSING_FEATURE_ROW_MASKS_TAG,
1167+
m_MissingFeatureRowMasks, traverser))
1168+
RESTORE(NUMBER_FOLDS_TAG,
1169+
core::CPersistUtils::restore(NUMBER_FOLDS_TAG, m_NumberFolds, traverser))
1170+
RESTORE(NUMBER_ROUNDS_TAG,
1171+
core::CPersistUtils::restore(NUMBER_ROUNDS_TAG, m_NumberRounds, traverser))
1172+
RESTORE(NUMBER_SPLITS_PER_FEATURE_TAG,
1173+
core::CPersistUtils::restore(NUMBER_SPLITS_PER_FEATURE_TAG,
1174+
m_NumberSplitsPerFeature, traverser))
1175+
RESTORE(NUMBER_THREADS_TAG,
1176+
core::CPersistUtils::restore(NUMBER_THREADS_TAG, m_NumberThreads, traverser))
1177+
RESTORE(RANDOM_NUMBER_GENERATOR_TAG, m_Rng.fromString(traverser.value()))
1178+
RESTORE(REGULARIZATION_TAG,
1179+
core::CPersistUtils::restore(REGULARIZATION_TAG, m_Regularization, traverser))
1180+
RESTORE(REGULARIZATION_OVERRIDE_TAG,
1181+
core::CPersistUtils::restore(REGULARIZATION_OVERRIDE_TAG,
1182+
m_RegularizationOverride, traverser))
1183+
RESTORE(ROWS_PER_FEATURE_TAG,
1184+
core::CPersistUtils::restore(ROWS_PER_FEATURE_TAG, m_RowsPerFeature, traverser))
1185+
RESTORE(TESTING_ROW_MASKS_TAG,
1186+
core::CPersistUtils::restore(TESTING_ROW_MASKS_TAG,
1187+
m_TestingRowMasks, traverser))
1188+
RESTORE(MAXIMUM_NUMBER_TREES_TAG,
1189+
core::CPersistUtils::restore(MAXIMUM_NUMBER_TREES_TAG,
1190+
m_MaximumNumberTrees, traverser))
1191+
RESTORE(TRAINING_ROW_MASKS_TAG,
1192+
core::CPersistUtils::restore(TRAINING_ROW_MASKS_TAG,
1193+
m_TrainingRowMasks, traverser))
1194+
RESTORE(TRAINING_PROGRESS_TAG,
1195+
core::CPersistUtils::restore(TRAINING_PROGRESS_TAG,
1196+
m_TrainingProgress, traverser))
1197+
RESTORE(BEST_FOREST_TAG,
1198+
core::CPersistUtils::restore(BEST_FOREST_TAG, m_BestForest, traverser))
1199+
RESTORE(BEST_HYPERPARAMETERS_TAG,
1200+
core::CPersistUtils::restore(BEST_HYPERPARAMETERS_TAG,
1201+
m_BestHyperparameters, traverser))
1202+
RESTORE(ETA_OVERRIDE_TAG,
1203+
core::CPersistUtils::restore(ETA_OVERRIDE_TAG, m_EtaOverride, traverser))
1204+
RESTORE(FEATURE_BAG_FRACTION_OVERRIDE_TAG,
1205+
core::CPersistUtils::restore(FEATURE_BAG_FRACTION_OVERRIDE_TAG,
1206+
m_FeatureBagFractionOverride, traverser))
1207+
RESTORE(MAXIMUM_NUMBER_TREES_OVERRIDE_TAG,
1208+
core::CPersistUtils::restore(MAXIMUM_NUMBER_TREES_OVERRIDE_TAG,
1209+
m_MaximumNumberTreesOverride, traverser))
1210+
RESTORE(LOSS_TAG, restoreLoss(m_Loss, traverser))
1211+
} while (traverser.next());
1212+
return true;
1213+
}
1214+
LOG_ERROR(<< "Input error: unsupported state serialization version. Currently supported version: "
1215+
<< VERSION_7_5_TAG);
1216+
return false;
12061217
}
12071218

12081219
bool CBoostedTreeImpl::restoreLoss(CBoostedTree::TLossFunctionUPtr& loss,

0 commit comments

Comments
 (0)