Skip to content

Commit 403e019

Browse files
committed
Address comments and internalize RegressionTree and TreeEsnmble.
A lot of things become internal and possibly best-friend because of RegressionTree and TreeEnsemble. Fix build and add some doc strings Undo some best-friend because they have no cross-assembly reference
1 parent 4093f28 commit 403e019

39 files changed

+208
-165
lines changed

src/Microsoft.ML.FastTree/BoostingFastTree.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ protected override void CheckArgs(IChannel ch)
6262
base.CheckArgs(ch);
6363
}
6464

65-
protected override TreeLearner ConstructTreeLearner(IChannel ch)
65+
internal override TreeLearner ConstructTreeLearner(IChannel ch)
6666
{
6767
return new LeastSquaresRegressionTreeLearner(
6868
TrainSet, Args.NumLeaves, Args.MinDocumentsInLeafs, Args.EntropyCoefficient,
@@ -73,7 +73,7 @@ protected override TreeLearner ConstructTreeLearner(IChannel ch)
7373
Args.MinDocsPercentageForCategoricalSplit, Args.Bundling, Args.MinDocsForCategoricalSplit, Args.Bias);
7474
}
7575

76-
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
76+
internal override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
7777
{
7878
Contracts.CheckValue(ch, nameof(ch));
7979
OptimizationAlgorithm optimizationAlgorithm;

src/Microsoft.ML.FastTree/Dataset/FeatureFlock.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ public PerBinStats(Double sumTargets, Double sumWeights, int count)
5151
/// per flock. Note that feature indices, whenever present, refer to the feature within the
5252
/// particular flock the same as they do with <see cref="FeatureFlockBase"/>.
5353
/// </summary>
54-
public abstract class SufficientStatsBase
54+
internal abstract class SufficientStatsBase
5555
{
5656
// REVIEW: Holdover from histogram. I really don't like this. Figure out if
5757
// there's a better way.
@@ -929,7 +929,7 @@ public void FillSplitCandidatesCategoricalNeighborBundling(LeastSquaresRegressio
929929
/// </summary>
930930
/// <typeparam name="TSuffStats">The type of sufficient stats that we will be able to do
931931
/// "peer" operations against, like subtract. This will always be the derived class itself.</typeparam>
932-
public abstract class SufficientStatsBase<TSuffStats> : SufficientStatsBase
932+
internal abstract class SufficientStatsBase<TSuffStats> : SufficientStatsBase
933933
where TSuffStats : SufficientStatsBase<TSuffStats>
934934
{
935935
protected SufficientStatsBase(int features)
@@ -1005,7 +1005,7 @@ protected FeatureFlockBase(int count, bool categorical = false)
10051005
/// <param name="hasWeights">Whether structures related to tracking
10061006
/// example weights should be allocated</param>
10071007
/// <returns>A sufficient statistics object</returns>
1008-
public abstract SufficientStatsBase CreateSufficientStats(bool hasWeights);
1008+
internal abstract SufficientStatsBase CreateSufficientStats(bool hasWeights);
10091009

10101010
/// <summary>
10111011
/// Returns a forward indexer for a single feature. This has a default implementation that
@@ -1207,7 +1207,7 @@ public override long SizeInBytes()
12071207
+ sizeof(int) * HotFeatureStarts.Length;
12081208
}
12091209

1210-
public override SufficientStatsBase CreateSufficientStats(bool hasWeights)
1210+
internal override SufficientStatsBase CreateSufficientStats(bool hasWeights)
12111211
{
12121212
return new SufficientStats(this, hasWeights);
12131213
}

src/Microsoft.ML.FastTree/Dataset/SingletonFeatureFlock.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public override long SizeInBytes()
3434
return _bins.SizeInBytes() + sizeof(double) * _binUpperBounds.Length;
3535
}
3636

37-
public override SufficientStatsBase CreateSufficientStats(bool hasWeights)
37+
internal override SufficientStatsBase CreateSufficientStats(bool hasWeights)
3838
{
3939
return new SufficientStats(this, hasWeights);
4040
}

src/Microsoft.ML.FastTree/FastTree.cs

+20-13
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,16 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
5555
{
5656
protected readonly TArgs Args;
5757
protected readonly bool AllowGC;
58-
protected TreeEnsemble TrainedEnsemble;
58+
internal TreeEnsemble TrainedEnsemble;
5959
protected int FeatureCount;
6060
private protected RoleMappedData ValidData;
6161
/// <summary>
6262
/// If not null, it's a test data set passed in from training context. It will be converted to one element in
6363
/// <see cref="Tests"/> by calling <see cref="ExamplesToFastTreeBins.GetCompatibleDataset"/> in <see cref="InitializeTests"/>.
6464
/// </summary>
6565
private protected RoleMappedData TestData;
66-
protected IParallelTraining ParallelTraining;
67-
protected OptimizationAlgorithm OptimizationAlgorithm;
66+
internal IParallelTraining ParallelTraining;
67+
internal OptimizationAlgorithm OptimizationAlgorithm;
6868
protected Dataset TrainSet;
6969
protected Dataset ValidSet;
7070
/// <summary>
@@ -88,7 +88,8 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
8888
protected double[] InitValidScores;
8989
protected double[][] InitTestScores;
9090
//protected int Iteration;
91-
protected TreeEnsemble Ensemble;
91+
[BestFriend]
92+
internal TreeEnsemble Ensemble;
9293

9394
protected bool HasValidSet => ValidSet != null;
9495

@@ -174,8 +175,8 @@ private protected FastTreeTrainerBase(IHostEnvironment env, TArgs args, SchemaSh
174175

175176
protected abstract Test ConstructTestForTrainingData();
176177

177-
protected abstract OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch);
178-
protected abstract TreeLearner ConstructTreeLearner(IChannel ch);
178+
internal abstract OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch);
179+
internal abstract TreeLearner ConstructTreeLearner(IChannel ch);
179180

180181
protected abstract ObjectiveFunctionBase ConstructObjFunc(IChannel ch);
181182

@@ -792,7 +793,7 @@ private float GetMachineAvailableBytes()
792793

793794
// This method is called at the end of each training iteration, with the tree that was learnt on that iteration.
794795
// Note that this tree can be null if no tree was learnt this iteration.
795-
protected virtual void CustomizedTrainingIteration(RegressionTree tree)
796+
internal virtual void CustomizedTrainingIteration(RegressionTree tree)
796797
{
797798
}
798799

@@ -2809,10 +2810,14 @@ public abstract class TreeEnsembleModelParameters :
28092810
ISingleCanSavePfa,
28102811
ISingleCanSaveOnnx
28112812
{
2812-
public TreeEnsembleView TrainedTreeEnsembleView { get; }
2813-
//The below two properties are necessary for tree Visualizer
2813+
/// <summary>
2814+
/// An ensemble of trees exposed to users. It is a wrapper on the <see langword="internal"/> <see cref="TreeEnsemble"/> in <see cref="TreeRegressorCollection"/>.
2815+
/// </summary>
2816+
public TreeRegressorCollection TrainedTreeCollection { get; }
2817+
2818+
// The below two properties are necessary for tree Visualizer
28142819
[BestFriend]
2815-
internal TreeEnsemble TrainedEnsemble => TrainedTreeEnsembleView.TreeEnsemble;
2820+
internal TreeEnsemble TrainedEnsemble => TrainedTreeCollection.TreeEnsemble;
28162821
int ITreeEnsemble.NumTrees => TrainedEnsemble.NumTrees;
28172822

28182823
// Inner args is used only for documentation purposes when saving comments to INI files.
@@ -2854,7 +2859,9 @@ public abstract class TreeEnsembleModelParameters :
28542859
/// </summary>
28552860
public FeatureContributionCalculator FeatureContributionCalculator => new FeatureContributionCalculator(this);
28562861

2857-
public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
2862+
/// The following function is used in both FastTree and LightGBM so <see cref="BestFriendAttribute"/> is required.
2863+
[BestFriend]
2864+
internal TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
28582865
: base(env, name)
28592866
{
28602867
Host.CheckValue(trainedEnsemble, nameof(trainedEnsemble));
@@ -2864,7 +2871,7 @@ public TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemb
28642871
// REVIEW: When we make the predictor wrapper, we may want to further "optimize"
28652872
// the trained ensemble to, for instance, resize arrays so that they are of the length
28662873
// the actual number of leaves/nodes, or remove unnecessary arrays, and so forth.
2867-
TrainedTreeEnsembleView = new TreeEnsembleView(trainedEnsemble);
2874+
TrainedTreeCollection = new TreeRegressorCollection(trainedEnsemble);
28682875
InnerArgs = innerArgs;
28692876
NumFeatures = numFeatures;
28702877

@@ -2892,7 +2899,7 @@ protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLo
28922899
if (ctx.Header.ModelVerWritten >= VerCategoricalSplitSerialized)
28932900
categoricalSplits = true;
28942901

2895-
TrainedTreeEnsembleView = new TreeEnsembleView(new TreeEnsemble(ctx, usingDefaultValues, categoricalSplits));
2902+
TrainedTreeCollection = new TreeRegressorCollection(new TreeEnsemble(ctx, usingDefaultValues, categoricalSplits));
28962903
MaxSplitFeatIdx = TrainedEnsemble.GetMaxFeatureIndex();
28972904

28982905
InnerArgs = ctx.LoadStringOrNull();

src/Microsoft.ML.FastTree/FastTreeArguments.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ public abstract class TreeArgs : LearnerInputBaseWithGroupId
156156
/// Allows to choose Parallel FastTree Learning Algorithm.
157157
/// </summary>
158158
[Argument(ArgumentType.Multiple, HelpText = "Allows to choose Parallel FastTree Learning Algorithm", ShortName = "parag")]
159-
public ISupportParallelTraining ParallelTrainer = new SingleTrainerFactory();
159+
internal ISupportParallelTraining ParallelTrainer = new SingleTrainerFactory();
160160

161161
/// <summary>
162162
/// The number of threads to use.

src/Microsoft.ML.FastTree/FastTreeClassification.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ private static VersionInfo GetVersionInfo()
7070

7171
protected override uint VerCategoricalSplitSerialized => 0x00010005;
7272

73-
public FastTreeBinaryModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
73+
internal FastTreeBinaryModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
7474
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
7575
{
7676
}
@@ -204,7 +204,7 @@ protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
204204
ParallelTraining);
205205
}
206206

207-
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
207+
internal override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
208208
{
209209
OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
210210
if (Args.UseLineSearch)

src/Microsoft.ML.FastTree/FastTreeRanking.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
190190
return new LambdaRankObjectiveFunction(TrainSet, TrainSet.Ratings, Args, ParallelTraining);
191191
}
192192

193-
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
193+
internal override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
194194
{
195195
OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
196196
if (Args.UseLineSearch)
@@ -374,7 +374,7 @@ protected override void Train(IChannel ch)
374374
PrintTestGraph(ch);
375375
}
376376

377-
protected override void CustomizedTrainingIteration(RegressionTree tree)
377+
internal override void CustomizedTrainingIteration(RegressionTree tree)
378378
{
379379
Contracts.AssertValueOrNull(tree);
380380
if (tree != null && Args.CompressEnsemble)
@@ -469,7 +469,7 @@ protected override SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSc
469469
};
470470
}
471471

472-
public sealed class LambdaRankObjectiveFunction : ObjectiveFunctionBase, IStepSearch
472+
internal sealed class LambdaRankObjectiveFunction : ObjectiveFunctionBase, IStepSearch
473473
{
474474
private readonly short[] _labels;
475475

@@ -992,7 +992,7 @@ protected override void GetGradientInOneQuery(int query, int threadIndex)
992992
}
993993
}
994994

995-
public void AdjustTreeOutputs(IChannel ch, RegressionTree tree, DocumentPartitioning partitioning,
995+
void IStepSearch.AdjustTreeOutputs(IChannel ch, RegressionTree tree, DocumentPartitioning partitioning,
996996
ScoreTracker trainingScores)
997997
{
998998
const double epsilon = 1.4e-45;
@@ -1131,7 +1131,7 @@ private static VersionInfo GetVersionInfo()
11311131

11321132
protected override uint VerCategoricalSplitSerialized => 0x00010005;
11331133

1134-
public FastTreeRankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
1134+
internal FastTreeRankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
11351135
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
11361136
{
11371137
}

src/Microsoft.ML.FastTree/FastTreeRegression.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
124124
return new ObjectiveImpl(TrainSet, Args);
125125
}
126126

127-
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
127+
internal override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
128128
{
129129
OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
130130
if (Args.UseLineSearch)
@@ -467,7 +467,7 @@ private static VersionInfo GetVersionInfo()
467467

468468
protected override uint VerCategoricalSplitSerialized => 0x00010005;
469469

470-
public FastTreeRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
470+
internal FastTreeRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
471471
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
472472
{
473473
}

src/Microsoft.ML.FastTree/FastTreeTweedie.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ protected override ObjectiveFunctionBase ConstructObjFunc(IChannel ch)
129129
return new ObjectiveImpl(TrainSet, Args);
130130
}
131131

132-
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
132+
internal override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
133133
{
134134
OptimizationAlgorithm optimizationAlgorithm = base.ConstructOptimizationAlgorithm(ch);
135135
if (Args.UseLineSearch)
@@ -470,7 +470,7 @@ private static VersionInfo GetVersionInfo()
470470

471471
protected override uint VerCategoricalSplitSerialized => 0x00010003;
472472

473-
public FastTreeTweedieModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
473+
internal FastTreeTweedieModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
474474
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
475475
{
476476
}

src/Microsoft.ML.FastTree/GamTrainer.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ public abstract class ArgumentsBase : LearnerInputBaseWithWeight
148148
public override TrainerInfo Info { get; }
149149
private protected virtual bool NeedCalibration => false;
150150

151-
protected IParallelTraining ParallelTraining;
151+
internal IParallelTraining ParallelTraining;
152152

153153
private protected GamTrainerBase(IHostEnvironment env,
154154
string name,

src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs

+1
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Internal.FastTree" + InternalPublicKey.Value)]
1616
[assembly: InternalsVisibleTo(assemblyName: "RunTests" + InternalPublicKey.Value)]
1717
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)]
18+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests" + PublicKey.TestValue)]
1819

1920
[assembly: WantsToBeBestFriends]

src/Microsoft.ML.FastTree/RandomForest.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ protected RandomForestTrainerBase(IHostEnvironment env,
4242
_quantileEnabled = quantileEnabled;
4343
}
4444

45-
protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
45+
internal override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
4646
{
4747
Host.CheckValue(ch, nameof(ch));
4848
IGradientAdjuster gradientWrapper = MakeGradientWrapper(ch);
@@ -63,7 +63,7 @@ protected override void InitializeTests()
6363
{
6464
}
6565

66-
protected override TreeLearner ConstructTreeLearner(IChannel ch)
66+
internal override TreeLearner ConstructTreeLearner(IChannel ch)
6767
{
6868
return new RandomForestLeastSquaresTreeLearner(
6969
TrainSet, Args.NumLeaves, Args.MinDocumentsInLeafs, Args.EntropyCoefficient,

src/Microsoft.ML.FastTree/RandomForestClassification.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ private static VersionInfo GetVersionInfo()
8080
/// </summary>
8181
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
8282

83-
public FastForestClassificationModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
83+
internal FastForestClassificationModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
8484
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
8585
{ }
8686

src/Microsoft.ML.FastTree/RandomForestRegression.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ private static VersionInfo GetVersionInfo()
6060

6161
protected override uint VerCategoricalSplitSerialized => 0x00010006;
6262

63-
public FastForestRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount)
63+
internal FastForestRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount)
6464
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
6565
{
6666
_quantileSampleCount = samplesCount;

src/Microsoft.ML.FastTree/Training/BaggingProvider.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public int GetBagCount(int numTrees, int bagSize)
7575
// Divides output values of leaves to bag count.
7676
// This brings back the final scores generated by model on a same
7777
// range as when we didn't use bagging
78-
public void ScaleEnsembleLeaves(int numTrees, int bagSize, TreeEnsemble ensemble)
78+
internal void ScaleEnsembleLeaves(int numTrees, int bagSize, TreeEnsemble ensemble)
7979
{
8080
int bagCount = GetBagCount(numTrees, bagSize);
8181
for (int t = 0; t < ensemble.NumTrees; t++)

src/Microsoft.ML.FastTree/Training/DocumentPartitioning.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public DocumentPartitioning(int[] documents, int numDocuments, int maxLeaves)
5050
/// Constructs partitioning object based on the documents and RegressionTree splits
5151
/// NOTE: It has been optimized for speed and multiprocs with 10x gain on naive LINQ implementation
5252
/// </summary>
53-
public DocumentPartitioning(RegressionTree tree, Dataset dataset)
53+
internal DocumentPartitioning(RegressionTree tree, Dataset dataset)
5454
: this(dataset.NumDocs, tree.NumLeaves)
5555
{
5656
using (Timer.Time(TimerEvent.DocumentPartitioningConstruction))

src/Microsoft.ML.FastTree/Training/EnsembleCompression/IEnsembleCompressor.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
namespace Microsoft.ML.Trainers.FastTree.Internal
66
{
7-
public interface IEnsembleCompressor<TLabel>
7+
internal interface IEnsembleCompressor<TLabel>
88
{
99
void Initialize(int numTrees, Dataset trainSet, TLabel[] labels, int randomSeed);
1010

0 commit comments

Comments
 (0)