Skip to content

WithOnFitDelegate test coverage and code samples #3733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
PeterPann23 opened this issue May 16, 2019 · 3 comments
Closed

WithOnFitDelegate test coverage and code samples #3733

PeterPann23 opened this issue May 16, 2019 · 3 comments
Labels
need info This issue needs more info before triage

Comments

@PeterPann23
Copy link

System information

  • OS version/distro:
    Windows 10
  • .NET Version (eg., dotnet --info):
    Release 1.0.0.0.

Issue

  • Related
    Mising code sample #3732

  • What did you do?
    Looked for implementations of any samples or tests in regards of how to use the WithOnFitDelegate. I have found none, no unit tests, no samples and not much documentation even though it looks like all networks cal on fit in one way or another.

  • What did you expect?
    I expect to find some documentation on the method as it allows on looking at the progress of the model being build

@yaeldekel
Copy link

Hi @PeterPann23 .
PR #3738 added a sample for WithOnFitDelegate. Does this help?

@yaeldekel yaeldekel added the need info This issue needs more info before triage label May 22, 2019
@PeterPann23
Copy link
Author

not really, unless I understand it correctly the WithOnFitDelegate should be called several times (with networks that support it) and allow users to "gage" the progress or save an intermediate result kind of "model per epoc"

I was thinking, would be nice to see some tests to see if called correctly and a sample for those that would like to use it in a business like type looking at something like this mock sample:

using System;
using System.IO;
using System.Linq;
using System.Text;

namespace Microsoft.ML.Samples.WithOnFitDelegate
{
    //ml namespaces used
    using Microsoft.ML.Data;
    using Microsoft.ML.Transforms;


    class Program
    {
        const string RawFeatures = "RawFeatures";
        const string KeyColumn = "KeyColumn";
        const string Features = "Features";
        const string Label = "Label";
        const string PredictedLabelIndex = "PredictedLabelIndex";
        private static string BaseDatasetsRelativePath = @"../../../../Data";
        private static string DataSetRealtivePath = $"{BaseDatasetsRelativePath}/iris-full.txt";
        private static string DataPath = GetAbsolutePath(DataSetRealtivePath);
        static readonly MLContext mlContext= new MLContext();
        static DataOperationsCatalog.TrainTestData horizonDataset;
        static int iteration = default;
        static void Main(string[] args)
        {

            IDataView fullData = mlContext.Data.LoadFromTextFile(path: DataPath,
                                    columns: new[]
                                                {
                                                                new TextLoader.Column(Label, DataKind.Single, 0),
                                                                new TextLoader.Column(nameof(IrisData.SepalLength), DataKind.Single, 1),
                                                                new TextLoader.Column(nameof(IrisData.SepalWidth), DataKind.Single, 2),
                                                                new TextLoader.Column(nameof(IrisData.PetalLength), DataKind.Single, 3),
                                                                new TextLoader.Column(nameof(IrisData.PetalWidth), DataKind.Single, 4),
                                                },
                                    hasHeader: true,
                                    separatorChar: '\t');

            horizonDataset = mlContext.Data.TrainTestSplit(fullData);

            //map columns for normalization
            var featureColumns = new[] { nameof(IrisData.SepalLength), nameof(IrisData.SepalWidth), nameof(IrisData.PetalLength), nameof(IrisData.PetalWidth) } ;
           
            //build pipeline
            var pipeline = mlContext.Transforms.Concatenate(outputColumnName: RawFeatures, inputColumnNames: featureColumns)
                    .Append(mlContext.Transforms.Conversion.MapValueToKey(outputColumnName: KeyColumn, inputColumnName: Label))
                    .Append(mlContext.Transforms.NormalizeMinMax(outputColumnName: Features, inputColumnName: RawFeatures))
                    .AppendCacheCheckpoint(mlContext);


            //generate trainer based on arguments, if empty use LightGmb
            if (args is null || args.Length == 0 || args.Contains("LighGmb"))
            {

                pipeline.Append(mlContext.MulticlassClassification.Trainers.LightGbm(options: new Trainers.LightGbm.LightGbmMulticlassTrainer.Options()
                {
                    EarlyStoppingRound = 5,
                    LabelColumnName = Label,
                    FeatureColumnName = Features
                }));

            }
            else
            {
                // do the other suported pipelines
                if (args.Contains("LbfgsMaximumEntropy"))
                {
                    pipeline.Append(mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(options: new Trainers.LbfgsMaximumEntropyMulticlassTrainer.Options()
                    {
                        LabelColumnName = Label,
                        FeatureColumnName = Features,
                        ShowTrainingStatistics = true
                    }));
                }
                else {
                    Console.WriteLine("not supported trainer");
                    return;
                }
            }

            //add mapper to get the index of the predicted key in relation to the keys provided and the score
            pipeline.Append(mlContext.Transforms.CopyColumns(inputColumnName: KeyColumn, outputColumnName: nameof(PredictedLabelIndex)));

            //the on fit getting the final or intermediate result depending the trainer selected
            pipeline.WithOnFitDelegate(fit => OnFit(fit));

            //train the model
            var model = pipeline.Fit(horizonDataset.TrainSet);
            // save the model
        }


        /// <summary>
        /// get the intemediate result and store it for later analisys and use
        /// </summary>
        /// <param name="fit"></param>
        static void OnFit(TransformerChain<NormalizingTransformer> fit)
        {
            try
            {
                if (fit is ITransformer transformer)
                {
                    iteration++;
                    var data = transformer.Transform(horizonDataset.TestSet);
                    var metrics = mlContext.MulticlassClassification.Evaluate(data, KeyColumn);
                    Console.WriteLine($"{DateTime.Now.ToString("dd.MM.yyyy HH:mm:ss")} - Iteration {iteration} Micro= {metrics.MicroAccuracy:P4} Macro= {metrics.MacroAccuracy:P4} LogLoss= {metrics.LogLoss:P4}");
                    StringBuilder sb = new StringBuilder();
                    sb.AppendLine($@"Iteration { iteration}
    Micro = { metrics.MicroAccuracy:P6}
    Macro= {metrics.MacroAccuracy:P6}
    LogLoss= {metrics.LogLoss:P6}
    LogLossReduction= {metrics.LogLossReduction:P6}
");
                    sb.AppendLine("");
                    sb.AppendLine(metrics.ConfusionMatrix.GetFormattedConfusionTable());
                    var dir = Path.Combine(new DirectoryInfo("..\\Data").FullName, "Epoc", iteration.ToString());
                    using var fs = new FileInfo(Path.Combine(dir, "model.zip")).OpenWrite();
                    mlContext.Model.Save(transformer, horizonDataset.TestSet.Schema, fs);

                    File.WriteAllText(Path.Combine(dir, "report.txt"), sb.ToString());
                }
            }
            catch (Exception e)
            {
                Console.WriteLine("Error on Fit with msg:" + e.Message);
            }

        }


        public static string GetAbsolutePath(string relativePath)
        {
            FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location);
            string assemblyFolderPath = _dataRoot.Directory.FullName;

            string fullPath = Path.Combine(assemblyFolderPath, relativePath);

            return fullPath;
        }
    }
}

@PeterPann23
Copy link
Author

gave an option for better documentation for the users, than current as the current one covers no real world use case

@ghost ghost locked as resolved and limited conversation to collaborators Mar 21, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
need info This issue needs more info before triage
Projects
None yet
Development

No branches or pull requests

2 participants