|
| 1 | +using System; |
| 2 | +using System.IO; |
| 3 | +using System.Linq; |
| 4 | +using Microsoft.ML.Auto; |
| 5 | +using Microsoft.ML.Data; |
| 6 | + |
| 7 | +namespace Microsoft.ML.AutoML.Samples |
| 8 | +{ |
| 9 | + public static class MulticlassClassificationExperiment |
| 10 | + { |
| 11 | + private static string TrainDataPath = "<Path to your train dataset goes here>"; |
| 12 | + private static string TestDataPath = "<Path to your test dataset goes here>"; |
| 13 | + private static string ModelPath = @"<Desired model output directory goes here>\OptDigitsModel.zip"; |
| 14 | + private static string LabelColumnName = "Number"; |
| 15 | + private static uint ExperimentTime = 60; |
| 16 | + |
| 17 | + public static void Run() |
| 18 | + { |
| 19 | + MLContext mlContext = new MLContext(); |
| 20 | + |
| 21 | + // STEP 1: Load data |
| 22 | + IDataView trainDataView = mlContext.Data.LoadFromTextFile<PixelData>(TrainDataPath, separatorChar: ','); |
| 23 | + IDataView testDataView = mlContext.Data.LoadFromTextFile<PixelData>(TestDataPath, separatorChar: ','); |
| 24 | + |
| 25 | + // STEP 2: Run AutoML experiment |
| 26 | + Console.WriteLine($"Running AutoML multiclass classification experiment for {ExperimentTime} seconds..."); |
| 27 | + ExperimentResult<MulticlassClassificationMetrics> experimentResult = mlContext.Auto() |
| 28 | + .CreateMulticlassClassificationExperiment(ExperimentTime) |
| 29 | + .Execute(trainDataView, LabelColumnName); |
| 30 | + |
| 31 | + // STEP 3: Print metric from the best model |
| 32 | + RunDetail<MulticlassClassificationMetrics> bestRun = experimentResult.BestRun; |
| 33 | + Console.WriteLine($"Total models produced: {experimentResult.RunDetails.Count()}"); |
| 34 | + Console.WriteLine($"Best model's trainer: {bestRun.TrainerName}"); |
| 35 | + Console.WriteLine($"Metrics of best model from validation data --"); |
| 36 | + PrintMetrics(bestRun.ValidationMetrics); |
| 37 | + |
| 38 | + // STEP 4: Evaluate test data |
| 39 | + IDataView testDataViewWithBestScore = bestRun.Model.Transform(testDataView); |
| 40 | + MulticlassClassificationMetrics testMetrics = mlContext.MulticlassClassification.Evaluate(testDataViewWithBestScore, labelColumnName: LabelColumnName); |
| 41 | + Console.WriteLine($"Metrics of best model on test data --"); |
| 42 | + PrintMetrics(testMetrics); |
| 43 | + |
| 44 | + // STEP 5: Save the best model for later deployment and inferencing |
| 45 | + using (FileStream fs = File.Create(ModelPath)) |
| 46 | + mlContext.Model.Save(bestRun.Model, trainDataView.Schema, fs); |
| 47 | + |
| 48 | + // STEP 6: Create prediction engine from the best trained model |
| 49 | + var predictionEngine = mlContext.Model.CreatePredictionEngine<PixelData, PixelPrediction>(bestRun.Model); |
| 50 | + |
| 51 | + // STEP 7: Initialize new pixel data, and get the predicted number |
| 52 | + var testPixelData = new PixelData |
| 53 | + { |
| 54 | + PixelValues = new float[] { 0, 0, 1, 8, 15, 10, 0, 0, 0, 3, 13, 15, 14, 14, 0, 0, 0, 5, 10, 0, 10, 12, 0, 0, 0, 0, 3, 5, 15, 10, 2, 0, 0, 0, 16, 16, 16, 16, 12, 0, 0, 1, 8, 12, 14, 8, 3, 0, 0, 0, 0, 10, 13, 0, 0, 0, 0, 0, 0, 11, 9, 0, 0, 0 } |
| 55 | + }; |
| 56 | + var prediction = predictionEngine.Predict(testPixelData); |
| 57 | + Console.WriteLine($"Predicted number for test pixels: {prediction.Prediction}"); |
| 58 | + |
| 59 | + Console.WriteLine("Press any key to continue..."); |
| 60 | + Console.ReadKey(); |
| 61 | + } |
| 62 | + |
| 63 | + private static void PrintMetrics(MulticlassClassificationMetrics metrics) |
| 64 | + { |
| 65 | + Console.WriteLine($"LogLoss: {metrics.LogLoss}"); |
| 66 | + Console.WriteLine($"LogLossReduction: {metrics.LogLossReduction}"); |
| 67 | + Console.WriteLine($"MacroAccuracy: {metrics.MacroAccuracy}"); |
| 68 | + Console.WriteLine($"MicroAccuracy: {metrics.MicroAccuracy}"); |
| 69 | + } |
| 70 | + } |
| 71 | +} |
0 commit comments