@@ -88,8 +88,10 @@ public FastTreeRankingTrainer(IHostEnvironment env,
88
88
/// <summary>
89
89
/// Initializes a new instance of <see cref="FastTreeRankingTrainer"/> by using the <see cref="Options"/> class.
90
90
/// </summary>
91
- public FastTreeRankingTrainer ( IHostEnvironment env , Options args )
92
- : base ( env , args , TrainerUtils . MakeR4ScalarColumn ( args . LabelColumn ) )
91
+ /// <param name="env">The instance of <see cref="IHostEnvironment"/>.</param>
92
+ /// <param name="options">Algorithm advanced settings.</param>
93
+ public FastTreeRankingTrainer ( IHostEnvironment env , Options options )
94
+ : base ( env , options , TrainerUtils . MakeR4ScalarColumn ( options . LabelColumn ) )
93
95
{
94
96
}
95
97
@@ -546,14 +548,14 @@ private enum DupeIdInfo
546
548
// Keeps track of labels of top 3 documents per query
547
549
public short [ ] [ ] TrainQueriesTopLabels ;
548
550
549
- public LambdaRankObjectiveFunction ( Dataset trainset , short [ ] labels , Options args , IParallelTraining parallelTraining )
551
+ public LambdaRankObjectiveFunction ( Dataset trainset , short [ ] labels , Options options , IParallelTraining parallelTraining )
550
552
: base ( trainset ,
551
- args . LearningRates ,
552
- args . Shrinkage ,
553
- args . MaxTreeOutput ,
554
- args . GetDerivativesSampleRate ,
555
- args . BestStepRankingRegressionTrees ,
556
- args . RngSeed )
553
+ options . LearningRates ,
554
+ options . Shrinkage ,
555
+ options . MaxTreeOutput ,
556
+ options . GetDerivativesSampleRate ,
557
+ options . BestStepRankingRegressionTrees ,
558
+ options . RngSeed )
557
559
{
558
560
559
561
_labels = labels ;
@@ -567,8 +569,8 @@ public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options arg
567
569
_labelCounts [ q ] = new int [ relevancyLevel ] ;
568
570
569
571
// precomputed arrays
570
- _maxDcgTruncationLevel = args . LambdaMartMaxTruncation ;
571
- _trainDcg = args . TrainDcg ;
572
+ _maxDcgTruncationLevel = options . LambdaMartMaxTruncation ;
573
+ _trainDcg = options . TrainDcg ;
572
574
if ( _trainDcg )
573
575
{
574
576
_inverseMaxDcgt = new double [ Dataset . NumQueries ] ;
@@ -583,7 +585,7 @@ public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options arg
583
585
}
584
586
585
587
_discount = new double [ Dataset . MaxDocsPerQuery ] ;
586
- FillDiscounts ( args . PositionDiscountFreeform ) ;
588
+ FillDiscounts ( options . PositionDiscountFreeform ) ;
587
589
588
590
_oneTwoThree = new int [ Dataset . MaxDocsPerQuery ] ;
589
591
for ( int d = 0 ; d < Dataset . MaxDocsPerQuery ; ++ d )
@@ -593,7 +595,7 @@ public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options arg
593
595
int numThreads = BlockingThreadPool . NumThreads ;
594
596
_comparers = new DcgPermutationComparer [ numThreads ] ;
595
597
for ( int i = 0 ; i < numThreads ; ++ i )
596
- _comparers [ i ] = DcgPermutationComparerFactory . GetDcgPermutationFactory ( args . SortingAlgorithm ) ;
598
+ _comparers [ i ] = DcgPermutationComparerFactory . GetDcgPermutationFactory ( options . SortingAlgorithm ) ;
597
599
598
600
_permutationBuffers = new int [ numThreads ] [ ] ;
599
601
for ( int i = 0 ; i < numThreads ; ++ i )
@@ -603,13 +605,13 @@ public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options arg
603
605
FillGainLabels ( ) ;
604
606
605
607
#region parameters
606
- _sigmoidParam = args . LearningRates ;
607
- _costFunctionParam = args . CostFunctionParam ;
608
- _distanceWeight2 = args . DistanceWeight2 ;
609
- _normalizeQueryLambdas = args . NormalizeQueryLambdas ;
608
+ _sigmoidParam = options . LearningRates ;
609
+ _costFunctionParam = options . CostFunctionParam ;
610
+ _distanceWeight2 = options . DistanceWeight2 ;
611
+ _normalizeQueryLambdas = options . NormalizeQueryLambdas ;
610
612
611
- _useShiftedNdcg = args . ShiftedNdcg ;
612
- _filterZeroLambdas = args . FilterZeroLambdas ;
613
+ _useShiftedNdcg = options . ShiftedNdcg ;
614
+ _filterZeroLambdas = options . FilterZeroLambdas ;
613
615
#endregion
614
616
615
617
_scoresCopy = new double [ Dataset . NumDocs ] ;
@@ -620,7 +622,7 @@ public LambdaRankObjectiveFunction(Dataset trainset, short[] labels, Options arg
620
622
#if OLD_DATALOAD
621
623
SetupSecondaryGains ( cmd ) ;
622
624
#endif
623
- SetupBaselineRisk ( args ) ;
625
+ SetupBaselineRisk ( options ) ;
624
626
_parallelTraining = parallelTraining ;
625
627
}
626
628
@@ -644,18 +646,18 @@ private void SetupSecondaryGains(Arguments args)
644
646
}
645
647
#endif
646
648
647
- private void SetupBaselineRisk ( Options args )
649
+ private void SetupBaselineRisk ( Options options )
648
650
{
649
651
double [ ] scores = Dataset . Skeleton . GetData < double > ( "BaselineScores" ) ;
650
652
if ( scores == null )
651
653
return ;
652
654
653
655
// Calculate the DCG with the discounts as they exist in the objective function (this
654
656
// can differ versus the actual DCG discount)
655
- DcgCalculator calc = new DcgCalculator ( Dataset . MaxDocsPerQuery , args . SortingAlgorithm ) ;
657
+ DcgCalculator calc = new DcgCalculator ( Dataset . MaxDocsPerQuery , options . SortingAlgorithm ) ;
656
658
_baselineDcg = calc . DcgFromScores ( Dataset , scores , _discount ) ;
657
659
658
- IniFileParserInterface ffi = IniFileParserInterface . CreateFromFreeform ( string . IsNullOrEmpty ( args . BaselineAlphaRisk ) ? "0" : args . BaselineAlphaRisk ) ;
660
+ IniFileParserInterface ffi = IniFileParserInterface . CreateFromFreeform ( string . IsNullOrEmpty ( options . BaselineAlphaRisk ) ? "0" : options . BaselineAlphaRisk ) ;
659
661
IniFileParserInterface . FeatureEvaluator ffe = ffi . GetFeatureEvaluators ( ) [ 0 ] ;
660
662
IniFileParserInterface . FeatureMap ffmap = ffi . GetFeatureMap ( ) ;
661
663
string [ ] ffnames = Enumerable . Range ( 0 , ffmap . RawFeatureCount )
@@ -672,7 +674,7 @@ private void SetupBaselineRisk(Options args)
672
674
uint [ ] vals = new uint [ ffmap . RawFeatureCount ] ;
673
675
int iInd = Array . IndexOf ( ffnames , "I" ) ;
674
676
int tInd = Array . IndexOf ( ffnames , "T" ) ;
675
- int totalTrees = args . NumTrees ;
677
+ int totalTrees = options . NumTrees ;
676
678
if ( tInd >= 0 )
677
679
vals [ tInd ] = ( uint ) totalTrees ;
678
680
_baselineAlpha = Enumerable . Range ( 0 , totalTrees ) . Select ( i =>
0 commit comments