Skip to content

Commit 6fc0fd7

Browse files
yaeldMSeerhardt
authored andcommitted
Fix CV macro to output the warnings data view properly. (dotnet#385)
1 parent a651016 commit 6fc0fd7

File tree

4 files changed

+90
-4
lines changed

4 files changed

+90
-4
lines changed

src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs

+3-2
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ public Double MacroAvgAccuracy
256256
{
257257
get
258258
{
259+
if (_numInstances == 0)
260+
return 0;
259261
Double macroAvgAccuracy = 0;
260262
int countOfNonEmptyClasses = 0;
261263
for (int i = 0; i < _numClasses; ++i)
@@ -267,8 +269,7 @@ public Double MacroAvgAccuracy
267269
}
268270
}
269271

270-
Contracts.Assert(countOfNonEmptyClasses > 0);
271-
return macroAvgAccuracy / countOfNonEmptyClasses;
272+
return countOfNonEmptyClasses > 0 ? macroAvgAccuracy / countOfNonEmptyClasses : 0;
272273
}
273274
}
274275

src/Microsoft.ML.StandardLearners/Standard/OlsLinearRegression.cs

-2
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
"OLS Linear Regression Executor",
2929
OlsLinearRegressionPredictor.LoaderSignature)]
3030

31-
[assembly: LoadableClass(typeof(void), typeof(OlsLinearRegressionTrainer), null, typeof(SignatureEntryPointModule), OlsLinearRegressionTrainer.LoadNameValue)]
32-
3331
namespace Microsoft.ML.Runtime.Learners
3432
{
3533
public sealed class OlsLinearRegressionTrainer : TrainerBase<RoleMappedData, OlsLinearRegressionPredictor>

src/Microsoft.ML/Runtime/EntryPoints/CrossValidationMacro.cs

+4
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,10 @@ public static CommonOutputs.MacroOutput<Output> CrossValidate(
381381
// Set the input bindings for the CombineMetrics entry point.
382382
var combineInputBindingMap = new Dictionary<string, List<ParameterBinding>>();
383383
var combineInputMap = new Dictionary<ParameterBinding, VariableBinding>();
384+
385+
var warningsArray = new SimpleParameterBinding(nameof(combineArgs.Warnings));
386+
combineInputBindingMap.Add(nameof(combineArgs.Warnings), new List<ParameterBinding> { warningsArray });
387+
combineInputMap.Add(warningsArray, new SimpleVariableBinding(warningsOutput.OutputData.VarName));
384388
var overallArray = new SimpleParameterBinding(nameof(combineArgs.OverallMetrics));
385389
combineInputBindingMap.Add(nameof(combineArgs.OverallMetrics), new List<ParameterBinding> { overallArray });
386390
combineInputMap.Add(overallArray, new SimpleVariableBinding(overallMetricsOutput.OutputData.VarName));

test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs

+83
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,89 @@ public void TestCrossValidationMacroWithMultiClass()
528528
}
529529
Assert.Equal(0, rowCount);
530530
}
531+
532+
var warnings = experiment.GetOutput(crossValidateOutput.Warnings);
533+
using (var cursor = warnings.GetRowCursor(col => true))
534+
Assert.False(cursor.MoveNext());
535+
}
536+
}
537+
538+
[Fact]
539+
public void TestCrossValidationMacroMultiClassWithWarnings()
540+
{
541+
var dataPath = GetDataPath(@"Train-Tiny-28x28.txt");
542+
using (var env = new TlcEnvironment(42))
543+
{
544+
var subGraph = env.CreateExperiment();
545+
546+
var nop = new ML.Transforms.NoOperation();
547+
var nopOutput = subGraph.Add(nop);
548+
549+
var learnerInput = new ML.Trainers.LogisticRegressionClassifier
550+
{
551+
TrainingData = nopOutput.OutputData,
552+
NumThreads = 1
553+
};
554+
var learnerOutput = subGraph.Add(learnerInput);
555+
556+
var experiment = env.CreateExperiment();
557+
var importInput = new ML.Data.TextLoader(dataPath);
558+
var importOutput = experiment.Add(importInput);
559+
560+
var filter = new ML.Transforms.RowRangeFilter();
561+
filter.Data = importOutput.Data;
562+
filter.Column = "Label";
563+
filter.Min = 0;
564+
filter.Max = 5;
565+
var filterOutput = experiment.Add(filter);
566+
567+
var term = new ML.Transforms.TextToKeyConverter();
568+
term.Column = new[]
569+
{
570+
new ML.Transforms.TermTransformColumn()
571+
{
572+
Source = "Label", Name = "Strat", Sort = ML.Transforms.TermTransformSortOrder.Value
573+
}
574+
};
575+
term.Data = filterOutput.OutputData;
576+
var termOutput = experiment.Add(term);
577+
578+
var crossValidate = new ML.Models.CrossValidator
579+
{
580+
Data = termOutput.OutputData,
581+
Nodes = subGraph,
582+
Kind = ML.Models.MacroUtilsTrainerKinds.SignatureMultiClassClassifierTrainer,
583+
TransformModel = null,
584+
StratificationColumn = "Strat"
585+
};
586+
crossValidate.Inputs.Data = nop.Data;
587+
crossValidate.Outputs.PredictorModel = learnerOutput.PredictorModel;
588+
var crossValidateOutput = experiment.Add(crossValidate);
589+
590+
experiment.Compile();
591+
importInput.SetInput(env, experiment);
592+
experiment.Run();
593+
var warnings = experiment.GetOutput(crossValidateOutput.Warnings);
594+
595+
var schema = warnings.Schema;
596+
var b = schema.TryGetColumnIndex("WarningText", out int warningCol);
597+
Assert.True(b);
598+
using (var cursor = warnings.GetRowCursor(col => col == warningCol))
599+
{
600+
var getter = cursor.GetGetter<DvText>(warningCol);
601+
602+
b = cursor.MoveNext();
603+
Assert.True(b);
604+
var warning = default(DvText);
605+
getter(ref warning);
606+
Assert.Contains("test instances with class values not seen in the training set.", warning.ToString());
607+
b = cursor.MoveNext();
608+
Assert.True(b);
609+
getter(ref warning);
610+
Assert.Contains("Detected columns of variable length: SortedScores, SortedClasses", warning.ToString());
611+
b = cursor.MoveNext();
612+
Assert.False(b);
613+
}
531614
}
532615
}
533616

0 commit comments

Comments
 (0)