Skip to content

Commit d4f7b3c

Browse files
committed
Address comments.
1. Rename classes. 2. Use IReadOnlyList instead of ReadOnlySpan. 3. Remove new namespace, Representation.
1 parent 284e715 commit d4f7b3c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+241
-202
lines changed

src/Microsoft.ML.FastTree/FastTree.cs

+16-16
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
5656
protected readonly TArgs Args;
5757
protected readonly bool AllowGC;
5858
protected int FeatureCount;
59-
private protected TreeEnsemble TrainedEnsemble;
59+
private protected InternalTreeEnsemble TrainedEnsemble;
6060
private protected RoleMappedData ValidData;
6161
/// <summary>
6262
/// If not null, it's a test data set passed in from training context. It will be converted to one element in
@@ -88,7 +88,7 @@ public abstract class FastTreeTrainerBase<TArgs, TTransformer, TModel> :
8888
protected double[] InitValidScores;
8989
protected double[][] InitTestScores;
9090
//protected int Iteration;
91-
private protected TreeEnsemble Ensemble;
91+
private protected InternalTreeEnsemble Ensemble;
9292

9393
protected bool HasValidSet => ValidSet != null;
9494

@@ -488,7 +488,7 @@ protected bool AreSamplesWeighted(IChannel ch)
488488

489489
private void InitializeEnsemble()
490490
{
491-
Ensemble = new TreeEnsemble();
491+
Ensemble = new InternalTreeEnsemble();
492492
}
493493

494494
/// <summary>
@@ -793,7 +793,7 @@ private float GetMachineAvailableBytes()
793793

794794
// This method is called at the end of each training iteration, with the tree that was learnt on that iteration.
795795
// Note that this tree can be null if no tree was learnt this iteration.
796-
private protected virtual void CustomizedTrainingIteration(RegressionTree tree)
796+
private protected virtual void CustomizedTrainingIteration(InternalRegressionTree tree)
797797
{
798798
}
799799

@@ -924,7 +924,7 @@ internal abstract class DataConverter
924924
/// of features we actually trained on. This can be null in the event that no filtering
925925
/// occurred.
926926
/// </summary>
927-
/// <seealso cref="TreeEnsemble.RemapFeatures"/>
927+
/// <seealso cref="InternalTreeEnsemble.RemapFeatures"/>
928928
public int[] FeatureMap;
929929

930930
protected readonly IHost Host;
@@ -2811,13 +2811,13 @@ public abstract class TreeEnsembleModelParameters :
28112811
ISingleCanSaveOnnx
28122812
{
28132813
/// <summary>
2814-
/// An ensemble of trees exposed to users. It is a wrapper on the <see langword="internal"/> <see cref="TreeEnsemble"/> in <see cref="ML.FastTree.Representation.TreeRegressorCollection"/>.
2814+
/// An ensemble of trees exposed to users. It is a wrapper on the <see langword="internal"/> <see cref="InternalTreeEnsemble"/> in <see cref="ML.FastTree.Representation.TreeEnsemble"/>.
28152815
/// </summary>
2816-
public ML.FastTree.Representation.TreeRegressorCollection TrainedTreeCollection { get; }
2816+
public ML.FastTree.Representation.TreeEnsemble TrainedTreeCollection { get; }
28172817

28182818
// The below two properties are necessary for tree Visualizer
28192819
[BestFriend]
2820-
internal TreeEnsemble TrainedEnsemble => TrainedTreeCollection.TreeEnsemble;
2820+
internal InternalTreeEnsemble TrainedEnsemble => TrainedTreeCollection.UnderlyingTreeEnsemble;
28212821
int ITreeEnsemble.NumTrees => TrainedEnsemble.NumTrees;
28222822

28232823
// Inner args is used only for documentation purposes when saving comments to INI files.
@@ -2861,7 +2861,7 @@ public abstract class TreeEnsembleModelParameters :
28612861

28622862
/// The following function is used in both FastTree and LightGBM so <see cref="BestFriendAttribute"/> is required.
28632863
[BestFriend]
2864-
internal TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
2864+
internal TreeEnsembleModelParameters(IHostEnvironment env, string name, InternalTreeEnsemble trainedEnsemble, int numFeatures, string innerArgs)
28652865
: base(env, name)
28662866
{
28672867
Host.CheckValue(trainedEnsemble, nameof(trainedEnsemble));
@@ -2871,7 +2871,7 @@ internal TreeEnsembleModelParameters(IHostEnvironment env, string name, TreeEnse
28712871
// REVIEW: When we make the predictor wrapper, we may want to further "optimize"
28722872
// the trained ensemble to, for instance, resize arrays so that they are of the length
28732873
// the actual number of leaves/nodes, or remove unnecessary arrays, and so forth.
2874-
TrainedTreeCollection = new ML.FastTree.Representation.TreeRegressorCollection(trainedEnsemble);
2874+
TrainedTreeCollection = new ML.FastTree.Representation.TreeEnsemble(trainedEnsemble);
28752875
InnerArgs = innerArgs;
28762876
NumFeatures = numFeatures;
28772877

@@ -2899,7 +2899,7 @@ protected TreeEnsembleModelParameters(IHostEnvironment env, string name, ModelLo
28992899
if (ctx.Header.ModelVerWritten >= VerCategoricalSplitSerialized)
29002900
categoricalSplits = true;
29012901

2902-
TrainedTreeCollection = new ML.FastTree.Representation.TreeRegressorCollection(new TreeEnsemble(ctx, usingDefaultValues, categoricalSplits));
2902+
TrainedTreeCollection = new ML.FastTree.Representation.TreeEnsemble(new InternalTreeEnsemble(ctx, usingDefaultValues, categoricalSplits));
29032903
MaxSplitFeatIdx = TrainedEnsemble.GetMaxFeatureIndex();
29042904

29052905
InnerArgs = ctx.LoadStringOrNull();
@@ -3202,7 +3202,7 @@ private void SaveEnsembleAsCode(TextWriter writer, RoleMappedSchema schema)
32023202
MetadataUtils.GetSlotNames(schema, RoleMappedSchema.ColumnRole.Feature, NumFeatures, ref names);
32033203

32043204
int i = 0;
3205-
foreach (RegressionTree tree in TrainedEnsemble.Trees)
3205+
foreach (InternalRegressionTree tree in TrainedEnsemble.Trees)
32063206
{
32073207
writer.Write("double treeOutput{0}=", i);
32083208
SaveTreeAsCode(tree, writer, in names);
@@ -3218,13 +3218,13 @@ private void SaveEnsembleAsCode(TextWriter writer, RoleMappedSchema schema)
32183218
/// <summary>
32193219
/// Convert a single tree to code, called recursively
32203220
/// </summary>
3221-
private void SaveTreeAsCode(RegressionTree tree, TextWriter writer, in VBuffer<ReadOnlyMemory<char>> names)
3221+
private void SaveTreeAsCode(InternalRegressionTree tree, TextWriter writer, in VBuffer<ReadOnlyMemory<char>> names)
32223222
{
32233223
ToCSharp(tree, writer, 0, in names);
32243224
}
32253225

32263226
// converts a subtree into a C# expression
3227-
private void ToCSharp(RegressionTree tree, TextWriter writer, int node, in VBuffer<ReadOnlyMemory<char>> names)
3227+
private void ToCSharp(InternalRegressionTree tree, TextWriter writer, int node, in VBuffer<ReadOnlyMemory<char>> names)
32283228
{
32293229
if (node < 0)
32303230
{
@@ -3308,9 +3308,9 @@ Row ICanGetSummaryAsIRow.GetStatsIRowOrNull(RoleMappedSchema schema)
33083308

33093309
private sealed class Tree : ITree<VBuffer<Float>>
33103310
{
3311-
private readonly RegressionTree _regTree;
3311+
private readonly InternalRegressionTree _regTree;
33123312

3313-
public Tree(RegressionTree regTree)
3313+
public Tree(InternalRegressionTree regTree)
33143314
{
33153315
_regTree = regTree;
33163316
}

src/Microsoft.ML.FastTree/FastTreeClassification.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ private static VersionInfo GetVersionInfo()
7070

7171
protected override uint VerCategoricalSplitSerialized => 0x00010005;
7272

73-
internal FastTreeBinaryModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
73+
internal FastTreeBinaryModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
7474
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
7575
{
7676
}
@@ -365,7 +365,7 @@ protected override void GetGradientInOneQuery(int query, int threadIndex)
365365
}
366366
}
367367

368-
public void AdjustTreeOutputs(IChannel ch, RegressionTree tree,
368+
public void AdjustTreeOutputs(IChannel ch, InternalRegressionTree tree,
369369
DocumentPartitioning partitioning, ScoreTracker trainingScores)
370370
{
371371
const double epsilon = 1.4e-45;

src/Microsoft.ML.FastTree/FastTreeRanking.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ protected override void Train(IChannel ch)
374374
PrintTestGraph(ch);
375375
}
376376

377-
private protected override void CustomizedTrainingIteration(RegressionTree tree)
377+
private protected override void CustomizedTrainingIteration(InternalRegressionTree tree)
378378
{
379379
Contracts.AssertValueOrNull(tree);
380380
if (tree != null && Args.CompressEnsemble)
@@ -992,7 +992,7 @@ protected override void GetGradientInOneQuery(int query, int threadIndex)
992992
}
993993
}
994994

995-
void IStepSearch.AdjustTreeOutputs(IChannel ch, RegressionTree tree, DocumentPartitioning partitioning,
995+
void IStepSearch.AdjustTreeOutputs(IChannel ch, InternalRegressionTree tree, DocumentPartitioning partitioning,
996996
ScoreTracker trainingScores)
997997
{
998998
const double epsilon = 1.4e-45;
@@ -1131,7 +1131,7 @@ private static VersionInfo GetVersionInfo()
11311131

11321132
protected override uint VerCategoricalSplitSerialized => 0x00010005;
11331133

1134-
internal FastTreeRankingModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
1134+
internal FastTreeRankingModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
11351135
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
11361136
{
11371137
}

src/Microsoft.ML.FastTree/FastTreeRegression.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,7 @@ public ObjectiveImpl(Dataset trainData, Options options)
416416
_labels = GetDatasetRegressionLabels(trainData);
417417
}
418418

419-
public void AdjustTreeOutputs(IChannel ch, RegressionTree tree, DocumentPartitioning partitioning, ScoreTracker trainingScores)
419+
public void AdjustTreeOutputs(IChannel ch, InternalRegressionTree tree, DocumentPartitioning partitioning, ScoreTracker trainingScores)
420420
{
421421
double shrinkage = LearningRate * Shrinkage;
422422
for (int l = 0; l < tree.NumLeaves; ++l)
@@ -467,7 +467,7 @@ private static VersionInfo GetVersionInfo()
467467

468468
protected override uint VerCategoricalSplitSerialized => 0x00010005;
469469

470-
internal FastTreeRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
470+
internal FastTreeRegressionModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
471471
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
472472
{
473473
}

src/Microsoft.ML.FastTree/FastTreeTweedie.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ public ObjectiveImpl(Dataset trainData, Options options)
363363
_maxClamp = Math.Abs(options.MaxTreeOutput);
364364
}
365365

366-
public void AdjustTreeOutputs(IChannel ch, RegressionTree tree, DocumentPartitioning partitioning, ScoreTracker trainingScores)
366+
public void AdjustTreeOutputs(IChannel ch, InternalRegressionTree tree, DocumentPartitioning partitioning, ScoreTracker trainingScores)
367367
{
368368
double shrinkage = LearningRate * Shrinkage;
369369
var scores = trainingScores.Scores;
@@ -470,7 +470,7 @@ private static VersionInfo GetVersionInfo()
470470

471471
protected override uint VerCategoricalSplitSerialized => 0x00010003;
472472

473-
internal FastTreeTweedieModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
473+
internal FastTreeTweedieModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
474474
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
475475
{
476476
}

src/Microsoft.ML.FastTree/GamModelParameters.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ private void GetFeatureContributions(in VBuffer<float> features, ref VBuffer<flo
437437
void ICanSaveInIniFormat.SaveAsIni(TextWriter writer, RoleMappedSchema schema, ICalibrator calibrator)
438438
{
439439
Host.CheckValue(writer, nameof(writer), "writer must not be null");
440-
var ensemble = new TreeEnsemble();
440+
var ensemble = new InternalTreeEnsemble();
441441

442442
for (int featureIndex = 0; featureIndex < NumShapeFunctions; featureIndex++)
443443
{
@@ -525,11 +525,11 @@ private int CreateBalancedTreeRecursive(int lower, int upper,
525525
}
526526
}
527527

528-
private static RegressionTree CreateRegressionTree(
528+
private static InternalRegressionTree CreateRegressionTree(
529529
int numLeaves, int[] splitFeatures, float[] rawThresholds, int[] lteChild, int[] gtChild, double[] leafValues)
530530
{
531531
var numInternalNodes = numLeaves - 1;
532-
return RegressionTree.Create(
532+
return InternalRegressionTree.Create(
533533
numLeaves: numLeaves,
534534
splitFeatures: splitFeatures,
535535
rawThresholds: rawThresholds,

src/Microsoft.ML.FastTree/RandomForestClassification.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ private static VersionInfo GetVersionInfo()
8080
/// </summary>
8181
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
8282

83-
internal FastForestClassificationModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
83+
internal FastForestClassificationModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs)
8484
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
8585
{ }
8686

src/Microsoft.ML.FastTree/RandomForestRegression.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ private static VersionInfo GetVersionInfo()
6060

6161
protected override uint VerCategoricalSplitSerialized => 0x00010006;
6262

63-
internal FastForestRegressionModelParameters(IHostEnvironment env, TreeEnsemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount)
63+
internal FastForestRegressionModelParameters(IHostEnvironment env, InternalTreeEnsemble trainedEnsemble, int featureCount, string innerArgs, int samplesCount)
6464
: base(env, RegistrationName, trainedEnsemble, featureCount, innerArgs)
6565
{
6666
_quantileSampleCount = samplesCount;

0 commit comments

Comments
 (0)