Skip to content

Commit eb26489

Browse files
Zruty0TomFinley
authored andcommitted
Extended contexts to regression and multiclass, added FFM pigstension
1 parent b88cc09 commit eb26489

File tree

10 files changed

+432
-128
lines changed

10 files changed

+432
-128
lines changed

src/Microsoft.ML.Data/Evaluators/EvaluatorStaticExtensions.cs

+86-2
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,13 @@ public static BinaryClassifierEvaluator.CalibratedResult Evaluate<T>(
5151
}
5252

5353
/// <summary>
54-
/// Evaluates scored binary classification data.
54+
/// Evaluates scored binary classification data, if the predictions are not calibrated.
5555
/// </summary>
5656
/// <typeparam name="T">The shape type for the input data.</typeparam>
5757
/// <param name="ctx">The binary classification context.</param>
5858
/// <param name="data">The data to evaluate.</param>
5959
/// <param name="label">The index delegate for the label column.</param>
60-
/// <param name="pred">The index delegate for columns from calibrated prediction of a binary classifier.
60+
/// <param name="pred">The index delegate for columns from uncalibrated prediction of a binary classifier.
6161
/// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param>
6262
/// <returns>The evaluation results for these uncalibrated outputs.</returns>
6363
public static BinaryClassifierEvaluator.Result Evaluate<T>(
@@ -83,5 +83,89 @@ public static BinaryClassifierEvaluator.Result Evaluate<T>(
8383
var eval = new BinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { });
8484
return eval.Evaluate(data.AsDynamic, labelName, scoreName, predName);
8585
}
86+
87+
/// <summary>
88+
/// Evaluates scored multiclass classification data.
89+
/// </summary>
90+
/// <typeparam name="T">The shape type for the input data.</typeparam>
91+
/// <typeparam name="TKey">The value type for the key label.</typeparam>
92+
/// <param name="ctx">The multiclass classification context.</param>
93+
/// <param name="data">The data to evaluate.</param>
94+
/// <param name="label">The index delegate for the label column.</param>
95+
/// <param name="pred">The index delegate for columns from the prediction of a multiclass classifier.
96+
/// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param>
97+
/// <param name="topK">If given a positive value, the <see cref="MultiClassClassifierEvaluator.Result.TopKAccuracy"/> will be filled with
98+
/// the top-K accuracy, that is, the accuracy assuming we consider an example with the correct class within
99+
/// the top-K values as being stored "correctly."</param>
100+
/// <returns>The evaluation metrics.</returns>
101+
public static MultiClassClassifierEvaluator.Result Evaluate<T, TKey>(
102+
this MulticlassClassificationContext ctx,
103+
DataView<T> data,
104+
Func<T, Key<uint, TKey>> label,
105+
Func<T, (Vector<float> score, Key<uint, TKey> predictedLabel)> pred,
106+
int topK = 0)
107+
{
108+
Contracts.CheckValue(data, nameof(data));
109+
var env = StaticPipeUtils.GetEnvironment(data);
110+
Contracts.AssertValue(env);
111+
env.CheckValue(label, nameof(label));
112+
env.CheckValue(pred, nameof(pred));
113+
env.CheckParam(topK >= 0, nameof(topK), "Must not be negative.");
114+
115+
var indexer = StaticPipeUtils.GetIndexer(data);
116+
string labelName = indexer.Get(label(indexer.Indices));
117+
(var scoreCol, var predCol) = pred(indexer.Indices);
118+
Contracts.CheckParam(scoreCol != null, nameof(pred), "Indexing delegate resulted in null score column.");
119+
Contracts.CheckParam(predCol != null, nameof(pred), "Indexing delegate resulted in null predicted label column.");
120+
string scoreName = indexer.Get(scoreCol);
121+
string predName = indexer.Get(predCol);
122+
123+
var args = new MultiClassClassifierEvaluator.Arguments() { };
124+
if (topK > 0)
125+
args.OutputTopKAcc = topK;
126+
127+
var eval = new MultiClassClassifierEvaluator(env, args);
128+
return eval.Evaluate(data.AsDynamic, labelName, scoreName, predName);
129+
}
130+
131+
private sealed class TrivialRegressionLossFactory : ISupportRegressionLossFactory
132+
{
133+
private readonly IRegressionLoss _loss;
134+
public TrivialRegressionLossFactory(IRegressionLoss loss) => _loss = loss;
135+
public IRegressionLoss CreateComponent(IHostEnvironment env) => _loss;
136+
}
137+
138+
/// <summary>
139+
/// Evaluates scored multiclass classification data.
140+
/// </summary>
141+
/// <typeparam name="T">The shape type for the input data.</typeparam>
142+
/// <param name="ctx">The regression context.</param>
143+
/// <param name="data">The data to evaluate.</param>
144+
/// <param name="label">The index delegate for the label column.</param>
145+
/// <param name="score">The index delegate for predicted score column.</param>
146+
/// <param name="loss">Potentially custom loss function. If left unspecified defaults to <see cref="SquaredLoss"/>.</param>
147+
/// <returns>The evaluation metrics.</returns>
148+
public static RegressionEvaluator.Result Evaluate<T>(
149+
this RegressionContext ctx,
150+
DataView<T> data,
151+
Func<T, Scalar<float>> label,
152+
Func<T, Scalar<float>> score,
153+
IRegressionLoss loss = null)
154+
{
155+
Contracts.CheckValue(data, nameof(data));
156+
var env = StaticPipeUtils.GetEnvironment(data);
157+
Contracts.AssertValue(env);
158+
env.CheckValue(label, nameof(label));
159+
env.CheckValue(score, nameof(score));
160+
161+
var indexer = StaticPipeUtils.GetIndexer(data);
162+
string labelName = indexer.Get(label(indexer.Indices));
163+
string scoreName = indexer.Get(score(indexer.Indices));
164+
165+
var args = new RegressionEvaluator.Arguments() { };
166+
if (loss != null)
167+
args.LossFunction = new TrivialRegressionLossFactory(loss);
168+
return new RegressionEvaluator(env, args).Evaluate(data.AsDynamic, labelName, scoreName);
169+
}
86170
}
87171
}

src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs

+21-46
Original file line numberDiff line numberDiff line change
@@ -598,62 +598,37 @@ internal Result(IExceptionContext ectx, IRow overallResult, int topK)
598598
}
599599

600600
/// <summary>
601-
/// Evaluates scored regression data.
601+
/// Evaluates scored multiclass classification data.
602602
/// </summary>
603-
/// <typeparam name="T">The shape type for the input data.</typeparam>
604-
/// <typeparam name="TKey">The value type for the key label.</typeparam>
605-
/// <param name="data">The data to evaluate.</param>
606-
/// <param name="label">The index delegate for the label column.</param>
607-
/// <param name="pred">The index delegate for columns from prediction of a multi-class classifier.
608-
/// Under typical scenarios, this will just be the same tuple of results returned from the trainer.</param>
609-
/// <param name="topK">If given a positive value, the <see cref="Result.TopKAccuracy"/> will be filled with
610-
/// the top-K accuracy, that is, the accuracy assuming we consider an example with the correct class within
611-
/// the top-K values as being stored "correctly."</param>
603+
/// <param name="data">The scored data.</param>
604+
/// <param name="label">The name of the label column in <paramref name="data"/>.</param>
605+
/// <param name="score">The name of the score column in <paramref name="data"/>.</param>
606+
/// <param name="predictedLabel">The name of the predicted label column in <paramref name="data"/>.</param>
612607
/// <returns>The evaluation results for these outputs.</returns>
613-
public static Result Evaluate<T, TKey>(
614-
DataView<T> data,
615-
Func<T, Key<uint, TKey>> label,
616-
Func<T, (Vector<float> score, Key<uint, TKey> predictedLabel)> pred,
617-
int topK = 0)
608+
public Result Evaluate(IDataView data, string label, string score, string predictedLabel)
618609
{
619-
Contracts.CheckValue(data, nameof(data));
620-
var env = StaticPipeUtils.GetEnvironment(data);
621-
Contracts.AssertValue(env);
622-
env.CheckValue(label, nameof(label));
623-
env.CheckValue(pred, nameof(pred));
624-
env.CheckParam(topK >= 0, nameof(topK), "Must not be negative.");
625-
626-
var indexer = StaticPipeUtils.GetIndexer(data);
627-
string labelName = indexer.Get(label(indexer.Indices));
628-
(var scoreCol, var predCol) = pred(indexer.Indices);
629-
Contracts.CheckParam(scoreCol != null, nameof(pred), "Indexing delegate resulted in null score column.");
630-
Contracts.CheckParam(predCol != null, nameof(pred), "Indexing delegate resulted in null predicted label column.");
631-
string scoreName = indexer.Get(scoreCol);
632-
string predName = indexer.Get(predCol);
633-
634-
var args = new Arguments() { };
635-
if (topK > 0)
636-
args.OutputTopKAcc = topK;
637-
638-
var eval = new MultiClassClassifierEvaluator(env, args);
639-
640-
var roles = new RoleMappedData(data.AsDynamic, opt: false,
641-
RoleMappedSchema.ColumnRole.Label.Bind(labelName),
642-
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreName),
643-
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.PredictedLabel, predName));
644-
645-
var resultDict = eval.Evaluate(roles);
646-
env.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
610+
Host.CheckValue(data, nameof(data));
611+
Host.CheckNonEmpty(label, nameof(label));
612+
Host.CheckNonEmpty(score, nameof(score));
613+
Host.CheckNonEmpty(predictedLabel, nameof(predictedLabel));
614+
615+
var roles = new RoleMappedData(data, opt: false,
616+
RoleMappedSchema.ColumnRole.Label.Bind(label),
617+
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, score),
618+
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.PredictedLabel, predictedLabel));
619+
620+
var resultDict = Evaluate(roles);
621+
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
647622
var overall = resultDict[MetricKinds.OverallMetrics];
648623

649624
Result result;
650625
using (var cursor = overall.GetRowCursor(i => true))
651626
{
652627
var moved = cursor.MoveNext();
653-
env.Assert(moved);
654-
result = new Result(env, cursor, topK);
628+
Host.Assert(moved);
629+
result = new Result(Host, cursor, _outputTopKAcc ?? 0);
655630
moved = cursor.MoveNext();
656-
env.Assert(!moved);
631+
Host.Assert(!moved);
657632
}
658633
return result;
659634
}

src/Microsoft.ML.Data/Evaluators/RegressionEvaluator.cs

+20-42
Original file line numberDiff line numberDiff line change
@@ -219,65 +219,43 @@ internal Result(IExceptionContext ectx, IRow overallResult)
219219
double Fetch(string name) => Fetch<double>(ectx, overallResult, name);
220220
L1 = Fetch(RegressionEvaluator.L1);
221221
L2 = Fetch(RegressionEvaluator.L2);
222-
Rms= Fetch(RegressionEvaluator.Rms);
222+
Rms = Fetch(RegressionEvaluator.Rms);
223223
LossFn = Fetch(RegressionEvaluator.Loss);
224224
RSquared = Fetch(RegressionEvaluator.RSquared);
225225
}
226226
}
227227

228-
private sealed class TrivialLossFactory : ISupportRegressionLossFactory
229-
{
230-
private readonly IRegressionLoss _loss;
231-
public TrivialLossFactory(IRegressionLoss loss) => _loss = loss;
232-
public IRegressionLoss CreateComponent(IHostEnvironment env) => _loss;
233-
}
234-
235228
/// <summary>
236229
/// Evaluates scored regression data.
237230
/// </summary>
238-
/// <typeparam name="T">The shape type for the input data.</typeparam>
239231
/// <param name="data">The data to evaluate.</param>
240-
/// <param name="label">The index delegate for the label column.</param>
241-
/// <param name="score">The index delegate for the predicted score column.</param>
242-
/// <param name="loss">Potentially custom loss function. If left unspecified defaults to <see cref="SquaredLoss"/>.</param>
243-
/// <returns>The evaluation results for these outputs.</returns>
244-
public static Result Evaluate<T>(
245-
DataView<T> data,
246-
Func<T, Scalar<float>> label,
247-
Func<T, Scalar<float>> score,
248-
IRegressionLoss loss = null)
232+
/// <param name="label">The name of the label column.</param>
233+
/// <param name="score">The name of the predicted score column.</param>
234+
/// <returns>The evaluation metrics for these outputs.</returns>
235+
public Result Evaluate(
236+
IDataView data,
237+
string label,
238+
string score)
249239
{
250-
Contracts.CheckValue(data, nameof(data));
251-
var env = StaticPipeUtils.GetEnvironment(data);
252-
Contracts.AssertValue(env);
253-
env.CheckValue(label, nameof(label));
254-
env.CheckValue(score, nameof(score));
255-
256-
var indexer = StaticPipeUtils.GetIndexer(data);
257-
string labelName = indexer.Get(label(indexer.Indices));
258-
string scoreName = indexer.Get(score(indexer.Indices));
259-
260-
var args = new Arguments() { };
261-
if (loss != null)
262-
args.LossFunction = new TrivialLossFactory(loss);
263-
var eval = new RegressionEvaluator(env, args);
264-
265-
var roles = new RoleMappedData(data.AsDynamic, opt: false,
266-
RoleMappedSchema.ColumnRole.Label.Bind(labelName),
267-
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, scoreName));
268-
269-
var resultDict = eval.Evaluate(roles);
270-
env.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
240+
Host.CheckValue(data, nameof(data));
241+
Host.CheckNonEmpty(label, nameof(label));
242+
Host.CheckNonEmpty(score, nameof(score));
243+
var roles = new RoleMappedData(data, opt: false,
244+
RoleMappedSchema.ColumnRole.Label.Bind(label),
245+
RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, score));
246+
247+
var resultDict = Evaluate(roles);
248+
Host.Assert(resultDict.ContainsKey(MetricKinds.OverallMetrics));
271249
var overall = resultDict[MetricKinds.OverallMetrics];
272250

273251
Result result;
274252
using (var cursor = overall.GetRowCursor(i => true))
275253
{
276254
var moved = cursor.MoveNext();
277-
env.Assert(moved);
278-
result = new Result(env, cursor);
255+
Host.Assert(moved);
256+
result = new Result(Host, cursor);
279257
moved = cursor.MoveNext();
280-
env.Assert(!moved);
258+
Host.Assert(!moved);
281259
}
282260
return result;
283261
}

0 commit comments

Comments
 (0)