Skip to content

Commit dea940e

Browse files
committed
Merge branch 'master' of https://github.com/dotnet/machinelearning into cleanup
2 parents 5601f79 + 0233d71 commit dea940e

File tree

381 files changed

+293820
-615
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

381 files changed

+293820
-615
lines changed

Samples/UCI/readme.md

-3
This file was deleted.

src/Microsoft.ML.PipelineInference/PipelinePattern.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ public Experiment CreateTrainTestExperiment(IDataView trainData, IDataView testD
152152
},
153153
Outputs =
154154
{
155-
Model = finalOutput
155+
PredictorModel = finalOutput
156156
},
157157
PipelineId = UniqueId.ToString("N"),
158158
Kind = MacroUtils.TrainerKindApiValue<Models.MacroUtilsTrainerKinds>(trainerKind),
@@ -189,7 +189,7 @@ public Models.TrainTestEvaluator.Output AddAsTrainTest(Var<IDataView> trainData,
189189
},
190190
Outputs =
191191
{
192-
Model = finalOutput
192+
PredictorModel = finalOutput
193193
},
194194
TrainingData = trainData,
195195
TestingData = testData,

src/Microsoft.ML/CSharpApi.cs

+787-326
Large diffs are not rendered by default.

src/Microsoft.ML/Data/CollectionDataSource.cs

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ public void SetInput(IHostEnvironment environment, Experiment experiment)
5252
experiment.SetInput(_dataViewEntryPoint.Data, _dataView);
5353
}
5454

55+
public Var<IDataView> GetInputData() => null;
56+
5557
public abstract IDataView GetDataView(IHostEnvironment environment);
5658
}
5759

src/Microsoft.ML/ILearningPipelineItem.cs

+6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ namespace Microsoft.ML
1414
public interface ILearningPipelineItem
1515
{
1616
ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment);
17+
18+
/// <summary>
19+
/// Returns the place holder for input IDataView object for the node in the execution graph.
20+
/// </summary>
21+
/// <returns></returns>
22+
Var<IDataView> GetInputData();
1723
}
1824

1925
/// <summary>

src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs

+5-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
using Microsoft.ML.Runtime;
66
using Microsoft.ML.Runtime.Data;
7-
using Microsoft.ML.Runtime.EntryPoints;
87
using Microsoft.ML.Transforms;
98

109
namespace Microsoft.ML.Models
@@ -66,7 +65,11 @@ public BinaryClassificationMetrics Evaluate(PredictionModel model, ILearningPipe
6665
throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate.");
6766
}
6867

69-
return BinaryClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix);
68+
var metric = BinaryClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix);
69+
70+
Contracts.Check(metric.Count == 1, $"Exactly one metric set was expected but found {metric.Count} metrics");
71+
72+
return metric[0];
7073
}
7174
}
7275
}

src/Microsoft.ML/Models/BinaryClassificationMetrics.cs

+32-22
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using Microsoft.ML.Runtime.Api;
77
using Microsoft.ML.Runtime.Data;
88
using System;
9+
using System.Collections.Generic;
910

1011
namespace Microsoft.ML.Models
1112
{
@@ -18,41 +19,50 @@ private BinaryClassificationMetrics()
1819
{
1920
}
2021

21-
internal static BinaryClassificationMetrics FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix)
22+
internal static List<BinaryClassificationMetrics> FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix, int confusionMatriceStartIndex = 0)
2223
{
2324
Contracts.AssertValue(env);
2425
env.AssertValue(overallMetrics);
2526
env.AssertValue(confusionMatrix);
2627

2728
var metricsEnumerable = overallMetrics.AsEnumerable<SerializationClass>(env, true, ignoreMissingColumns: true);
28-
var enumerator = metricsEnumerable.GetEnumerator();
29-
if (!enumerator.MoveNext())
29+
if (!metricsEnumerable.GetEnumerator().MoveNext())
3030
{
3131
throw env.Except("The overall RegressionMetrics didn't have any rows.");
3232
}
3333

34-
SerializationClass metrics = enumerator.Current;
34+
List<BinaryClassificationMetrics> metrics = new List<BinaryClassificationMetrics>();
35+
var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator();
3536

36-
if (enumerator.MoveNext())
37+
int Index = 0;
38+
foreach(var metric in metricsEnumerable)
3739
{
38-
throw env.Except("The overall RegressionMetrics contained more than 1 row.");
40+
41+
if (Index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext())
42+
{
43+
throw env.Except("Confusion matrices didn't have enough matrices.");
44+
}
45+
46+
metrics.Add(
47+
new BinaryClassificationMetrics()
48+
{
49+
Auc = metric.Auc,
50+
Accuracy = metric.Accuracy,
51+
PositivePrecision = metric.PositivePrecision,
52+
PositiveRecall = metric.PositiveRecall,
53+
NegativePrecision = metric.NegativePrecision,
54+
NegativeRecall = metric.NegativeRecall,
55+
LogLoss = metric.LogLoss,
56+
LogLossReduction = metric.LogLossReduction,
57+
Entropy = metric.Entropy,
58+
F1Score = metric.F1Score,
59+
Auprc = metric.Auprc,
60+
ConfusionMatrix = confusionMatrices.Current,
61+
});
62+
3963
}
4064

41-
return new BinaryClassificationMetrics()
42-
{
43-
Auc = metrics.Auc,
44-
Accuracy = metrics.Accuracy,
45-
PositivePrecision = metrics.PositivePrecision,
46-
PositiveRecall = metrics.PositiveRecall,
47-
NegativePrecision = metrics.NegativePrecision,
48-
NegativeRecall = metrics.NegativeRecall,
49-
LogLoss = metrics.LogLoss,
50-
LogLossReduction = metrics.LogLossReduction,
51-
Entropy = metrics.Entropy,
52-
F1Score = metrics.F1Score,
53-
Auprc = metrics.Auprc,
54-
ConfusionMatrix = ConfusionMatrix.Create(env, confusionMatrix),
55-
};
65+
return metrics;
5666
}
5767

5868
/// <summary>
@@ -155,7 +165,7 @@ internal static BinaryClassificationMetrics FromMetrics(IHostEnvironment env, ID
155165
/// <summary>
156166
/// This class contains the public fields necessary to deserialize from IDataView.
157167
/// </summary>
158-
private class SerializationClass
168+
private sealed class SerializationClass
159169
{
160170
#pragma warning disable 649 // never assigned
161171
[ColumnName(BinaryClassifierEvaluator.Auc)]

src/Microsoft.ML/Models/ClassificationEvaluator.cs

+5-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,11 @@ public ClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLo
6666
throw environment.Except($"Could not find ConfusionMatrix in the results returned in {nameof(BinaryClassificationEvaluator)} Evaluate.");
6767
}
6868

69-
return ClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix);
69+
var metric = ClassificationMetrics.FromMetrics(environment, overallMetrics, confusionMatrix);
70+
71+
Contracts.Check(metric.Count == 1, $"Exactly one metric set was expected but found {metric.Count} metrics");
72+
73+
return metric[0];
7074
}
7175
}
7276
}

src/Microsoft.ML/Models/ClassificationMetrics.cs

+27-17
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using Microsoft.ML.Runtime;
66
using Microsoft.ML.Runtime.Api;
77
using Microsoft.ML.Runtime.Data;
8+
using System.Collections.Generic;
89

910
namespace Microsoft.ML.Models
1011
{
@@ -17,36 +18,45 @@ private ClassificationMetrics()
1718
{
1819
}
1920

20-
internal static ClassificationMetrics FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix)
21+
internal static List<ClassificationMetrics> FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix,
22+
int confusionMatriceStartIndex = 0)
2123
{
2224
Contracts.AssertValue(env);
2325
env.AssertValue(overallMetrics);
2426
env.AssertValue(confusionMatrix);
2527

2628
var metricsEnumerable = overallMetrics.AsEnumerable<SerializationClass>(env, true, ignoreMissingColumns: true);
27-
var enumerator = metricsEnumerable.GetEnumerator();
28-
if (!enumerator.MoveNext())
29+
if (!metricsEnumerable.GetEnumerator().MoveNext())
2930
{
3031
throw env.Except("The overall RegressionMetrics didn't have any rows.");
3132
}
3233

33-
SerializationClass metrics = enumerator.Current;
34+
List<ClassificationMetrics> metrics = new List<ClassificationMetrics>();
35+
var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator();
3436

35-
if (enumerator.MoveNext())
37+
int Index = 0;
38+
foreach (var metric in metricsEnumerable)
3639
{
37-
throw env.Except("The overall RegressionMetrics contained more than 1 row.");
40+
if (Index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext())
41+
{
42+
throw env.Except("Confusion matrices didn't have enough matrices.");
43+
}
44+
45+
metrics.Add(
46+
new ClassificationMetrics()
47+
{
48+
AccuracyMicro = metric.AccuracyMicro,
49+
AccuracyMacro = metric.AccuracyMacro,
50+
LogLoss = metric.LogLoss,
51+
LogLossReduction = metric.LogLossReduction,
52+
TopKAccuracy = metric.TopKAccuracy,
53+
PerClassLogLoss = metric.PerClassLogLoss,
54+
ConfusionMatrix = confusionMatrices.Current
55+
});
56+
3857
}
3958

40-
return new ClassificationMetrics()
41-
{
42-
AccuracyMicro = metrics.AccuracyMicro,
43-
AccuracyMacro = metrics.AccuracyMacro,
44-
LogLoss = metrics.LogLoss,
45-
LogLossReduction = metrics.LogLossReduction,
46-
TopKAccuracy = metrics.TopKAccuracy,
47-
PerClassLogLoss = metrics.PerClassLogLoss,
48-
ConfusionMatrix = ConfusionMatrix.Create(env, confusionMatrix)
49-
};
59+
return metrics;
5060
}
5161

5262
/// <summary>
@@ -125,7 +135,7 @@ internal static ClassificationMetrics FromMetrics(IHostEnvironment env, IDataVie
125135
/// <summary>
126136
/// This class contains the public fields necessary to deserialize from IDataView.
127137
/// </summary>
128-
private class SerializationClass
138+
private sealed class SerializationClass
129139
{
130140
#pragma warning disable 649 // never assigned
131141
[ColumnName(MultiClassClassifierEvaluator.AccuracyMicro)]

src/Microsoft.ML/Models/ConfusionMatrix.cs

+22-14
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ private ConfusionMatrix(double[,] elements, string[] classNames)
4141
});
4242
}
4343

44-
internal static ConfusionMatrix Create(IHostEnvironment env, IDataView confusionMatrix)
44+
internal static List<ConfusionMatrix> Create(IHostEnvironment env, IDataView confusionMatrix)
4545
{
4646
Contracts.AssertValue(env);
4747
env.AssertValue(confusionMatrix);
@@ -51,36 +51,44 @@ internal static ConfusionMatrix Create(IHostEnvironment env, IDataView confusion
5151
env.Except($"ConfusionMatrix data view did not contain a {nameof(MetricKinds.ColumnNames.Count)} column.");
5252
}
5353

54+
IRowCursor cursor = confusionMatrix.GetRowCursor(col => col == countColumn);
55+
var slots = default(VBuffer<DvText>);
56+
confusionMatrix.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countColumn, ref slots);
57+
string[] classNames = new string[slots.Count];
58+
for (int i = 0; i < slots.Count; i++)
59+
{
60+
classNames[i] = slots.Values[i].ToString();
61+
}
62+
5463
ColumnType type = confusionMatrix.Schema.GetColumnType(countColumn);
5564
env.Assert(type.IsVector);
56-
57-
double[,] elements = new double[type.VectorSize, type.VectorSize];
58-
59-
IRowCursor cursor = confusionMatrix.GetRowCursor(col => col == countColumn);
6065
ValueGetter<VBuffer<double>> countGetter = cursor.GetGetter<VBuffer<double>>(countColumn);
6166
VBuffer<double> countValues = default;
62-
67+
List<ConfusionMatrix> confusionMatrices = new List<ConfusionMatrix>();
68+
6369
int valuesRowIndex = 0;
70+
double[,] elements = null;
6471
while (cursor.MoveNext())
6572
{
73+
if(valuesRowIndex == 0)
74+
elements = new double[type.VectorSize, type.VectorSize];
75+
6676
countGetter(ref countValues);
6777
for (int i = 0; i < countValues.Length; i++)
6878
{
6979
elements[valuesRowIndex, i] = countValues.Values[i];
7080
}
7181

7282
valuesRowIndex++;
73-
}
7483

75-
var slots = default(VBuffer<DvText>);
76-
confusionMatrix.Schema.GetMetadata(MetadataUtils.Kinds.SlotNames, countColumn, ref slots);
77-
string[] classNames = new string[slots.Count];
78-
for (int i = 0; i < slots.Count; i++)
79-
{
80-
classNames[i] = slots.Values[i].ToString();
84+
if(valuesRowIndex == type.VectorSize)
85+
{
86+
valuesRowIndex = 0;
87+
confusionMatrices.Add(new ConfusionMatrix(elements, classNames));
88+
}
8189
}
8290

83-
return new ConfusionMatrix(elements, classNames);
91+
return confusionMatrices;
8492
}
8593

8694
/// <summary>

0 commit comments

Comments
 (0)