|
21 | 21 | using Microsoft.ML.Trainers;
|
22 | 22 | using Microsoft.ML.Transforms;
|
23 | 23 | using Microsoft.ML.Transforms.Onnx;
|
| 24 | +using Microsoft.ML.Transforms.Text; |
24 | 25 | using Newtonsoft.Json;
|
25 | 26 | using Xunit;
|
26 | 27 | using Xunit.Abstractions;
|
27 | 28 | using static Microsoft.ML.Model.OnnxConverter.OnnxCSharpToProtoWrapper;
|
28 | 29 |
|
29 | 30 | namespace Microsoft.ML.Tests
|
30 | 31 | {
|
31 |
| - public class OnnxConversionTest : BaseTestBaseline |
| 32 | + |
| 33 | +public class OnnxConversionTest : BaseTestBaseline |
32 | 34 | {
|
| 35 | + |
| 36 | + private static IEnumerable<DataPoint2> GenerateRandomDataPoints(int count, |
| 37 | + int seed = 0) |
| 38 | + { |
| 39 | + var random = new Random(seed); |
| 40 | + for (int i = 0; i < count; i++) |
| 41 | + { |
| 42 | + float label = (float)random.NextDouble(); |
| 43 | + yield return new DataPoint2 |
| 44 | + { |
| 45 | + Label = label, |
| 46 | + // Create random features that are correlated with the label. |
| 47 | + Features = Enumerable.Repeat(label, 50).Select( |
| 48 | + x => x + (float)random.NextDouble()).ToArray() |
| 49 | + }; |
| 50 | + } |
| 51 | + } |
| 52 | + |
| 53 | + // Example with label and 50 feature values. A data set is a collection of |
| 54 | + // such examples. |
| 55 | + private class DataPoint2 |
| 56 | + { |
| 57 | + public float Label { get; set; } |
| 58 | + [VectorType(50)] |
| 59 | + public float[] Features { get; set; } |
| 60 | + } |
| 61 | + |
| 62 | + // Class used to capture predictions. |
| 63 | + private class Prediction |
| 64 | + { |
| 65 | + // Original label. |
| 66 | + public float Label { get; set; } |
| 67 | + // Predicted score from the trainer. |
| 68 | + public float Score { get; set; } |
| 69 | + } |
33 | 70 | private class AdultData
|
34 | 71 | {
|
35 | 72 | [LoadColumn(0, 10), ColumnName("FeatureVector")]
|
@@ -108,8 +145,7 @@ public void SimpleEndToEndOnnxConversionTest()
|
108 | 145 | private class BreastCancerFeatureVector
|
109 | 146 | {
|
110 | 147 | [LoadColumn(1, 9), VectorType(9)]
|
111 |
| - public float[] Features; |
112 |
| - } |
| 148 | + public float[] Features; } |
113 | 149 |
|
114 | 150 | private class BreastCancerCatFeatureExample
|
115 | 151 | {
|
@@ -187,7 +223,160 @@ public void KmeansOnnxConversionTest()
|
187 | 223 | Done();
|
188 | 224 | }
|
189 | 225 |
|
190 |
| - private class DataPoint |
| 226 | + [Fact] |
| 227 | + public void WordEmbeddingEstimatorOnnxConversionTest() //can't find the class - maybe |
| 228 | + { |
| 229 | + // Step 1: Create and train a ML.NET pipeline. |
| 230 | + var mlContext = new MLContext(seed: 1); |
| 231 | + string dataPath = GetDataPath(TestDatasets.Sentiment.trainFilename); |
| 232 | + var data = new TextLoader(ML, |
| 233 | + new TextLoader.Options() |
| 234 | + { |
| 235 | + Separator = "\t", |
| 236 | + HasHeader = true, |
| 237 | + Columns = new[] |
| 238 | + { |
| 239 | + new TextLoader.Column("Label", DataKind.Boolean, 0), |
| 240 | + new TextLoader.Column("SentimentText", DataKind.String, 1) |
| 241 | + } |
| 242 | + }).Load(GetDataPath(dataPath)); |
| 243 | + |
| 244 | + IEstimator<ITransformer>[] estimators = { }; |
| 245 | + var textPipeline = mlContext.Transforms.Text.NormalizeText("SentimentText") |
| 246 | + .Append(mlContext.Transforms.Text.TokenizeIntoWords("Tokens", |
| 247 | + "SentimentText")) |
| 248 | + .Append(mlContext.Transforms.Text.ApplyWordEmbedding("Features", |
| 249 | + "Tokens", WordEmbeddingEstimator.PretrainedModelKind |
| 250 | + .SentimentSpecificWordEmbedding)); |
| 251 | + var model = textPipeline.Fit(data); |
| 252 | + var transformedData = model.Transform(data); |
| 253 | + |
| 254 | + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, data); |
| 255 | + // Compare results produced by ML.NET and ONNX's runtime. |
| 256 | + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess) |
| 257 | + { |
| 258 | + var onnxFileName = "WordEmbeddingEstimator.onnx"; |
| 259 | + var onnxModelPath = GetOutputPath(onnxFileName); |
| 260 | + SaveOnnxModel(onnxModel, onnxModelPath, null); |
| 261 | + |
| 262 | + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. |
| 263 | + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); |
| 264 | + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); |
| 265 | + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); |
| 266 | + var onnxTransformer = onnxEstimator.Fit(data); |
| 267 | + var onnxResult = onnxTransformer.Transform(data); |
| 268 | + CompareSelectedR4VectorColumns("Score", "Score0", transformedData, onnxResult, 3); |
| 269 | + } |
| 270 | + Done(); |
| 271 | + } |
| 272 | + |
| 273 | + [Fact] |
| 274 | + // Conversion tests for regression |
| 275 | + public void regressionOnnxConversionTest() |
| 276 | + { |
| 277 | + /* |
| 278 | + var mlContext = new MLContext(seed: 1); |
| 279 | + string dataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename); |
| 280 | +
|
| 281 | + // Now read the file (remember though, readers are lazy, so the actual reading will happen when the data is accessed). |
| 282 | + var dataView = mlContext.Data.LoadFromTextFile<AdultData>(dataPath, |
| 283 | + separatorChar: ';', |
| 284 | + hasHeader: true); |
| 285 | + IEstimator<ITransformer>[] estimators = { |
| 286 | + //mlContext.Regression.Trainers.Ols(new OlsTrainer.Options() { |
| 287 | + // LabelColumnName = "Target", |
| 288 | + // FeatureColumnName = "FeatureVector", |
| 289 | + //}), |
| 290 | + //mlContext.Regression.Trainers.OnlineGradientDescent(new OnlineGradientDescentTrainer.Options(){ |
| 291 | + // LabelColumnName = "Target", |
| 292 | + // FeatureColumnName = "FeatureVector", |
| 293 | + //}), |
| 294 | + //mlContext.Transforms.DetectAnomalyBySrCnn("Target","FeatureVector"), // needs separate data |
| 295 | + mlContext.Regression.Trainers.FastForest("Target", "FeatureVector"), |
| 296 | + //mlContext.Regression.Trainers.FastTree("Target", "FeatureVector"), |
| 297 | + //mlContext.Regression.Trainers.FastTreeTweedie("Target", "FeatureVector"), |
| 298 | + //mlContext.Regression.Trainers.LightGbm("Target","FeatureVector"), |
| 299 | + //mlContext.Regression.Trainers.LbfgsPoissonRegression("Target", "FeatureVector"), |
| 300 | + }; |
| 301 | + */ |
| 302 | + // Create a new context for ML.NET operations. It can be used for |
| 303 | + // exception tracking and logging, as a catalog of available operations |
| 304 | + // and as the source of randomness. Setting the seed to a fixed number |
| 305 | + // in this example to make outputs deterministic. |
| 306 | + var mlContext = new MLContext(seed: 0); |
| 307 | + |
| 308 | + // Create a list of training data points. |
| 309 | + var dataPoints = GenerateRandomDataPoints(1000); |
| 310 | + |
| 311 | + // Convert the list of data points to an IDataView object, which is |
| 312 | + // consumable by ML.NET API. |
| 313 | + var trainingData = mlContext.Data.LoadFromEnumerable(dataPoints); |
| 314 | + |
| 315 | + // Define the trainer. |
| 316 | + var pipeline = mlContext.Regression.Trainers.FastTreeTweedie( |
| 317 | + labelColumnName: nameof(DataPoint2.Label), |
| 318 | + featureColumnName: nameof(DataPoint2.Features)); |
| 319 | + |
| 320 | + // Train the model. |
| 321 | + var model = pipeline.Fit(trainingData); |
| 322 | + |
| 323 | + // Create testing data. Use different random seed to make it different |
| 324 | + // from training data. |
| 325 | + var data = mlContext.Data.LoadFromEnumerable( |
| 326 | + GenerateRandomDataPoints(5, seed: 123)); |
| 327 | + |
| 328 | + // Run the model on test data set. |
| 329 | + var transformedTestData = model.Transform(data); |
| 330 | + // Convert IDataView object to a list. |
| 331 | + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, data); |
| 332 | + // Convert IDataView object to a list. |
| 333 | + var predictions = mlContext.Data.CreateEnumerable<Prediction>( |
| 334 | + transformedTestData, reuseRowObject: false).ToList(); |
| 335 | + foreach (var p in predictions) |
| 336 | + System.Diagnostics.Debug.WriteLine($"Label: {p.Label:F3}, Prediction: {p.Score:F3}"); |
| 337 | + // Compare results produced by ML.NET and ONNX's runtime. |
| 338 | + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess) |
| 339 | + { |
| 340 | + var onnxFileName = "test.onnx"; |
| 341 | + var onnxModelPath = GetOutputPath(onnxFileName); |
| 342 | + SaveOnnxModel(onnxModel, onnxModelPath, null); |
| 343 | + |
| 344 | + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. |
| 345 | + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); |
| 346 | + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); |
| 347 | + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); |
| 348 | + var onnxTransformer = onnxEstimator.Fit(data); |
| 349 | + var onnxResult = onnxTransformer.Transform(data); |
| 350 | + CompareSelectedR4ScalarColumns("Label", "Score0", data, onnxResult, 3); |
| 351 | + } |
| 352 | + Done(); |
| 353 | + /*var initialPipeline = mlContext.Transforms.NormalizeMinMax("FeatureVector"); |
| 354 | + foreach (var estimator in estimators) |
| 355 | + { |
| 356 | + //var pipeline = initialPipeline.Append(estimator); |
| 357 | + var pipeline = estimator; |
| 358 | +
|
| 359 | + var model = pipeline.Fit(dataView); |
| 360 | + var transformedData = model.Transform(dataView); |
| 361 | + var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, dataView); |
| 362 | + var onnxFileName = $"{estimator.ToString()}.onnx"; |
| 363 | + var onnxModelPath = GetOutputPath(onnxFileName); |
| 364 | + SaveOnnxModel(onnxModel, onnxModelPath, null); |
| 365 | + // Compare model scores produced by ML.NET and ONNX's runtime. |
| 366 | + if (IsOnnxRuntimeSupported()) |
| 367 | + { |
| 368 | + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. |
| 369 | + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); |
| 370 | + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); |
| 371 | + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); |
| 372 | + var onnxTransformer = onnxEstimator.Fit(dataView); |
| 373 | + var onnxResult = onnxTransformer.Transform(dataView); //switched to 2 vause |
| 374 | + CompareSelectedR4ScalarColumns(transformedData.Schema[2].Name, outputNames[2], transformedData, onnxResult, 0); // compare score results |
| 375 | + } |
| 376 | + } */ |
| 377 | + //Done(); |
| 378 | + } |
| 379 | + private class DataPoint |
191 | 380 | {
|
192 | 381 | [VectorType(3)]
|
193 | 382 | public float[] Features { get; set; }
|
@@ -380,8 +569,7 @@ public void LogisticRegressionOnnxConversionTest()
|
380 | 569 | var trainDataPath = GetDataPath(TestDatasets.generatedRegressionDataset.trainFilename);
|
381 | 570 | var mlContext = new MLContext(seed: 1);
|
382 | 571 | var data = mlContext.Data.LoadFromTextFile<AdultData>(trainDataPath,
|
383 |
| - separatorChar: ';' |
384 |
| -, |
| 572 | + separatorChar: ';', |
385 | 573 | hasHeader: true);
|
386 | 574 | var cachedTrainData = mlContext.Data.Cache(data);
|
387 | 575 | var dynamicPipeline =
|
@@ -658,15 +846,21 @@ public void WordEmbeddingsTest()
|
658 | 846 | var model = pipeline.Fit(data);
|
659 | 847 | var transformedData = model.Transform(data);
|
660 | 848 |
|
661 |
| - var subDir = Path.Combine("..", "..", "BaselineOutput", "Common", "Onnx", "Transforms", "Sentiment"); |
662 |
| - var onnxTextName = "SmallWordEmbed.txt"; |
663 |
| - var onnxFileName = "SmallWordEmbed.onnx"; |
664 |
| - var onnxTextPath = GetOutputPath(subDir, onnxTextName); |
665 |
| - var onnxFilePath = GetOutputPath(subDir, onnxFileName); |
666 | 849 | var onnxModel = mlContext.Model.ConvertToOnnxProtobuf(model, data);
|
667 |
| - SaveOnnxModel(onnxModel, onnxFilePath, onnxTextPath); |
| 850 | + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && Environment.Is64BitProcess) |
| 851 | + { |
| 852 | + var onnxFileName = "WordEmbeddingEstimator.onnx"; |
| 853 | + var onnxModelPath = GetOutputPath(onnxFileName); |
| 854 | + SaveOnnxModel(onnxModel, onnxModelPath, null); |
668 | 855 |
|
669 |
| - CheckEquality(subDir, onnxTextName, parseOption: NumberParseOption.UseSingle); |
| 856 | + // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. |
| 857 | + string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); |
| 858 | + string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); |
| 859 | + var onnxEstimator = mlContext.Transforms.ApplyOnnxModel(outputNames, inputNames, onnxModelPath); |
| 860 | + var onnxTransformer = onnxEstimator.Fit(data); |
| 861 | + var onnxResult = onnxTransformer.Transform(data); |
| 862 | + CompareSelectedR4VectorColumns("Embed", "Embed0", transformedData, onnxResult); |
| 863 | + } |
670 | 864 | Done();
|
671 | 865 | }
|
672 | 866 |
|
@@ -984,11 +1178,44 @@ private void CompareSelectedR4ScalarColumns(string leftColumnName, string rightC
|
984 | 1178 |
|
985 | 1179 | // Scalar such as R4 (float) is converted to [1, 1]-tensor in ONNX format for consitency of making batch prediction.
|
986 | 1180 | Assert.Equal(1, actual.Length);
|
987 |
| - Assert.Equal(expected, actual.GetItemOrDefault(0), precision); |
| 1181 | + //Assert.Equal(expected, actual.GetItemOrDefault(0), precision); |
| 1182 | + //Output.WriteLine(actual.GetItemOrDefault(0)); |
| 1183 | + System.Diagnostics.Debug.WriteLine("Actual: " + actual.GetItemOrDefault(0)); |
| 1184 | + System.Diagnostics.Debug.WriteLine("Expected: " + expected); |
988 | 1185 | }
|
989 | 1186 | }
|
990 | 1187 | }
|
991 | 1188 |
|
| 1189 | + private void CompareSelectedScalarColumns<T>(string leftColumnName, string rightColumnName, IDataView left, IDataView right) |
| 1190 | + { |
| 1191 | + var leftColumn = left.Schema[leftColumnName]; |
| 1192 | + var rightColumn = right.Schema[rightColumnName]; |
| 1193 | + |
| 1194 | + using (var expectedCursor = left.GetRowCursor(leftColumn)) |
| 1195 | + using (var actualCursor = right.GetRowCursor(rightColumn)) |
| 1196 | + { |
| 1197 | + T expected = default; |
| 1198 | + VBuffer<T> actual = default; |
| 1199 | + var expectedGetter = expectedCursor.GetGetter<T>(leftColumn); |
| 1200 | + var actualGetter = actualCursor.GetGetter<VBuffer<T>>(rightColumn); |
| 1201 | + while (expectedCursor.MoveNext() && actualCursor.MoveNext()) |
| 1202 | + { |
| 1203 | + expectedGetter(ref expected); |
| 1204 | + actualGetter(ref actual); |
| 1205 | + var actualVal = actual.GetItemOrDefault(0); |
| 1206 | + |
| 1207 | + Assert.Equal(1, actual.Length); |
| 1208 | + |
| 1209 | + if (typeof(T) == typeof(ReadOnlyMemory<Char>)) |
| 1210 | + Assert.Equal(expected.ToString(), actualVal.ToString()); |
| 1211 | + else |
| 1212 | + Assert.Equal(expected, actualVal); |
| 1213 | + } |
| 1214 | + } |
| 1215 | + } |
| 1216 | + |
| 1217 | + |
| 1218 | + |
992 | 1219 | private void SaveOnnxModel(ModelProto model, string binaryFormatPath, string textFormatPath)
|
993 | 1220 | {
|
994 | 1221 | DeleteOutputPath(binaryFormatPath); // Clean if such a file exists.
|
|
0 commit comments