Skip to content

Commit b7563a2

Browse files
committed
Address comments and internalize RegressionTree and TreeEsnmble.
A lot of things become internal and possibly best-friend because of RegressionTree and TreeEnsemble. Fix build and add some doc strings
1 parent c479a06 commit b7563a2

32 files changed

+183
-153
lines changed

src/Microsoft.ML.FastTree/BoostingFastTree.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ protected override void CheckArgs(IChannel ch)
6262
base.CheckArgs(ch);
6363
}
6464

65-
protected override TreeLearner ConstructTreeLearner(IChannel ch)
65+
internal override TreeLearner ConstructTreeLearner(IChannel ch)
6666
{
6767
return new LeastSquaresRegressionTreeLearner(
6868
TrainSet, Args.NumLeaves, Args.MinDocumentsInLeafs, Args.EntropyCoefficient,
@@ -73,7 +73,7 @@ protected override TreeLearner ConstructTreeLearner(IChannel ch)
7373
Args.MinDocsPercentageForCategoricalSplit, Args.Bundling, Args.MinDocsForCategoricalSplit, Args.Bias);
7474
}
7575

76-
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
76+
internal override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
7777
{
7878
Contracts.CheckValue(ch, nameof(ch));
7979
OptimizationAlgorithm optimizationAlgorithm;

src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public PerBinStats(Double sumTargets, Double sumWeights, int count)
5151
/// per flock. Note that feature indices, whenever present, refer to the feature within the
5252
/// particular flock the same as they do with <see cref="FeatureFlockBase"/>.
5353
/// </summary>
54-
public abstract class SufficientStatsBase
54+
internal abstract class SufficientStatsBase
5555
{
5656
// REVIEW: Holdover from histogram. I really don't like this. Figure out if
5757
// there's a better way.
@@ -929,7 +929,7 @@ public void FillSplitCandidatesCategoricalNeighborBundling(LeastSquaresRegressio
929929
/// </summary>
930930
/// <typeparam name="TSuffStats">The type of sufficient stats that we will be able to do
931931
/// "peer" operations against, like subtract. This will always be the derived class itself.</typeparam>
932-
public abstract class SufficientStatsBase<TSuffStats> : SufficientStatsBase
932+
internal abstract class SufficientStatsBase<TSuffStats> : SufficientStatsBase
933933
where TSuffStats : SufficientStatsBase<TSuffStats>
934934
{
935935
protected SufficientStatsBase(int features)
@@ -1005,7 +1005,7 @@ protected FeatureFlockBase(int count, bool categorical = false)
10051005
/// <param name="hasWeights">Whether structures related to tracking
10061006
/// example weights should be allocated</param>
10071007
/// <returns>A sufficient statistics object</returns>
1008-
public abstract SufficientStatsBase CreateSufficientStats(bool hasWeights);
1008+
internal abstract SufficientStatsBase CreateSufficientStats(bool hasWeights);
10091009

10101010
/// <summary>
10111011
/// Returns a forward indexer for a single feature. This has a default implementation that
@@ -1207,7 +1207,7 @@ public override long SizeInBytes()
12071207
+ sizeof(int) * HotFeatureStarts.Length;
12081208
}
12091209

1210-
public override SufficientStatsBase CreateSufficientStats(bool hasWeights)
1210+
internal override SufficientStatsBase CreateSufficientStats(bool hasWeights)
12111211
{
12121212
return new SufficientStats(this, hasWeights);
12131213
}

src/Microsoft.ML.FastTree/Dataset/SingletonFeatureFlock.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public override long SizeInBytes()
3434
return _bins.SizeInBytes() + sizeof(double) * _binUpperBounds.Length;
3535
}
3636

37-
public override SufficientStatsBase CreateSufficientStats(bool hasWeights)
37+
internal override SufficientStatsBase CreateSufficientStats(bool hasWeights)
3838
{
3939
return new SufficientStats(this, hasWeights);
4040
}

src/Microsoft.ML.FastTree/FastTree.cs

+12-10
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
5454
{
5555
protected readonly TArgs Args;
5656
protected readonly bool AllowGC;
57-
[BestFriend]
5857
internal TreeEnsemble TrainedEnsemble;
5958
protected int FeatureCount;
6059
private protected RoleMappedData ValidData;
@@ -63,8 +62,8 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
6362
/// <see cref="Tests"/> by calling <see cref="ExamplesToFastTreeBins.GetCompatibleDataset"/> in <see cref="InitializeTests"/>.
6463
/// </summary>
6564
private protected RoleMappedData TestData;
66-
protected IParallelTraining ParallelTraining;
67-
protected OptimizationAlgorithm OptimizationAlgorithm;
65+
internal IParallelTraining ParallelTraining;
66+
internal OptimizationAlgorithm OptimizationAlgorithm;
6867
protected Dataset TrainSet;
6968
protected Dataset ValidSet;
7069
/// <summary>
@@ -175,8 +174,8 @@ private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaSh
175174

176175
protected abstract Test ConstructTestForTrainingData();
177176

178-
protected abstract OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch);
179-
protected abstract TreeLearner ConstructTreeLearner(IChannel ch);
177+
internal abstract OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch);
178+
internal abstract TreeLearner ConstructTreeLearner(IChannel ch);
180179

181180
protected abstract ObjectiveFunctionBase ConstructObjFunc(IChannel ch);
182181

@@ -793,7 +792,7 @@ private float GetMachineAvailableBytes()
793792

794793
// This method is called at the end of each training iteration, with the tree that was learnt on that iteration.
795794
// Note that this tree can be null if no tree was learnt this iteration.
796-
protected virtual void CustomizedTrainingIteration(RegressionTree tree)
795+
internal virtual void CustomizedTrainingIteration(RegressionTree tree)
797796
{
798797
}
799798

@@ -2810,10 +2809,13 @@ public abstract class TreeEnsembleModelParameters :
28102809
ISingleCanSavePfa,
28112810
ISingleCanSaveOnnx
28122811
{
2813-
public TreeEnsembleView TrainedTreeEnsembleView { get; }
2812+
/// <summary>
2813+
/// An ensemble of trees exposed to users. It is a wrapper on an <see langword="internal"/> <see cref="TreeEnsemble"/> in <see cref="TreeRegressorCollection"/>.
2814+
/// </summary>
2815+
public TreeRegressorCollection TrainedTreeCollection { get; }
28142816
//The below two properties are necessary for tree Visualizer
28152817
[BestFriend]
2816-
internal TreeEnsemble TrainedEnsemble => TrainedTreeEnsembleView.TreeEnsemble;
2818+
internal TreeEnsemble TrainedEnsemble => TrainedTreeCollection.TreeEnsemble;
28172819
int ITreeEnsemble.NumTrees => TrainedEnsemble.NumTrees;
28182820

28192821
// Inner args is used only for documentation purposes when saving comments to INI files.
@@ -2866,7 +2868,7 @@ internal TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnse
28662868
// REVIEW: When we make the predictor wrapper, we may want to further "optimize"
28672869
// the trained ensemble to, for instance, resize arrays so that they are of the length
28682870
// the actual number of leaves/nodes, or remove unnecessary arrays, and so forth.
2869-
TrainedTreeEnsembleView = new TreeEnsembleView(trainedEnsemble);
2871+
TrainedTreeCollection = new TreeRegressorCollection(trainedEnsemble);
28702872
InnerArgs = innerArgs;
28712873
NumFeatures = numFeatures;
28722874

@@ -2894,7 +2896,7 @@ protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLo
28942896
if (ctx.Header.ModelVerWritten >= VerCategoricalSplitSerialized)
28952897
categoricalSplits = true;
28962898

2897-
TrainedTreeEnsembleView = new TreeEnsembleView(new TreeEnsemble(ctx, usingDefaultValues, categoricalSplits));
2899+
TrainedTreeCollection = new TreeRegressorCollection(new TreeEnsemble(ctx, usingDefaultValues, categoricalSplits));
28982900
MaxSplitFeatIdx = TrainedEnsemble.GetMaxFeatureIndex();
28992901

29002902
InnerArgs = ctx.LoadStringOrNull();

src/Microsoft.ML.FastTree/FastTreeArguments.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ public abstract class TreeArgs : LearnerInputBaseWithGroupId
156156
/// Allows to choose Parallel FastTree Learning Algorithm.
157157
/// </summary>
158158
[Argument(ArgumentType.Multiple, HelpText = "Allows to choose Parallel FastTree Learning Algorithm", ShortName = "parag")]
159-
public ISupportParallelTraining ParallelTrainer = new SingleTrainerFactory();
159+
internal ISupportParallelTraining ParallelTrainer = new SingleTrainerFactory();
160160

161161
/// <summary>
162162
/// The number of threads to use.

src/Microsoft.ML.FastTree/FastTreeClassification.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
204204
ParallelTraining);
205205
}
206206

207-
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
207+
internal override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
208208
{
209209
OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
210210
if (Args.UseLineSearch)

src/Microsoft.ML.FastTree/FastTreeRanking.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
189189
return new LambdaRankObjectiveFunction(TrainSet, TrainSet.Ratings, Args, ParallelTraining);
190190
}
191191

192-
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
192+
internal override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
193193
{
194194
OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
195195
if (Args.UseLineSearch)
@@ -373,7 +373,7 @@ protected override void Train(IChannel ch)
373373
PrintTestGraph(ch);
374374
}
375375

376-
protected override void CustomizedTrainingIteration(RegressionTree tree)
376+
internal override void CustomizedTrainingIteration(RegressionTree tree)
377377
{
378378
Contracts.AssertValueOrNull(tree);
379379
if (tree != null && Args.CompressEnsemble)
@@ -468,7 +468,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
468468
};
469469
}
470470

471-
public sealed class LambdaRankObjectiveFunction : ObjectiveFunctionBase, IStepSearch
471+
internal sealed class LambdaRankObjectiveFunction : ObjectiveFunctionBase, IStepSearch
472472
{
473473
private readonly short[] _labels;
474474

@@ -991,7 +991,7 @@ protected override void GetGradientInOneQuery(int query, int threadIndex)
991991
}
992992
}
993993

994-
public void AdjustTreeOutputs(IChannel ch, RegressionTree tree, DocumentPartitioning partitioning,
994+
void IStepSearch.AdjustTreeOutputs(IChannel ch, RegressionTree tree, DocumentPartitioning partitioning,
995995
ScoreTracker trainingScores)
996996
{
997997
const double epsilon = 1.4e-45;

src/Microsoft.ML.FastTree/FastTreeRegression.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
123123
return new ObjectiveImpl(TrainSet, Args);
124124
}
125125

126-
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
126+
internal override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
127127
{
128128
OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
129129
if (Args.UseLineSearch)

src/Microsoft.ML.FastTree/FastTreeTweedie.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
128128
return new ObjectiveImpl(TrainSet, Args);
129129
}
130130

131-
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
131+
internal override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
132132
{
133133
OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
134134
if (Args.UseLineSearch)

src/Microsoft.ML.FastTree/GamTrainer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
148148
public override TrainerInfo Info { get; }
149149
private protected virtual bool NeedCalibration => false;
150150

151-
protected IParallelTraining ParallelTraining;
151+
internal IParallelTraining ParallelTraining;
152152

153153
private protected GamTrainerBase(IHostEnvironment env,
154154
string name,

src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs

+1
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Internal.FastTree" + InternalPublicKey.Value)]
1616
[assembly: InternalsVisibleTo(assemblyName: "RunTests" + InternalPublicKey.Value)]
1717
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)]
18+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests" + PublicKey.TestValue)]
1819

1920
[assembly: WantsToBeBestFriends]

src/Microsoft.ML.FastTree/RandomForest.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ protected RandomForestTrainerBase(IHostEnvironment env,
4242
_quantileEnabled = quantileEnabled;
4343
}
4444

45-
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
45+
internal override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
4646
{
4747
Host.CheckValue(ch, nameof(ch));
4848
IGradientAdjuster gradientWrapper = MakeGradientWrapper(ch);
@@ -63,7 +63,7 @@ protected override void InitializeTests()
6363
{
6464
}
6565

66-
protected override TreeLearner ConstructTreeLearner(IChannel ch)
66+
internal override TreeLearner ConstructTreeLearner(IChannel ch)
6767
{
6868
return new RandomForestLeastSquaresTreeLearner(
6969
TrainSet, Args.NumLeaves, Args.MinDocumentsInLeafs, Args.EntropyCoefficient,

src/Microsoft.ML.FastTree/Training/DocumentPartitioning.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public DocumentPartitioning(int[] documents, int numDocuments, int maxLeaves)
5050
/// Constructs partitioning object based on the documents and RegressionTree splits
5151
/// NOTE: It has been optimized for speed and multiprocs with 10x gain on naive LINQ implementation
5252
/// </summary>
53-
public DocumentPartitioning(RegressionTree tree, Dataset dataset)
53+
internal DocumentPartitioning(RegressionTree tree, Dataset dataset)
5454
: this(dataset.NumDocs, tree.NumLeaves)
5555
{
5656
using (Timer.Time(TimerEvent.DocumentPartitioningConstruction))

src/Microsoft.ML.FastTree/Training/EnsembleCompression/LassoBasedEnsembleCompressor.cs

-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ namespace Microsoft.ML.Trainers.FastTree.Internal
1515
/// https://www-stat.stanford.edu/~hastie/Papers/glmnet.pdf
1616
/// </summary>
1717
/// <remarks>Author was Yasser Ganjisaffar during his internship.</remarks>
18-
[BestFriend]
1918
internal class LassoBasedEnsembleCompressor : IEnsembleCompressor<short>
2019
{
2120
// This module shouldn't consume more than 4GB of memory
@@ -534,7 +533,6 @@ private unsafe void LoadTargets(double[] trainScores, int bestIteration)
534533
}
535534
}
536535

537-
[BestFriend]
538536
bool IEnsembleCompressor<short>.Compress(IChannel ch, TreeEnsemble ensemble, double[] trainScores, int bestIteration, int maxTreesAfterCompression)
539537
{
540538
LoadTargets(trainScores, bestIteration);
@@ -553,7 +551,6 @@ bool IEnsembleCompressor<short>.Compress(IChannel ch, TreeEnsemble ensemble, dou
553551
return true;
554552
}
555553

556-
[BestFriend]
557554
TreeEnsemble IEnsembleCompressor<short>.GetCompressedEnsemble()
558555
{
559556
return _compressedEnsemble;

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/AcceleratedGradientDescent.cs

+3-4
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
namespace Microsoft.ML.Trainers.FastTree.Internal
66
{
77
//Accelerated gradient descent score tracker
8-
public class AcceleratedGradientDescent : GradientDescent
8+
internal class AcceleratedGradientDescent : GradientDescent
99
{
10-
[BestFriend]
1110
internal AcceleratedGradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
1211
: base(ensemble, trainData, initTrainScores, gradientWrapper)
1312
{
@@ -18,7 +17,7 @@ protected override ScoreTracker ConstructScoreTracker(string name, Dataset set,
1817
return new AgdScoreTracker(name, set, initScores);
1918
}
2019

21-
public override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatures)
20+
internal override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatures)
2221
{
2322
Contracts.CheckValue(ch, nameof(ch));
2423
AgdScoreTracker trainingScores = TrainingScores as AgdScoreTracker;
@@ -52,7 +51,7 @@ public override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatu
5251
return tree;
5352
}
5453

55-
public override void UpdateScores(ScoreTracker t, RegressionTree tree)
54+
internal override void UpdateScores(ScoreTracker t, RegressionTree tree)
5655
{
5756
if (t == TrainingScores)
5857
{

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/ConjugateGradientDescent.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
namespace Microsoft.ML.Trainers.FastTree.Internal
66
{
77
// Conjugate gradient descent
8-
public class ConjugateGradientDescent : GradientDescent
8+
internal class ConjugateGradientDescent : GradientDescent
99
{
1010
private double[] _previousGradient;
1111
private double[] _currentGradient;
@@ -18,7 +18,7 @@ internal ConjugateGradientDescent(TreeEnsemble ensemble, Dataset trainData, doub
1818
_currentDk = new double[trainData.NumDocs];
1919
}
2020

21-
protected override double[] GetGradient(IChannel ch)
21+
internal override double[] GetGradient(IChannel ch)
2222
{
2323
Contracts.AssertValue(ch);
2424
_previousGradient = _currentGradient;

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/GradientDescent.cs

+3-4
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
namespace Microsoft.ML.Trainers.FastTree.Internal
1010
{
11-
public class GradientDescent : OptimizationAlgorithm
11+
internal class GradientDescent : OptimizationAlgorithm
1212
{
1313
private IGradientAdjuster _gradientWrapper;
1414

@@ -21,7 +21,6 @@ public class GradientDescent : OptimizationAlgorithm
2121
private double[] _droppedScores;
2222
private double[] _scores;
2323

24-
[BestFriend]
2524
internal GradientDescent(TreeEnsemble ensemble, Dataset trainData, double[] initTrainScores, IGradientAdjuster gradientWrapper)
2625
: base(ensemble, trainData, initTrainScores)
2726
{
@@ -34,7 +33,7 @@ protected override ScoreTracker ConstructScoreTracker(string name, Dataset set,
3433
return new ScoreTracker(name, set, initScores);
3534
}
3635

37-
protected virtual double[] GetGradient(IChannel ch)
36+
internal virtual double[] GetGradient(IChannel ch)
3837
{
3938
Contracts.AssertValue(ch);
4039
if (DropoutRate > 0)
@@ -91,7 +90,7 @@ protected virtual double[] AdjustTargetsAndSetWeights(IChannel ch)
9190
}
9291
}
9392

94-
public override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatures)
93+
internal override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatures)
9594
{
9695
Contracts.CheckValue(ch, nameof(ch));
9796
// Fit a regression tree to the gradient using least squares.

src/Microsoft.ML.FastTree/Training/OptimizationAlgorithms/NoOptimizationAlgorithm.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ namespace Microsoft.ML.Trainers.FastTree.Internal
88
/// This is dummy optimizer. As Random forest does not have any boosting based optimization, this is place holder to be consistent
99
/// with other fast tree based applications
1010
/// </summary>
11-
public class RandomForestOptimizer : GradientDescent
11+
internal class RandomForestOptimizer : GradientDescent
1212
{
1313
private IGradientAdjuster _gradientWrapper;
1414
// REVIEW: When the FastTree appliation is decoupled with tree learner and boosting logic, this class should be removed.
@@ -25,7 +25,7 @@ protected override ScoreTracker ConstructScoreTracker(string name, Dataset set,
2525
return new ScoreTracker(name, set, initScores);
2626
}
2727

28-
public override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatures)
28+
internal override RegressionTree TrainingIteration(IChannel ch, bool[] activeFeatures)
2929
{
3030
Contracts.CheckValue(ch, nameof(ch));
3131

0 commit comments

Comments
 (0)