Skip to content

Commit 8f8f8ca

Browse files
author
Ivan Matantsev
committed
Lockdown Microsoft.ML.Recommender public surface
1 parent d01bd83 commit 8f8f8ca

File tree

3 files changed

+20
-12
lines changed

3 files changed

+20
-12
lines changed

src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs

+11-11
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ public PredictionKind PredictionKind
6464
get { return PredictionKind.Recommendation; }
6565
}
6666

67-
public ColumnType OutputType { get { return NumberType.Float; } }
67+
private ColumnType OutputType { get { return NumberType.Float; } }
6868

69-
public ColumnType MatrixColumnIndexType { get; }
70-
public ColumnType MatrixRowIndexType { get; }
69+
internal ColumnType MatrixColumnIndexType { get; }
70+
internal ColumnType MatrixRowIndexType { get; }
7171

7272
internal MatrixFactorizationPredictor(IHostEnvironment env, SafeTrainingAndModelBuffer buffer, KeyType matrixColumnIndexType, KeyType matrixRowIndexType)
7373
{
@@ -131,7 +131,7 @@ private MatrixFactorizationPredictor(IHostEnvironment env, ModelLoadContext ctx)
131131
/// <summary>
132132
/// Load model from the given context
133133
/// </summary>
134-
public static MatrixFactorizationPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
134+
private static MatrixFactorizationPredictor Create(IHostEnvironment env, ModelLoadContext ctx)
135135
{
136136
Contracts.CheckValue(env, nameof(env));
137137
env.CheckValue(ctx, nameof(ctx));
@@ -377,11 +377,11 @@ public Row GetRow(Row input, Func<int, bool> active)
377377

378378
public sealed class MatrixFactorizationPredictionTransformer : PredictionTransformerBase<MatrixFactorizationPredictor, GenericScorer>, ICanSaveModel
379379
{
380-
public const string LoaderSignature = "MaFactPredXf";
381-
public string MatrixColumnIndexColumnName { get; }
382-
public string MatrixRowIndexColumnName { get; }
383-
public ColumnType MatrixColumnIndexColumnType { get; }
384-
public ColumnType MatrixRowIndexColumnType { get; }
380+
internal const string LoaderSignature = "MaFactPredXf";
381+
internal string MatrixColumnIndexColumnName { get; }
382+
internal string MatrixRowIndexColumnName { get; }
383+
internal ColumnType MatrixColumnIndexColumnType { get; }
384+
internal ColumnType MatrixRowIndexColumnType { get; }
385385
protected override GenericScorer Scorer { get; set; }
386386

387387
/// <summary>
@@ -396,7 +396,7 @@ public sealed class MatrixFactorizationPredictionTransformer : PredictionTransfo
396396
/// <param name="matrixColumnIndexColumnName">The name of the column used as role <see cref="RecommenderUtils.MatrixColumnIndexKind"/> in matrix factorization world</param>
397397
/// <param name="matrixRowIndexColumnName">The name of the column used as role <see cref="RecommenderUtils.MatrixRowIndexKind"/> in matrix factorization world</param>
398398
/// <param name="scoreColumnNameSuffix">A string attached to the output column name of this transformer</param>
399-
public MatrixFactorizationPredictionTransformer(IHostEnvironment env, MatrixFactorizationPredictor model, Schema trainSchema,
399+
internal MatrixFactorizationPredictionTransformer(IHostEnvironment env, MatrixFactorizationPredictor model, Schema trainSchema,
400400
string matrixColumnIndexColumnName, string matrixRowIndexColumnName, string scoreColumnNameSuffix = "")
401401
: base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MatrixFactorizationPredictionTransformer)), model, trainSchema)
402402
{
@@ -433,7 +433,7 @@ private RoleMappedSchema GetSchema()
433433
/// The counter constructor of re-creating <see cref="MatrixFactorizationPredictionTransformer"/> from the context where
434434
/// the original transform is saved.
435435
/// </summary>
436-
public MatrixFactorizationPredictionTransformer(IHostEnvironment host, ModelLoadContext ctx)
436+
private MatrixFactorizationPredictionTransformer(IHostEnvironment host, ModelLoadContext ctx)
437437
: base(Contracts.CheckRef(host, nameof(host)).Register(nameof(MatrixFactorizationPredictionTransformer)), ctx)
438438
{
439439
// *** Binary format ***

src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs

+5-1
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ public sealed class Options
197197
private readonly bool _doNmf;
198198

199199
public override PredictionKind PredictionKind => PredictionKind.Recommendation;
200-
public const string LoadNameValue = "MatrixFactorization";
200+
internal const string LoadNameValue = "MatrixFactorization";
201201

202202
/// <summary>
203203
/// The row index, column index, and label columns needed to specify the training matrix. This trainer uses tuples of (row index, column index, label value) to specify a matrix.
@@ -427,6 +427,10 @@ public MatrixFactorizationPredictionTransformer Train(IDataView trainData, IData
427427
/// <param name="input">The training data set.</param>
428428
public MatrixFactorizationPredictionTransformer Fit(IDataView input) => Train(input);
429429

430+
/// <summary>
431+
/// Schema propagation for transformers. Returns the output schema of the data, if
432+
/// the input schema is like the one provided.
433+
/// </summary>
430434
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
431435
{
432436
Host.CheckValue(inputSchema, nameof(inputSchema));

src/Microsoft.ML.StandardLearners/FactorizationMachine/FactorizationMachineTrainer.cs

+4
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,10 @@ public FieldAwareFactorizationMachinePredictionTransformer Train(IDataView train
511511

512512
public FieldAwareFactorizationMachinePredictionTransformer Fit(IDataView input) => Train(input);
513513

514+
/// <summary>
515+
/// Schema propagation for transformers. Returns the output schema of the data, if
516+
/// the input schema is like the one provided.
517+
/// </summary>
514518
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
515519
{
516520

0 commit comments

Comments
 (0)