Skip to content

Commit b255eb3

Browse files
committed
Pass objects as arguments instead of delegate
1 parent 9c9c114 commit b255eb3

File tree

14 files changed

+59
-65
lines changed

14 files changed

+59
-65
lines changed

src/Microsoft.ML.FastTree/BoostingFastTree.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ protected BoostingFastTreeTrainerBase(IHostEnvironment env,
4444
string featureColumn,
4545
string weightColumn,
4646
string groupIdColumn,
47-
Action<TArgs> advancedSettings)
47+
TArgs advancedSettings)
4848
: base(env, label, featureColumn, weightColumn, groupIdColumn, advancedSettings)
4949
{
5050
}

src/Microsoft.ML.FastTree/FastTree.cs

+2-5
Original file line numberDiff line numberDiff line change
@@ -155,13 +155,10 @@ private protected FastTreeTrainerBase(IHostEnvironment env,
155155
string featureColumn,
156156
string weightColumn,
157157
string groupIdColumn,
158-
Action<TArgs> advancedSettings)
158+
TArgs advancedSettings)
159159
: base(Contracts.CheckRef(env, nameof(env)).Register(RegisterName), TrainerUtils.MakeR4VecFeature(featureColumn), label, TrainerUtils.MakeR4ScalarWeightColumn(weightColumn), TrainerUtils.MakeU4ScalarColumn(groupIdColumn))
160160
{
161-
Args = new TArgs();
162-
163-
//apply the advanced args, if the user supplied any
164-
advancedSettings?.Invoke(Args);
161+
Args = advancedSettings;
165162

166163
Args.LabelColumn = label.Name;
167164
Args.FeatureColumn = featureColumn;

src/Microsoft.ML.FastTree/FastTreeClassification.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,12 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env,
148148
/// <param name="labelColumn">The name of the label column.</param>
149149
/// <param name="featureColumn">The name of the feature column.</param>
150150
/// <param name="weightColumn">The name for the column containing the initial weight.</param>
151-
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
151+
/// <param name="advancedSettings">Advanced arguments to the algorithm.</param>
152152
public FastTreeBinaryClassificationTrainer(IHostEnvironment env,
153153
string labelColumn,
154154
string featureColumn,
155155
string weightColumn,
156-
Action<Options> advancedSettings)
156+
Options advancedSettings)
157157
: base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings)
158158
{
159159
// Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss

src/Microsoft.ML.FastTree/FastTreeRanking.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,13 @@ public FastTreeRankingTrainer(IHostEnvironment env,
9393
/// <param name="featureColumn">The name of the feature column.</param>
9494
/// <param name="groupIdColumn">The name for the column containing the group ID. </param>
9595
/// <param name="weightColumn">The name for the column containing the initial weight.</param>
96-
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
96+
/// <param name="advancedSettings">Advanced arguments to the algorithm.</param>
9797
public FastTreeRankingTrainer(IHostEnvironment env,
9898
string labelColumn,
9999
string featureColumn,
100100
string groupIdColumn,
101101
string weightColumn,
102-
Action<Options> advancedSettings = null)
102+
Options advancedSettings)
103103
: base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, groupIdColumn, advancedSettings)
104104
{
105105
Host.CheckNonEmpty(groupIdColumn, nameof(groupIdColumn));

src/Microsoft.ML.FastTree/FastTreeRegression.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,12 @@ public FastTreeRegressionTrainer(IHostEnvironment env,
8181
/// <param name="labelColumn">The name of the label column.</param>
8282
/// <param name="featureColumn">The name of the feature column.</param>
8383
/// <param name="weightColumn">The name for the column containing the initial weight.</param>
84-
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
84+
/// <param name="advancedSettings">Advanced arguments to the algorithm.</param>
8585
public FastTreeRegressionTrainer(IHostEnvironment env,
8686
string labelColumn,
8787
string featureColumn,
8888
string weightColumn,
89-
Action<Options> advancedSettings = null)
89+
Options advancedSettings)
9090
: base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings)
9191
{
9292
}

src/Microsoft.ML.FastTree/FastTreeTweedie.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ public FastTreeTweedieTrainer(IHostEnvironment env,
8282
/// <param name="labelColumn">The name of the label column.</param>
8383
/// <param name="featureColumn">The name of the feature column.</param>
8484
/// <param name="weightColumn">The name for the column containing the initial weight.</param>
85-
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
85+
/// <param name="advancedSettings">Advanced arguments to the algorithm.</param>
8686
public FastTreeTweedieTrainer(IHostEnvironment env,
8787
string labelColumn,
8888
string featureColumn,
8989
string weightColumn,
90-
Action<Options> advancedSettings)
90+
Options advancedSettings)
9191
: base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings)
9292
{
9393
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));

src/Microsoft.ML.FastTree/RandomForest.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ protected RandomForestTrainerBase(IHostEnvironment env,
5050
string featureColumn,
5151
string weightColumn,
5252
string groupIdColumn,
53-
Action<TArgs> advancedSettings,
53+
TArgs advancedSettings,
5454
bool quantileEnabled = false)
5555
: base(env, label, featureColumn, weightColumn, null, advancedSettings)
5656
{

src/Microsoft.ML.FastTree/RandomForestClassification.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,12 @@ public FastForestClassification(IHostEnvironment env,
166166
/// <param name="labelColumn">The name of the label column.</param>
167167
/// <param name="featureColumn">The name of the feature column.</param>
168168
/// <param name="weightColumn">The name for the column containing the initial weight.</param>
169-
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
169+
/// <param name="advancedSettings">Advanced arguments to the algorithm.</param>
170170
public FastForestClassification(IHostEnvironment env,
171171
string labelColumn,
172172
string featureColumn,
173173
string weightColumn,
174-
Action<Options> advancedSettings)
174+
Options advancedSettings)
175175
: base(env, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weightColumn, null, advancedSettings)
176176
{
177177
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));

src/Microsoft.ML.FastTree/RandomForestRegression.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -184,12 +184,12 @@ public FastForestRegression(IHostEnvironment env,
184184
/// <param name="labelColumn">The name of the label column.</param>
185185
/// <param name="featureColumn">The name of the feature column.</param>
186186
/// <param name="weightColumn">The optional name for the column containing the initial weight.</param>
187-
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
187+
/// <param name="advancedSettings">Advanced arguments to the algorithm.</param>
188188
public FastForestRegression(IHostEnvironment env,
189189
string labelColumn,
190190
string featureColumn,
191191
string weightColumn,
192-
Action<Options> advancedSettings)
192+
Options advancedSettings)
193193
: base(env, TrainerUtils.MakeR4ScalarColumn(labelColumn), featureColumn, weightColumn, null, advancedSettings)
194194
{
195195
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));

src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi
5050
string labelColumn,
5151
string featureColumn,
5252
string weights,
53-
Action<FastTreeRegressionTrainer.Options> advancedSettings)
53+
FastTreeRegressionTrainer.Options advancedSettings)
5454
{
5555
Contracts.CheckValue(ctx, nameof(ctx));
5656
var env = CatalogUtils.GetEnvironment(ctx);
@@ -94,7 +94,7 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica
9494
string labelColumn,
9595
string featureColumn,
9696
string weights,
97-
Action<FastTreeBinaryClassificationTrainer.Options> advancedSettings)
97+
FastTreeBinaryClassificationTrainer.Options advancedSettings)
9898
{
9999
Contracts.CheckValue(ctx, nameof(ctx));
100100
var env = CatalogUtils.GetEnvironment(ctx);
@@ -142,7 +142,7 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer
142142
string featureColumn,
143143
string groupId,
144144
string weights,
145-
Action<FastTreeRankingTrainer.Options> advancedSettings)
145+
FastTreeRankingTrainer.Options advancedSettings)
146146
{
147147
Contracts.CheckValue(ctx, nameof(ctx));
148148
var env = CatalogUtils.GetEnvironment(ctx);
@@ -236,7 +236,7 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.Regr
236236
string labelColumn,
237237
string featureColumn,
238238
string weights,
239-
Action<FastTreeTweedieTrainer.Options> advancedSettings = null)
239+
FastTreeTweedieTrainer.Options advancedSettings = null)
240240
{
241241
Contracts.CheckValue(ctx, nameof(ctx));
242242
var env = CatalogUtils.GetEnvironment(ctx);
@@ -280,7 +280,7 @@ public static FastForestRegression FastForest(this RegressionContext.RegressionT
280280
string labelColumn,
281281
string featureColumn,
282282
string weights,
283-
Action<FastForestRegression.Options> advancedSettings = null)
283+
FastForestRegression.Options advancedSettings = null)
284284
{
285285
Contracts.CheckValue(ctx, nameof(ctx));
286286
var env = CatalogUtils.GetEnvironment(ctx);
@@ -324,7 +324,7 @@ public static FastForestClassification FastForest(this BinaryClassificationConte
324324
string labelColumn,
325325
string featureColumn,
326326
string weights,
327-
Action<FastForestClassification.Options> advancedSettings)
327+
FastForestClassification.Options advancedSettings)
328328
{
329329
Contracts.CheckValue(ctx, nameof(ctx));
330330
var env = CatalogUtils.GetEnvironment(ctx);

src/Microsoft.ML.StaticPipe/TreeTrainersStatic.cs

+9-8
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,11 @@ public static Scalar<float> FastTree(this RegressionContext.RegressionTrainers c
8484
/// </example>
8585
public static Scalar<float> FastTree(this RegressionContext.RegressionTrainers ctx,
8686
Scalar<float> label, Vector<float> features, Scalar<float> weights,
87-
Action<FastTreeRegressionTrainer.Options> advancedSettings,
87+
FastTreeRegressionTrainer.Options advancedSettings,
8888
Action<FastTreeRegressionModelParameters> onFit = null)
8989
{
90-
CheckUserValues(label, features, weights, advancedSettings, onFit);
90+
Contracts.CheckValueOrNull(advancedSettings);
91+
CheckUserValues(label, features, weights, onFit);
9192

9293
var rec = new TrainerEstimatorReconciler.Regression(
9394
(env, labelName, featuresName, weightsName) =>
@@ -175,10 +176,11 @@ public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> pred
175176
/// </example>
176177
public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel) FastTree(this BinaryClassificationContext.BinaryClassificationTrainers ctx,
177178
Scalar<bool> label, Vector<float> features, Scalar<float> weights,
178-
Action<FastTreeBinaryClassificationTrainer.Options> advancedSettings,
179+
FastTreeBinaryClassificationTrainer.Options advancedSettings,
179180
Action<IPredictorWithFeatureWeights<float>> onFit = null)
180181
{
181-
CheckUserValues(label, features, weights, advancedSettings, onFit);
182+
Contracts.CheckValueOrNull(advancedSettings);
183+
CheckUserValues(label, features, weights, onFit);
182184

183185
var rec = new TrainerEstimatorReconciler.BinaryClassifier(
184186
(env, labelName, featuresName, weightsName) =>
@@ -254,10 +256,11 @@ public static Scalar<float> FastTree<TVal>(this RankingContext.RankingTrainers c
254256
/// <returns>The Score output column indicating the predicted value.</returns>
255257
public static Scalar<float> FastTree<TVal>(this RankingContext.RankingTrainers ctx,
256258
Scalar<float> label, Vector<float> features, Key<uint, TVal> groupId, Scalar<float> weights,
257-
Action<FastTreeRankingTrainer.Options> advancedSettings,
259+
FastTreeRankingTrainer.Options advancedSettings,
258260
Action<FastTreeRankingModelParameters> onFit = null)
259261
{
260-
CheckUserValues(label, features, weights, advancedSettings, onFit);
262+
Contracts.CheckValueOrNull(advancedSettings);
263+
CheckUserValues(label, features, weights, onFit);
261264

262265
var rec = new TrainerEstimatorReconciler.Ranker<TVal>(
263266
(env, labelName, featuresName, groupIdName, weightsName) =>
@@ -289,13 +292,11 @@ internal static void CheckUserValues(PipelineColumn label, Vector<float> feature
289292
}
290293

291294
internal static void CheckUserValues(PipelineColumn label, Vector<float> features, Scalar<float> weights,
292-
Delegate advancedSettings,
293295
Delegate onFit)
294296
{
295297
Contracts.CheckValue(label, nameof(label));
296298
Contracts.CheckValue(features, nameof(features));
297299
Contracts.CheckValueOrNull(weights);
298-
Contracts.CheckValueOrNull(advancedSettings);
299300
Contracts.CheckValueOrNull(onFit);
300301
}
301302
}

src/Microsoft.ML.Sweeper/Algorithms/SmacSweeper.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,11 @@ private FastForestRegressionModelParameters FitModel(IEnumerable<IRunResult> pre
136136
// Set relevant random forest arguments.
137137
// Train random forest.
138138
var trainer = new FastForestRegression(_host, DefaultColumnNames.Label, DefaultColumnNames.Features, null,
139-
advancedSettings: s =>
139+
new FastForestRegression.Options
140140
{
141-
s.FeatureFraction = _args.SplitRatio;
142-
s.NumTrees = _args.NumOfTrees;
143-
s.MinDocumentsInLeafs = _args.NMinForSplit;
141+
FeatureFraction = _args.SplitRatio,
142+
NumTrees = _args.NumOfTrees,
143+
MinDocumentsInLeafs = _args.NMinForSplit,
144144
});
145145
var predictor = trainer.Train(data);
146146

test/Microsoft.ML.Tests/Scenarios/OvaTest.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ public void OvaFastTree()
105105
var pipeline = new Ova(
106106
mlContext,
107107
new FastTreeBinaryClassificationTrainer(mlContext, DefaultColumnNames.Label, DefaultColumnNames.Features, null,
108-
advancedSettings: s => { s.NumThreads = 1; }),
108+
new FastTreeBinaryClassificationTrainer.Options { NumThreads = 1 }),
109109
useProbabilities: false);
110110

111111
var model = pipeline.Fit(data);

test/Microsoft.ML.Tests/TrainerEstimators/TreeEstimators.cs

+23-27
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ public void FastTreeBinaryEstimator()
2525
var (pipe, dataView) = GetBinaryClassificationPipeline();
2626

2727
var trainer = new FastTreeBinaryClassificationTrainer(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null,
28-
advancedSettings: s => {
29-
s.NumThreads = 1;
30-
s.NumTrees = 10;
31-
s.NumLeaves = 5;
28+
new FastTreeBinaryClassificationTrainer.Options {
29+
NumThreads = 1,
30+
NumTrees = 10,
31+
NumLeaves = 5,
3232
});
3333

3434
var pipeWithTrainer = pipe.Append(trainer);
@@ -83,11 +83,12 @@ public void FastForestClassificationEstimator()
8383
{
8484
var (pipe, dataView) = GetBinaryClassificationPipeline();
8585

86-
var trainer = new FastForestClassification(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, advancedSettings: s =>
87-
{
88-
s.NumLeaves = 10;
89-
s.NumTrees = 20;
90-
});
86+
var trainer = new FastForestClassification(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null,
87+
new FastForestClassification.Options {
88+
NumLeaves = 10,
89+
NumTrees = 20,
90+
});
91+
9192
var pipeWithTrainer = pipe.Append(trainer);
9293
TestEstimatorCore(pipeWithTrainer, dataView);
9394

@@ -104,8 +105,7 @@ public void FastTreeRankerEstimator()
104105
{
105106
var (pipe, dataView) = GetRankingPipeline();
106107

107-
var trainer = new FastTreeRankingTrainer(Env, "Label0", "NumericFeatures", "Group", null,
108-
advancedSettings: s => { s.NumTrees = 10; });
108+
var trainer = new FastTreeRankingTrainer(Env, "Label0", "NumericFeatures", "Group", null, new FastTreeRankingTrainer.Options { NumTrees = 10 });
109109
var pipeWithTrainer = pipe.Append(trainer);
110110
TestEstimatorCore(pipeWithTrainer, dataView);
111111

@@ -139,12 +139,8 @@ public void LightGBMRankerEstimator()
139139
public void FastTreeRegressorEstimator()
140140
{
141141
var dataView = GetRegressionPipeline();
142-
var trainer = new FastTreeRegressionTrainer(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, advancedSettings: s =>
143-
{
144-
s.NumTrees = 10;
145-
s.NumThreads = 1;
146-
s.NumLeaves = 5;
147-
});
142+
var trainer = new FastTreeRegressionTrainer(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null,
143+
new FastTreeRegressionTrainer.Options { NumTrees = 10, NumThreads = 1, NumLeaves = 5 });
148144

149145
TestEstimatorCore(trainer, dataView);
150146
var model = trainer.Train(dataView, dataView);
@@ -196,11 +192,11 @@ public void GAMRegressorEstimator()
196192
public void TweedieRegressorEstimator()
197193
{
198194
var dataView = GetRegressionPipeline();
199-
var trainer = new FastTreeTweedieTrainer(Env, "Label", "Features", null, advancedSettings: s =>
200-
{
201-
s.EntropyCoefficient = 0.3;
202-
s.OptimizationAlgorithm = BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent;
203-
});
195+
var trainer = new FastTreeTweedieTrainer(Env, "Label", "Features", null,
196+
new FastTreeTweedieTrainer.Options {
197+
EntropyCoefficient = 0.3,
198+
OptimizationAlgorithm = BoostedTreeArgs.OptimizationAlgorithmType.AcceleratedGradientDescent,
199+
});
204200

205201
TestEstimatorCore(trainer, dataView);
206202
var model = trainer.Train(dataView, dataView);
@@ -214,11 +210,11 @@ public void TweedieRegressorEstimator()
214210
public void FastForestRegressorEstimator()
215211
{
216212
var dataView = GetRegressionPipeline();
217-
var trainer = new FastForestRegression(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null, advancedSettings: s =>
218-
{
219-
s.BaggingSize = 2;
220-
s.NumTrees = 10;
221-
});
213+
var trainer = new FastForestRegression(Env, DefaultColumnNames.Label, DefaultColumnNames.Features, null,
214+
new FastForestRegression.Options {
215+
BaggingSize = 2,
216+
NumTrees = 10,
217+
});
222218

223219
TestEstimatorCore(trainer, dataView);
224220
var model = trainer.Train(dataView, dataView);

0 commit comments

Comments
 (0)