Skip to content

Commit 33fe723

Browse files
committed
Options renaming
1 parent 9205525 commit 33fe723

File tree

11 files changed

+99
-99
lines changed

11 files changed

+99
-99
lines changed

src/Microsoft.ML.LightGBM.StaticPipe/LightGbmStaticExtensions.cs

+35-35
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ public static Scalar<float> LightGbm(this RegressionCatalog.RegressionTrainers c
4545
int? numLeaves = null,
4646
int? minDataPerLeaf = null,
4747
double? learningRate = null,
48-
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
48+
int numBoostRound = Options.Defaults.NumBoostRound,
4949
Action<LightGbmRegressionModelParameters> onFit = null)
5050
{
5151
CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, onFit);
@@ -70,7 +70,7 @@ public static Scalar<float> LightGbm(this RegressionCatalog.RegressionTrainers c
7070
/// <param name="label">The label column.</param>
7171
/// <param name="features">The features column.</param>
7272
/// <param name="weights">The weights column.</param>
73-
/// <param name="advancedSettings">Algorithm advanced settings.</param>
73+
/// <param name="options">Algorithm advanced settings.</param>
7474
/// <param name="onFit">A delegate that is called every time the
7575
/// <see cref="Estimator{TInShape, TOutShape, TTransformer}.Fit(DataView{TInShape})"/> method is called on the
7676
/// <see cref="Estimator{TInShape, TOutShape, TTransformer}"/> instance created out of this. This delegate will receive
@@ -79,19 +79,19 @@ public static Scalar<float> LightGbm(this RegressionCatalog.RegressionTrainers c
7979
/// <returns>The Score output column indicating the predicted value.</returns>
8080
public static Scalar<float> LightGbm(this RegressionCatalog.RegressionTrainers catalog,
8181
Scalar<float> label, Vector<float> features, Scalar<float> weights,
82-
LightGbmArguments advancedSettings,
82+
Options options,
8383
Action<LightGbmRegressionModelParameters> onFit = null)
8484
{
85-
CheckUserValues(label, features, weights, advancedSettings, onFit);
85+
CheckUserValues(label, features, weights, options, onFit);
8686

8787
var rec = new TrainerEstimatorReconciler.Regression(
8888
(env, labelName, featuresName, weightsName) =>
8989
{
90-
advancedSettings.LabelColumn = labelName;
91-
advancedSettings.FeatureColumn = featuresName;
92-
advancedSettings.WeightColumn = weightsName != null ? Optional<string>.Explicit(weightsName) : Optional<string>.Implicit(DefaultColumnNames.Weight);
90+
options.LabelColumn = labelName;
91+
options.FeatureColumn = featuresName;
92+
options.WeightColumn = weightsName != null ? Optional<string>.Explicit(weightsName) : Optional<string>.Implicit(DefaultColumnNames.Weight);
9393

94-
var trainer = new LightGbmRegressorTrainer(env, advancedSettings);
94+
var trainer = new LightGbmRegressorTrainer(env, options);
9595
if (onFit != null)
9696
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
9797
return trainer;
@@ -129,7 +129,7 @@ public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> pred
129129
int? numLeaves = null,
130130
int? minDataPerLeaf = null,
131131
double? learningRate = null,
132-
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
132+
int numBoostRound = Options.Defaults.NumBoostRound,
133133
Action<IPredictorWithFeatureWeights<float>> onFit = null)
134134
{
135135
CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, onFit);
@@ -156,7 +156,7 @@ public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> pred
156156
/// <param name="label">The label column.</param>
157157
/// <param name="features">The features column.</param>
158158
/// <param name="weights">The weights column.</param>
159-
/// <param name="advancedSettings">Algorithm advanced settings.</param>
159+
/// <param name="options">Algorithm advanced settings.</param>
160160
/// <param name="onFit">A delegate that is called every time the
161161
/// <see cref="Estimator{TInShape, TOutShape, TTransformer}.Fit(DataView{TInShape})"/> method is called on the
162162
/// <see cref="Estimator{TInShape, TOutShape, TTransformer}"/> instance created out of this. This delegate will receive
@@ -166,19 +166,19 @@ public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> pred
166166
/// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label.</returns>
167167
public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> predictedLabel) LightGbm(this BinaryClassificationCatalog.BinaryClassificationTrainers catalog,
168168
Scalar<bool> label, Vector<float> features, Scalar<float> weights,
169-
LightGbmArguments advancedSettings,
169+
Options options,
170170
Action<IPredictorWithFeatureWeights<float>> onFit = null)
171171
{
172-
CheckUserValues(label, features, weights, advancedSettings, onFit);
172+
CheckUserValues(label, features, weights, options, onFit);
173173

174174
var rec = new TrainerEstimatorReconciler.BinaryClassifier(
175175
(env, labelName, featuresName, weightsName) =>
176176
{
177-
advancedSettings.LabelColumn = labelName;
178-
advancedSettings.FeatureColumn = featuresName;
179-
advancedSettings.WeightColumn = weightsName != null ? Optional<string>.Explicit(weightsName) : Optional<string>.Implicit(DefaultColumnNames.Weight);
177+
options.LabelColumn = labelName;
178+
options.FeatureColumn = featuresName;
179+
options.WeightColumn = weightsName != null ? Optional<string>.Explicit(weightsName) : Optional<string>.Implicit(DefaultColumnNames.Weight);
180180

181-
var trainer = new LightGbmBinaryTrainer(env, advancedSettings);
181+
var trainer = new LightGbmBinaryTrainer(env, options);
182182

183183
if (onFit != null)
184184
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
@@ -213,7 +213,7 @@ public static Scalar<float> LightGbm<TVal>(this RankingCatalog.RankingTrainers c
213213
int? numLeaves = null,
214214
int? minDataPerLeaf = null,
215215
double? learningRate = null,
216-
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
216+
int numBoostRound = Options.Defaults.NumBoostRound,
217217
Action<LightGbmRankingModelParameters> onFit = null)
218218
{
219219
CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, onFit);
@@ -241,7 +241,7 @@ public static Scalar<float> LightGbm<TVal>(this RankingCatalog.RankingTrainers c
241241
/// <param name="features">The features column.</param>
242242
/// <param name="groupId">The groupId column.</param>
243243
/// <param name="weights">The weights column.</param>
244-
/// <param name="advancedSettings">Algorithm advanced settings.</param>
244+
/// <param name="options">Algorithm advanced settings.</param>
245245
/// <param name="onFit">A delegate that is called every time the
246246
/// <see cref="Estimator{TInShape, TOutShape, TTransformer}.Fit(DataView{TInShape})"/> method is called on the
247247
/// <see cref="Estimator{TInShape, TOutShape, TTransformer}"/> instance created out of this. This delegate will receive
@@ -251,21 +251,21 @@ public static Scalar<float> LightGbm<TVal>(this RankingCatalog.RankingTrainers c
251251
/// from negative to positive infinity), the calibrated prediction (from 0 to 1), and the predicted label.</returns>
252252
public static Scalar<float> LightGbm<TVal>(this RankingCatalog.RankingTrainers catalog,
253253
Scalar<float> label, Vector<float> features, Key<uint, TVal> groupId, Scalar<float> weights,
254-
LightGbmArguments advancedSettings,
254+
Options options,
255255
Action<LightGbmRankingModelParameters> onFit = null)
256256
{
257-
CheckUserValues(label, features, weights, advancedSettings, onFit);
257+
CheckUserValues(label, features, weights, options, onFit);
258258
Contracts.CheckValue(groupId, nameof(groupId));
259259

260260
var rec = new TrainerEstimatorReconciler.Ranker<TVal>(
261261
(env, labelName, featuresName, groupIdName, weightsName) =>
262262
{
263-
advancedSettings.LabelColumn = labelName;
264-
advancedSettings.FeatureColumn = featuresName;
265-
advancedSettings.GroupIdColumn = groupIdName;
266-
advancedSettings.WeightColumn = weightsName != null ? Optional<string>.Explicit(weightsName) : Optional<string>.Implicit(DefaultColumnNames.Weight);
263+
options.LabelColumn = labelName;
264+
options.FeatureColumn = featuresName;
265+
options.GroupIdColumn = groupIdName;
266+
options.WeightColumn = weightsName != null ? Optional<string>.Explicit(weightsName) : Optional<string>.Implicit(DefaultColumnNames.Weight);
267267

268-
var trainer = new LightGbmRankingTrainer(env, advancedSettings);
268+
var trainer = new LightGbmRankingTrainer(env, options);
269269

270270
if (onFit != null)
271271
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
@@ -307,7 +307,7 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel)
307307
int? numLeaves = null,
308308
int? minDataPerLeaf = null,
309309
double? learningRate = null,
310-
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound,
310+
int numBoostRound = Options.Defaults.NumBoostRound,
311311
Action<OvaModelParameters> onFit = null)
312312
{
313313
CheckUserValues(label, features, weights, numLeaves, minDataPerLeaf, learningRate, numBoostRound, onFit);
@@ -333,7 +333,7 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel)
333333
/// <param name="label">The label, or dependent variable.</param>
334334
/// <param name="features">The features, or independent variables.</param>
335335
/// <param name="weights">The weights column.</param>
336-
/// <param name="advancedSettings">Advanced options to the algorithm.</param>
336+
/// <param name="options">Advanced options to the algorithm.</param>
337337
/// <param name="onFit">A delegate that is called every time the
338338
/// <see cref="Estimator{TInShape, TOutShape, TTransformer}.Fit(DataView{TInShape})"/> method is called on the
339339
/// <see cref="Estimator{TInShape, TOutShape, TTransformer}"/> instance created out of this. This delegate will receive
@@ -345,19 +345,19 @@ public static (Vector<float> score, Key<uint, TVal> predictedLabel)
345345
Key<uint, TVal> label,
346346
Vector<float> features,
347347
Scalar<float> weights,
348-
LightGbmArguments advancedSettings,
348+
Options options,
349349
Action<OvaModelParameters> onFit = null)
350350
{
351-
CheckUserValues(label, features, weights, advancedSettings, onFit);
351+
CheckUserValues(label, features, weights, options, onFit);
352352

353353
var rec = new TrainerEstimatorReconciler.MulticlassClassifier<TVal>(
354354
(env, labelName, featuresName, weightsName) =>
355355
{
356-
advancedSettings.LabelColumn = labelName;
357-
advancedSettings.FeatureColumn = featuresName;
358-
advancedSettings.WeightColumn = weightsName != null ? Optional<string>.Explicit(weightsName) : Optional<string>.Implicit(DefaultColumnNames.Weight);
356+
options.LabelColumn = labelName;
357+
options.FeatureColumn = featuresName;
358+
options.WeightColumn = weightsName != null ? Optional<string>.Explicit(weightsName) : Optional<string>.Implicit(DefaultColumnNames.Weight);
359359

360-
var trainer = new LightGbmMulticlassTrainer(env, advancedSettings);
360+
var trainer = new LightGbmMulticlassTrainer(env, options);
361361

362362
if (onFit != null)
363363
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
@@ -385,13 +385,13 @@ private static void CheckUserValues(PipelineColumn label, Vector<float> features
385385
}
386386

387387
private static void CheckUserValues(PipelineColumn label, Vector<float> features, Scalar<float> weights,
388-
LightGbmArguments advancedSettings,
388+
Options options,
389389
Delegate onFit)
390390
{
391391
Contracts.CheckValue(label, nameof(label));
392392
Contracts.CheckValue(features, nameof(features));
393393
Contracts.CheckValueOrNull(weights);
394-
Contracts.CheckValue(advancedSettings, nameof(advancedSettings));
394+
Contracts.CheckValue(options, nameof(options));
395395
Contracts.CheckValueOrNull(onFit);
396396
}
397397
}

src/Microsoft.ML.LightGBM/LightGbmArguments.cs

+10-10
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@
1111
using Microsoft.ML.Internal.Internallearn;
1212
using Microsoft.ML.LightGBM;
1313

14-
[assembly: LoadableClass(typeof(LightGbmArguments.TreeBooster), typeof(LightGbmArguments.TreeBooster.Arguments),
15-
typeof(SignatureLightGBMBooster), LightGbmArguments.TreeBooster.FriendlyName, LightGbmArguments.TreeBooster.Name)]
16-
[assembly: LoadableClass(typeof(LightGbmArguments.DartBooster), typeof(LightGbmArguments.DartBooster.Arguments),
17-
typeof(SignatureLightGBMBooster), LightGbmArguments.DartBooster.FriendlyName, LightGbmArguments.DartBooster.Name)]
18-
[assembly: LoadableClass(typeof(LightGbmArguments.GossBooster), typeof(LightGbmArguments.GossBooster.Arguments),
19-
typeof(SignatureLightGBMBooster), LightGbmArguments.GossBooster.FriendlyName, LightGbmArguments.GossBooster.Name)]
14+
[assembly: LoadableClass(typeof(Options.TreeBooster), typeof(Options.TreeBooster.Arguments),
15+
typeof(SignatureLightGBMBooster), Options.TreeBooster.FriendlyName, Options.TreeBooster.Name)]
16+
[assembly: LoadableClass(typeof(Options.DartBooster), typeof(Options.DartBooster.Arguments),
17+
typeof(SignatureLightGBMBooster), Options.DartBooster.FriendlyName, Options.DartBooster.Name)]
18+
[assembly: LoadableClass(typeof(Options.GossBooster), typeof(Options.GossBooster.Arguments),
19+
typeof(SignatureLightGBMBooster), Options.GossBooster.FriendlyName, Options.GossBooster.Name)]
2020

21-
[assembly: EntryPointModule(typeof(LightGbmArguments.TreeBooster.Arguments))]
22-
[assembly: EntryPointModule(typeof(LightGbmArguments.DartBooster.Arguments))]
23-
[assembly: EntryPointModule(typeof(LightGbmArguments.GossBooster.Arguments))]
21+
[assembly: EntryPointModule(typeof(Options.TreeBooster.Arguments))]
22+
[assembly: EntryPointModule(typeof(Options.DartBooster.Arguments))]
23+
[assembly: EntryPointModule(typeof(Options.GossBooster.Arguments))]
2424

2525
namespace Microsoft.ML.LightGBM
2626
{
@@ -39,7 +39,7 @@ public interface IBoosterParameter
3939
/// Parameters names comes from LightGBM library.
4040
/// See https://github.com/Microsoft/LightGBM/blob/master/docs/Parameters.rst.
4141
/// </summary>
42-
public sealed class LightGbmArguments : LearnerInputBaseWithGroupId
42+
public sealed class Options : LearnerInputBaseWithGroupId
4343
{
4444
public abstract class BoosterParameter<TArgs> : IBoosterParameter
4545
where TArgs : class, new()

src/Microsoft.ML.LightGBM/LightGbmBinaryTrainer.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
using Microsoft.ML.Trainers.FastTree.Internal;
1717
using Microsoft.ML.Training;
1818

19-
[assembly: LoadableClass(LightGbmBinaryTrainer.Summary, typeof(LightGbmBinaryTrainer), typeof(LightGbmArguments),
19+
[assembly: LoadableClass(LightGbmBinaryTrainer.Summary, typeof(LightGbmBinaryTrainer), typeof(Options),
2020
new[] { typeof(SignatureBinaryClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureTreeEnsembleTrainer) },
2121
LightGbmBinaryTrainer.UserName, LightGbmBinaryTrainer.LoadNameValue, LightGbmBinaryTrainer.ShortName, DocName = "trainer/LightGBM.md")]
2222

@@ -95,8 +95,8 @@ public sealed class LightGbmBinaryTrainer : LightGbmTrainerBase<float, BinaryPre
9595

9696
public override PredictionKind PredictionKind => PredictionKind.BinaryClassification;
9797

98-
internal LightGbmBinaryTrainer(IHostEnvironment env, LightGbmArguments args)
99-
: base(env, LoadNameValue, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
98+
internal LightGbmBinaryTrainer(IHostEnvironment env, Options options)
99+
: base(env, LoadNameValue, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn))
100100
{
101101
}
102102

@@ -118,7 +118,7 @@ internal LightGbmBinaryTrainer(IHostEnvironment env,
118118
int? numLeaves = null,
119119
int? minDataPerLeaf = null,
120120
double? learningRate = null,
121-
int numBoostRound = LightGbmArguments.Defaults.NumBoostRound)
121+
int numBoostRound = LightGBM.Options.Defaults.NumBoostRound)
122122
: base(env, LoadNameValue, TrainerUtils.MakeBoolScalarLabel(labelColumn), featureColumn, weights, null, numLeaves, minDataPerLeaf, learningRate, numBoostRound)
123123
{
124124
}
@@ -181,14 +181,14 @@ public static partial class LightGbm
181181
ShortName = LightGbmBinaryTrainer.ShortName,
182182
XmlInclude = new[] { @"<include file='../Microsoft.ML.LightGBM/doc.xml' path='doc/members/member[@name=""LightGBM""]/*' />",
183183
@"<include file='../Microsoft.ML.LightGBM/doc.xml' path='doc/members/example[@name=""LightGbmBinaryClassifier""]/*' />"})]
184-
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, LightGbmArguments input)
184+
public static CommonOutputs.BinaryClassificationOutput TrainBinary(IHostEnvironment env, Options input)
185185
{
186186
Contracts.CheckValue(env, nameof(env));
187187
var host = env.Register("TrainLightGBM");
188188
host.CheckValue(input, nameof(input));
189189
EntryPointUtils.CheckInputArgs(host, input);
190190

191-
return LearnerEntryPointsUtils.Train<LightGbmArguments, CommonOutputs.BinaryClassificationOutput>(host, input,
191+
return LearnerEntryPointsUtils.Train<Options, CommonOutputs.BinaryClassificationOutput>(host, input,
192192
() => new LightGbmBinaryTrainer(host, input),
193193
getLabel: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn),
194194
getWeight: () => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.WeightColumn));

0 commit comments

Comments
 (0)