Skip to content

Ranker train context and FastTree ranking xtensions #1068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 29, 2018
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/images/DCG.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/NDCG.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 35 additions & 0 deletions src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -167,5 +167,40 @@ public static RegressionEvaluator.Result Evaluate<T>(
args.LossFunction = new TrivialRegressionLossFactory(loss);
return new RegressionEvaluator(env, args).Evaluate(data.AsDynamic, labelName, scoreName);
}

/// <summary>
/// Evaluates scored ranking data.
/// </summary>
/// <typeparam name="T">The shape type for the input data.</typeparam>
/// <typeparam name="TVal">The type of data, before being converted to a key.</typeparam>
/// <param name="ctx">The ranking context.</param>
/// <param name="data">The data to evaluate.</param>
/// <param name="label">The index delegate for the label column.</param>
/// <param name="groupId">The index delegate for the groupId column. </param>
/// <param name="score">The index delegate for predicted score column.</param>
/// <returns>The evaluation metrics.</returns>
public static RankerEvaluator.Result Evaluate<T, TVal>(
this RankerContext ctx,
DataView<T> data,
Func<T, Scalar<float>> label,
Func<T, Key<uint, TVal>> groupId,
Func<T, Scalar<float>> score)
{
Contracts.CheckValue(data, nameof(data));
var env = StaticPipeUtils.GetEnvironment(data);
Contracts.AssertValue(env);
env.CheckValue(label, nameof(label));
env.CheckValue(groupId, nameof(groupId));
env.CheckValue(score, nameof(score));

var indexer = StaticPipeUtils.GetIndexer(data);
string labelName = indexer.Get(label(indexer.Indices));
string scoreName = indexer.Get(score(indexer.Indices));
string groupIdName = indexer.Get(groupId(indexer.Indices));

var args = new RankerEvaluator.Arguments() { };

return new RankerEvaluator(env, args).Evaluate(data.AsDynamic, labelName, groupIdName, scoreName);
}
}
}
82 changes: 79 additions & 3 deletions src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,17 @@ public sealed class Arguments
public bool OutputGroupSummary;
}

public const string LoadName = "RankingEvaluator";
internal const string LoadName = "RankingEvaluator";

public const string Ndcg = "NDCG";
public const string Dcg = "DCG";
public const string MaxDcg = "MaxDCG";

/// <summary>
/// <value>
/// The ranking evaluator outputs a data view by this name, which contains metrics aggregated per group.
/// It contains four columns: GroupId, NDCG, DCG and MaxDCG. Each row in the data view corresponds to one
/// group in the scored data.
/// </summary>
/// </value>
public const string GroupSummary = "GroupSummary";

private const string GroupId = "GroupId";
Expand Down Expand Up @@ -234,6 +234,40 @@ protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, A
};
}

/// <summary>
/// Evaluates scored regression data.
/// </summary>
/// <param name="data">The data to evaluate.</param>
/// <param name="label">The name of the label column.</param>
/// <param name="groupId">The name of the groupId column.</param>
/// <param name="score">The name of the predicted score column.</param>
/// <returns>The evaluation metrics for these outputs.</returns>
public Result Evaluate(IDataView data, string label, string groupId, string score)
{
Host.CheckValue(data, nameof(data));
Host.CheckNonEmpty(label, nameof(label));
Host.CheckNonEmpty(score, nameof(score));
var roles = new RoleMappedData(data, opt: false,
RoleMappedSchema.ColumnRole.Label.Bind(label),
RoleMappedSchema.ColumnRole.Group.Bind(groupId),
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, score));

var resultDict = Evaluate(roles);
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
var overall = resultDict[MetricKinds.OverallMetrics];

Result result;
using (var cursor = overall.GetRowCursor(i => true))
{
var moved = cursor.MoveNext();
Host.Assert(moved);
result = new Result(Host, cursor);
moved = cursor.MoveNext();
Host.Assert(!moved);
}
return result;
}

public sealed class Aggregator : AggregatorBase
{
public sealed class Counters
Expand Down Expand Up @@ -509,6 +543,48 @@ public void GetSlotNames(ref VBuffer<ReadOnlyMemory<char>> slotNames)
slotNames = new VBuffer<ReadOnlyMemory<char>>(UnweightedCounters.TruncationLevel, values);
}
}

public sealed class Result
{
/// <summary>
/// Normalized Discounted Cumulative Gain
/// <a href="https://github.com/dotnet/machinelearning/tree/master/docs/images/ndcg.png"></a>
/// </summary>
public double[] Ndcg { get; }

/// <summary>
/// <a href="https://en.wikipedia.org/wiki/Discounted_cumulative_gain">Discounted Cumulative gain</a>
/// is the sum of the gains, for all the instances i, normalized by the natural logarithm of the instance + 1.
/// Note that unline the Wikipedia article, ML.Net uses the natural logarithm.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unlike

/// <a href="https://github.com/dotnet/machinelearning/tree/master/docs/images/dcg.png"></a>
/// </summary>
public double[] Dcg { get; }

/// <summary>
/// MaxDcgs is the value of <see cref="Dcg"/> when the documents are ordered in the ideal order from most relevant to least relevant.
/// In case there are ties in scores, metrics are computed in a pessimistic fashion. In other words, if two or more results get the same score,
/// for the purpose of computing DCG and NDCG they are ordered from least relevant to most relevant.
/// </summary>
public double[] MaxDcg { get; }

private static T Fetch<T>(IExceptionContext ectx, IRow row, string name)
{
if (!row.Schema.TryGetColumnIndex(name, out int col))
throw ectx.Except($"Could not find column '{name}'");
T val = default;
row.GetGetter<T>(col)(ref val);
return val;
}

internal Result(IExceptionContext ectx, IRow overallResult)
{
VBuffer<double> Fetch(string name) => Fetch<VBuffer<double>>(ectx, overallResult, name);

Dcg = Fetch(RankerEvaluator.Dcg).Values;
Ndcg = Fetch(RankerEvaluator.Ndcg).Values;
// MaxDcg = Fetch(RankerEvaluator.MaxDcg).Values;
Copy link
Member Author

@sfilipi sfilipi Sep 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// MaxDcg = Fetch(RankerEvaluator.MaxDcg).Values [](start = 16, length = 48)

Uncomment, and retrieve correctly. #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed from metrics, since it is a dataset characteristic.


In reply to: 220800890 [](ancestors = 220800890)

}
}
}

public sealed class RankerPerInstanceTransform : IDataTransform
Expand Down
64 changes: 64 additions & 0 deletions src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -403,5 +403,69 @@ public ImplScore(MulticlassClassifier<TVal> rec) : base(rec, rec.Inputs) { }
}
}

/// <summary>
/// A reconciler for ranking capable of handling the most common cases for ranking.
/// </summary>
public sealed class Ranker<TVal> : TrainerEstimatorReconciler
{
/// <summary>
/// The delegate to create the ranking trainer instance.
/// </summary>
/// <param name="env">The environment with which to create the estimator</param>
/// <param name="label">The label column name</param>
/// <param name="features">The features column name</param>
/// <param name="weights">The weights column name, or <c>null</c> if the reconciler was constructed with <c>null</c> weights</param>
/// <param name="groupId">The groupId column name.</param>
/// <returns>A estimator producing columns with the fixed name <see cref="DefaultColumnNames.Score"/>.</returns>
public delegate IEstimator<ITransformer> EstimatorFactory(IHostEnvironment env, string label, string features, string weights, string groupId);

private readonly EstimatorFactory _estFact;

/// <summary>
/// The output score column for ranking. This will have this instance as its reconciler.
/// </summary>
public Scalar<float> Score { get; }

protected override IEnumerable<PipelineColumn> Outputs => Enumerable.Repeat(Score, 1);

private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score };

/// <summary>
/// Constructs a new general ranker reconciler.
/// </summary>
/// <param name="estimatorFactory">The delegate to create the training estimator. It is assumed that this estimator
/// will produce a single new scalar <see cref="float"/> column named <see cref="DefaultColumnNames.Score"/>.</param>
/// <param name="label">The input label column.</param>
/// <param name="features">The input features column.</param>
/// <param name="weights">The input weights column, or <c>null</c> if there are no weights.</param>
/// <param name="groupId">The input groupId column.</param>
public Ranker(EstimatorFactory estimatorFactory, Scalar<float> label, Vector<float> features, Key<uint, TVal> groupId, Scalar<float> weights)
: base(MakeInputs(Contracts.CheckRef(label, nameof(label)),
Contracts.CheckRef(features, nameof(features)),
Contracts.CheckRef(groupId, nameof(groupId)),
weights),
_fixedOutputNames)
{
Contracts.CheckValue(estimatorFactory, nameof(estimatorFactory));
_estFact = estimatorFactory;
Contracts.Assert(Inputs.Length == 3 || Inputs.Length == 4);
Score = new Impl(this);
}

private static PipelineColumn[] MakeInputs(Scalar<float> label, Vector<float> features, Key<uint, TVal> groupId, Scalar<float> weights)
=> weights == null ? new PipelineColumn[] { label, features, groupId } : new PipelineColumn[] { label, features, groupId, weights };

protected override IEstimator<ITransformer> ReconcileCore(IHostEnvironment env, string[] inputNames)
{
Contracts.AssertValue(env);
env.Assert(Utils.Size(inputNames) == Inputs.Length);
return _estFact(env, inputNames[0], inputNames[1], inputNames[2], inputNames.Length > 3 ? inputNames[3] : null);
}

private sealed class Impl : Scalar<float>
{
public Impl(Ranker<TVal> rec) : base(rec, rec.Inputs) { }
}
}
}
}
44 changes: 44 additions & 0 deletions src/Microsoft.ML.Data/Training/TrainContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -451,4 +451,48 @@ public RegressionEvaluator.Result Evaluate(IDataView data, string label, string
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
}
}

/// <summary>
/// The central context for regression trainers.
/// </summary>
public sealed class RankerContext : TrainContextBase
{
/// <summary>
/// For trainers for performing regression.
/// </summary>
public RankerTrainers Trainers { get; }

public RankerContext(IHostEnvironment env)
: base(env, nameof(RankerContext))
{
Trainers = new RankerTrainers(this);
}

public sealed class RankerTrainers : ContextInstantiatorBase
{
internal RankerTrainers(RankerContext ctx)
: base(ctx)
{
}
}

/// <summary>
/// Evaluates scored ranking data.
/// </summary>
/// <param name="data">The scored data.</param>
/// <param name="label">The name of the label column in <paramref name="data"/>.</param>
/// <param name="groupId">The name of the groupId column in <paramref name="data"/>.</param>
/// <param name="score">The name of the score column in <paramref name="data"/>.</param>
/// <returns>The evaluation results for these calibrated outputs.</returns>
public RankerEvaluator.Result Evaluate(IDataView data, string label, string groupId, string score = DefaultColumnNames.Score)
{
Host.CheckValue(data, nameof(data));
Host.CheckNonEmpty(label, nameof(label));
Host.CheckNonEmpty(score, nameof(score));
Host.CheckNonEmpty(groupId, nameof(groupId));

var eval = new RankerEvaluator(Host, new RankerEvaluator.Arguments() { });
return eval.Evaluate(data, label, groupId, score);
}
}
}
42 changes: 42 additions & 0 deletions src/Microsoft.ML.FastTree/FastTreeStatic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,48 @@ public static (Scalar<float> score, Scalar<float> probability, Scalar<bool> pred
return rec.Output;
}

/// <summary>
/// FastTree <see cref="RankerContext"/>.
Copy link
Contributor

@TomFinley TomFinley Sep 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FastTree . [](start = 12, length = 37)

Is this intended to be a placeholder? I'm not certain it's terribly helpful.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line appears to still be here. I like the additional line you've added though.


In reply to: 221019258 [](ancestors = 221019258)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duh.. let me remove that first one.


In reply to: 221379761 [](ancestors = 221379761,221019258)

/// </summary>
/// <param name="ctx">The <see cref="RegressionContext"/>.</param>
/// <param name="label">The label column.</param>
/// <param name="features">The features colum.</param>
/// <param name="groupId">The name of the groupId column.</param>
Copy link
Contributor

@TomFinley TomFinley Sep 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of the groupId column [](start = 34, length = 30)

It's not really a name is it? The point of pigsty is you don't use names, you just pass in objects. #Closed

/// <param name="weights">The weights column.</param>
Copy link
Contributor

@TomFinley TomFinley Sep 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weights [](start = 38, length = 7)

Maybe while you're at it, "the optional weights column" could make clear it's just fine to leave it null... though maybe people form this expectation already given that its default value is null. I don't know. #Closed

/// <param name="numLeaves">The number of leaves to use.</param>
Copy link
Contributor

@TomFinley TomFinley Sep 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to use [](start = 57, length = 6)

"Leaves to use" is kind of weird. Maybe "maximum number of leaves per decision tree" would be better. I might also prefer to have number of trees above number of leaves (that is, reverse the two args). #Closed

/// <param name="numTrees">Total number of decision trees to create in the ensemble.</param>
/// <param name="minDocumentsInLeafs">The minimal number of documents allowed in a leaf of a regression tree, out of the subsampled data.</param>
Copy link
Contributor

@TomFinley TomFinley Sep 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

documents [](start = 68, length = 9)

The term documents is a little unfortunate. It is a legacy of this packages roots in ranking documents for the web ranker, but no-where else in ML.NET do we call the datapoints we have "documents" as far as I am aware. #Closed

/// <param name="learningRate">The learning rate.</param>
/// <param name="advancedSettings">Algorithm advanced settings.</param>
/// <param name="onFit">A delegate that is called every time the
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}.Fit(DataView{TTupleInShape})"/> method is called on the
/// <see cref="Estimator{TTupleInShape, TTupleOutShape, TTransformer}"/> instance created out of this. This delegate will receive
/// the linear model that was trained. Note that this action cannot change the result in any way;
/// it is only a way for the caller to be informed about what was learnt.</param>
/// <returns>The Score output column indicating the predicted value.</returns>
public static Scalar<float> FastTree<TVal>(this RankerContext.RankerTrainers ctx,
Scalar<float> label, Vector<float> features, Key<uint, TVal> groupId, Scalar<float> weights = null,
int numLeaves = Defaults.NumLeaves,
int numTrees = Defaults.NumTrees,
int minDocumentsInLeafs = Defaults.MinDocumentsInLeafs,
double learningRate = Defaults.LearningRates,
Action<FastTreeRankingTrainer.Arguments> advancedSettings = null,
Action<FastTreeRankingPredictor> onFit = null)
{
CheckUserValues(label, features, weights, numLeaves, numTrees, minDocumentsInLeafs, learningRate, advancedSettings, onFit);

var rec = new TrainerEstimatorReconciler.Ranker<TVal>(
(env, labelName, featuresName, groupIdName, weightsName) =>
{
var trainer = new FastTreeRankingTrainer(env, labelName, featuresName, groupIdName, weightsName,advancedSettings);
if (onFit != null)
return trainer.WithOnFitDelegate(trans => onFit(trans.Model));
return trainer;
}, label, features, groupId, weights);

return rec.Score;
}

private static void CheckUserValues(PipelineColumn label, Vector<float> features, Scalar<float> weights,
int numLeaves,
int numTrees,
Expand Down
41 changes: 41 additions & 0 deletions test/Microsoft.ML.StaticPipelineTesting/Training.cs
Original file line number Diff line number Diff line change
Expand Up @@ -451,5 +451,46 @@ public void LightGbmRegression()
Assert.Equal(metrics.Rms * metrics.Rms, metrics.L2, 5);
Assert.InRange(metrics.LossFn, 0, double.PositiveInfinity);
}

[Fact]
public void FastTreeRanking()
{
var env = new ConsoleEnvironment(seed: 0);
var dataPath = GetDataPath(TestDatasets.adultRanking.trainFilename);
var dataSource = new MultiFileSource(dataPath);

var ctx = new RankerContext(env);

var reader = TextLoader.CreateReader(env,
c => (label: c.LoadFloat(0), features: c.LoadFloat(9, 14), groupId: c.LoadText(1)),
separator: '\t', hasHeader: true);

FastTreeRankingPredictor pred = null;

var est = reader.MakeNewEstimator()
.Append(r => (r.label, r.features, groupId: r.groupId.ToKey()))
.Append(r => (r.label, r.groupId, score: ctx.Trainers.FastTree(r.label, r.features, r.groupId, onFit: (p) => { pred = p; })));

var pipe = reader.Append(est);

Assert.Null(pred);
var model = pipe.Fit(dataSource);
Assert.NotNull(pred);

var data = model.Read(dataSource);

var metrics = ctx.Evaluate(data, r => r.label, r => r.groupId, r => r.score);
Assert.NotNull(metrics);

Assert.True(metrics.Ndcg.Length == metrics.Dcg.Length && metrics.Dcg.Length == 3);

Assert.InRange(metrics.Dcg[0], 1.4, 1.6);
Assert.InRange(metrics.Dcg[1], 1.4, 1.8);
Assert.InRange(metrics.Dcg[2], 1.4, 1.8);

Assert.InRange(metrics.Ndcg[0], 36.5, 37);
Assert.InRange(metrics.Ndcg[1], 36.5, 37);
Assert.InRange(metrics.Ndcg[2], 36.5, 37);
Copy link
Contributor

@TomFinley TomFinley Sep 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. That's weird that they're all so close. Are we sure about this? They're not all identical, are they? If they are, that might be a sign that the group-id is different for each line.
#Closed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The values are:
Dcg = [ 1.5972695, 1.770648, 1.79641082 ]
Ndcg = [ 36.90476, 36.7512, 36.8729 ]


In reply to: 221056237 [](ancestors = 221056237)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm OK! A bit weird but we'll go with that.


In reply to: 221145695 [](ancestors = 221145695,221056237)

}
}
}