-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Adding the extension methods for FastTree classification and regression #1009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
67906db
4b023b6
93751c9
3dc333a
4d86fcf
c9abb48
1d6ade6
538ecbf
7194439
ce9301b
2cbcaba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,15 +57,41 @@ public sealed partial class FastTreeRegressionTrainer | |
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param> | ||
/// <param name="labelColumn">The name of the label column.</param> | ||
/// <param name="featureColumn">The name of the feature column.</param> | ||
/// <param name="groupIdColumn">The name for the column containing the group ID. </param> | ||
/// <param name="weightColumn">The name for the column containing the initial weight.</param> | ||
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param> | ||
public FastTreeRegressionTrainer(IHostEnvironment env, string labelColumn, string featureColumn, | ||
string weightColumn = null, string groupIdColumn = null, Action<Arguments> advancedSettings = null) | ||
: base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings) | ||
/// <param name="learningRates">The learning rate.</param> | ||
/// <param name="minDocumentsInLeafs">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param> | ||
/// <param name="numLeaves">The max number of leaves in each regression tree.</param> | ||
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param> | ||
public FastTreeRegressionTrainer(IHostEnvironment env, | ||
string labelColumn, | ||
string featureColumn, | ||
string weightColumn = null, | ||
int numLeaves = 20, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Should it be part of static Defaults class? #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You have Default class, but you don't use it here. In reply to: 220063796 [](ancestors = 220063796,220018807) |
||
int numTrees = 100, | ||
int minDocumentsInLeafs = 10, | ||
double learningRates = 0.2, | ||
Action<Arguments> advancedSettings = null) | ||
: base(env, TrainerUtils.MakeR4ScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings) | ||
{ | ||
Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); | ||
Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); | ||
|
||
if (advancedSettings != null) | ||
{ | ||
using (var ch = Host.Start("Validating advanced settings.")) | ||
{ | ||
//take a quick snapshot at the defaults, for comparison with the current args values | ||
var snapshot = new Arguments(); | ||
|
||
// Check that the user didn't supply different parameters in the args, from what it specified directly. | ||
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, Args.NumLeaves, nameof(numLeaves)); | ||
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numTrees, snapshot.NumTrees, Args.NumTrees, nameof(numTrees)); | ||
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, minDocumentsInLeafs, snapshot.MinDocumentsInLeafs, Args.MinDocumentsInLeafs, nameof(minDocumentsInLeafs)); | ||
TrainerUtils.CheckArgsAndAdvancedSettingMismatch(ch, numLeaves, snapshot.NumLeaves, Args.NumLeaves, nameof(numLeaves)); | ||
ch.Done(); | ||
} | ||
} | ||
} | ||
|
||
/// <summary> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Data.StaticPipe.Runtime; | ||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using Microsoft.ML.Runtime.FastTree; | ||
using Microsoft.ML.Runtime.Internal.Internallearn; | ||
using Microsoft.ML.Runtime.Training; | ||
using System; | ||
|
||
namespace Microsoft.ML.Trainers | ||
{ | ||
public static class FastTreeStatic | ||
{ | ||
public static Scalar<float> FastTree(this RegressionContext.RegressionTrainers ctx, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's my time to comment what you need comments for public method! Oh, the joy of revenge!111 #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did them, before you commented!!!! see the next push :) In reply to: 220283064 [](ancestors = 220283064) |
||
Scalar<float> label, Vector<float> features, Scalar<float> weights = null, | ||
int numLeaves = Defaults.NumLeaves, | ||
int numTrees = Defaults.NumTrees, | ||
int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, | ||
double learningRate= Defaults.LearningRates, | ||
Action<FastTreeRegressionTrainer.Arguments> advancedSettings = null, | ||
Action<FastTreeRegressionPredictor> onFit = null) | ||
{ | ||
CheckUserValues(label, features, weights, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings, onFit); | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. don't duplicate #Resolved |
||
var rec = new TrainerEstimatorReconciler.Regression( | ||
(env, labelName, featuresName, weightsName) => | ||
{ | ||
var trainer = new FastTreeRegressionTrainer(env, labelName, featuresName, weightsName, numLeaves, | ||
numTrees, minDocumentsInLeafs, learningRate, advancedSettings); | ||
if (onFit != null) | ||
return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); | ||
return trainer; | ||
}, label, features, weights); | ||
|
||
return rec.Score; | ||
} | ||
|
||
public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel) FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx, | ||
Scalar<bool> label, Vector<float> features, Scalar<float> weights = null, | ||
int numLeaves = Defaults.NumLeaves, | ||
int numTrees = Defaults.NumTrees, | ||
int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, | ||
double learningRate = Defaults.LearningRates, | ||
Action<FastTreeBinaryClassificationTrainer.Arguments> advancedSettings = null, | ||
Action<IPredictorWithFeatureWeights<float>> onFit = null) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
that's the best we could do here? I would argue that we should expose some form of tree ensemble predictor or something #Pending There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll leave it for now and log an issue about it. Just casting didn't work. In reply to: 220030085 [](ancestors = 220030085) |
||
{ | ||
CheckUserValues(label, features, weights, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings, onFit); | ||
|
||
var rec = new TrainerEstimatorReconciler.BinaryClassifier( | ||
(env, labelName, featuresName, weightsName) => | ||
{ | ||
var trainer = new FastTreeBinaryClassificationTrainer(env, labelName, featuresName, weightsName, numLeaves, | ||
numTrees, minDocumentsInLeafs, learningRate, advancedSettings); | ||
|
||
if (onFit != null) | ||
return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); | ||
else | ||
return trainer; | ||
}, label, features, weights); | ||
|
||
return rec.Output; | ||
} | ||
|
||
private static void CheckUserValues<TVal, TArgs, TPred>(Scalar<TVal> label, Vector<float> features, Scalar<float> weights = null, | ||
int numLeaves = Defaults.NumLeaves, | ||
int numTrees = Defaults.NumTrees, | ||
int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, | ||
double learningRate = Defaults.LearningRates, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixxxx #Resolved |
||
Action<TArgs> advancedSettings = null, | ||
Action<TPred> onFit = null) | ||
{ | ||
Contracts.CheckValue(label, nameof(label)); | ||
Contracts.CheckValue(features, nameof(features)); | ||
Contracts.CheckValueOrNull(weights); | ||
Contracts.CheckParam(numLeaves >= 2, nameof(numLeaves), "Must be at least 2."); | ||
Contracts.CheckParam(numTrees > 0, nameof(numTrees), "Must be positive"); | ||
Contracts.CheckParam(minDocumentsInLeafs > 0, nameof(minDocumentsInLeafs), "Must be positive"); | ||
Contracts.CheckParam(learningRate > 0, nameof(learningRate), "Must be positive"); | ||
Contracts.CheckValueOrNull(advancedSettings); | ||
Contracts.CheckValueOrNull(onFit); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -59,10 +59,10 @@ public static (Scalar<float> score, Scalar<bool> predictedLabel) FieldAwareFacto | |
var trainer = new FieldAwareFactorizationMachineTrainer(env, labelCol, featureCols, advancedSettings: | ||
args => | ||
{ | ||
advancedSettings?.Invoke(args); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
my order was intentional :) #Pending |
||
args.LearningRate = learningRate; | ||
args.Iters = numIterations; | ||
args.LatentDim = numLatentDimensions; | ||
advancedSettings?.Invoke(args); | ||
}); | ||
if (onFit != null) | ||
return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move in FastTreeUtils and re-use. #Resolved