Skip to content

Commit cb78dee

Browse files
committed
review comments - 3. Rename Options objects as options (instead of args or advancedSettings used so far)
1 parent 05018b8 commit cb78dee

File tree

10 files changed

+130
-118
lines changed

10 files changed

+130
-118
lines changed

src/Microsoft.ML.FastTree/FastTreeClassification.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -144,8 +144,10 @@ public FastTreeBinaryClassificationTrainer(IHostEnvironment env,
144144
/// <summary>
145145
/// Initializes a new instance of <see cref="FastTreeBinaryClassificationTrainer"/> by using the <see cref="Options"/> class.
146146
/// </summary>
147-
public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options args)
148-
: base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
147+
/// <param name="env">The instance of <see cref="IHostEnvironment"/>.</param>
148+
/// <param name="options">Algorithm advanced settings.</param>
149+
public FastTreeBinaryClassificationTrainer(IHostEnvironment env, Options options)
150+
: base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn))
149151
{
150152
// Set the sigmoid parameter to the 2 * learning rate, for traditional FastTreeClassification loss
151153
_sigmoidParameter = 2.0 * Args.LearningRates;

src/Microsoft.ML.FastTree/FastTreeRanking.cs

+26-24
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,10 @@ public FastTreeRankingTrainer(IHostEnvironment env,
8888
/// <summary>
8989
/// Initializes a new instance of <see cref="FastTreeRankingTrainer"/> by using the <see cref="Options"/> class.
9090
/// </summary>
91-
public FastTreeRankingTrainer(IHostEnvironment env, Options args)
92-
: base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn))
91+
/// <param name="env">The instance of <see cref="IHostEnvironment"/>.</param>
92+
/// <param name="options">Algorithm advanced settings.</param>
93+
public FastTreeRankingTrainer(IHostEnvironment env, Options options)
94+
: base(env, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumn))
9395
{
9496
}
9597

@@ -546,14 +548,14 @@ private enum DupeIdInfo
546548
// Keeps track of labels of top 3 documents per query
547549
public short[][] TrainQueriesTopLabels;
548550

549-
public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options args, IParallelTraining parallelTraining)
551+
public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options options, IParallelTraining parallelTraining)
550552
: base(trainset,
551-
args.LearningRates,
552-
args.Shrinkage,
553-
args.MaxTreeOutput,
554-
args.GetDerivativesSampleRate,
555-
args.BestStepRankingRegressionTrees,
556-
args.RngSeed)
553+
options.LearningRates,
554+
options.Shrinkage,
555+
options.MaxTreeOutput,
556+
options.GetDerivativesSampleRate,
557+
options.BestStepRankingRegressionTrees,
558+
options.RngSeed)
557559
{
558560

559561
_labels = labels;
@@ -567,8 +569,8 @@ public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options arg
567569
_labelCounts[q] = new int[relevancyLevel];
568570

569571
// precomputed arrays
570-
_maxDcgTruncationLevel = args.LambdaMartMaxTruncation;
571-
_trainDcg = args.TrainDcg;
572+
_maxDcgTruncationLevel = options.LambdaMartMaxTruncation;
573+
_trainDcg = options.TrainDcg;
572574
if (_trainDcg)
573575
{
574576
_inverseMaxDcgt = new double[Dataset.NumQueries];
@@ -583,7 +585,7 @@ public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options arg
583585
}
584586

585587
_discount = new double[Dataset.MaxDocsPerQuery];
586-
FillDiscounts(args.PositionDiscountFreeform);
588+
FillDiscounts(options.PositionDiscountFreeform);
587589

588590
_oneTwoThree = new int[Dataset.MaxDocsPerQuery];
589591
for (int d = 0; d < Dataset.MaxDocsPerQuery; ++d)
@@ -593,7 +595,7 @@ public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options arg
593595
int numThreads = BlockingThreadPool.NumThreads;
594596
_comparers = new DcgPermutationComparer[numThreads];
595597
for (int i = 0; i < numThreads; ++i)
596-
_comparers[i] = DcgPermutationComparerFactory.GetDcgPermutationFactory(args.SortingAlgorithm);
598+
_comparers[i] = DcgPermutationComparerFactory.GetDcgPermutationFactory(options.SortingAlgorithm);
597599

598600
_permutationBuffers = new int[numThreads][];
599601
for (int i = 0; i < numThreads; ++i)
@@ -603,13 +605,13 @@ public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options arg
603605
FillGainLabels();
604606

605607
#region parameters
606-
_sigmoidParam = args.LearningRates;
607-
_costFunctionParam = args.CostFunctionParam;
608-
_distanceWeight2 = args.DistanceWeight2;
609-
_normalizeQueryLambdas = args.NormalizeQueryLambdas;
608+
_sigmoidParam = options.LearningRates;
609+
_costFunctionParam = options.CostFunctionParam;
610+
_distanceWeight2 = options.DistanceWeight2;
611+
_normalizeQueryLambdas = options.NormalizeQueryLambdas;
610612

611-
_useShiftedNdcg = args.ShiftedNdcg;
612-
_filterZeroLambdas = args.FilterZeroLambdas;
613+
_useShiftedNdcg = options.ShiftedNdcg;
614+
_filterZeroLambdas = options.FilterZeroLambdas;
613615
#endregion
614616

615617
_scoresCopy = new double[Dataset.NumDocs];
@@ -620,7 +622,7 @@ public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options arg
620622
#if OLD_DATALOAD
621623
SetupSecondaryGains(cmd);
622624
#endif
623-
SetupBaselineRisk(args);
625+
SetupBaselineRisk(options);
624626
_parallelTraining = parallelTraining;
625627
}
626628

@@ -644,18 +646,18 @@ private void SetupSecondaryGains(Arguments args)
644646
}
645647
#endif
646648

647-
private void SetupBaselineRisk(Options args)
649+
private void SetupBaselineRisk(Options options)
648650
{
649651
double[] scores = Dataset.Skeleton.GetData<double>("BaselineScores");
650652
if (scores == null)
651653
return;
652654

653655
// Calculate the DCG with the discounts as they exist in the objective function (this
654656
// can differ versus the actual DCG discount)
655-
DcgCalculator calc = new DcgCalculator(Dataset.MaxDocsPerQuery, args.SortingAlgorithm);
657+
DcgCalculator calc = new DcgCalculator(Dataset.MaxDocsPerQuery, options.SortingAlgorithm);
656658
_baselineDcg = calc.DcgFromScores(Dataset, scores, _discount);
657659

658-
IniFileParserInterface ffi = IniFileParserInterface.CreateFromFreeform(string.IsNullOrEmpty(args.BaselineAlphaRisk) ? "0" : args.BaselineAlphaRisk);
660+
IniFileParserInterface ffi = IniFileParserInterface.CreateFromFreeform(string.IsNullOrEmpty(options.BaselineAlphaRisk) ? "0" : options.BaselineAlphaRisk);
659661
IniFileParserInterface.FeatureEvaluator ffe = ffi.GetFeatureEvaluators()[0];
660662
IniFileParserInterface.FeatureMap ffmap = ffi.GetFeatureMap();
661663
string[] ffnames = Enumerable.Range(0, ffmap.RawFeatureCount)
@@ -672,7 +674,7 @@ private void SetupBaselineRisk(Options args)
672674
uint[] vals = new uint[ffmap.RawFeatureCount];
673675
int iInd = Array.IndexOf(ffnames, "I");
674676
int tInd = Array.IndexOf(ffnames, "T");
675-
int totalTrees = args.NumTrees;
677+
int totalTrees = options.NumTrees;
676678
if (tInd >= 0)
677679
vals[tInd] = (uint)totalTrees;
678680
_baselineAlpha = Enumerable.Range(0, totalTrees).Select(i =>

src/Microsoft.ML.FastTree/FastTreeRegression.cs

+12-10
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ public FastTreeRegressionTrainer(IHostEnvironment env,
7777
/// <summary>
7878
/// Initializes a new instance of <see cref="FastTreeRegressionTrainer"/> by using the <see cref="Options"/> class.
7979
/// </summary>
80-
public FastTreeRegressionTrainer(IHostEnvironment env, Options args)
81-
: base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn))
80+
/// <param name="env">The instance of <see cref="IHostEnvironment"/>.</param>
81+
/// <param name="options">Algorithm advanced settings.</param>
82+
public FastTreeRegressionTrainer(IHostEnvironment env, Options options)
83+
: base(env, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumn))
8284
{
8385
}
8486

@@ -397,17 +399,17 @@ public ObjectiveImpl(Dataset trainData, RegressionGamTrainer.Arguments args) :
397399
_labels = GetDatasetRegressionLabels(trainData);
398400
}
399401

400-
public ObjectiveImpl(Dataset trainData, Options args)
402+
public ObjectiveImpl(Dataset trainData, Options options)
401403
: base(
402404
trainData,
403-
args.LearningRates,
404-
args.Shrinkage,
405-
args.MaxTreeOutput,
406-
args.GetDerivativesSampleRate,
407-
args.BestStepRankingRegressionTrees,
408-
args.RngSeed)
405+
options.LearningRates,
406+
options.Shrinkage,
407+
options.MaxTreeOutput,
408+
options.GetDerivativesSampleRate,
409+
options.BestStepRankingRegressionTrees,
410+
options.RngSeed)
409411
{
410-
if (args.DropoutRate > 0 && LearningRate > 0) // Don't do shrinkage if dropouts are used.
412+
if (options.DropoutRate > 0 && LearningRate > 0) // Don't do shrinkage if dropouts are used.
411413
Shrinkage = 1.0 / LearningRate;
412414

413415
_labels = GetDatasetRegressionLabels(trainData);

src/Microsoft.ML.FastTree/FastTreeTweedie.cs

+15-13
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,10 @@ public FastTreeTweedieTrainer(IHostEnvironment env,
7878
/// <summary>
7979
/// Initializes a new instance of <see cref="FastTreeTweedieTrainer"/> by using the <see cref="Options"/> class.
8080
/// </summary>
81-
public FastTreeTweedieTrainer(IHostEnvironment env, Options args)
82-
: base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn))
81+
/// <param name="env">The instance of <see cref="IHostEnvironment"/>.</param>
82+
/// <param name="options">Algorithm advanced settings.</param>
83+
public FastTreeTweedieTrainer(IHostEnvironment env, Options options)
84+
: base(env, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumn))
8385
{
8486
Initialize();
8587
}
@@ -334,17 +336,17 @@ private sealed class ObjectiveImpl : ObjectiveFunctionBase, IStepSearch
334336
private readonly Double _index2; // 2 minus the index parameter.
335337
private readonly Double _maxClamp;
336338

337-
public ObjectiveImpl(Dataset trainData, Options args)
339+
public ObjectiveImpl(Dataset trainData, Options options)
338340
: base(
339341
trainData,
340-
args.LearningRates,
341-
args.Shrinkage,
342-
args.MaxTreeOutput,
343-
args.GetDerivativesSampleRate,
344-
args.BestStepRankingRegressionTrees,
345-
args.RngSeed)
342+
options.LearningRates,
343+
options.Shrinkage,
344+
options.MaxTreeOutput,
345+
options.GetDerivativesSampleRate,
346+
options.BestStepRankingRegressionTrees,
347+
options.RngSeed)
346348
{
347-
if (args.DropoutRate > 0 && LearningRate > 0) // Don't do shrinkage if dropouts are used.
349+
if (options.DropoutRate > 0 && LearningRate > 0) // Don't do shrinkage if dropouts are used.
348350
Shrinkage = 1.0 / LearningRate;
349351

350352
_labels = GetDatasetRegressionLabels(trainData);
@@ -355,9 +357,9 @@ public ObjectiveImpl(Dataset trainData, Options args)
355357
_labels[i] = 0;
356358
}
357359

358-
_index1 = 1 - args.Index;
359-
_index2 = 2 - args.Index;
360-
_maxClamp = Math.Abs(args.MaxTreeOutput);
360+
_index1 = 1 - options.Index;
361+
_index2 = 2 - options.Index;
362+
_maxClamp = Math.Abs(options.MaxTreeOutput);
361363
}
362364

363365
public void AdjustTreeOutputs(IChannel ch, RegressionTree tree, DocumentPartitioning partitioning, ScoreTracker trainingScores)

src/Microsoft.ML.FastTree/RandomForestClassification.cs

+6-4
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,10 @@ public FastForestClassification(IHostEnvironment env,
162162
/// <summary>
163163
/// Initializes a new instance of <see cref="FastForestClassification"/> by using the <see cref="Options"/> class.
164164
/// </summary>
165-
public FastForestClassification(IHostEnvironment env, Options args)
166-
: base(env, args, TrainerUtils.MakeBoolScalarLabel(args.LabelColumn))
165+
/// <param name="env">The instance of <see cref="IHostEnvironment"/>.</param>
166+
/// <param name="options">Algorithm advanced settings.</param>
167+
public FastForestClassification(IHostEnvironment env, Options options)
168+
: base(env, options, TrainerUtils.MakeBoolScalarLabel(options.LabelColumn))
167169
{
168170
}
169171

@@ -229,8 +231,8 @@ private sealed class ObjectiveFunctionImpl : RandomForestObjectiveFunction
229231
{
230232
private readonly bool[] _labels;
231233

232-
public ObjectiveFunctionImpl(Dataset trainSet, bool[] trainSetLabels, Options args)
233-
: base(trainSet, args, args.MaxTreeOutput)
234+
public ObjectiveFunctionImpl(Dataset trainSet, bool[] trainSetLabels, Options options)
235+
: base(trainSet, options, options.MaxTreeOutput)
234236
{
235237
_labels = trainSetLabels;
236238
}

src/Microsoft.ML.FastTree/RandomForestRegression.cs

+17-15
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,10 @@ public FastForestRegression(IHostEnvironment env,
180180
/// <summary>
181181
/// Initializes a new instance of <see cref="FastForestRegression"/> by using the <see cref="Options"/> class.
182182
/// </summary>
183-
public FastForestRegression(IHostEnvironment env, Options args)
184-
: base(env, args, TrainerUtils.MakeR4ScalarColumn(args.LabelColumn), true)
183+
/// <param name="env">The instance of <see cref="IHostEnvironment"/>.</param>
184+
/// <param name="options">Algorithm advanced settings.</param>
185+
public FastForestRegression(IHostEnvironment env, Options options)
186+
: base(env, options, TrainerUtils.MakeR4ScalarColumn(options.LabelColumn), true)
185187
{
186188
}
187189

@@ -237,15 +239,15 @@ private abstract class ObjectiveFunctionImplBase : RandomForestObjectiveFunction
237239
{
238240
private readonly float[] _labels;
239241

240-
public static ObjectiveFunctionImplBase Create(Dataset trainData, Options args)
242+
public static ObjectiveFunctionImplBase Create(Dataset trainData, Options options)
241243
{
242-
if (args.ShuffleLabels)
243-
return new ShuffleImpl(trainData, args);
244-
return new BasicImpl(trainData, args);
244+
if (options.ShuffleLabels)
245+
return new ShuffleImpl(trainData, options);
246+
return new BasicImpl(trainData, options);
245247
}
246248

247-
private ObjectiveFunctionImplBase(Dataset trainData, Options args)
248-
: base(trainData, args, double.MaxValue) // No notion of maximum step size.
249+
private ObjectiveFunctionImplBase(Dataset trainData, Options options)
250+
: base(trainData, options, double.MaxValue) // No notion of maximum step size.
249251
{
250252
_labels = FastTreeRegressionTrainer.GetDatasetRegressionLabels(trainData);
251253
Contracts.Assert(_labels.Length == trainData.NumDocs);
@@ -264,11 +266,11 @@ private sealed class ShuffleImpl : ObjectiveFunctionImplBase
264266
private readonly Random _rgen;
265267
private readonly int _labelLim;
266268

267-
public ShuffleImpl(Dataset trainData, Options args)
268-
: base(trainData, args)
269+
public ShuffleImpl(Dataset trainData, Options options)
270+
: base(trainData, options)
269271
{
270-
Contracts.AssertValue(args);
271-
Contracts.Assert(args.ShuffleLabels);
272+
Contracts.AssertValue(options);
273+
Contracts.Assert(options.ShuffleLabels);
272274

273275
_rgen = new Random(0); // Ideally we'd get this from the host.
274276

@@ -277,7 +279,7 @@ public ShuffleImpl(Dataset trainData, Options args)
277279
var lab = _labels[i];
278280
if (!(0 <= lab && lab < Utils.ArrayMaxSize))
279281
{
280-
throw Contracts.ExceptUserArg(nameof(args.ShuffleLabels),
282+
throw Contracts.ExceptUserArg(nameof(options.ShuffleLabels),
281283
"Label {0} for example {1} outside of allowed range" +
282284
"[0,{2}) when doing shuffled labels", lab, i, Utils.ArrayMaxSize);
283285
}
@@ -302,8 +304,8 @@ public override double[] GetGradient(IChannel ch, double[] scores)
302304

303305
private sealed class BasicImpl : ObjectiveFunctionImplBase
304306
{
305-
public BasicImpl(Dataset trainData, Options args)
306-
: base(trainData, args)
307+
public BasicImpl(Dataset trainData, Options options)
308+
: base(trainData, options)
307309
{
308310
}
309311
}

0 commit comments

Comments
 (0)