@@ -331,8 +331,7 @@ private void TrainingIteration(int globalFeatureIndex, double[] gradient, double
331
331
TrainSet . Flocks [ flockIndex ] . Trust ( subFeatureIndex ) , 0 ) ;
332
332
333
333
// Adjust the model
334
- if ( _leafSplitCandidates . FeatureSplitInfo [ globalFeatureIndex ] . Gain > 0 )
335
- ConvertTreeToGraph ( globalFeatureIndex , iteration ) ;
334
+ ConvertTreeToGraph ( globalFeatureIndex , iteration , _leafSplitCandidates . FeatureSplitInfo [ globalFeatureIndex ] . Gain > 0 ) ;
336
335
}
337
336
338
337
/// <summary>
@@ -550,14 +549,19 @@ private void CenterGraph()
550
549
}
551
550
}
552
551
553
- private void ConvertTreeToGraph ( int globalFeatureIndex , int iteration )
552
+ private void ConvertTreeToGraph ( int globalFeatureIndex , int iteration , bool useSplitValues )
554
553
{
555
- SplitInfo splitinfo = _leafSplitCandidates . FeatureSplitInfo [ globalFeatureIndex ] ;
556
-
557
- _splitPoint [ globalFeatureIndex ] [ iteration ] = splitinfo . Threshold ;
554
+ // Always define the graph
558
555
_splitValue [ globalFeatureIndex ] [ iteration ] = new double [ 2 ] ; // Easily extend to variable-length trees
559
- _splitValue [ globalFeatureIndex ] [ iteration ] [ 0 ] = Args . LearningRates * splitinfo . LteOutput ;
560
- _splitValue [ globalFeatureIndex ] [ iteration ] [ 1 ] = Args . LearningRates * splitinfo . GTOutput ;
556
+
557
+ // But only fill it in if some criteria were met
558
+ if ( useSplitValues )
559
+ {
560
+ SplitInfo splitinfo = _leafSplitCandidates . FeatureSplitInfo [ globalFeatureIndex ] ;
561
+ _splitPoint [ globalFeatureIndex ] [ iteration ] = splitinfo . Threshold ;
562
+ _splitValue [ globalFeatureIndex ] [ iteration ] [ 0 ] = Args . LearningRates * splitinfo . LteOutput ;
563
+ _splitValue [ globalFeatureIndex ] [ iteration ] [ 1 ] = Args . LearningRates * splitinfo . GTOutput ;
564
+ }
561
565
}
562
566
563
567
private void InitializeGamHistograms ( )
0 commit comments