|
2 | 2 | // The .NET Foundation licenses this file to you under the MIT license.
|
3 | 3 | // See the LICENSE file in the project root for more information.
|
4 | 4 |
|
| 5 | +using Microsoft.ML.Core.Data; |
5 | 6 | using Microsoft.ML.Core.Prediction;
|
6 | 7 | using Microsoft.ML.Runtime;
|
7 | 8 | using Microsoft.ML.Runtime.Data;
|
8 | 9 | using Microsoft.ML.Trainers;
|
9 | 10 | using System;
|
| 11 | +using System.Linq; |
10 | 12 |
|
11 | 13 | namespace Microsoft.ML
|
12 | 14 | {
|
13 | 15 | public static class RecommenderCatalog
|
14 | 16 | {
|
| 17 | + |
| 18 | + /// <summary> |
| 19 | + /// Trainers and tasks specific to ranking problems. |
| 20 | + /// </summary> |
| 21 | + public static RecommendationContext Recommendation(this MLContext ctx) => new RecommendationContext(ctx); |
| 22 | + } |
| 23 | + |
| 24 | + /// <summary> |
| 25 | + /// The central context for regression trainers. |
| 26 | + /// </summary> |
| 27 | + public sealed class RecommendationContext : TrainContextBase |
| 28 | + { |
| 29 | + /// <summary> |
| 30 | + /// For trainers for performing regression. |
| 31 | + /// </summary> |
| 32 | + public RecommendationTrainers Trainers { get; } |
| 33 | + |
| 34 | + public RecommendationContext(IHostEnvironment env) |
| 35 | + : base(env, nameof(RecommendationContext)) |
| 36 | + { |
| 37 | + Trainers = new RecommendationTrainers(this); |
| 38 | + } |
| 39 | + |
| 40 | + public sealed class RecommendationTrainers : ContextInstantiatorBase |
| 41 | + { |
| 42 | + internal RecommendationTrainers(RecommendationContext ctx) |
| 43 | + : base(ctx) |
| 44 | + { |
| 45 | + } |
| 46 | + |
| 47 | + /// <summary> |
| 48 | + /// Initializing a new instance of <see cref="MatrixFactorizationTrainer"/>. |
| 49 | + /// </summary> |
| 50 | + /// <param name="matrixColumnIndexColumnName">The name of the column hosting the matrix's column IDs.</param> |
| 51 | + /// <param name="matrixRowIndexColumnName">The name of the column hosting the matrix's row IDs.</param> |
| 52 | + /// <param name="labelColumn">The name of the label column.</param> |
| 53 | + /// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param> |
| 54 | + /// <param name="context">The <see cref="TrainerEstimatorContext"/> for additional input data to training.</param> |
| 55 | + public MatrixFactorizationTrainer MatrixFactorization( |
| 56 | + string matrixColumnIndexColumnName, |
| 57 | + string matrixRowIndexColumnName, |
| 58 | + string labelColumn = DefaultColumnNames.Label, |
| 59 | + TrainerEstimatorContext context = null, |
| 60 | + Action<MatrixFactorizationTrainer.Arguments> advancedSettings = null) |
| 61 | + => new MatrixFactorizationTrainer(Owner.Environment, matrixColumnIndexColumnName, matrixRowIndexColumnName, labelColumn, context, advancedSettings); |
| 62 | + } |
| 63 | + |
| 64 | + /// <summary> |
| 65 | + /// Evaluates scored recommendation data. |
| 66 | + /// </summary> |
| 67 | + /// <param name="data">The scored data.</param> |
| 68 | + /// <param name="label">The name of the label column in <paramref name="data"/>.</param> |
| 69 | + /// <param name="score">The name of the score column in <paramref name="data"/>.</param> |
| 70 | + /// <returns>The evaluation results for these calibrated outputs.</returns> |
| 71 | + public RegressionEvaluator.Result Evaluate(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score) |
| 72 | + { |
| 73 | + Host.CheckValue(data, nameof(data)); |
| 74 | + Host.CheckNonEmpty(label, nameof(label)); |
| 75 | + Host.CheckNonEmpty(score, nameof(score)); |
| 76 | + |
| 77 | + var eval = new RegressionEvaluator(Host, new RegressionEvaluator.Arguments() { }); |
| 78 | + return eval.Evaluate(data, label, score); |
| 79 | + } |
| 80 | + |
15 | 81 | /// <summary>
|
16 |
| - /// Initializing a new instance of <see cref="MatrixFactorizationTrainer"/>. |
| 82 | + /// Run cross-validation over <paramref name="numFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>, |
| 83 | + /// and respecting <paramref name="stratificationColumn"/> if provided. |
| 84 | + /// Then evaluate each sub-model against <paramref name="labelColumn"/> and return metrics. |
17 | 85 | /// </summary>
|
18 |
| - /// <param name="ctx">The <see cref="RegressionContext.RegressionTrainers"/> instance.</param> |
19 |
| - /// <param name="matrixColumnIndexColumnName">The name of the column hosting the matrix's column IDs.</param> |
20 |
| - /// <param name="matrixRowIndexColumnName">The name of the column hosting the matrix's row IDs.</param> |
21 |
| - /// <param name="labelColumn">The name of the label column.</param> |
22 |
| - /// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param> |
23 |
| - /// <param name="context">The <see cref="TrainerEstimatorContext"/> for additional input data to training.</param> |
24 |
| - public static MatrixFactorizationTrainer MatrixFactorization(this RegressionContext.RegressionTrainers ctx, |
25 |
| - string matrixColumnIndexColumnName, |
26 |
| - string matrixRowIndexColumnName, |
27 |
| - string labelColumn = DefaultColumnNames.Label, |
28 |
| - TrainerEstimatorContext context = null, |
29 |
| - Action<MatrixFactorizationTrainer.Arguments> advancedSettings = null) |
| 86 | + /// <param name="data">The data to run cross-validation on.</param> |
| 87 | + /// <param name="estimator">The estimator to fit.</param> |
| 88 | + /// <param name="numFolds">Number of cross-validation folds.</param> |
| 89 | + /// <param name="labelColumn">The label column (for evaluation).</param> |
| 90 | + /// <param name="stratificationColumn">Optional stratification column.</param> |
| 91 | + /// <remarks>If two examples share the same value of the <paramref name="stratificationColumn"/> (if provided), |
| 92 | + /// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from |
| 93 | + /// train to the test set.</remarks> |
| 94 | + /// <returns>Per-fold results: metrics, models, scored datasets.</returns> |
| 95 | + public (RegressionEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate( |
| 96 | + IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null) |
30 | 97 | {
|
31 |
| - Contracts.CheckValue(ctx, nameof(ctx)); |
32 |
| - return new MatrixFactorizationTrainer(CatalogUtils.GetEnvironment(ctx), matrixColumnIndexColumnName, matrixRowIndexColumnName, labelColumn, context, advancedSettings); |
| 98 | + Host.CheckNonEmpty(labelColumn, nameof(labelColumn)); |
| 99 | + var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn); |
| 100 | + return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray(); |
33 | 101 | }
|
34 | 102 | }
|
35 | 103 | }
|
0 commit comments