Skip to content

Commit 9671001

Browse files
author
Rogan Carr
committed
Rebasing and fixing to reflect changes in master.
1 parent b48c1fb commit 9671001

File tree

1 file changed

+8
-20
lines changed

1 file changed

+8
-20
lines changed

test/Microsoft.ML.Functional.Tests/Training.cs

+8-20
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ public void CompareTrainerEvaluations()
3232
// Get the dataset.
3333
var data = mlContext.Data.LoadFromTextFile<TweetSentiment>(GetDataPath(TestDatasets.Sentiment.trainFilename),
3434
separatorChar: TestDatasets.Sentiment.fileSeparator,
35-
hasHeader: TestDatasets.Sentiment.fileHasHeader,
35+
hasHeader: TestDatasets.Sentiment.fileHasHeader,
3636
allowQuoting: TestDatasets.Sentiment.allowQuoting);
37-
var trainTestSplit = mlContext.BinaryClassification.TrainTestSplit(data);
37+
var trainTestSplit = mlContext.Data.TrainTestSplit(data);
3838
var trainData = trainTestSplit.TrainSet;
3939
var testData = trainTestSplit.TestSet;
4040

@@ -266,6 +266,7 @@ public void ContinueTrainingLogisticRegressionMulticlass()
266266

267267
// Create a training pipeline.
268268
var featurizationPipeline = mlContext.Transforms.Concatenate("Features", Iris.Features)
269+
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label"))
269270
.AppendCacheCheckpoint(mlContext);
270271

271272
var trainer = mlContext.MulticlassClassification.Trainers.LogisticRegression(
@@ -467,8 +468,7 @@ public void MetacomponentsFunctionAsExpectedOva()
467468
var binaryClassificationPipeline = mlContext.Transforms.Concatenate("Features", Iris.Features)
468469
.AppendCacheCheckpoint(mlContext)
469470
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label"))
470-
.Append(mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryclassificationTrainer))
471-
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
471+
.Append(mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryclassificationTrainer));
472472

473473
// Fit the binary classification pipeline.
474474
var binaryClassificationModel = binaryClassificationPipeline.Fit(data);
@@ -503,40 +503,28 @@ public void MetacomponentsFunctionAsExpectedOva()
503503
// Create a model training an OVA trainer with a ranking trainer.
504504
var rankingTrainer = mlContext.Ranking.Trainers.FastTree(
505505
new FastTreeRankingTrainer.Options { NumberOfTrees = 2, NumberOfThreads = 1, });
506+
// Todo #2920: Make this fail somehow.
506507
var rankingPipeline = mlContext.Transforms.Concatenate("Features", Iris.Features)
507508
.AppendCacheCheckpoint(mlContext)
508509
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label"))
509510
.Append(mlContext.MulticlassClassification.Trainers.OneVersusAll(rankingTrainer))
510511
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
511512

512513
// Fit the invalid pipeline.
513-
// Todo #2920: Make this fail somehow.
514-
var rankingModel = rankingPipeline.Fit(data);
515-
516-
// Transform the data
517-
var rankingPredictions = rankingModel.Transform(data);
518-
519-
// Evaluate the model.
520-
var rankingMetrics = mlContext.MulticlassClassification.Evaluate(rankingPredictions);
514+
Assert.Throws<ArgumentOutOfRangeException>(() => rankingPipeline.Fit(data));
521515

522516
// Create a model training an OVA trainer with a regressor.
523517
var regressionTrainer = mlContext.Regression.Trainers.PoissonRegression(
524518
new PoissonRegression.Options { NumberOfIterations = 10, NumberOfThreads = 1, });
519+
// Todo #2920: Make this fail somehow.
525520
var regressionPipeline = mlContext.Transforms.Concatenate("Features", Iris.Features)
526521
.AppendCacheCheckpoint(mlContext)
527522
.Append(mlContext.Transforms.Conversion.MapValueToKey("Label"))
528523
.Append(mlContext.MulticlassClassification.Trainers.OneVersusAll(regressionTrainer))
529524
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
530525

531526
// Fit the invalid pipeline.
532-
// Todo #2920: Make this fail somehow.
533-
var regressionModel = regressionPipeline.Fit(data);
534-
535-
// Transform the data
536-
var regressionPredictions = regressionModel.Transform(data);
537-
538-
// Evaluate the model.
539-
var regressionMetrics = mlContext.MulticlassClassification.Evaluate(regressionPredictions);
527+
Assert.Throws<ArgumentOutOfRangeException>(() => regressionPipeline.Fit(data));
540528
}
541529
}
542530
}

0 commit comments

Comments
 (0)