Skip to content

Commit 73762a8

Browse files
authored
Fix zero based key input from C# classes for matrix factorization (#1507)
This PR makes 0-based key types loaded properly by modifying the getter for C# classes. The previous framework doesn't map raw key value to the actual ordinal number; for example, 0/1/2 in 0-based key should be mapped to 1/2/3 in ML.NET's type system while the previous system maps 0/1/2 to just 0/1/2 if they are read from C# classes. Note that out-of-range values should be mapped to 0. In addition, matrix factorization module is improved for a better code quality.
1 parent f920262 commit 73762a8

File tree

4 files changed

+212
-48
lines changed

4 files changed

+212
-48
lines changed

src/Microsoft.ML.Api/DataViewConstructionUtils.cs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,16 @@ private Delegate CreateGetter(ColumnType colType, InternalSchemaDefinition.Colum
207207
Host.Assert(colType.RawType == Nullable.GetUnderlyingType(outputType));
208208
else
209209
Host.Assert(colType.RawType == outputType);
210-
del = CreateDirectGetterDelegate<int>;
210+
211+
if (!colType.IsKey)
212+
del = CreateDirectGetterDelegate<int>;
213+
else
214+
{
215+
var keyRawType = colType.RawType;
216+
Host.Assert(colType.AsKey.Contiguous);
217+
Func<Delegate, ColumnType, Delegate> delForKey = CreateKeyGetterDelegate<uint>;
218+
return Utils.MarshalInvoke(delForKey, keyRawType, peek, colType);
219+
}
211220
}
212221
else
213222
{
@@ -288,6 +297,38 @@ private Delegate CreateDirectGetterDelegate<TDst>(Delegate peekDel)
288297
peek(GetCurrentRowObject(), Position, ref dst));
289298
}
290299

300+
private Delegate CreateKeyGetterDelegate<TDst>(Delegate peekDel, ColumnType colType)
301+
{
302+
// Make sure the function is dealing with key.
303+
Host.Check(colType.IsKey);
304+
// Following equations work only with contiguous key type.
305+
Host.Check(colType.AsKey.Contiguous);
306+
// Following equations work only with unsigned integers.
307+
Host.Check(typeof(TDst) == typeof(ulong) || typeof(TDst) == typeof(uint) ||
308+
typeof(TDst) == typeof(byte) || typeof(TDst) == typeof(bool));
309+
310+
// Convert delegate function to a function which can fetch the underlying value.
311+
var peek = peekDel as Peek<TRow, TDst>;
312+
Host.AssertValue(peek);
313+
314+
TDst rawKeyValue = default;
315+
ulong key = 0; // the raw key value as ulong
316+
ulong min = colType.AsKey.Min;
317+
ulong max = min + (ulong)colType.AsKey.Count - 1;
318+
ulong result = 0; // the result as ulong
319+
ValueGetter<TDst> getter = (ref TDst dst) =>
320+
{
321+
peek(GetCurrentRowObject(), Position, ref rawKeyValue);
322+
key = (ulong)Convert.ChangeType(rawKeyValue, typeof(ulong));
323+
if (min <= key && key <= max)
324+
result = key - min + 1;
325+
else
326+
result = 0;
327+
dst = (TDst)Convert.ChangeType(result, typeof(TDst));
328+
};
329+
return getter;
330+
}
331+
291332
protected abstract TRow GetCurrentRowObject();
292333

293334
public bool IsColumnActive(int col)

src/Microsoft.ML.Recommender/MatrixFactorizationPredictor.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
using Microsoft.ML.Runtime.Model;
1414
using Microsoft.ML.Runtime.Recommender;
1515
using Microsoft.ML.Runtime.Recommender.Internal;
16-
using Microsoft.ML.Trainers;
1716
using Microsoft.ML.Trainers.Recommender;
1817

1918
[assembly: LoadableClass(typeof(MatrixFactorizationPredictor), null, typeof(SignatureLoadModel), "Matrix Factorization Predictor Executor", MatrixFactorizationPredictor.LoaderSignature)]
@@ -347,9 +346,12 @@ private Delegate[] CreateGetter(IRow input, bool[] active)
347346
var getters = new Delegate[1];
348347
if (active[0])
349348
{
349+
// First check if expected columns are ok and then create getters to acccess those columns' values.
350350
CheckInputSchema(input.Schema, _matrixColumnIndexColumnIndex, _matrixRowIndexCololumnIndex);
351-
var matrixColumnIndexGetter = input.GetGetter<uint>(_matrixColumnIndexColumnIndex);
352-
var matrixRowIndexGetter = input.GetGetter<uint>(_matrixRowIndexCololumnIndex);
351+
var matrixColumnIndexGetter = RowCursorUtils.GetGetterAs<uint>(NumberType.U4, input, _matrixColumnIndexColumnIndex);
352+
var matrixRowIndexGetter = RowCursorUtils.GetGetterAs<uint>(NumberType.U4, input, _matrixRowIndexCololumnIndex);
353+
354+
// Assign the getter of the prediction score. It maps a pair of matrix column index and matrix row index to a scalar.
353355
getters[0] = _parent.GetGetter(matrixColumnIndexGetter, matrixRowIndexGetter);
354356
}
355357
return getters;

src/Microsoft.ML.Recommender/MatrixFactorizationTrainer.cs

Lines changed: 63 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -89,44 +89,50 @@ public sealed class MatrixFactorizationTrainer : TrainerBase<MatrixFactorization
8989
{
9090
public sealed class Arguments
9191
{
92-
[Argument(ArgumentType.AtMostOnce, HelpText = "Regularization parameter")]
92+
[Argument(ArgumentType.AtMostOnce, HelpText = "Regularization parameter. " +
93+
"It's the weight of factor matrices' norms in the objective function minimized by matrix factorization's algorithm. " +
94+
"A small value could cause over-fitting.")]
9395
[TGUI(SuggestedSweeps = "0.01,0.05,0.1,0.5,1")]
9496
[TlcModule.SweepableDiscreteParam("Lambda", new object[] { 0.01f, 0.05f, 0.1f, 0.5f, 1f })]
95-
public Double Lambda = 0.1;
97+
public double Lambda = 0.1;
9698

97-
[Argument(ArgumentType.AtMostOnce, HelpText = "Latent space dimension")]
99+
[Argument(ArgumentType.AtMostOnce, HelpText = "Latent space dimension (denoted by k). If the factorized matrix is m-by-n, " +
100+
"two factor matrices found by matrix factorization are m-by-k and k-by-n, respectively. " +
101+
"This value is also known as the rank of matrix factorization because k is generally much smaller than m and n.")]
98102
[TGUI(SuggestedSweeps = "8,16,64,128")]
99103
[TlcModule.SweepableDiscreteParam("K", new object[] { 8, 16, 64, 128 })]
100104
public int K = 8;
101105

102-
[Argument(ArgumentType.AtMostOnce, HelpText = "Training iterations", ShortName = "iter")]
106+
[Argument(ArgumentType.AtMostOnce, HelpText = "Training iterations; that is, the times that the training algorithm iterates through the whole training data once.", ShortName = "iter")]
103107
[TGUI(SuggestedSweeps = "10,20,40")]
104108
[TlcModule.SweepableDiscreteParam("NumIterations", new object[] { 10, 20, 40 })]
105109
public int NumIterations = 20;
106110

107-
[Argument(ArgumentType.AtMostOnce, HelpText = "Initial learning rate")]
111+
[Argument(ArgumentType.AtMostOnce, HelpText = "Initial learning rate. It specifies the speed of the training algorithm. " +
112+
"Small value may increase the number of iterations needed to achieve a reasonable result. Large value may lead to numerical difficulty such as a infinity value.")]
108113
[TGUI(SuggestedSweeps = "0.001,0.01,0.1")]
109114
[TlcModule.SweepableDiscreteParam("Eta", new object[] { 0.001f, 0.01f, 0.1f })]
110-
public Double Eta = 0.1;
115+
public double Eta = 0.1;
111116

112-
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of threads", ShortName = "t")]
117+
[Argument(ArgumentType.AtMostOnce, HelpText = "Number of threads can be used in the training procedure.", ShortName = "t")]
113118
public int? NumThreads;
114119

115-
[Argument(ArgumentType.AtMostOnce, HelpText = "Suppress writing additional information to output")]
120+
[Argument(ArgumentType.AtMostOnce, HelpText = "Suppress writing additional information to output.")]
116121
public bool Quiet;
117122

118-
[Argument(ArgumentType.AtMostOnce, HelpText = "Force the matrix factorization P and Q to be non-negative", ShortName = "nn")]
123+
[Argument(ArgumentType.AtMostOnce, HelpText = "Force the factor matrices to be non-negative.", ShortName = "nn")]
119124
public bool NonNegative;
120125
};
121126

122127
internal const string Summary = "From pairs of row/column indices and a value of a matrix, this trains a predictor capable of filling in unknown entries of the matrix, "
123-
+ "utilizing a low-rank matrix factorization. This technique is often used in recommender system, where the row and column indices indicate users and items, "
124-
+ "and the value of the matrix is some rating. ";
128+
+ "using a low-rank matrix factorization. This technique is often used in recommender system, where the row and column indices indicate users and items, "
129+
+ "and the values of the matrix are ratings. ";
125130

126-
private readonly Double _lambda;
131+
// LIBMF's parameter
132+
private readonly double _lambda;
127133
private readonly int _k;
128134
private readonly int _iter;
129-
private readonly Double _eta;
135+
private readonly double _eta;
130136
private readonly int _threads;
131137
private readonly bool _quiet;
132138
private readonly bool _doNmf;
@@ -135,16 +141,28 @@ public sealed class Arguments
135141
public const string LoadNameValue = "MatrixFactorization";
136142

137143
/// <summary>
138-
/// The row, column, and label columns that the trainer expects. This module uses tuples of (row index, column index, label value) to specify a matrix.
144+
/// The row index, column index, and label columns needed to specify the training matrix. This trainer uses tuples of (row index, column index, label value) to specify a matrix.
139145
/// For example, a 2-by-2 matrix
140146
/// [9, 4]
141147
/// [8, 7]
142148
/// can be encoded as tuples (0, 0, 9), (0, 1, 4), (1, 0, 8), and (1, 1, 7). It means that the row/column/label column contains [0, 0, 1, 1]/
143149
/// [0, 1, 0, 1]/[9, 4, 8, 7].
144150
/// </summary>
145-
public readonly SchemaShape.Column MatrixColumnIndexColumn; // column indices of the training matrix
146-
public readonly SchemaShape.Column MatrixRowIndexColumn; // row indices of the training matrix
147-
public readonly SchemaShape.Column LabelColumn;
151+
152+
/// <summary>
153+
/// The name of variable (i.e., Column in a <see cref="IDataView"/> type system) used be as matrix's column index.
154+
/// </summary>
155+
public readonly string MatrixColumnIndexName;
156+
157+
/// <summary>
158+
/// The name of variable (i.e., column in a <see cref="IDataView"/> type system) used as matrix's row index.
159+
/// </summary>
160+
public readonly string MatrixRowIndexName;
161+
162+
/// <summary>
163+
/// The name variable (i.e., column in a <see cref="IDataView"/> type system) used as matrix's element value.
164+
/// </summary>
165+
public readonly string LabelName;
148166

149167
/// <summary>
150168
/// The <see cref="TrainerInfo"/> contains general parameters for this trainer.
@@ -155,7 +173,7 @@ public sealed class Arguments
155173
/// Extra information the trainer can use. For example, its validation set (if not null) can be use to evaluate the
156174
/// training progress made at each training iteration.
157175
/// </summary>
158-
public readonly TrainerEstimatorContext Context;
176+
private readonly TrainerEstimatorContext _context;
159177

160178
/// <summary>
161179
/// Legacy constructor initializing a new instance of <see cref="MatrixFactorizationTrainer"/> through the legacy
@@ -209,11 +227,11 @@ public MatrixFactorizationTrainer(IHostEnvironment env, string labelColumn, stri
209227
_doNmf = args.NonNegative;
210228

211229
Info = new TrainerInfo(normalization: false, caching: false);
212-
Context = context;
230+
_context = context;
213231

214-
LabelColumn = new SchemaShape.Column(labelColumn, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
215-
MatrixColumnIndexColumn = new SchemaShape.Column(matrixColumnIndexColumnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
216-
MatrixRowIndexColumn = new SchemaShape.Column(matrixRowIndexColumnName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
232+
LabelName = labelColumn;
233+
MatrixColumnIndexName = matrixColumnIndexColumnName;
234+
MatrixRowIndexName = matrixRowIndexColumnName;
217235
}
218236

219237
/// <summary>
@@ -270,22 +288,21 @@ private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data,
270288
int rowCount = matrixRowIndexColInfo.Type.KeyCount;
271289
ch.Assert(rowCount > 0);
272290
ch.Assert(colCount > 0);
273-
// Checks for equality on the validation set ensure it is correct here.
274291

292+
// Checks for equality on the validation set ensure it is correct here.
275293
using (var cursor = data.Data.GetRowCursor(c => c == matrixColumnIndexColInfo.Index || c == matrixRowIndexColInfo.Index || c == data.Schema.Label.Index))
276294
{
277295
// LibMF works only over single precision floats, but we want to be able to consume either.
278-
ValueGetter<Single> labGetter = RowCursorUtils.GetGetterAs<Single>(NumberType.R4, cursor, data.Schema.Label.Index);
279-
var matrixColumnIndexGetter = cursor.GetGetter<uint>(matrixColumnIndexColInfo.Index);
280-
var matrixRowIndexGetter = cursor.GetGetter<uint>(matrixRowIndexColInfo.Index);
296+
var labGetter = RowCursorUtils.GetGetterAs<float>(NumberType.R4, cursor, data.Schema.Label.Index);
297+
var matrixColumnIndexGetter = RowCursorUtils.GetGetterAs<uint>(NumberType.U4, cursor, matrixColumnIndexColInfo.Index);
298+
var matrixRowIndexGetter = RowCursorUtils.GetGetterAs<uint>(NumberType.U4, cursor, matrixRowIndexColInfo.Index);
281299

282300
if (validData == null)
283301
{
284302
// Have the trainer do its work.
285303
using (var buffer = PrepareBuffer())
286304
{
287-
buffer.Train(ch, rowCount, colCount,
288-
cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter);
305+
buffer.Train(ch, rowCount, colCount, cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter);
289306
predictor = new MatrixFactorizationPredictor(Host, buffer, matrixColumnIndexColInfo.Type.AsKey, matrixRowIndexColInfo.Type.AsKey);
290307
}
291308
}
@@ -294,16 +311,16 @@ private MatrixFactorizationPredictor TrainCore(IChannel ch, RoleMappedData data,
294311
using (var validCursor = validData.Data.GetRowCursor(
295312
c => c == validMatrixColumnIndexColInfo.Index || c == validMatrixRowIndexColInfo.Index || c == validData.Schema.Label.Index))
296313
{
297-
ValueGetter<Single> validLabGetter = RowCursorUtils.GetGetterAs<Single>(NumberType.R4, validCursor, validData.Schema.Label.Index);
298-
var validXGetter = validCursor.GetGetter<uint>(validMatrixColumnIndexColInfo.Index);
299-
var validYGetter = validCursor.GetGetter<uint>(validMatrixRowIndexColInfo.Index);
314+
ValueGetter<float> validLabelGetter = RowCursorUtils.GetGetterAs<float>(NumberType.R4, validCursor, validData.Schema.Label.Index);
315+
var validMatrixColumnIndexGetter = RowCursorUtils.GetGetterAs<uint>(NumberType.U4, validCursor, validMatrixColumnIndexColInfo.Index);
316+
var validMatrixRowIndexGetter = RowCursorUtils.GetGetterAs<uint>(NumberType.U4, validCursor, validMatrixRowIndexColInfo.Index);
300317

301318
// Have the trainer do its work.
302319
using (var buffer = PrepareBuffer())
303320
{
304321
buffer.TrainWithValidation(ch, rowCount, colCount,
305322
cursor, labGetter, matrixRowIndexGetter, matrixColumnIndexGetter,
306-
validCursor, validLabGetter, validYGetter, validXGetter);
323+
validCursor, validLabelGetter, validMatrixRowIndexGetter, validMatrixColumnIndexGetter);
307324
predictor = new MatrixFactorizationPredictor(Host, buffer, matrixColumnIndexColInfo.Type.AsKey, matrixRowIndexColInfo.Type.AsKey);
308325
}
309326
}
@@ -328,20 +345,20 @@ public MatrixFactorizationPredictionTransformer Fit(IDataView input)
328345
MatrixFactorizationPredictor model = null;
329346

330347
var roles = new List<KeyValuePair<RoleMappedSchema.ColumnRole, string>>();
331-
roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Label, LabelColumn.Name));
332-
roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RecommenderUtils.MatrixColumnIndexKind.Value, MatrixColumnIndexColumn.Name));
333-
roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RecommenderUtils.MatrixRowIndexKind.Value, MatrixRowIndexColumn.Name));
348+
roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RoleMappedSchema.ColumnRole.Label, LabelName));
349+
roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RecommenderUtils.MatrixColumnIndexKind.Value, MatrixColumnIndexName));
350+
roles.Add(new KeyValuePair<RoleMappedSchema.ColumnRole, string>(RecommenderUtils.MatrixRowIndexKind.Value, MatrixRowIndexName));
334351

335352
var trainingData = new RoleMappedData(input, roles);
336-
var validData = Context == null ? null : new RoleMappedData(Context.ValidationSet, roles);
353+
var validData = _context == null ? null : new RoleMappedData(_context.ValidationSet, roles);
337354

338355
using (var ch = Host.Start("Training"))
339356
using (var pch = Host.StartProgressChannel("Training"))
340357
{
341358
model = TrainCore(ch, trainingData, validData);
342359
}
343360

344-
return new MatrixFactorizationPredictionTransformer(Host, model, input.Schema, MatrixColumnIndexColumn.Name, MatrixRowIndexColumn.Name);
361+
return new MatrixFactorizationPredictionTransformer(Host, model, input.Schema, MatrixColumnIndexName, MatrixRowIndexName);
345362
}
346363

347364
public SchemaShape GetOutputSchema(SchemaShape inputSchema)
@@ -357,13 +374,15 @@ void CheckColumnsCompatible(SchemaShape.Column cachedColumn, string expectedColu
357374
throw Host.Except($"{expectedColumnName} column '{cachedColumn.Name}' is not compatible");
358375
}
359376

360-
// In prediction phase, no label column is expected.
361-
if (LabelColumn != null)
362-
CheckColumnsCompatible(LabelColumn, LabelColumn.Name);
377+
// Check if label column is good.
378+
var labelColumn = new SchemaShape.Column(LabelName, SchemaShape.Column.VectorKind.Scalar, NumberType.R4, false);
379+
CheckColumnsCompatible(labelColumn, LabelName);
363380

364-
// In both of training and prediction phases, we need columns of user ID and column ID.
365-
CheckColumnsCompatible(MatrixColumnIndexColumn, MatrixColumnIndexColumn.Name);
366-
CheckColumnsCompatible(MatrixRowIndexColumn, MatrixRowIndexColumn.Name);
381+
// Check if columns of matrix's row and column indexes are good. Note that column of IDataView and column of matrix are two different things.
382+
var matrixColumnIndexColumn = new SchemaShape.Column(MatrixColumnIndexName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
383+
var matrixRowIndexColumn = new SchemaShape.Column(MatrixRowIndexName, SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true);
384+
CheckColumnsCompatible(matrixColumnIndexColumn, MatrixColumnIndexName);
385+
CheckColumnsCompatible(matrixRowIndexColumn, MatrixRowIndexName);
367386

368387
// Input columns just pass through so that output column dictionary contains all input columns.
369388
var outColumns = inputSchema.Columns.ToDictionary(x => x.Name);
@@ -377,7 +396,7 @@ void CheckColumnsCompatible(SchemaShape.Column cachedColumn, string expectedColu
377396

378397
private SchemaShape.Column[] GetOutputColumnsCore(SchemaShape inputSchema)
379398
{
380-
bool success = inputSchema.TryFindColumn(LabelColumn.Name, out var labelCol);
399+
bool success = inputSchema.TryFindColumn(LabelName, out var labelCol);
381400
Contracts.Assert(success);
382401

383402
return new[]

0 commit comments

Comments
 (0)