@@ -18,7 +18,7 @@ namespace Microsoft.ML.Tests.TrainerEstimators
18
18
public partial class TrainerEstimators : TestDataPipeBase
19
19
{
20
20
[ ConditionalFact ( typeof ( Environment ) , nameof ( Environment . Is64BitProcess ) ) ] // This test is being fixed as part of issue #1441.
21
- public void MatrixFactorization_Estimator ( )
21
+ public void MatrixFactorizationEstimator ( )
22
22
{
23
23
string labelColumnName = "Label" ;
24
24
string matrixColumnIndexColumnName = "Col" ;
@@ -39,11 +39,11 @@ public void MatrixFactorization_Estimator()
39
39
MatrixRowIndexColumnName = matrixRowIndexColumnName ,
40
40
LabelColumnName = labelColumnName ,
41
41
NumIterations = 3 ,
42
- NumThreads = 1 ,
43
- K = 4 ,
42
+ NumThreads = 1 ,
43
+ ApproximationRank = 4 ,
44
44
} ;
45
45
46
- var est = new MatrixFactorizationTrainer ( Env , options ) ;
46
+ var est = ML . Recommendation ( ) . Trainers . MatrixFactorization ( options ) ;
47
47
48
48
TestEstimatorCore ( est , data , invalidInput : invalidData ) ;
49
49
@@ -68,20 +68,28 @@ public void MatrixFactorizationSimpleTrainAndPredict()
68
68
var data = reader . Read ( new MultiFileSource ( GetDataPath ( TestDatasets . trivialMatrixFactorization . trainFilename ) ) ) ;
69
69
70
70
// Create a pipeline with a single operator.
71
- var options = new MatrixFactorizationTrainer . Options {
71
+ var options = new MatrixFactorizationTrainer . Options
72
+ {
72
73
MatrixColumnIndexColumnName = userColumnName ,
73
74
MatrixRowIndexColumnName = itemColumnName ,
74
75
LabelColumnName = labelColumnName ,
75
76
NumIterations = 3 ,
76
77
NumThreads = 1 , // To eliminate randomness, # of threads must be 1.
77
- K = 7 ,
78
+ ApproximationRank = 7 ,
78
79
} ;
79
80
80
81
var pipeline = mlContext . Recommendation ( ) . Trainers . MatrixFactorization ( options ) ;
81
82
82
83
// Train a matrix factorization model.
83
84
var model = pipeline . Fit ( data ) ;
84
85
86
+ // Les's validate content of the model.
87
+ Assert . Equal ( model . Model . ApproximationRank , options . ApproximationRank ) ;
88
+ var leftMatrix = model . Model . GetLeftFactorMatrix ( ) ;
89
+ var rightMatrix = model . Model . GetRightFactorMatrix ( ) ;
90
+ Assert . Equal ( leftMatrix . Length , model . Model . NumberOfRows * model . Model . ApproximationRank ) ;
91
+ Assert . Equal ( rightMatrix . Length , model . Model . NumberOfColumns * model . Model . ApproximationRank ) ;
92
+
85
93
// Read the test data set as an IDataView
86
94
var testData = reader . Read ( new MultiFileSource ( GetDataPath ( TestDatasets . trivialMatrixFactorization . testFilename ) ) ) ;
87
95
@@ -197,7 +205,7 @@ public void MatrixFactorizationInMemoryData()
197
205
LabelColumnName = nameof ( MatrixElement . Value ) ,
198
206
NumIterations = 10 ,
199
207
NumThreads = 1 , // To eliminate randomness, # of threads must be 1.
200
- K = 32 ,
208
+ ApproximationRank = 32 ,
201
209
} ;
202
210
203
211
var pipeline = mlContext . Recommendation ( ) . Trainers . MatrixFactorization ( options ) ;
@@ -287,7 +295,7 @@ public void MatrixFactorizationInMemoryDataZeroBaseIndex()
287
295
LabelColumnName = nameof ( MatrixElement . Value ) ,
288
296
NumIterations = 100 ,
289
297
NumThreads = 1 , // To eliminate randomness, # of threads must be 1.
290
- K = 32 ,
298
+ ApproximationRank = 32 ,
291
299
LearningRate = 0.5 ,
292
300
} ;
293
301
@@ -409,7 +417,7 @@ public void OneClassMatrixFactorizationInMemoryDataZeroBaseIndex()
409
417
NumIterations = 100 ,
410
418
NumThreads = 1 , // To eliminate randomness, # of threads must be 1.
411
419
Lambda = 0.025 , // Let's test non-default regularization coefficient.
412
- K = 16 ,
420
+ ApproximationRank = 16 ,
413
421
Alpha = 0.01 , // Importance coefficient of loss function over matrix elements not specified in the input matrix.
414
422
C = 0.15 , // Desired value for matrix elements not specified in the input matrix.
415
423
} ;
0 commit comments