|
| 1 | +//***************************************************************************************** |
| 2 | +//* * |
| 3 | +//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * |
| 4 | +//* * |
| 5 | +//***************************************************************************************** |
| 6 | + |
| 7 | +using System; |
| 8 | +using System.IO; |
| 9 | +using System.Linq; |
| 10 | +using Microsoft.ML; |
| 11 | +using Microsoft.ML.Data; |
| 12 | +using Microsoft.Data.DataView; |
| 13 | +using Microsoft.ML.LightGBM; |
| 14 | + |
| 15 | + |
| 16 | +namespace MyNamespace |
| 17 | +{ |
| 18 | + class Program |
| 19 | + { |
| 20 | + private static string TrainDataPath = @"x:\dummypath\dummy_train.csv"; |
| 21 | + private static string TestDataPath = @"x:\dummypath\dummy_test.csv"; |
| 22 | + private static string ModelPath = @"x:\models\model.zip"; |
| 23 | + |
| 24 | + static void Main(string[] args) |
| 25 | + { |
| 26 | + // Create MLContext to be shared across the model creation workflow objects |
| 27 | + var mlContext = new MLContext(); |
| 28 | + |
| 29 | + var command = Command.Predict; // Your desired action here |
| 30 | + |
| 31 | + if (command == Command.Predict) |
| 32 | + { |
| 33 | + Predict(mlContext); |
| 34 | + ConsoleHelper.ConsoleWriteHeader("=============== If you also want to train a model use Command.TrainAndPredict ==============="); |
| 35 | + } |
| 36 | + |
| 37 | + if (command == Command.TrainAndPredict) |
| 38 | + { |
| 39 | + TrainEvaluateAndSaveModel(mlContext); |
| 40 | + Predict(mlContext); |
| 41 | + } |
| 42 | + |
| 43 | + Console.WriteLine("=============== End of process, hit any key to finish ==============="); |
| 44 | + Console.ReadKey(); |
| 45 | + } |
| 46 | + |
| 47 | + private enum Command |
| 48 | + { |
| 49 | + Predict, |
| 50 | + TrainAndPredict |
| 51 | + } |
| 52 | + |
| 53 | + private static ITransformer TrainEvaluateAndSaveModel(MLContext mlContext) |
| 54 | + { |
| 55 | + // Load data |
| 56 | + Console.WriteLine("=============== Loading data ==============="); |
| 57 | + IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>( |
| 58 | + path: TrainDataPath, |
| 59 | + hasHeader: true, |
| 60 | + separatorChar: ',', |
| 61 | + allowQuoting: true, |
| 62 | + allowSparse: true); |
| 63 | + IDataView testDataView = mlContext.Data.LoadFromTextFile<SampleObservation>( |
| 64 | + path: TestDataPath, |
| 65 | + hasHeader: true, |
| 66 | + separatorChar: ',', |
| 67 | + allowQuoting: true, |
| 68 | + allowSparse: true); |
| 69 | + |
| 70 | + // Common data process configuration with pipeline data transformations |
| 71 | + var dataProcessPipeline = mlContext.Transforms.Concatenate("Out", new[] { "In" }) |
| 72 | + .AppendCacheCheckpoint(mlContext); |
| 73 | + |
| 74 | + // Set the training algorithm, then create and config the modelBuilder |
| 75 | + var trainer = mlContext.MulticlassClassification.Trainers.LightGbm(new Options() { NumLeaves = 2, Booster = new Options.TreeBooster.Options() { }, LabelColumn = "Label", FeatureColumn = "Features" }); |
| 76 | + var trainingPipeline = dataProcessPipeline.Append(trainer); |
| 77 | + |
| 78 | + // Train the model fitting to the DataSet |
| 79 | + Console.WriteLine("=============== Training the model ==============="); |
| 80 | + var trainedModel = trainingPipeline.Fit(trainingDataView); |
| 81 | + |
| 82 | + // Evaluate the model and show accuracy stats |
| 83 | + Console.WriteLine("===== Evaluating Model's accuracy with Test data ====="); |
| 84 | + var predictions = trainedModel.Transform(testDataView); |
| 85 | + |
| 86 | + // Save/persist the trained model to a .ZIP file |
| 87 | + Console.WriteLine($"=============== Saving the model ==============="); |
| 88 | + using (var fs = new FileStream(ModelPath, FileMode.Create, FileAccess.Write, FileShare.Write)) |
| 89 | + mlContext.Model.Save(trainedModel, fs); |
| 90 | + |
| 91 | + Console.WriteLine("The model is saved to {0}", ModelPath); |
| 92 | + Console.WriteLine("=============== End of training process ==============="); |
| 93 | + |
| 94 | + return trainedModel; |
| 95 | + } |
| 96 | + |
| 97 | + // Try/test a single prediction by loading the model from the file, first. |
| 98 | + private static void Predict(MLContext mlContext) |
| 99 | + { |
| 100 | + //Load data to test. Could be any test data. For demonstration purpose train data is used here. |
| 101 | + IDataView trainingDataView = mlContext.Data.LoadFromTextFile<SampleObservation>( |
| 102 | + path: TestDataPath, |
| 103 | + hasHeader: true, |
| 104 | + separatorChar: ',', |
| 105 | + allowQuoting: true, |
| 106 | + allowSparse: true); |
| 107 | + |
| 108 | + var sample = mlContext.Data.CreateEnumerable<SampleObservation>(trainingDataView, false).First(); |
| 109 | + |
| 110 | + ITransformer trainedModel; |
| 111 | + using (var stream = new FileStream(ModelPath, FileMode.Open, FileAccess.Read, FileShare.Read)) |
| 112 | + { |
| 113 | + trainedModel = mlContext.Model.Load(stream); |
| 114 | + } |
| 115 | + |
| 116 | + // Create prediction engine related to the loaded trained model |
| 117 | + var predEngine = trainedModel.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlContext); |
| 118 | + |
| 119 | + //Score |
| 120 | + var resultprediction = predEngine.Predict(sample); |
| 121 | + |
| 122 | + Console.WriteLine($"=============== Single Prediction ==============="); |
| 123 | + Console.WriteLine($"Actual value: {sample.Label} | Predicted value: {resultprediction.Prediction} | Predicted scores: [{String.Join(", ", resultprediction.Score)}]"); |
| 124 | + Console.WriteLine($"=================================================="); |
| 125 | + } |
| 126 | + |
| 127 | + } |
| 128 | + |
| 129 | + public class SampleObservation |
| 130 | + { |
| 131 | + [ColumnName("Label"), LoadColumn(0)] |
| 132 | + public bool Label { get; set; } |
| 133 | + |
| 134 | + |
| 135 | + [ColumnName("col1"), LoadColumn(1)] |
| 136 | + public float Col1 { get; set; } |
| 137 | + |
| 138 | + |
| 139 | + [ColumnName("col2"), LoadColumn(0)] |
| 140 | + public float Col2 { get; set; } |
| 141 | + |
| 142 | + |
| 143 | + [ColumnName("col3"), LoadColumn(0)] |
| 144 | + public string Col3 { get; set; } |
| 145 | + |
| 146 | + |
| 147 | + [ColumnName("col4"), LoadColumn(0)] |
| 148 | + public int Col4 { get; set; } |
| 149 | + |
| 150 | + |
| 151 | + [ColumnName("col5"), LoadColumn(0)] |
| 152 | + public uint Col5 { get; set; } |
| 153 | + |
| 154 | + |
| 155 | + } |
| 156 | + |
| 157 | + public class SamplePrediction |
| 158 | + { |
| 159 | + // ColumnName attribute is used to change the column name from |
| 160 | + // its default value, which is the name of the field. |
| 161 | + [ColumnName("PredictedLabel")] |
| 162 | + public Boolean Prediction { get; set; } |
| 163 | + public float[] Score { get; set; } |
| 164 | + } |
| 165 | + |
| 166 | +} |
0 commit comments