Skip to content

Commit 1c44ea2

Browse files
committed
renaming the NgramContingEstimator and the NgramExtractingEstimator to swap names.
Adding the Recommendation context.
1 parent 1f4b01d commit 1c44ea2

File tree

23 files changed

+258
-184
lines changed

23 files changed

+258
-184
lines changed

docs/code/MlNetCookBook.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -1132,7 +1132,7 @@ var learningPipeline = reader.MakeNewEstimator()
11321132
BagOfBigrams: r.Message.NormalizeText().ToBagofHashedWords(ngramLength: 2, allLengths: false),
11331133

11341134
// NLP pipeline 3: bag of tri-character sequences with TF-IDF weighting.
1135-
BagOfTrichar: r.Message.TokenizeIntoCharacters().ToNgrams(ngramLength: 3, weighting: NgramCountingEstimator.WeightingCriteria.TfIdf),
1135+
BagOfTrichar: r.Message.TokenizeIntoCharacters().ToNgrams(ngramLength: 3, weighting: NgramExtractingEstimator.WeightingCriteria.TfIdf),
11361136

11371137
// NLP pipeline 4: word embeddings.
11381138
Embeddings: r.Message.NormalizeText().TokenizeText().WordEmbeddings(WordEmbeddingsExtractorTransformer.PretrainedModelKind.GloVeTwitter25D)
@@ -1186,8 +1186,8 @@ var dynamicPipeline =
11861186

11871187
// NLP pipeline 3: bag of tri-character sequences with TF-IDF weighting.
11881188
.Append(mlContext.Transforms.Text.TokenizeCharacters("Message", "MessageChars"))
1189-
.Append(new NgramCountingEstimator(mlContext, "MessageChars", "BagOfTrichar",
1190-
ngramLength: 3, weighting: NgramTokenizingTransformer.WeightingCriteria.TfIdf))
1189+
.Append(new NgramExtractingEstimator(mlContext, "MessageChars", "BagOfTrichar",
1190+
ngramLength: 3, weighting: NgramExtractingEstimator.WeightingCriteria.TfIdf))
11911191

11921192
// NLP pipeline 4: word embeddings.
11931193
.Append(mlContext.Transforms.Text.TokenizeWords("NormalizedMessage", "TokenizedMessage"))

docs/samples/Microsoft.ML.Samples/Dynamic/MatrixFactorization.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ public static void MatrixFactorizationInMemoryData()
6464
// Create a matrix factorization trainer which may consume "Value" as the training label, "MatrixColumnIndex" as the
6565
// matrix's column index, and "MatrixRowIndex" as the matrix's row index. Here nameof(...) is used to extract field
6666
// names' in MatrixElement class.
67-
var pipeline = new MatrixFactorizationTrainer(mlContext,
67+
var pipeline = mlContext.Recommendation().Trainers.MatrixFactorization(
6868
nameof(MatrixElement.MatrixColumnIndex),
6969
nameof(MatrixElement.MatrixRowIndex),
7070
nameof(MatrixElement.Value),
@@ -82,7 +82,7 @@ public static void MatrixFactorizationInMemoryData()
8282
var prediction = model.Transform(dataView);
8383

8484
// Calculate regression matrices for the prediction result.
85-
var metrics = mlContext.Regression.Evaluate(prediction,
85+
var metrics = mlContext.Recommendation().Evaluate(prediction,
8686
label: nameof(MatrixElement.Value), score: nameof(MatrixElementForScore.Score));
8787

8888
// Print out some metrics for checking the model's quality.

docs/samples/Microsoft.ML.Samples/Microsoft.ML.Samples.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
<NativeAssemblyReference Include="CpuMathNative" />
1818
<NativeAssemblyReference Include="FastTreeNative" />
19+
<NativeAssemblyReference Include="MatrixFactorizationNative" />
1920

2021
<ProjectReference Include="..\..\..\src\Microsoft.ML.Analyzer\Microsoft.ML.Analyzer.csproj">
2122
<ReferenceOutputAssembly>false</ReferenceOutputAssembly>

docs/samples/Microsoft.ML.Samples/Program.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ internal static class Program
66
{
77
static void Main(string[] args)
88
{
9-
NormalizerExample.Normalizer();
9+
MatrixFactorizationExample.MatrixFactorizationInMemoryData();
1010
}
1111
}
1212
}

src/Microsoft.ML.Data/MLContext.cs

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ public sealed class MLContext : IHostEnvironment
3535
/// Trainers and tasks specific to clustering problems.
3636
/// </summary>
3737
public ClusteringContext Clustering { get; }
38+
3839
/// <summary>
3940
/// Trainers and tasks specific to ranking problems.
4041
/// </summary>

src/Microsoft.ML.Data/TrainContext.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ namespace Microsoft.ML
2121
public abstract class TrainContextBase
2222
{
2323
protected readonly IHost Host;
24+
25+
[BestFriend]
2426
internal IHostEnvironment Environment => Host;
2527

2628
/// <summary>
@@ -162,6 +164,7 @@ private void EnsureStratificationColumn(ref IDataView data, ref string stratific
162164
/// </summary>
163165
public abstract class ContextInstantiatorBase
164166
{
167+
[BestFriend]
165168
internal TrainContextBase Owner { get; }
166169

167170
protected ContextInstantiatorBase(TrainContextBase ctx)
@@ -498,7 +501,7 @@ public RegressionEvaluator.Result Evaluate(IDataView data, string label = Defaul
498501
}
499502

500503
/// <summary>
501-
/// The central context for regression trainers.
504+
/// The central context for ranking trainers.
502505
/// </summary>
503506
public sealed class RankingContext : TrainContextBase
504507
{

src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,11 @@ public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent(this
7171
/// </format>
7272
/// </example>
7373
public static VectorWhiteningEstimator VectorWhiten(this TransformsCatalog.ProjectionTransforms catalog, string inputColumn, string outputColumn = null,
74-
WhiteningKind kind = VectorWhiteningTransformer.Defaults.Kind,
75-
float eps = VectorWhiteningTransformer.Defaults.Eps,
76-
int maxRows = VectorWhiteningTransformer.Defaults.MaxRows,
77-
int pcaNum = VectorWhiteningTransformer.Defaults.PcaNum)
78-
=> new VectorWhiteningEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, kind, eps, maxRows, pcaNum);
74+
WhiteningKind kind = VectorWhiteningTransformer.Defaults.Kind,
75+
float eps = VectorWhiteningTransformer.Defaults.Eps,
76+
int maxRows = VectorWhiteningTransformer.Defaults.MaxRows,
77+
int pcaNum = VectorWhiteningTransformer.Defaults.PcaNum)
78+
=> new VectorWhiteningEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, kind, eps, maxRows, pcaNum);
7979

8080
/// <summary>
8181
/// Takes columns filled with a vector of random variables with a known covariance matrix into a set of new variables whose covariance is the identity matrix,

src/Microsoft.ML.Legacy/CSharpApi.cs

+10-10
Original file line numberDiff line numberDiff line change
@@ -15421,15 +15421,15 @@ public sealed class Output
1542115421

1542215422
namespace Legacy.Transforms
1542315423
{
15424-
public enum NgramCountingEstimatorWeightingCriteria
15424+
public enum NgramExtractingEstimatorWeightingCriteria
1542515425
{
1542615426
Tf = 0,
1542715427
Idf = 1,
1542815428
TfIdf = 2
1542915429
}
1543015430

1543115431

15432-
public sealed partial class NgramCountingTransformerColumn : OneToOneColumn<NgramCountingTransformerColumn>, IOneToOneColumn
15432+
public sealed partial class NgramExtractingTransformerColumn : OneToOneColumn<NgramExtractingTransformerColumn>, IOneToOneColumn
1543315433
{
1543415434
/// <summary>
1543515435
/// Maximum ngram length
@@ -15454,7 +15454,7 @@ public sealed partial class NgramCountingTransformerColumn : OneToOneColumn<Ngra
1545415454
/// <summary>
1545515455
/// Statistical measure used to evaluate how important a word is to a document in a corpus
1545615456
/// </summary>
15457-
public NgramCountingEstimatorWeightingCriteria? Weighting { get; set; }
15457+
public NgramExtractingEstimatorWeightingCriteria? Weighting { get; set; }
1545815458

1545915459
/// <summary>
1546015460
/// Name of the new column
@@ -15500,23 +15500,23 @@ public NGramTranslator(params (string inputColumn, string outputColumn)[] inputO
1550015500

1550115501
public void AddColumn(string inputColumn)
1550215502
{
15503-
var list = Column == null ? new List<Microsoft.ML.Legacy.Transforms.NgramCountingTransformerColumn>() : new List<Microsoft.ML.Legacy.Transforms.NgramCountingTransformerColumn>(Column);
15504-
list.Add(OneToOneColumn<Microsoft.ML.Legacy.Transforms.NgramCountingTransformerColumn>.Create(inputColumn));
15503+
var list = Column == null ? new List<Microsoft.ML.Legacy.Transforms.NgramExtractingTransformerColumn>() : new List<Microsoft.ML.Legacy.Transforms.NgramExtractingTransformerColumn>(Column);
15504+
list.Add(OneToOneColumn<Microsoft.ML.Legacy.Transforms.NgramExtractingTransformerColumn>.Create(inputColumn));
1550515505
Column = list.ToArray();
1550615506
}
1550715507

1550815508
public void AddColumn(string outputColumn, string inputColumn)
1550915509
{
15510-
var list = Column == null ? new List<Microsoft.ML.Legacy.Transforms.NgramCountingTransformerColumn>() : new List<Microsoft.ML.Legacy.Transforms.NgramCountingTransformerColumn>(Column);
15511-
list.Add(OneToOneColumn<Microsoft.ML.Legacy.Transforms.NgramCountingTransformerColumn>.Create(outputColumn, inputColumn));
15510+
var list = Column == null ? new List<Microsoft.ML.Legacy.Transforms.NgramExtractingTransformerColumn>() : new List<Microsoft.ML.Legacy.Transforms.NgramExtractingTransformerColumn>(Column);
15511+
list.Add(OneToOneColumn<Microsoft.ML.Legacy.Transforms.NgramExtractingTransformerColumn>.Create(outputColumn, inputColumn));
1551215512
Column = list.ToArray();
1551315513
}
1551415514

1551515515

1551615516
/// <summary>
1551715517
/// New column definition(s) (optional form: name:src)
1551815518
/// </summary>
15519-
public NgramCountingTransformerColumn[] Column { get; set; }
15519+
public NgramExtractingTransformerColumn[] Column { get; set; }
1552015520

1552115521
/// <summary>
1552215522
/// Maximum ngram length
@@ -15541,7 +15541,7 @@ public void AddColumn(string outputColumn, string inputColumn)
1554115541
/// <summary>
1554215542
/// The weighting criteria
1554315543
/// </summary>
15544-
public NgramCountingEstimatorWeightingCriteria Weighting { get; set; } = NgramCountingEstimatorWeightingCriteria.Tf;
15544+
public NgramExtractingEstimatorWeightingCriteria Weighting { get; set; } = NgramExtractingEstimatorWeightingCriteria.Tf;
1554515545

1554615546
/// <summary>
1554715547
/// Input dataset
@@ -20138,7 +20138,7 @@ public sealed class NGramNgramExtractor : NgramExtractor
2013820138
/// <summary>
2013920139
/// The weighting criteria
2014020140
/// </summary>
20141-
public Microsoft.ML.Legacy.Transforms.NgramCountingEstimatorWeightingCriteria Weighting { get; set; } = Microsoft.ML.Legacy.Transforms.NgramCountingEstimatorWeightingCriteria.Tf;
20141+
public Microsoft.ML.Legacy.Transforms.NgramExtractingEstimatorWeightingCriteria Weighting { get; set; } = Microsoft.ML.Legacy.Transforms.NgramExtractingEstimatorWeightingCriteria.Tf;
2014220142

2014320143
internal override string ComponentName => "NGram";
2014420144
}

src/Microsoft.ML.Recommender/RecommenderCatalog.cs

+83-15
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,102 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using Microsoft.ML.Core.Data;
56
using Microsoft.ML.Core.Prediction;
67
using Microsoft.ML.Runtime;
78
using Microsoft.ML.Runtime.Data;
89
using Microsoft.ML.Trainers;
910
using System;
11+
using System.Linq;
1012

1113
namespace Microsoft.ML
1214
{
1315
public static class RecommenderCatalog
1416
{
17+
18+
/// <summary>
19+
/// Trainers and tasks specific to ranking problems.
20+
/// </summary>
21+
public static RecommendationContext Recommendation(this MLContext ctx) => new RecommendationContext(ctx);
22+
}
23+
24+
/// <summary>
25+
/// The central context for regression trainers.
26+
/// </summary>
27+
public sealed class RecommendationContext : TrainContextBase
28+
{
29+
/// <summary>
30+
/// For trainers for performing regression.
31+
/// </summary>
32+
public RecommendationTrainers Trainers { get; }
33+
34+
public RecommendationContext(IHostEnvironment env)
35+
: base(env, nameof(RecommendationContext))
36+
{
37+
Trainers = new RecommendationTrainers(this);
38+
}
39+
40+
public sealed class RecommendationTrainers : ContextInstantiatorBase
41+
{
42+
internal RecommendationTrainers(RecommendationContext ctx)
43+
: base(ctx)
44+
{
45+
}
46+
47+
/// <summary>
48+
/// Initializing a new instance of <see cref="MatrixFactorizationTrainer"/>.
49+
/// </summary>
50+
/// <param name="matrixColumnIndexColumnName">The name of the column hosting the matrix's column IDs.</param>
51+
/// <param name="matrixRowIndexColumnName">The name of the column hosting the matrix's row IDs.</param>
52+
/// <param name="labelColumn">The name of the label column.</param>
53+
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
54+
/// <param name="context">The <see cref="TrainerEstimatorContext"/> for additional input data to training.</param>
55+
public MatrixFactorizationTrainer MatrixFactorization(
56+
string matrixColumnIndexColumnName,
57+
string matrixRowIndexColumnName,
58+
string labelColumn = DefaultColumnNames.Label,
59+
TrainerEstimatorContext context = null,
60+
Action<MatrixFactorizationTrainer.Arguments> advancedSettings = null)
61+
=> new MatrixFactorizationTrainer(Owner.Environment, matrixColumnIndexColumnName, matrixRowIndexColumnName, labelColumn, context, advancedSettings);
62+
}
63+
64+
/// <summary>
65+
/// Evaluates scored recommendation data.
66+
/// </summary>
67+
/// <param name="data">The scored data.</param>
68+
/// <param name="label">The name of the label column in <paramref name="data"/>.</param>
69+
/// <param name="score">The name of the score column in <paramref name="data"/>.</param>
70+
/// <returns>The evaluation results for these calibrated outputs.</returns>
71+
public RegressionEvaluator.Result Evaluate(IDataView data, string label = DefaultColumnNames.Label, string score = DefaultColumnNames.Score)
72+
{
73+
Host.CheckValue(data, nameof(data));
74+
Host.CheckNonEmpty(label, nameof(label));
75+
Host.CheckNonEmpty(score, nameof(score));
76+
77+
var eval = new RegressionEvaluator(Host, new RegressionEvaluator.Arguments() { });
78+
return eval.Evaluate(data, label, score);
79+
}
80+
1581
/// <summary>
16-
/// Initializing a new instance of <see cref="MatrixFactorizationTrainer"/>.
82+
/// Run cross-validation over <paramref name="numFolds"/> folds of <paramref name="data"/>, by fitting <paramref name="estimator"/>,
83+
/// and respecting <paramref name="stratificationColumn"/> if provided.
84+
/// Then evaluate each sub-model against <paramref name="labelColumn"/> and return metrics.
1785
/// </summary>
18-
/// <param name="ctx">The <see cref="RegressionContext.RegressionTrainers"/> instance.</param>
19-
/// <param name="matrixColumnIndexColumnName">The name of the column hosting the matrix's column IDs.</param>
20-
/// <param name="matrixRowIndexColumnName">The name of the column hosting the matrix's row IDs.</param>
21-
/// <param name="labelColumn">The name of the label column.</param>
22-
/// <param name="advancedSettings">A delegate to apply all the advanced arguments to the algorithm.</param>
23-
/// <param name="context">The <see cref="TrainerEstimatorContext"/> for additional input data to training.</param>
24-
public static MatrixFactorizationTrainer MatrixFactorization(this RegressionContext.RegressionTrainers ctx,
25-
string matrixColumnIndexColumnName,
26-
string matrixRowIndexColumnName,
27-
string labelColumn = DefaultColumnNames.Label,
28-
TrainerEstimatorContext context = null,
29-
Action<MatrixFactorizationTrainer.Arguments> advancedSettings = null)
86+
/// <param name="data">The data to run cross-validation on.</param>
87+
/// <param name="estimator">The estimator to fit.</param>
88+
/// <param name="numFolds">Number of cross-validation folds.</param>
89+
/// <param name="labelColumn">The label column (for evaluation).</param>
90+
/// <param name="stratificationColumn">Optional stratification column.</param>
91+
/// <remarks>If two examples share the same value of the <paramref name="stratificationColumn"/> (if provided),
92+
/// they are guaranteed to appear in the same subset (train or test). Use this to make sure there is no label leakage from
93+
/// train to the test set.</remarks>
94+
/// <returns>Per-fold results: metrics, models, scored datasets.</returns>
95+
public (RegressionEvaluator.Result metrics, ITransformer model, IDataView scoredTestData)[] CrossValidate(
96+
IDataView data, IEstimator<ITransformer> estimator, int numFolds = 5, string labelColumn = DefaultColumnNames.Label, string stratificationColumn = null)
3097
{
31-
Contracts.CheckValue(ctx, nameof(ctx));
32-
return new MatrixFactorizationTrainer(CatalogUtils.GetEnvironment(ctx), matrixColumnIndexColumnName, matrixRowIndexColumnName, labelColumn, context, advancedSettings);
98+
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
99+
var result = CrossValidateTrain(data, estimator, numFolds, stratificationColumn);
100+
return result.Select(x => (Evaluate(x.scoredTestSet, labelColumn), x.model, x.scoredTestSet)).ToArray();
33101
}
34102
}
35103
}

src/Microsoft.ML.Transforms/EntryPoints/TextAnalytics.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,14 @@ public static CommonOutputs.TransformOutput DelimitedTokenizeTransform(IHostEnvi
5252
}
5353

5454
[TlcModule.EntryPoint(Name = "Transforms.NGramTranslator",
55-
Desc = NgramCountingTransformer.Summary,
56-
UserName = NgramCountingTransformer.UserName,
57-
ShortName = NgramCountingTransformer.LoaderSignature,
55+
Desc = NgramExtractingTransformer.Summary,
56+
UserName = NgramExtractingTransformer.UserName,
57+
ShortName = NgramExtractingTransformer.LoaderSignature,
5858
XmlInclude = new[] { @"<include file='../Microsoft.ML.Transforms/Text/doc.xml' path='doc/members/member[@name=""NgramTranslator""]/*' />" })]
59-
public static CommonOutputs.TransformOutput NGramTransform(IHostEnvironment env, NgramCountingTransformer.Arguments input)
59+
public static CommonOutputs.TransformOutput NGramTransform(IHostEnvironment env, NgramExtractingTransformer.Arguments input)
6060
{
6161
var h = EntryPointUtils.CheckArgsAndCreateHost(env, "NGramTransform", input);
62-
var xf = NgramCountingTransformer.Create(h, input, input.Data);
62+
var xf = NgramExtractingTransformer.Create(h, input, input.Data);
6363
return new CommonOutputs.TransformOutput()
6464
{
6565
Model = new TransformModel(h, xf, input.Data),

0 commit comments

Comments
 (0)