Skip to content

Commit 0ef9ce1

Browse files
committed
Merge PR# 207.
1 parent 363cad8 commit 0ef9ce1

20 files changed

+223
-118
lines changed

ZBaselines/Common/EntryPoints/core_ep-list.tsv

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ Data.DataViewReference Pass dataview from memory to experiment Microsoft.ML.Runt
33
Data.IDataViewArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIDataViewOutput
44
Data.PredictorModelArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayIPredictorModelOutput
55
Data.TextLoader Import a dataset from a text file Microsoft.ML.Runtime.EntryPoints.ImportTextData TextLoader Microsoft.ML.Runtime.EntryPoints.ImportTextData+LoaderInput Microsoft.ML.Runtime.EntryPoints.ImportTextData+Output
6+
Data.TransformModelArrayConverter Create and array variable Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro MakeArray Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayITransformModelInput Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+ArrayITransformModelOutput
67
Models.AnomalyDetectionEvaluator Evaluates an anomaly detection scored dataset. Microsoft.ML.Runtime.Data.Evaluate AnomalyDetection Microsoft.ML.Runtime.Data.AnomalyDetectionMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+CommonEvaluateOutput
78
Models.BinaryClassificationEvaluator Evaluates a binary classification scored dataset. Microsoft.ML.Runtime.Data.Evaluate Binary Microsoft.ML.Runtime.Data.BinaryClassifierMamlEvaluator+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+ClassificationEvaluateOutput
89
Models.BinaryCrossValidator Cross validation for binary classification Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro CrossValidateBinary Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+MacroOutput`1[Microsoft.ML.Runtime.EntryPoints.CrossValidationBinaryMacro+Output]

ZBaselines/Common/EntryPoints/core_manifest.json

+84-4
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,35 @@
469469
"ILearningPipelineLoader"
470470
]
471471
},
472+
{
473+
"Name": "Data.TransformModelArrayConverter",
474+
"Desc": "Create and array variable",
475+
"FriendlyName": null,
476+
"ShortName": null,
477+
"Inputs": [
478+
{
479+
"Name": "TransformModel",
480+
"Type": {
481+
"Kind": "Array",
482+
"ItemType": "TransformModel"
483+
},
484+
"Desc": "The models",
485+
"Required": true,
486+
"SortOrder": 1.0,
487+
"IsNullable": false
488+
}
489+
],
490+
"Outputs": [
491+
{
492+
"Name": "OutputModel",
493+
"Type": {
494+
"Kind": "Array",
495+
"ItemType": "TransformModel"
496+
},
497+
"Desc": "The model array"
498+
}
499+
]
500+
},
472501
{
473502
"Name": "Models.AnomalyDetectionEvaluator",
474503
"Desc": "Evaluates an anomaly detection scored dataset.",
@@ -1411,9 +1440,28 @@
14111440
"Name": "Model",
14121441
"Type": "PredictorModel",
14131442
"Desc": "The model",
1414-
"Required": true,
1443+
"Required": false,
14151444
"SortOrder": 1.0,
1416-
"IsNullable": false
1445+
"IsNullable": false,
1446+
"Default": null
1447+
},
1448+
{
1449+
"Name": "TransformModel",
1450+
"Type": "TransformModel",
1451+
"Desc": "The transform model",
1452+
"Required": false,
1453+
"SortOrder": 2.0,
1454+
"IsNullable": false,
1455+
"Default": null
1456+
},
1457+
{
1458+
"Name": "UseTransformModel",
1459+
"Type": "Bool",
1460+
"Desc": "Indicates to use transform model instead of predictor model.",
1461+
"Required": false,
1462+
"SortOrder": 3.0,
1463+
"IsNullable": false,
1464+
"Default": false
14171465
}
14181466
]
14191467
},
@@ -1476,6 +1524,14 @@
14761524
},
14771525
"Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel."
14781526
},
1527+
{
1528+
"Name": "TransformModel",
1529+
"Type": {
1530+
"Kind": "Array",
1531+
"ItemType": "TransformModel"
1532+
},
1533+
"Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel."
1534+
},
14791535
{
14801536
"Name": "Warnings",
14811537
"Type": "DataView",
@@ -3002,9 +3058,28 @@
30023058
"Name": "Model",
30033059
"Type": "PredictorModel",
30043060
"Desc": "The model",
3005-
"Required": true,
3061+
"Required": false,
30063062
"SortOrder": 1.0,
3007-
"IsNullable": false
3063+
"IsNullable": false,
3064+
"Default": null
3065+
},
3066+
{
3067+
"Name": "TransformModel",
3068+
"Type": "TransformModel",
3069+
"Desc": "Transform model",
3070+
"Required": false,
3071+
"SortOrder": 2.0,
3072+
"IsNullable": false,
3073+
"Default": null
3074+
},
3075+
{
3076+
"Name": "UseTransformModel",
3077+
"Type": "Bool",
3078+
"Desc": "Indicates to use transform model instead of predictor model.",
3079+
"Required": false,
3080+
"SortOrder": 3.0,
3081+
"IsNullable": false,
3082+
"Default": false
30083083
}
30093084
]
30103085
},
@@ -3058,6 +3133,11 @@
30583133
"Type": "PredictorModel",
30593134
"Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel."
30603135
},
3136+
{
3137+
"Name": "TransformModel",
3138+
"Type": "TransformModel",
3139+
"Desc": "The final model including the trained predictor model and the model from the transforms, provided as the Input.TransformModel."
3140+
},
30613141
{
30623142
"Name": "Warnings",
30633143
"Type": "DataView",

src/Microsoft.ML.Core/Data/ITransformModel.cs

-2
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ public interface ITransformModel
2525
/// </summary>
2626
ISchema InputSchema { get; }
2727

28-
IDataView Data { get; }
29-
3028
/// <summary>
3129
/// Apply the transform(s) in the model to the given input data.
3230
/// </summary>

src/Microsoft.ML.Data/EntryPoints/TransformModel.cs

-5
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,6 @@ public ISchema InputSchema
4444
get { return _schemaRoot; }
4545
}
4646

47-
public IDataView Data
48-
{
49-
get { return _chain; }
50-
}
51-
5247
/// <summary>
5348
/// Create a TransformModel containing the transforms from "result" back to "input".
5449
/// </summary>

src/Microsoft.ML/CSharpApi.cs

-5
Original file line numberDiff line numberDiff line change
@@ -11582,11 +11582,6 @@ public sealed class Output
1158211582
/// </summary>
1158311583
public Var<Microsoft.ML.Runtime.EntryPoints.ITransformModel> OutputModel { get; set; } = new Var<Microsoft.ML.Runtime.EntryPoints.ITransformModel>();
1158411584

11585-
/// <summary>
11586-
/// Data
11587-
/// </summary>
11588-
public Var<Microsoft.ML.Runtime.Data.IDataView> Data { get; set; } = new Var<Microsoft.ML.Runtime.Data.IDataView>();
11589-
1159011585
}
1159111586
}
1159211587
}

src/Microsoft.ML/Models/BinaryClassificationEvaluator.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using Microsoft.ML.Runtime.Data;
77
using Microsoft.ML.Runtime.EntryPoints;
88
using Microsoft.ML.Transforms;
9+
using System.Collections.Generic;
910

1011
namespace Microsoft.ML.Models
1112
{
@@ -23,7 +24,7 @@ public sealed partial class BinaryClassificationEvaluator
2324
/// <returns>
2425
/// A BinaryClassificationMetrics instance that describes how well the model performed against the test data.
2526
/// </returns>
26-
public BinaryClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData)
27+
public List<BinaryClassificationMetrics> Evaluate(PredictionModel model, ILearningPipelineLoader testData)
2728
{
2829
using (var environment = new TlcEnvironment())
2930
{

src/Microsoft.ML/Models/BinaryClassificationMetrics.cs

+32-22
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using Microsoft.ML.Runtime.Api;
77
using Microsoft.ML.Runtime.Data;
88
using System;
9+
using System.Collections.Generic;
910

1011
namespace Microsoft.ML.Models
1112
{
@@ -18,7 +19,7 @@ private BinaryClassificationMetrics()
1819
{
1920
}
2021

21-
internal static BinaryClassificationMetrics FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix)
22+
internal static List<BinaryClassificationMetrics> FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix)
2223
{
2324
Contracts.AssertValue(env);
2425
env.AssertValue(overallMetrics);
@@ -31,28 +32,37 @@ internal static BinaryClassificationMetrics FromMetrics(IHostEnvironment env, ID
3132
throw env.Except("The overall RegressionMetrics didn't have any rows.");
3233
}
3334

34-
SerializationClass metrics = enumerator.Current;
35-
36-
if (enumerator.MoveNext())
37-
{
38-
throw env.Except("The overall RegressionMetrics contained more than 1 row.");
39-
}
40-
41-
return new BinaryClassificationMetrics()
35+
List<BinaryClassificationMetrics> metrics = new List<BinaryClassificationMetrics>();
36+
var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator();
37+
do
4238
{
43-
Auc = metrics.Auc,
44-
Accuracy = metrics.Accuracy,
45-
PositivePrecision = metrics.PositivePrecision,
46-
PositiveRecall = metrics.PositiveRecall,
47-
NegativePrecision = metrics.NegativePrecision,
48-
NegativeRecall = metrics.NegativeRecall,
49-
LogLoss = metrics.LogLoss,
50-
LogLossReduction = metrics.LogLossReduction,
51-
Entropy = metrics.Entropy,
52-
F1Score = metrics.F1Score,
53-
Auprc = metrics.Auprc,
54-
ConfusionMatrix = ConfusionMatrix.Create(env, confusionMatrix),
55-
};
39+
SerializationClass metric = enumerator.Current;
40+
41+
if (!confusionMatrices.MoveNext())
42+
{
43+
throw env.Except("Confusion matrices didn't have enough matrices.");
44+
}
45+
46+
metrics.Add(
47+
new BinaryClassificationMetrics()
48+
{
49+
Auc = metric.Auc,
50+
Accuracy = metric.Accuracy,
51+
PositivePrecision = metric.PositivePrecision,
52+
PositiveRecall = metric.PositiveRecall,
53+
NegativePrecision = metric.NegativePrecision,
54+
NegativeRecall = metric.NegativeRecall,
55+
LogLoss = metric.LogLoss,
56+
LogLossReduction = metric.LogLossReduction,
57+
Entropy = metric.Entropy,
58+
F1Score = metric.F1Score,
59+
Auprc = metric.Auprc,
60+
ConfusionMatrix = confusionMatrices.Current,
61+
});
62+
63+
} while (enumerator.MoveNext());
64+
65+
return metrics;
5666
}
5767

5868
/// <summary>

src/Microsoft.ML/Models/ClassificationEvaluator.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using Microsoft.ML.Runtime;
66
using Microsoft.ML.Runtime.Data;
77
using Microsoft.ML.Transforms;
8+
using System.Collections.Generic;
89

910
namespace Microsoft.ML.Models
1011
{
@@ -23,7 +24,7 @@ public sealed partial class ClassificationEvaluator
2324
/// <returns>
2425
/// A ClassificationMetrics instance that describes how well the model performed against the test data.
2526
/// </returns>
26-
public ClassificationMetrics Evaluate(PredictionModel model, ILearningPipelineLoader testData)
27+
public List<ClassificationMetrics> Evaluate(PredictionModel model, ILearningPipelineLoader testData)
2728
{
2829
using (var environment = new TlcEnvironment())
2930
{

src/Microsoft.ML/Models/ClassificationMetrics.cs

+27-18
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using Microsoft.ML.Runtime;
66
using Microsoft.ML.Runtime.Api;
77
using Microsoft.ML.Runtime.Data;
8+
using System.Collections.Generic;
89

910
namespace Microsoft.ML.Models
1011
{
@@ -17,7 +18,7 @@ private ClassificationMetrics()
1718
{
1819
}
1920

20-
internal static ClassificationMetrics FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix)
21+
internal static List<ClassificationMetrics> FromMetrics(IHostEnvironment env, IDataView overallMetrics, IDataView confusionMatrix)
2122
{
2223
Contracts.AssertValue(env);
2324
env.AssertValue(overallMetrics);
@@ -29,24 +30,32 @@ internal static ClassificationMetrics FromMetrics(IHostEnvironment env, IDataVie
2930
{
3031
throw env.Except("The overall RegressionMetrics didn't have any rows.");
3132
}
32-
33-
SerializationClass metrics = enumerator.Current;
34-
35-
if (enumerator.MoveNext())
36-
{
37-
throw env.Except("The overall RegressionMetrics contained more than 1 row.");
38-
}
39-
40-
return new ClassificationMetrics()
33+
34+
List<ClassificationMetrics> metrics = new List<ClassificationMetrics>();
35+
var confusionMatrices = ConfusionMatrix.Create(env, confusionMatrix).GetEnumerator();
36+
do
4137
{
42-
AccuracyMicro = metrics.AccuracyMicro,
43-
AccuracyMacro = metrics.AccuracyMacro,
44-
LogLoss = metrics.LogLoss,
45-
LogLossReduction = metrics.LogLossReduction,
46-
TopKAccuracy = metrics.TopKAccuracy,
47-
PerClassLogLoss = metrics.PerClassLogLoss,
48-
ConfusionMatrix = ConfusionMatrix.Create(env, confusionMatrix)
49-
};
38+
if (!confusionMatrices.MoveNext())
39+
{
40+
throw env.Except("Confusion matrices didn't have enough matrices.");
41+
}
42+
43+
SerializationClass metric = enumerator.Current;
44+
metrics.Add(
45+
new ClassificationMetrics()
46+
{
47+
AccuracyMicro = metric.AccuracyMicro,
48+
AccuracyMacro = metric.AccuracyMacro,
49+
LogLoss = metric.LogLoss,
50+
LogLossReduction = metric.LogLossReduction,
51+
TopKAccuracy = metric.TopKAccuracy,
52+
PerClassLogLoss = metric.PerClassLogLoss,
53+
ConfusionMatrix = confusionMatrices.Current
54+
});
55+
56+
} while (enumerator.MoveNext());
57+
58+
return metrics;
5059
}
5160

5261
/// <summary>

0 commit comments

Comments
 (0)