Skip to content

[7.5][ML] Add state format version for resilience (#726) #727

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 41 additions & 32 deletions lib/maths/CBayesianOptimisation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ namespace ml {
namespace maths {

namespace {
const std::string VERSION_7_5_TAG{"7.5"};

const std::string MIN_BOUNDARY_TAG{"min_boundary"};
const std::string MAX_BOUNDARY_TAG{"max_boundary"};
const std::string ERROR_VARIANCES_TAG{"error_variances"};
Expand Down Expand Up @@ -443,6 +445,7 @@ double CBayesianOptimisation::kernel(const TVector& a, const TVector& x, const T

void CBayesianOptimisation::acceptPersistInserter(core::CStatePersistInserter& inserter) const {
try {
core::CPersistUtils::persist(VERSION_7_5_TAG, "", inserter);
inserter.insertValue(RNG_TAG, m_Rng.toString());
core::CPersistUtils::persist(MIN_BOUNDARY_TAG, m_MinBoundary, inserter);
core::CPersistUtils::persist(MAX_BOUNDARY_TAG, m_MaxBoundary, inserter);
Expand All @@ -460,39 +463,45 @@ void CBayesianOptimisation::acceptPersistInserter(core::CStatePersistInserter& i
}

bool CBayesianOptimisation::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
try {
do {
const std::string& name = traverser.name();
RESTORE(RNG_TAG, m_Rng.fromString(traverser.value()))
RESTORE(MIN_BOUNDARY_TAG,
core::CPersistUtils::restore(MIN_BOUNDARY_TAG, m_MinBoundary, traverser))
RESTORE(MAX_BOUNDARY_TAG,
core::CPersistUtils::restore(MAX_BOUNDARY_TAG, m_MaxBoundary, traverser))
RESTORE(ERROR_VARIANCES_TAG,
core::CPersistUtils::restore(ERROR_VARIANCES_TAG, m_ErrorVariances, traverser))
RESTORE(RANGE_SHIFT_TAG,
core::CPersistUtils::restore(RANGE_SHIFT_TAG, m_RangeShift, traverser))
RESTORE(RANGE_SCALE_TAG,
core::CPersistUtils::restore(RANGE_SCALE_TAG, m_RangeScale, traverser))
RESTORE(RESTARTS_TAG,
core::CPersistUtils::restore(RESTARTS_TAG, m_Restarts, traverser))
RESTORE(KERNEL_PARAMETERS_TAG,
core::CPersistUtils::restore(KERNEL_PARAMETERS_TAG,
m_KernelParameters, traverser))
RESTORE(MIN_KERNEL_COORDINATE_DISTANCE_SCALES_TAG,
core::CPersistUtils::restore(
MIN_KERNEL_COORDINATE_DISTANCE_SCALES_TAG,
m_MinimumKernelCoordinateDistanceScale, traverser))
RESTORE(FUNCTION_MEAN_VALUES_TAG,
core::CPersistUtils::restore(FUNCTION_MEAN_VALUES_TAG,
m_FunctionMeanValues, traverser))
} while (traverser.next());
} catch (std::exception& e) {
LOG_ERROR(<< "Failed to restore state! " << e.what());
return false;
}
if (traverser.name() == VERSION_7_5_TAG) {
try {
do {
const std::string& name = traverser.name();
RESTORE(RNG_TAG, m_Rng.fromString(traverser.value()))
RESTORE(MIN_BOUNDARY_TAG,
core::CPersistUtils::restore(MIN_BOUNDARY_TAG, m_MinBoundary, traverser))
RESTORE(MAX_BOUNDARY_TAG,
core::CPersistUtils::restore(MAX_BOUNDARY_TAG, m_MaxBoundary, traverser))
RESTORE(ERROR_VARIANCES_TAG,
core::CPersistUtils::restore(ERROR_VARIANCES_TAG,
m_ErrorVariances, traverser))
RESTORE(RANGE_SHIFT_TAG,
core::CPersistUtils::restore(RANGE_SHIFT_TAG, m_RangeShift, traverser))
RESTORE(RANGE_SCALE_TAG,
core::CPersistUtils::restore(RANGE_SCALE_TAG, m_RangeScale, traverser))
RESTORE(RESTARTS_TAG,
core::CPersistUtils::restore(RESTARTS_TAG, m_Restarts, traverser))
RESTORE(KERNEL_PARAMETERS_TAG,
core::CPersistUtils::restore(KERNEL_PARAMETERS_TAG,
m_KernelParameters, traverser))
RESTORE(MIN_KERNEL_COORDINATE_DISTANCE_SCALES_TAG,
core::CPersistUtils::restore(
MIN_KERNEL_COORDINATE_DISTANCE_SCALES_TAG,
m_MinimumKernelCoordinateDistanceScale, traverser))
RESTORE(FUNCTION_MEAN_VALUES_TAG,
core::CPersistUtils::restore(FUNCTION_MEAN_VALUES_TAG,
m_FunctionMeanValues, traverser))
} while (traverser.next());
} catch (std::exception& e) {
LOG_ERROR(<< "Failed to restore state! " << e.what());
return false;
}

return true;
return true;
}
LOG_ERROR(<< "Input error: unsupported state serialization version. Currently supported version: "
<< VERSION_7_5_TAG);
return false;
}

std::size_t CBayesianOptimisation::memoryUsage() const {
Expand Down
179 changes: 95 additions & 84 deletions lib/maths/CBoostedTreeImpl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -961,13 +961,15 @@ std::size_t CBoostedTreeImpl::maximumTreeSize(std::size_t numberRows) const {
const std::size_t CBoostedTreeImpl::PACKED_BIT_VECTOR_MAXIMUM_ROWS_PER_BYTE{256};

namespace {
const std::string VERSION_7_5_TAG{"7.5"};

const std::string BAYESIAN_OPTIMIZATION_TAG{"bayesian_optimization"};
const std::string BEST_FOREST_TAG{"best_forest"};
const std::string BEST_FOREST_TEST_LOSS_TAG{"best_forest_test_loss"};
const std::string BEST_HYPERPARAMETERS_TAG{"best_hyperparameters"};
const std::string CURRENT_ROUND_TAG{"current_round"};
const std::string DEPENDENT_VARIABLE_TAG{"dependent_variable"};
const std::string ENCODER_TAG{"encoder_tag"};
const std::string ENCODER_TAG{"encoder"};
const std::string ETA_GROWTH_RATE_PER_TREE_TAG{"eta_growth_rate_per_tree"};
const std::string ETA_OVERRIDE_TAG{"eta_override"};
const std::string ETA_TAG{"eta"};
Expand Down Expand Up @@ -1034,6 +1036,7 @@ void CBoostedTreeImpl::SHyperparameters::acceptPersistInserter(core::CStatePersi
}

void CBoostedTreeImpl::acceptPersistInserter(core::CStatePersistInserter& inserter) const {
core::CPersistUtils::persist(VERSION_7_5_TAG, "", inserter);
core::CPersistUtils::persist(BAYESIAN_OPTIMIZATION_TAG, *m_BayesianOptimization, inserter);
core::CPersistUtils::persist(BEST_FOREST_TEST_LOSS_TAG, m_BestForestTestLoss, inserter);
core::CPersistUtils::persist(CURRENT_ROUND_TAG, m_CurrentRound, inserter);
Expand Down Expand Up @@ -1120,89 +1123,97 @@ bool CBoostedTreeImpl::SHyperparameters::acceptRestoreTraverser(core::CStateRest
}

bool CBoostedTreeImpl::acceptRestoreTraverser(core::CStateRestoreTraverser& traverser) {
do {
const std::string& name = traverser.name();
RESTORE_NO_ERROR(BAYESIAN_OPTIMIZATION_TAG,
m_BayesianOptimization =
std::make_unique<CBayesianOptimisation>(traverser))
RESTORE(BEST_FOREST_TEST_LOSS_TAG,
core::CPersistUtils::restore(BEST_FOREST_TEST_LOSS_TAG,
m_BestForestTestLoss, traverser))
RESTORE(CURRENT_ROUND_TAG,
core::CPersistUtils::restore(CURRENT_ROUND_TAG, m_CurrentRound, traverser))
RESTORE(DEPENDENT_VARIABLE_TAG,
core::CPersistUtils::restore(DEPENDENT_VARIABLE_TAG,
m_DependentVariable, traverser))
RESTORE_NO_ERROR(ENCODER_TAG,
m_Encoder = std::make_unique<CDataFrameCategoryEncoder>(traverser))
RESTORE(ETA_GROWTH_RATE_PER_TREE_TAG,
core::CPersistUtils::restore(ETA_GROWTH_RATE_PER_TREE_TAG,
m_EtaGrowthRatePerTree, traverser))
RESTORE(ETA_TAG, core::CPersistUtils::restore(ETA_TAG, m_Eta, traverser))
RESTORE(FEATURE_BAG_FRACTION_TAG,
core::CPersistUtils::restore(FEATURE_BAG_FRACTION_TAG,
m_FeatureBagFraction, traverser))
RESTORE(FEATURE_DATA_TYPES_TAG,
core::CPersistUtils::restore(FEATURE_DATA_TYPES_TAG,
m_FeatureDataTypes, traverser));
RESTORE(FEATURE_SAMPLE_PROBABILITIES_TAG,
core::CPersistUtils::restore(FEATURE_SAMPLE_PROBABILITIES_TAG,
m_FeatureSampleProbabilities, traverser))
RESTORE(MAXIMUM_ATTEMPTS_TO_ADD_TREE_TAG,
core::CPersistUtils::restore(MAXIMUM_ATTEMPTS_TO_ADD_TREE_TAG,
m_MaximumAttemptsToAddTree, traverser))
RESTORE(MAXIMUM_OPTIMISATION_ROUNDS_PER_HYPERPARAMETER_TAG,
core::CPersistUtils::restore(
MAXIMUM_OPTIMISATION_ROUNDS_PER_HYPERPARAMETER_TAG,
m_MaximumOptimisationRoundsPerHyperparameter, traverser))
RESTORE(MAXIMUM_TREE_SIZE_MULTIPLIER_TAG,
core::CPersistUtils::restore(MAXIMUM_TREE_SIZE_MULTIPLIER_TAG,
m_MaximumTreeSizeMultiplier, traverser))
RESTORE(MISSING_FEATURE_ROW_MASKS_TAG,
core::CPersistUtils::restore(MISSING_FEATURE_ROW_MASKS_TAG,
m_MissingFeatureRowMasks, traverser))
RESTORE(NUMBER_FOLDS_TAG,
core::CPersistUtils::restore(NUMBER_FOLDS_TAG, m_NumberFolds, traverser))
RESTORE(NUMBER_ROUNDS_TAG,
core::CPersistUtils::restore(NUMBER_ROUNDS_TAG, m_NumberRounds, traverser))
RESTORE(NUMBER_SPLITS_PER_FEATURE_TAG,
core::CPersistUtils::restore(NUMBER_SPLITS_PER_FEATURE_TAG,
m_NumberSplitsPerFeature, traverser))
RESTORE(NUMBER_THREADS_TAG,
core::CPersistUtils::restore(NUMBER_THREADS_TAG, m_NumberThreads, traverser))
RESTORE(RANDOM_NUMBER_GENERATOR_TAG, m_Rng.fromString(traverser.value()))
RESTORE(REGULARIZATION_TAG,
core::CPersistUtils::restore(REGULARIZATION_TAG, m_Regularization, traverser))
RESTORE(REGULARIZATION_OVERRIDE_TAG,
core::CPersistUtils::restore(REGULARIZATION_OVERRIDE_TAG,
m_RegularizationOverride, traverser))
RESTORE(ROWS_PER_FEATURE_TAG,
core::CPersistUtils::restore(ROWS_PER_FEATURE_TAG, m_RowsPerFeature, traverser))
RESTORE(TESTING_ROW_MASKS_TAG,
core::CPersistUtils::restore(TESTING_ROW_MASKS_TAG, m_TestingRowMasks, traverser))
RESTORE(MAXIMUM_NUMBER_TREES_TAG,
core::CPersistUtils::restore(MAXIMUM_NUMBER_TREES_TAG,
m_MaximumNumberTrees, traverser))
RESTORE(TRAINING_ROW_MASKS_TAG,
core::CPersistUtils::restore(TRAINING_ROW_MASKS_TAG, m_TrainingRowMasks, traverser))
RESTORE(TRAINING_PROGRESS_TAG,
core::CPersistUtils::restore(TRAINING_PROGRESS_TAG, m_TrainingProgress, traverser))
RESTORE(BEST_FOREST_TAG,
core::CPersistUtils::restore(BEST_FOREST_TAG, m_BestForest, traverser))
RESTORE(BEST_HYPERPARAMETERS_TAG,
core::CPersistUtils::restore(BEST_HYPERPARAMETERS_TAG,
m_BestHyperparameters, traverser))
RESTORE(ETA_OVERRIDE_TAG,
core::CPersistUtils::restore(ETA_OVERRIDE_TAG, m_EtaOverride, traverser))
RESTORE(FEATURE_BAG_FRACTION_OVERRIDE_TAG,
core::CPersistUtils::restore(FEATURE_BAG_FRACTION_OVERRIDE_TAG,
m_FeatureBagFractionOverride, traverser))
RESTORE(MAXIMUM_NUMBER_TREES_OVERRIDE_TAG,
core::CPersistUtils::restore(MAXIMUM_NUMBER_TREES_OVERRIDE_TAG,
m_MaximumNumberTreesOverride, traverser))
RESTORE(LOSS_TAG, restoreLoss(m_Loss, traverser))
} while (traverser.next());
return true;
if (traverser.name() == VERSION_7_5_TAG) {
do {
const std::string& name = traverser.name();
RESTORE_NO_ERROR(BAYESIAN_OPTIMIZATION_TAG,
m_BayesianOptimization =
std::make_unique<CBayesianOptimisation>(traverser))
RESTORE(BEST_FOREST_TEST_LOSS_TAG,
core::CPersistUtils::restore(BEST_FOREST_TEST_LOSS_TAG,
m_BestForestTestLoss, traverser))
RESTORE(CURRENT_ROUND_TAG,
core::CPersistUtils::restore(CURRENT_ROUND_TAG, m_CurrentRound, traverser))
RESTORE(DEPENDENT_VARIABLE_TAG,
core::CPersistUtils::restore(DEPENDENT_VARIABLE_TAG,
m_DependentVariable, traverser))
RESTORE_NO_ERROR(ENCODER_TAG,
m_Encoder = std::make_unique<CDataFrameCategoryEncoder>(traverser))
RESTORE(ETA_GROWTH_RATE_PER_TREE_TAG,
core::CPersistUtils::restore(ETA_GROWTH_RATE_PER_TREE_TAG,
m_EtaGrowthRatePerTree, traverser))
RESTORE(ETA_TAG, core::CPersistUtils::restore(ETA_TAG, m_Eta, traverser))
RESTORE(FEATURE_BAG_FRACTION_TAG,
core::CPersistUtils::restore(FEATURE_BAG_FRACTION_TAG,
m_FeatureBagFraction, traverser))
RESTORE(FEATURE_DATA_TYPES_TAG,
core::CPersistUtils::restore(FEATURE_DATA_TYPES_TAG,
m_FeatureDataTypes, traverser));
RESTORE(FEATURE_SAMPLE_PROBABILITIES_TAG,
core::CPersistUtils::restore(FEATURE_SAMPLE_PROBABILITIES_TAG,
m_FeatureSampleProbabilities, traverser))
RESTORE(MAXIMUM_ATTEMPTS_TO_ADD_TREE_TAG,
core::CPersistUtils::restore(MAXIMUM_ATTEMPTS_TO_ADD_TREE_TAG,
m_MaximumAttemptsToAddTree, traverser))
RESTORE(MAXIMUM_OPTIMISATION_ROUNDS_PER_HYPERPARAMETER_TAG,
core::CPersistUtils::restore(
MAXIMUM_OPTIMISATION_ROUNDS_PER_HYPERPARAMETER_TAG,
m_MaximumOptimisationRoundsPerHyperparameter, traverser))
RESTORE(MAXIMUM_TREE_SIZE_MULTIPLIER_TAG,
core::CPersistUtils::restore(MAXIMUM_TREE_SIZE_MULTIPLIER_TAG,
m_MaximumTreeSizeMultiplier, traverser))
RESTORE(MISSING_FEATURE_ROW_MASKS_TAG,
core::CPersistUtils::restore(MISSING_FEATURE_ROW_MASKS_TAG,
m_MissingFeatureRowMasks, traverser))
RESTORE(NUMBER_FOLDS_TAG,
core::CPersistUtils::restore(NUMBER_FOLDS_TAG, m_NumberFolds, traverser))
RESTORE(NUMBER_ROUNDS_TAG,
core::CPersistUtils::restore(NUMBER_ROUNDS_TAG, m_NumberRounds, traverser))
RESTORE(NUMBER_SPLITS_PER_FEATURE_TAG,
core::CPersistUtils::restore(NUMBER_SPLITS_PER_FEATURE_TAG,
m_NumberSplitsPerFeature, traverser))
RESTORE(NUMBER_THREADS_TAG,
core::CPersistUtils::restore(NUMBER_THREADS_TAG, m_NumberThreads, traverser))
RESTORE(RANDOM_NUMBER_GENERATOR_TAG, m_Rng.fromString(traverser.value()))
RESTORE(REGULARIZATION_TAG,
core::CPersistUtils::restore(REGULARIZATION_TAG, m_Regularization, traverser))
RESTORE(REGULARIZATION_OVERRIDE_TAG,
core::CPersistUtils::restore(REGULARIZATION_OVERRIDE_TAG,
m_RegularizationOverride, traverser))
RESTORE(ROWS_PER_FEATURE_TAG,
core::CPersistUtils::restore(ROWS_PER_FEATURE_TAG, m_RowsPerFeature, traverser))
RESTORE(TESTING_ROW_MASKS_TAG,
core::CPersistUtils::restore(TESTING_ROW_MASKS_TAG,
m_TestingRowMasks, traverser))
RESTORE(MAXIMUM_NUMBER_TREES_TAG,
core::CPersistUtils::restore(MAXIMUM_NUMBER_TREES_TAG,
m_MaximumNumberTrees, traverser))
RESTORE(TRAINING_ROW_MASKS_TAG,
core::CPersistUtils::restore(TRAINING_ROW_MASKS_TAG,
m_TrainingRowMasks, traverser))
RESTORE(TRAINING_PROGRESS_TAG,
core::CPersistUtils::restore(TRAINING_PROGRESS_TAG,
m_TrainingProgress, traverser))
RESTORE(BEST_FOREST_TAG,
core::CPersistUtils::restore(BEST_FOREST_TAG, m_BestForest, traverser))
RESTORE(BEST_HYPERPARAMETERS_TAG,
core::CPersistUtils::restore(BEST_HYPERPARAMETERS_TAG,
m_BestHyperparameters, traverser))
RESTORE(ETA_OVERRIDE_TAG,
core::CPersistUtils::restore(ETA_OVERRIDE_TAG, m_EtaOverride, traverser))
RESTORE(FEATURE_BAG_FRACTION_OVERRIDE_TAG,
core::CPersistUtils::restore(FEATURE_BAG_FRACTION_OVERRIDE_TAG,
m_FeatureBagFractionOverride, traverser))
RESTORE(MAXIMUM_NUMBER_TREES_OVERRIDE_TAG,
core::CPersistUtils::restore(MAXIMUM_NUMBER_TREES_OVERRIDE_TAG,
m_MaximumNumberTreesOverride, traverser))
RESTORE(LOSS_TAG, restoreLoss(m_Loss, traverser))
} while (traverser.next());
return true;
}
LOG_ERROR(<< "Input error: unsupported state serialization version. Currently supported version: "
<< VERSION_7_5_TAG);
return false;
}

bool CBoostedTreeImpl::restoreLoss(CBoostedTree::TLossFunctionUPtr& loss,
Expand Down
Loading