Skip to content

Commit 9f9879c

Browse files
author
Rogan Carr
committed
Synchronizing seeds at each iteration. This allows newly provisioned workers to "fast-forward" to the latest iteration and have the same random number generator state as all other workers.
As part of this, the dropout parameters were moved into the constructor for `OptimizationAlgorithm` and made `private` and/or `protected`.
1 parent 81d40a9 commit 9f9879c

11 files changed

+73
-37
lines changed

src/Microsoft.ML.FastTree/BoostingFastTree.cs

+6-5
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,16 @@ protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel
6868
switch (Args.OptimizationAlgorithm)
6969
{
7070
case BoostedTreeArgs.OptimizationAlgorithmType.GradientDescent:
71-
optimizationAlgorithm = new GradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper);
71+
optimizationAlgorithm = new GradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper,
72+
dropoutRate: Args.DropoutRate, dropoutSeed: Args.RngSeed);
7273
break;
7374
case BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent:
74-
optimizationAlgorithm = new AcceleratedGradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper);
75+
optimizationAlgorithm = new AcceleratedGradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper,
76+
dropoutRate: Args.DropoutRate, dropoutSeed: Args.RngSeed);
7577
break;
7678
case BoostedTreeArgs.OptimizationAlgorithmType.ConjugateGradientDescent:
77-
optimizationAlgorithm = new ConjugateGradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper);
79+
optimizationAlgorithm = new ConjugateGradientDescent(Ensemble, TrainSet, InitTrainScores, gradientWrapper,
80+
dropoutRate: Args.DropoutRate, dropoutSeed: Args.RngSeed);
7881
break;
7982
default:
8083
throw ch.Except("Unknown optimization algorithm '{0}'", Args.OptimizationAlgorithm);
@@ -83,8 +86,6 @@ protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel
8386
optimizationAlgorithm.TreeLearner = ConstructTreeLearner(ch);
8487
optimizationAlgorithm.ObjectiveFunction = ConstructObjFunc(ch);
8588
optimizationAlgorithm.Smoothing = Args.Smoothing;
86-
optimizationAlgorithm.DropoutRate = Args.DropoutRate;
87-
optimizationAlgorithm.DropoutRng = new Random(Args.RngSeed);
8889
optimizationAlgorithm.PreScoreUpdateEvent += PrintTestGraph;
8990

9091
return optimizationAlgorithm;

src/Microsoft.ML.FastTree/FastTree.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ protected bool[] GetActiveFeatures()
377377
if (Args.FeatureFraction < 1.0)
378378
{
379379
if (_featureSelectionRandom == null)
380-
_featureSelectionRandom = new Random(Args.FeatureSelectSeed);
380+
_featureSelectionRandom = new Random(Args.FeatureSelectSeed + Ensemble.NumTrees);
381381

382382
for (int i = 0; i < TrainSet.NumFeatures; ++i)
383383
{
@@ -614,6 +614,9 @@ protected virtual void Train(IChannel ch)
614614
{
615615
using (Timer.Time(TimerEvent.Iteration))
616616
{
617+
// Reset Seeds
618+
_featureSelectionRandom = new Random(Args.FeatureSelectSeed + Ensemble.NumTrees);
619+
617620
#if NO_STORE
618621
bool[] activeFeatures = GetActiveFeatures();
619622
#else

src/Microsoft.ML.FastTree/RandomForest.cs

-3
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,6 @@ protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel
2929
optimizationAlgorithm.TreeLearner = ConstructTreeLearner(ch);
3030
optimizationAlgorithm.ObjectiveFunction = ConstructObjFunc(ch);
3131
optimizationAlgorithm.Smoothing = Args.Smoothing;
32-
// No notion of dropout for non-boosting applications.
33-
optimizationAlgorithm.DropoutRate = 0;
34-
optimizationAlgorithm.DropoutRng = null;
3532
optimizationAlgorithm.PreScoreUpdateEvent += PrintTestGraph;
3633

3734
return optimizationAlgorithm;

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/AcceleratedGradientDescent.cs

+3-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ namespace Microsoft.ML.Runtime.FastTree.Internal
77
//Accelerated gradient descent score tracker
88
public class AcceleratedGradientDescent : GradientDescent
99
{
10-
public AcceleratedGradientDescent(Ensemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
11-
: base(ensemble, trainData, initTrainScores, gradientWrapper)
10+
public AcceleratedGradientDescent(Ensemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper,
11+
double dropoutRate = 0, int dropoutSeed = int.MinValue)
12+
: base(ensemble, trainData, initTrainScores, gradientWrapper, dropoutRate, dropoutSeed)
1213
{
1314
UseFastTrainingScoresUpdate = false;
1415
}

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/ConjugateGradientDescent.cs

+3-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ public class ConjugateGradientDescent : GradientDescent
1111
private double[] _currentGradient;
1212
private double[] _currentDk;
1313

14-
public ConjugateGradientDescent(Ensemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
15-
: base(ensemble, trainData, initTrainScores, gradientWrapper)
14+
public ConjugateGradientDescent(Ensemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper,
15+
double dropoutRate = 0, int dropoutSeed = int.MinValue)
16+
: base(ensemble, trainData, initTrainScores, gradientWrapper, dropoutRate, dropoutSeed)
1617
{
1718
_currentDk = new double[trainData.NumDocs];
1819
}

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/GradientDescent.cs

+15-10
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ public class GradientDescent : OptimizationAlgorithm
2121
private double[] _droppedScores;
2222
private double[] _scores;
2323

24-
public GradientDescent(Ensemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
25-
: base(ensemble, trainData, initTrainScores)
24+
public GradientDescent(Ensemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper,
25+
double dropoutRate = 0, int dropoutSeed = int.MinValue)
26+
: base(ensemble, trainData, initTrainScores, dropoutRate, dropoutSeed)
2627
{
2728
_gradientWrapper = gradientWrapper;
2829
_treeScores = new List<double[]>();
@@ -36,7 +37,11 @@ protected override ScoreTracker ConstructScoreTracker(string name, Dataset set,
3637
protected virtual double[] GetGradient(IChannel ch)
3738
{
3839
Contracts.AssertValue(ch);
39-
if (DropoutRate > 0)
40+
41+
// Assumes that GetGradient is called at most once per iteration
42+
ResetDropoutSeed();
43+
44+
if (_dropoutRate > 0)
4045
{
4146
if (_droppedScores == null)
4247
_droppedScores = new double[TrainingScores.Scores.Length];
@@ -46,16 +51,16 @@ protected virtual double[] GetGradient(IChannel ch)
4651
_scores = new double[TrainingScores.Scores.Length];
4752
int numberOfTrees = Ensemble.NumTrees;
4853
int[] droppedTrees =
49-
Enumerable.Range(0, numberOfTrees).Where(t => (DropoutRng.NextDouble() < DropoutRate)).ToArray();
54+
Enumerable.Range(0, numberOfTrees).Where(t => (_dropoutRng.NextDouble() < _dropoutRate)).ToArray();
5055
_numberOfDroppedTrees = droppedTrees.Length;
5156
if ((_numberOfDroppedTrees == 0) && (numberOfTrees > 0))
5257
{
53-
droppedTrees = new int[] { DropoutRng.Next(numberOfTrees) };
58+
droppedTrees = new int[] { _dropoutRng.Next(numberOfTrees) };
5459
// force at least a single tree to be dropped
5560
_numberOfDroppedTrees = droppedTrees.Length;
5661
}
5762
ch.Trace("dropout: Dropping {0} trees of {1} for rate {2}",
58-
_numberOfDroppedTrees, numberOfTrees, DropoutRate);
63+
_numberOfDroppedTrees, numberOfTrees, _dropoutRate);
5964
foreach (int i in droppedTrees)
6065
{
6166
double[] s = _treeScores[i];
@@ -94,7 +99,7 @@ public override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatu
9499
{
95100
Contracts.CheckValue(ch, nameof(ch));
96101
// Fit a regression tree to the gradient using least squares.
97-
RegressionTree tree = TreeLearner.FitTargets(ch, activeFeatures, AdjustTargetsAndSetWeights(ch));
102+
RegressionTree tree = TreeLearner.FitTargets(ch, activeFeatures, AdjustTargetsAndSetWeights(ch), iteration: Iteration);
98103
if (tree == null)
99104
return null; // Could not learn a tree. Exit.
100105

@@ -105,7 +110,7 @@ public override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatu
105110
{
106111
double[] backupScores = null;
107112
// when doing dropouts we need to replace the TrainingScores with the scores without the dropped trees
108-
if (DropoutRate > 0)
113+
if (_dropoutRate > 0)
109114
{
110115
backupScores = TrainingScores.Scores;
111116
TrainingScores.Scores = _scores;
@@ -117,7 +122,7 @@ public override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatu
117122
(ObjectiveFunction as IStepSearch).AdjustTreeOutputs(ch, tree, TreeLearner.Partitioning, TrainingScores);
118123
else
119124
throw ch.Except("No AdjustTreeOutputs defined. Objective function should define IStepSearch or AdjustTreeOutputsOverride should be set");
120-
if (DropoutRate > 0)
125+
if (_dropoutRate > 0)
121126
{
122127
// Returning the original scores.
123128
TrainingScores.Scores = backupScores;
@@ -128,7 +133,7 @@ public override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatu
128133
SmoothTree(tree, Smoothing);
129134
UseFastTrainingScoresUpdate = false;
130135
}
131-
if (DropoutRate > 0)
136+
if (_dropoutRate > 0)
132137
{
133138
// Don't do shrinkage if you do dropouts.
134139
double scaling = (1.0 / (1.0 + _numberOfDroppedTrees));

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/NoOptimizationAlgorithm.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ public class RandomForestOptimizer : GradientDescent
1212
{
1313
private IGradientAdjuster _gradientWrapper;
1414
// REVIEW: When the FastTree appliation is decoupled with tree learner and boosting logic, this class should be removed.
15-
public RandomForestOptimizer(Ensemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
16-
: base(ensemble, trainData, initTrainScores, gradientWrapper)
15+
public RandomForestOptimizer(Ensemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper,
16+
double dropoutRate = 0, int dropoutSeed = int.MinValue)
17+
: base(ensemble, trainData, initTrainScores, gradientWrapper, dropoutRate, dropoutSeed)
1718
{
1819
_gradientWrapper = gradientWrapper;
1920
}
@@ -32,7 +33,7 @@ public override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatu
3233
double[] targets = GetGradient(ch);
3334
double[] weightedTargets = _gradientWrapper.AdjustTargetAndSetWeights(targets, ObjectiveFunction, out sampleWeights);
3435
RegressionTree tree = ((RandomForestLeastSquaresTreeLearner)TreeLearner).FitTargets(ch, activeFeatures, weightedTargets,
35-
targets, sampleWeights);
36+
targets, sampleWeights, iteration: Iteration);
3637

3738
if (tree != null)
3839
Ensemble.AddTree(tree);

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/OptimizationAlgorithm.cs

+16-4
Original file line numberDiff line numberDiff line change
@@ -32,17 +32,22 @@ public abstract class OptimizationAlgorithm
3232

3333
public IStepSearch AdjustTreeOutputsOverride; // if set it overrides IStepSearch possibly implemented by ObejctiveFunctionBase
3434
public double Smoothing;
35-
public double DropoutRate;
36-
public Random DropoutRng;
35+
protected double _dropoutRate;
36+
protected Random _dropoutRng;
37+
private int _dropoutSeed;
3738
public bool UseFastTrainingScoresUpdate;
3839

39-
public OptimizationAlgorithm(Ensemble ensemble, Dataset trainData, double[] initTrainScores)
40+
public OptimizationAlgorithm(Ensemble ensemble, Dataset trainData, double[] initTrainScores,
41+
double dropoutRate = 0, int dropoutSeed = int.MinValue)
4042
{
4143
Ensemble = ensemble;
4244
TrainingScores = ConstructScoreTracker("train", trainData, initTrainScores);
4345
TrackedScores = new List<ScoreTracker>();
4446
TrackedScores.Add(TrainingScores);
45-
DropoutRng = new Random();
47+
if (dropoutSeed != int.MinValue)
48+
_dropoutRng = new Random(dropoutSeed);
49+
_dropoutSeed = dropoutSeed;
50+
_dropoutRate = dropoutRate;
4651
UseFastTrainingScoresUpdate = true;
4752
}
4853

@@ -117,5 +122,12 @@ public virtual void FinalizeLearning(int bestIteration)
117122
TrackedScores.Clear(); //Invalidate all precomputed scores as they are not valid anymore //slow method of score computation will be used instead
118123
}
119124
}
125+
126+
protected void ResetDropoutSeed()
127+
{
128+
_dropoutRng = new Random(_dropoutSeed + Iteration);
129+
}
130+
131+
protected int Iteration => Ensemble.NumTrees;
120132
}
121133
}

src/Microsoft.ML.FastTree/Training/TreeLearners/FastForestLeastSquaresTreeLearner.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@ protected override RegressionTree NewTree()
2828
return new QuantileRegressionTree(NumLeaves);
2929
}
3030

31-
public RegressionTree FitTargets(IChannel ch, bool[] activeFeatures, Double[] weightedtargets, Double[] targets, Double[] weights)
31+
public RegressionTree FitTargets(IChannel ch, bool[] activeFeatures, Double[] weightedtargets, Double[] targets, Double[] weights, int iteration = 0)
3232
{
33-
var tree = (QuantileRegressionTree)FitTargets(ch, activeFeatures, weightedtargets);
33+
ResetRandomNumberGenerator(iteration);
34+
35+
var tree = (QuantileRegressionTree)FitTargets(ch, activeFeatures, weightedtargets, iteration: iteration);
3436
if (tree != null && _quantileEnabled)
3537
{
3638
Double[] distributionWeights = null;

src/Microsoft.ML.FastTree/Training/TreeLearners/LeastSquaresRegressionTreeLearner.cs

+17-4
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ public class LeastSquaresRegressionTreeLearner : TreeLearner
6262
// how many times each feature has been split on, for diversity penalty
6363
protected readonly int[] FeatureUseCount;
6464

65-
protected readonly Random Rand;
65+
protected Random Rand;
66+
private readonly int _randomSeed;
6667

6768
protected readonly double SplitFraction;
6869
protected readonly bool FilterZeros;
@@ -163,6 +164,7 @@ public LeastSquaresRegressionTreeLearner(Dataset trainData, int numLeaves, int m
163164

164165
FeatureUseCount = new int[TrainData.NumFeatures];
165166
Rand = new Random(randomSeed);
167+
_randomSeed = randomSeed;
166168
SplitFraction = splitFraction;
167169
FilterZeros = filterZeros;
168170
BsrMaxTreeOutput = bsrMaxTreeOutput;
@@ -213,7 +215,7 @@ protected virtual void MakeDummyRootSplit(RegressionTree tree, double rootTarget
213215
/// Learns a new tree for the current outputs
214216
/// </summary>
215217
/// <returns>A regression tree</returns>
216-
public sealed override RegressionTree FitTargets(IChannel ch, bool[] activeFeatures, double[] targets)
218+
public sealed override RegressionTree FitTargets(IChannel ch, bool[] activeFeatures, double[] targets, int iteration = 0)
217219
{
218220
int maxLeaves = base.NumLeaves;
219221
using (Timer.Time(TimerEvent.TreeLearnerGetTree))
@@ -224,7 +226,7 @@ public sealed override RegressionTree FitTargets(IChannel ch, bool[] activeFeatu
224226
tree.ActiveFeatures = (bool[])activeFeatures.Clone();
225227

226228
// clear memory
227-
Initialize(activeFeatures);
229+
Initialize(activeFeatures, iteration);
228230

229231
// find the best split of the root node.
230232
FindBestSplitOfRoot(targets);
@@ -318,14 +320,25 @@ protected virtual void PerformSplit(RegressionTree tree, int bestLeaf, double[]
318320
/// <summary>
319321
/// Clears data structures
320322
/// </summary>
321-
private void Initialize(bool[] activeFeatures)
323+
private void Initialize(bool[] activeFeatures, int iteration)
322324
{
325+
// Synchronize the Random Number Generator
326+
ResetRandomNumberGenerator(iteration);
323327
_parallelTraining.InitIteration(ref activeFeatures);
324328
ActiveFeatures = activeFeatures;
325329
HistogramArrayPool.Reset();
326330
Partitioning.Initialize();
327331
}
328332

333+
/// <summary>
334+
/// Reset the random number generator to synchronize at iteration boundaries
335+
/// </summary>
336+
/// <param name="iteration">The iteration</param>
337+
protected void ResetRandomNumberGenerator(int iteration)
338+
{
339+
Rand = new Random(_randomSeed + iteration);
340+
}
341+
329342
protected bool HasWeights => TrainData?.SampleWeights != null;
330343

331344
protected double[] GetTargetWeights()

src/Microsoft.ML.FastTree/Training/TreeLearners/TreeLearner.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ protected TreeLearner(Dataset trainData, int numLeaves)
2222

2323
public static string TargetWeightsDatasetName { get { return "TargetWeightsDataset"; } }
2424

25-
public abstract RegressionTree FitTargets(IChannel ch, bool[] activeFeatures, double[] targets);
25+
public abstract RegressionTree FitTargets(IChannel ch, bool[] activeFeatures, double[] targets, int iteration = 0);
2626

2727
/// <summary>
2828
/// Get size of reserved memory for the tree learner.

0 commit comments

Comments
 (0)