Skip to content

Commit 40ea35c

Browse files
committed
review comments - 4. Single Constructor (FastTreeTweedie)
1 parent 5ba8d7e commit 40ea35c

File tree

3 files changed

+15
-30
lines changed

3 files changed

+15
-30
lines changed

src/Microsoft.ML.FastTree/FastTreeTweedie.cs

+1-28
Original file line numberDiff line numberDiff line change
@@ -48,39 +48,12 @@ public sealed partial class FastTreeTweedieTrainer
4848

4949
private SchemaShape.Column[] _outputColumns;
5050

51-
/// <summary>
52-
/// Initializes a new instance of <see cref="FastTreeTweedieTrainer"/>
53-
/// </summary>
54-
/// <param name="env">The private instance of <see cref="IHostEnvironment"/>.</param>
55-
/// <param name="labelColumn">The name of the label column.</param>
56-
/// <param name="featureColumn">The name of the feature column.</param>
57-
/// <param name="weightColumn">The name for the column containing the initial weight.</param>
58-
/// <param name="learningRate">The learning rate.</param>
59-
/// <param name="minDatapointsInLeaves">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
60-
/// <param name="numLeaves">The max number of leaves in each regression tree.</param>
61-
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
62-
public FastTreeTweedieTrainer(IHostEnvironment env,
63-
string labelColumn = DefaultColumnNames.Label,
64-
string featureColumn = DefaultColumnNames.Features,
65-
string weightColumn = null,
66-
int numLeaves = Defaults.NumLeaves,
67-
int numTrees = Defaults.NumTrees,
68-
int minDatapointsInLeaves = Defaults.MinDocumentsInLeaves,
69-
double learningRate = Defaults.LearningRates)
70-
: base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, numLeaves, numTrees, minDatapointsInLeaves, learningRate)
71-
{
72-
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
73-
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
74-
75-
Initialize();
76-
}
77-
7851
/// <summary>
7952
/// Initializes a new instance of <see cref="FastTreeTweedieTrainer"/> by using the <see cref="Options"/> class.
8053
/// </summary>
8154
/// <param name="env">The instance of <see cref="IHostEnvironment"/>.</param>
8255
/// <param name="options">Algorithm advanced settings.</param>
83-
public FastTreeTweedieTrainer(IHostEnvironment env, Options options)
56+
internal FastTreeTweedieTrainer(IHostEnvironment env, Options options)
8457
: base(env, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumn))
8558
{
8659
Initialize();

src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs

+13-1
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,19 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.Regr
225225
{
226226
Contracts.CheckValue(ctx, nameof(ctx));
227227
var env = CatalogUtils.GetEnvironment(ctx);
228-
return new FastTreeTweedieTrainer(env, labelColumn, featureColumn, weights, numLeaves, numTrees, minDatapointsInLeaves, learningRate);
228+
229+
var options = new FastTreeTweedieTrainer.Options()
230+
{
231+
LabelColumn = labelColumn,
232+
FeatureColumn = featureColumn,
233+
WeightColumn = weights,
234+
NumLeaves = numLeaves,
235+
NumTrees = numTrees,
236+
MinDocumentsInLeafs = minDatapointsInLeaves,
237+
LearningRates = learningRate,
238+
};
239+
240+
return new FastTreeTweedieTrainer(env, options);
229241
}
230242

231243
/// <summary>

test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ public void GAMRegressorEstimator()
197197
public void TweedieRegressorEstimator()
198198
{
199199
var dataView = GetRegressionPipeline();
200-
var trainer = new FastTreeTweedieTrainer(Env,
200+
var trainer = ML.Regression.Trainers.FastTreeTweedie(
201201
new FastTreeTweedieTrainer.Options {
202202
EntropyCoefficient = 0.3,
203203
OptimizationAlgorithm = BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent,

0 commit comments

Comments
 (0)