-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Ranker train context and FastTree ranking xtensions #1068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
f03e5c1
f9d629d
fdcfe0c
4857e80
4f257e6
95ea9da
52c4641
51dfd0a
fb8470e
a6da6c1
a29af1f
0b1f9b1
13e8d0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,17 +41,17 @@ public sealed class Arguments | |
public bool OutputGroupSummary; | ||
} | ||
|
||
public const string LoadName = "RankingEvaluator"; | ||
internal const string LoadName = "RankingEvaluator"; | ||
|
||
public const string Ndcg = "NDCG"; | ||
public const string Dcg = "DCG"; | ||
public const string MaxDcg = "MaxDCG"; | ||
|
||
/// <summary> | ||
/// <value> | ||
/// The ranking evaluator outputs a data view by this name, which contains metrics aggregated per group. | ||
/// It contains four columns: GroupId, NDCG, DCG and MaxDCG. Each row in the data view corresponds to one | ||
/// group in the scored data. | ||
/// </summary> | ||
/// </value> | ||
public const string GroupSummary = "GroupSummary"; | ||
|
||
private const string GroupId = "GroupId"; | ||
|
@@ -234,6 +234,40 @@ protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, A | |
}; | ||
} | ||
|
||
/// <summary> | ||
/// Evaluates scored regression data. | ||
/// </summary> | ||
/// <param name="data">The data to evaluate.</param> | ||
/// <param name="label">The name of the label column.</param> | ||
/// <param name="groupId">The name of the groupId column.</param> | ||
/// <param name="score">The name of the predicted score column.</param> | ||
/// <returns>The evaluation metrics for these outputs.</returns> | ||
public Result Evaluate(IDataView data, string label, string groupId, string score) | ||
{ | ||
Host.CheckValue(data, nameof(data)); | ||
Host.CheckNonEmpty(label, nameof(label)); | ||
Host.CheckNonEmpty(score, nameof(score)); | ||
var roles = new RoleMappedData(data, opt: false, | ||
RoleMappedSchema.ColumnRole.Label.Bind(label), | ||
RoleMappedSchema.ColumnRole.Group.Bind(groupId), | ||
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, score)); | ||
|
||
var resultDict = Evaluate(roles); | ||
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics)); | ||
var overall = resultDict[MetricKinds.OverallMetrics]; | ||
|
||
Result result; | ||
using (var cursor = overall.GetRowCursor(i => true)) | ||
{ | ||
var moved = cursor.MoveNext(); | ||
Host.Assert(moved); | ||
result = new Result(Host, cursor); | ||
moved = cursor.MoveNext(); | ||
Host.Assert(!moved); | ||
} | ||
return result; | ||
} | ||
|
||
public sealed class Aggregator : AggregatorBase | ||
{ | ||
public sealed class Counters | ||
|
@@ -509,6 +543,48 @@ public void GetSlotNames(ref VBuffer<ReadOnlyMemory<char>> slotNames) | |
slotNames = new VBuffer<ReadOnlyMemory<char>>(UnweightedCounters.TruncationLevel, values); | ||
} | ||
} | ||
|
||
public sealed class Result | ||
{ | ||
/// <summary> | ||
/// Normalized Discounted Cumulative Gain | ||
/// <a href="https://github.com/dotnet/machinelearning/tree/master/docs/images/ndcg.png"></a> | ||
/// </summary> | ||
public double[] Ndcg { get; } | ||
|
||
/// <summary> | ||
/// <a href="https://en.wikipedia.org/wiki/Discounted_cumulative_gain">Discounted Cumulative gain</a> | ||
/// is the sum of the gains, for all the instances i, normalized by the natural logarithm of the instance + 1. | ||
/// Note that unline the Wikipedia article, ML.Net uses the natural logarithm. | ||
/// <a href="https://github.com/dotnet/machinelearning/tree/master/docs/images/dcg.png"></a> | ||
/// </summary> | ||
public double[] Dcg { get; } | ||
|
||
/// <summary> | ||
/// MaxDcgs is the value of <see cref="Dcg"/> when the documents are ordered in the ideal order from most relevant to least relevant. | ||
/// In case there are ties in scores, metrics are computed in a pessimistic fashion. In other words, if two or more results get the same score, | ||
/// for the purpose of computing DCG and NDCG they are ordered from least relevant to most relevant. | ||
/// </summary> | ||
public double[] MaxDcg { get; } | ||
|
||
private static T Fetch<T>(IExceptionContext ectx, IRow row, string name) | ||
{ | ||
if (!row.Schema.TryGetColumnIndex(name, out int col)) | ||
throw ectx.Except($"Could not find column '{name}'"); | ||
T val = default; | ||
row.GetGetter<T>(col)(ref val); | ||
return val; | ||
} | ||
|
||
internal Result(IExceptionContext ectx, IRow overallResult) | ||
{ | ||
VBuffer<double> Fetch(string name) => Fetch<VBuffer<double>>(ectx, overallResult, name); | ||
|
||
Dcg = Fetch(RankerEvaluator.Dcg).Values; | ||
Ndcg = Fetch(RankerEvaluator.Ndcg).Values; | ||
// MaxDcg = Fetch(RankerEvaluator.MaxDcg).Values; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Uncomment, and retrieve correctly. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed from metrics, since it is a dataset characteristic. In reply to: 220800890 [](ancestors = 220800890) |
||
} | ||
} | ||
} | ||
|
||
public sealed class RankerPerInstanceTransform : IDataTransform | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,6 +105,48 @@ public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> pred | |
return rec.Output; | ||
} | ||
|
||
/// <summary> | ||
/// FastTree <see cref="RankerContext"/>. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Is this intended to be a placeholder? I'm not certain it's terribly helpful. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line appears to still be here. I like the additional line you've added though. In reply to: 221019258 [](ancestors = 221019258) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
/// </summary> | ||
/// <param name="ctx">The <see cref="RegressionContext"/>.</param> | ||
/// <param name="label">The label column.</param> | ||
/// <param name="features">The features colum.</param> | ||
/// <param name="groupId">The name of the groupId column.</param> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It's not really a name is it? The point of pigsty is you don't use names, you just pass in objects. #Closed |
||
/// <param name="weights">The weights column.</param> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Maybe while you're at it, "the optional weights column" could make clear it's just fine to leave it |
||
/// <param name="numLeaves">The number of leaves to use.</param> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
"Leaves to use" is kind of weird. Maybe "maximum number of leaves per decision tree" would be better. I might also prefer to have number of trees above number of leaves (that is, reverse the two args). #Closed |
||
/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param> | ||
/// <param name="minDocumentsInLeafs">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The term documents is a little unfortunate. It is a legacy of this packages roots in ranking documents for the web ranker, but no-where else in ML.NET do we call the datapoints we have "documents" as far as I am aware. #Closed |
||
/// <param name="learningRate">The learning rate.</param> | ||
/// <param name="advancedSettings">Algorithm advanced settings.</param> | ||
/// <param name="onFit">A delegate that is called every time the | ||
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}.Fit(DataView{TTupleInShape})"/> method is called on the | ||
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}"/> instance created out of this. This delegate will receive | ||
/// the linear model that was trained. Note that this action cannot change the result in any way; | ||
/// it is only a way for the caller to be informed about what was learnt.</param> | ||
/// <returns>The Score output column indicating the predicted value.</returns> | ||
public static Scalar<float> FastTree<TVal>(this RankerContext.RankerTrainers ctx, | ||
Scalar<float> label, Vector<float> features, Key<uint, TVal> groupId, Scalar<float> weights = null, | ||
int numLeaves = Defaults.NumLeaves, | ||
int numTrees = Defaults.NumTrees, | ||
int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs, | ||
double learningRate = Defaults.LearningRates, | ||
Action<FastTreeRankingTrainer.Arguments> advancedSettings = null, | ||
Action<FastTreeRankingPredictor> onFit = null) | ||
{ | ||
CheckUserValues(label, features, weights, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings, onFit); | ||
|
||
var rec = new TrainerEstimatorReconciler.Ranker<TVal>( | ||
(env, labelName, featuresName, groupIdName, weightsName) => | ||
{ | ||
var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName,advancedSettings); | ||
if (onFit != null) | ||
return trainer.WithOnFitDelegate(trans => onFit(trans.Model)); | ||
return trainer; | ||
}, label, features, groupId, weights); | ||
|
||
return rec.Score; | ||
} | ||
|
||
private static void CheckUserValues(PipelineColumn label, Vector<float> features, Scalar<float> weights, | ||
int numLeaves, | ||
int numTrees, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -451,5 +451,46 @@ public void LightGbmRegression() | |
Assert.Equal(metrics.Rms * metrics.Rms, metrics.L2, 5); | ||
Assert.InRange(metrics.LossFn, 0, double.PositiveInfinity); | ||
} | ||
|
||
[Fact] | ||
public void FastTreeRanking() | ||
{ | ||
var env = new ConsoleEnvironment(seed: 0); | ||
var dataPath = GetDataPath(TestDatasets.adultRanking.trainFilename); | ||
var dataSource = new MultiFileSource(dataPath); | ||
|
||
var ctx = new RankerContext(env); | ||
|
||
var reader = TextLoader.CreateReader(env, | ||
c => (label: c.LoadFloat(0), features: c.LoadFloat(9, 14), groupId: c.LoadText(1)), | ||
separator: '\t', hasHeader: true); | ||
|
||
FastTreeRankingPredictor pred = null; | ||
|
||
var est = reader.MakeNewEstimator() | ||
.Append(r => (r.label, r.features, groupId: r.groupId.ToKey())) | ||
.Append(r => (r.label, r.groupId, score: ctx.Trainers.FastTree(r.label, r.features, r.groupId, onFit: (p) => { pred = p; }))); | ||
|
||
var pipe = reader.Append(est); | ||
|
||
Assert.Null(pred); | ||
var model = pipe.Fit(dataSource); | ||
Assert.NotNull(pred); | ||
|
||
var data = model.Read(dataSource); | ||
|
||
var metrics = ctx.Evaluate(data, r => r.label, r => r.groupId, r => r.score); | ||
Assert.NotNull(metrics); | ||
|
||
Assert.True(metrics.Ndcg.Length == metrics.Dcg.Length && metrics.Dcg.Length == 3); | ||
|
||
Assert.InRange(metrics.Dcg[0], 1.4, 1.6); | ||
Assert.InRange(metrics.Dcg[1], 1.4, 1.8); | ||
Assert.InRange(metrics.Dcg[2], 1.4, 1.8); | ||
|
||
Assert.InRange(metrics.Ndcg[0], 36.5, 37); | ||
Assert.InRange(metrics.Ndcg[1], 36.5, 37); | ||
Assert.InRange(metrics.Ndcg[2], 36.5, 37); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm. That's weird that they're all so close. Are we sure about this? They're not all identical, are they? If they are, that might be a sign that the group-id is different for each line. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The values are: In reply to: 221056237 [](ancestors = 221056237) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm OK! A bit weird but we'll go with that. In reply to: 221145695 [](ancestors = 221145695,221056237) |
||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unlike