-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Added RankingEvaluatorOptions and removed the truncation limit. #4081
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 55 commits
8749a10
fe25bf6
cb446be
b5ee220
b9a7471
80e238d
2ef424d
3958f01
56d4595
00bc7ef
d0462f1
87cefbc
c3a908b
c0a430a
0b55903
56983d5
3382d1d
8ca5d01
4ac459e
8f20ea4
f9f9e1d
21cb8f3
138f201
55e3460
e43bba3
421d713
4f4f81c
89082a5
f167af8
0d4d34f
6cd2f15
1424ab3
3ee03ca
34b7a91
5539127
02053a6
35ad3c0
0eb3e2b
a3291b1
37af437
68f1f35
5b90a34
0efe238
b6584aa
7d47832
0e99776
20a4490
0d111f4
72d1a4d
ea9ebed
013be4f
d2ae365
a9e6db8
5855f99
724bb12
d009f55
8f7b6cd
30d56a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,7 @@ | |
using Microsoft.ML.Internal.Utilities; | ||
using Microsoft.ML.Runtime; | ||
|
||
[assembly: LoadableClass(typeof(RankingEvaluator), typeof(RankingEvaluator), typeof(RankingEvaluator.Arguments), typeof(SignatureEvaluator), | ||
[assembly: LoadableClass(typeof(RankingEvaluator), typeof(RankingEvaluator), typeof(RankingEvaluatorOptions), typeof(SignatureEvaluator), | ||
"Ranking Evaluator", RankingEvaluator.LoadName, "Ranking", "rank")] | ||
|
||
[assembly: LoadableClass(typeof(RankingMamlEvaluator), typeof(RankingMamlEvaluator), typeof(RankingMamlEvaluator.Arguments), typeof(SignatureMamlEvaluator), | ||
|
@@ -26,21 +26,24 @@ | |
|
||
namespace Microsoft.ML.Data | ||
{ | ||
[BestFriend] | ||
internal sealed class RankingEvaluator : EvaluatorBase<RankingEvaluator.Aggregator> | ||
/// <summary> | ||
/// Options to control the output of the RankingEvaluator | ||
/// </summary> | ||
public sealed class RankingEvaluatorOptions | ||
{ | ||
public sealed class Arguments | ||
{ | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum truncation level for computing (N)DCG", ShortName = "t")] | ||
public int DcgTruncationLevel = 3; | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Maximum truncation level for computing (N)DCG", ShortName = "t")] | ||
public int DcgTruncationLevel = 3; | ||
|
||
[Argument(ArgumentType.AtMostOnce, HelpText = "Label relevance gains", ShortName = "gains")] | ||
public string LabelGains = "0,3,7,15,31"; | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Label relevance gains", ShortName = "gains")] | ||
public string LabelGains = "0,3,7,15,31"; | ||
|
||
[Argument(ArgumentType.AtMostOnce, HelpText = "Generate per-group (N)DCG", ShortName = "ogs")] | ||
public bool OutputGroupSummary; | ||
} | ||
[Argument(ArgumentType.AtMostOnce, HelpText = "Generate per-group (N)DCG", ShortName = "ogs")] | ||
internal bool OutputGroupSummary; | ||
} | ||
|
||
[BestFriend] | ||
internal sealed class RankingEvaluator : EvaluatorBase<RankingEvaluator.Aggregator> | ||
{ | ||
internal const string LoadName = "RankingEvaluator"; | ||
|
||
public const string Ndcg = "NDCG"; | ||
|
@@ -60,24 +63,25 @@ public sealed class Arguments | |
private readonly bool _groupSummary; | ||
private readonly Double[] _labelGains; | ||
|
||
public RankingEvaluator(IHostEnvironment env, Arguments args) | ||
public RankingEvaluator(IHostEnvironment env, RankingEvaluatorOptions options) | ||
: base(env, LoadName) | ||
{ | ||
// REVIEW: What kind of checking should be applied to labelGains? | ||
if (args.DcgTruncationLevel <= 0 || args.DcgTruncationLevel > Aggregator.Counters.MaxTruncationLevel) | ||
throw Host.ExceptUserArg(nameof(args.DcgTruncationLevel), "DCG Truncation Level must be between 1 and {0}", Aggregator.Counters.MaxTruncationLevel); | ||
Host.CheckUserArg(args.LabelGains != null, nameof(args.LabelGains), "Label gains cannot be null"); | ||
// add the setter to utils here | ||
if (options.DcgTruncationLevel <= 0) | ||
throw Host.ExceptUserArg(nameof(options.DcgTruncationLevel), "DCG Truncation Level must be greater than 0"); | ||
Host.CheckUserArg(options.LabelGains != null, nameof(options.LabelGains), "Label gains cannot be null"); | ||
|
||
_truncationLevel = args.DcgTruncationLevel; | ||
_groupSummary = args.OutputGroupSummary; | ||
_truncationLevel = options.DcgTruncationLevel; | ||
_groupSummary = options.OutputGroupSummary; | ||
|
||
var labelGains = new List<Double>(); | ||
string[] gains = args.LabelGains.Split(','); | ||
string[] gains = options.LabelGains.Split(','); | ||
for (int i = 0; i < gains.Length; i++) | ||
{ | ||
Double gain; | ||
if (!Double.TryParse(gains[i], out gain)) | ||
throw Host.ExceptUserArg(nameof(args.LabelGains), "Label Gains must be of floating or integral type", Aggregator.Counters.MaxTruncationLevel); | ||
throw Host.ExceptUserArg(nameof(options.LabelGains), "Label Gains must be of floating or integral type"); | ||
labelGains.Add(gain); | ||
} | ||
_labelGains = labelGains.ToArray(); | ||
|
@@ -271,8 +275,6 @@ public sealed class Aggregator : AggregatorBase | |
{ | ||
public sealed class Counters | ||
{ | ||
public const int MaxTruncationLevel = 10; | ||
|
||
public readonly int TruncationLevel; | ||
private readonly List<Double[]> _groupNdcg; | ||
private readonly List<Double[]> _groupDcg; | ||
|
@@ -287,6 +289,7 @@ public sealed class Counters | |
private readonly List<short> _queryLabels; | ||
private readonly List<Single> _queryOutputs; | ||
private readonly Double[] _labelGains; | ||
private readonly Double[] _discountMap; | ||
|
||
public bool GroupSummary { get { return _groupNdcg != null; } } | ||
|
||
|
@@ -348,6 +351,8 @@ public Counters(Double[] labelGains, int truncationLevel, bool groupSummary) | |
Contracts.AssertValue(labelGains); | ||
|
||
TruncationLevel = truncationLevel; | ||
_discountMap = RankingUtils.GetDiscountMap(truncationLevel); | ||
|
||
_sumDcgAtN = new Double[TruncationLevel]; | ||
_sumNdcgAtN = new Double[TruncationLevel]; | ||
|
||
|
@@ -373,15 +378,15 @@ public void Update(short label, Single output) | |
|
||
public void UpdateGroup(Single weight) | ||
{ | ||
RankingUtils.QueryMaxDcg(_labelGains, TruncationLevel, _queryLabels, _queryOutputs, _groupMaxDcgCur); | ||
RankingUtils.QueryMaxDcg(_labelGains, TruncationLevel, _discountMap, _queryLabels, _queryOutputs, _groupMaxDcgCur); | ||
if (_groupMaxDcg != null) | ||
{ | ||
var maxDcg = new Double[TruncationLevel]; | ||
Array.Copy(_groupMaxDcgCur, maxDcg, TruncationLevel); | ||
_groupMaxDcg.Add(maxDcg); | ||
} | ||
|
||
RankingUtils.QueryDcg(_labelGains, TruncationLevel, _queryLabels, _queryOutputs, _groupDcgCur); | ||
RankingUtils.QueryDcg(_labelGains, TruncationLevel, _discountMap, _queryLabels, _queryOutputs, _groupDcgCur); | ||
if (_groupDcg != null) | ||
{ | ||
var groupDcg = new Double[TruncationLevel]; | ||
|
@@ -684,17 +689,19 @@ private void SlotNamesGetter(int iinfo, ref VBuffer<ReadOnlyMemory<char>> dst) | |
|
||
private readonly Bindings _bindings; | ||
private readonly int _truncationLevel; | ||
private readonly Double[] _discountMap; | ||
private readonly Double[] _labelGains; | ||
|
||
public Transform(IHostEnvironment env, IDataView input, string labelCol, string scoreCol, string groupCol, | ||
int truncationLevel, Double[] labelGains) | ||
: base(env, input, labelCol, scoreCol, groupCol, RegistrationName) | ||
{ | ||
Host.CheckParam(0 < truncationLevel && truncationLevel < 100, nameof(truncationLevel), | ||
"Truncation level must be between 1 and 99"); | ||
Host.CheckParam(0 < truncationLevel , nameof(truncationLevel), | ||
"Truncation level must be greater than 0"); | ||
Host.CheckValue(labelGains, nameof(labelGains)); | ||
|
||
_truncationLevel = truncationLevel; | ||
_discountMap = RankingUtils.GetDiscountMap(_truncationLevel); | ||
_labelGains = labelGains; | ||
_bindings = new Bindings(Host, Source.Schema, true, LabelCol, ScoreCol, GroupCol, _truncationLevel); | ||
} | ||
|
@@ -709,7 +716,7 @@ public Transform(IHostEnvironment env, ModelLoadContext ctx, IDataView input) | |
// double[]: _labelGains | ||
|
||
_truncationLevel = ctx.Reader.ReadInt32(); | ||
Host.CheckDecode(0 < _truncationLevel && _truncationLevel < 100); | ||
Host.CheckDecode(0 < _truncationLevel); | ||
_labelGains = ctx.Reader.ReadDoubleArray(); | ||
_bindings = new Bindings(Host, input.Schema, false, LabelCol, ScoreCol, GroupCol, _truncationLevel); | ||
} | ||
|
@@ -725,7 +732,7 @@ private protected override void SaveModel(ModelSaveContext ctx) | |
// double[]: _labelGains | ||
|
||
base.SaveModel(ctx); | ||
Host.Assert(0 < _truncationLevel && _truncationLevel < 100); | ||
Host.Assert(0 < _truncationLevel); | ||
ctx.Writer.Write(_truncationLevel); | ||
ctx.Writer.WriteDoubleArray(_labelGains); | ||
} | ||
|
@@ -800,9 +807,9 @@ protected override void ProcessExample(RowCursorState state, short label, Single | |
protected override void UpdateState(RowCursorState state) | ||
{ | ||
// Calculate the current group DCG, NDCG and MaxDcg. | ||
RankingUtils.QueryMaxDcg(_labelGains, _truncationLevel, state.QueryLabels, state.QueryOutputs, | ||
RankingUtils.QueryMaxDcg(_labelGains, _truncationLevel, _discountMap, state.QueryLabels, state.QueryOutputs, | ||
state.MaxDcgCur); | ||
RankingUtils.QueryDcg(_labelGains, _truncationLevel, state.QueryLabels, state.QueryOutputs, state.DcgCur); | ||
RankingUtils.QueryDcg(_labelGains, _truncationLevel, _discountMap, state.QueryLabels, state.QueryOutputs, state.DcgCur); | ||
for (int t = 0; t < _truncationLevel; t++) | ||
{ | ||
Double ndcg = state.MaxDcgCur[t] > 0 ? state.DcgCur[t] / state.MaxDcgCur[t] : 0; | ||
|
@@ -823,7 +830,7 @@ public sealed class RowCursorState | |
|
||
public RowCursorState(int truncationLevel) | ||
{ | ||
Contracts.Assert(0 < truncationLevel && truncationLevel < 100); | ||
Contracts.Assert(0 < truncationLevel); | ||
|
||
QueryLabels = new List<short>(); | ||
QueryOutputs = new List<Single>(); | ||
|
@@ -867,12 +874,12 @@ public RankingMamlEvaluator(IHostEnvironment env, Arguments args) | |
Host.CheckValue(args, nameof(args)); | ||
Utils.CheckOptionalUserDirectory(args.GroupSummaryFilename, nameof(args.GroupSummaryFilename)); | ||
|
||
var evalArgs = new RankingEvaluator.Arguments(); | ||
evalArgs.DcgTruncationLevel = args.DcgTruncationLevel; | ||
evalArgs.LabelGains = args.LabelGains; | ||
evalArgs.OutputGroupSummary = !string.IsNullOrEmpty(args.GroupSummaryFilename); | ||
var evalOpts = new RankingEvaluatorOptions(); | ||
evalOpts.DcgTruncationLevel = args.DcgTruncationLevel; | ||
evalOpts.LabelGains = args.LabelGains; | ||
evalOpts.OutputGroupSummary = !string.IsNullOrEmpty(args.GroupSummaryFilename); | ||
|
||
_evaluator = new RankingEvaluator(Host, evalArgs); | ||
_evaluator = new RankingEvaluator(Host, evalOpts); | ||
_groupSummaryFilename = args.GroupSummaryFilename; | ||
_groupIdCol = args.GroupIdColumn; | ||
} | ||
|
@@ -946,30 +953,41 @@ private protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleM | |
|
||
internal static class RankingUtils | ||
{ | ||
private static volatile Double[] _discountMap; | ||
public static Double[] DiscountMap | ||
// Truncation levels are typically less than 100. So we maintain a fixed discount map of size 100 | ||
// If truncation level greater than 100 is required, we build a new one and return that. | ||
private const int FixedDiscountMapSize = 100; | ||
private static Double[] _discountMapFixed; | ||
|
||
private static Double[] GetDiscountMapCore(int truncationLevel) | ||
{ | ||
var discountMap = new Double[FixedDiscountMapSize]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should use the truncationLevel passed into the function. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
for (int i = 0; i < discountMap.Length; i++) | ||
discountMap[i] = 1 / Math.Log(2 + i); | ||
|
||
return discountMap; | ||
} | ||
|
||
public static Double[] GetDiscountMap(int truncationLevel) | ||
{ | ||
get | ||
var discountMap = _discountMapFixed; | ||
if (discountMap == null) | ||
{ | ||
double[] result = _discountMap; | ||
if (result == null) | ||
{ | ||
var discountMap = new Double[100]; //Hard to believe anyone would set truncation Level higher than 100 | ||
for (int i = 0; i < discountMap.Length; i++) | ||
{ | ||
discountMap[i] = 1 / Math.Log(2 + i); | ||
} | ||
Interlocked.CompareExchange(ref _discountMap, discountMap, null); | ||
result = _discountMap; | ||
} | ||
return result; | ||
discountMap = GetDiscountMapCore(truncationLevel); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. First check if the requested level is small enough for the “fixed” map, if it is small enough, use the cached one. If the cached one hasn’t been created yet, then create it using ‘FixedDiscountMapSize’. #Resolved |
||
Interlocked.CompareExchange(ref _discountMapFixed, discountMap, null); | ||
discountMap = _discountMapFixed; | ||
} | ||
|
||
if (truncationLevel <= discountMap.Length) | ||
return discountMap; | ||
|
||
return GetDiscountMapCore(truncationLevel); | ||
} | ||
|
||
/// <summary> | ||
/// Calculates natural-based max DCG at all truncations from 1 to truncationLevel. | ||
/// </summary> | ||
public static void QueryMaxDcg(Double[] labelGains, int truncationLevel, | ||
public static void QueryMaxDcg(Double[] labelGains, int truncationLevel, Double[] discountMap, | ||
List<short> queryLabels, List<Single> queryOutputs, Double[] groupMaxDcgCur) | ||
{ | ||
Contracts.Assert(Utils.Size(groupMaxDcgCur) == truncationLevel); | ||
|
@@ -994,21 +1012,21 @@ public static void QueryMaxDcg(Double[] labelGains, int truncationLevel, | |
while (labelCounts[topLabel] == 0) | ||
topLabel--; | ||
|
||
groupMaxDcgCur[0] = labelGains[topLabel] * DiscountMap[0]; | ||
groupMaxDcgCur[0] = labelGains[topLabel] * discountMap[0]; | ||
labelCounts[topLabel]--; | ||
for (int t = 1; t < maxTrunc; t++) | ||
{ | ||
while (labelCounts[topLabel] == 0) | ||
topLabel--; | ||
groupMaxDcgCur[t] = groupMaxDcgCur[t - 1] + labelGains[topLabel] * DiscountMap[t]; | ||
groupMaxDcgCur[t] = groupMaxDcgCur[t - 1] + labelGains[topLabel] * discountMap[t]; | ||
labelCounts[topLabel]--; | ||
} | ||
for (int t = maxTrunc; t < truncationLevel; t++) | ||
groupMaxDcgCur[t] = groupMaxDcgCur[t - 1]; | ||
} | ||
} | ||
|
||
public static void QueryDcg(Double[] labelGains, int truncationLevel, | ||
public static void QueryDcg(Double[] labelGains, int truncationLevel, Double[] discountMap, | ||
List<short> queryLabels, List<Single> queryOutputs, Double[] groupDcgCur) | ||
{ | ||
// calculate the permutation | ||
|
@@ -1021,7 +1039,7 @@ public static void QueryDcg(Double[] labelGains, int truncationLevel, | |
Double dcg = 0; | ||
for (int t = 0; t < count; ++t) | ||
{ | ||
dcg = dcg + labelGains[queryLabels[permutation[t]]] * DiscountMap[t]; | ||
dcg = dcg + labelGains[queryLabels[permutation[t]]] * discountMap[t]; | ||
groupDcgCur[t] = dcg; | ||
} | ||
for (int t = count; t < truncationLevel; ++t) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -638,6 +638,21 @@ internal RankingTrainers(RankingCatalog catalog) | |
/// <param name="scoreColumnName">The name of the score column in <paramref name="data"/>.</param> | ||
/// <returns>The evaluation results for these calibrated outputs.</returns> | ||
public RankingMetrics Evaluate(IDataView data, | ||
string labelColumnName = DefaultColumnNames.Label, | ||
string rowGroupColumnName = DefaultColumnNames.GroupId, | ||
string scoreColumnName = DefaultColumnNames.Score) => Evaluate(data, null, labelColumnName, rowGroupColumnName, scoreColumnName); | ||
|
||
/// <summary> | ||
/// Evaluates scored ranking data. | ||
/// </summary> | ||
/// <param name="data">The scored data.</param> | ||
/// <param name="options">Options to control the evaluation result.</param> | ||
/// <param name="labelColumnName">The name of the label column in <paramref name="data"/>.</param> | ||
/// <param name="rowGroupColumnName">The name of the groupId column in <paramref name="data"/>.</param> | ||
/// <param name="scoreColumnName">The name of the score column in <paramref name="data"/>.</param> | ||
/// <returns>The evaluation results for these calibrated outputs.</returns> | ||
public RankingMetrics Evaluate(IDataView data, | ||
RankingEvaluatorOptions options, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add at least one new test for this new API? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And ensure there is an existing unit test covering the existing MAML interface for It should look about like: Specifically the The output should end in a block of DCG/NDCG@N metrics:
And it should match the |
||
string labelColumnName = DefaultColumnNames.Label, | ||
string rowGroupColumnName = DefaultColumnNames.GroupId, | ||
string scoreColumnName = DefaultColumnNames.Score) | ||
|
@@ -647,7 +662,7 @@ public RankingMetrics Evaluate(IDataView data, | |
Environment.CheckNonEmpty(scoreColumnName, nameof(scoreColumnName)); | ||
Environment.CheckNonEmpty(rowGroupColumnName, nameof(rowGroupColumnName)); | ||
|
||
var eval = new RankingEvaluator(Environment, new RankingEvaluator.Arguments() { }); | ||
var eval = new RankingEvaluator(Environment, options ?? new RankingEvaluatorOptions() { }); | ||
return eval.Evaluate(data, labelColumnName, rowGroupColumnName, scoreColumnName); | ||
} | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.