15
15
using Microsoft . ML . Internal . Utilities ;
16
16
using Microsoft . ML . Runtime ;
17
17
18
- [ assembly: LoadableClass ( typeof ( RankingEvaluator ) , typeof ( RankingEvaluator ) , typeof ( RankingEvaluator . Arguments ) , typeof ( SignatureEvaluator ) ,
18
+ [ assembly: LoadableClass ( typeof ( RankingEvaluator ) , typeof ( RankingEvaluator ) , typeof ( RankingEvaluatorOptions ) , typeof ( SignatureEvaluator ) ,
19
19
"Ranking Evaluator" , RankingEvaluator . LoadName , "Ranking" , "rank" ) ]
20
20
21
21
[ assembly: LoadableClass ( typeof ( RankingMamlEvaluator ) , typeof ( RankingMamlEvaluator ) , typeof ( RankingMamlEvaluator . Arguments ) , typeof ( SignatureMamlEvaluator ) ,
26
26
27
27
namespace Microsoft . ML . Data
28
28
{
29
- [ BestFriend ]
30
- internal sealed class RankingEvaluator : EvaluatorBase < RankingEvaluator . Aggregator >
29
+ /// <summary>
30
+ /// Options to control the output of the RankingEvaluator
31
+ /// </summary>
32
+ public sealed class RankingEvaluatorOptions
31
33
{
32
- public sealed class Arguments
33
- {
34
- [ Argument ( ArgumentType . AtMostOnce , HelpText = "Maximum truncation level for computing (N)DCG" , ShortName = "t" ) ]
35
- public int DcgTruncationLevel = 3 ;
34
+ /// <value>
35
+ /// Maximum truncation level for computing (N)DCG
36
+ /// </value>
37
+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Maximum truncation level for computing (N)DCG" , ShortName = "t" ) ]
38
+ public int DcgTruncationLevel = 3 ;
36
39
37
- [ Argument ( ArgumentType . AtMostOnce , HelpText = "Label relevance gains" , ShortName = "gains" ) ]
38
- public string LabelGains = "0,3,7,15,31" ;
40
+ /// <value>
41
+ /// Label relevance gains
42
+ /// </value>
43
+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Label relevance gains" , ShortName = "gains" ) ]
44
+ public string LabelGains = "0,3,7,15,31" ;
39
45
40
- [ Argument ( ArgumentType . AtMostOnce , HelpText = "Generate per-group (N)DCG" , ShortName = "ogs" ) ]
41
- public bool OutputGroupSummary ;
42
- }
46
+ [ Argument ( ArgumentType . AtMostOnce , HelpText = "Generate per-group (N)DCG" , ShortName = "ogs" ) ]
47
+ internal bool OutputGroupSummary ;
48
+ }
43
49
50
+ [ BestFriend ]
51
+ internal sealed class RankingEvaluator : EvaluatorBase < RankingEvaluator . Aggregator >
52
+ {
44
53
internal const string LoadName = "RankingEvaluator" ;
45
54
46
55
public const string Ndcg = "NDCG" ;
@@ -60,24 +69,25 @@ public sealed class Arguments
60
69
private readonly bool _groupSummary ;
61
70
private readonly Double [ ] _labelGains ;
62
71
63
- public RankingEvaluator ( IHostEnvironment env , Arguments args )
72
+ public RankingEvaluator ( IHostEnvironment env , RankingEvaluatorOptions options )
64
73
: base ( env , LoadName )
65
74
{
66
75
// REVIEW: What kind of checking should be applied to labelGains?
67
- if ( args . DcgTruncationLevel <= 0 || args . DcgTruncationLevel > Aggregator . Counters . MaxTruncationLevel )
68
- throw Host . ExceptUserArg ( nameof ( args . DcgTruncationLevel ) , "DCG Truncation Level must be between 1 and {0}" , Aggregator . Counters . MaxTruncationLevel ) ;
69
- Host . CheckUserArg ( args . LabelGains != null , nameof ( args . LabelGains ) , "Label gains cannot be null" ) ;
76
+ // add the setter to utils here
77
+ if ( options . DcgTruncationLevel <= 0 )
78
+ throw Host . ExceptUserArg ( nameof ( options . DcgTruncationLevel ) , "DCG Truncation Level must be greater than 0" ) ;
79
+ Host . CheckUserArg ( options . LabelGains != null , nameof ( options . LabelGains ) , "Label gains cannot be null" ) ;
70
80
71
- _truncationLevel = args . DcgTruncationLevel ;
72
- _groupSummary = args . OutputGroupSummary ;
81
+ _truncationLevel = options . DcgTruncationLevel ;
82
+ _groupSummary = options . OutputGroupSummary ;
73
83
74
84
var labelGains = new List < Double > ( ) ;
75
- string [ ] gains = args . LabelGains . Split ( ',' ) ;
85
+ string [ ] gains = options . LabelGains . Split ( ',' ) ;
76
86
for ( int i = 0 ; i < gains . Length ; i ++ )
77
87
{
78
88
Double gain ;
79
89
if ( ! Double . TryParse ( gains [ i ] , out gain ) )
80
- throw Host . ExceptUserArg ( nameof ( args . LabelGains ) , "Label Gains must be of floating or integral type" , Aggregator . Counters . MaxTruncationLevel ) ;
90
+ throw Host . ExceptUserArg ( nameof ( options . LabelGains ) , "Label Gains must be of floating or integral type" ) ;
81
91
labelGains . Add ( gain ) ;
82
92
}
83
93
_labelGains = labelGains . ToArray ( ) ;
@@ -271,8 +281,6 @@ public sealed class Aggregator : AggregatorBase
271
281
{
272
282
public sealed class Counters
273
283
{
274
- public const int MaxTruncationLevel = 10 ;
275
-
276
284
public readonly int TruncationLevel ;
277
285
private readonly List < Double [ ] > _groupNdcg ;
278
286
private readonly List < Double [ ] > _groupDcg ;
@@ -287,6 +295,7 @@ public sealed class Counters
287
295
private readonly List < short > _queryLabels ;
288
296
private readonly List < Single > _queryOutputs ;
289
297
private readonly Double [ ] _labelGains ;
298
+ private readonly Double [ ] _discountMap ;
290
299
291
300
public bool GroupSummary { get { return _groupNdcg != null ; } }
292
301
@@ -348,6 +357,8 @@ public Counters(Double[] labelGains, int truncationLevel, bool groupSummary)
348
357
Contracts . AssertValue ( labelGains ) ;
349
358
350
359
TruncationLevel = truncationLevel ;
360
+ _discountMap = RankingUtils . GetDiscountMap ( truncationLevel ) ;
361
+
351
362
_sumDcgAtN = new Double [ TruncationLevel ] ;
352
363
_sumNdcgAtN = new Double [ TruncationLevel ] ;
353
364
@@ -373,15 +384,15 @@ public void Update(short label, Single output)
373
384
374
385
public void UpdateGroup ( Single weight )
375
386
{
376
- RankingUtils . QueryMaxDcg ( _labelGains , TruncationLevel , _queryLabels , _queryOutputs , _groupMaxDcgCur ) ;
387
+ RankingUtils . QueryMaxDcg ( _labelGains , TruncationLevel , _discountMap , _queryLabels , _queryOutputs , _groupMaxDcgCur ) ;
377
388
if ( _groupMaxDcg != null )
378
389
{
379
390
var maxDcg = new Double [ TruncationLevel ] ;
380
391
Array . Copy ( _groupMaxDcgCur , maxDcg , TruncationLevel ) ;
381
392
_groupMaxDcg . Add ( maxDcg ) ;
382
393
}
383
394
384
- RankingUtils . QueryDcg ( _labelGains , TruncationLevel , _queryLabels , _queryOutputs , _groupDcgCur ) ;
395
+ RankingUtils . QueryDcg ( _labelGains , TruncationLevel , _discountMap , _queryLabels , _queryOutputs , _groupDcgCur ) ;
385
396
if ( _groupDcg != null )
386
397
{
387
398
var groupDcg = new Double [ TruncationLevel ] ;
@@ -684,17 +695,19 @@ private void SlotNamesGetter(int iinfo, ref VBuffer<ReadOnlyMemory<char>> dst)
684
695
685
696
private readonly Bindings _bindings ;
686
697
private readonly int _truncationLevel ;
698
+ private readonly Double [ ] _discountMap ;
687
699
private readonly Double [ ] _labelGains ;
688
700
689
701
public Transform ( IHostEnvironment env , IDataView input , string labelCol , string scoreCol , string groupCol ,
690
702
int truncationLevel , Double [ ] labelGains )
691
703
: base ( env , input , labelCol , scoreCol , groupCol , RegistrationName )
692
704
{
693
- Host . CheckParam ( 0 < truncationLevel && truncationLevel < 100 , nameof ( truncationLevel ) ,
694
- "Truncation level must be between 1 and 99 " ) ;
705
+ Host . CheckParam ( 0 < truncationLevel , nameof ( truncationLevel ) ,
706
+ "Truncation level must be greater than 0 " ) ;
695
707
Host . CheckValue ( labelGains , nameof ( labelGains ) ) ;
696
708
697
709
_truncationLevel = truncationLevel ;
710
+ _discountMap = RankingUtils . GetDiscountMap ( _truncationLevel ) ;
698
711
_labelGains = labelGains ;
699
712
_bindings = new Bindings ( Host , Source . Schema , true , LabelCol , ScoreCol , GroupCol , _truncationLevel ) ;
700
713
}
@@ -709,7 +722,7 @@ public Transform(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
709
722
// double[]: _labelGains
710
723
711
724
_truncationLevel = ctx . Reader . ReadInt32 ( ) ;
712
- Host . CheckDecode ( 0 < _truncationLevel && _truncationLevel < 100 ) ;
725
+ Host . CheckDecode ( 0 < _truncationLevel ) ;
713
726
_labelGains = ctx . Reader . ReadDoubleArray ( ) ;
714
727
_bindings = new Bindings ( Host , input . Schema , false , LabelCol , ScoreCol , GroupCol , _truncationLevel ) ;
715
728
}
@@ -725,7 +738,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
725
738
// double[]: _labelGains
726
739
727
740
base . SaveModel ( ctx ) ;
728
- Host . Assert ( 0 < _truncationLevel && _truncationLevel < 100 ) ;
741
+ Host . Assert ( 0 < _truncationLevel ) ;
729
742
ctx . Writer . Write ( _truncationLevel ) ;
730
743
ctx . Writer . WriteDoubleArray ( _labelGains ) ;
731
744
}
@@ -800,9 +813,9 @@ protected override void ProcessExample(RowCursorState state, short label, Single
800
813
protected override void UpdateState ( RowCursorState state )
801
814
{
802
815
// Calculate the current group DCG, NDCG and MaxDcg.
803
- RankingUtils . QueryMaxDcg ( _labelGains , _truncationLevel , state . QueryLabels , state . QueryOutputs ,
816
+ RankingUtils . QueryMaxDcg ( _labelGains , _truncationLevel , _discountMap , state . QueryLabels , state . QueryOutputs ,
804
817
state . MaxDcgCur ) ;
805
- RankingUtils . QueryDcg ( _labelGains , _truncationLevel , state . QueryLabels , state . QueryOutputs , state . DcgCur ) ;
818
+ RankingUtils . QueryDcg ( _labelGains , _truncationLevel , _discountMap , state . QueryLabels , state . QueryOutputs , state . DcgCur ) ;
806
819
for ( int t = 0 ; t < _truncationLevel ; t ++ )
807
820
{
808
821
Double ndcg = state . MaxDcgCur [ t ] > 0 ? state . DcgCur [ t ] / state . MaxDcgCur [ t ] : 0 ;
@@ -823,7 +836,7 @@ public sealed class RowCursorState
823
836
824
837
public RowCursorState ( int truncationLevel )
825
838
{
826
- Contracts . Assert ( 0 < truncationLevel && truncationLevel < 100 ) ;
839
+ Contracts . Assert ( 0 < truncationLevel ) ;
827
840
828
841
QueryLabels = new List < short > ( ) ;
829
842
QueryOutputs = new List < Single > ( ) ;
@@ -867,12 +880,12 @@ public RankingMamlEvaluator(IHostEnvironment env, Arguments args)
867
880
Host . CheckValue ( args , nameof ( args ) ) ;
868
881
Utils . CheckOptionalUserDirectory ( args . GroupSummaryFilename , nameof ( args . GroupSummaryFilename ) ) ;
869
882
870
- var evalArgs = new RankingEvaluator . Arguments ( ) ;
871
- evalArgs . DcgTruncationLevel = args . DcgTruncationLevel ;
872
- evalArgs . LabelGains = args . LabelGains ;
873
- evalArgs . OutputGroupSummary = ! string . IsNullOrEmpty ( args . GroupSummaryFilename ) ;
883
+ var evalOpts = new RankingEvaluatorOptions ( ) ;
884
+ evalOpts . DcgTruncationLevel = args . DcgTruncationLevel ;
885
+ evalOpts . LabelGains = args . LabelGains ;
886
+ evalOpts . OutputGroupSummary = ! string . IsNullOrEmpty ( args . GroupSummaryFilename ) ;
874
887
875
- _evaluator = new RankingEvaluator ( Host , evalArgs ) ;
888
+ _evaluator = new RankingEvaluator ( Host , evalOpts ) ;
876
889
_groupSummaryFilename = args . GroupSummaryFilename ;
877
890
_groupIdCol = args . GroupIdColumn ;
878
891
}
@@ -946,30 +959,41 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM
946
959
947
960
internal static class RankingUtils
948
961
{
949
- private static volatile Double [ ] _discountMap ;
950
- public static Double [ ] DiscountMap
962
+ // Truncation levels are typically less than 100. So we maintain a fixed discount map of size 100
963
+ // If truncation level greater than 100 is required, we build a new one and return that.
964
+ private const int FixedDiscountMapSize = 100 ;
965
+ private static Double [ ] _discountMapFixed ;
966
+
967
+ private static Double [ ] GetDiscountMapCore ( int truncationLevel )
951
968
{
952
- get
969
+ var discountMap = new Double [ truncationLevel ] ;
970
+
971
+ for ( int i = 0 ; i < discountMap . Length ; i ++ )
972
+ discountMap [ i ] = 1 / Math . Log ( 2 + i ) ;
973
+
974
+ return discountMap ;
975
+ }
976
+
977
+ public static Double [ ] GetDiscountMap ( int truncationLevel )
978
+ {
979
+ var discountMap = _discountMapFixed ;
980
+ if ( discountMap == null )
953
981
{
954
- double [ ] result = _discountMap ;
955
- if ( result == null )
956
- {
957
- var discountMap = new Double [ 100 ] ; //Hard to believe anyone would set truncation Level higher than 100
958
- for ( int i = 0 ; i < discountMap . Length ; i ++ )
959
- {
960
- discountMap [ i ] = 1 / Math . Log ( 2 + i ) ;
961
- }
962
- Interlocked . CompareExchange ( ref _discountMap , discountMap , null ) ;
963
- result = _discountMap ;
964
- }
965
- return result ;
982
+ discountMap = GetDiscountMapCore ( FixedDiscountMapSize ) ;
983
+ Interlocked . CompareExchange ( ref _discountMapFixed , discountMap , null ) ;
984
+ discountMap = _discountMapFixed ;
966
985
}
986
+
987
+ if ( truncationLevel <= discountMap . Length )
988
+ return discountMap ;
989
+
990
+ return GetDiscountMapCore ( truncationLevel ) ;
967
991
}
968
992
969
993
/// <summary>
970
994
/// Calculates natural-based max DCG at all truncations from 1 to truncationLevel.
971
995
/// </summary>
972
- public static void QueryMaxDcg ( Double [ ] labelGains , int truncationLevel ,
996
+ public static void QueryMaxDcg ( Double [ ] labelGains , int truncationLevel , Double [ ] discountMap ,
973
997
List < short > queryLabels , List < Single > queryOutputs , Double [ ] groupMaxDcgCur )
974
998
{
975
999
Contracts . Assert ( Utils . Size ( groupMaxDcgCur ) == truncationLevel ) ;
@@ -994,21 +1018,21 @@ public static void QueryMaxDcg(Double[] labelGains, int truncationLevel,
994
1018
while ( labelCounts [ topLabel ] == 0 )
995
1019
topLabel -- ;
996
1020
997
- groupMaxDcgCur [ 0 ] = labelGains [ topLabel ] * DiscountMap [ 0 ] ;
1021
+ groupMaxDcgCur [ 0 ] = labelGains [ topLabel ] * discountMap [ 0 ] ;
998
1022
labelCounts [ topLabel ] -- ;
999
1023
for ( int t = 1 ; t < maxTrunc ; t ++ )
1000
1024
{
1001
1025
while ( labelCounts [ topLabel ] == 0 )
1002
1026
topLabel -- ;
1003
- groupMaxDcgCur [ t ] = groupMaxDcgCur [ t - 1 ] + labelGains [ topLabel ] * DiscountMap [ t ] ;
1027
+ groupMaxDcgCur [ t ] = groupMaxDcgCur [ t - 1 ] + labelGains [ topLabel ] * discountMap [ t ] ;
1004
1028
labelCounts [ topLabel ] -- ;
1005
1029
}
1006
1030
for ( int t = maxTrunc ; t < truncationLevel ; t ++ )
1007
1031
groupMaxDcgCur [ t ] = groupMaxDcgCur [ t - 1 ] ;
1008
1032
}
1009
1033
}
1010
1034
1011
- public static void QueryDcg ( Double [ ] labelGains , int truncationLevel ,
1035
+ public static void QueryDcg ( Double [ ] labelGains , int truncationLevel , Double [ ] discountMap ,
1012
1036
List < short > queryLabels , List < Single > queryOutputs , Double [ ] groupDcgCur )
1013
1037
{
1014
1038
// calculate the permutation
@@ -1021,7 +1045,7 @@ public static void QueryDcg(Double[] labelGains, int truncationLevel,
1021
1045
Double dcg = 0 ;
1022
1046
for ( int t = 0 ; t < count ; ++ t )
1023
1047
{
1024
- dcg = dcg + labelGains [ queryLabels [ permutation [ t ] ] ] * DiscountMap [ t ] ;
1048
+ dcg = dcg + labelGains [ queryLabels [ permutation [ t ] ] ] * discountMap [ t ] ;
1025
1049
groupDcgCur [ t ] = dcg ;
1026
1050
}
1027
1051
for ( int t = count ; t < truncationLevel ; ++ t )
0 commit comments