@@ -21,8 +21,9 @@ public class GradientDescent : OptimizationAlgorithm
21
21
private double [ ] _droppedScores ;
22
22
private double [ ] _scores ;
23
23
24
- public GradientDescent ( Ensemble ensemble , Dataset trainData , double [ ] initTrainScores , IGradientAdjuster gradientWrapper )
25
- : base ( ensemble , trainData , initTrainScores )
24
+ public GradientDescent ( Ensemble ensemble , Dataset trainData , double [ ] initTrainScores , IGradientAdjuster gradientWrapper ,
25
+ double dropoutRate = 0 , int dropoutSeed = int . MinValue )
26
+ : base ( ensemble , trainData , initTrainScores , dropoutRate , dropoutSeed )
26
27
{
27
28
_gradientWrapper = gradientWrapper ;
28
29
_treeScores = new List < double [ ] > ( ) ;
@@ -36,7 +37,11 @@ protected override ScoreTracker ConstructScoreTracker(string name, Dataset set,
36
37
protected virtual double [ ] GetGradient ( IChannel ch )
37
38
{
38
39
Contracts . AssertValue ( ch ) ;
39
- if ( DropoutRate > 0 )
40
+
41
+ // Assumes that GetGradient is called at most once per iteration
42
+ ResetDropoutSeed ( ) ;
43
+
44
+ if ( _dropoutRate > 0 )
40
45
{
41
46
if ( _droppedScores == null )
42
47
_droppedScores = new double [ TrainingScores . Scores . Length ] ;
@@ -46,16 +51,16 @@ protected virtual double[] GetGradient(IChannel ch)
46
51
_scores = new double [ TrainingScores . Scores . Length ] ;
47
52
int numberOfTrees = Ensemble . NumTrees ;
48
53
int [ ] droppedTrees =
49
- Enumerable . Range ( 0 , numberOfTrees ) . Where ( t => ( DropoutRng . NextDouble ( ) < DropoutRate ) ) . ToArray ( ) ;
54
+ Enumerable . Range ( 0 , numberOfTrees ) . Where ( t => ( _dropoutRng . NextDouble ( ) < _dropoutRate ) ) . ToArray ( ) ;
50
55
_numberOfDroppedTrees = droppedTrees . Length ;
51
56
if ( ( _numberOfDroppedTrees == 0 ) && ( numberOfTrees > 0 ) )
52
57
{
53
- droppedTrees = new int [ ] { DropoutRng . Next ( numberOfTrees ) } ;
58
+ droppedTrees = new int [ ] { _dropoutRng . Next ( numberOfTrees ) } ;
54
59
// force at least a single tree to be dropped
55
60
_numberOfDroppedTrees = droppedTrees . Length ;
56
61
}
57
62
ch . Trace ( "dropout: Dropping {0} trees of {1} for rate {2}" ,
58
- _numberOfDroppedTrees , numberOfTrees , DropoutRate ) ;
63
+ _numberOfDroppedTrees , numberOfTrees , _dropoutRate ) ;
59
64
foreach ( int i in droppedTrees )
60
65
{
61
66
double [ ] s = _treeScores [ i ] ;
@@ -94,7 +99,7 @@ public override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatu
94
99
{
95
100
Contracts . CheckValue ( ch , nameof ( ch ) ) ;
96
101
// Fit a regression tree to the gradient using least squares.
97
- RegressionTree tree = TreeLearner . FitTargets ( ch , activeFeatures , AdjustTargetsAndSetWeights ( ch ) ) ;
102
+ RegressionTree tree = TreeLearner . FitTargets ( ch , activeFeatures , AdjustTargetsAndSetWeights ( ch ) , iteration : Iteration ) ;
98
103
if ( tree == null )
99
104
return null ; // Could not learn a tree. Exit.
100
105
@@ -105,7 +110,7 @@ public override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatu
105
110
{
106
111
double [ ] backupScores = null ;
107
112
// when doing dropouts we need to replace the TrainingScores with the scores without the dropped trees
108
- if ( DropoutRate > 0 )
113
+ if ( _dropoutRate > 0 )
109
114
{
110
115
backupScores = TrainingScores . Scores ;
111
116
TrainingScores . Scores = _scores ;
@@ -117,7 +122,7 @@ public override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatu
117
122
( ObjectiveFunction as IStepSearch ) . AdjustTreeOutputs ( ch , tree , TreeLearner . Partitioning , TrainingScores ) ;
118
123
else
119
124
throw ch . Except ( "No AdjustTreeOutputs defined. Objective function should define IStepSearch or AdjustTreeOutputsOverride should be set" ) ;
120
- if ( DropoutRate > 0 )
125
+ if ( _dropoutRate > 0 )
121
126
{
122
127
// Returning the original scores.
123
128
TrainingScores . Scores = backupScores ;
@@ -128,7 +133,7 @@ public override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatu
128
133
SmoothTree ( tree , Smoothing ) ;
129
134
UseFastTrainingScoresUpdate = false ;
130
135
}
131
- if ( DropoutRate > 0 )
136
+ if ( _dropoutRate > 0 )
132
137
{
133
138
// Don't do shrinkage if you do dropouts.
134
139
double scaling = ( 1.0 / ( 1.0 + _numberOfDroppedTrees ) ) ;
0 commit comments