Skip to content

Ranker train context and FastTree ranking xtensions #1068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 29, 2018
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/images/DCG.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added docs/images/NDCG.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 35 additions & 0 deletions src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -167,5 +167,40 @@ public static RegressionEvaluator.Result Evaluate<T>(
args.LossFunction = new TrivialRegressionLossFactory(loss);
return new RegressionEvaluator(env, args).Evaluate(data.AsDynamic, labelName, scoreName);
}

/// <summary>
/// Evaluates scored ranking data.
/// </summary>
/// <typeparam name="T">The shape type for the input data.</typeparam>
/// <typeparam name="TVal">The type of data, before being converted to a key.</typeparam>
/// <param name="ctx">The ranking context.</param>
/// <param name="data">The data to evaluate.</param>
/// <param name="label">The index delegate for the label column.</param>
/// <param name="groupId">The index delegate for the groupId column. </param>
/// <param name="score">The index delegate for predicted score column.</param>
/// <returns>The evaluation metrics.</returns>
public static RankerEvaluator.Result Evaluate<T, TVal>(
this RankerContext ctx,
DataView<T> data,
Func<T, Scalar<float>> label,
Func<T, Key<uint, TVal>> groupId,
Func<T, Scalar<float>> score)
{
Contracts.CheckValue(data, nameof(data));
var env = StaticPipeUtils.GetEnvironment(data);
Contracts.AssertValue(env);
env.CheckValue(label, nameof(label));
env.CheckValue(groupId, nameof(groupId));
env.CheckValue(score, nameof(score));

var indexer = StaticPipeUtils.GetIndexer(data);
string labelName = indexer.Get(label(indexer.Indices));
string scoreName = indexer.Get(score(indexer.Indices));
string groupIdName = indexer.Get(groupId(indexer.Indices));

var args = new RankerEvaluator.Arguments() { };

return new RankerEvaluator(env, args).Evaluate(data.AsDynamic, labelName, groupIdName, scoreName);
}
}
}
74 changes: 71 additions & 3 deletions src/Microsoft.ML.Data/Evaluators/RankerEvaluator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,17 @@ public sealed class Arguments
public bool OutputGroupSummary;
}

public const string LoadName = "RankingEvaluator";
internal const string LoadName = "RankingEvaluator";

public const string Ndcg = "NDCG";
public const string Dcg = "DCG";
public const string MaxDcg = "MaxDCG";

/// <summary>
/// <value>
/// The ranking evaluator outputs a data view by this name, which contains metrics aggregated per group.
/// It contains four columns: GroupId, NDCG, DCG and MaxDCG. Each row in the data view corresponds to one
/// group in the scored data.
/// </summary>
/// </value>
public const string GroupSummary = "GroupSummary";

private const string GroupId = "GroupId";
Expand Down Expand Up @@ -234,6 +234,40 @@ protected override void GetAggregatorConsolidationFuncs(Aggregator aggregator, A
};
}

/// <summary>
/// Evaluates scored regression data.
/// </summary>
/// <param name="data">The data to evaluate.</param>
/// <param name="label">The name of the label column.</param>
/// <param name="groupId">The name of the groupId column.</param>
/// <param name="score">The name of the predicted score column.</param>
/// <returns>The evaluation metrics for these outputs.</returns>
public Result Evaluate(IDataView data, string label, string groupId, string score)
{
Host.CheckValue(data, nameof(data));
Host.CheckNonEmpty(label, nameof(label));
Host.CheckNonEmpty(score, nameof(score));
var roles = new RoleMappedData(data, opt: false,
RoleMappedSchema.ColumnRole.Label.Bind(label),
RoleMappedSchema.ColumnRole.Group.Bind(groupId),
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, score));

var resultDict = Evaluate(roles);
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
var overall = resultDict[MetricKinds.OverallMetrics];

Result result;
using (var cursor = overall.GetRowCursor(i => true))
{
var moved = cursor.MoveNext();
Host.Assert(moved);
result = new Result(Host, cursor);
moved = cursor.MoveNext();
Host.Assert(!moved);
}
return result;
}

public sealed class Aggregator : AggregatorBase
{
public sealed class Counters
Expand Down Expand Up @@ -509,6 +543,40 @@ public void GetSlotNames(ref VBuffer<ReadOnlyMemory<char>> slotNames)
slotNames = new VBuffer<ReadOnlyMemory<char>>(UnweightedCounters.TruncationLevel, values);
}
}

public sealed class Result
{
/// <summary>
/// Normalized Discounted Cumulative Gain
/// <a href="https://github.com/dotnet/machinelearning/tree/master/docs/images/ndcg.png"></a>
/// </summary>
public double[] Ndcg { get; }

/// <summary>
/// <a href="https://en.wikipedia.org/wiki/Discounted_cumulative_gain">Discounted Cumulative gain</a>
/// is the sum of the gains, for all the instances i, normalized by the natural logarithm of the instance + 1.
/// Note that unline the Wikipedia article, ML.Net uses the natural logarithm.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unlike

/// <a href="https://github.com/dotnet/machinelearning/tree/master/docs/images/dcg.png"></a>
/// </summary>
public double[] Dcg { get; }

private static T Fetch<T>(IExceptionContext ectx, IRow row, string name)
{
if (!row.Schema.TryGetColumnIndex(name, out int col))
throw ectx.Except($"Could not find column '{name}'");
T val = default;
row.GetGetter<T>(col)(ref val);
return val;
}

internal Result(IExceptionContext ectx, IRow overallResult)
{
VBuffer<double> Fetch(string name) => Fetch<VBuffer<double>>(ectx, overallResult, name);

Dcg = Fetch(RankerEvaluator.Dcg).Values;
Ndcg = Fetch(RankerEvaluator.Ndcg).Values;
}
}
}

public sealed class RankerPerInstanceTransform : IDataTransform
Expand Down
64 changes: 64 additions & 0 deletions src/Microsoft.ML.Data/StaticPipe/TrainerEstimatorReconciler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -403,5 +403,69 @@ public ImplScore(MulticlassClassifier<TVal> rec) : base(rec, rec.Inputs) { }
}
}

/// <summary>
/// A reconciler for ranking capable of handling the most common cases for ranking.
/// </summary>
public sealed class Ranker<TVal> : TrainerEstimatorReconciler
{
/// <summary>
/// The delegate to create the ranking trainer instance.
/// </summary>
/// <param name="env">The environment with which to create the estimator</param>
/// <param name="label">The label column name</param>
/// <param name="features">The features column name</param>
/// <param name="weights">The weights column name, or <c>null</c> if the reconciler was constructed with <c>null</c> weights</param>
/// <param name="groupId">The groupId column name.</param>
/// <returns>A estimator producing columns with the fixed name <see cref="DefaultColumnNames.Score"/>.</returns>
public delegate IEstimator<ITransformer> EstimatorFactory(IHostEnvironment env, string label, string features, string weights, string groupId);

private readonly EstimatorFactory _estFact;

/// <summary>
/// The output score column for ranking. This will have this instance as its reconciler.
/// </summary>
public Scalar<float> Score { get; }

protected override IEnumerable<PipelineColumn> Outputs => Enumerable.Repeat(Score, 1);

private static readonly string[] _fixedOutputNames = new[] { DefaultColumnNames.Score };

/// <summary>
/// Constructs a new general ranker reconciler.
/// </summary>
/// <param name="estimatorFactory">The delegate to create the training estimator. It is assumed that this estimator
/// will produce a single new scalar <see cref="float"/> column named <see cref="DefaultColumnNames.Score"/>.</param>
/// <param name="label">The input label column.</param>
/// <param name="features">The input features column.</param>
/// <param name="weights">The input weights column, or <c>null</c> if there are no weights.</param>
/// <param name="groupId">The input groupId column.</param>
public Ranker(EstimatorFactory estimatorFactory, Scalar<float> label, Vector<float> features, Key<uint, TVal> groupId, Scalar<float> weights)
: base(MakeInputs(Contracts.CheckRef(label, nameof(label)),
Contracts.CheckRef(features, nameof(features)),
Contracts.CheckRef(groupId, nameof(groupId)),
weights),
_fixedOutputNames)
{
Contracts.CheckValue(estimatorFactory, nameof(estimatorFactory));
_estFact = estimatorFactory;
Contracts.Assert(Inputs.Length == 3 || Inputs.Length == 4);
Score = new Impl(this);
}

private static PipelineColumn[] MakeInputs(Scalar<float> label, Vector<float> features, Key<uint, TVal> groupId, Scalar<float> weights)
=> weights == null ? new PipelineColumn[] { label, features, groupId } : new PipelineColumn[] { label, features, groupId, weights };

protected override IEstimator<ITransformer> ReconcileCore(IHostEnvironment env, string[] inputNames)
{
Contracts.AssertValue(env);
env.Assert(Utils.Size(inputNames) == Inputs.Length);
return _estFact(env, inputNames[0], inputNames[1], inputNames[2], inputNames.Length > 3 ? inputNames[3] : null);
}

private sealed class Impl : Scalar<float>
{
public Impl(Ranker<TVal> rec) : base(rec, rec.Inputs) { }
}
}
}
}
44 changes: 44 additions & 0 deletions src/Microsoft.ML.Data/Training/TrainContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -451,4 +451,48 @@ public RegressionEvaluator.Result Evaluate(IDataView data, string label, string
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
}
}

/// <summary>
/// The central context for regression trainers.
/// </summary>
public sealed class RankerContext : TrainContextBase
{
/// <summary>
/// For trainers for performing regression.
/// </summary>
public RankerTrainers Trainers { get; }

public RankerContext(IHostEnvironment env)
: base(env, nameof(RankerContext))
{
Trainers = new RankerTrainers(this);
}

public sealed class RankerTrainers : ContextInstantiatorBase
{
internal RankerTrainers(RankerContext ctx)
: base(ctx)
{
}
}

/// <summary>
/// Evaluates scored ranking data.
/// </summary>
/// <param name="data">The scored data.</param>
/// <param name="label">The name of the label column in <paramref name="data"/>.</param>
/// <param name="groupId">The name of the groupId column in <paramref name="data"/>.</param>
/// <param name="score">The name of the score column in <paramref name="data"/>.</param>
/// <returns>The evaluation results for these calibrated outputs.</returns>
public RankerEvaluator.Result Evaluate(IDataView data, string label, string groupId, string score = DefaultColumnNames.Score)
{
Host.CheckValue(data, nameof(data));
Host.CheckNonEmpty(label, nameof(label));
Host.CheckNonEmpty(score, nameof(score));
Host.CheckNonEmpty(groupId, nameof(groupId));

var eval = new RankerEvaluator(Host, new RankerEvaluator.Arguments() { });
return eval.Evaluate(data, label, groupId, score);
}
}
}
Loading