Skip to content

Scrub Fast Tree Family #2753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Mar 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions docs/samples/Microsoft.ML.Samples/Dynamic/FastTreeRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,12 @@ public static void Example()
// We will train a FastTreeRegression model with 1 tree on these two columns to predict Age.
string outputColumnName = "Features";
var pipeline = ml.Transforms.Concatenate(outputColumnName, new[] { "Parity", "Induced" })
.Append(ml.Regression.Trainers.FastTree(labelColumnName: "Age", featureColumnName: outputColumnName, numTrees: 1, numLeaves: 2, minDatapointsInLeaves: 1));
.Append(ml.Regression.Trainers.FastTree(labelColumnName: "Age", featureColumnName: outputColumnName, numberOfTrees: 1, numberOfLeaves: 2, minimumExampleCountPerLeaf: 1));

var model = pipeline.Fit(trainData);

// Get the trained model parameters.
var modelParams = model.LastTransformer.Model;

// Let's see where an example with Parity = 1 and Induced = 1 would end up in the single trained tree.
var testRow = new VBuffer<float>(2, new[] { 1.0f, 1.0f });
// Use the path object to pass to GetLeaf, which will populate path with the IDs of th nodes from root to leaf.
List<int> path = default;
// Get the ID of the leaf this example ends up in tree 0.
var leafID = modelParams.GetLeaf(0, in testRow, ref path);
Copy link
Member Author

@wschin wschin Feb 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed because our public APIs doesn't provide those functions. #Resolved

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But you can still traverse through trees with new structure, right?


In reply to: 260562343 [](ancestors = 260562343)

Copy link
Member Author

@wschin wschin Feb 28, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we cannot but this definitely should be hidden here. @[email protected], do you think this feature is required in tree's public APIs?


In reply to: 261011054 [](ancestors = 261011054,260562343)

Copy link
Contributor

@TomFinley TomFinley Mar 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the trees should enable someone to explain what the tree is. I wouldn't make evaluation code part of the API. Any more than we insist that linear model parameters expose "do dot product" as a method on the model parameters. Both things are potentially useful of course, but we don't need it as part of our public API (at least at first).


In reply to: 261417026 [](ancestors = 261417026,261011054,260562343)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I feel evaluation module should be independent to the expression of the trained objects. Thanks you!


In reply to: 261734289 [](ancestors = 261734289,261417026,261011054,260562343)

// Get the leaf value for this leaf ID in tree 0.
var leafValue = modelParams.GetLeafValue(0, leafID);
Console.WriteLine("The leaf value in tree 0 is: " + leafValue);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public static void Example()
.ToArray();
var pipeline = mlContext.Transforms.Concatenate("Features", featureNames)
.Append(mlContext.Regression.Trainers.GeneralizedAdditiveModels(
labelColumnName: labelName, featureColumnName: "Features", maxBins: 16));
labelColumnName: labelName, featureColumnName: "Features", maxBinCountPerFeature: 16));
var fitPipeline = pipeline.Fit(data);

// Extract the model from the pipeline
Expand All @@ -37,7 +37,7 @@ public static void Example()
// Now investigate the properties of the Generalized Additive Model: The intercept and shape functions.

// The intercept for the GAM models represent the average prediction for the training data
var intercept = gamModel.Intercept;
var intercept = gamModel.Bias;
// Expected output: Average predicted cost: 22.53
Console.WriteLine($"Average predicted cost: {intercept:0.00}");

Expand Down Expand Up @@ -93,7 +93,7 @@ public static void Example()
// Distillation." <a href='https://arxiv.org/abs/1710.06169'>arXiv:1710.06169</a>."
Console.WriteLine();
Console.WriteLine("Student-Teacher Ratio");
for (int i = 0; i < teacherRatioBinUpperBounds.Length; i++)
for (int i = 0; i < teacherRatioBinUpperBounds.Count; i++)
{
Console.WriteLine($"x < {teacherRatioBinUpperBounds[i]:0.00} => {teacherRatioBinEffects[i]:0.000}");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ public static void FastTreeBinaryClassification()
Score: mlContext.BinaryClassification.Trainers.FastTree(
row.Label,
row.Features,
numTrees: 100, // try: (int) 20-2000
numLeaves: 20, // try: (int) 2-128
minDatapointsInLeaves: 10, // try: (int) 1-100
numberOfTrees: 100, // try: (int) 20-2000
numberOfLeaves: 20, // try: (int) 2-128
minimumExampleCountPerLeaf: 10, // try: (int) 1-100
learningRate: 0.2))) // try: (float) 0.025-0.4
.Append(row => (
Label: row.Label,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ public static void FastTreeRegression()
.Append(r => (r.label, score: mlContext.Regression.Trainers.FastTree(
r.label,
r.features,
numTrees: 100, // try: (int) 20-2000
numLeaves: 20, // try: (int) 2-128
minDatapointsInLeaves: 10, // try: (int) 1-100
numberOfTrees: 100, // try: (int) 20-2000
numberOfLeaves: 20, // try: (int) 2-128
minimumExampleCountPerLeaf: 10, // try: (int) 1-100
learningRate: 0.2, // try: (float) 0.025-0.4
onFit: p => pred = p)
)
Expand Down
34 changes: 17 additions & 17 deletions src/Microsoft.ML.FastTree/BoostingFastTree.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ private protected BoostingFastTreeTrainerBase(IHostEnvironment env, TOptions opt

private protected BoostingFastTreeTrainerBase(IHostEnvironment env,
SchemaShape.Column label,
string featureColumn,
string weightColumn,
string groupIdColumn,
int numLeaves,
int numTrees,
int minDatapointsInLeaves,
string featureColumnName,
string exampleWeightColumnName,
string rowGroupColumnName,
int numberOfLeaves,
int numberOfTrees,
int minimumExampleCountPerLeaf,
double learningRate)
: base(env, label, featureColumn, weightColumn, groupIdColumn, numLeaves, numTrees, minDatapointsInLeaves)
: base(env, label, featureColumnName, exampleWeightColumnName, rowGroupColumnName, numberOfLeaves, numberOfTrees, minimumExampleCountPerLeaf)
{
FastTreeTrainerOptions.LearningRates = learningRate;
FastTreeTrainerOptions.LearningRate = learningRate;
}

private protected override void CheckOptions(IChannel ch)
Expand All @@ -40,10 +40,10 @@ private protected override void CheckOptions(IChannel ch)
if (FastTreeTrainerOptions.CompressEnsemble && FastTreeTrainerOptions.WriteLastEnsemble)
throw ch.Except("Ensemble compression cannot be done when forcing to write last ensemble (hl)");

if (FastTreeTrainerOptions.NumLeaves > 2 && FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumLeaves - 1)
if (FastTreeTrainerOptions.NumberOfLeaves > 2 && FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumberOfLeaves - 1)
throw ch.Except("Histogram pool size (ps) must be at least 2.");

if (FastTreeTrainerOptions.NumLeaves > 2 && FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumLeaves - 1)
if (FastTreeTrainerOptions.NumberOfLeaves > 2 && FastTreeTrainerOptions.HistogramPoolSize > FastTreeTrainerOptions.NumberOfLeaves - 1)
throw ch.Except("Histogram pool size (ps) must be at most numLeaves - 1.");

if (FastTreeTrainerOptions.EnablePruning && !HasValidSet)
Expand All @@ -61,12 +61,12 @@ private protected override void CheckOptions(IChannel ch)
private protected override TreeLearner ConstructTreeLearner(IChannel ch)
{
return new LeastSquaresRegressionTreeLearner(
TrainSet, FastTreeTrainerOptions.NumLeaves, FastTreeTrainerOptions.MinDocumentsInLeafs, FastTreeTrainerOptions.EntropyCoefficient,
TrainSet, FastTreeTrainerOptions.NumberOfLeaves, FastTreeTrainerOptions.MinimumExampleCountPerLeaf, FastTreeTrainerOptions.EntropyCoefficient,
FastTreeTrainerOptions.FeatureFirstUsePenalty, FastTreeTrainerOptions.FeatureReusePenalty, FastTreeTrainerOptions.SoftmaxTemperature,
FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.RngSeed, FastTreeTrainerOptions.SplitFraction, FastTreeTrainerOptions.FilterZeroLambdas,
FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaxCategoricalGroupsPerNode,
FastTreeTrainerOptions.MaxCategoricalSplitPoints, BsrMaxTreeOutput(), ParallelTraining,
FastTreeTrainerOptions.MinDocsPercentageForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinDocsForCategoricalSplit, FastTreeTrainerOptions.Bias);
FastTreeTrainerOptions.HistogramPoolSize, FastTreeTrainerOptions.Seed, FastTreeTrainerOptions.FeatureFractionPerSplit, FastTreeTrainerOptions.FilterZeroLambdas,
FastTreeTrainerOptions.AllowEmptyTrees, FastTreeTrainerOptions.GainConfidenceLevel, FastTreeTrainerOptions.MaximumCategoricalGroupCountPerNode,
FastTreeTrainerOptions.MaximumCategoricalSplitPointCount, BsrMaxTreeOutput(), ParallelTraining,
FastTreeTrainerOptions.MinimumExampleFractionForCategoricalSplit, FastTreeTrainerOptions.Bundling, FastTreeTrainerOptions.MinimumExamplesForCategoricalSplit, FastTreeTrainerOptions.Bias);
}

private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(IChannel ch)
Expand Down Expand Up @@ -94,7 +94,7 @@ private protected override OptimizationAlgorithm ConstructOptimizationAlgorithm(
optimizationAlgorithm.ObjectiveFunction = ConstructObjFunc(ch);
optimizationAlgorithm.Smoothing = FastTreeTrainerOptions.Smoothing;
optimizationAlgorithm.DropoutRate = FastTreeTrainerOptions.DropoutRate;
optimizationAlgorithm.DropoutRng = new Random(FastTreeTrainerOptions.RngSeed);
optimizationAlgorithm.DropoutRng = new Random(FastTreeTrainerOptions.Seed);
optimizationAlgorithm.PreScoreUpdateEvent += PrintTestGraph;

return optimizationAlgorithm;
Expand Down Expand Up @@ -162,7 +162,7 @@ private protected override int GetBestIteration(IChannel ch)
private protected double BsrMaxTreeOutput()
{
if (FastTreeTrainerOptions.BestStepRankingRegressionTrees)
return FastTreeTrainerOptions.MaxTreeOutput;
return FastTreeTrainerOptions.MaximumTreeOutput;
else
return -1;
}
Expand Down
Loading