Skip to content

Commit 3f98485

Browse files
authored
Added RankingEvaluatorOptions and removed the truncation limit. (#4081)
Summary of changes: Added RankingEvaluatorOptions class to control the output of evaluation Removed hard coded truncation limit for the max truncation level Added corresponding unit tests and maml tests Fixes #3993
1 parent d3dca2e commit 3f98485

File tree

6 files changed

+233
-65
lines changed

6 files changed

+233
-65
lines changed

src/Microsoft.ML.Data/Evaluators/RankingEvaluator.cs

Lines changed: 80 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
using Microsoft.ML.Internal.Utilities;
1616
using Microsoft.ML.Runtime;
1717

18-
[assembly: LoadableClass(typeof(RankingEvaluator), typeof(RankingEvaluator), typeof(RankingEvaluator.Arguments), typeof(SignatureEvaluator),
18+
[assembly: LoadableClass(typeof(RankingEvaluator), typeof(RankingEvaluator), typeof(RankingEvaluatorOptions), typeof(SignatureEvaluator),
1919
"Ranking Evaluator", RankingEvaluator.LoadName, "Ranking", "rank")]
2020

2121
[assembly: LoadableClass(typeof(RankingMamlEvaluator), typeof(RankingMamlEvaluator), typeof(RankingMamlEvaluator.Arguments), typeof(SignatureMamlEvaluator),
@@ -26,21 +26,30 @@
2626

2727
namespace Microsoft.ML.Data
2828
{
29-
[BestFriend]
30-
internal sealed class RankingEvaluator : EvaluatorBase<RankingEvaluator.Aggregator>
29+
/// <summary>
30+
/// Options to control the output of the RankingEvaluator
31+
/// </summary>
32+
public sealed class RankingEvaluatorOptions
3133
{
32-
public sealed class Arguments
33-
{
34-
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum truncation level for computing (N)DCG", ShortName = "t")]
35-
public int DcgTruncationLevel = 3;
34+
/// <value>
35+
/// Maximum truncation level for computing (N)DCG
36+
/// </value>
37+
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum truncation level for computing (N)DCG", ShortName = "t")]
38+
public int DcgTruncationLevel = 3;
3639

37-
[Argument(ArgumentType.AtMostOnce, HelpText = "Label relevance gains", ShortName = "gains")]
38-
public string LabelGains = "0,3,7,15,31";
40+
/// <value>
41+
/// Label relevance gains
42+
/// </value>
43+
[Argument(ArgumentType.AtMostOnce, HelpText = "Label relevance gains", ShortName = "gains")]
44+
public string LabelGains = "0,3,7,15,31";
3945

40-
[Argument(ArgumentType.AtMostOnce, HelpText = "Generate per-group (N)DCG", ShortName = "ogs")]
41-
public bool OutputGroupSummary;
42-
}
46+
[Argument(ArgumentType.AtMostOnce, HelpText = "Generate per-group (N)DCG", ShortName = "ogs")]
47+
internal bool OutputGroupSummary;
48+
}
4349

50+
[BestFriend]
51+
internal sealed class RankingEvaluator : EvaluatorBase<RankingEvaluator.Aggregator>
52+
{
4453
internal const string LoadName = "RankingEvaluator";
4554

4655
public const string Ndcg = "NDCG";
@@ -60,24 +69,25 @@ public sealed class Arguments
6069
private readonly bool _groupSummary;
6170
private readonly Double[] _labelGains;
6271

63-
public RankingEvaluator(IHostEnvironment env, Arguments args)
72+
public RankingEvaluator(IHostEnvironment env, RankingEvaluatorOptions options)
6473
: base(env, LoadName)
6574
{
6675
// REVIEW: What kind of checking should be applied to labelGains?
67-
if (args.DcgTruncationLevel <= 0 || args.DcgTruncationLevel > Aggregator.Counters.MaxTruncationLevel)
68-
throw Host.ExceptUserArg(nameof(args.DcgTruncationLevel), "DCG Truncation Level must be between 1 and {0}", Aggregator.Counters.MaxTruncationLevel);
69-
Host.CheckUserArg(args.LabelGains != null, nameof(args.LabelGains), "Label gains cannot be null");
76+
// add the setter to utils here
77+
if (options.DcgTruncationLevel <= 0)
78+
throw Host.ExceptUserArg(nameof(options.DcgTruncationLevel), "DCG Truncation Level must be greater than 0");
79+
Host.CheckUserArg(options.LabelGains != null, nameof(options.LabelGains), "Label gains cannot be null");
7080

71-
_truncationLevel = args.DcgTruncationLevel;
72-
_groupSummary = args.OutputGroupSummary;
81+
_truncationLevel = options.DcgTruncationLevel;
82+
_groupSummary = options.OutputGroupSummary;
7383

7484
var labelGains = new List<Double>();
75-
string[] gains = args.LabelGains.Split(',');
85+
string[] gains = options.LabelGains.Split(',');
7686
for (int i = 0; i < gains.Length; i++)
7787
{
7888
Double gain;
7989
if (!Double.TryParse(gains[i], out gain))
80-
throw Host.ExceptUserArg(nameof(args.LabelGains), "Label Gains must be of floating or integral type", Aggregator.Counters.MaxTruncationLevel);
90+
throw Host.ExceptUserArg(nameof(options.LabelGains), "Label Gains must be of floating or integral type");
8191
labelGains.Add(gain);
8292
}
8393
_labelGains = labelGains.ToArray();
@@ -271,8 +281,6 @@ public sealed class Aggregator : AggregatorBase
271281
{
272282
public sealed class Counters
273283
{
274-
public const int MaxTruncationLevel = 10;
275-
276284
public readonly int TruncationLevel;
277285
private readonly List<Double[]> _groupNdcg;
278286
private readonly List<Double[]> _groupDcg;
@@ -287,6 +295,7 @@ public sealed class Counters
287295
private readonly List<short> _queryLabels;
288296
private readonly List<Single> _queryOutputs;
289297
private readonly Double[] _labelGains;
298+
private readonly Double[] _discountMap;
290299

291300
public bool GroupSummary { get { return _groupNdcg != null; } }
292301

@@ -348,6 +357,8 @@ public Counters(Double[] labelGains, int truncationLevel, bool groupSummary)
348357
Contracts.AssertValue(labelGains);
349358

350359
TruncationLevel = truncationLevel;
360+
_discountMap = RankingUtils.GetDiscountMap(truncationLevel);
361+
351362
_sumDcgAtN = new Double[TruncationLevel];
352363
_sumNdcgAtN = new Double[TruncationLevel];
353364

@@ -373,15 +384,15 @@ public void Update(short label, Single output)
373384

374385
public void UpdateGroup(Single weight)
375386
{
376-
RankingUtils.QueryMaxDcg(_labelGains, TruncationLevel, _queryLabels, _queryOutputs, _groupMaxDcgCur);
387+
RankingUtils.QueryMaxDcg(_labelGains, TruncationLevel, _discountMap, _queryLabels, _queryOutputs, _groupMaxDcgCur);
377388
if (_groupMaxDcg != null)
378389
{
379390
var maxDcg = new Double[TruncationLevel];
380391
Array.Copy(_groupMaxDcgCur, maxDcg, TruncationLevel);
381392
_groupMaxDcg.Add(maxDcg);
382393
}
383394

384-
RankingUtils.QueryDcg(_labelGains, TruncationLevel, _queryLabels, _queryOutputs, _groupDcgCur);
395+
RankingUtils.QueryDcg(_labelGains, TruncationLevel, _discountMap, _queryLabels, _queryOutputs, _groupDcgCur);
385396
if (_groupDcg != null)
386397
{
387398
var groupDcg = new Double[TruncationLevel];
@@ -684,17 +695,19 @@ private void SlotNamesGetter(int iinfo, ref VBuffer<ReadOnlyMemory<char>> dst)
684695

685696
private readonly Bindings _bindings;
686697
private readonly int _truncationLevel;
698+
private readonly Double[] _discountMap;
687699
private readonly Double[] _labelGains;
688700

689701
public Transform(IHostEnvironment env, IDataView input, string labelCol, string scoreCol, string groupCol,
690702
int truncationLevel, Double[] labelGains)
691703
: base(env, input, labelCol, scoreCol, groupCol, RegistrationName)
692704
{
693-
Host.CheckParam(0 < truncationLevel && truncationLevel < 100, nameof(truncationLevel),
694-
"Truncation level must be between 1 and 99");
705+
Host.CheckParam(0 < truncationLevel , nameof(truncationLevel),
706+
"Truncation level must be greater than 0");
695707
Host.CheckValue(labelGains, nameof(labelGains));
696708

697709
_truncationLevel = truncationLevel;
710+
_discountMap = RankingUtils.GetDiscountMap(_truncationLevel);
698711
_labelGains = labelGains;
699712
_bindings = new Bindings(Host, Source.Schema, true, LabelCol, ScoreCol, GroupCol, _truncationLevel);
700713
}
@@ -709,7 +722,7 @@ public Transform(IHostEnvironment env, ModelLoadContext ctx, IDataView input)
709722
// double[]: _labelGains
710723

711724
_truncationLevel = ctx.Reader.ReadInt32();
712-
Host.CheckDecode(0 < _truncationLevel && _truncationLevel < 100);
725+
Host.CheckDecode(0 < _truncationLevel);
713726
_labelGains = ctx.Reader.ReadDoubleArray();
714727
_bindings = new Bindings(Host, input.Schema, false, LabelCol, ScoreCol, GroupCol, _truncationLevel);
715728
}
@@ -725,7 +738,7 @@ private protected override void SaveModel(ModelSaveContext ctx)
725738
// double[]: _labelGains
726739

727740
base.SaveModel(ctx);
728-
Host.Assert(0 < _truncationLevel && _truncationLevel < 100);
741+
Host.Assert(0 < _truncationLevel);
729742
ctx.Writer.Write(_truncationLevel);
730743
ctx.Writer.WriteDoubleArray(_labelGains);
731744
}
@@ -800,9 +813,9 @@ protected override void ProcessExample(RowCursorState state, short label, Single
800813
protected override void UpdateState(RowCursorState state)
801814
{
802815
// Calculate the current group DCG, NDCG and MaxDcg.
803-
RankingUtils.QueryMaxDcg(_labelGains, _truncationLevel, state.QueryLabels, state.QueryOutputs,
816+
RankingUtils.QueryMaxDcg(_labelGains, _truncationLevel, _discountMap, state.QueryLabels, state.QueryOutputs,
804817
state.MaxDcgCur);
805-
RankingUtils.QueryDcg(_labelGains, _truncationLevel, state.QueryLabels, state.QueryOutputs, state.DcgCur);
818+
RankingUtils.QueryDcg(_labelGains, _truncationLevel, _discountMap, state.QueryLabels, state.QueryOutputs, state.DcgCur);
806819
for (int t = 0; t < _truncationLevel; t++)
807820
{
808821
Double ndcg = state.MaxDcgCur[t] > 0 ? state.DcgCur[t] / state.MaxDcgCur[t] : 0;
@@ -823,7 +836,7 @@ public sealed class RowCursorState
823836

824837
public RowCursorState(int truncationLevel)
825838
{
826-
Contracts.Assert(0 < truncationLevel && truncationLevel < 100);
839+
Contracts.Assert(0 < truncationLevel);
827840

828841
QueryLabels = new List<short>();
829842
QueryOutputs = new List<Single>();
@@ -867,12 +880,12 @@ public RankingMamlEvaluator(IHostEnvironment env, Arguments args)
867880
Host.CheckValue(args, nameof(args));
868881
Utils.CheckOptionalUserDirectory(args.GroupSummaryFilename, nameof(args.GroupSummaryFilename));
869882

870-
var evalArgs = new RankingEvaluator.Arguments();
871-
evalArgs.DcgTruncationLevel = args.DcgTruncationLevel;
872-
evalArgs.LabelGains = args.LabelGains;
873-
evalArgs.OutputGroupSummary = !string.IsNullOrEmpty(args.GroupSummaryFilename);
883+
var evalOpts = new RankingEvaluatorOptions();
884+
evalOpts.DcgTruncationLevel = args.DcgTruncationLevel;
885+
evalOpts.LabelGains = args.LabelGains;
886+
evalOpts.OutputGroupSummary = !string.IsNullOrEmpty(args.GroupSummaryFilename);
874887

875-
_evaluator = new RankingEvaluator(Host, evalArgs);
888+
_evaluator = new RankingEvaluator(Host, evalOpts);
876889
_groupSummaryFilename = args.GroupSummaryFilename;
877890
_groupIdCol = args.GroupIdColumn;
878891
}
@@ -946,30 +959,41 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM
946959

947960
internal static class RankingUtils
948961
{
949-
private static volatile Double[] _discountMap;
950-
public static Double[] DiscountMap
962+
// Truncation levels are typically less than 100. So we maintain a fixed discount map of size 100
963+
// If truncation level greater than 100 is required, we build a new one and return that.
964+
private const int FixedDiscountMapSize = 100;
965+
private static Double[] _discountMapFixed;
966+
967+
private static Double[] GetDiscountMapCore(int truncationLevel)
951968
{
952-
get
969+
var discountMap = new Double[truncationLevel];
970+
971+
for (int i = 0; i < discountMap.Length; i++)
972+
discountMap[i] = 1 / Math.Log(2 + i);
973+
974+
return discountMap;
975+
}
976+
977+
public static Double[] GetDiscountMap(int truncationLevel)
978+
{
979+
var discountMap = _discountMapFixed;
980+
if (discountMap == null)
953981
{
954-
double[] result = _discountMap;
955-
if (result == null)
956-
{
957-
var discountMap = new Double[100]; //Hard to believe anyone would set truncation Level higher than 100
958-
for (int i = 0; i < discountMap.Length; i++)
959-
{
960-
discountMap[i] = 1 / Math.Log(2 + i);
961-
}
962-
Interlocked.CompareExchange(ref _discountMap, discountMap, null);
963-
result = _discountMap;
964-
}
965-
return result;
982+
discountMap = GetDiscountMapCore(FixedDiscountMapSize);
983+
Interlocked.CompareExchange(ref _discountMapFixed, discountMap, null);
984+
discountMap = _discountMapFixed;
966985
}
986+
987+
if (truncationLevel <= discountMap.Length)
988+
return discountMap;
989+
990+
return GetDiscountMapCore(truncationLevel);
967991
}
968992

969993
/// <summary>
970994
/// Calculates natural-based max DCG at all truncations from 1 to truncationLevel.
971995
/// </summary>
972-
public static void QueryMaxDcg(Double[] labelGains, int truncationLevel,
996+
public static void QueryMaxDcg(Double[] labelGains, int truncationLevel, Double[] discountMap,
973997
List<short> queryLabels, List<Single> queryOutputs, Double[] groupMaxDcgCur)
974998
{
975999
Contracts.Assert(Utils.Size(groupMaxDcgCur) == truncationLevel);
@@ -994,21 +1018,21 @@ public static void QueryMaxDcg(Double[] labelGains, int truncationLevel,
9941018
while (labelCounts[topLabel] == 0)
9951019
topLabel--;
9961020

997-
groupMaxDcgCur[0] = labelGains[topLabel] * DiscountMap[0];
1021+
groupMaxDcgCur[0] = labelGains[topLabel] * discountMap[0];
9981022
labelCounts[topLabel]--;
9991023
for (int t = 1; t < maxTrunc; t++)
10001024
{
10011025
while (labelCounts[topLabel] == 0)
10021026
topLabel--;
1003-
groupMaxDcgCur[t] = groupMaxDcgCur[t - 1] + labelGains[topLabel] * DiscountMap[t];
1027+
groupMaxDcgCur[t] = groupMaxDcgCur[t - 1] + labelGains[topLabel] * discountMap[t];
10041028
labelCounts[topLabel]--;
10051029
}
10061030
for (int t = maxTrunc; t < truncationLevel; t++)
10071031
groupMaxDcgCur[t] = groupMaxDcgCur[t - 1];
10081032
}
10091033
}
10101034

1011-
public static void QueryDcg(Double[] labelGains, int truncationLevel,
1035+
public static void QueryDcg(Double[] labelGains, int truncationLevel, Double[] discountMap,
10121036
List<short> queryLabels, List<Single> queryOutputs, Double[] groupDcgCur)
10131037
{
10141038
// calculate the permutation
@@ -1021,7 +1045,7 @@ public static void QueryDcg(Double[] labelGains, int truncationLevel,
10211045
Double dcg = 0;
10221046
for (int t = 0; t < count; ++t)
10231047
{
1024-
dcg = dcg + labelGains[queryLabels[permutation[t]]] * DiscountMap[t];
1048+
dcg = dcg + labelGains[queryLabels[permutation[t]]] * discountMap[t];
10251049
groupDcgCur[t] = dcg;
10261050
}
10271051
for (int t = count; t < truncationLevel; ++t)

src/Microsoft.ML.Data/TrainCatalog.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,21 @@ internal RankingTrainers(RankingCatalog catalog)
638638
/// <param name="scoreColumnName">The name of the score column in <paramref name="data"/>.</param>
639639
/// <returns>The evaluation results for these calibrated outputs.</returns>
640640
public RankingMetrics Evaluate(IDataView data,
641+
string labelColumnName = DefaultColumnNames.Label,
642+
string rowGroupColumnName = DefaultColumnNames.GroupId,
643+
string scoreColumnName = DefaultColumnNames.Score) => Evaluate(data, null, labelColumnName, rowGroupColumnName, scoreColumnName);
644+
645+
/// <summary>
646+
/// Evaluates scored ranking data.
647+
/// </summary>
648+
/// <param name="data">The scored data.</param>
649+
/// <param name="options">Options to control the evaluation result.</param>
650+
/// <param name="labelColumnName">The name of the label column in <paramref name="data"/>.</param>
651+
/// <param name="rowGroupColumnName">The name of the groupId column in <paramref name="data"/>.</param>
652+
/// <param name="scoreColumnName">The name of the score column in <paramref name="data"/>.</param>
653+
/// <returns>The evaluation results for these calibrated outputs.</returns>
654+
public RankingMetrics Evaluate(IDataView data,
655+
RankingEvaluatorOptions options,
641656
string labelColumnName = DefaultColumnNames.Label,
642657
string rowGroupColumnName = DefaultColumnNames.GroupId,
643658
string scoreColumnName = DefaultColumnNames.Score)
@@ -647,7 +662,7 @@ public RankingMetrics Evaluate(IDataView data,
647662
Environment.CheckNonEmpty(scoreColumnName, nameof(scoreColumnName));
648663
Environment.CheckNonEmpty(rowGroupColumnName, nameof(rowGroupColumnName));
649664

650-
var eval = new RankingEvaluator(Environment, new RankingEvaluator.Arguments() { });
665+
var eval = new RankingEvaluator(Environment, options ?? new RankingEvaluatorOptions() { });
651666
return eval.Evaluate(data, labelColumnName, rowGroupColumnName, scoreColumnName);
652667
}
653668
}

0 commit comments

Comments
 (0)