Skip to content

Commit 0f94a3b

Browse files
authored
Pass fold index to cross validation metrics. (#575)
RowTag in metrics
1 parent 0e0f702 commit 0f94a3b

File tree

4 files changed

+45
-2
lines changed

4 files changed

+45
-2
lines changed

src/Microsoft.ML/Models/BinaryClassificationMetrics.cs

+12-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using Microsoft.ML.Runtime.Data;
88
using System;
99
using System.Collections.Generic;
10+
using static Microsoft.ML.Runtime.Data.MetricKinds;
1011

1112
namespace Microsoft.ML.Models
1213
{
@@ -35,7 +36,7 @@ internal static List<BinaryClassificationMetrics> FromMetrics(IHostEnvironment e
3536
var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator();
3637

3738
int index = 0;
38-
foreach(var metric in metricsEnumerable)
39+
foreach (var metric in metricsEnumerable)
3940
{
4041

4142
if (index++ >= confusionMatriceStartIndex && !confusionMatrices.MoveNext())
@@ -57,6 +58,7 @@ internal static List<BinaryClassificationMetrics> FromMetrics(IHostEnvironment e
5758
Entropy = metric.Entropy,
5859
F1Score = metric.F1Score,
5960
Auprc = metric.Auprc,
61+
RowTag = metric.RowTag,
6062
ConfusionMatrix = confusionMatrices.Current,
6163
});
6264

@@ -162,6 +164,12 @@ internal static List<BinaryClassificationMetrics> FromMetrics(IHostEnvironment e
162164
/// </summary>
163165
public ConfusionMatrix ConfusionMatrix { get; private set; }
164166

167+
/// <summary>
168+
/// For cross-validation, this is equal to "Fold N" for per-fold metric rows, "Overall" for the average metrics and "STD" for standard deviation.
169+
/// For non-CV scenarios, this is equal to null
170+
/// </summary>
171+
public string RowTag { get; private set; }
172+
165173
/// <summary>
166174
/// This class contains the public fields necessary to deserialize from IDataView.
167175
/// </summary>
@@ -200,6 +208,9 @@ private sealed class SerializationClass
200208

201209
[ColumnName(BinaryClassifierEvaluator.AuPrc)]
202210
public Double Auprc;
211+
212+
[ColumnName(ColumnNames.FoldIndex)]
213+
public string RowTag;
203214
#pragma warning restore 649 // never assigned
204215
}
205216
}

src/Microsoft.ML/Models/ClassificationMetrics.cs

+12-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using Microsoft.ML.Runtime.Api;
77
using Microsoft.ML.Runtime.Data;
88
using System.Collections.Generic;
9+
using static Microsoft.ML.Runtime.Data.MetricKinds;
910

1011
namespace Microsoft.ML.Models
1112
{
@@ -51,7 +52,8 @@ internal static List<ClassificationMetrics> FromMetrics(IHostEnvironment env, ID
5152
LogLossReduction = metric.LogLossReduction,
5253
TopKAccuracy = metric.TopKAccuracy,
5354
PerClassLogLoss = metric.PerClassLogLoss,
54-
ConfusionMatrix = confusionMatrices.Current
55+
ConfusionMatrix = confusionMatrices.Current,
56+
RowTag = metric.RowTag,
5557
});
5658

5759
}
@@ -127,6 +129,12 @@ internal static List<ClassificationMetrics> FromMetrics(IHostEnvironment env, ID
127129
/// </remarks>
128130
public double[] PerClassLogLoss { get; private set; }
129131

132+
/// <summary>
133+
/// For cross-validation, this is equal to "Fold N" for per-fold metric rows, "Overall" for the average metrics and "STD" for standard deviation.
134+
/// For non-CV scenarios, this is equal to null
135+
/// </summary>
136+
public string RowTag { get; private set; }
137+
130138
/// <summary>
131139
/// Gets the confusion matrix, or error matrix, of the classifier.
132140
/// </summary>
@@ -155,6 +163,9 @@ private sealed class SerializationClass
155163

156164
[ColumnName(MultiClassClassifierEvaluator.PerClassLogLoss)]
157165
public double[] PerClassLogLoss;
166+
167+
[ColumnName(ColumnNames.FoldIndex)]
168+
public string RowTag;
158169
#pragma warning restore 649 // never assigned
159170
}
160171
}

src/Microsoft.ML/Models/ClusterMetrics.cs

+10
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using Microsoft.ML.Runtime.Data;
88
using System;
99
using System.Collections.Generic;
10+
using static Microsoft.ML.Runtime.Data.MetricKinds;
1011

1112
namespace Microsoft.ML.Models
1213
{
@@ -38,6 +39,7 @@ internal static List<ClusterMetrics> FromOverallMetrics(IHostEnvironment env, ID
3839
AvgMinScore = metric.AvgMinScore,
3940
Nmi = metric.Nmi,
4041
Dbi = metric.Dbi,
42+
RowTag = metric.RowTag,
4143
});
4244
}
4345

@@ -73,6 +75,12 @@ internal static List<ClusterMetrics> FromOverallMetrics(IHostEnvironment env, ID
7375
/// </remarks>
7476
public double AvgMinScore { get; private set; }
7577

78+
/// <summary>
79+
/// For cross-validation, this is equal to "Fold N" for per-fold metric rows, "Overall" for the average metrics and "STD" for standard deviation.
80+
/// For non-CV scenarios, this is equal to null
81+
/// </summary>
82+
public string RowTag { get; private set; }
83+
7684
/// <summary>
7785
/// This class contains the public fields necessary to deserialize from IDataView.
7886
/// </summary>
@@ -88,6 +96,8 @@ private sealed class SerializationClass
8896
[ColumnName(Runtime.Data.ClusteringEvaluator.AvgMinScore)]
8997
public Double AvgMinScore;
9098

99+
[ColumnName(ColumnNames.FoldIndex)]
100+
public string RowTag;
91101
#pragma warning restore 649 // never assigned
92102
}
93103
}

src/Microsoft.ML/Models/RegressionMetrics.cs

+11
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using Microsoft.ML.Runtime.Data;
88
using System;
99
using System.Collections.Generic;
10+
using static Microsoft.ML.Runtime.Data.MetricKinds;
1011

1112
namespace Microsoft.ML.Models
1213
{
@@ -40,6 +41,7 @@ internal static List<RegressionMetrics> FromOverallMetrics(IHostEnvironment env,
4041
Rms = metric.Rms,
4142
LossFn = metric.LossFn,
4243
RSquared = metric.RSquared,
44+
RowTag = metric.RowTag,
4345
});
4446
}
4547

@@ -90,6 +92,12 @@ internal static List<RegressionMetrics> FromOverallMetrics(IHostEnvironment env,
9092
/// </summary>
9193
public double RSquared { get; private set; }
9294

95+
/// <summary>
96+
/// For cross-validation, this is equal to "Fold N" for per-fold metric rows, "Overall" for the average metrics and "STD" for standard deviation.
97+
/// For non-CV scenarios, this is equal to null
98+
/// </summary>
99+
public string RowTag { get; private set; }
100+
93101
/// <summary>
94102
/// This class contains the public fields necessary to deserialize from IDataView.
95103
/// </summary>
@@ -110,6 +118,9 @@ private sealed class SerializationClass
110118

111119
[ColumnName(Runtime.Data.RegressionEvaluator.RSquared)]
112120
public Double RSquared;
121+
122+
[ColumnName(ColumnNames.FoldIndex)]
123+
public string RowTag;
113124
#pragma warning restore 649 // never assigned
114125
}
115126
}

0 commit comments

Comments
 (0)