Skip to content

Commit 2a92786

Browse files
authored
Fix an initial-value problem caused by unseen row/column (#2525)
* Fix an initial-value problem caused by unseen row/column * Add a test * Also test unseen column
1 parent 70830ed commit 2a92786

File tree

2 files changed

+73
-1
lines changed

2 files changed

+73
-1
lines changed

test/Microsoft.ML.Tests/TrainerEstimators/MatrixFactorizationTests.cs

+72
Original file line numberDiff line numberDiff line change
@@ -527,5 +527,77 @@ public void MatrixFactorizationBackCompat()
527527
// Negative example (i.e., examples can not be found in dataMatrix) is close to 0.15 (specified by s.C = 0.15 in the trainer).
528528
CompareNumbersWithTolerance(0.141411, testResults[1].Score, digitsOfPrecision: 5);
529529
}
530+
531+
[MatrixFactorizationFact]
532+
public void OneClassMatrixFactorizationWithUnseenColumnAndRow()
533+
{
534+
// Create an in-memory matrix as a list of tuples (column index, row index, value). For one-class matrix
535+
// factorization problem, unspecified matrix elements are all a constant provided by user. If that constant is 0.15,
536+
// the following list means a 3-by-2 training matrix with elements:
537+
// (0, 0, 1), (0, 1, 1), (1, 0, 0.15), (1, 1, 0.15), (0, 2, 0.15), (1, 2, 0.15).
538+
// because matrix elements at (1, 0), (1, 1), (0, 2), and (1, 2) are not specified. Below is a visualization of the training matrix.
539+
// [1, ?]
540+
// |1, ?| where ? will be set to 0.15 by user when creating the trainer.
541+
// [?, ?]
542+
// Note that the second column and the third row are called unseen because they contain no training element (i.e., all its values are "?"s).
543+
var dataMatrix = new List<OneClassMatrixElementZeroBased>();
544+
dataMatrix.Add(new OneClassMatrixElementZeroBased() { MatrixColumnIndex = 0, MatrixRowIndex = 0, Value = 1 });
545+
dataMatrix.Add(new OneClassMatrixElementZeroBased() { MatrixColumnIndex = 0, MatrixRowIndex = 1, Value = 1 });
546+
547+
// Convert the in-memory matrix into an IDataView so that ML.NET components can consume it.
548+
var dataView = ML.Data.ReadFromEnumerable(dataMatrix);
549+
550+
// Create a matrix factorization trainer which may consume "Value" as the training label, "MatrixColumnIndex" as the
551+
// matrix's column index, and "MatrixRowIndex" as the matrix's row index.
552+
var mlContext = new MLContext(seed: 1, conc: 1);
553+
554+
var options = new MatrixFactorizationTrainer.Options
555+
{
556+
MatrixColumnIndexColumnName = nameof(MatrixElement.MatrixColumnIndex),
557+
MatrixRowIndexColumnName = nameof(MatrixElement.MatrixRowIndex),
558+
LabelColumnName = nameof(MatrixElement.Value),
559+
LossFunction = MatrixFactorizationTrainer.LossFunctionType.SquareLossOneClass,
560+
NumIterations = 100,
561+
NumThreads = 1, // To eliminate randomness, # of threads must be 1.
562+
Lambda = 0.025, // Let's test non-default regularization coefficient.
563+
ApproximationRank = 16,
564+
Alpha = 0.01, // Importance coefficient of loss function over matrix elements not specified in the input matrix.
565+
C = 0.15, // Desired value for matrix elements not specified in the input matrix.
566+
};
567+
568+
var pipeline = mlContext.Recommendation().Trainers.MatrixFactorization(options);
569+
570+
// Train a matrix factorization model.
571+
var model = pipeline.Fit(dataView);
572+
573+
// Apply the trained model to the training set.
574+
var prediction = model.Transform(dataView);
575+
576+
// Calculate regression matrices for the prediction result.
577+
var metrics = mlContext.Recommendation().Evaluate(prediction, label: "Value", score: "Score");
578+
579+
// Make sure the prediction error is not too large.
580+
Assert.InRange(metrics.L2, 0, 0.0016);
581+
582+
// Create data for testing.
583+
var testDataMatrix = new List<OneClassMatrixElementZeroBasedForScore>();
584+
testDataMatrix.Add(new OneClassMatrixElementZeroBasedForScore() { MatrixColumnIndex = 0, MatrixRowIndex = 0, Value = 0, Score = 0 });
585+
testDataMatrix.Add(new OneClassMatrixElementZeroBasedForScore() { MatrixColumnIndex = 1, MatrixRowIndex = 0, Value = 0, Score = 0 });
586+
testDataMatrix.Add(new OneClassMatrixElementZeroBasedForScore() { MatrixColumnIndex = 1, MatrixRowIndex = 2, Value = 0, Score = 0 });
587+
588+
// Convert the in-memory matrix into an IDataView so that ML.NET components can consume it.
589+
var testDataView = ML.Data.ReadFromEnumerable(testDataMatrix);
590+
591+
// Apply the trained model to the test data.
592+
var testPrediction = model.Transform(testDataView);
593+
594+
var testResults = mlContext.CreateEnumerable<OneClassMatrixElementZeroBasedForScore>(testPrediction, false).ToList();
595+
// Positive example (i.e., examples can be found in dataMatrix) is close to 1.
596+
CompareNumbersWithTolerance(0.9823623, testResults[0].Score, digitsOfPrecision: 5);
597+
// Negative examples' scores (i.e., examples can not be found in dataMatrix) are closer
598+
// to 0.15 (specified by s.C = 0.15 in the trainer) than positive example's score.
599+
CompareNumbersWithTolerance(0.05511549, testResults[1].Score, digitsOfPrecision: 5);
600+
CompareNumbersWithTolerance(0.00316973357, testResults[2].Score, digitsOfPrecision: 5);
601+
}
530602
}
531603
}

0 commit comments

Comments
 (0)