diff --git a/Microsoft.ML.AutoML.sln b/Microsoft.ML.AutoML.sln new file mode 100644 index 0000000000..280cef5704 --- /dev/null +++ b/Microsoft.ML.AutoML.sln @@ -0,0 +1,91 @@ + +Microsoft Visual Studio Solution File, Format Version 12.00 +# Visual Studio 15 +VisualStudioVersion = 15.0.28010.2050 +MinimumVisualStudioVersion = 10.0.40219.1 +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Auto", "src\Microsoft.ML.Auto\Microsoft.ML.Auto.csproj", "{B3727729-3DF8-47E0-8710-9B41DAF55817}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.AutoML.Tests", "test\Microsoft.ML.AutoML.Tests\Microsoft.ML.AutoML.Tests.csproj", "{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "mlnet", "src\mlnet\mlnet.csproj", "{ED714FA5-6F89-401B-9E7F-CADF1373C553}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "mlnet.Tests", "test\mlnet.Tests\mlnet.Tests.csproj", "{AAC3E4E6-C146-44BB-8873-A1E61D563F2A}" +EndProject +Global + GlobalSection(SolutionConfigurationPlatforms) = preSolution + Debug|Any CPU = Debug|Any CPU + Debug-Intrinsics|Any CPU = Debug-Intrinsics|Any CPU + Debug-netfx|Any CPU = Debug-netfx|Any CPU + Release|Any CPU = Release|Any CPU + Release-Intrinsics|Any CPU = Release-Intrinsics|Any CPU + Release-netfx|Any CPU = Release-netfx|Any CPU + EndGlobalSection + GlobalSection(ProjectConfigurationPlatforms) = postSolution + {B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug|Any CPU.Build.0 = Debug|Any CPU + {B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU + {B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU + {B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU + {B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU + {B3727729-3DF8-47E0-8710-9B41DAF55817}.Release|Any CPU.ActiveCfg = Release|Any CPU + {B3727729-3DF8-47E0-8710-9B41DAF55817}.Release|Any CPU.Build.0 = Release|Any CPU + {B3727729-3DF8-47E0-8710-9B41DAF55817}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU + {B3727729-3DF8-47E0-8710-9B41DAF55817}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU + {B3727729-3DF8-47E0-8710-9B41DAF55817}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU + {B3727729-3DF8-47E0-8710-9B41DAF55817}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU + {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug|Any CPU.Build.0 = Debug|Any CPU + {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU + {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU + {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU + {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU + {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release|Any CPU.ActiveCfg = Release|Any CPU + {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release|Any CPU.Build.0 = Release|Any CPU + {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU + {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU + {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU + {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU + {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug|Any CPU.Build.0 = Debug|Any CPU + {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU + {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU + {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU + {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU + {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release|Any CPU.ActiveCfg = Release|Any CPU + {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release|Any CPU.Build.0 = Release|Any CPU + {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU + {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU + {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU + {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU + {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug|Any CPU.Build.0 = Debug|Any CPU + {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU + {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU + {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU + {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU + {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release|Any CPU.ActiveCfg = Release|Any CPU + {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release|Any CPU.Build.0 = Release|Any CPU + {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU + {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU + {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU + {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU + {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug|Any CPU.Build.0 = Debug|Any CPU + {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU + {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU + {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU + {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU + {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release|Any CPU.ActiveCfg = Release|Any CPU + {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release|Any CPU.Build.0 = Release|Any CPU + {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU + {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU + {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU + {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU + EndGlobalSection + GlobalSection(SolutionProperties) = preSolution + HideSolutionNode = FALSE + EndGlobalSection + GlobalSection(ExtensibilityGlobals) = postSolution + SolutionGuid = {8C1BC26C-B87E-47CD-928E-00EFE4353B40} + EndGlobalSection +EndGlobal diff --git a/build.proj b/build.proj index 15fea4e309..3dd5edd0e7 100644 --- a/build.proj +++ b/build.proj @@ -22,6 +22,7 @@ + diff --git a/src/Microsoft.ML.Auto/API/AutoInferenceCatalog.cs b/src/Microsoft.ML.Auto/API/AutoInferenceCatalog.cs new file mode 100644 index 0000000000..adf04111f2 --- /dev/null +++ b/src/Microsoft.ML.Auto/API/AutoInferenceCatalog.cs @@ -0,0 +1,79 @@ +// Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + public sealed class AutoMLCatalog + { + private readonly MLContext _context; + + internal AutoMLCatalog(MLContext context) + { + _context = context; + } + + public RegressionExperiment CreateRegressionExperiment(uint maxExperimentTimeInSeconds) + { + return new RegressionExperiment(_context, new RegressionExperimentSettings() + { + MaxExperimentTimeInSeconds = maxExperimentTimeInSeconds + }); + } + + public RegressionExperiment CreateRegressionExperiment(RegressionExperimentSettings experimentSettings) + { + return new RegressionExperiment(_context, experimentSettings); + } + + public BinaryClassificationExperiment CreateBinaryClassificationExperiment(uint maxExperimentTimeInSeconds) + { + return new BinaryClassificationExperiment(_context, new BinaryExperimentSettings() + { + MaxExperimentTimeInSeconds = maxExperimentTimeInSeconds + }); + } + + public BinaryClassificationExperiment CreateBinaryClassificationExperiment(BinaryExperimentSettings experimentSettings) + { + return new BinaryClassificationExperiment(_context, experimentSettings); + } + + public MulticlassClassificationExperiment CreateMulticlassClassificationExperiment(uint maxExperimentTimeInSeconds) + { + return new MulticlassClassificationExperiment(_context, new MulticlassExperimentSettings() + { + MaxExperimentTimeInSeconds = maxExperimentTimeInSeconds + }); + } + + public MulticlassClassificationExperiment CreateMulticlassClassificationExperiment(MulticlassExperimentSettings experimentSettings) + { + return new MulticlassClassificationExperiment(_context, experimentSettings); + } + + public ColumnInferenceResults InferColumns(string path, string labelColumn = DefaultColumnNames.Label, char? separatorChar = null, bool? allowQuotedStrings = null, + bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true) + { + UserInputValidationUtil.ValidateInferColumnsArgs(path, labelColumn); + return ColumnInferenceApi.InferColumns(_context, path, labelColumn, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns); + } + + public ColumnInferenceResults InferColumns(string path, ColumnInformation columnInformation, char? separatorChar = null, bool? allowQuotedStrings = null, + bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true) + { + columnInformation = columnInformation ?? new ColumnInformation(); + UserInputValidationUtil.ValidateInferColumnsArgs(path, columnInformation); + return ColumnInferenceApi.InferColumns(_context, path, columnInformation, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns); + } + + public ColumnInferenceResults InferColumns(string path, uint labelColumnIndex, bool hasHeader = false, char? separatorChar = null, + bool? allowQuotedStrings = null, bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true) + { + UserInputValidationUtil.ValidateInferColumnsArgs(path); + return ColumnInferenceApi.InferColumns(_context, path, labelColumnIndex, hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns); + } + } +} diff --git a/src/Microsoft.ML.Auto/API/BinaryClassificationExperiment.cs b/src/Microsoft.ML.Auto/API/BinaryClassificationExperiment.cs new file mode 100644 index 0000000000..3ab9dbb7a1 --- /dev/null +++ b/src/Microsoft.ML.Auto/API/BinaryClassificationExperiment.cs @@ -0,0 +1,73 @@ +// Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + public sealed class BinaryExperimentSettings : ExperimentSettings + { + public BinaryClassificationMetric OptimizingMetric { get; set; } = BinaryClassificationMetric.Accuracy; + public ICollection Trainers { get; } = + Enum.GetValues(typeof(BinaryClassificationTrainer)).OfType().ToList(); + } + + public enum BinaryClassificationMetric + { + Accuracy, + AreaUnderRocCurve, + AreaUnderPrecisionRecallCurve, + F1Score, + PositivePrecision, + PositiveRecall, + NegativePrecision, + NegativeRecall, + } + + public enum BinaryClassificationTrainer + { + AveragedPerceptron, + FastForest, + FastTree, + LightGbm, + LinearSupportVectorMachines, + LbfgsLogisticRegression, + SdcaLogisticRegression, + SgdCalibrated, + SymbolicSgdLogisticRegression, + } + + public sealed class BinaryClassificationExperiment : ExperimentBase + { + internal BinaryClassificationExperiment(MLContext context, BinaryExperimentSettings settings) + : base(context, + new BinaryMetricsAgent(context, settings.OptimizingMetric), + new OptimizingMetricInfo(settings.OptimizingMetric), + settings, + TaskKind.BinaryClassification, + TrainerExtensionUtil.GetTrainerNames(settings.Trainers)) + { + } + } + + public static class BinaryExperimentResultExtensions + { + public static RunDetail Best(this IEnumerable> results, BinaryClassificationMetric metric = BinaryClassificationMetric.Accuracy) + { + var metricsAgent = new BinaryMetricsAgent(null, metric); + var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing; + return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing); + } + + public static CrossValidationRunDetail Best(this IEnumerable> results, BinaryClassificationMetric metric = BinaryClassificationMetric.Accuracy) + { + var metricsAgent = new BinaryMetricsAgent(null, metric); + var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing; + return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing); + } + } +} diff --git a/src/Microsoft.ML.Auto/API/ColumnInference.cs b/src/Microsoft.ML.Auto/API/ColumnInference.cs new file mode 100644 index 0000000000..588116897d --- /dev/null +++ b/src/Microsoft.ML.Auto/API/ColumnInference.cs @@ -0,0 +1,27 @@ +// Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Collections.ObjectModel; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + public sealed class ColumnInferenceResults + { + public TextLoader.Options TextLoaderOptions { get; internal set; } = new TextLoader.Options(); + public ColumnInformation ColumnInformation { get; internal set; } = new ColumnInformation(); + } + + public sealed class ColumnInformation + { + public string LabelColumnName { get; set; } = DefaultColumnNames.Label; + public string ExampleWeightColumnName { get; set; } + public string SamplingKeyColumnName { get; set; } + public ICollection CategoricalColumnNames { get; } = new Collection(); + public ICollection NumericColumnNames { get; } = new Collection(); + public ICollection TextColumnNames { get; } = new Collection(); + public ICollection IgnoredColumnNames { get; } = new Collection(); + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/API/ExperimentBase.cs b/src/Microsoft.ML.Auto/API/ExperimentBase.cs new file mode 100644 index 0000000000..381196c54c --- /dev/null +++ b/src/Microsoft.ML.Auto/API/ExperimentBase.cs @@ -0,0 +1,202 @@ +// Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.ML.Auto +{ + public abstract class ExperimentBase where TMetrics : class + { + protected readonly MLContext Context; + + private readonly IMetricsAgent _metricsAgent; + private readonly OptimizingMetricInfo _optimizingMetricInfo; + private readonly ExperimentSettings _settings; + private readonly TaskKind _task; + private readonly IEnumerable _trainerWhitelist; + + internal ExperimentBase(MLContext context, + IMetricsAgent metricsAgent, + OptimizingMetricInfo optimizingMetricInfo, + ExperimentSettings settings, + TaskKind task, + IEnumerable trainerWhitelist) + { + Context = context; + _metricsAgent = metricsAgent; + _optimizingMetricInfo = optimizingMetricInfo; + _settings = settings; + _task = task; + _trainerWhitelist = trainerWhitelist; + } + + public IEnumerable> Execute(IDataView trainData, string labelColumn = DefaultColumnNames.Label, + string samplingKeyColumn = null, IEstimator preFeaturizers = null, IProgress> progressHandler = null) + { + var columnInformation = new ColumnInformation() + { + LabelColumnName = labelColumn, + SamplingKeyColumnName = samplingKeyColumn + }; + return Execute(trainData, columnInformation, preFeaturizers, progressHandler); + } + + public IEnumerable> Execute(IDataView trainData, ColumnInformation columnInformation, + IEstimator preFeaturizer = null, IProgress> progressHandler = null) + { + // Cross val threshold for # of dataset rows -- + // If dataset has < threshold # of rows, use cross val. + // Else, run experiment using train-validate split. + const int crossValRowCountThreshold = 15000; + + var rowCount = DatasetDimensionsUtil.CountRows(trainData, crossValRowCountThreshold); + + if (rowCount < crossValRowCountThreshold) + { + const int numCrossValFolds = 10; + var splitResult = SplitUtil.CrossValSplit(Context, trainData, numCrossValFolds, columnInformation?.SamplingKeyColumnName); + return ExecuteCrossValSummary(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizer, progressHandler); + } + else + { + var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumnName); + return ExecuteTrainValidate(splitResult.trainData, columnInformation, splitResult.validationData, preFeaturizer, progressHandler); + } + } + + public IEnumerable> Execute(IDataView trainData, IDataView validationData, string labelColumn = DefaultColumnNames.Label, IEstimator preFeaturizer = null, IProgress> progressHandler = null) + { + var columnInformation = new ColumnInformation() { LabelColumnName = labelColumn }; + return Execute(trainData, validationData, columnInformation, preFeaturizer, progressHandler); + } + + public IEnumerable> Execute(IDataView trainData, IDataView validationData, ColumnInformation columnInformation, IEstimator preFeaturizer = null, IProgress> progressHandler = null) + { + if (validationData == null) + { + var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumnName); + trainData = splitResult.trainData; + validationData = splitResult.validationData; + } + return ExecuteTrainValidate(trainData, columnInformation, validationData, preFeaturizer, progressHandler); + } + + public IEnumerable> Execute(IDataView trainData, uint numberOfCVFolds, ColumnInformation columnInformation = null, IEstimator preFeaturizer = null, IProgress> progressHandler = null) + { + UserInputValidationUtil.ValidateNumberOfCVFoldsArg(numberOfCVFolds); + var splitResult = SplitUtil.CrossValSplit(Context, trainData, numberOfCVFolds, columnInformation?.SamplingKeyColumnName); + return ExecuteCrossVal(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizer, progressHandler); + } + + public IEnumerable> Execute(IDataView trainData, + uint numberOfCVFolds, string labelColumn = DefaultColumnNames.Label, + string samplingKeyColumn = null, IEstimator preFeaturizer = null, + Progress> progressHandler = null) + { + var columnInformation = new ColumnInformation() + { + LabelColumnName = labelColumn, + SamplingKeyColumnName = samplingKeyColumn + }; + return Execute(trainData, numberOfCVFolds, columnInformation, preFeaturizer, progressHandler); + } + + private IEnumerable> ExecuteTrainValidate(IDataView trainData, + ColumnInformation columnInfo, + IDataView validationData, + IEstimator preFeaturizer, + IProgress> progressHandler) + { + columnInfo = columnInfo ?? new ColumnInformation(); + UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData); + + // Apply pre-featurizer + ITransformer preprocessorTransform = null; + if (preFeaturizer != null) + { + preprocessorTransform = preFeaturizer.Fit(trainData); + trainData = preprocessorTransform.Transform(trainData); + validationData = preprocessorTransform.Transform(validationData); + } + + var runner = new TrainValidateRunner(Context, trainData, validationData, columnInfo.LabelColumnName, _metricsAgent, + preFeaturizer, preprocessorTransform, _settings.DebugLogger); + var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainData, columnInfo); + return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner); + } + + private IEnumerable> ExecuteCrossVal(IDataView[] trainDatasets, + ColumnInformation columnInfo, + IDataView[] validationDatasets, + IEstimator preFeaturizer, + IProgress> progressHandler) + { + columnInfo = columnInfo ?? new ColumnInformation(); + UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0]); + + // Apply pre-featurizer + ITransformer[] preprocessorTransforms = null; + (trainDatasets, validationDatasets, preprocessorTransforms) = ApplyPreFeaturizerCrossVal(trainDatasets, validationDatasets, preFeaturizer); + + var runner = new CrossValRunner(Context, trainDatasets, validationDatasets, _metricsAgent, preFeaturizer, + preprocessorTransforms, columnInfo.LabelColumnName, _settings.DebugLogger); + var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainDatasets[0], columnInfo); + return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner); + } + + private IEnumerable> ExecuteCrossValSummary(IDataView[] trainDatasets, + ColumnInformation columnInfo, + IDataView[] validationDatasets, + IEstimator preFeaturizer, + IProgress> progressHandler) + { + columnInfo = columnInfo ?? new ColumnInformation(); + UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0]); + + // Apply pre-featurizer + ITransformer[] preprocessorTransforms = null; + (trainDatasets, validationDatasets, preprocessorTransforms) = ApplyPreFeaturizerCrossVal(trainDatasets, validationDatasets, preFeaturizer); + + var runner = new CrossValSummaryRunner(Context, trainDatasets, validationDatasets, _metricsAgent, preFeaturizer, + preprocessorTransforms, columnInfo.LabelColumnName, _optimizingMetricInfo, _settings.DebugLogger); + var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainDatasets[0], columnInfo); + return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner); + } + + private IEnumerable Execute(ColumnInformation columnInfo, + DatasetColumnInfo[] columns, + IEstimator preFeaturizer, + IProgress progressHandler, + IRunner runner) + where TRunDetail : RunDetail + { + // Execute experiment & get all pipelines run + var experiment = new Experiment(Context, _task, _optimizingMetricInfo, progressHandler, + _settings, _metricsAgent, _trainerWhitelist, columns, runner); + + return experiment.Execute(); + } + + private static (IDataView[] trainDatasets, IDataView[] validDatasets, ITransformer[] preprocessorTransforms) + ApplyPreFeaturizerCrossVal(IDataView[] trainDatasets, IDataView[] validDatasets, IEstimator preFeaturizer) + { + if (preFeaturizer == null) + { + return (trainDatasets, validDatasets, null); + } + + var preprocessorTransforms = new ITransformer[trainDatasets.Length]; + for (var i = 0; i < trainDatasets.Length; i++) + { + // Preprocess train and validation data + preprocessorTransforms[i] = preFeaturizer.Fit(trainDatasets[i]); + trainDatasets[i] = preprocessorTransforms[i].Transform(trainDatasets[i]); + validDatasets[i] = preprocessorTransforms[i].Transform(validDatasets[i]); + } + + return (trainDatasets, validDatasets, preprocessorTransforms); + } + } +} diff --git a/src/Microsoft.ML.Auto/API/ExperimentSettings.cs b/src/Microsoft.ML.Auto/API/ExperimentSettings.cs new file mode 100644 index 0000000000..43c6c8befe --- /dev/null +++ b/src/Microsoft.ML.Auto/API/ExperimentSettings.cs @@ -0,0 +1,33 @@ +// Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.IO; +using System.Threading; + +namespace Microsoft.ML.Auto +{ + public class ExperimentSettings + { + public uint MaxExperimentTimeInSeconds { get; set; } = 24 * 60 * 60; + public CancellationToken CancellationToken { get; set; } = default; + + /// + /// This is a pointer to a directory where all models trained during the AutoML experiment will be saved. + /// If null, models will be kept in memory instead of written to disk. + /// (Please note: for an experiment with high runtime operating on a large dataset, opting to keep models in + /// memory could cause a system to run out of memory.) + /// + public DirectoryInfo CacheDirectory { get; set; } = new DirectoryInfo(Path.Combine(Path.GetTempPath(), "Microsoft.ML.Auto")); + + /// + /// This setting controls whether or not an AutoML experiment will make use of ML.NET-provided caching. + /// If set to true, caching will be forced on for all pipelines. If set to false, caching will be forced off. + /// If set to null (default value), AutoML will decide whether to enable caching for each model. + /// + public bool? CacheBeforeTrainer = null; + + internal int MaxModels = int.MaxValue; + internal IDebugLogger DebugLogger; + } +} diff --git a/src/Microsoft.ML.Auto/API/InferenceException.cs b/src/Microsoft.ML.Auto/API/InferenceException.cs new file mode 100644 index 0000000000..423c4ae3ce --- /dev/null +++ b/src/Microsoft.ML.Auto/API/InferenceException.cs @@ -0,0 +1,31 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Auto +{ + public enum InferenceType + { + ColumnDataKind, + ColumnSplit, + Label, + } + + public sealed class InferenceException : Exception + { + public InferenceType InferenceType; + + public InferenceException(InferenceType inferenceType, string message) + : base(message) + { + } + + public InferenceException(InferenceType inferenceType, string message, Exception inner) + : base(message, inner) + { + } + } + +} diff --git a/src/Microsoft.ML.Auto/API/MLContextExtension.cs b/src/Microsoft.ML.Auto/API/MLContextExtension.cs new file mode 100644 index 0000000000..9287fe827c --- /dev/null +++ b/src/Microsoft.ML.Auto/API/MLContextExtension.cs @@ -0,0 +1,14 @@ +// Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Auto +{ + public static class MLContextExtension + { + public static AutoMLCatalog Auto(this MLContext mlContext) + { + return new AutoMLCatalog(mlContext); + } + } +} diff --git a/src/Microsoft.ML.Auto/API/MulticlassClassificationExperiment.cs b/src/Microsoft.ML.Auto/API/MulticlassClassificationExperiment.cs new file mode 100644 index 0000000000..f7f5a856cb --- /dev/null +++ b/src/Microsoft.ML.Auto/API/MulticlassClassificationExperiment.cs @@ -0,0 +1,71 @@ +// Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + public sealed class MulticlassExperimentSettings : ExperimentSettings + { + public MulticlassClassificationMetric OptimizingMetric { get; set; } = MulticlassClassificationMetric.MicroAccuracy; + public ICollection Trainers { get; } = + Enum.GetValues(typeof(MulticlassClassificationTrainer)).OfType().ToList(); + } + + public enum MulticlassClassificationMetric + { + MicroAccuracy, + MacroAccuracy, + LogLoss, + LogLossReduction, + TopKAccuracy, + } + + public enum MulticlassClassificationTrainer + { + AveragedPerceptronOVA, + FastForestOVA, + FastTreeOVA, + LightGbm, + LinearSupportVectorMachinesOVA, + LbfgsMaximumEntropy, + LbfgsLogisticRegressionOVA, + SdcaMaximumEntropy, + SgdCalibratedOVA, + SymbolicSgdLogisticRegressionOVA, + } + + public sealed class MulticlassClassificationExperiment : ExperimentBase + { + internal MulticlassClassificationExperiment(MLContext context, MulticlassExperimentSettings settings) + : base(context, + new MultiMetricsAgent(context, settings.OptimizingMetric), + new OptimizingMetricInfo(settings.OptimizingMetric), + settings, + TaskKind.MulticlassClassification, + TrainerExtensionUtil.GetTrainerNames(settings.Trainers)) + { + } + } + + public static class MulticlassExperimentResultExtensions + { + public static RunDetail Best(this IEnumerable> results, MulticlassClassificationMetric metric = MulticlassClassificationMetric.MicroAccuracy) + { + var metricsAgent = new MultiMetricsAgent(null, metric); + var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing; + return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing); + } + + public static CrossValidationRunDetail Best(this IEnumerable> results, MulticlassClassificationMetric metric = MulticlassClassificationMetric.MicroAccuracy) + { + var metricsAgent = new MultiMetricsAgent(null, metric); + var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing; + return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/API/Pipeline.cs b/src/Microsoft.ML.Auto/API/Pipeline.cs new file mode 100644 index 0000000000..864d82037b --- /dev/null +++ b/src/Microsoft.ML.Auto/API/Pipeline.cs @@ -0,0 +1,110 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; + +namespace Microsoft.ML.Auto +{ + internal class Pipeline + { + public PipelineNode[] Nodes { get; set; } + public bool CacheBeforeTrainer { get; set; } + + public Pipeline(PipelineNode[] nodes, bool cacheBeforeTrainer = false) + { + Nodes = nodes; + CacheBeforeTrainer = cacheBeforeTrainer; + } + + // (used by Newtonsoft) + internal Pipeline() + { + } + + public IEstimator ToEstimator(MLContext context) + { + var inferredPipeline = SuggestedPipeline.FromPipeline(context, this); + return inferredPipeline.ToEstimator(); + } + } + + internal class PipelineNode + { + public string Name { get; set; } + public PipelineNodeType NodeType { get; set; } + public string[] InColumns { get; set; } + public string[] OutColumns { get; set; } + public IDictionary Properties { get; set; } + + public PipelineNode(string name, PipelineNodeType nodeType, + string[] inColumns, string[] outColumns, + IDictionary properties = null) + { + Name = name; + NodeType = nodeType; + InColumns = inColumns; + OutColumns = outColumns; + Properties = properties ?? new Dictionary(); + } + + public PipelineNode(string name, PipelineNodeType nodeType, + string inColumn, string outColumn, IDictionary properties = null) : + this(name, nodeType, new string[] { inColumn }, new string[] { outColumn }, properties) + { + } + + public PipelineNode(string name, PipelineNodeType nodeType, + string[] inColumns, string outColumn, IDictionary properties = null) : + this(name, nodeType, inColumns, new string[] { outColumn }, properties) + { + } + + // (used by Newtonsoft) + internal PipelineNode() + { + } + } + + internal enum PipelineNodeType + { + Transform, + Trainer + } + + internal class CustomProperty + { + public string Name { get; set; } + public IDictionary Properties { get; set; } + + public CustomProperty(string name, IDictionary properties) + { + Name = name; + Properties = properties; + } + + internal CustomProperty() + { + } + } + + internal class PipelineScore + { + public readonly double Score; + + /// + /// This setting is true if the pipeline run succeeded and ran to completion. + /// Else, it is false if some exception was thrown before the run could complete. + /// + public readonly bool RunSucceded; + + internal readonly Pipeline Pipeline; + + internal PipelineScore(Pipeline pipeline, double score, bool runSucceeded) + { + Pipeline = pipeline; + Score = score; + RunSucceded = runSucceeded; + } + } +} diff --git a/src/Microsoft.ML.Auto/API/RegressionExperiment.cs b/src/Microsoft.ML.Auto/API/RegressionExperiment.cs new file mode 100644 index 0000000000..51f5988f64 --- /dev/null +++ b/src/Microsoft.ML.Auto/API/RegressionExperiment.cs @@ -0,0 +1,68 @@ +// Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + public sealed class RegressionExperimentSettings : ExperimentSettings + { + public RegressionMetric OptimizingMetric { get; set; } = RegressionMetric.RSquared; + public ICollection Trainers { get; } = + Enum.GetValues(typeof(RegressionTrainer)).OfType().ToList(); + } + + public enum RegressionMetric + { + MeanAbsoluteError, + MeanSquaredError, + RootMeanSquaredError, + RSquared + } + + public enum RegressionTrainer + { + FastForest, + FastTree, + FastTreeTweedie, + LightGbm, + OnlineGradientDescent, + Ols, + LbfgsPoissonRegression, + StochasticDualCoordinateAscent, + } + + public sealed class RegressionExperiment : ExperimentBase + { + internal RegressionExperiment(MLContext context, RegressionExperimentSettings settings) + : base(context, + new RegressionMetricsAgent(context, settings.OptimizingMetric), + new OptimizingMetricInfo(settings.OptimizingMetric), + settings, + TaskKind.Regression, + TrainerExtensionUtil.GetTrainerNames(settings.Trainers)) + { + } + } + + public static class RegressionExperimentResultExtensions + { + public static RunDetail Best(this IEnumerable> results, RegressionMetric metric = RegressionMetric.RSquared) + { + var metricsAgent = new RegressionMetricsAgent(null, metric); + var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing; + return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing); + } + + public static CrossValidationRunDetail Best(this IEnumerable> results, RegressionMetric metric = RegressionMetric.RSquared) + { + var metricsAgent = new RegressionMetricsAgent(null, metric); + var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing; + return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing); + } + } +} diff --git a/src/Microsoft.ML.Auto/API/RunDetails/CrossValidationRunDetail.cs b/src/Microsoft.ML.Auto/API/RunDetails/CrossValidationRunDetail.cs new file mode 100644 index 0000000000..713c820a99 --- /dev/null +++ b/src/Microsoft.ML.Auto/API/RunDetails/CrossValidationRunDetail.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.ML.Auto +{ + public sealed class CrossValidationRunDetail : RunDetail + { + public IEnumerable> Results { get; private set; } + + internal CrossValidationRunDetail(string trainerName, + IEstimator estimator, + Pipeline pipeline, + IEnumerable> results) : base(trainerName, estimator, pipeline) + { + Results = results; + } + } + + public sealed class TrainResult + { + public TMetrics ValidationMetrics { get; private set; } + public ITransformer Model { get { return _modelContainer.GetModel(); } } + public Exception Exception { get; private set; } + + private readonly ModelContainer _modelContainer; + + internal TrainResult(ModelContainer modelContainer, + TMetrics metrics, + Exception exception) + { + _modelContainer = modelContainer; + ValidationMetrics = metrics; + Exception = exception; + } + } + +} diff --git a/src/Microsoft.ML.Auto/API/RunDetails/RunDetail.cs b/src/Microsoft.ML.Auto/API/RunDetails/RunDetail.cs new file mode 100644 index 0000000000..a83670986d --- /dev/null +++ b/src/Microsoft.ML.Auto/API/RunDetails/RunDetail.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Auto +{ + public sealed class RunDetail : RunDetail + { + public TMetrics ValidationMetrics { get; private set; } + public ITransformer Model { get { return _modelContainer.GetModel(); } } + public Exception Exception { get; private set; } + + private readonly ModelContainer _modelContainer; + + internal RunDetail(string trainerName, + IEstimator estimator, + Pipeline pipeline, + ModelContainer modelContainer, + TMetrics metrics, + Exception exception) : base(trainerName, estimator, pipeline) + { + _modelContainer = modelContainer; + ValidationMetrics = metrics; + Exception = exception; + } + } + + public abstract class RunDetail + { + public string TrainerName { get; private set; } + public double RuntimeInSeconds { get; internal set; } + public IEstimator Estimator { get; private set; } + + internal Pipeline Pipeline { get; private set; } + internal double PipelineInferenceTimeInSeconds { get; set; } + + internal RunDetail(string trainerName, + IEstimator estimator, + Pipeline pipeline) + { + TrainerName = trainerName; + Estimator = estimator; + Pipeline = pipeline; + } + } +} diff --git a/src/Microsoft.ML.Auto/Assembly.cs b/src/Microsoft.ML.Auto/Assembly.cs new file mode 100644 index 0000000000..4a3999e45a --- /dev/null +++ b/src/Microsoft.ML.Auto/Assembly.cs @@ -0,0 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Microsoft.ML.AutoML.Tests, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")] +[assembly: InternalsVisibleTo("mlnet, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")] +[assembly: InternalsVisibleTo("mlnet.Tests, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")] +[assembly: InternalsVisibleTo("Benchmark, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")] diff --git a/src/Microsoft.ML.Auto/AutoMlUtils.cs b/src/Microsoft.ML.Auto/AutoMlUtils.cs new file mode 100644 index 0000000000..1ee5570ee3 --- /dev/null +++ b/src/Microsoft.ML.Auto/AutoMlUtils.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal static class AutoMlUtils + { + public static readonly ThreadLocal random = new ThreadLocal(() => new Random()); + + public static void Assert(bool boolVal, string message = null) + { + if (!boolVal) + { + message = message ?? "Assertion failed"; + throw new InvalidOperationException(message); + } + } + } +} diff --git a/src/Microsoft.ML.Auto/ColumnInference/ColumnGroupingInference.cs b/src/Microsoft.ML.Auto/ColumnInference/ColumnGroupingInference.cs new file mode 100644 index 0000000000..c4535ba7d0 --- /dev/null +++ b/src/Microsoft.ML.Auto/ColumnInference/ColumnGroupingInference.cs @@ -0,0 +1,151 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.ML.Data; +using static Microsoft.ML.Data.TextLoader; + +namespace Microsoft.ML.Auto +{ + /// + /// This class incapsulates logic for grouping together the inferred columns of the text file based on their type + /// and purpose, and generating column names. + /// + internal static class ColumnGroupingInference + { + /// + /// This is effectively a merger of and a + /// with support for vector-value columns. + /// + public class GroupingColumn + { + public string SuggestedName; + public DataKind ItemKind; + public ColumnPurpose Purpose; + public Range[] Ranges; + + public GroupingColumn(string name, DataKind kind, ColumnPurpose purpose, Range[] ranges) + { + SuggestedName = name; + ItemKind = kind; + Purpose = purpose; + Ranges = ranges; + } + + public TextLoader.Column GenerateTextLoaderColumn() + { + return new TextLoader.Column(SuggestedName, ItemKind, Ranges); + } + } + + /// + /// Group together the single-valued columns with the same type and purpose and generate column names. + /// + /// The host environment to use. + /// Whether the original file had a header. + /// If yes, the fields are used to generate the column + /// names, otherwise they are ignored. + /// The (detected) column types. + /// The (detected) column purposes. Must be parallel to . + /// The struct containing an array of grouped columns specifications. + public static GroupingColumn[] InferGroupingAndNames(MLContext env, bool hasHeader, ColumnTypeInference.Column[] types, PurposeInference.Column[] purposes) + { + var result = new List(); + var tuples = types.Zip(purposes, Tuple.Create).ToList(); + var grouped = + from t in tuples + group t by + new + { + t.Item1.ItemType, + t.Item2.Purpose, + purposeGroupId = GetPurposeGroupId(t.Item1.ColumnIndex, t.Item2.Purpose) + } + into g + select g; + + foreach (var g in grouped) + { + string name = (hasHeader && g.Count() == 1) + ? g.First().Item1.SuggestedName + : GetName(g.Key.ItemType.GetRawKind(), g.Key.Purpose, result); + + var ranges = GetRanges(g.Select(t => t.Item1.ColumnIndex).ToArray()); + result.Add(new GroupingColumn(name, g.Key.ItemType.GetRawKind(), g.Key.Purpose, ranges)); + } + + return result.ToArray(); + } + + private static int GetPurposeGroupId(int columnIndex, ColumnPurpose purpose) + { + if (purpose == ColumnPurpose.CategoricalFeature || + purpose == ColumnPurpose.TextFeature || + purpose == ColumnPurpose.Ignore) + return columnIndex; + return 0; + } + + private static string GetName(DataKind itemKind, ColumnPurpose purpose, List previousColumns) + { + string prefix = GetPurposeName(purpose, itemKind); + int i = 0; + string name = prefix; + while (previousColumns.Any(x => x.SuggestedName == name)) + { + i++; + name = string.Format("{0}{1:00}", prefix, i); + } + + return name; + } + + private static string GetPurposeName(ColumnPurpose purpose, DataKind itemKind) + { + switch (purpose) + { + case ColumnPurpose.NumericFeature: + if (itemKind == DataKind.Boolean) + { + return "BooleanFeatures"; + } + else + { + return "Features"; + } + case ColumnPurpose.CategoricalFeature: + return "Cat"; + default: + return Enum.GetName(typeof(ColumnPurpose), purpose); + } + } + + /// + /// Generates a collection of Ranges from indices. + /// + private static Range[] GetRanges(int[] indices) + { + Array.Sort(indices); + var allRanges = new List(); + var currRange = new Range(indices[0]); + for (int i = 1; i < indices.Length; i++) + { + if (indices[i] == currRange.Max + 1) + { + currRange.Max++; + } + else + { + allRanges.Add(currRange); + currRange = new Range(indices[i]); + } + } + allRanges.Add(currRange); + return allRanges.ToArray(); + } + } +} diff --git a/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs b/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs new file mode 100644 index 0000000000..6db0fab782 --- /dev/null +++ b/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs @@ -0,0 +1,150 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal static class ColumnInferenceApi + { + public static ColumnInferenceResults InferColumns(MLContext context, string path, uint labelColumnIndex, + bool hasHeader, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns) + { + var sample = TextFileSample.CreateFromFullFile(path); + var splitInference = InferSplit(context, sample, separatorChar, allowQuotedStrings, supportSparse); + var typeInference = InferColumnTypes(context, sample, splitInference, hasHeader, labelColumnIndex, null); + + // if no column is named label, + // rename label column to default ML.NET label column name + if (!typeInference.Columns.Any(c => c.SuggestedName == DefaultColumnNames.Label)) + { + typeInference.Columns[labelColumnIndex].SuggestedName = DefaultColumnNames.Label; + } + + var columnInfo = new ColumnInformation() { LabelColumnName = typeInference.Columns[labelColumnIndex].SuggestedName }; + + return InferColumns(context, path, columnInfo, hasHeader, splitInference, typeInference, trimWhitespace, groupColumns); + } + + public static ColumnInferenceResults InferColumns(MLContext context, string path, string labelColumn, + char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns) + { + var columnInfo = new ColumnInformation() { LabelColumnName = labelColumn }; + return InferColumns(context, path, columnInfo, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns); + } + + public static ColumnInferenceResults InferColumns(MLContext context, string path, ColumnInformation columnInfo, + char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns) + { + var sample = TextFileSample.CreateFromFullFile(path); + var splitInference = InferSplit(context, sample, separatorChar, allowQuotedStrings, supportSparse); + var typeInference = InferColumnTypes(context, sample, splitInference, true, null, columnInfo.LabelColumnName); + return InferColumns(context, path, columnInfo, true, splitInference, typeInference, trimWhitespace, groupColumns); + } + + public static ColumnInferenceResults InferColumns(MLContext context, string path, ColumnInformation columnInfo, bool hasHeader, + TextFileContents.ColumnSplitResult splitInference, ColumnTypeInference.InferenceResult typeInference, + bool trimWhitespace, bool groupColumns) + { + var loaderColumns = ColumnTypeInference.GenerateLoaderColumns(typeInference.Columns); + var typedLoaderOptions = new TextLoader.Options + { + Columns = loaderColumns, + Separators = new[] { splitInference.Separator.Value }, + AllowSparse = splitInference.AllowSparse, + AllowQuoting = splitInference.AllowQuote, + HasHeader = hasHeader, + TrimWhitespace = trimWhitespace + }; + var textLoader = context.Data.CreateTextLoader(typedLoaderOptions); + var dataView = textLoader.Load(path); + + var purposeInferenceResult = PurposeInference.InferPurposes(context, dataView, columnInfo); + + // start building result objects + IEnumerable columnResults = null; + IEnumerable<(string, ColumnPurpose)> purposeResults = null; + + // infer column grouping and generate column names + if (groupColumns) + { + var groupingResult = ColumnGroupingInference.InferGroupingAndNames(context, hasHeader, + typeInference.Columns, purposeInferenceResult); + + columnResults = groupingResult.Select(c => c.GenerateTextLoaderColumn()); + purposeResults = groupingResult.Select(c => (c.SuggestedName, c.Purpose)); + } + else + { + columnResults = loaderColumns; + purposeResults = purposeInferenceResult.Select(p => (dataView.Schema[p.ColumnIndex].Name, p.Purpose)); + } + + var textLoaderOptions = new TextLoader.Options() + { + Columns = columnResults.ToArray(), + AllowQuoting = splitInference.AllowQuote, + AllowSparse = splitInference.AllowSparse, + Separators = new char[] { splitInference.Separator.Value }, + HasHeader = hasHeader, + TrimWhitespace = trimWhitespace + }; + + return new ColumnInferenceResults() + { + TextLoaderOptions = textLoaderOptions, + ColumnInformation = ColumnInformationUtil.BuildColumnInfo(purposeResults) + }; + } + + private static TextFileContents.ColumnSplitResult InferSplit(MLContext context, TextFileSample sample, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse) + { + var separatorCandidates = separatorChar == null ? TextFileContents.DefaultSeparators : new char[] { separatorChar.Value }; + var splitInference = TextFileContents.TrySplitColumns(context, sample, separatorCandidates); + + // respect passed-in overrides + if (allowQuotedStrings != null) + { + splitInference.AllowQuote = allowQuotedStrings.Value; + } + if (supportSparse != null) + { + splitInference.AllowSparse = supportSparse.Value; + } + + if (!splitInference.IsSuccess) + { + throw new InferenceException(InferenceType.ColumnSplit, "Unable to split the file provided into multiple, consistent columns."); + } + + return splitInference; + } + + private static ColumnTypeInference.InferenceResult InferColumnTypes(MLContext context, TextFileSample sample, + TextFileContents.ColumnSplitResult splitInference, bool hasHeader, uint? labelColumnIndex, string label) + { + // infer column types + var typeInferenceResult = ColumnTypeInference.InferTextFileColumnTypes(context, sample, + new ColumnTypeInference.Arguments + { + ColumnCount = splitInference.ColumnCount, + Separator = splitInference.Separator.Value, + AllowSparse = splitInference.AllowSparse, + AllowQuote = splitInference.AllowQuote, + HasHeader = hasHeader, + LabelColumnIndex = labelColumnIndex, + Label = label + }); + + if (!typeInferenceResult.IsSuccess) + { + throw new InferenceException(InferenceType.ColumnDataKind, "Unable to infer column types of the file provided."); + } + + return typeInferenceResult; + } + } +} diff --git a/src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs b/src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs new file mode 100644 index 0000000000..c567730f6f --- /dev/null +++ b/src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs @@ -0,0 +1,93 @@ +// Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal static class ColumnInformationUtil + { + internal static ColumnPurpose? GetColumnPurpose(this ColumnInformation columnInfo, string columnName) + { + if (columnName == columnInfo.LabelColumnName) + { + return ColumnPurpose.Label; + } + + if (columnName == columnInfo.ExampleWeightColumnName) + { + return ColumnPurpose.Weight; + } + + if (columnName == columnInfo.SamplingKeyColumnName) + { + return ColumnPurpose.SamplingKey; + } + + if (columnInfo.CategoricalColumnNames.Contains(columnName)) + { + return ColumnPurpose.CategoricalFeature; + } + + if (columnInfo.NumericColumnNames.Contains(columnName)) + { + return ColumnPurpose.NumericFeature; + } + + if (columnInfo.TextColumnNames.Contains(columnName)) + { + return ColumnPurpose.TextFeature; + } + + if (columnInfo.IgnoredColumnNames.Contains(columnName)) + { + return ColumnPurpose.Ignore; + } + + return null; + } + + internal static ColumnInformation BuildColumnInfo(IEnumerable<(string name, ColumnPurpose purpose)> columnPurposes) + { + var columnInfo = new ColumnInformation(); + + foreach (var column in columnPurposes) + { + switch (column.purpose) + { + case ColumnPurpose.Label: + columnInfo.LabelColumnName = column.name; + break; + case ColumnPurpose.Weight: + columnInfo.ExampleWeightColumnName = column.name; + break; + case ColumnPurpose.SamplingKey: + columnInfo.SamplingKeyColumnName = column.name; + break; + case ColumnPurpose.CategoricalFeature: + columnInfo.CategoricalColumnNames.Add(column.name); + break; + case ColumnPurpose.Ignore: + columnInfo.IgnoredColumnNames.Add(column.name); + break; + case ColumnPurpose.NumericFeature: + columnInfo.NumericColumnNames.Add(column.name); + break; + case ColumnPurpose.TextFeature: + columnInfo.TextColumnNames.Add(column.name); + break; + } + } + + return columnInfo; + } + + public static ColumnInformation BuildColumnInfo(IEnumerable columns) + { + return BuildColumnInfo(columns.Select(c => (c.Name, c.Purpose))); + } + } +} diff --git a/src/Microsoft.ML.Auto/ColumnInference/ColumnPurpose.cs b/src/Microsoft.ML.Auto/ColumnInference/ColumnPurpose.cs new file mode 100644 index 0000000000..45bf787396 --- /dev/null +++ b/src/Microsoft.ML.Auto/ColumnInference/ColumnPurpose.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Auto +{ + internal enum ColumnPurpose + { + Ignore = 0, + Label = 1, + NumericFeature = 2, + CategoricalFeature = 3, + TextFeature = 4, + Weight = 5, + ImagePath = 6, + SamplingKey = 7 + } +} diff --git a/src/Microsoft.ML.Auto/ColumnInference/ColumnTypeInference.cs b/src/Microsoft.ML.Auto/ColumnInference/ColumnTypeInference.cs new file mode 100644 index 0000000000..760ea0bc4a --- /dev/null +++ b/src/Microsoft.ML.Auto/ColumnInference/ColumnTypeInference.cs @@ -0,0 +1,413 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text.RegularExpressions; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + /// + /// This class incapsulates logic for automatic inference of column types for the text file. + /// It also attempts to guess whether there is a header row. + /// + internal static class ColumnTypeInference + { + // Maximum number of columns to invoke type inference. + // REVIEW: revisit this requirement. Either work for arbitrary number of columns, + // or have a 'dumb' inference that would quickly figure everything out. + private const int SmartColumnsLim = 10000; + + internal sealed class Arguments + { + public char Separator; + public bool AllowSparse; + public bool AllowQuote; + public int ColumnCount; + public bool HasHeader; + public int MaxRowsToRead; + public uint? LabelColumnIndex; + public string Label; + + public Arguments() + { + MaxRowsToRead = 10000; + } + } + + private class IntermediateColumn + { + private readonly ReadOnlyMemory[] _data; + private readonly int _columnId; + private PrimitiveDataViewType _suggestedType; + private bool? _hasHeader; + + public int ColumnId + { + get { return _columnId; } + } + + public PrimitiveDataViewType SuggestedType + { + get { return _suggestedType; } + set { _suggestedType = value; } + } + + public bool? HasHeader + { + get { return _hasHeader; } + set { _hasHeader = value; } + } + + public IntermediateColumn(ReadOnlyMemory[] data, int columnId) + { + _data = data; + _columnId = columnId; + } + + public ReadOnlyMemory[] RawData { get { return _data; } } + + public string Name { get; set; } + + public bool HasAllBooleanValues() + { + if (this.RawData.Skip(1) + .All(x => + { + bool value; + // (note: Conversions.TryParse parses an empty string as a Boolean) + return !string.IsNullOrEmpty(x.ToString()) && + Conversions.TryParse(in x, out value); + })) + { + return true; + } + + return false; + } + } + + public class Column + { + public readonly int ColumnIndex; + + public PrimitiveDataViewType ItemType; + public string SuggestedName; + + public Column(int columnIndex, string suggestedName, PrimitiveDataViewType itemType) + { + ColumnIndex = columnIndex; + SuggestedName = suggestedName; + ItemType = itemType; + } + } + + public readonly struct InferenceResult + { + public readonly Column[] Columns; + public readonly bool HasHeader; + public readonly bool IsSuccess; + public readonly ReadOnlyMemory[][] Data; + + private InferenceResult(bool isSuccess, Column[] columns, bool hasHeader, ReadOnlyMemory[][] data) + { + IsSuccess = isSuccess; + Columns = columns; + HasHeader = hasHeader; + Data = data; + } + + public static InferenceResult Success(Column[] columns, bool hasHeader, ReadOnlyMemory[][] data) + { + return new InferenceResult(true, columns, hasHeader, data); + } + + public static InferenceResult Fail() + { + return new InferenceResult(false, null, false, null); + } + } + + private interface ITypeInferenceExpert + { + void Apply(IntermediateColumn[] columns); + } + + /// + /// Current design is as follows: there's a sequence of 'experts' that each look at all the columns. + /// Every expert may or may not assign the 'answer' (suggested type) to a column. If the expert needs + /// some information about the column (for example, the column values), this information is lazily calculated + /// by the column object, not the expert itself, to allow the reuse of the same information by another + /// expert. + /// + private static class Experts + { + internal sealed class BooleanValues : ITypeInferenceExpert + { + public void Apply(IntermediateColumn[] columns) + { + foreach (var col in columns) + { + // skip columns that already have a suggested type, + // or that don't have all Boolean values + if (col.SuggestedType != null || + !col.HasAllBooleanValues()) + { + continue; + } + + col.SuggestedType = BooleanDataViewType.Instance; + bool first; + + col.HasHeader = !Conversions.TryParse(in col.RawData[0], out first); + } + } + } + + internal sealed class AllNumericValues : ITypeInferenceExpert + { + public void Apply(IntermediateColumn[] columns) + { + foreach (var col in columns) + { + if (!col.RawData.Skip(1) + .All(x => + { + float value; + return Conversions.TryParse(in x, out value); + }) + ) + { + continue; + } + + col.SuggestedType = NumberDataViewType.Single; + + var headerStr = col.RawData[0].ToString(); + col.HasHeader = !double.TryParse(headerStr, out var doubleVal); + } + } + } + + internal sealed class EverythingText : ITypeInferenceExpert + { + public void Apply(IntermediateColumn[] columns) + { + foreach (var col in columns) + { + if (col.SuggestedType != null) + continue; + + col.SuggestedType = TextDataViewType.Instance; + col.HasHeader = IsLookLikeHeader(col.RawData[0]); + } + } + + private bool? IsLookLikeHeader(ReadOnlyMemory value) + { + var v = value.ToString(); + if (v.Length > 100) + return false; + var headerCandidates = new[] { "^Label", "^Feature", "^Market", "^m_", "^Weight" }; + foreach (var candidate in headerCandidates) + { + if (Regex.IsMatch(v, candidate, RegexOptions.IgnoreCase)) + return true; + } + + return null; + } + } + } + + private static IEnumerable GetExperts() + { + // Current logic is pretty primitive: if every value (except the first) of a column + // parses as numeric then it's numeric. Else if it parses as a Boolean, it's Boolean. Otherwise, it is text. + yield return new Experts.AllNumericValues(); + yield return new Experts.BooleanValues(); + yield return new Experts.EverythingText(); + } + + /// + /// Auto-detect column types of the file. + /// + public static InferenceResult InferTextFileColumnTypes(MLContext context, IMultiStreamSource fileSource, Arguments args) + { + return InferTextFileColumnTypesCore(context, fileSource, args); + } + + private static InferenceResult InferTextFileColumnTypesCore(MLContext context, IMultiStreamSource fileSource, Arguments args) + { + if (args.ColumnCount == 0) + { + // too many empty columns for automatic inference + return InferenceResult.Fail(); + } + + if (args.ColumnCount >= SmartColumnsLim) + { + // too many columns for automatic inference + return InferenceResult.Fail(); + } + + // read the file as the specified number of text columns + var textLoaderOptions = new TextLoader.Options + { + Columns = new[] { new TextLoader.Column("C", DataKind.String, 0, args.ColumnCount - 1) }, + Separators = new[] { args.Separator }, + AllowSparse = args.AllowSparse, + AllowQuoting = args.AllowQuote, + }; + var textLoader = context.Data.CreateTextLoader(textLoaderOptions); + var idv = textLoader.Load(fileSource); + idv = context.Data.TakeRows(idv, args.MaxRowsToRead); + + // read all the data into memory. + // list items are rows of the dataset. + var data = new List[]>(); + using (var cursor = idv.GetRowCursor(idv.Schema)) + { + var column = cursor.Schema.GetColumnOrNull("C").Value; + var colType = column.Type; + ValueGetter>> vecGetter = null; + ValueGetter> oneGetter = null; + bool isVector = colType.IsVector(); + if (isVector) { vecGetter = cursor.GetGetter>>(column); } + else + { + oneGetter = cursor.GetGetter>(column); + } + + VBuffer> line = default; + ReadOnlyMemory tsValue = default; + while (cursor.MoveNext()) + { + if (isVector) + { + vecGetter(ref line); + var values = new ReadOnlyMemory[args.ColumnCount]; + line.CopyTo(values); + data.Add(values); + } + else + { + oneGetter(ref tsValue); + var values = new[] { tsValue }; + data.Add(values); + } + } + } + + if (data.Count < 2) + { + // too few rows for automatic inference + return InferenceResult.Fail(); + } + + var cols = new IntermediateColumn[args.ColumnCount]; + for (int i = 0; i < args.ColumnCount; i++) + { + cols[i] = new IntermediateColumn(data.Select(x => x[i]).ToArray(), i); + } + + foreach (var expert in GetExperts()) + { + expert.Apply(cols); + } + + // Aggregating header signals. + int suspect = 0; + var usedNames = new HashSet(); + for (int i = 0; i < args.ColumnCount; i++) + { + if (cols[i].HasHeader == true) + { + if (usedNames.Add(cols[i].RawData[0].ToString())) + suspect++; + else + { + // duplicate value in the first column is a strong signal that this is not a header + suspect -= args.ColumnCount; + } + } + else if (cols[i].HasHeader == false) + suspect--; + } + + // suggest names + usedNames.Clear(); + foreach (var col in cols) + { + string name0; + string name; + name0 = name = SuggestName(col, args.HasHeader); + int i = 0; + while (!usedNames.Add(name)) + { + name = string.Format("{0}_{1:00}", name0, i++); + } + col.Name = name; + } + + // validate & retrieve label column + var labelColumn = GetAndValidateLabelColumn(args, cols); + + // if label column has all Boolean values, set its type as Boolean + if (labelColumn.HasAllBooleanValues()) + { + labelColumn.SuggestedType = BooleanDataViewType.Instance; + } + + var outCols = cols.Select(x => new Column(x.ColumnId, x.Name, x.SuggestedType)).ToArray(); + + return InferenceResult.Success(outCols, args.HasHeader, cols.Select(col => col.RawData).ToArray()); + } + + private static string SuggestName(IntermediateColumn column, bool hasHeader) + { + var header = column.RawData[0].ToString(); + return (hasHeader && !string.IsNullOrWhiteSpace(header)) ? header : string.Format("col{0}", column.ColumnId); + } + + private static IntermediateColumn GetAndValidateLabelColumn(Arguments args, IntermediateColumn[] cols) + { + IntermediateColumn labelColumn = null; + if (args.LabelColumnIndex != null) + { + // if label column index > inferred # of columns, throw error + if (args.LabelColumnIndex >= cols.Count()) + { + throw new ArgumentOutOfRangeException(nameof(args.LabelColumnIndex), $"Label column index ({args.LabelColumnIndex}) is >= than # of inferred columns ({cols.Count()})."); + } + + labelColumn = cols[args.LabelColumnIndex.Value]; + } + else + { + labelColumn = cols.FirstOrDefault(c => c.Name == args.Label); + if (labelColumn == null) + { + throw new ArgumentException($"Specified label column '{args.Label}' was not found."); + } + } + + return labelColumn; + } + + public static TextLoader.Column[] GenerateLoaderColumns(Column[] columns) + { + var loaderColumns = new List(); + foreach (var col in columns) + { + var loaderColumn = new TextLoader.Column(col.SuggestedName, col.ItemType.GetRawKind(), col.ColumnIndex); + loaderColumns.Add(loaderColumn); + } + return loaderColumns.ToArray(); + } + } + +} diff --git a/src/Microsoft.ML.Auto/ColumnInference/PurposeInference.cs b/src/Microsoft.ML.Auto/ColumnInference/PurposeInference.cs new file mode 100644 index 0000000000..4cdf1e5411 --- /dev/null +++ b/src/Microsoft.ML.Auto/ColumnInference/PurposeInference.cs @@ -0,0 +1,283 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + /// + /// Automatic inference of column purposes for the data view. + /// This is used in the context of text import wizard, but can be used outside as well. + /// + internal static class PurposeInference + { + public const int MaxRowsToRead = 1000; + + public class Column + { + public readonly int ColumnIndex; + public readonly ColumnPurpose Purpose; + + public Column(int columnIndex, ColumnPurpose purpose) + { + ColumnIndex = columnIndex; + Purpose = purpose; + } + } + + /// + /// The design is the same as for : there's a sequence of 'experts' + /// that each look at all the columns. Every expert may or may not assign the 'answer' (suggested purpose) + /// to a column. If the expert needs some information about the column (for example, the column values), + /// this information is lazily calculated by the column object, not the expert itself, to allow the reuse + /// of the same information by another expert. + /// + private interface IPurposeInferenceExpert + { + void Apply(IntermediateColumn[] columns); + } + + private class IntermediateColumn + { + private readonly IDataView _data; + private readonly int _columnId; + private ColumnPurpose _suggestedPurpose; + private readonly Lazy _type; + private readonly Lazy _columnName; + private IReadOnlyList> _cachedData; + + public bool IsPurposeSuggested { get; private set; } + + public ColumnPurpose SuggestedPurpose + { + get { return _suggestedPurpose; } + set + { + _suggestedPurpose = value; + IsPurposeSuggested = true; + } + } + + public DataViewType Type { get { return _type.Value; } } + + public string ColumnName { get { return _columnName.Value; } } + + public IntermediateColumn(IDataView data, int columnId, ColumnPurpose suggestedPurpose = ColumnPurpose.Ignore) + { + _data = data; + _columnId = columnId; + _type = new Lazy(() => _data.Schema[_columnId].Type); + _columnName = new Lazy(() => _data.Schema[_columnId].Name); + _suggestedPurpose = suggestedPurpose; + } + + public Column GetColumn() + { + return new Column(_columnId, _suggestedPurpose); + } + + public IReadOnlyList> GetColumnData() + { + if (_cachedData != null) + return _cachedData; + + var results = new List>(); + var column = _data.Schema[_columnId]; + + using (var cursor = _data.GetRowCursor(new[] { column })) + { + var getter = cursor.GetGetter>(column); + while (cursor.MoveNext()) + { + var value = default(ReadOnlyMemory); + getter(ref value); + + var copy = new ReadOnlyMemory(value.ToArray()); + + results.Add(copy); + } + } + + _cachedData = results; + + return results; + } + } + + private static class Experts + { + internal sealed class TextClassification : IPurposeInferenceExpert + { + public void Apply(IntermediateColumn[] columns) + { + string[] commonImageExtensions = { ".bmp", ".dib", ".rle", ".jpg", ".jpeg", ".jpe", ".jfif", ".gif", ".tif", ".tiff", ".png" }; + foreach (var column in columns) + { + if (column.IsPurposeSuggested || !column.Type.IsText()) + continue; + + var data = column.GetColumnData(); + + long sumLength = 0; + int sumSpaces = 0; + var seen = new HashSet(); + int imagePathCount = 0; + foreach (var span in data) + { + sumLength += span.Length; + seen.Add(span.ToString()); + string spanStr = span.ToString(); + sumSpaces += spanStr.Count(x => x == ' '); + + foreach (var ext in commonImageExtensions) + { + if (spanStr.EndsWith(ext, StringComparison.OrdinalIgnoreCase)) + { + imagePathCount++; + break; + } + } + } + + if (imagePathCount < data.Count - 1) + { + Double avgLength = 1.0 * sumLength / data.Count; + Double cardinalityRatio = 1.0 * seen.Count / data.Count; + Double avgSpaces = 1.0 * sumSpaces / data.Count; + if (cardinalityRatio < 0.7) + column.SuggestedPurpose = ColumnPurpose.CategoricalFeature; + // (note: the columns.Count() == 1 condition below, in case a dataset has only + // a 'name' and a 'label' column, forces what would be an 'ignore' column to become a text feature) + else if (cardinalityRatio >= 0.85 && (avgLength > 30 || avgSpaces >= 1 || columns.Count() == 1)) + column.SuggestedPurpose = ColumnPurpose.TextFeature; + else if (cardinalityRatio >= 0.9) + column.SuggestedPurpose = ColumnPurpose.Ignore; + } + else + column.SuggestedPurpose = ColumnPurpose.ImagePath; + } + } + } + + internal sealed class NumericAreFeatures : IPurposeInferenceExpert + { + public void Apply(IntermediateColumn[] columns) + { + foreach (var column in columns) + { + if (column.IsPurposeSuggested) + continue; + if (column.Type.GetItemType().IsNumber()) + column.SuggestedPurpose = ColumnPurpose.NumericFeature; + } + } + } + + internal sealed class BooleanProcessing : IPurposeInferenceExpert + { + public void Apply(IntermediateColumn[] columns) + { + foreach (var column in columns) + { + if (column.IsPurposeSuggested) + continue; + if (column.Type.GetItemType().IsBool()) + column.SuggestedPurpose = ColumnPurpose.NumericFeature; + } + } + } + + internal sealed class TextArraysAreText : IPurposeInferenceExpert + { + public void Apply(IntermediateColumn[] columns) + { + foreach (var column in columns) + { + if (column.IsPurposeSuggested) + continue; + if (column.Type.IsVector() && column.Type.GetItemType().IsText()) + column.SuggestedPurpose = ColumnPurpose.TextFeature; + } + } + } + + internal sealed class IgnoreEverythingElse : IPurposeInferenceExpert + { + public void Apply(IntermediateColumn[] columns) + { + foreach (var column in columns) + { + if (!column.IsPurposeSuggested) + column.SuggestedPurpose = ColumnPurpose.Ignore; + } + } + } + } + + private static IEnumerable GetExperts() + { + // Each of the experts respects the decisions of all the experts above. + + // Single-value text columns may be category, name, text or ignore. + yield return new Experts.TextClassification(); + // Vector-value text columns are always treated as text. + // REVIEW: could be improved. + yield return new Experts.TextArraysAreText(); + // Check column on boolean only values. + yield return new Experts.BooleanProcessing(); + // All numeric columns are features. + yield return new Experts.NumericAreFeatures(); + // Everything else is ignored. + yield return new Experts.IgnoreEverythingElse(); + } + + /// + /// Auto-detect purpose for the data view columns. + /// + public static PurposeInference.Column[] InferPurposes(MLContext context, IDataView data, + ColumnInformation columnInfo) + { + data = context.Data.TakeRows(data, MaxRowsToRead); + + var allColumns = new List(); + var columnsToInfer = new List(); + + for (var i = 0; i < data.Schema.Count; i++) + { + var column = data.Schema[i]; + IntermediateColumn intermediateCol; + + if (column.IsHidden) + { + intermediateCol = new IntermediateColumn(data, i, ColumnPurpose.Ignore); + allColumns.Add(intermediateCol); + continue; + } + + var columnPurpose = columnInfo.GetColumnPurpose(column.Name); + if (columnPurpose == null) + { + intermediateCol = new IntermediateColumn(data, i); + columnsToInfer.Add(intermediateCol); + } + else + { + intermediateCol = new IntermediateColumn(data, i, columnPurpose.Value); + } + + allColumns.Add(intermediateCol); + } + + foreach (var expert in GetExperts()) + { + expert.Apply(columnsToInfer.ToArray()); + } + + return allColumns.Select(c => c.GetColumn()).ToArray(); + } + } +} diff --git a/src/Microsoft.ML.Auto/ColumnInference/TextFileContents.cs b/src/Microsoft.ML.Auto/ColumnInference/TextFileContents.cs new file mode 100644 index 0000000000..9fed93d27a --- /dev/null +++ b/src/Microsoft.ML.Auto/ColumnInference/TextFileContents.cs @@ -0,0 +1,124 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + /// + /// Utilities for various heuristics against text files. + /// Currently, separator inference and column count detection. + /// + internal static class TextFileContents + { + public class ColumnSplitResult + { + public readonly bool IsSuccess; + public readonly char? Separator; + public readonly int ColumnCount; + + public bool AllowQuote { get; set; } + public bool AllowSparse { get; set; } + + public ColumnSplitResult(bool isSuccess, char? separator, bool allowQuote, bool allowSparse, int columnCount) + { + IsSuccess = isSuccess; + Separator = separator; + AllowQuote = allowQuote; + AllowSparse = allowSparse; + ColumnCount = columnCount; + } + } + + // If the fraction of lines having the same number of columns exceeds this, we consider the column count to be known. + private const Double UniformColumnCountThreshold = 0.98; + + public static readonly char[] DefaultSeparators = { '\t', ',', ' ', ';' }; + + /// + /// Attempt to detect text loader arguments. + /// The algorithm selects the first 'acceptable' set: the one that recognizes the same number of columns in at + /// least of the sample's lines, + /// and this number of columns is more than 1. + /// We sweep on separator, allow sparse and allow quote parameter. + /// + public static ColumnSplitResult TrySplitColumns(MLContext context, IMultiStreamSource source, char[] separatorCandidates) + { + var sparse = new[] { false, true }; + var quote = new[] { true, false }; + var foundAny = false; + var result = default(ColumnSplitResult); + foreach (var perm in (from _allowSparse in sparse + from _allowQuote in quote + from _sep in separatorCandidates + select new { _allowSparse, _allowQuote, _sep })) + { + var options = new TextLoader.Options + { + Columns = new[] { new TextLoader.Column() { + Name = "C", + DataKind = DataKind.String, + Source = new[] { new TextLoader.Range(0, null) } + } }, + Separators = new[] { perm._sep }, + AllowQuoting = perm._allowQuote, + AllowSparse = perm._allowSparse + }; + + if (TryParseFile(context, options, source, out result)) + { + foundAny = true; + break; + } + } + return foundAny ? result : new ColumnSplitResult(false, null, true, true, 0); + } + + private static bool TryParseFile(MLContext context, TextLoader.Options options, IMultiStreamSource source, + out ColumnSplitResult result) + { + result = null; + // try to instantiate data view with swept arguments + try + { + var textLoader = context.Data.CreateTextLoader(options, source); + var idv = context.Data.TakeRows(textLoader.Load(source), 1000); + var columnCounts = new List(); + var column = idv.Schema["C"]; + + using (var cursor = idv.GetRowCursor(new[] { column })) + { + var getter = cursor.GetGetter>>(column); + + VBuffer> line = default; + while (cursor.MoveNext()) + { + getter(ref line); + columnCounts.Add(line.Length); + } + } + + var mostCommon = columnCounts.GroupBy(x => x).OrderByDescending(x => x.Count()).First(); + if (mostCommon.Count() < UniformColumnCountThreshold * columnCounts.Count) + { + return false; + } + + // disallow single-column case + if (mostCommon.Key <= 1) { return false; } + + result = new ColumnSplitResult(true, options.Separators.First(), options.AllowQuoting, options.AllowSparse, mostCommon.Key); + return true; + } + // fail gracefully if unable to instantiate data view with swept arguments + catch(Exception) + { + return false; + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/ColumnInference/TextFileSample.cs b/src/Microsoft.ML.Auto/ColumnInference/TextFileSample.cs new file mode 100644 index 0000000000..f528d5a8b8 --- /dev/null +++ b/src/Microsoft.ML.Auto/ColumnInference/TextFileSample.cs @@ -0,0 +1,304 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + /// + /// This class holds an in-memory sample of the text file, and serves as an proxy to it. + /// + internal sealed class TextFileSample : IMultiStreamSource + { + // REVIEW: consider including multiple files via IMultiStreamSource. + + // REVIEW: right now, it expects 0x0A being the trailing character of line break. + // Consider a more general implementation. + + private const int BufferSizeMb = 4; + private const int FirstChunkSizeMb = 1; + private const int LinesPerChunk = 20; + private const Double OversamplingRate = 1.1; + + private readonly byte[] _buffer; + private readonly long? _fullFileSize; + private readonly long? _approximateRowCount; + + private TextFileSample(byte[] buffer, long? fullFileSize, long? lineCount) + { + _buffer = buffer; + _fullFileSize = fullFileSize; + _approximateRowCount = lineCount; + } + + public int Count + { + get { return 1; } + } + + // Full file size, if known, otherwise, null. + public long? FullFileSize + { + get { return _fullFileSize; } + } + + public int SampleSize + { + get { return _buffer.Length; } + } + + public string GetPathOrNull(int index) + { + //Contracts.Check(index == 0, "Index must be 0"); + return null; + } + + public Stream Open(int index) + { + //Contracts.Check(index == 0, "Index must be 0"); + return new MemoryStream(_buffer); + } + + public TextReader OpenTextReader(int index) + { + return new StreamReader(Open(index)); + } + + public long? ApproximateRowCount => _approximateRowCount; + + public static TextFileSample CreateFromFullFile(string path) + { + using (var fs = new FileStream(path, FileMode.Open, FileAccess.Read, FileShare.Read)) + { + return CreateFromFullStream(fs); + } + } + + /// + /// Create a by reading multiple chunks from the file (or other source) and + /// then stitching them together. The algorithm is as follows: + /// 0. If the source is not seekable, revert to . + /// 1. If the file length is less than 2 * , revert to . + /// 2. Read first MB chunk. Determine average line length in the chunk. + /// 3. Determine how large one chunk should be, and how many chunks there should be, to end up + /// with * MB worth of lines. + /// 4. Determine seek locations and read the chunks. + /// 5. Stitch and return a . + /// + public static TextFileSample CreateFromFullStream(Stream stream) + { + if (!stream.CanSeek) + { + return CreateFromHead(stream); + } + var fileSize = stream.Length; + + if (fileSize <= 2 * BufferSizeMb * (1 << 20)) + { + return CreateFromHead(stream); + } + + var firstChunk = new byte[FirstChunkSizeMb * (1 << 20)]; + int count = stream.Read(firstChunk, 0, firstChunk.Length); + if (!IsEncodingOkForSampling(firstChunk)) + return CreateFromHead(stream); + // REVIEW: CreateFromHead still truncates the file before the last 0x0A byte. For multi-byte encoding, + // this might cause an unfinished string to be present in the buffer. Right now this is considered an acceptable + // price to pay for parse-free processing. + + var lineCount = firstChunk.Count(x => x == '\n'); + if (lineCount == 0) + { + throw new ArgumentException("Counldn't identify line breaks. Provided file is not text?"); + } + + long approximateRowCount = (long)(lineCount * fileSize * 1.0 / firstChunk.Length); + var firstNewline = Array.FindIndex(firstChunk, x => x == '\n'); + + // First line may be header, so we exclude it. The remaining lineCount-1 line breaks are + // splitting the text into lineCount lines, and the last line is actually half-size. + Double averageLineLength = 2.0 * (firstChunk.Length - firstNewline) / (lineCount * 2 - 1); + averageLineLength = Math.Max(averageLineLength, 3); + + int usefulChunkSize = (int)(averageLineLength * LinesPerChunk); + int chunkSize = (int)(usefulChunkSize + averageLineLength); // assuming that 1 line worth will be trimmed out + + int chunkCount = (int)Math.Ceiling((BufferSizeMb * OversamplingRate - FirstChunkSizeMb) * (1 << 20) / usefulChunkSize); + int maxChunkCount = (int)Math.Floor((double)(fileSize - firstChunk.Length) / chunkSize); + chunkCount = Math.Min(chunkCount, maxChunkCount); + + var chunks = new List(); + chunks.Add(firstChunk); + + // determine the start of each remaining chunk + long fileSizeRemaining = fileSize - firstChunk.Length - ((long)chunkSize) * chunkCount; + + var chunkStartIndices = Enumerable.Range(0, chunkCount) + .Select(x => AutoMlUtils.random.Value.NextDouble() * fileSizeRemaining) + .OrderBy(x => x) + .Select((spot, i) => (long)(spot + firstChunk.Length + i * chunkSize)) + .ToArray(); + + foreach (var chunkStartIndex in chunkStartIndices) + { + stream.Seek(chunkStartIndex, SeekOrigin.Begin); + byte[] chunk = new byte[chunkSize]; + int readCount = stream.Read(chunk, 0, chunkSize); + Array.Resize(ref chunk, chunkSize); + chunks.Add(chunk); + } + + return new TextFileSample(StitchChunks(false, chunks.ToArray()), fileSize, approximateRowCount); + } + + /// + /// Create a by reading one chunk from the beginning. + /// + private static TextFileSample CreateFromHead(Stream stream) + { + var buf = new byte[BufferSizeMb * (1 << 20)]; + int readCount = stream.Read(buf, 0, buf.Length); + Array.Resize(ref buf, readCount); + long? multiplier = stream.CanSeek ? (int?)(stream.Length / buf.Length) : null; + return new TextFileSample(StitchChunks(readCount == stream.Length, buf), + stream.CanSeek ? (long?)stream.Length : null, + multiplier.HasValue ? buf.Count(x => x == '\n') * multiplier : null); + } + + /// + /// Given an array of chunks of the text file, of which the first chunk is the head, + /// this method trims incomplete lines from the beginning and end of each chunk + /// (except that it doesn't trim the beginning of the first chunk and end of last chunk if we read whole file), + /// then joins the rest together to form a final byte buffer and returns a + /// wrapped around it. + /// + /// did we read whole file + /// chunks of data + /// + private static byte[] StitchChunks(bool wholeFile, params byte[][] chunks) + { + using (var resultStream = new MemoryStream(BufferSizeMb * (1 << 20))) + { + for (int i = 0; i < chunks.Length; i++) + { + int iMin = (i == 0) ? 0 : Array.FindIndex(chunks[i], x => x == '\n') + 1; + int iLim = (wholeFile && i == chunks.Length - 1) + ? chunks[i].Length + : Array.FindLastIndex(chunks[i], x => x == '\n') + 1; + + if (iLim == 0) + { + //entire buffer is one string, skip + continue; + } + + resultStream.Write(chunks[i], iMin, iLim - iMin); + } + + var resultBuffer = resultStream.ToArray(); + if (resultBuffer.Length == 0) + { + throw new ArgumentException("File is not text, or couldn't detect line breaks"); + } + + return resultBuffer; + } + } + + /// + /// Detect whether we can auto-detect EOL characters without parsing. + /// If we do, we can cheaply sample from different file locations and trim the partial strings. + /// The encodings that pass the test are UTF8 and all single-byte encodings. + /// + private static bool IsEncodingOkForSampling(byte[] buffer) + { + // First check if a BOM/signature exists (sourced from https://www.unicode.org/faq/utf_bom.html#bom4) + if (buffer.Length >= 4 && buffer[0] == 0x00 && buffer[1] == 0x00 && buffer[2] == 0xFE && buffer[3] == 0xFF) + { + // UTF-32, big-endian + return false; + } + if (buffer.Length >= 4 && buffer[0] == 0xFF && buffer[1] == 0xFE && buffer[2] == 0x00 && buffer[3] == 0x00) + { + // UTF-32, little-endian + return false; + } + if (buffer.Length >= 2 && buffer[0] == 0xFE && buffer[1] == 0xFF) + { + // UTF-16, big-endian + return false; + } + if (buffer.Length >= 2 && buffer[0] == 0xFF && buffer[1] == 0xFE) + { + // UTF-16, little-endian + return false; + } + if (buffer.Length >= 3 && buffer[0] == 0xEF && buffer[1] == 0xBB && buffer[2] == 0xBF) + { + // UTF-8 + return true; + } + if (buffer.Length >= 3 && buffer[0] == 0x2b && buffer[1] == 0x2f && buffer[2] == 0x76) + { + // UTF-7 + return true; + } + + // No BOM/signature was found, so now we need to 'sniff' the file to see if can manually discover the encoding. + int sniffLim = Math.Min(1000, buffer.Length); + + // Some text files are encoded in UTF8, but have no BOM/signature. Hence the below manually checks for a UTF8 pattern. This code is based off + // the top answer at: https://stackoverflow.com/questions/6555015/check-for-invalid-utf8 . + int i = 0; + bool utf8 = false; + while (i < sniffLim - 4) + { + if (buffer[i] <= 0x7F) + { + i += 1; + continue; + } + if (buffer[i] >= 0xC2 && buffer[i] <= 0xDF && buffer[i + 1] >= 0x80 && buffer[i + 1] < 0xC0) + { + i += 2; + utf8 = true; + continue; + } + if (buffer[i] >= 0xE0 && buffer[i] <= 0xF0 && buffer[i + 1] >= 0x80 && buffer[i + 1] < 0xC0 && + buffer[i + 2] >= 0x80 && buffer[i + 2] < 0xC0) + { + i += 3; + utf8 = true; + continue; + } + if (buffer[i] >= 0xF0 && buffer[i] <= 0xF4 && buffer[i + 1] >= 0x80 && buffer[i + 1] < 0xC0 && + buffer[i + 2] >= 0x80 && buffer[i + 2] < 0xC0 && buffer[i + 3] >= 0x80 && buffer[i + 3] < 0xC0) + { + i += 4; + utf8 = true; + continue; + } + utf8 = false; + break; + } + if (utf8) + { + return true; + } + + if (buffer.Take(sniffLim).Any(x => x == 0)) + { + // likely a UTF-16 or UTF-32 without a BOM. + return false; + } + + // If all else failed, the file is likely in a local 1-byte encoding. + return true; + } + } +} diff --git a/src/Microsoft.ML.Auto/DatasetDimensions/ColumnDimensions.cs b/src/Microsoft.ML.Auto/DatasetDimensions/ColumnDimensions.cs new file mode 100644 index 0000000000..78283ac6c5 --- /dev/null +++ b/src/Microsoft.ML.Auto/DatasetDimensions/ColumnDimensions.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Auto +{ + internal class ColumnDimensions + { + public int? Cardinality; + public bool? HasMissing; + + public ColumnDimensions(int? cardinality, bool? hasMissing) + { + Cardinality = cardinality; + HasMissing = hasMissing; + } + } +} diff --git a/src/Microsoft.ML.Auto/DatasetDimensions/DatasetDimensionsApi.cs b/src/Microsoft.ML.Auto/DatasetDimensions/DatasetDimensionsApi.cs new file mode 100644 index 0000000000..8d18b5057b --- /dev/null +++ b/src/Microsoft.ML.Auto/DatasetDimensions/DatasetDimensionsApi.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal class DatasetDimensionsApi + { + private const long MaxRowsToRead = 1000; + + public static ColumnDimensions[] CalcColumnDimensions(MLContext context, IDataView data, PurposeInference.Column[] purposes) + { + data = context.Data.TakeRows(data, MaxRowsToRead); + + var colDimensions = new ColumnDimensions[data.Schema.Count]; + + for (var i = 0; i < data.Schema.Count; i++) + { + var column = data.Schema[i]; + var purpose = purposes[i]; + + // default column dimensions + int? cardinality = null; + bool? hasMissing = null; + + var itemType = column.Type.GetItemType(); + + // If categorical text feature, calculate cardinality + if (itemType.IsText() && purpose.Purpose == ColumnPurpose.CategoricalFeature) + { + cardinality = DatasetDimensionsUtil.GetTextColumnCardinality(data, column); + } + + // If numeric feature, discover missing values + if (itemType == NumberDataViewType.Single) + { + hasMissing = column.Type.IsVector() ? + DatasetDimensionsUtil.HasMissingNumericVector(data, column) : + DatasetDimensionsUtil.HasMissingNumericSingleValue(data, column); + } + + colDimensions[i] = new ColumnDimensions(cardinality, hasMissing); + } + + return colDimensions; + } + } +} diff --git a/src/Microsoft.ML.Auto/DatasetDimensions/DatasetDimensionsUtil.cs b/src/Microsoft.ML.Auto/DatasetDimensions/DatasetDimensionsUtil.cs new file mode 100644 index 0000000000..c0dea14fbb --- /dev/null +++ b/src/Microsoft.ML.Auto/DatasetDimensions/DatasetDimensionsUtil.cs @@ -0,0 +1,85 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal static class DatasetDimensionsUtil + { + public static int GetTextColumnCardinality(IDataView data, DataViewSchema.Column column) + { + var seen = new HashSet(); + using (var cursor = data.GetRowCursor(new[] { column })) + { + var getter = cursor.GetGetter>(column); + while (cursor.MoveNext()) + { + var value = default(ReadOnlyMemory); + getter(ref value); + var valueStr = value.ToString(); + seen.Add(valueStr); + } + } + return seen.Count; + } + + public static bool HasMissingNumericSingleValue(IDataView data, DataViewSchema.Column column) + { + using (var cursor = data.GetRowCursor(new[] { column })) + { + var getter = cursor.GetGetter(column); + var value = default(Single); + while (cursor.MoveNext()) + { + getter(ref value); + if (Single.IsNaN(value)) + { + return true; + } + } + return false; + } + } + + public static bool HasMissingNumericVector(IDataView data, DataViewSchema.Column column) + { + using (var cursor = data.GetRowCursor(new[] { column })) + { + var getter = cursor.GetGetter>(column); + var value = default(VBuffer); + while (cursor.MoveNext()) + { + getter(ref value); + if (VBufferUtils.HasNaNs(value)) + { + return true; + } + } + return false; + } + } + + public static ulong CountRows(IDataView data, ulong maxRows) + { + var cursor = data.GetRowCursor(new[] { data.Schema[0] }); + ulong rowCount = 0; + while (cursor.MoveNext()) + { + if (++rowCount == maxRows) + { + break; + } + } + return rowCount; + } + + public static bool IsDataViewEmpty(IDataView data) + { + return CountRows(data, 1) == 0; + } + } +} diff --git a/src/Microsoft.ML.Auto/DebugLogger.cs b/src/Microsoft.ML.Auto/DebugLogger.cs new file mode 100644 index 0000000000..90ed9cfdd2 --- /dev/null +++ b/src/Microsoft.ML.Auto/DebugLogger.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Auto +{ + internal interface IDebugLogger + { + void Log(LogSeverity logLevel, string message); + } + + internal enum LogSeverity + { + Error, + Debug + } +} diff --git a/src/Microsoft.ML.Auto/EstimatorExtensions/EstimatorExtensionCatalog.cs b/src/Microsoft.ML.Auto/EstimatorExtensions/EstimatorExtensionCatalog.cs new file mode 100644 index 0000000000..e37acd42ac --- /dev/null +++ b/src/Microsoft.ML.Auto/EstimatorExtensions/EstimatorExtensionCatalog.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.ML.Auto +{ + internal enum EstimatorName + { + ColumnConcatenating, + ColumnCopying, + KeyToValueMapping, + MissingValueIndicating, + MissingValueReplacing, + Normalizing, + OneHotEncoding, + OneHotHashEncoding, + TextFeaturizing, + TypeConverting, + ValueToKeyMapping + } + + internal class EstimatorExtensionCatalog + { + private static readonly IDictionary _namesToExtensionTypes = new + Dictionary() + { + { EstimatorName.ColumnConcatenating, typeof(ColumnConcatenatingExtension) }, + { EstimatorName.ColumnCopying, typeof(ColumnCopyingExtension) }, + { EstimatorName.KeyToValueMapping, typeof(KeyToValueMappingExtension) }, + { EstimatorName.MissingValueIndicating, typeof(MissingValueIndicatingExtension) }, + { EstimatorName.MissingValueReplacing, typeof(MissingValueReplacingExtension) }, + { EstimatorName.Normalizing, typeof(NormalizingExtension) }, + { EstimatorName.OneHotEncoding, typeof(OneHotEncodingExtension) }, + { EstimatorName.OneHotHashEncoding, typeof(OneHotHashEncodingExtension) }, + { EstimatorName.TextFeaturizing, typeof(TextFeaturizingExtension) }, + { EstimatorName.TypeConverting, typeof(TypeConvertingExtension) }, + { EstimatorName.ValueToKeyMapping, typeof(ValueToKeyMappingExtension) }, + }; + + public static IEstimatorExtension GetExtension(EstimatorName estimatorName) + { + var extType = _namesToExtensionTypes[estimatorName]; + return (IEstimatorExtension)Activator.CreateInstance(extType); + } + } +} diff --git a/src/Microsoft.ML.Auto/EstimatorExtensions/EstimatorExtensions.cs b/src/Microsoft.ML.Auto/EstimatorExtensions/EstimatorExtensions.cs new file mode 100644 index 0000000000..573e24d932 --- /dev/null +++ b/src/Microsoft.ML.Auto/EstimatorExtensions/EstimatorExtensions.cs @@ -0,0 +1,272 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.ML.Transforms; + +namespace Microsoft.ML.Auto +{ + internal class ColumnConcatenatingExtension : IEstimatorExtension + { + public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode) + { + return CreateInstance(context, pipelineNode.InColumns, pipelineNode.OutColumns[0]); + } + + public static SuggestedTransform CreateSuggestedTransform(MLContext context, string[] inColumns, string outColumn) + { + var pipelineNode = new PipelineNode(EstimatorName.ColumnConcatenating.ToString(), + PipelineNodeType.Transform, inColumns, outColumn); + var estimator = CreateInstance(context, inColumns, outColumn); + return new SuggestedTransform(pipelineNode, estimator); + } + + private static IEstimator CreateInstance(MLContext context, string[] inColumns, string outColumn) + { + return context.Transforms.Concatenate(outColumn, inColumns); + } + } + + internal class ColumnCopyingExtension : IEstimatorExtension + { + public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode) + { + return CreateInstance(context, pipelineNode.InColumns[0], pipelineNode.OutColumns[0]); + } + + public static SuggestedTransform CreateSuggestedTransform(MLContext context, string inColumn, string outColumn) + { + var pipelineNode = new PipelineNode(EstimatorName.ColumnCopying.ToString(), + PipelineNodeType.Transform, inColumn, outColumn); + var estimator = CreateInstance(context, inColumn, outColumn); + return new SuggestedTransform(pipelineNode, estimator); + } + + private static IEstimator CreateInstance(MLContext context, string inColumn, string outColumn) + { + return context.Transforms.CopyColumns(outColumn, inColumn); + } + } + + internal class KeyToValueMappingExtension : IEstimatorExtension + { + public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode) + { + return CreateInstance(context, pipelineNode.InColumns[0], pipelineNode.OutColumns[0]); + } + + public static SuggestedTransform CreateSuggestedTransform(MLContext context, string inColumn, string outColumn) + { + var pipelineNode = new PipelineNode(EstimatorName.KeyToValueMapping.ToString(), + PipelineNodeType.Transform, inColumn, outColumn); + var estimator = CreateInstance(context, inColumn, outColumn); + return new SuggestedTransform(pipelineNode, estimator); + } + + private static IEstimator CreateInstance(MLContext context, string inColumn, string outColumn) + { + return context.Transforms.Conversion.MapKeyToValue(outColumn, inColumn); + } + } + + internal class MissingValueIndicatingExtension : IEstimatorExtension + { + public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode) + { + return CreateInstance(context, pipelineNode.InColumns, pipelineNode.OutColumns); + } + + public static SuggestedTransform CreateSuggestedTransform(MLContext context, string[] inColumns, string[] outColumns) + { + var pipelineNode = new PipelineNode(EstimatorName.MissingValueIndicating.ToString(), + PipelineNodeType.Transform, inColumns, outColumns); + var estimator = CreateInstance(context, inColumns, outColumns); + return new SuggestedTransform(pipelineNode, estimator); + } + + private static IEstimator CreateInstance(MLContext context, string[] inColumns, string[] outColumns) + { + var pairs = new InputOutputColumnPair[inColumns.Length]; + for (var i = 0; i < inColumns.Length; i++) + { + var pair = new InputOutputColumnPair(outColumns[i], inColumns[i]); + pairs[i] = pair; + } + return context.Transforms.IndicateMissingValues(pairs); + } + } + + internal class MissingValueReplacingExtension : IEstimatorExtension + { + public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode) + { + return CreateInstance(context, pipelineNode.InColumns, pipelineNode.OutColumns); + } + + public static SuggestedTransform CreateSuggestedTransform(MLContext context, string[] inColumns, string[] outColumns) + { + var pipelineNode = new PipelineNode(EstimatorName.MissingValueReplacing.ToString(), + PipelineNodeType.Transform, inColumns, outColumns); + var estimator = CreateInstance(context, inColumns, outColumns); + return new SuggestedTransform(pipelineNode, estimator); + } + + private static IEstimator CreateInstance(MLContext context, string[] inColumns, string[] outColumns) + { + var pairs = new InputOutputColumnPair[inColumns.Length]; + for (var i = 0; i < inColumns.Length; i++) + { + var pair = new InputOutputColumnPair(outColumns[i], inColumns[i]); + pairs[i] = pair; + } + return context.Transforms.ReplaceMissingValues(pairs); + } + } + + internal class NormalizingExtension : IEstimatorExtension + { + public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode) + { + return CreateInstance(context, pipelineNode.InColumns[0], pipelineNode.OutColumns[0]); + } + + public static SuggestedTransform CreateSuggestedTransform(MLContext context, string inColumn, string outColumn) + { + var pipelineNode = new PipelineNode(EstimatorName.Normalizing.ToString(), + PipelineNodeType.Transform, inColumn, outColumn); + var estimator = CreateInstance(context, inColumn, outColumn); + return new SuggestedTransform(pipelineNode, estimator); + } + + private static IEstimator CreateInstance(MLContext context, string inColumn, string outColumn) + { + return context.Transforms.NormalizeMinMax(outColumn, inColumn); + } + } + + internal class OneHotEncodingExtension : IEstimatorExtension + { + public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode) + { + return CreateInstance(context, pipelineNode.InColumns, pipelineNode.OutColumns); + } + + public static SuggestedTransform CreateSuggestedTransform(MLContext context, string[] inColumns, string[] outColumns) + { + var pipelineNode = new PipelineNode(EstimatorName.OneHotEncoding.ToString(), + PipelineNodeType.Transform, inColumns, outColumns); + var estimator = CreateInstance(context, inColumns, outColumns); + return new SuggestedTransform(pipelineNode, estimator); + } + + public static IEstimator CreateInstance(MLContext context, string[] inColumns, string[] outColumns) + { + var cols = new InputOutputColumnPair[inColumns.Length]; + for (var i = 0; i < cols.Length; i++) + { + cols[i] = new InputOutputColumnPair(outColumns[i], inColumns[i]); + } + return context.Transforms.Categorical.OneHotEncoding(cols); + } + } + + internal class OneHotHashEncodingExtension : IEstimatorExtension + { + public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode) + { + return CreateInstance(context, pipelineNode.InColumns, pipelineNode.OutColumns); + } + + public static SuggestedTransform CreateSuggestedTransform(MLContext context, string inColumn, string outColumn) + { + return CreateSuggestedTransform(context, new[] { inColumn }, new[] { outColumn }); + } + + public static SuggestedTransform CreateSuggestedTransform(MLContext context, string[] inColumns, string[] outColumns) + { + var pipelineNode = new PipelineNode(EstimatorName.OneHotHashEncoding.ToString(), + PipelineNodeType.Transform, inColumns, outColumns); + var estimator = CreateInstance(context, inColumns, outColumns); + return new SuggestedTransform(pipelineNode, estimator); + } + + private static IEstimator CreateInstance(MLContext context, string[] inColumns, string[] outColumns) + { + var cols = new InputOutputColumnPair[inColumns.Length]; + for (var i = 0; i < cols.Length; i++) + { + cols[i] = new InputOutputColumnPair(outColumns[i], inColumns[i]); + } + return context.Transforms.Categorical.OneHotHashEncoding(cols); + } + } + + internal class TextFeaturizingExtension : IEstimatorExtension + { + public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode) + { + return CreateInstance(context, pipelineNode.InColumns[0], pipelineNode.OutColumns[0]); + } + + public static SuggestedTransform CreateSuggestedTransform(MLContext context, string inColumn, string outColumn) + { + var pipelineNode = new PipelineNode(EstimatorName.TextFeaturizing.ToString(), + PipelineNodeType.Transform, inColumn, outColumn); + var estimator = CreateInstance(context, inColumn, outColumn); + return new SuggestedTransform(pipelineNode, estimator); + } + + private static IEstimator CreateInstance(MLContext context, string inColumn, string outColumn) + { + return context.Transforms.Text.FeaturizeText(outColumn, inColumn); + } + } + + internal class TypeConvertingExtension : IEstimatorExtension + { + public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode) + { + return CreateInstance(context, pipelineNode.InColumns, pipelineNode.OutColumns); + } + + public static SuggestedTransform CreateSuggestedTransform(MLContext context, string[] inColumns, string[] outColumns) + { + var pipelineNode = new PipelineNode(EstimatorName.TypeConverting.ToString(), + PipelineNodeType.Transform, inColumns, outColumns); + var estimator = CreateInstance(context, inColumns, outColumns); + return new SuggestedTransform(pipelineNode, estimator); + } + + private static IEstimator CreateInstance(MLContext context, string[] inColumns, string[] outColumns) + { + var cols = new InputOutputColumnPair[inColumns.Length]; + for (var i = 0; i < cols.Length; i++) + { + cols[i] = new InputOutputColumnPair(outColumns[i], inColumns[i]); + } + return context.Transforms.Conversion.ConvertType(cols); + } + } + + internal class ValueToKeyMappingExtension : IEstimatorExtension + { + public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode) + { + return CreateInstance(context, pipelineNode.InColumns[0], pipelineNode.OutColumns[0]); + } + + public static SuggestedTransform CreateSuggestedTransform(MLContext context, string inColumn, string outColumn) + { + var pipelineNode = new PipelineNode(EstimatorName.ValueToKeyMapping.ToString(), + PipelineNodeType.Transform, inColumn, outColumn); + var estimator = CreateInstance(context, inColumn, outColumn); + return new SuggestedTransform(pipelineNode, estimator); + } + + private static IEstimator CreateInstance(MLContext context, string inColumn, string outColumn) + { + return context.Transforms.Conversion.MapValueToKey(outColumn, inColumn); + } + } +} diff --git a/src/Microsoft.ML.Auto/EstimatorExtensions/IEstimatorExtension.cs b/src/Microsoft.ML.Auto/EstimatorExtensions/IEstimatorExtension.cs new file mode 100644 index 0000000000..9701fc5a15 --- /dev/null +++ b/src/Microsoft.ML.Auto/EstimatorExtensions/IEstimatorExtension.cs @@ -0,0 +1,11 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Auto +{ + internal interface IEstimatorExtension + { + IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode); + } +} diff --git a/src/Microsoft.ML.Auto/Experiment/Experiment.cs b/src/Microsoft.ML.Auto/Experiment/Experiment.cs new file mode 100644 index 0000000000..4eb389ab79 --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/Experiment.cs @@ -0,0 +1,150 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; + +namespace Microsoft.ML.Auto +{ + internal class Experiment where TRunDetail : RunDetail + { + private readonly MLContext _context; + private readonly OptimizingMetricInfo _optimizingMetricInfo; + private readonly TaskKind _task; + private readonly IProgress _progressCallback; + private readonly ExperimentSettings _experimentSettings; + private readonly IMetricsAgent _metricsAgent; + private readonly IEnumerable _trainerWhitelist; + private readonly DirectoryInfo _modelDirectory; + private readonly DatasetColumnInfo[] _datasetColumnInfo; + private readonly IRunner _runner; + private readonly IList _history = new List(); + + + public Experiment(MLContext context, + TaskKind task, + OptimizingMetricInfo metricInfo, + IProgress progressCallback, + ExperimentSettings experimentSettings, + IMetricsAgent metricsAgent, + IEnumerable trainerWhitelist, + DatasetColumnInfo[] datasetColumnInfo, + IRunner runner) + { + _context = context; + _optimizingMetricInfo = metricInfo; + _task = task; + _progressCallback = progressCallback; + _experimentSettings = experimentSettings; + _metricsAgent = metricsAgent; + _trainerWhitelist = trainerWhitelist; + _modelDirectory = GetModelDirectory(_experimentSettings.CacheDirectory); + _datasetColumnInfo = datasetColumnInfo; + _runner = runner; + } + + public IList Execute() + { + var stopwatch = Stopwatch.StartNew(); + var iterationResults = new List(); + + do + { + var iterationStopwatch = Stopwatch.StartNew(); + + // get next pipeline + var getPiplelineStopwatch = Stopwatch.StartNew(); + var pipeline = PipelineSuggester.GetNextInferredPipeline(_context, _history, _datasetColumnInfo, _task, _optimizingMetricInfo.IsMaximizing, _trainerWhitelist, _experimentSettings.CacheBeforeTrainer); + var pipelineInferenceTimeInSeconds = getPiplelineStopwatch.Elapsed.TotalSeconds; + + // break if no candidates returned, means no valid pipeline available + if (pipeline == null) + { + break; + } + + // evaluate pipeline + Log(LogSeverity.Debug, $"Evaluating pipeline {pipeline.ToString()}"); + (SuggestedPipelineRunDetail suggestedPipelineRunDetail, TRunDetail runDetail) + = _runner.Run(pipeline, _modelDirectory, _history.Count + 1); + _history.Add(suggestedPipelineRunDetail); + WriteIterationLog(pipeline, suggestedPipelineRunDetail, iterationStopwatch); + + runDetail.RuntimeInSeconds = iterationStopwatch.Elapsed.TotalSeconds; + runDetail.PipelineInferenceTimeInSeconds = getPiplelineStopwatch.Elapsed.TotalSeconds; + + ReportProgress(runDetail); + iterationResults.Add(runDetail); + + // if model is perfect, break + if (_metricsAgent.IsModelPerfect(suggestedPipelineRunDetail.Score)) + { + break; + } + + } while (_history.Count < _experimentSettings.MaxModels && + !_experimentSettings.CancellationToken.IsCancellationRequested && + stopwatch.Elapsed.TotalSeconds < _experimentSettings.MaxExperimentTimeInSeconds); + + return iterationResults; + } + + private static DirectoryInfo GetModelDirectory(DirectoryInfo rootDir) + { + if (rootDir == null) + { + return null; + } + var subdirs = rootDir.Exists ? + new HashSet(rootDir.EnumerateDirectories().Select(d => d.Name)) : + new HashSet(); + string experimentDir; + for (var i = 0; ; i++) + { + experimentDir = $"experiment{i}"; + if (!subdirs.Contains(experimentDir)) + { + break; + } + } + var experimentDirFullPath = Path.Combine(rootDir.FullName, experimentDir); + var experimentDirInfo = new DirectoryInfo(experimentDirFullPath); + if (!experimentDirInfo.Exists) + { + experimentDirInfo.Create(); + } + return experimentDirInfo; + } + + private void ReportProgress(TRunDetail iterationResult) + { + try + { + _progressCallback?.Report(iterationResult); + } + catch (Exception ex) + { + Log(LogSeverity.Error, $"Progress report callback reported exception {ex}"); + } + } + + private void WriteIterationLog(SuggestedPipeline pipeline, SuggestedPipelineRunDetail runResult, Stopwatch stopwatch) + { + Log(LogSeverity.Debug, $"{_history.Count}\t{runResult.Score}\t{stopwatch.Elapsed}\t{pipeline.ToString()}"); + } + + private void Log(LogSeverity severity, string message) + { + if(_experimentSettings?.DebugLogger == null) + { + return; + } + + _experimentSettings.DebugLogger.Log(severity, message); + } + } +} diff --git a/src/Microsoft.ML.Auto/Experiment/MetricsAgents/BinaryMetricsAgent.cs b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/BinaryMetricsAgent.cs new file mode 100644 index 0000000000..4b23abd0cd --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/BinaryMetricsAgent.cs @@ -0,0 +1,86 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal class BinaryMetricsAgent : IMetricsAgent + { + private readonly MLContext _mlContext; + private readonly BinaryClassificationMetric _optimizingMetric; + + public BinaryMetricsAgent(MLContext mlContext, + BinaryClassificationMetric optimizingMetric) + { + _mlContext = mlContext; + _optimizingMetric = optimizingMetric; + } + + public double GetScore(BinaryClassificationMetrics metrics) + { + if (metrics == null) + { + return double.NaN; + } + + switch (_optimizingMetric) + { + case BinaryClassificationMetric.Accuracy: + return metrics.Accuracy; + case BinaryClassificationMetric.AreaUnderRocCurve: + return metrics.AreaUnderRocCurve; + case BinaryClassificationMetric.AreaUnderPrecisionRecallCurve: + return metrics.AreaUnderPrecisionRecallCurve; + case BinaryClassificationMetric.F1Score: + return metrics.F1Score; + case BinaryClassificationMetric.NegativePrecision: + return metrics.NegativePrecision; + case BinaryClassificationMetric.NegativeRecall: + return metrics.NegativeRecall; + case BinaryClassificationMetric.PositivePrecision: + return metrics.PositivePrecision; + case BinaryClassificationMetric.PositiveRecall: + return metrics.PositiveRecall; + default: + throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric); + } + } + + public bool IsModelPerfect(double score) + { + if (double.IsNaN(score)) + { + return false; + } + + switch (_optimizingMetric) + { + case BinaryClassificationMetric.Accuracy: + return score == 1; + case BinaryClassificationMetric.AreaUnderRocCurve: + return score == 1; + case BinaryClassificationMetric.AreaUnderPrecisionRecallCurve: + return score == 1; + case BinaryClassificationMetric.F1Score: + return score == 1; + case BinaryClassificationMetric.NegativePrecision: + return score == 1; + case BinaryClassificationMetric.NegativeRecall: + return score == 1; + case BinaryClassificationMetric.PositivePrecision: + return score == 1; + case BinaryClassificationMetric.PositiveRecall: + return score == 1; + default: + throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric); + } + } + + public BinaryClassificationMetrics EvaluateMetrics(IDataView data, string labelColumn) + { + return _mlContext.BinaryClassification.EvaluateNonCalibrated(data, labelColumn); + } + } +} diff --git a/src/Microsoft.ML.Auto/Experiment/MetricsAgents/IMetricsAgent.cs b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/IMetricsAgent.cs new file mode 100644 index 0000000000..d1605aac18 --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/IMetricsAgent.cs @@ -0,0 +1,15 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Auto +{ + internal interface IMetricsAgent + { + double GetScore(T metrics); + + bool IsModelPerfect(double score); + + T EvaluateMetrics(IDataView data, string labelColumn); + } +} diff --git a/src/Microsoft.ML.Auto/Experiment/MetricsAgents/MetricsAgentUtil.cs b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/MetricsAgentUtil.cs new file mode 100644 index 0000000000..80d292648f --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/MetricsAgentUtil.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Auto +{ + internal static class MetricsAgentUtil + { + public static NotSupportedException BuildMetricNotSupportedException(T optimizingMetric) + { + return new NotSupportedException($"{optimizingMetric} is not a supported sweep metric"); + } + } +} diff --git a/src/Microsoft.ML.Auto/Experiment/MetricsAgents/MultiMetricsAgent.cs b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/MultiMetricsAgent.cs new file mode 100644 index 0000000000..eb625c02a0 --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/MultiMetricsAgent.cs @@ -0,0 +1,74 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal class MultiMetricsAgent : IMetricsAgent + { + private readonly MLContext _mlContext; + private readonly MulticlassClassificationMetric _optimizingMetric; + + public MultiMetricsAgent(MLContext mlContext, + MulticlassClassificationMetric optimizingMetric) + { + _mlContext = mlContext; + _optimizingMetric = optimizingMetric; + } + + public double GetScore(MulticlassClassificationMetrics metrics) + { + if (metrics == null) + { + return double.NaN; + } + + switch (_optimizingMetric) + { + case MulticlassClassificationMetric.MacroAccuracy: + return metrics.MacroAccuracy; + case MulticlassClassificationMetric.MicroAccuracy: + return metrics.MicroAccuracy; + case MulticlassClassificationMetric.LogLoss: + return metrics.LogLoss; + case MulticlassClassificationMetric.LogLossReduction: + return metrics.LogLossReduction; + case MulticlassClassificationMetric.TopKAccuracy: + return metrics.TopKAccuracy; + default: + throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric); + } + } + + public bool IsModelPerfect(double score) + { + if (double.IsNaN(score)) + { + return false; + } + + switch (_optimizingMetric) + { + case MulticlassClassificationMetric.MacroAccuracy: + return score == 1; + case MulticlassClassificationMetric.MicroAccuracy: + return score == 1; + case MulticlassClassificationMetric.LogLoss: + return score == 0; + case MulticlassClassificationMetric.LogLossReduction: + return score == 1; + case MulticlassClassificationMetric.TopKAccuracy: + return score == 1; + default: + throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric); + } + } + + public MulticlassClassificationMetrics EvaluateMetrics(IDataView data, string labelColumn) + { + return _mlContext.MulticlassClassification.Evaluate(data, labelColumn); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/Experiment/MetricsAgents/RegressionMetricsAgent.cs b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/RegressionMetricsAgent.cs new file mode 100644 index 0000000000..9350acd643 --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/RegressionMetricsAgent.cs @@ -0,0 +1,69 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal class RegressionMetricsAgent : IMetricsAgent + { + private readonly MLContext _mlContext; + private readonly RegressionMetric _optimizingMetric; + + public RegressionMetricsAgent(MLContext mlContext, RegressionMetric optimizingMetric) + { + _mlContext = mlContext; + _optimizingMetric = optimizingMetric; + } + + public double GetScore(RegressionMetrics metrics) + { + if (metrics == null) + { + return double.NaN; + } + + switch (_optimizingMetric) + { + case RegressionMetric.MeanAbsoluteError: + return metrics.MeanAbsoluteError; + case RegressionMetric.MeanSquaredError: + return metrics.MeanSquaredError; + case RegressionMetric.RootMeanSquaredError: + return metrics.RootMeanSquaredError; + case RegressionMetric.RSquared: + return metrics.RSquared; + default: + throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric); + } + } + + public bool IsModelPerfect(double score) + { + if (double.IsNaN(score)) + { + return false; + } + + switch (_optimizingMetric) + { + case RegressionMetric.MeanAbsoluteError: + return score == 0; + case RegressionMetric.MeanSquaredError: + return score == 0; + case RegressionMetric.RootMeanSquaredError: + return score == 0; + case RegressionMetric.RSquared: + return score == 1; + default: + throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric); + } + } + + public RegressionMetrics EvaluateMetrics(IDataView data, string labelColumn) + { + return _mlContext.Regression.Evaluate(data, labelColumn); + } + } +} diff --git a/src/Microsoft.ML.Auto/Experiment/ModelContainer.cs b/src/Microsoft.ML.Auto/Experiment/ModelContainer.cs new file mode 100644 index 0000000000..775471fdfe --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/ModelContainer.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.IO; + +namespace Microsoft.ML.Auto +{ + internal class ModelContainer + { + private readonly MLContext _mlContext; + private readonly FileInfo _fileInfo; + private readonly ITransformer _model; + + internal ModelContainer(MLContext mlContext, ITransformer model) + { + _mlContext = mlContext; + _model = model; + } + + internal ModelContainer(MLContext mlContext, FileInfo fileInfo, ITransformer model, DataViewSchema modelInputSchema) + { + _mlContext = mlContext; + _fileInfo = fileInfo; + + // Write model to disk + using (var fs = File.Create(fileInfo.FullName)) + { + _mlContext.Model.Save(model, modelInputSchema, fs); + } + } + + public ITransformer GetModel() + { + // If model stored in memory, return it + if (_model != null) + { + return _model; + } + + // Load model from disk + ITransformer model; + using (var stream = new FileStream(_fileInfo.FullName, FileMode.Open, FileAccess.Read, FileShare.Read)) + { + model = _mlContext.Model.Load(stream, out var modelInputSchema); + } + return model; + } + } +} diff --git a/src/Microsoft.ML.Auto/Experiment/OptimizingMetricInfo.cs b/src/Microsoft.ML.Auto/Experiment/OptimizingMetricInfo.cs new file mode 100644 index 0000000000..54f23fe1be --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/OptimizingMetricInfo.cs @@ -0,0 +1,44 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Linq; + +namespace Microsoft.ML.Auto +{ + internal sealed class OptimizingMetricInfo + { + public bool IsMaximizing { get; } + + private static RegressionMetric[] _minimizingRegressionMetrics = new RegressionMetric[] + { + RegressionMetric.MeanAbsoluteError, + RegressionMetric.MeanSquaredError, + RegressionMetric.RootMeanSquaredError + }; + + private static BinaryClassificationMetric[] _minimizingBinaryMetrics = new BinaryClassificationMetric[] + { + }; + + private static MulticlassClassificationMetric[] _minimizingMulticlassMetrics = new MulticlassClassificationMetric[] + { + MulticlassClassificationMetric.LogLoss, + }; + + public OptimizingMetricInfo(RegressionMetric regressionMetric) + { + IsMaximizing = !_minimizingRegressionMetrics.Contains(regressionMetric); + } + + public OptimizingMetricInfo(BinaryClassificationMetric binaryMetric) + { + IsMaximizing = !_minimizingBinaryMetrics.Contains(binaryMetric); + } + + public OptimizingMetricInfo(MulticlassClassificationMetric multiMetric) + { + IsMaximizing = !_minimizingMulticlassMetrics.Contains(multiMetric); + } + } +} diff --git a/src/Microsoft.ML.Auto/Experiment/RecipeInference.cs b/src/Microsoft.ML.Auto/Experiment/RecipeInference.cs new file mode 100644 index 0000000000..c680a4e96b --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/RecipeInference.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; + +namespace Microsoft.ML.Auto +{ + internal static class RecipeInference + { + /// + /// Given a predictor type, return a set of all permissible trainers (with their sweeper params, if defined). + /// + /// Array of viable learners. + public static IEnumerable AllowedTrainers(MLContext mlContext, TaskKind task, + ColumnInformation columnInfo, IEnumerable trainerWhitelist) + { + var trainerExtensions = TrainerExtensionCatalog.GetTrainers(task, trainerWhitelist); + + var trainers = new List(); + foreach (var trainerExtension in trainerExtensions) + { + var learner = new SuggestedTrainer(mlContext, trainerExtension, columnInfo); + trainers.Add(learner); + } + return trainers.ToArray(); + } + } +} diff --git a/src/Microsoft.ML.Auto/Experiment/Runners/CrossValRunner.cs b/src/Microsoft.ML.Auto/Experiment/Runners/CrossValRunner.cs new file mode 100644 index 0000000000..f211b3107f --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/Runners/CrossValRunner.cs @@ -0,0 +1,74 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.IO; +using System.Linq; + +namespace Microsoft.ML.Auto +{ + internal class CrossValRunner : IRunner> + where TMetrics : class + { + private readonly MLContext _context; + private readonly IDataView[] _trainDatasets; + private readonly IDataView[] _validDatasets; + private readonly IMetricsAgent _metricsAgent; + private readonly IEstimator _preFeaturizer; + private readonly ITransformer[] _preprocessorTransforms; + private readonly string _labelColumn; + private readonly IDebugLogger _logger; + private readonly DataViewSchema _modelInputSchema; + + public CrossValRunner(MLContext context, + IDataView[] trainDatasets, + IDataView[] validDatasets, + IMetricsAgent metricsAgent, + IEstimator preFeaturizer, + ITransformer[] preprocessorTransforms, + string labelColumn, + IDebugLogger logger) + { + _context = context; + _trainDatasets = trainDatasets; + _validDatasets = validDatasets; + _metricsAgent = metricsAgent; + _preFeaturizer = preFeaturizer; + _preprocessorTransforms = preprocessorTransforms; + _labelColumn = labelColumn; + _logger = logger; + _modelInputSchema = trainDatasets[0].Schema; + } + + public (SuggestedPipelineRunDetail suggestedPipelineRunDetail, CrossValidationRunDetail runDetail) + Run(SuggestedPipeline pipeline, DirectoryInfo modelDirectory, int iterationNum) + { + var trainResults = new List>(); + + for (var i = 0; i < _trainDatasets.Length; i++) + { + var modelFileInfo = RunnerUtil.GetModelFileInfo(modelDirectory, iterationNum, i + 1); + var trainResult = RunnerUtil.TrainAndScorePipeline(_context, pipeline, _trainDatasets[i], _validDatasets[i], + _labelColumn, _metricsAgent, _preprocessorTransforms?[i], modelFileInfo, _modelInputSchema, _logger); + trainResults.Add(new SuggestedPipelineTrainResult(trainResult.model, trainResult.metrics, trainResult.exception, trainResult.score)); + } + + var avgScore = CalcAverageScore(trainResults.Select(r => r.Score)); + var allRunsSucceeded = trainResults.All(r => r.Exception == null); + + var suggestedPipelineRunDetail = new SuggestedPipelineCrossValRunDetail(pipeline, avgScore, allRunsSucceeded, trainResults); + var runDetail = suggestedPipelineRunDetail.ToIterationResult(_preFeaturizer); + return (suggestedPipelineRunDetail, runDetail); + } + + private static double CalcAverageScore(IEnumerable scores) + { + if (scores.Any(s => double.IsNaN(s))) + { + return double.NaN; + } + return scores.Average(); + } + } +} diff --git a/src/Microsoft.ML.Auto/Experiment/Runners/CrossValSummaryRunner.cs b/src/Microsoft.ML.Auto/Experiment/Runners/CrossValSummaryRunner.cs new file mode 100644 index 0000000000..701baa1663 --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/Runners/CrossValSummaryRunner.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; + +namespace Microsoft.ML.Auto +{ + internal class CrossValSummaryRunner : IRunner> + where TMetrics : class + { + private readonly MLContext _context; + private readonly IDataView[] _trainDatasets; + private readonly IDataView[] _validDatasets; + private readonly IMetricsAgent _metricsAgent; + private readonly IEstimator _preFeaturizer; + private readonly ITransformer[] _preprocessorTransforms; + private readonly string _labelColumn; + private readonly OptimizingMetricInfo _optimizingMetricInfo; + private readonly IDebugLogger _logger; + private readonly DataViewSchema _modelInputSchema; + + public CrossValSummaryRunner(MLContext context, + IDataView[] trainDatasets, + IDataView[] validDatasets, + IMetricsAgent metricsAgent, + IEstimator preFeaturizer, + ITransformer[] preprocessorTransforms, + string labelColumn, + OptimizingMetricInfo optimizingMetricInfo, + IDebugLogger logger) + { + _context = context; + _trainDatasets = trainDatasets; + _validDatasets = validDatasets; + _metricsAgent = metricsAgent; + _preFeaturizer = preFeaturizer; + _preprocessorTransforms = preprocessorTransforms; + _labelColumn = labelColumn; + _optimizingMetricInfo = optimizingMetricInfo; + _logger = logger; + _modelInputSchema = trainDatasets[0].Schema; + } + + public (SuggestedPipelineRunDetail suggestedPipelineRunDetail, RunDetail runDetail) + Run(SuggestedPipeline pipeline, DirectoryInfo modelDirectory, int iterationNum) + { + var trainResults = new List<(ModelContainer model, TMetrics metrics, Exception exception, double score)>(); + + for (var i = 0; i < _trainDatasets.Length; i++) + { + var modelFileInfo = RunnerUtil.GetModelFileInfo(modelDirectory, iterationNum, i + 1); + var trainResult = RunnerUtil.TrainAndScorePipeline(_context, pipeline, _trainDatasets[i], _validDatasets[i], + _labelColumn, _metricsAgent, _preprocessorTransforms?.ElementAt(i), modelFileInfo, _modelInputSchema, + _logger); + trainResults.Add(trainResult); + } + + var allRunsSucceeded = trainResults.All(r => r.exception == null); + if (!allRunsSucceeded) + { + var firstException = trainResults.First(r => r.exception != null).exception; + var errorRunDetail = new SuggestedPipelineRunDetail(pipeline, double.NaN, false, null, null, firstException); + return (errorRunDetail, errorRunDetail.ToIterationResult(_preFeaturizer)); + } + + // Get the model from the best fold + var bestFoldIndex = BestResultUtil.GetIndexOfBestScore(trainResults.Select(r => r.score), _optimizingMetricInfo.IsMaximizing); + var bestModel = trainResults.ElementAt(bestFoldIndex).model; + + // Get the metrics from the fold whose score is closest to avg of all fold scores + var avgScore = trainResults.Average(r => r.score); + var indexClosestToAvg = GetIndexClosestToAverage(trainResults.Select(r => r.score), avgScore); + var metricsClosestToAvg = trainResults[indexClosestToAvg].metrics; + + // Build result objects + var suggestedPipelineRunDetail = new SuggestedPipelineRunDetail(pipeline, avgScore, allRunsSucceeded, metricsClosestToAvg, bestModel, null); + var runDetail = suggestedPipelineRunDetail.ToIterationResult(_preFeaturizer); + return (suggestedPipelineRunDetail, runDetail); + } + + private static int GetIndexClosestToAverage(IEnumerable values, double average) + { + int avgFoldIndex = -1; + var smallestDistFromAvg = double.PositiveInfinity; + for (var i = 0; i < values.Count(); i++) + { + var distFromAvg = Math.Abs(values.ElementAt(i) - average); + if (distFromAvg < smallestDistFromAvg) + { + smallestDistFromAvg = distFromAvg; + avgFoldIndex = i; + } + } + return avgFoldIndex; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/Experiment/Runners/IRunner.cs b/src/Microsoft.ML.Auto/Experiment/Runners/IRunner.cs new file mode 100644 index 0000000000..8bb56fc9d0 --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/Runners/IRunner.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.IO; + +namespace Microsoft.ML.Auto +{ + internal interface IRunner where TRunDetail : RunDetail + { + (SuggestedPipelineRunDetail suggestedPipelineRunDetail, TRunDetail runDetail) + Run (SuggestedPipeline pipeline, DirectoryInfo modelDirectory, int iterationNum); + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/Experiment/Runners/RunnerUtil.cs b/src/Microsoft.ML.Auto/Experiment/Runners/RunnerUtil.cs new file mode 100644 index 0000000000..88575a2280 --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/Runners/RunnerUtil.cs @@ -0,0 +1,59 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; + +namespace Microsoft.ML.Auto +{ + internal static class RunnerUtil + { + public static (ModelContainer model, TMetrics metrics, Exception exception, double score) + TrainAndScorePipeline(MLContext context, + SuggestedPipeline pipeline, + IDataView trainData, + IDataView validData, + string labelColumn, + IMetricsAgent metricsAgent, + ITransformer preprocessorTransform, + FileInfo modelFileInfo, + DataViewSchema modelInputSchema, + IDebugLogger logger) where TMetrics : class + { + try + { + var estimator = pipeline.ToEstimator(); + var model = estimator.Fit(trainData); + + var scoredData = model.Transform(validData); + var metrics = metricsAgent.EvaluateMetrics(scoredData, labelColumn); + var score = metricsAgent.GetScore(metrics); + + if (preprocessorTransform != null) + { + model = preprocessorTransform.Append(model); + } + + // Build container for model + var modelContainer = modelFileInfo == null ? + new ModelContainer(context, model) : + new ModelContainer(context, modelFileInfo, model, modelInputSchema); + + return (modelContainer, metrics, null, score); + } + catch (Exception ex) + { + logger?.Log(LogSeverity.Error, $"Pipeline crashed: {pipeline.ToString()} . Exception: {ex}"); + return (null, null, ex, double.NaN); + } + } + + public static FileInfo GetModelFileInfo(DirectoryInfo modelDirectory, int iterationNum, int foldNum) + { + return modelDirectory == null ? + null : + new FileInfo(Path.Combine(modelDirectory.FullName, $"Model{iterationNum}_{foldNum}.zip")); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/Experiment/Runners/TrainValidateRunner.cs b/src/Microsoft.ML.Auto/Experiment/Runners/TrainValidateRunner.cs new file mode 100644 index 0000000000..9226dcbeb0 --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/Runners/TrainValidateRunner.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.IO; + +namespace Microsoft.ML.Auto +{ + internal class TrainValidateRunner : IRunner> + where TMetrics : class + { + private readonly MLContext _context; + private readonly IDataView _trainData; + private readonly IDataView _validData; + private readonly string _labelColumn; + private readonly IMetricsAgent _metricsAgent; + private readonly IEstimator _preFeaturizer; + private readonly ITransformer _preprocessorTransform; + private readonly IDebugLogger _logger; + private readonly DataViewSchema _modelInputSchema; + + public TrainValidateRunner(MLContext context, + IDataView trainData, + IDataView validData, + string labelColumn, + IMetricsAgent metricsAgent, + IEstimator preFeaturizer, + ITransformer preprocessorTransform, + IDebugLogger logger) + { + _context = context; + _trainData = trainData; + _validData = validData; + _labelColumn = labelColumn; + _metricsAgent = metricsAgent; + _preFeaturizer = preFeaturizer; + _preprocessorTransform = preprocessorTransform; + _logger = logger; + _modelInputSchema = trainData.Schema; + } + + public (SuggestedPipelineRunDetail suggestedPipelineRunDetail, RunDetail runDetail) + Run(SuggestedPipeline pipeline, DirectoryInfo modelDirectory, int iterationNum) + { + var modelFileInfo = GetModelFileInfo(modelDirectory, iterationNum); + var trainResult = RunnerUtil.TrainAndScorePipeline(_context, pipeline, _trainData, _validData, + _labelColumn, _metricsAgent, _preprocessorTransform, modelFileInfo, _modelInputSchema, _logger); + var suggestedPipelineRunDetail = new SuggestedPipelineRunDetail(pipeline, + trainResult.score, + trainResult.exception == null, + trainResult.metrics, + trainResult.model, + trainResult.exception); + var runDetail = suggestedPipelineRunDetail.ToIterationResult(_preFeaturizer); + return (suggestedPipelineRunDetail, runDetail); + } + + private static FileInfo GetModelFileInfo(DirectoryInfo modelDirectory, int iterationNum) + { + return modelDirectory == null ? + null : + new FileInfo(Path.Combine(modelDirectory.FullName, $"Model{iterationNum}.zip")); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/Experiment/SuggestedPipeline.cs b/src/Microsoft.ML.Auto/Experiment/SuggestedPipeline.cs new file mode 100644 index 0000000000..6019c9c9ac --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/SuggestedPipeline.cs @@ -0,0 +1,144 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + /// + /// A runnable pipeline. Contains a learner and set of transforms, + /// along with a RunSummary if it has already been exectued. + /// + internal class SuggestedPipeline + { + public readonly IList Transforms; + public readonly SuggestedTrainer Trainer; + public readonly IList TransformsPostTrainer; + + private readonly MLContext _context; + private readonly bool _cacheBeforeTrainer; + + public SuggestedPipeline(IEnumerable transforms, + IEnumerable transformsPostTrainer, + SuggestedTrainer trainer, + MLContext context, + bool cacheBeforeTrainer) + { + Transforms = transforms.Select(t => t.Clone()).ToList(); + TransformsPostTrainer = transformsPostTrainer.Select(t => t.Clone()).ToList(); + Trainer = trainer.Clone(); + _context = context; + _cacheBeforeTrainer = cacheBeforeTrainer; + } + + public override string ToString() => $"{string.Join(" ", Transforms.Select(t => $"xf={t}"))} tr={this.Trainer} {string.Join(" ", TransformsPostTrainer.Select(t => $"xf={t}"))} cache={(_cacheBeforeTrainer ? "+" : "-")}"; + + public override bool Equals(object obj) + { + var pipeline = obj as SuggestedPipeline; + if(pipeline == null) + { + return false; + } + return pipeline.ToString() == this.ToString(); + } + + public override int GetHashCode() + { + return ToString().GetHashCode(); + } + + public Pipeline ToPipeline() + { + var pipelineElements = new List(); + foreach(var transform in Transforms) + { + pipelineElements.Add(transform.PipelineNode); + } + pipelineElements.Add(Trainer.ToPipelineNode()); + foreach (var transform in TransformsPostTrainer) + { + pipelineElements.Add(transform.PipelineNode); + } + return new Pipeline(pipelineElements.ToArray(), _cacheBeforeTrainer); + } + + public static SuggestedPipeline FromPipeline(MLContext context, Pipeline pipeline) + { + var transforms = new List(); + var transformsPostTrainer = new List(); + SuggestedTrainer trainer = null; + + var trainerEncountered = false; + foreach(var pipelineNode in pipeline.Nodes) + { + if(pipelineNode.NodeType == PipelineNodeType.Trainer) + { + var trainerName = (TrainerName)Enum.Parse(typeof(TrainerName), pipelineNode.Name); + var trainerExtension = TrainerExtensionCatalog.GetTrainerExtension(trainerName); + var hyperParamSet = TrainerExtensionUtil.BuildParameterSet(trainerName, pipelineNode.Properties); + var columnInfo = TrainerExtensionUtil.BuildColumnInfo(pipelineNode.Properties); + trainer = new SuggestedTrainer(context, trainerExtension, columnInfo, hyperParamSet); + trainerEncountered = true; + } + else if (pipelineNode.NodeType == PipelineNodeType.Transform) + { + var estimatorName = (EstimatorName)Enum.Parse(typeof(EstimatorName), pipelineNode.Name); + var estimatorExtension = EstimatorExtensionCatalog.GetExtension(estimatorName); + var estimator = estimatorExtension.CreateInstance(context, pipelineNode); + var transform = new SuggestedTransform(pipelineNode, estimator); + if (!trainerEncountered) + { + transforms.Add(transform); + } + else + { + transformsPostTrainer.Add(transform); + } + } + } + + return new SuggestedPipeline(transforms, transformsPostTrainer, trainer, context, pipeline.CacheBeforeTrainer); + } + + public IEstimator ToEstimator() + { + IEstimator pipeline = new EstimatorChain(); + + // Append each transformer to the pipeline + foreach (var transform in Transforms) + { + if (transform.Estimator != null) + { + pipeline = pipeline.Append(transform.Estimator); + } + } + + // Get learner + var learner = Trainer.BuildTrainer(); + + if (_cacheBeforeTrainer) + { + pipeline = pipeline.AppendCacheCheckpoint(_context); + } + + // Append learner to pipeline + pipeline = pipeline.Append(learner); + + // Append each post-trainer transformer to the pipeline + foreach (var transform in TransformsPostTrainer) + { + if (transform.Estimator != null) + { + pipeline = pipeline.Append(transform.Estimator); + } + } + + return pipeline; + } + } +} diff --git a/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineBuilder.cs b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineBuilder.cs new file mode 100644 index 0000000000..a3fad88e0b --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineBuilder.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal static class SuggestedPipelineBuilder + { + public static SuggestedPipeline Build(MLContext context, + ICollection transforms, + ICollection transformsPostTrainer, + SuggestedTrainer trainer, + bool? enableCaching) + { + var trainerInfo = trainer.BuildTrainer().Info; + AddNormalizationTransforms(context, trainerInfo, transforms); + var cacheBeforeTrainer = ShouldCacheBeforeTrainer(trainerInfo, enableCaching); + return new SuggestedPipeline(transforms, transformsPostTrainer, trainer, context, cacheBeforeTrainer); + } + + private static void AddNormalizationTransforms(MLContext context, + TrainerInfo trainerInfo, + ICollection transforms) + { + // Only add normalization if trainer needs it + if (!trainerInfo.NeedNormalization) + { + return; + } + + var transform = NormalizingExtension.CreateSuggestedTransform(context, DefaultColumnNames.Features, DefaultColumnNames.Features); + transforms.Add(transform); + } + + private static bool ShouldCacheBeforeTrainer(TrainerInfo trainerInfo, bool? enableCaching) + { + return enableCaching == true || (enableCaching == null && trainerInfo.WantCaching); + } + } +} diff --git a/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineCrossValRunDetail.cs b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineCrossValRunDetail.cs new file mode 100644 index 0000000000..e21b55428a --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineCrossValRunDetail.cs @@ -0,0 +1,54 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.ML.Auto +{ + internal sealed class SuggestedPipelineTrainResult + { + public readonly TMetrics ValidationMetrics; + public readonly ModelContainer ModelContainer; + public readonly Exception Exception; + public readonly double Score; + + internal SuggestedPipelineTrainResult(ModelContainer modelContainer, + TMetrics metrics, + Exception exception, + double score) + { + ModelContainer = modelContainer; + ValidationMetrics = metrics; + Exception = exception; + Score = score; + } + + public TrainResult ToTrainResult() + { + return new TrainResult(ModelContainer, ValidationMetrics, Exception); + } + } + + internal sealed class SuggestedPipelineCrossValRunDetail : SuggestedPipelineRunDetail + { + public readonly IEnumerable> Results; + + internal SuggestedPipelineCrossValRunDetail(SuggestedPipeline pipeline, + double score, + bool runSucceeded, + IEnumerable> results) : base(pipeline, score, runSucceeded) + { + Results = results; + } + + public CrossValidationRunDetail ToIterationResult(IEstimator preFeaturizer) + { + var estimator = SuggestedPipelineRunDetailUtil.PrependPreFeaturizer(Pipeline.ToEstimator(), preFeaturizer); + return new CrossValidationRunDetail(Pipeline.Trainer.TrainerName.ToString(), estimator, + Pipeline.ToPipeline(), Results.Select(r => r.ToTrainResult())); + } + } +} diff --git a/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineRunDetail.cs b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineRunDetail.cs new file mode 100644 index 0000000000..7cb76e1ab3 --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineRunDetail.cs @@ -0,0 +1,58 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Auto +{ + internal class SuggestedPipelineRunDetail + { + public readonly SuggestedPipeline Pipeline; + public readonly bool RunSucceded; + public readonly double Score; + + public SuggestedPipelineRunDetail(SuggestedPipeline pipeline, double score, bool runSucceeded) + { + Pipeline = pipeline; + Score = score; + RunSucceded = runSucceeded; + } + + public static SuggestedPipelineRunDetail FromPipelineRunResult(MLContext context, PipelineScore pipelineRunResult) + { + return new SuggestedPipelineRunDetail(SuggestedPipeline.FromPipeline(context, pipelineRunResult.Pipeline), pipelineRunResult.Score, pipelineRunResult.RunSucceded); + } + + public IRunResult ToRunResult(bool isMetricMaximizing) + { + return new RunResult(Pipeline.Trainer.HyperParamSet, Score, isMetricMaximizing); + } + } + + internal class SuggestedPipelineRunDetail : SuggestedPipelineRunDetail + { + public readonly TMetrics ValidationMetrics; + public readonly ModelContainer ModelContainer; + public readonly Exception Exception; + + internal SuggestedPipelineRunDetail(SuggestedPipeline pipeline, + double score, + bool runSucceeded, + TMetrics validationMetrics, + ModelContainer modelContainer, + Exception ex) : base(pipeline, score, runSucceeded) + { + ValidationMetrics = validationMetrics; + ModelContainer = modelContainer; + Exception = ex; + } + + public RunDetail ToIterationResult(IEstimator preFeaturizer) + { + var estimator = SuggestedPipelineRunDetailUtil.PrependPreFeaturizer(Pipeline.ToEstimator(), preFeaturizer); + return new RunDetail(Pipeline.Trainer.TrainerName.ToString(), estimator, + Pipeline.ToPipeline(), ModelContainer, ValidationMetrics, Exception); + } + } +} diff --git a/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineRunDetailUtil.cs b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineRunDetailUtil.cs new file mode 100644 index 0000000000..8fcb59b4d5 --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineRunDetailUtil.cs @@ -0,0 +1,18 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Auto +{ + internal static class SuggestedPipelineRunDetailUtil + { + public static IEstimator PrependPreFeaturizer(IEstimator estimator, IEstimator preFeaturizer) + { + if (preFeaturizer == null) + { + return estimator; + } + return preFeaturizer.Append(estimator); + } + } +} diff --git a/src/Microsoft.ML.Auto/Experiment/SuggestedTrainer.cs b/src/Microsoft.ML.Auto/Experiment/SuggestedTrainer.cs new file mode 100644 index 0000000000..90fae04eb9 --- /dev/null +++ b/src/Microsoft.ML.Auto/Experiment/SuggestedTrainer.cs @@ -0,0 +1,92 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Trainers; + +namespace Microsoft.ML.Auto +{ + internal class SuggestedTrainer + { + public IEnumerable SweepParams { get; } + public TrainerName TrainerName { get; } + public ParameterSet HyperParamSet { get; set; } + + private readonly MLContext _mlContext; + private readonly ITrainerExtension _trainerExtension; + private readonly ColumnInformation _columnInfo; + + internal SuggestedTrainer(MLContext mlContext, ITrainerExtension trainerExtension, + ColumnInformation columnInfo, + ParameterSet hyperParamSet = null) + { + _mlContext = mlContext; + _trainerExtension = trainerExtension; + _columnInfo = columnInfo; + SweepParams = _trainerExtension.GetHyperparamSweepRanges(); + TrainerName = TrainerExtensionCatalog.GetTrainerName(_trainerExtension); + SetHyperparamValues(hyperParamSet); + } + + public void SetHyperparamValues(ParameterSet hyperParamSet) + { + HyperParamSet = hyperParamSet; + PropagateParamSetValues(); + } + + public SuggestedTrainer Clone() + { + return new SuggestedTrainer(_mlContext, _trainerExtension, _columnInfo, HyperParamSet?.Clone()); + } + + public ITrainerEstimator, object> BuildTrainer() + { + IEnumerable sweepParams = null; + if (HyperParamSet != null) + { + sweepParams = SweepParams; + } + return _trainerExtension.CreateInstance(_mlContext, sweepParams, _columnInfo); + } + + public override string ToString() + { + var paramsStr = string.Empty; + if (SweepParams != null) + { + paramsStr = string.Join(", ", SweepParams.Where(p => p != null && p.RawValue != null).Select(p => $"{p.Name}:{p.ProcessedValue()}")); + } + return $"{TrainerName}{{{paramsStr}}}"; + } + + public PipelineNode ToPipelineNode() + { + var sweepParams = SweepParams.Where(p => p.RawValue != null); + return _trainerExtension.CreatePipelineNode(sweepParams, _columnInfo); + } + + /// + /// make sure sweep params and param set are consistent + /// + private void PropagateParamSetValues() + { + if (HyperParamSet == null) + { + return; + } + + var spMap = SweepParams.ToDictionary(sp => sp.Name); + + foreach (var hp in HyperParamSet) + { + if (spMap.ContainsKey(hp.Name)) + { + var sp = spMap[hp.Name]; + sp.SetUsingValueText(hp.ValueText); + } + } + } + } +} diff --git a/src/Microsoft.ML.Auto/Microsoft.ML.Auto.csproj b/src/Microsoft.ML.Auto/Microsoft.ML.Auto.csproj new file mode 100644 index 0000000000..cb8eb4fc72 --- /dev/null +++ b/src/Microsoft.ML.Auto/Microsoft.ML.Auto.csproj @@ -0,0 +1,34 @@ + + + netstandard2.0 + 7.3 + Microsoft.ML.Auto + + false + false + + + + + + + + + + + + Microsoft + LICENSE + https://dot.net/ml + https://aka.ms/mlnetlogo + https://aka.ms/mlnetreleasenotes + + ML.NET ML Machine Learning AutoML + Microsoft.ML.Auto + + + + + + + diff --git a/src/Microsoft.ML.Auto/PipelineSuggesters/PipelineSuggester.cs b/src/Microsoft.ML.Auto/PipelineSuggesters/PipelineSuggester.cs new file mode 100644 index 0000000000..ca834ce801 --- /dev/null +++ b/src/Microsoft.ML.Auto/PipelineSuggesters/PipelineSuggester.cs @@ -0,0 +1,217 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal static class PipelineSuggester + { + private const int TopKTrainers = 3; + + public static Pipeline GetNextPipeline(MLContext context, + IEnumerable history, + DatasetColumnInfo[] columns, + TaskKind task, + bool isMaximizingMetric = true) + { + var inferredHistory = history.Select(r => SuggestedPipelineRunDetail.FromPipelineRunResult(context, r)); + var nextInferredPipeline = GetNextInferredPipeline(context, inferredHistory, columns, task, isMaximizingMetric); + return nextInferredPipeline?.ToPipeline(); + } + + public static SuggestedPipeline GetNextInferredPipeline(MLContext context, + IEnumerable history, + DatasetColumnInfo[] columns, + TaskKind task, + bool isMaximizingMetric, + IEnumerable trainerWhitelist = null, + bool? _enableCaching = null) + { + var availableTrainers = RecipeInference.AllowedTrainers(context, task, + ColumnInformationUtil.BuildColumnInfo(columns), trainerWhitelist); + var transforms = TransformInferenceApi.InferTransforms(context, task, columns).ToList(); + var transformsPostTrainer = TransformInferenceApi.InferTransformsPostTrainer(context, task, columns).ToList(); + + // if we haven't run all pipelines once + if (history.Count() < availableTrainers.Count()) + { + return GetNextFirstStagePipeline(context, history, availableTrainers, transforms, transformsPostTrainer, _enableCaching); + } + + // get top trainers from stage 1 runs + var topTrainers = GetTopTrainers(history, availableTrainers, isMaximizingMetric); + + // sort top trainers by # of times they've been run, from lowest to highest + var orderedTopTrainers = OrderTrainersByNumTrials(history, topTrainers); + + // keep as hashset of previously visited pipelines + var visitedPipelines = new HashSet(history.Select(h => h.Pipeline)); + + // iterate over top trainers (from least run to most run), + // to find next pipeline + foreach (var trainer in orderedTopTrainers) + { + var newTrainer = trainer.Clone(); + + // repeat until passes or runs out of chances + const int maxNumberAttempts = 10; + var count = 0; + do + { + // sample new hyperparameters for the learner + if (!SampleHyperparameters(context, newTrainer, history, isMaximizingMetric)) + { + // if unable to sample new hyperparameters for the learner + // (ie SMAC returned 0 suggestions), break + break; + } + + var suggestedPipeline = SuggestedPipelineBuilder.Build(context, transforms, transformsPostTrainer, newTrainer, _enableCaching); + + // make sure we have not seen pipeline before + if (!visitedPipelines.Contains(suggestedPipeline)) + { + return suggestedPipeline; + } + } while (++count <= maxNumberAttempts); + } + + return null; + } + + /// + /// Get top trainers from first stage + /// + private static IEnumerable GetTopTrainers(IEnumerable history, + IEnumerable availableTrainers, + bool isMaximizingMetric) + { + // narrow history to first stage runs + history = history.Take(availableTrainers.Count()); + + history = history.GroupBy(r => r.Pipeline.Trainer.TrainerName).Select(g => g.First()); + IEnumerable sortedHistory = history.OrderBy(r => r.Score); + if(isMaximizingMetric) + { + sortedHistory = sortedHistory.Reverse(); + } + var topTrainers = sortedHistory.Take(TopKTrainers).Select(r => r.Pipeline.Trainer); + return topTrainers; + } + + private static IEnumerable OrderTrainersByNumTrials(IEnumerable history, + IEnumerable selectedTrainers) + { + var selectedTrainerNames = new HashSet(selectedTrainers.Select(t => t.TrainerName)); + return history.Where(h => selectedTrainerNames.Contains(h.Pipeline.Trainer.TrainerName)) + .GroupBy(h => h.Pipeline.Trainer.TrainerName) + .OrderBy(x => x.Count()) + .Select(x => x.First().Pipeline.Trainer); + } + + private static SuggestedPipeline GetNextFirstStagePipeline(MLContext context, + IEnumerable history, + IEnumerable availableTrainers, + ICollection transforms, + ICollection transformsPostTrainer, + bool? _enableCaching) + { + var trainer = availableTrainers.ElementAt(history.Count()); + return SuggestedPipelineBuilder.Build(context, transforms, transformsPostTrainer, trainer, _enableCaching); + } + + private static IValueGenerator[] ConvertToValueGenerators(IEnumerable hps) + { + var results = new IValueGenerator[hps.Count()]; + + for (int i = 0; i < hps.Count(); i++) + { + switch (hps.ElementAt(i)) + { + case SweepableDiscreteParam dp: + var dpArgs = new DiscreteParamArguments() + { + Name = dp.Name, + Values = dp.Options.Select(o => o.ToString()).ToArray() + }; + results[i] = new DiscreteValueGenerator(dpArgs); + break; + + case SweepableFloatParam fp: + var fpArgs = new FloatParamArguments() + { + Name = fp.Name, + Min = fp.Min, + Max = fp.Max, + LogBase = fp.IsLogScale, + }; + if (fp.NumSteps.HasValue) + { + fpArgs.NumSteps = fp.NumSteps.Value; + } + if (fp.StepSize.HasValue) + { + fpArgs.StepSize = fp.StepSize.Value; + } + results[i] = new FloatValueGenerator(fpArgs); + break; + + case SweepableLongParam lp: + var lpArgs = new LongParamArguments() + { + Name = lp.Name, + Min = lp.Min, + Max = lp.Max, + LogBase = lp.IsLogScale + }; + if (lp.NumSteps.HasValue) + { + lpArgs.NumSteps = lp.NumSteps.Value; + } + if (lp.StepSize.HasValue) + { + lpArgs.StepSize = lp.StepSize.Value; + } + results[i] = new LongValueGenerator(lpArgs); + break; + } + } + return results; + } + + /// + /// Samples new hyperparameters for the trainer, and sets them. + /// Returns true if success (new hyperparams were suggested and set). Else, returns false. + /// + private static bool SampleHyperparameters(MLContext context, SuggestedTrainer trainer, IEnumerable history, bool isMaximizingMetric) + { + var sps = ConvertToValueGenerators(trainer.SweepParams); + var sweeper = new SmacSweeper(context, + new SmacSweeper.Arguments + { + SweptParameters = sps + }); + + IEnumerable historyToUse = history + .Where(r => r.RunSucceded && r.Pipeline.Trainer.TrainerName == trainer.TrainerName && r.Pipeline.Trainer.HyperParamSet != null && r.Pipeline.Trainer.HyperParamSet.Any()); + + // get new set of hyperparameter values + var proposedParamSet = sweeper.ProposeSweeps(1, historyToUse.Select(h => h.ToRunResult(isMaximizingMetric))).First(); + if(!proposedParamSet.Any()) + { + return false; + } + + // associate proposed param set with trainer, so that smart hyperparam + // sweepers (like KDO) can map them back. + trainer.SetHyperparamValues(proposedParamSet); + + return true; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/RuleSet1.ruleset b/src/Microsoft.ML.Auto/RuleSet1.ruleset new file mode 100644 index 0000000000..81992ac3d4 --- /dev/null +++ b/src/Microsoft.ML.Auto/RuleSet1.ruleset @@ -0,0 +1,10 @@ + + + + + + + + + + \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/Sweepers/ISweeper.cs b/src/Microsoft.ML.Auto/Sweepers/ISweeper.cs new file mode 100644 index 0000000000..457ebd2645 --- /dev/null +++ b/src/Microsoft.ML.Auto/Sweepers/ISweeper.cs @@ -0,0 +1,272 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.ML.Auto +{ + /// + /// The main interface of the sweeper + /// + internal interface ISweeper + { + /// + /// Returns between 0 and maxSweeps configurations to run. + /// It expects a list of previous runs such that it can generate configurations that were not already tried. + /// The list of runs can be null if there were no previous runs. + /// Some smart sweepers can take advantage of the metric(s) that the caller computes for previous runs. + /// + ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable previousRuns = null); + } + + /// + /// This is the interface that each type of parameter sweep needs to implement + /// + internal interface IValueGenerator + { + /// + /// Given a value in the [0,1] range, return a value for this parameter. + /// + IParameterValue CreateFromNormalized(Double normalizedValue); + + /// + /// Used mainly in grid sweepers, return the i-th distinct value for this parameter + /// + IParameterValue this[int i] { get; } + + /// + /// Used mainly in grid sweepers, return the count of distinct values for this parameter + /// + int Count { get; } + + /// + /// Returns the name of the generated parameter + /// + string Name { get; } + } + + /// + /// Parameter value generated from the sweeping. + /// The parameter values must be immutable. + /// Value is converted to string because the runner will usually want to construct a command line for TL. + /// Implementations of this interface must also override object.GetHashCode() and object.Equals(object) so they are consistent + /// with IEquatable.Equals(IParameterValue). + /// + internal interface IParameterValue : IEquatable + { + string Name { get; } + string ValueText { get; } + } + + /// + /// Type safe version of the IParameterValue interface. + /// + internal interface IParameterValue : IParameterValue + { + TValue Value { get; } + } + + /// + /// A set of parameter values. + /// The parameter set must be immutable. + /// + internal sealed class ParameterSet : IEquatable, IEnumerable + { + private readonly Dictionary _parameterValues; + private readonly int _hash; + + public ParameterSet(IEnumerable parameters) + { + _parameterValues = new Dictionary(); + foreach (var parameter in parameters) + { + _parameterValues.Add(parameter.Name, parameter); + } + + var parameterNames = _parameterValues.Keys.ToList(); + parameterNames.Sort(); + _hash = 0; + foreach (var parameterName in parameterNames) + { + _hash = Hashing.CombineHash(_hash, _parameterValues[parameterName].GetHashCode()); + } + } + + public ParameterSet(Dictionary paramValues, int hash) + { + _parameterValues = paramValues; + _hash = hash; + } + + public IEnumerator GetEnumerator() + { + return _parameterValues.Values.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + public int Count + { + get { return _parameterValues.Count; } + } + + public IParameterValue this[string name] + { + get { return _parameterValues[name]; } + } + + private bool ContainsParamValue(IParameterValue parameterValue) + { + IParameterValue value; + return _parameterValues.TryGetValue(parameterValue.Name, out value) && + parameterValue.Equals(value); + } + + public bool Equals(ParameterSet other) + { + if (other == null || other._hash != _hash || other._parameterValues.Count != _parameterValues.Count) + return false; + return other._parameterValues.Values.All(pv => ContainsParamValue(pv)); + } + + public ParameterSet Clone() => + new ParameterSet(new Dictionary(_parameterValues), _hash); + + public override string ToString() + { + return string.Join(" ", _parameterValues.Select(kvp => string.Format("{0}={1}", kvp.Value.Name, kvp.Value.ValueText)).ToArray()); + } + + public override int GetHashCode() + { + return _hash; + } + } + + /// + /// The result of a run. + /// Contains the parameter set used, useful for the sweeper to not generate the same configuration multiple times. + /// Also contains the result of a run and the metric value that is used by smart sweepers to generate new configurations + /// that try to maximize this metric. + /// + internal interface IRunResult : IComparable + { + ParameterSet ParameterSet { get; } + IComparable MetricValue { get; } + bool IsMetricMaximizing { get; } + } + + internal interface IRunResult : IRunResult + where T : IComparable + { + new T MetricValue { get; } + } + + /// + /// Simple implementation of IRunResult + /// + internal sealed class RunResult : IRunResult + { + private readonly ParameterSet _parameterSet; + private readonly Double? _metricValue; + private readonly bool _isMetricMaximizing; + + /// + /// This switch changes the behavior of the CompareTo function, switching the greater than / less than + /// behavior, depending on if it is set to True. + /// + public bool IsMetricMaximizing { get { return _isMetricMaximizing; } } + + public ParameterSet ParameterSet + { + get { return _parameterSet; } + } + + public RunResult(ParameterSet parameterSet, Double metricValue, bool isMetricMaximizing) + { + _parameterSet = parameterSet; + _metricValue = metricValue; + _isMetricMaximizing = isMetricMaximizing; + } + + public Double MetricValue + { + get + { + return _metricValue.Value; + } + } + + public int CompareTo(IRunResult other) + { + var otherTyped = other as RunResult; + //Contracts.Check(otherTyped != null); + if (_metricValue == otherTyped._metricValue) + return 0; + return _isMetricMaximizing ^ (_metricValue < otherTyped._metricValue) ? 1 : -1; + } + + public bool HasMetricValue + { + get + { + return _metricValue != null; + } + } + + IComparable IRunResult.MetricValue + { + get { return MetricValue; } + } + } + + /// + /// The metric class, used by smart sweeping algorithms. + /// Ideally we would like to move towards the new IDataView/ISchematized, this is + /// just a simple view instead, and it is decoupled from RunResult so we can move + /// in that direction in the future. + /// + internal sealed class RunMetric + { + private readonly float _primaryMetric; + private readonly float[] _metricDistribution; + + public RunMetric(float primaryMetric, IEnumerable metricDistribution = null) + { + _primaryMetric = primaryMetric; + if (metricDistribution != null) + _metricDistribution = metricDistribution.ToArray(); + } + + /// + /// The primary metric to optimize. + /// This metric is usually an aggregate value for the run, for example, AUC, accuracy etc. + /// By default, smart sweeping algorithms will maximize this metric. + /// If you want to minimize, either negate this value or change the option in the arguments of the sweeper constructor. + /// + public float PrimaryMetric + { + get { return _primaryMetric; } + } + + /// + /// The (optional) distribution of the metric. + /// This distribution can be a secondary measure of how good a run was, e.g per-fold AUC, per-fold accuracy, (sampled) per-instance log loss etc. + /// + public float[] GetMetricDistribution() + { + if (_metricDistribution == null) + return null; + var result = new float[_metricDistribution.Length]; + Array.Copy(_metricDistribution, result, _metricDistribution.Length); + return result; + } + } +} diff --git a/src/Microsoft.ML.Auto/Sweepers/Parameters.cs b/src/Microsoft.ML.Auto/Sweepers/Parameters.cs new file mode 100644 index 0000000000..9c9ffabfd3 --- /dev/null +++ b/src/Microsoft.ML.Auto/Sweepers/Parameters.cs @@ -0,0 +1,473 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.ML.Auto +{ + internal abstract class BaseParamArguments + { + //[Argument(ArgumentType.Required, HelpText = "Parameter name", ShortName = "n")] + public string Name; + } + + internal abstract class NumericParamArguments : BaseParamArguments + { + //[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of steps for grid runthrough.", ShortName = "steps")] + public int NumSteps = 100; + + //[Argument(ArgumentType.LastOccurenceWins, HelpText = "Amount of increment between steps (multiplicative if log).", ShortName = "inc")] + public Double? StepSize = null; + + //[Argument(ArgumentType.LastOccurenceWins, HelpText = "Log scale.", ShortName = "log")] + public bool LogBase = false; + } + + internal class FloatParamArguments : NumericParamArguments + { + //[Argument(ArgumentType.Required, HelpText = "Minimum value")] + public float Min; + + //[Argument(ArgumentType.Required, HelpText = "Maximum value")] + public float Max; + } + + internal class LongParamArguments : NumericParamArguments + { + //[Argument(ArgumentType.Required, HelpText = "Minimum value")] + public long Min; + + //[Argument(ArgumentType.Required, HelpText = "Maximum value")] + public long Max; + } + + internal class DiscreteParamArguments : BaseParamArguments + { + //[Argument(ArgumentType.Multiple, HelpText = "Values", ShortName = "v")] + public string[] Values = null; + } + + internal sealed class LongParameterValue : IParameterValue + { + private readonly string _name; + private readonly string _valueText; + private readonly long _value; + + public string Name + { + get { return _name; } + } + + public string ValueText + { + get { return _valueText; } + } + + public long Value + { + get { return _value; } + } + + public LongParameterValue(string name, long value) + { + _name = name; + _value = value; + _valueText = _value.ToString("D"); + } + + public bool Equals(IParameterValue other) + { + return Equals((object)other); + } + + public override bool Equals(object obj) + { + var lpv = obj as LongParameterValue; + return lpv != null && Name == lpv.Name && _value == lpv._value; + } + + public override int GetHashCode() + { + return Hashing.CombinedHash(0, typeof(LongParameterValue), _name, _value); + } + } + + internal sealed class FloatParameterValue : IParameterValue + { + private readonly string _name; + private readonly string _valueText; + private readonly float _value; + + public string Name + { + get { return _name; } + } + + public string ValueText + { + get { return _valueText; } + } + + public float Value + { + get { return _value; } + } + + public FloatParameterValue(string name, float value) + { + AutoMlUtils.Assert(!float.IsNaN(value)); + _name = name; + _value = value; + _valueText = _value.ToString("R"); + } + + public bool Equals(IParameterValue other) + { + return Equals((object)other); + } + + public override bool Equals(object obj) + { + var fpv = obj as FloatParameterValue; + return fpv != null && Name == fpv.Name && _value == fpv._value; + } + + public override int GetHashCode() + { + return Hashing.CombinedHash(0, typeof(FloatParameterValue), _name, _value); + } + } + + internal sealed class StringParameterValue : IParameterValue + { + private readonly string _name; + private readonly string _value; + + public string Name + { + get { return _name; } + } + + public string ValueText + { + get { return _value; } + } + + public string Value + { + get { return _value; } + } + + public StringParameterValue(string name, string value) + { + _name = name; + _value = value; + } + + public bool Equals(IParameterValue other) + { + return Equals((object)other); + } + + public override bool Equals(object obj) + { + var spv = obj as StringParameterValue; + return spv != null && Name == spv.Name && ValueText == spv.ValueText; + } + + public override int GetHashCode() + { + return Hashing.CombinedHash(0, typeof(StringParameterValue), _name, _value); + } + } + + internal interface INumericValueGenerator : IValueGenerator + { + float NormalizeValue(IParameterValue value); + bool InRange(IParameterValue value); + } + + /// + /// The integer type parameter sweep. + /// + internal class LongValueGenerator : INumericValueGenerator + { + private readonly LongParamArguments _args; + private IParameterValue[] _gridValues; + + public string Name { get { return _args.Name; } } + + public LongValueGenerator(LongParamArguments args) + { + AutoMlUtils.Assert(args.Min < args.Max, "min must be less than max"); + // REVIEW: this condition can be relaxed if we change the math below to deal with it + AutoMlUtils.Assert(!args.LogBase || args.Min > 0, "min must be positive if log scale is used"); + AutoMlUtils.Assert(!args.LogBase || args.StepSize == null || args.StepSize > 1, "StepSize must be greater than 1 if log scale is used"); + AutoMlUtils.Assert(args.LogBase || args.StepSize == null || args.StepSize > 0, "StepSize must be greater than 0 if linear scale is used"); + _args = args; + } + + // REVIEW: Is Float accurate enough? + public IParameterValue CreateFromNormalized(Double normalizedValue) + { + long val; + if (_args.LogBase) + { + // REVIEW: review the math below, it only works for positive Min and Max + var logBase = !_args.StepSize.HasValue + ? Math.Pow(1.0 * _args.Max / _args.Min, 1.0 / (_args.NumSteps - 1)) + : _args.StepSize.Value; + var logMax = Math.Log(_args.Max, logBase); + var logMin = Math.Log(_args.Min, logBase); + val = (long)(_args.Min * Math.Pow(logBase, normalizedValue * (logMax - logMin))); + } + else + val = (long)(_args.Min + normalizedValue * (_args.Max - _args.Min)); + + return new LongParameterValue(_args.Name, val); + } + + private void EnsureParameterValues() + { + if (_gridValues != null) + return; + + var result = new List(); + if ((_args.StepSize == null && _args.NumSteps > (_args.Max - _args.Min)) || + (_args.StepSize != null && _args.StepSize <= 1)) + { + for (long i = _args.Min; i <= _args.Max; i++) + result.Add(new LongParameterValue(_args.Name, i)); + } + else + { + if (_args.LogBase) + { + // REVIEW: review the math below, it only works for positive Min and Max + var logBase = _args.StepSize ?? Math.Pow(1.0 * _args.Max / _args.Min, 1.0 / (_args.NumSteps - 1)); + + long prevValue = long.MinValue; + var maxPlusEpsilon = _args.Max * Math.Sqrt(logBase); + for (Double value = _args.Min; value <= maxPlusEpsilon; value *= logBase) + { + var longValue = (long)value; + if (longValue > prevValue) + result.Add(new LongParameterValue(_args.Name, longValue)); + prevValue = longValue; + } + } + else + { + var stepSize = _args.StepSize ?? (Double)(_args.Max - _args.Min) / (_args.NumSteps - 1); + long prevValue = long.MinValue; + var maxPlusEpsilon = _args.Max + stepSize / 2; + for (Double value = _args.Min; value <= maxPlusEpsilon; value += stepSize) + { + var longValue = (long)value; + if (longValue > prevValue) + result.Add(new LongParameterValue(_args.Name, longValue)); + prevValue = longValue; + } + } + } + _gridValues = result.ToArray(); + } + + public IParameterValue this[int i] + { + get + { + EnsureParameterValues(); + return _gridValues[i]; + } + } + + public int Count + { + get + { + EnsureParameterValues(); + return _gridValues.Length; + } + } + + public float NormalizeValue(IParameterValue value) + { + var valueTyped = value as LongParameterValue; + AutoMlUtils.Assert(valueTyped != null, "LongValueGenerator could not normalized parameter because it is not of the correct type"); + AutoMlUtils.Assert(_args.Min <= valueTyped.Value && valueTyped.Value <= _args.Max, "Value not in correct range"); + + if (_args.LogBase) + { + float logBase = (float)(_args.StepSize ?? Math.Pow(1.0 * _args.Max / _args.Min, 1.0 / (_args.NumSteps - 1))); + return (float)((Math.Log(valueTyped.Value, logBase) - Math.Log(_args.Min, logBase)) / (Math.Log(_args.Max, logBase) - Math.Log(_args.Min, logBase))); + } + else + return (float)(valueTyped.Value - _args.Min) / (_args.Max - _args.Min); + } + + public bool InRange(IParameterValue value) + { + var valueTyped = value as LongParameterValue; + return (_args.Min <= valueTyped.Value && valueTyped.Value <= _args.Max); + } + } + + /// + /// The floating point type parameter sweep. + /// + internal class FloatValueGenerator : INumericValueGenerator + { + private readonly FloatParamArguments _args; + private IParameterValue[] _gridValues; + + public string Name { get { return _args.Name; } } + + public FloatValueGenerator(FloatParamArguments args) + { + AutoMlUtils.Assert(args.Min < args.Max, "min must be less than max"); + // REVIEW: this condition can be relaxed if we change the math below to deal with it + AutoMlUtils.Assert(!args.LogBase || args.Min > 0, "min must be positive if log scale is used"); + AutoMlUtils.Assert(!args.LogBase || args.StepSize == null || args.StepSize > 1, "StepSize must be greater than 1 if log scale is used"); + AutoMlUtils.Assert(args.LogBase || args.StepSize == null || args.StepSize > 0, "StepSize must be greater than 0 if linear scale is used"); + _args = args; + } + + // REVIEW: Is Float accurate enough? + public IParameterValue CreateFromNormalized(Double normalizedValue) + { + float val; + if (_args.LogBase) + { + // REVIEW: review the math below, it only works for positive Min and Max + var logBase = !_args.StepSize.HasValue + ? Math.Pow(1.0 * _args.Max / _args.Min, 1.0 / (_args.NumSteps - 1)) + : _args.StepSize.Value; + var logMax = Math.Log(_args.Max, logBase); + var logMin = Math.Log(_args.Min, logBase); + val = (float)(_args.Min * Math.Pow(logBase, normalizedValue * (logMax - logMin))); + } + else + val = (float)(_args.Min + normalizedValue * (_args.Max - _args.Min)); + + return new FloatParameterValue(_args.Name, val); + } + + private void EnsureParameterValues() + { + if (_gridValues != null) + return; + + var result = new List(); + if (_args.LogBase) + { + // REVIEW: review the math below, it only works for positive Min and Max + var logBase = _args.StepSize ?? Math.Pow(1.0 * _args.Max / _args.Min, 1.0 / (_args.NumSteps - 1)); + + float prevValue = float.NegativeInfinity; + var maxPlusEpsilon = _args.Max * Math.Sqrt(logBase); + for (Double value = _args.Min; value <= maxPlusEpsilon; value *= logBase) + { + var floatValue = (float)value; + if (floatValue > prevValue) + result.Add(new FloatParameterValue(_args.Name, floatValue)); + prevValue = floatValue; + } + } + else + { + var stepSize = _args.StepSize ?? (Double)(_args.Max - _args.Min) / (_args.NumSteps - 1); + float prevValue = float.NegativeInfinity; + var maxPlusEpsilon = _args.Max + stepSize / 2; + for (Double value = _args.Min; value <= maxPlusEpsilon; value += stepSize) + { + var floatValue = (float)value; + if (floatValue > prevValue) + result.Add(new FloatParameterValue(_args.Name, floatValue)); + prevValue = floatValue; + } + } + + _gridValues = result.ToArray(); + } + + public IParameterValue this[int i] + { + get + { + EnsureParameterValues(); + return _gridValues[i]; + } + } + + public int Count + { + get + { + EnsureParameterValues(); + return _gridValues.Length; + } + } + + public float NormalizeValue(IParameterValue value) + { + var valueTyped = value as FloatParameterValue; + AutoMlUtils.Assert(valueTyped != null, "FloatValueGenerator could not normalized parameter because it is not of the correct type"); + AutoMlUtils.Assert(_args.Min <= valueTyped.Value && valueTyped.Value <= _args.Max, "Value not in correct range"); + + if (_args.LogBase) + { + float logBase = (float)(_args.StepSize ?? Math.Pow(1.0 * _args.Max / _args.Min, 1.0 / (_args.NumSteps - 1))); + return (float)((Math.Log(valueTyped.Value, logBase) - Math.Log(_args.Min, logBase)) / (Math.Log(_args.Max, logBase) - Math.Log(_args.Min, logBase))); + } + else + return (valueTyped.Value - _args.Min) / (_args.Max - _args.Min); + } + + public bool InRange(IParameterValue value) + { + var valueTyped = value as FloatParameterValue; + AutoMlUtils.Assert(valueTyped != null, "Parameter should be of type FloatParameterValue"); + return (_args.Min <= valueTyped.Value && valueTyped.Value <= _args.Max); + } + } + + /// + /// The discrete parameter sweep. + /// + internal class DiscreteValueGenerator : IValueGenerator + { + private readonly DiscreteParamArguments _args; + + public string Name { get { return _args.Name; } } + + public DiscreteValueGenerator(DiscreteParamArguments args) + { + _args = args; + } + + // REVIEW: Is Float accurate enough? + public IParameterValue CreateFromNormalized(Double normalizedValue) + { + return new StringParameterValue(_args.Name, _args.Values[(int)(_args.Values.Length * normalizedValue)]); + } + + public IParameterValue this[int i] + { + get + { + return new StringParameterValue(_args.Name, _args.Values[i]); + } + } + + public int Count + { + get + { + return _args.Values.Length; + } + } + } +} diff --git a/src/Microsoft.ML.Auto/Sweepers/Random.cs b/src/Microsoft.ML.Auto/Sweepers/Random.cs new file mode 100644 index 0000000000..36edcb8dca --- /dev/null +++ b/src/Microsoft.ML.Auto/Sweepers/Random.cs @@ -0,0 +1,29 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Linq; + +namespace Microsoft.ML.Auto +{ + /// + /// Random sweeper, it generates random values for each of the parameters. + /// + internal sealed class UniformRandomSweeper : SweeperBase + { + public UniformRandomSweeper(ArgumentsBase args) + : base(args, "UniformRandom") + { + } + + public UniformRandomSweeper(ArgumentsBase args, IValueGenerator[] sweepParameters) + : base(args, sweepParameters, "UniformRandom") + { + } + + protected override ParameterSet CreateParamSet() + { + return new ParameterSet(SweepParameters.Select(sweepParameter => sweepParameter.CreateFromNormalized(AutoMlUtils.random.Value.NextDouble()))); + } + } +} diff --git a/src/Microsoft.ML.Auto/Sweepers/SmacSweeper.cs b/src/Microsoft.ML.Auto/Sweepers/SmacSweeper.cs new file mode 100644 index 0000000000..618cf74256 --- /dev/null +++ b/src/Microsoft.ML.Auto/Sweepers/SmacSweeper.cs @@ -0,0 +1,423 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Microsoft.ML.Data; +using Microsoft.ML.Trainers.FastTree; +using Float = System.Single; + +namespace Microsoft.ML.Auto +{ + //REVIEW: Figure out better way to do this. could introduce a base class for all smart sweepers, + //encapsulating common functionality. This seems like a good plan to persue. + internal sealed class SmacSweeper : ISweeper + { + public sealed class Arguments + { + //[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "Swept parameters", ShortName = "p", SignatureType = typeof(SignatureSweeperParameter))] + public IValueGenerator[] SweptParameters; + + //[Argument(ArgumentType.AtMostOnce, HelpText = "Seed for the random number generator for the first batch sweeper", ShortName = "seed")] + public int RandomSeed; + + //[Argument(ArgumentType.LastOccurenceWins, HelpText = "If iteration point is outside parameter definitions, should it be projected?", ShortName = "project")] + public bool ProjectInBounds = true; + + //[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of regression trees in forest", ShortName = "numtrees")] + public int NumOfTrees = 10; + + //[Argument(ArgumentType.LastOccurenceWins, HelpText = "Minimum number of data points required to be in a node if it is to be split further", ShortName = "nmin")] + public int NMinForSplit = 2; + + //[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of points to use for random initialization", ShortName = "nip")] + public int NumberInitialPopulation = 20; + + //[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of search parents to use for local search in maximizing EI acquisition function", ShortName = "lsp")] + public int LocalSearchParentCount = 10; + + //[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of random configurations when maximizing EI acquisition function", ShortName = "nrcan")] + public int NumRandomEISearchConfigurations = 10000; + + //[Argument(ArgumentType.LastOccurenceWins, HelpText = "Fraction of eligible dimensions to split on (i.e., split ratio)", ShortName = "sr")] + public Float SplitRatio = (Float)0.8; + + //[Argument(ArgumentType.LastOccurenceWins, HelpText = "Epsilon threshold for ending local searches", ShortName = "eps")] + public Float Epsilon = (Float)0.00001; + + //[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of neighbors to sample for locally searching each numerical parameter", ShortName = "nnnp")] + public int NumNeighborsForNumericalParams = 4; + } + + private readonly ISweeper _randomSweeper; + private readonly Arguments _args; + private readonly MLContext _context; + + private readonly IValueGenerator[] _sweepParameters; + + public SmacSweeper(MLContext context, Arguments args) + { + _context = context; + _args = args; + _sweepParameters = args.SweptParameters; + _randomSweeper = new UniformRandomSweeper(new SweeperBase.ArgumentsBase(), _sweepParameters); + } + + public ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable previousRuns = null) + { + int numOfCandidates = maxSweeps; + + // Initialization: Will enter here on first iteration and use the default (random) + // sweeper to generate initial candidates. + int numRuns = previousRuns == null ? 0 : previousRuns.Count(); + if (numRuns < _args.NumberInitialPopulation) + return _randomSweeper.ProposeSweeps(Math.Min(numOfCandidates, _args.NumberInitialPopulation - numRuns), previousRuns); + + // Only retain viable runs + List viableRuns = new List(); + foreach (RunResult run in previousRuns) + { + if (run != null && run.HasMetricValue) + viableRuns.Add(run); + } + + // Fit Random Forest Model on previous run data. + var forestPredictor = FitModel(viableRuns); + + // Using acquisition function and current best, get candidate configuration(s). + return GenerateCandidateConfigurations(numOfCandidates, viableRuns, forestPredictor); + } + + private FastForestRegressionModelParameters FitModel(IEnumerable previousRuns) + { + Single[] targets = new Single[previousRuns.Count()]; + Single[][] features = new Single[previousRuns.Count()][]; + + int i = 0; + foreach (RunResult r in previousRuns) + { + features[i] = SweeperProbabilityUtils.ParameterSetAsFloatArray(_sweepParameters, r.ParameterSet, true); + targets[i] = (Float)r.MetricValue; + i++; + } + + ArrayDataViewBuilder dvBuilder = new ArrayDataViewBuilder(_context); + dvBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single, targets); + dvBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single, features); + + IDataView data = dvBuilder.GetDataView(); + AutoMlUtils.Assert(data.GetRowCount() == targets.Length, "This data view will have as many rows as there have been evaluations"); + + // Set relevant random forest arguments. + // Train random forest. + var trainer = _context.Regression.Trainers.FastForest(new FastForestRegressionTrainer.Options() + { + FeatureFraction = _args.SplitRatio, + NumberOfTrees = _args.NumOfTrees, + MinimumExampleCountPerLeaf = _args.NMinForSplit + }); + var predictor = trainer.Fit(data).Model; + + // Return random forest predictor. + return predictor; + } + + /// + /// Generates a set of candidate configurations to sweep through, based on a combination of random and local + /// search, as outlined in Hutter et al - Sequential Model-Based Optimization for General Algorithm Configuration. + /// Makes use of class private members which determine how many candidates are returned. This number will include + /// random configurations interleaved (per the paper), and thus will be double the specified value. + /// + /// Number of candidate solutions to return. + /// History of previously evaluated points, with their emprical performance values. + /// Trained random forest ensemble. Used in evaluating the candidates. + /// An array of ParamaterSets which are the candidate configurations to sweep. + private ParameterSet[] GenerateCandidateConfigurations(int numOfCandidates, IEnumerable previousRuns, FastForestRegressionModelParameters forest) + { + // Get k best previous runs ParameterSets. + ParameterSet[] bestKParamSets = GetKBestConfigurations(previousRuns, _args.LocalSearchParentCount); + + // Perform local searches using the k best previous run configurations. + ParameterSet[] eiChallengers = GreedyPlusRandomSearch(bestKParamSets, forest, (int)Math.Ceiling(numOfCandidates / 2.0F), previousRuns); + + // Generate another set of random configurations to interleave. + ParameterSet[] randomChallengers = _randomSweeper.ProposeSweeps(numOfCandidates - eiChallengers.Length, previousRuns); + + // Return interleaved challenger candidates with random candidates. Since the number of candidates from either can be less than + // the number asked for, since we only generate unique candidates, and the number from either method may vary considerably. + ParameterSet[] configs = new ParameterSet[eiChallengers.Length + randomChallengers.Length]; + Array.Copy(eiChallengers, 0, configs, 0, eiChallengers.Length); + Array.Copy(randomChallengers, 0, configs, eiChallengers.Length, randomChallengers.Length); + + return configs; + } + + /// + /// Does a mix of greedy local search around best performing parameter sets, while throwing random parameter sets into the mix. + /// + /// Beginning locations for local greedy search. + /// Trained random forest, used later for evaluating parameters. + /// Number of candidate configurations returned by the method (top K). + /// Historical run results. + /// Array of parameter sets, which will then be evaluated. + private ParameterSet[] GreedyPlusRandomSearch(ParameterSet[] parents, FastForestRegressionModelParameters forest, int numOfCandidates, IEnumerable previousRuns) + { + RunResult bestRun = (RunResult)previousRuns.Max(); + RunResult worstRun = (RunResult)previousRuns.Min(); + double bestVal = bestRun.MetricValue; + + HashSet> configurations = new HashSet>(); + + // Perform local search. + foreach (ParameterSet c in parents) + { + Tuple bestChildKvp = LocalSearch(c, forest, bestVal, _args.Epsilon, bestRun.IsMetricMaximizing); + configurations.Add(bestChildKvp); + } + + // Additional set of random configurations to choose from during local search. + ParameterSet[] randomConfigs = _randomSweeper.ProposeSweeps(_args.NumRandomEISearchConfigurations, previousRuns); + double[] randomEIs = EvaluateConfigurationsByEI(forest, bestVal, randomConfigs, bestRun.IsMetricMaximizing); + AutoMlUtils.Assert(randomConfigs.Length == randomEIs.Length); + + for (int i = 0; i < randomConfigs.Length; i++) + configurations.Add(new Tuple(randomEIs[i], randomConfigs[i])); + + IOrderedEnumerable> bestConfigurations = configurations.OrderByDescending(x => x.Item1); + + var retainedConfigs = new HashSet(bestConfigurations.Select(x => x.Item2)); + + // remove configurations matching previous run + foreach (var previousRun in previousRuns) + { + retainedConfigs.Remove(previousRun.ParameterSet); + } + + return retainedConfigs.Take(numOfCandidates).ToArray(); + } + + /// + /// Performs a local one-mutation neighborhood greedy search. + /// + /// Starting parameter set configuration. + /// Trained forest, for evaluation of points. + /// Best performance seen thus far. + /// Threshold for when to stop the local search. + /// Metric type - maximizing or minimizing. + /// + private Tuple LocalSearch(ParameterSet parent, FastForestRegressionModelParameters forest, double bestVal, double epsilon, bool isMetricMaximizing) + { + try + { + double currentBestEI = EvaluateConfigurationsByEI(forest, bestVal, new ParameterSet[] { parent }, isMetricMaximizing)[0]; + ParameterSet currentBestConfig = parent; + + for (; ; ) + { + ParameterSet[] neighborhood = GetOneMutationNeighborhood(currentBestConfig); + double[] eis = EvaluateConfigurationsByEI(forest, bestVal, neighborhood, isMetricMaximizing); + int bestIndex = eis.ArgMax(); + if (eis[bestIndex] - currentBestEI < _args.Epsilon) + break; + else + { + currentBestConfig = neighborhood[bestIndex]; + currentBestEI = eis[bestIndex]; + } + } + + return new Tuple(currentBestEI, currentBestConfig); + } + catch (Exception e) + { + throw new InvalidOperationException("SMAC sweeper localSearch threw exception", e); + } + } + + /// + /// Computes a single-mutation neighborhood (one param at a time) for a given configuration. For + /// numeric parameters, samples K mutations (i.e., creates K neighbors based on that paramater). + /// + /// Starting configuration. + /// A set of configurations that each differ from parent in exactly one parameter. + private ParameterSet[] GetOneMutationNeighborhood(ParameterSet parent) + { + List neighbors = new List(); + SweeperProbabilityUtils spu = new SweeperProbabilityUtils(); + + for (int i = 0; i < _sweepParameters.Length; i++) + { + // This allows us to query possible values of this parameter. + IValueGenerator sweepParam = _sweepParameters[i]; + + // This holds the actual value for this parameter, chosen in this parameter set. + IParameterValue pset = parent[sweepParam.Name]; + + AutoMlUtils.Assert(pset != null); + + DiscreteValueGenerator parameterDiscrete = sweepParam as DiscreteValueGenerator; + if (parameterDiscrete != null) + { + // Create one neighbor for every discrete parameter. + Float[] neighbor = SweeperProbabilityUtils.ParameterSetAsFloatArray(_sweepParameters, parent, false); + + int hotIndex = -1; + for (int j = 0; j < parameterDiscrete.Count; j++) + { + if (parameterDiscrete[j].Equals(pset)) + { + hotIndex = j; + break; + } + } + + AutoMlUtils.Assert(hotIndex >= 0); + + Random r = new Random(); + int randomIndex = r.Next(0, parameterDiscrete.Count - 1); + randomIndex += randomIndex >= hotIndex ? 1 : 0; + neighbor[i] = randomIndex; + neighbors.Add(SweeperProbabilityUtils.FloatArrayAsParameterSet(_sweepParameters, neighbor, false)); + } + else + { + INumericValueGenerator parameterNumeric = sweepParam as INumericValueGenerator; + AutoMlUtils.Assert(parameterNumeric != null, "SMAC sweeper can only sweep over discrete and numeric parameters"); + + // Create k neighbors (typically 4) for every numerical parameter. + for (int j = 0; j < _args.NumNeighborsForNumericalParams; j++) + { + Float[] neigh = SweeperProbabilityUtils.ParameterSetAsFloatArray(_sweepParameters, parent, false); + double newVal = spu.NormalRVs(1, neigh[i], 0.2)[0]; + while (newVal <= 0.0 || newVal >= 1.0) + newVal = spu.NormalRVs(1, neigh[i], 0.2)[0]; + neigh[i] = (Float)newVal; + ParameterSet neighbor = SweeperProbabilityUtils.FloatArrayAsParameterSet(_sweepParameters, neigh, false); + neighbors.Add(neighbor); + } + } + } + return neighbors.ToArray(); + } + + /// + /// Goes through forest to extract the set of leaf values associated with filtering each configuration. + /// + /// Trained forest predictor, used for filtering configs. + /// Parameter configurations. + /// 2D array where rows correspond to configurations, and columns to the predicted leaf values. + private double[][] GetForestRegressionLeafValues(FastForestRegressionModelParameters forest, ParameterSet[] configs) + { + List datasetLeafValues = new List(); + foreach (ParameterSet config in configs) + { + List leafValues = new List(); + for (var treeId = 0; treeId < forest.TrainedTreeEnsemble.Trees.Count; treeId++) + { + Float[] transformedParams = SweeperProbabilityUtils.ParameterSetAsFloatArray(_sweepParameters, config, true); + VBuffer features = new VBuffer(transformedParams.Length, transformedParams); + var leafId = GetLeaf(forest, treeId, features); + var leafValue = GetLeafValue(forest, treeId, leafId); + leafValues.Add(leafValue); + } + datasetLeafValues.Add(leafValues.ToArray()); + } + return datasetLeafValues.ToArray(); + } + + // Todo: Remove the reflection below for TreeTreeEnsembleModelParameters methods GetLeaf and GetLeafValue. + // Long-term, replace with tree featurizer once it becomes available + // Tracking issue -- https://github.com/dotnet/machinelearning-automl/issues/342 + private static MethodInfo GetLeafMethod = typeof(TreeEnsembleModelParameters).GetMethod("GetLeaf", BindingFlags.NonPublic | BindingFlags.Instance); + private static MethodInfo GetLeafValueMethod = typeof(TreeEnsembleModelParameters).GetMethod("GetLeafValue", BindingFlags.NonPublic | BindingFlags.Instance); + + private static int GetLeaf(TreeEnsembleModelParameters model, int treeId, VBuffer features) + { + List path = null; + return (int)GetLeafMethod.Invoke(model, new object[] { treeId, features, path }); + } + + private static float GetLeafValue(TreeEnsembleModelParameters model, int treeId, int leafId) + { + return (float)GetLeafValueMethod.Invoke(model, new object[] { treeId, leafId }); + } + + /// + /// Computes the empirical means and standard deviations for trees in the forest for each configuration. + /// + /// The sets of leaf values from which the means and standard deviations are computed. + /// A 2D array with one row per set of tree values, and the columns being mean and stddev, respectively. + private double[][] ComputeForestStats(double[][] leafValues) + { + // Computes the empirical mean and empirical std dev from the leaf prediction values. + double[][] meansAndStdDevs = new double[leafValues.Length][]; + for (int i = 0; i < leafValues.Length; i++) + { + double[] row = new double[2]; + row[0] = VectorUtils.GetMean(leafValues[i]); + row[1] = VectorUtils.GetStandardDeviation(leafValues[i]); + meansAndStdDevs[i] = row; + } + return meansAndStdDevs; + } + + private double[] EvaluateConfigurationsByEI(FastForestRegressionModelParameters forest, double bestVal, ParameterSet[] configs, bool isMetricMaximizing) + { + double[][] leafPredictions = GetForestRegressionLeafValues(forest, configs); + double[][] forestStatistics = ComputeForestStats(leafPredictions); + return ComputeEIs(bestVal, forestStatistics, isMetricMaximizing); + } + + private ParameterSet[] GetKBestConfigurations(IEnumerable previousRuns, int k = 10) + { + // NOTE: Should we change this to rank according to EI (using forest), instead of observed performance? + + SortedSet bestK = new SortedSet(); + + foreach (RunResult r in previousRuns) + { + RunResult worst = bestK.Min(); + + if (bestK.Count < k || r.CompareTo(worst) > 0) + bestK.Add(r); + + if (bestK.Count > k) + bestK.Remove(worst); + } + + // Extract the ParamaterSets and return. + List outSet = new List(); + foreach (RunResult r in bestK) + outSet.Add(r.ParameterSet); + return outSet.ToArray(); + } + + private double ComputeEI(double bestVal, double[] forestStatistics, bool isMetricMaximizing) + { + double empMean = forestStatistics[0]; + double empStdDev = forestStatistics[1]; + double centered = empMean - bestVal; + if (!isMetricMaximizing) + { + centered *= -1; + } + if (empStdDev == 0) + { + return centered; + } + double ztrans = centered / empStdDev; + return centered * SweeperProbabilityUtils.StdNormalCdf(ztrans) + empStdDev * SweeperProbabilityUtils.StdNormalPdf(ztrans); + } + + private double[] ComputeEIs(double bestVal, double[][] forestStatistics, bool isMetricMaximizing) + { + double[] eis = new double[forestStatistics.Length]; + for (int i = 0; i < forestStatistics.Length; i++) + eis[i] = ComputeEI(bestVal, forestStatistics[i], isMetricMaximizing); + return eis; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/Sweepers/SweeperBase.cs b/src/Microsoft.ML.Auto/Sweepers/SweeperBase.cs new file mode 100644 index 0000000000..402e4db9f9 --- /dev/null +++ b/src/Microsoft.ML.Auto/Sweepers/SweeperBase.cs @@ -0,0 +1,74 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.ML.Auto +{ + /// + /// Signature for the GUI loaders of sweepers. + /// + internal delegate void SignatureSweeperFromParameterList(IValueGenerator[] sweepParameters); + + /// + /// Base sweeper that ensures the suggestions are different from each other and from the previous runs. + /// + internal abstract class SweeperBase : ISweeper + { + internal class ArgumentsBase + { + //[Argument(ArgumentType.Multiple, HelpText = "Swept parameters", ShortName = "p", SignatureType = typeof(SignatureSweeperParameter))] + public IValueGenerator[] SweptParameters; + + //[Argument(ArgumentType.LastOccurenceWins, HelpText = "Number of tries to generate distinct parameter sets.", ShortName = "r")] + public int Retries = 10; + } + + private readonly ArgumentsBase _args; + protected readonly IValueGenerator[] SweepParameters; + + protected SweeperBase(ArgumentsBase args, string name) + { + _args = args; + + SweepParameters = args.SweptParameters.ToArray(); + } + + protected SweeperBase(ArgumentsBase args, IValueGenerator[] sweepParameters, string name) + { + _args = args; + SweepParameters = sweepParameters; + } + + public virtual ParameterSet[] ProposeSweeps(int maxSweeps, IEnumerable previousRuns = null) + { + var prevParamSets = new HashSet(previousRuns?.Select(r => r.ParameterSet).ToList() ?? new List()); + var result = new HashSet(); + for (int i = 0; i < maxSweeps; i++) + { + ParameterSet paramSet; + int retries = 0; + do + { + paramSet = CreateParamSet(); + ++retries; + } while (paramSet != null && retries < _args.Retries && + (AlreadyGenerated(paramSet, prevParamSets) || AlreadyGenerated(paramSet, result))); + + AutoMlUtils.Assert(paramSet != null); + result.Add(paramSet); + } + + return result.ToArray(); + } + + protected abstract ParameterSet CreateParamSet(); + + protected static bool AlreadyGenerated(ParameterSet paramSet, ISet previousRuns) + { + return previousRuns.Contains(paramSet); + } + } +} diff --git a/src/Microsoft.ML.Auto/Sweepers/SweeperProbabilityUtils.cs b/src/Microsoft.ML.Auto/Sweepers/SweeperProbabilityUtils.cs new file mode 100644 index 0000000000..646a7df869 --- /dev/null +++ b/src/Microsoft.ML.Auto/Sweepers/SweeperProbabilityUtils.cs @@ -0,0 +1,159 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.ML.Auto +{ + internal sealed class SweeperProbabilityUtils + { + public static double StdNormalPdf(double x) + { + return 1 / Math.Sqrt(2 * Math.PI) * Math.Exp(-Math.Pow(x, 2) / 2); + } + + public static double StdNormalCdf(double x) + { + return 0.5 * (1 + ProbabilityFunctions.Erf(x * 1 / Math.Sqrt(2))); + } + + /// + /// Samples from a Gaussian Normal with mean mu and std dev sigma. + /// + /// Number of samples + /// mean + /// standard deviation + /// + public double[] NormalRVs(int numRVs, double mu, double sigma) + { + List rvs = new List(); + double u1; + double u2; + + for (int i = 0; i < numRVs; i++) + { + u1 = AutoMlUtils.random.Value.NextDouble(); + u2 = AutoMlUtils.random.Value.NextDouble(); + rvs.Add(mu + sigma * Math.Sqrt(-2.0 * Math.Log(u1)) * Math.Sin(2.0 * Math.PI * u2)); + } + + return rvs.ToArray(); + } + + /// + /// Simple binary search method for finding smallest index in array where value + /// meets or exceeds what you're looking for. + /// + /// Array to search + /// Value to search for + /// Left boundary of search + /// Right boundary of search + /// + private int BinarySearch(double[] a, double u, int low, int high) + { + int diff = high - low; + if (diff < 2) + return a[low] >= u ? low : high; + int mid = low + (diff / 2); + return a[mid] >= u ? BinarySearch(a, u, low, mid) : BinarySearch(a, u, mid, high); + } + + public static float[] ParameterSetAsFloatArray(IValueGenerator[] sweepParams, ParameterSet ps, bool expandCategoricals = true) + { + AutoMlUtils.Assert(ps.Count == sweepParams.Length); + + var result = new List(); + + for (int i = 0; i < sweepParams.Length; i++) + { + // This allows us to query possible values of this parameter. + var sweepParam = sweepParams[i]; + + // This holds the actual value for this parameter, chosen in this parameter set. + var pset = ps[sweepParam.Name]; + AutoMlUtils.Assert(pset != null); + + var parameterDiscrete = sweepParam as DiscreteValueGenerator; + if (parameterDiscrete != null) + { + int hotIndex = -1; + for (int j = 0; j < parameterDiscrete.Count; j++) + { + if (parameterDiscrete[j].Equals(pset)) + { + hotIndex = j; + break; + } + } + AutoMlUtils.Assert(hotIndex >= 0); + + if (expandCategoricals) + for (int j = 0; j < parameterDiscrete.Count; j++) + result.Add(j == hotIndex ? 1 : 0); + else + result.Add(hotIndex); + } + else if (sweepParam is LongValueGenerator lvg) + { + // Normalizing all numeric parameters to [0,1] range. + result.Add(lvg.NormalizeValue(new LongParameterValue(pset.Name, long.Parse(pset.ValueText)))); + } + else if (sweepParam is FloatValueGenerator fvg) + { + // Normalizing all numeric parameters to [0,1] range. + result.Add(fvg.NormalizeValue(new FloatParameterValue(pset.Name, float.Parse(pset.ValueText)))); + } + else + { + throw new InvalidOperationException("Smart sweeper can only sweep over discrete and numeric parameters"); + } + } + + return result.ToArray(); + } + + public static ParameterSet FloatArrayAsParameterSet(IValueGenerator[] sweepParams, float[] array, bool expandedCategoricals = true) + { + AutoMlUtils.Assert(array.Length == sweepParams.Length); + + List parameters = new List(); + int currentArrayIndex = 0; + for (int i = 0; i < sweepParams.Length; i++) + { + var parameterDiscrete = sweepParams[i] as DiscreteValueGenerator; + if (parameterDiscrete != null) + { + if (expandedCategoricals) + { + int hotIndex = -1; + for (int j = 0; j < parameterDiscrete.Count; j++) + { + if (array[i + j] > 0) + { + hotIndex = j; + break; + } + } + AutoMlUtils.Assert(hotIndex >= i); + parameters.Add(new StringParameterValue(sweepParams[i].Name, parameterDiscrete[hotIndex].ValueText)); + currentArrayIndex += parameterDiscrete.Count; + } + else + { + parameters.Add(new StringParameterValue(sweepParams[i].Name, parameterDiscrete[(int)array[currentArrayIndex]].ValueText)); + currentArrayIndex++; + } + } + else + { + parameters.Add(sweepParams[i].CreateFromNormalized(array[currentArrayIndex])); + currentArrayIndex++; + } + } + + return new ParameterSet(parameters); + } + } +} diff --git a/src/Microsoft.ML.Auto/TaskKind.cs b/src/Microsoft.ML.Auto/TaskKind.cs new file mode 100644 index 0000000000..93e0929f51 --- /dev/null +++ b/src/Microsoft.ML.Auto/TaskKind.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Auto +{ + internal enum TaskKind + { + BinaryClassification, + MulticlassClassification, + Regression, + } +} diff --git a/src/Microsoft.ML.Auto/Terminators/IterationBasedTerminator.cs b/src/Microsoft.ML.Auto/Terminators/IterationBasedTerminator.cs new file mode 100644 index 0000000000..9be7a170ca --- /dev/null +++ b/src/Microsoft.ML.Auto/Terminators/IterationBasedTerminator.cs @@ -0,0 +1,26 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Auto +{ + internal sealed class IterationBasedTerminator + { + private readonly int _numTotalIterations; + + public IterationBasedTerminator(int numTotalIterations) + { + _numTotalIterations = numTotalIterations; + } + + public bool ShouldTerminate(int numPreviousIterations) + { + return numPreviousIterations >= _numTotalIterations; + } + + public int RemainingIterations(int numPreviousIterations) + { + return _numTotalIterations - numPreviousIterations; + } + } +} diff --git a/src/Microsoft.ML.Auto/TrainerExtensions/BinaryTrainerExtensions.cs b/src/Microsoft.ML.Auto/TrainerExtensions/BinaryTrainerExtensions.cs new file mode 100644 index 0000000000..d2fd72a673 --- /dev/null +++ b/src/Microsoft.ML.Auto/TrainerExtensions/BinaryTrainerExtensions.cs @@ -0,0 +1,235 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Calibrators; +using Microsoft.ML.Data; +using Microsoft.ML.Trainers; +using Microsoft.ML.Trainers.FastTree; +using Microsoft.ML.Trainers.LightGbm; + +namespace Microsoft.ML.Auto +{ + using ITrainerEstimator = ITrainerEstimator, object>; + + internal class AveragedPerceptronBinaryExtension : ITrainerExtension + { + private const int DefaultNumIterations = 10; + + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildAveragePerceptronParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + AveragedPerceptronTrainer.Options options = null; + if (sweepParams == null || !sweepParams.Any()) + { + options = new AveragedPerceptronTrainer.Options(); + options.NumberOfIterations = DefaultNumIterations; + options.LabelColumnName = columnInfo.LabelColumnName; + } + else + { + options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + if (!sweepParams.Any(p => p.Name == "NumberOfIterations")) + { + options.NumberOfIterations = DefaultNumIterations; + } + } + return mlContext.BinaryClassification.Trainers.AveragedPerceptron(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + Dictionary additionalProperties = null; + + if (sweepParams == null || !sweepParams.Any(p => p.Name != "NumberOfIterations")) + { + additionalProperties = new Dictionary() + { + { "NumberOfIterations", DefaultNumIterations } + }; + } + + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName, additionalProperties: additionalProperties); + } + } + + internal class FastForestBinaryExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildFastForestParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + options.ExampleWeightColumnName = columnInfo.ExampleWeightColumnName; + return mlContext.BinaryClassification.Trainers.FastForest(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName); + } + } + + internal class FastTreeBinaryExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildFastTreeParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + options.ExampleWeightColumnName = columnInfo.ExampleWeightColumnName; + return mlContext.BinaryClassification.Trainers.FastTree(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName); + } + } + + internal class LightGbmBinaryExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildLightGbmParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + LightGbmBinaryTrainer.Options options = TrainerExtensionUtil.CreateLightGbmOptions>, CalibratedModelParametersBase>(sweepParams, columnInfo); + return mlContext.BinaryClassification.Trainers.LightGbm(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildLightGbmPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName); + } + } + + internal class LinearSvmBinaryExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildLinearSvmParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + return mlContext.BinaryClassification.Trainers.LinearSvm(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName); + } + } + + internal class SdcaLogisticRegressionBinaryExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildSdcaParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + return mlContext.BinaryClassification.Trainers.SdcaLogisticRegression(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName); + } + } + + internal class LbfgsLogisticRegressionBinaryExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildLbfgsLogisticRegressionParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + options.ExampleWeightColumnName = columnInfo.ExampleWeightColumnName; + return mlContext.BinaryClassification.Trainers.LbfgsLogisticRegression(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName); + } + } + + internal class SgdCalibratedBinaryExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildSgdParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + options.ExampleWeightColumnName = columnInfo.ExampleWeightColumnName; + return mlContext.BinaryClassification.Trainers.SgdCalibrated(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName); + } + } + + internal class SymbolicSgdLogisticRegressionBinaryExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildSymSgdLogisticRegressionParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + return mlContext.BinaryClassification.Trainers.SymbolicSgdLogisticRegression(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName); + } + } +} diff --git a/src/Microsoft.ML.Auto/TrainerExtensions/ITrainerExtension.cs b/src/Microsoft.ML.Auto/TrainerExtensions/ITrainerExtension.cs new file mode 100644 index 0000000000..790825f8ec --- /dev/null +++ b/src/Microsoft.ML.Auto/TrainerExtensions/ITrainerExtension.cs @@ -0,0 +1,20 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.ML.Trainers; + +namespace Microsoft.ML.Auto +{ + using ITrainerEstimator = ITrainerEstimator, object>; + + internal interface ITrainerExtension + { + IEnumerable GetHyperparamSweepRanges(); + + ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, ColumnInformation columnInfo); + + PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo); + } +} diff --git a/src/Microsoft.ML.Auto/TrainerExtensions/MultiTrainerExtensions.cs b/src/Microsoft.ML.Auto/TrainerExtensions/MultiTrainerExtensions.cs new file mode 100644 index 0000000000..8d8bdcf728 --- /dev/null +++ b/src/Microsoft.ML.Auto/TrainerExtensions/MultiTrainerExtensions.cs @@ -0,0 +1,232 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.ML.Data; +using Microsoft.ML.Trainers; +using Microsoft.ML.Trainers.FastTree; +using Microsoft.ML.Trainers.LightGbm; + +namespace Microsoft.ML.Auto +{ + using ITrainerEstimator = ITrainerEstimator, object>; + + internal class AveragedPerceptronOvaExtension : ITrainerExtension + { + private static readonly ITrainerExtension _binaryLearnerCatalogItem = new AveragedPerceptronBinaryExtension(); + + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildAveragePerceptronParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var binaryTrainer = _binaryLearnerCatalogItem.CreateInstance(mlContext, sweepParams, columnInfo) as AveragedPerceptronTrainer; + return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer, labelColumnName: columnInfo.LabelColumnName); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildOvaPipelineNode(this, _binaryLearnerCatalogItem, sweepParams, columnInfo); + } + } + + internal class FastForestOvaExtension : ITrainerExtension + { + private static readonly ITrainerExtension _binaryLearnerCatalogItem = new FastForestBinaryExtension(); + + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildFastForestParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var binaryTrainer = _binaryLearnerCatalogItem.CreateInstance(mlContext, sweepParams, columnInfo) as FastForestBinaryTrainer; + return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer, labelColumnName: columnInfo.LabelColumnName); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildOvaPipelineNode(this, _binaryLearnerCatalogItem, sweepParams, columnInfo); + } + } + + internal class LightGbmMultiExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildLightGbmParamsMulticlass(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + LightGbmMulticlassTrainer.Options options = TrainerExtensionUtil.CreateLightGbmOptions, MulticlassPredictionTransformer, OneVersusAllModelParameters>(sweepParams, columnInfo); + return mlContext.MulticlassClassification.Trainers.LightGbm(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildLightGbmPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName); + } + } + + internal class LinearSvmOvaExtension : ITrainerExtension + { + private static readonly ITrainerExtension _binaryLearnerCatalogItem = new LinearSvmBinaryExtension(); + + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildLinearSvmParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var binaryTrainer = _binaryLearnerCatalogItem.CreateInstance(mlContext, sweepParams, columnInfo) as LinearSvmTrainer; + return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer, labelColumnName: columnInfo.LabelColumnName); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildOvaPipelineNode(this, _binaryLearnerCatalogItem, sweepParams, columnInfo); + } + } + + internal class SdcaMaximumEntropyMultiExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildSdcaParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + return mlContext.MulticlassClassification.Trainers.SdcaMaximumEntropy(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName); + } + } + + internal class LbfgsLogisticRegressionOvaExtension : ITrainerExtension + { + private static readonly ITrainerExtension _binaryLearnerCatalogItem = new LbfgsLogisticRegressionBinaryExtension(); + + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildLbfgsLogisticRegressionParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var binaryTrainer = _binaryLearnerCatalogItem.CreateInstance(mlContext, sweepParams, columnInfo) as LbfgsLogisticRegressionBinaryTrainer; + return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer, labelColumnName: columnInfo.LabelColumnName); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildOvaPipelineNode(this, _binaryLearnerCatalogItem, sweepParams, columnInfo); + } + } + + internal class SgdCalibratedOvaExtension : ITrainerExtension + { + private static readonly ITrainerExtension _binaryLearnerCatalogItem = new SgdCalibratedBinaryExtension(); + + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildSgdParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var binaryTrainer = _binaryLearnerCatalogItem.CreateInstance(mlContext, sweepParams, columnInfo) as SgdCalibratedTrainer; + return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer, labelColumnName: columnInfo.LabelColumnName); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildOvaPipelineNode(this, _binaryLearnerCatalogItem, sweepParams, columnInfo); + } + } + + internal class SymbolicSgdLogisticRegressionOvaExtension : ITrainerExtension + { + private static readonly ITrainerExtension _binaryLearnerCatalogItem = new SymbolicSgdLogisticRegressionBinaryExtension(); + + public IEnumerable GetHyperparamSweepRanges() + { + return _binaryLearnerCatalogItem.GetHyperparamSweepRanges(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var binaryTrainer = _binaryLearnerCatalogItem.CreateInstance(mlContext, sweepParams, columnInfo) as SymbolicSgdLogisticRegressionBinaryTrainer; + return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer, labelColumnName: columnInfo.LabelColumnName); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildOvaPipelineNode(this, _binaryLearnerCatalogItem, sweepParams, columnInfo); + } + } + + internal class FastTreeOvaExtension : ITrainerExtension + { + private static readonly ITrainerExtension _binaryLearnerCatalogItem = new FastTreeBinaryExtension(); + + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildFastTreeParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var binaryTrainer = _binaryLearnerCatalogItem.CreateInstance(mlContext, sweepParams, columnInfo) as FastTreeBinaryTrainer; + return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer, labelColumnName: columnInfo.LabelColumnName); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildOvaPipelineNode(this, _binaryLearnerCatalogItem, sweepParams, columnInfo); + } + } + + internal class LbfgsMaximumEntropyMultiExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildLbfgsLogisticRegressionParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + options.ExampleWeightColumnName = columnInfo.ExampleWeightColumnName; + return mlContext.MulticlassClassification.Trainers.LbfgsMaximumEntropy(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/TrainerExtensions/RegressionTrainerExtensions.cs b/src/Microsoft.ML.Auto/TrainerExtensions/RegressionTrainerExtensions.cs new file mode 100644 index 0000000000..5995e448a2 --- /dev/null +++ b/src/Microsoft.ML.Auto/TrainerExtensions/RegressionTrainerExtensions.cs @@ -0,0 +1,187 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.ML.Data; +using Microsoft.ML.Trainers; +using Microsoft.ML.Trainers.FastTree; +using Microsoft.ML.Trainers.LightGbm; + +namespace Microsoft.ML.Auto +{ + using ITrainerEstimator = ITrainerEstimator, object>; + + internal class FastForestRegressionExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildFastForestParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + options.ExampleWeightColumnName = columnInfo.ExampleWeightColumnName; + return mlContext.Regression.Trainers.FastForest(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName); + } + } + + internal class FastTreeRegressionExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildFastTreeParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + options.ExampleWeightColumnName = columnInfo.ExampleWeightColumnName; + return mlContext.Regression.Trainers.FastTree(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName); + } + } + + internal class FastTreeTweedieRegressionExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildFastTreeTweedieParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + options.ExampleWeightColumnName = columnInfo.ExampleWeightColumnName; + return mlContext.Regression.Trainers.FastTreeTweedie(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName); + } + } + + internal class LightGbmRegressionExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildLightGbmParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + LightGbmRegressionTrainer.Options options = TrainerExtensionUtil.CreateLightGbmOptions, LightGbmRegressionModelParameters>(sweepParams, columnInfo); + return mlContext.Regression.Trainers.LightGbm(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildLightGbmPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName); + } + } + + internal class OnlineGradientDescentRegressionExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildOnlineGradientDescentParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + return mlContext.Regression.Trainers.OnlineGradientDescent(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName); + } + } + + internal class OlsRegressionExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildOlsParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + options.ExampleWeightColumnName = columnInfo.ExampleWeightColumnName; + return mlContext.Regression.Trainers.Ols(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName); + } + } + + internal class LbfgsPoissonRegressionExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildLbfgsPoissonRegressionParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + options.ExampleWeightColumnName = columnInfo.ExampleWeightColumnName; + return mlContext.Regression.Trainers.LbfgsPoissonRegression(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName, columnInfo.ExampleWeightColumnName); + } + } + + internal class SdcaRegressionExtension : ITrainerExtension + { + public IEnumerable GetHyperparamSweepRanges() + { + return SweepableParams.BuildSdcaParams(); + } + + public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable sweepParams, + ColumnInformation columnInfo) + { + var options = TrainerExtensionUtil.CreateOptions(sweepParams, columnInfo.LabelColumnName); + return mlContext.Regression.Trainers.Sdca(options); + } + + public PipelineNode CreatePipelineNode(IEnumerable sweepParams, ColumnInformation columnInfo) + { + return TrainerExtensionUtil.BuildPipelineNode(TrainerExtensionCatalog.GetTrainerName(this), sweepParams, + columnInfo.LabelColumnName); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/TrainerExtensions/SweepableParams.cs b/src/Microsoft.ML.Auto/TrainerExtensions/SweepableParams.cs new file mode 100644 index 0000000000..9d80ebe09a --- /dev/null +++ b/src/Microsoft.ML.Auto/TrainerExtensions/SweepableParams.cs @@ -0,0 +1,176 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.ML.Auto +{ + internal static class SweepableParams + { + private static IEnumerable BuildAveragedLinearArgsParams() + { + return new SweepableParam[] + { + new SweepableDiscreteParam("LearningRate", new object[] { 0.01f, 0.1f, 0.5f, 1.0f}), + new SweepableDiscreteParam("DecreaseLearningRate", new object[] { false, true }), + new SweepableFloatParam("L2Regularization", 0.0f, 0.4f), + }; + } + + private static IEnumerable BuildOnlineLinearArgsParams() + { + return new SweepableParam[] + { + new SweepableLongParam("NumberOfIterations", 1, 100, stepSize: 10, isLogScale: true), + new SweepableFloatParam("InitialWeightsDiameter", 0.0f, 1.0f, numSteps: 5), + new SweepableDiscreteParam("Shuffle", new object[] { false, true }), + }; + } + + private static IEnumerable BuildTreeArgsParams() + { + return new SweepableParam[] + { + new SweepableLongParam("NumberOfLeaves", 2, 128, isLogScale: true, stepSize: 4), + new SweepableDiscreteParam("MinimumExampleCountPerLeaf", new object[] { 1, 10, 50 }), + new SweepableDiscreteParam("NumberOfTrees", new object[] { 20, 100, 500 }), + }; + } + + private static IEnumerable BuildBoostedTreeArgsParams() + { + return BuildTreeArgsParams().Concat(new List() + { + new SweepableFloatParam("LearningRate", 0.025f, 0.4f, isLogScale: true), + new SweepableFloatParam("Shrinkage", 0.025f, 4f, isLogScale: true), + }); + } + + private static IEnumerable BuildLbfgsArgsParams() + { + return new SweepableParam[] { + new SweepableFloatParam("L2Regularization", 0.0f, 1.0f, numSteps: 4), + new SweepableFloatParam("L1Regularization", 0.0f, 1.0f, numSteps: 4), + new SweepableDiscreteParam("OptmizationTolerance", new object[] { 1e-4f, 1e-7f }), + new SweepableDiscreteParam("HistorySize", new object[] { 5, 20, 50 }), + new SweepableLongParam("MaximumNumberOfIterations", 1, int.MaxValue), + new SweepableFloatParam("InitialWeightsDiameter", 0.0f, 1.0f, numSteps: 5), + new SweepableDiscreteParam("DenseOptimizer", new object[] { false, true }), + }; + } + + public static IEnumerable BuildAveragePerceptronParams() + { + return BuildAveragedLinearArgsParams().Concat(BuildOnlineLinearArgsParams()); + } + + public static IEnumerable BuildFastForestParams() + { + return BuildTreeArgsParams(); + } + + public static IEnumerable BuildFastTreeParams() + { + return BuildBoostedTreeArgsParams(); + } + + public static IEnumerable BuildFastTreeTweedieParams() + { + return BuildBoostedTreeArgsParams(); + } + + public static IEnumerable BuildLightGbmParamsMulticlass() + { + return BuildLightGbmParams().Union(new SweepableParam[] + { + new SweepableDiscreteParam("UseSoftmax", new object[] { true, false }), + }); + } + + public static IEnumerable BuildLightGbmParams() + { + return new SweepableParam[] + { + new SweepableDiscreteParam("NumberOfIterations", new object[] { 10, 20, 50, 100, 150, 200 }), + new SweepableFloatParam("LearningRate", 0.025f, 0.4f, isLogScale: true), + new SweepableLongParam("NumberOfLeaves", 2, 128, isLogScale: true, stepSize: 4), + new SweepableDiscreteParam("MinimumExampleCountPerLeaf", new object[] { 1, 10, 20, 50 }), + new SweepableDiscreteParam("UseCategoricalSplit", new object[] { true, false }), + new SweepableDiscreteParam("HandleMissingValue", new object[] { true, false }), + new SweepableDiscreteParam("MinimumExampleCountPerGroup", new object[] { 10, 50, 100, 200 }), + new SweepableDiscreteParam("MaximumCategoricalSplitPointCount", new object[] { 8, 16, 32, 64 }), + new SweepableDiscreteParam("CategoricalSmoothing", new object[] { 1, 10, 20 }), + new SweepableDiscreteParam("L2CategoricalRegularization", new object[] { 0.1, 0.5, 1, 5, 10 }), + + // Booster params + new SweepableDiscreteParam("L2Regularization", new object[] { 0f, 0.5f, 1f }), + new SweepableDiscreteParam("L1Regularization", new object[] { 0f, 0.5f, 1f }) + }; + } + + public static IEnumerable BuildLinearSvmParams() + { + return new SweepableParam[] { + new SweepableFloatParam("Lambda", 0.00001f, 0.1f, 10, isLogScale: true), + new SweepableDiscreteParam("PerformProjection", null, isBool: true), + new SweepableDiscreteParam("NoBias", null, isBool: true) + }.Concat(BuildOnlineLinearArgsParams()); + } + + public static IEnumerable BuildLbfgsLogisticRegressionParams() + { + return BuildLbfgsArgsParams(); + } + + public static IEnumerable BuildOnlineGradientDescentParams() + { + return BuildAveragedLinearArgsParams(); + } + + public static IEnumerable BuildLbfgsPoissonRegressionParams() + { + return BuildLbfgsArgsParams(); + } + + public static IEnumerable BuildSdcaParams() + { + return new SweepableParam[] { + new SweepableDiscreteParam("L2Regularization", new object[] { "", 1e-7f, 1e-6f, 1e-5f, 1e-4f, 1e-3f, 1e-2f }), + new SweepableDiscreteParam("L1Regularization", new object[] { "", 0f, 0.25f, 0.5f, 0.75f, 1f }), + new SweepableDiscreteParam("ConvergenceTolerance", new object[] { 0.001f, 0.01f, 0.1f, 0.2f }), + new SweepableDiscreteParam("MaximumNumberOfIterations", new object[] { "", 10, 20, 100 }), + new SweepableDiscreteParam("Shuffle", null, isBool: true), + new SweepableDiscreteParam("BiasLearningRate", new object[] { 0.0f, 0.01f, 0.1f, 1f }) + }; + } + + public static IEnumerable BuildOlsParams() + { + return new SweepableParam[] { + new SweepableDiscreteParam("L2Regularization", new object[] { 1e-6f, 0.1f, 1f }) + }; + } + + public static IEnumerable BuildSgdParams() + { + return new SweepableParam[] { + new SweepableDiscreteParam("L2Regularization", new object[] { 1e-7f, 5e-7f, 1e-6f, 5e-6f, 1e-5f }), + new SweepableDiscreteParam("ConvergenceTolerance", new object[] { 1e-2f, 1e-3f, 1e-4f, 1e-5f }), + new SweepableDiscreteParam("NumberOfIterations", new object[] { 1, 5, 10, 20 }), + new SweepableDiscreteParam("Shuffle", null, isBool: true), + }; + } + + public static IEnumerable BuildSymSgdLogisticRegressionParams() + { + return new SweepableParam[] { + new SweepableDiscreteParam("NumberOfIterations", new object[] { 1, 5, 10, 20, 30, 40, 50 }), + new SweepableDiscreteParam("LearningRate", new object[] { "", 1e1f, 1e0f, 1e-1f, 1e-2f, 1e-3f }), + new SweepableDiscreteParam("L2Regularization", new object[] { 0.0f, 1e-5f, 1e-5f, 1e-6f, 1e-7f }), + new SweepableDiscreteParam("UpdateFrequency", new object[] { "", 5, 20 }) + }; + } + } +} diff --git a/src/Microsoft.ML.Auto/TrainerExtensions/TrainerExtensionCatalog.cs b/src/Microsoft.ML.Auto/TrainerExtensions/TrainerExtensionCatalog.cs new file mode 100644 index 0000000000..13301f1e78 --- /dev/null +++ b/src/Microsoft.ML.Auto/TrainerExtensions/TrainerExtensionCatalog.cs @@ -0,0 +1,138 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.ML.Auto +{ + internal class TrainerExtensionCatalog + { + private static readonly IDictionary _trainerNamesToExtensionTypes = + new Dictionary() + { + { TrainerName.AveragedPerceptronBinary, typeof(AveragedPerceptronBinaryExtension) }, + { TrainerName.AveragedPerceptronOva, typeof(AveragedPerceptronOvaExtension) }, + { TrainerName.FastForestBinary, typeof(FastForestBinaryExtension) }, + { TrainerName.FastForestOva, typeof(FastForestOvaExtension) }, + { TrainerName.FastForestRegression, typeof(FastForestRegressionExtension) }, + { TrainerName.FastTreeBinary, typeof(FastTreeBinaryExtension) }, + { TrainerName.FastTreeOva, typeof(FastTreeOvaExtension) }, + { TrainerName.FastTreeRegression, typeof(FastTreeRegressionExtension) }, + { TrainerName.FastTreeTweedieRegression, typeof(FastTreeTweedieRegressionExtension) }, + { TrainerName.LightGbmBinary, typeof(LightGbmBinaryExtension) }, + { TrainerName.LightGbmMulti, typeof(LightGbmMultiExtension) }, + { TrainerName.LightGbmRegression, typeof(LightGbmRegressionExtension) }, + { TrainerName.LinearSvmBinary, typeof(LinearSvmBinaryExtension) }, + { TrainerName.LinearSvmOva, typeof(LinearSvmOvaExtension) }, + { TrainerName.LbfgsLogisticRegressionBinary, typeof(LbfgsLogisticRegressionBinaryExtension) }, + { TrainerName.LbfgsMaximumEntropyMulti, typeof(LbfgsMaximumEntropyMultiExtension) }, + { TrainerName.LbfgsLogisticRegressionOva, typeof(LbfgsLogisticRegressionOvaExtension) }, + { TrainerName.OnlineGradientDescentRegression, typeof(OnlineGradientDescentRegressionExtension) }, + { TrainerName.OlsRegression, typeof(OlsRegressionExtension) }, + { TrainerName.LbfgsPoissonRegression, typeof(LbfgsPoissonRegressionExtension) }, + { TrainerName.SdcaLogisticRegressionBinary, typeof(SdcaLogisticRegressionBinaryExtension) }, + { TrainerName.SdcaMaximumEntropyMulti, typeof(SdcaMaximumEntropyMultiExtension) }, + { TrainerName.SdcaRegression, typeof(SdcaRegressionExtension) }, + { TrainerName.SgdCalibratedBinary, typeof(SgdCalibratedBinaryExtension) }, + { TrainerName.SgdCalibratedOva, typeof(SgdCalibratedOvaExtension) }, + { TrainerName.SymbolicSgdLogisticRegressionBinary, typeof(SymbolicSgdLogisticRegressionBinaryExtension) }, + { TrainerName.SymbolicSgdLogisticRegressionOva, typeof(SymbolicSgdLogisticRegressionOvaExtension) } + }; + + private static readonly IDictionary _extensionTypesToTrainerNames = + _trainerNamesToExtensionTypes.ToDictionary(kvp => kvp.Value, kvp => kvp.Key); + + public static TrainerName GetTrainerName(ITrainerExtension trainerExtension) + { + return _extensionTypesToTrainerNames[trainerExtension.GetType()]; + } + + public static ITrainerExtension GetTrainerExtension(TrainerName trainerName) + { + var trainerExtensionType = _trainerNamesToExtensionTypes[trainerName]; + return (ITrainerExtension)Activator.CreateInstance(trainerExtensionType); + } + + public static IEnumerable GetTrainers(TaskKind task, + IEnumerable whitelist) + { + IEnumerable trainers; + if (task == TaskKind.BinaryClassification) + { + trainers = GetBinaryLearners(); + } + else if (task == TaskKind.MulticlassClassification) + { + trainers = GetMultiLearners(); + } + else if (task == TaskKind.Regression) + { + trainers = GetRegressionLearners(); + } + else + { + // should not be possible to reach here + throw new NotSupportedException($"unsupported machine learning task type {task}"); + } + + if (whitelist != null) + { + whitelist = new HashSet(whitelist); + trainers = trainers.Where(t => whitelist.Contains(GetTrainerName(t))); + } + + return trainers; + } + + private static IEnumerable GetBinaryLearners() + { + return new ITrainerExtension[] + { + new AveragedPerceptronBinaryExtension(), + new SdcaLogisticRegressionBinaryExtension(), + new LightGbmBinaryExtension(), + new SymbolicSgdLogisticRegressionBinaryExtension(), + new LinearSvmBinaryExtension(), + new FastTreeBinaryExtension(), + new LbfgsLogisticRegressionBinaryExtension(), + new FastForestBinaryExtension(), + new SgdCalibratedBinaryExtension() + }; + } + + private static IEnumerable GetMultiLearners() + { + return new ITrainerExtension[] + { + new AveragedPerceptronOvaExtension(), + new SdcaMaximumEntropyMultiExtension(), + new LightGbmMultiExtension(), + new SymbolicSgdLogisticRegressionOvaExtension(), + new FastTreeOvaExtension(), + new LinearSvmOvaExtension(), + new LbfgsLogisticRegressionOvaExtension(), + new SgdCalibratedOvaExtension(), + new FastForestOvaExtension(), + new LbfgsMaximumEntropyMultiExtension() + }; + } + + private static IEnumerable GetRegressionLearners() + { + return new ITrainerExtension[] + { + new SdcaRegressionExtension(), + new LightGbmRegressionExtension(), + new FastTreeRegressionExtension(), + new FastTreeTweedieRegressionExtension(), + new FastForestRegressionExtension(), + new LbfgsPoissonRegressionExtension(), + new OnlineGradientDescentRegressionExtension(), + new OlsRegressionExtension(), + }; + } + } +} diff --git a/src/Microsoft.ML.Auto/TrainerExtensions/TrainerExtensionUtil.cs b/src/Microsoft.ML.Auto/TrainerExtensions/TrainerExtensionUtil.cs new file mode 100644 index 0000000000..213d555545 --- /dev/null +++ b/src/Microsoft.ML.Auto/TrainerExtensions/TrainerExtensionUtil.cs @@ -0,0 +1,381 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Microsoft.ML.Calibrators; +using Microsoft.ML.Trainers; +using Microsoft.ML.Trainers.LightGbm; + +namespace Microsoft.ML.Auto +{ + internal enum TrainerName + { + AveragedPerceptronBinary, + AveragedPerceptronOva, + FastForestBinary, + FastForestOva, + FastForestRegression, + FastTreeBinary, + FastTreeOva, + FastTreeRegression, + FastTreeTweedieRegression, + LightGbmBinary, + LightGbmMulti, + LightGbmRegression, + LinearSvmBinary, + LinearSvmOva, + LbfgsLogisticRegressionBinary, + LbfgsLogisticRegressionOva, + LbfgsMaximumEntropyMulti, + OnlineGradientDescentRegression, + OlsRegression, + Ova, + LbfgsPoissonRegression, + SdcaLogisticRegressionBinary, + SdcaMaximumEntropyMulti, + SdcaRegression, + SgdCalibratedBinary, + SgdCalibratedOva, + SymbolicSgdLogisticRegressionBinary, + SymbolicSgdLogisticRegressionOva + } + + internal static class TrainerExtensionUtil + { + private const string WeightColumn = "ExampleWeightColumnName"; + private const string LabelColumn = "LabelColumnName"; + + public static T CreateOptions(IEnumerable sweepParams, string labelColumn) where T : TrainerInputBaseWithLabel + { + var options = Activator.CreateInstance(); + options.LabelColumnName = labelColumn; + if (sweepParams != null) + { + UpdateFields(options, sweepParams); + } + return options; + } + + private static string[] _lightGbmBoosterParamNames = new[] { "L2Regularization", "L1Regularization" }; + private const string LightGbmBoosterPropName = "Booster"; + + public static TOptions CreateLightGbmOptions(IEnumerable sweepParams, ColumnInformation columnInfo) + where TOptions : LightGbmTrainerBase.OptionsBase, new() + where TTransformer : ISingleFeaturePredictionTransformer + where TModel : class + { + var options = new TOptions(); + options.LabelColumnName = columnInfo.LabelColumnName; + options.ExampleWeightColumnName = columnInfo.ExampleWeightColumnName; + options.Booster = new GradientBooster.Options(); + if (sweepParams != null) + { + var boosterParams = sweepParams.Where(p => _lightGbmBoosterParamNames.Contains(p.Name)); + var parentArgParams = sweepParams.Except(boosterParams); + UpdateFields(options, parentArgParams); + UpdateFields(options.Booster, boosterParams); + } + return options; + } + + public static PipelineNode BuildOvaPipelineNode(ITrainerExtension multiExtension, ITrainerExtension binaryExtension, + IEnumerable sweepParams, ColumnInformation columnInfo) + { + var ovaNode = new PipelineNode() + { + Name = TrainerName.Ova.ToString(), + NodeType = PipelineNodeType.Trainer, + Properties = new Dictionary() + { + { LabelColumn, columnInfo.LabelColumnName } + } + }; + var binaryNode = binaryExtension.CreatePipelineNode(sweepParams, columnInfo); + ovaNode.Properties["BinaryTrainer"] = binaryNode; + return ovaNode; + } + + public static PipelineNode BuildPipelineNode(TrainerName trainerName, IEnumerable sweepParams, + string labelColumn, string weightColumn = null, IDictionary additionalProperties = null) + { + var properties = BuildBasePipelineNodeProps(sweepParams, labelColumn, weightColumn); + + if (additionalProperties != null) + { + foreach (var property in additionalProperties) + { + properties[property.Key] = property.Value; + } + } + + return new PipelineNode(trainerName.ToString(), PipelineNodeType.Trainer, DefaultColumnNames.Features, + DefaultColumnNames.Score, properties); + } + + public static PipelineNode BuildLightGbmPipelineNode(TrainerName trainerName, IEnumerable sweepParams, + string labelColumn, string weightColumn) + { + return new PipelineNode(trainerName.ToString(), PipelineNodeType.Trainer, DefaultColumnNames.Features, + DefaultColumnNames.Score, BuildLightGbmPipelineNodeProps(sweepParams, labelColumn, weightColumn)); + } + + private static IDictionary BuildBasePipelineNodeProps(IEnumerable sweepParams, + string labelColumn, string weightColumn) + { + var props = new Dictionary(); + if (sweepParams != null) + { + foreach (var sweepParam in sweepParams) + { + props[sweepParam.Name] = sweepParam.ProcessedValue(); + } + } + props[LabelColumn] = labelColumn; + if (weightColumn != null) + { + props[WeightColumn] = weightColumn; + } + return props; + } + + private static IDictionary BuildLightGbmPipelineNodeProps(IEnumerable sweepParams, + string labelColumn, string weightColumn) + { + Dictionary props = null; + if (sweepParams == null || !sweepParams.Any()) + { + props = new Dictionary(); + } + else + { + var boosterParams = sweepParams.Where(p => _lightGbmBoosterParamNames.Contains(p.Name)); + var parentArgParams = sweepParams.Except(boosterParams); + + var boosterProps = boosterParams.ToDictionary(p => p.Name, p => (object)p.ProcessedValue()); + var boosterCustomProp = new CustomProperty("GradientBooster.Options", boosterProps); + + props = parentArgParams.ToDictionary(p => p.Name, p => (object)p.ProcessedValue()); + props[LightGbmBoosterPropName] = boosterCustomProp; + } + + props[LabelColumn] = labelColumn; + if (weightColumn != null) + { + props[WeightColumn] = weightColumn; + } + + return props; + } + + public static ParameterSet BuildParameterSet(TrainerName trainerName, IDictionary props) + { + props = props.Where(p => p.Key != LabelColumn && p.Key != WeightColumn) + .ToDictionary(kvp => kvp.Key, kvp => kvp.Value); + + if (trainerName == TrainerName.LightGbmBinary || trainerName == TrainerName.LightGbmMulti || + trainerName == TrainerName.LightGbmRegression) + { + return BuildLightGbmParameterSet(props); + } + + var paramVals = props.Select(p => new StringParameterValue(p.Key, p.Value.ToString())); + return new ParameterSet(paramVals); + } + + public static ColumnInformation BuildColumnInfo(IDictionary props) + { + var columnInfo = new ColumnInformation(); + + columnInfo.LabelColumnName = props[LabelColumn] as string; + + props.TryGetValue(WeightColumn, out var weightColumn); + columnInfo.ExampleWeightColumnName = weightColumn as string; + + return columnInfo; + } + + private static ParameterSet BuildLightGbmParameterSet(IDictionary props) + { + IEnumerable parameters; + if (props == null || !props.Any()) + { + parameters = new List(); + } + else + { + var parentProps = props.Where(p => p.Key != LightGbmBoosterPropName); + var treeProps = ((CustomProperty)props[LightGbmBoosterPropName]).Properties; + var allProps = parentProps.Union(treeProps); + parameters = allProps.Select(p => new StringParameterValue(p.Key, p.Value.ToString())); + } + return new ParameterSet(parameters); + } + + private static void SetValue(FieldInfo fi, IComparable value, object obj, Type propertyType) + { + if (propertyType == value?.GetType()) + fi.SetValue(obj, value); + else if (propertyType == typeof(double) && value is float) + fi.SetValue(obj, Convert.ToDouble(value)); + else if (propertyType == typeof(int) && value is long) + fi.SetValue(obj, Convert.ToInt32(value)); + else if (propertyType == typeof(long) && value is int) + fi.SetValue(obj, Convert.ToInt64(value)); + } + + /// + /// Updates properties of object instance based on the values in sweepParams + /// + public static void UpdateFields(object obj, IEnumerable sweepParams) + { + foreach (var param in sweepParams) + { + try + { + // Only updates property if param.value isn't null and + // param has a name of property. + if (param.RawValue == null) + { + continue; + } + var fi = obj.GetType().GetField(param.Name); + var propType = Nullable.GetUnderlyingType(fi.FieldType) ?? fi.FieldType; + + if (param is SweepableDiscreteParam dp) + { + var optIndex = (int)dp.RawValue; + //Contracts.Assert(0 <= optIndex && optIndex < dp.Options.Length, $"Options index out of range: {optIndex}"); + var option = dp.Options[optIndex].ToString().ToLower(); + + // Handle string values in sweep params + if (option == "auto" || option == "" || option == "< auto >") + { + //Check if nullable type, in which case 'null' is the auto value. + if (Nullable.GetUnderlyingType(fi.FieldType) != null) + fi.SetValue(obj, null); + else if (fi.FieldType.IsEnum) + { + // Check if there is an enum option named Auto + var enumDict = fi.FieldType.GetEnumValues().Cast() + .ToDictionary(v => Enum.GetName(fi.FieldType, v), v => v); + if (enumDict.ContainsKey("Auto")) + fi.SetValue(obj, enumDict["Auto"]); + } + } + else + SetValue(fi, (IComparable)dp.Options[optIndex], obj, propType); + } + else + SetValue(fi, param.RawValue, obj, propType); + } + catch (Exception) + { + throw new InvalidOperationException($"Cannot set parameter {param.Name} for {obj.GetType()}"); + } + } + } + + public static TrainerName GetTrainerName(BinaryClassificationTrainer binaryTrainer) + { + switch (binaryTrainer) + { + case BinaryClassificationTrainer.AveragedPerceptron: + return TrainerName.AveragedPerceptronBinary; + case BinaryClassificationTrainer.FastForest: + return TrainerName.FastForestBinary; + case BinaryClassificationTrainer.FastTree: + return TrainerName.FastTreeBinary; + case BinaryClassificationTrainer.LightGbm: + return TrainerName.LightGbmBinary; + case BinaryClassificationTrainer.LinearSupportVectorMachines: + return TrainerName.LinearSvmBinary; + case BinaryClassificationTrainer.LbfgsLogisticRegression: + return TrainerName.LbfgsLogisticRegressionBinary; + case BinaryClassificationTrainer.SdcaLogisticRegression: + return TrainerName.SdcaLogisticRegressionBinary; + case BinaryClassificationTrainer.SgdCalibrated: + return TrainerName.SgdCalibratedBinary; + case BinaryClassificationTrainer.SymbolicSgdLogisticRegression: + return TrainerName.SymbolicSgdLogisticRegressionBinary; + } + + // never expected to reach here + throw new NotSupportedException($"{binaryTrainer} not supported"); + } + + public static TrainerName GetTrainerName(MulticlassClassificationTrainer multiTrainer) + { + switch (multiTrainer) + { + case MulticlassClassificationTrainer.AveragedPerceptronOVA: + return TrainerName.AveragedPerceptronOva; + case MulticlassClassificationTrainer.FastForestOVA: + return TrainerName.FastForestOva; + case MulticlassClassificationTrainer.FastTreeOVA: + return TrainerName.FastTreeOva; + case MulticlassClassificationTrainer.LightGbm: + return TrainerName.LightGbmMulti; + case MulticlassClassificationTrainer.LinearSupportVectorMachinesOVA: + return TrainerName.LinearSvmOva; + case MulticlassClassificationTrainer.LbfgsMaximumEntropy: + return TrainerName.LbfgsMaximumEntropyMulti; + case MulticlassClassificationTrainer.LbfgsLogisticRegressionOVA: + return TrainerName.LbfgsLogisticRegressionOva; + case MulticlassClassificationTrainer.SdcaMaximumEntropy: + return TrainerName.SdcaMaximumEntropyMulti; + case MulticlassClassificationTrainer.SgdCalibratedOVA: + return TrainerName.SgdCalibratedOva; + case MulticlassClassificationTrainer.SymbolicSgdLogisticRegressionOVA: + return TrainerName.SymbolicSgdLogisticRegressionOva; + } + + // never expected to reach here + throw new NotSupportedException($"{multiTrainer} not supported"); + } + + public static TrainerName GetTrainerName(RegressionTrainer regressionTrainer) + { + switch (regressionTrainer) + { + case RegressionTrainer.FastForest: + return TrainerName.FastForestRegression; + case RegressionTrainer.FastTree: + return TrainerName.FastTreeRegression; + case RegressionTrainer.FastTreeTweedie: + return TrainerName.FastTreeTweedieRegression; + case RegressionTrainer.LightGbm: + return TrainerName.LightGbmRegression; + case RegressionTrainer.OnlineGradientDescent: + return TrainerName.OnlineGradientDescentRegression; + case RegressionTrainer.Ols: + return TrainerName.OlsRegression; + case RegressionTrainer.LbfgsPoissonRegression: + return TrainerName.LbfgsPoissonRegression; + case RegressionTrainer.StochasticDualCoordinateAscent: + return TrainerName.SdcaRegression; + } + + // never expected to reach here + throw new NotSupportedException($"{regressionTrainer} not supported"); + } + + public static IEnumerable GetTrainerNames(IEnumerable binaryTrainers) + { + return binaryTrainers?.Select(t => GetTrainerName(t)); + } + + public static IEnumerable GetTrainerNames(IEnumerable multiTrainers) + { + return multiTrainers?.Select(t => GetTrainerName(t)); + } + + public static IEnumerable GetTrainerNames(IEnumerable regressionTrainers) + { + return regressionTrainers?.Select(t => GetTrainerName(t)); + } + } +} diff --git a/src/Microsoft.ML.Auto/TransformInference/TransformInference.cs b/src/Microsoft.ML.Auto/TransformInference/TransformInference.cs new file mode 100644 index 0000000000..73c86b029a --- /dev/null +++ b/src/Microsoft.ML.Auto/TransformInference/TransformInference.cs @@ -0,0 +1,419 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal class SuggestedTransform + { + public readonly IEstimator Estimator; + public readonly PipelineNode PipelineNode; + + public SuggestedTransform(PipelineNode pipelineNode, IEstimator estimator) + { + PipelineNode = pipelineNode; + Estimator = estimator; + } + + public SuggestedTransform Clone() + { + return new SuggestedTransform(PipelineNode, Estimator); + } + + public override string ToString() + { + var sb = new StringBuilder(); + sb.Append(PipelineNode.Name); + sb.Append("{"); + if (PipelineNode.OutColumns.Length > 1) + { + for (var i = 0; i < PipelineNode.OutColumns.Length; i++) + { + sb.Append($" col={PipelineNode.OutColumns[i]}:{PipelineNode.InColumns[i]}"); + } + } + else + { + sb.Append($" col={PipelineNode.OutColumns[0]}:{string.Join(",", PipelineNode.InColumns)}"); + } + if (PipelineNode.Properties != null) + { + foreach (var property in PipelineNode.Properties) + { + sb.Append($" {property.Key}={property.Value}"); + } + } + sb.Append("}"); + return sb.ToString(); + } + } + + /// + /// Auto-generate set of transforms for the data view, given the purposes of specified columns. + /// + /// The design is the same as for : there's a sequence of 'experts' + /// that each look at all the columns. Every expert may or may not suggest additional transforms. + /// If the expert needs some information about the column (for example, the column values), + /// this information is lazily calculated by the column object, not the expert itself, to allow the reuse + /// of the same information by another expert. + /// + internal static class TransformInference + { + internal class IntermediateColumn + { + public readonly string ColumnName; + public readonly DataViewType Type; + public readonly ColumnPurpose Purpose; + public readonly ColumnDimensions Dimensions; + + public IntermediateColumn(string name, DataViewType type, ColumnPurpose purpose, ColumnDimensions dimensions) + { + ColumnName = name; + Type = type; + Purpose = purpose; + Dimensions = dimensions; + } + } + + internal sealed class ColumnRoutingStructure : IEquatable + { + public struct AnnotatedName + { + public string Name { get; set; } + public bool IsNumeric { get; set; } + + public bool Equals(AnnotatedName an) + { + return an.Name == Name && + an.IsNumeric == IsNumeric; + } + + public override string ToString() => $"{Name}({IsNumeric})"; + } + + public AnnotatedName[] ColumnsConsumed { get; } + public AnnotatedName[] ColumnsProduced { get; } + + public ColumnRoutingStructure(AnnotatedName[] columnsConsumed, AnnotatedName[] columnsProduced) + { + ColumnsConsumed = columnsConsumed; + ColumnsProduced = columnsProduced; + } + + public bool Equals(ColumnRoutingStructure obj) + { + return obj != null && + obj.ColumnsConsumed.All(cc => ColumnsConsumed.Any(cc.Equals)) && + obj.ColumnsProduced.All(cp => ColumnsProduced.Any(cp.Equals)); + } + } + + internal interface ITransformInferenceExpert + { + IEnumerable Apply(IntermediateColumn[] columns, TaskKind task); + } + + public abstract class TransformInferenceExpertBase : ITransformInferenceExpert + { + public abstract IEnumerable Apply(IntermediateColumn[] columns, TaskKind task); + + protected readonly MLContext Context; + + public TransformInferenceExpertBase(MLContext context) + { + Context = context; + } + } + + private static IEnumerable GetExperts(MLContext context) + { + // The expert work independently of each other, the sequence is irrelevant + // (it only determines the sequence of resulting transforms). + + // For multiclass tasks, convert label column to key + yield return new Experts.Label(context); + + // For boolean columns use convert transform + yield return new Experts.Boolean(context); + + // For categorical columns, use Cat transform. + yield return new Experts.Categorical(context); + + // For text columns, use TextTransform. + yield return new Experts.Text(context); + + // If numeric column has missing values, use Missing transform. + yield return new Experts.NumericMissing(context); + } + + internal static class Experts + { + internal sealed class Label : TransformInferenceExpertBase + { + public Label(MLContext context) : base(context) + { + } + + public override IEnumerable Apply(IntermediateColumn[] columns, TaskKind task) + { + if (task != TaskKind.MulticlassClassification) + { + yield break; + } + + var lastLabelColId = Array.FindLastIndex(columns, x => x.Purpose == ColumnPurpose.Label); + if (lastLabelColId < 0) + yield break; + + var col = columns[lastLabelColId]; + + if (!col.Type.IsKey()) + { + yield return ValueToKeyMappingExtension.CreateSuggestedTransform(Context, col.ColumnName, col.ColumnName); + } + } + } + + internal sealed class Categorical : TransformInferenceExpertBase + { + public Categorical(MLContext context) : base(context) + { + } + + public override IEnumerable Apply(IntermediateColumn[] columns, TaskKind task) + { + bool foundCat = false; + bool foundCatHash = false; + var catColumnsNew = new List(); + var catHashColumnsNew = new List(); + + foreach (var column in columns) + { + if (column.Purpose != ColumnPurpose.CategoricalFeature) + { + continue; + } + + if (column.Dimensions.Cardinality != null && column.Dimensions.Cardinality < 100) + { + foundCat = true; + catColumnsNew.Add(column.ColumnName); + } + else + { + foundCatHash = true; + catHashColumnsNew.Add(column.ColumnName); + } + } + + if (foundCat) + { + var catColumnsArr = catColumnsNew.ToArray(); + yield return OneHotEncodingExtension.CreateSuggestedTransform(Context, catColumnsArr, catColumnsArr); + } + + if (foundCatHash) + { + var catHashColumnsNewArr = catHashColumnsNew.ToArray(); + yield return OneHotHashEncodingExtension.CreateSuggestedTransform(Context, catHashColumnsNewArr, catHashColumnsNewArr); + } + + var transformedColumns = new List(); + transformedColumns.AddRange(catColumnsNew); + transformedColumns.AddRange(catHashColumnsNew); + } + } + + internal sealed class Boolean : TransformInferenceExpertBase + { + public Boolean(MLContext context) : base(context) + { + } + + public override IEnumerable Apply(IntermediateColumn[] columns, TaskKind task) + { + var newColumns = new List(); + + foreach (var column in columns) + { + if (!column.Type.GetItemType().IsBool() || column.Purpose != ColumnPurpose.NumericFeature) + { + continue; + } + + newColumns.Add(column.ColumnName); + } + + if (newColumns.Count() > 0) + { + var newColumnsArr = newColumns.ToArray(); + yield return TypeConvertingExtension.CreateSuggestedTransform(Context, newColumnsArr, newColumnsArr); + } + } + } + + internal sealed class Text : TransformInferenceExpertBase + { + public Text(MLContext context) : base(context) + { + } + + public override IEnumerable Apply(IntermediateColumn[] columns, TaskKind task) + { + var featureCols = new List(); + + foreach (var column in columns) + { + if (!column.Type.GetItemType().IsText() || column.Purpose != ColumnPurpose.TextFeature) + continue; + + var columnDestSuffix = "_tf"; + var columnNameSafe = column.ColumnName; + + string columnDestRenamed = $"{columnNameSafe}{columnDestSuffix}"; + + featureCols.Add(columnDestRenamed); + yield return TextFeaturizingExtension.CreateSuggestedTransform(Context, columnNameSafe, columnDestRenamed); + } + } + } + + internal sealed class NumericMissing : TransformInferenceExpertBase + { + public NumericMissing(MLContext context) : base(context) + { + } + + public override IEnumerable Apply(IntermediateColumn[] columns, TaskKind task) + { + var columnsWithMissing = new List(); + foreach (var column in columns) + { + if (column.Type.GetItemType() == NumberDataViewType.Single + && column.Purpose == ColumnPurpose.NumericFeature + && column.Dimensions.HasMissing == true) + { + columnsWithMissing.Add(column.ColumnName); + } + } + if (columnsWithMissing.Any()) + { + var columnsArr = columnsWithMissing.ToArray(); + var indicatorColNames = GetNewColumnNames(columnsArr.Select(c => $"{c}_MissingIndicator"), columns).ToArray(); + yield return MissingValueIndicatingExtension.CreateSuggestedTransform(Context, columnsArr, indicatorColNames); + yield return TypeConvertingExtension.CreateSuggestedTransform(Context, indicatorColNames, indicatorColNames); + yield return MissingValueReplacingExtension.CreateSuggestedTransform(Context, columnsArr, columnsArr); + } + } + } + } + + /// + /// Automatically infer transforms for the data view + /// + public static SuggestedTransform[] InferTransforms(MLContext context, TaskKind task, DatasetColumnInfo[] columns) + { + var intermediateCols = columns.Where(c => c.Purpose != ColumnPurpose.Ignore) + .Select(c => new IntermediateColumn(c.Name, c.Type, c.Purpose, c.Dimensions)) + .ToArray(); + + var suggestedTransforms = new List(); + foreach (var expert in GetExperts(context)) + { + SuggestedTransform[] suggestions = expert.Apply(intermediateCols, task).ToArray(); + suggestedTransforms.AddRange(suggestions); + } + + var finalFeaturesConcatTransform = BuildFinalFeaturesConcatTransform(context, suggestedTransforms, intermediateCols); + if (finalFeaturesConcatTransform != null) + { + suggestedTransforms.Add(finalFeaturesConcatTransform); + } + + return suggestedTransforms.ToArray(); + } + + /// + /// Build final features concat transform, using output of all suggested experts. + /// Take the output columns from all suggested experts (except for 'Label'), and concatenate them + /// into one final 'Features' column that a trainer will accept. + /// + private static SuggestedTransform BuildFinalFeaturesConcatTransform(MLContext context, IEnumerable suggestedTransforms, + IEnumerable intermediateCols) + { + // get the output column names from all suggested transforms + var concatColNames = new List(); + foreach (var suggestedTransform in suggestedTransforms) + { + concatColNames.AddRange(suggestedTransform.PipelineNode.OutColumns); + } + + // include all numeric columns of type R4 + foreach(var intermediateCol in intermediateCols) + { + if (intermediateCol.Purpose == ColumnPurpose.NumericFeature && + intermediateCol.Type.GetItemType() == NumberDataViewType.Single) + { + concatColNames.Add(intermediateCol.ColumnName); + } + } + + // remove column with 'Label' purpose + var labelColumnName = intermediateCols.FirstOrDefault(c => c.Purpose == ColumnPurpose.Label)?.ColumnName; + concatColNames.Remove(labelColumnName); + + intermediateCols = intermediateCols.Where(c => c.Purpose == ColumnPurpose.NumericFeature || + c.Purpose == ColumnPurpose.CategoricalFeature || c.Purpose == ColumnPurpose.TextFeature); + + if (!concatColNames.Any() || (concatColNames.Count == 1 && + concatColNames[0] == DefaultColumnNames.Features && + intermediateCols.First().Type.IsVector())) + { + return null; + } + + if (concatColNames.Count() == 1 && + (intermediateCols.First().Type.IsVector() || + intermediateCols.First().Purpose == ColumnPurpose.CategoricalFeature || + intermediateCols.First().Purpose == ColumnPurpose.TextFeature)) + { + return ColumnCopyingExtension.CreateSuggestedTransform(context, concatColNames.First(), DefaultColumnNames.Features); + } + + return ColumnConcatenatingExtension.CreateSuggestedTransform(context, concatColNames.Distinct().ToArray(), DefaultColumnNames.Features); + } + + private static IEnumerable GetNewColumnNames(IEnumerable desiredColNames, IEnumerable columns) + { + var newColNames = new List(); + + var existingColNames = new HashSet(columns.Select(c => c.ColumnName)); + foreach (var desiredColName in desiredColNames) + { + if (!existingColNames.Contains(desiredColName)) + { + newColNames.Add(desiredColName); + continue; + } + + for(var i = 0; ; i++) + { + var newColName = $"{desiredColName}{i}"; + if (!existingColNames.Contains(newColName)) + { + newColNames.Add(newColName); + break; + } + } + } + + return newColNames; + } + } +} diff --git a/src/Microsoft.ML.Auto/TransformInference/TransformInferenceApi.cs b/src/Microsoft.ML.Auto/TransformInference/TransformInferenceApi.cs new file mode 100644 index 0000000000..384f4a6aa6 --- /dev/null +++ b/src/Microsoft.ML.Auto/TransformInference/TransformInferenceApi.cs @@ -0,0 +1,22 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal static class TransformInferenceApi + { + public static IEnumerable InferTransforms(MLContext context, TaskKind task, DatasetColumnInfo[] columns) + { + return TransformInference.InferTransforms(context, task, columns); + } + + public static IEnumerable InferTransformsPostTrainer(MLContext context, TaskKind task, DatasetColumnInfo[] columns) + { + return TransformPostTrainerInference.InferTransforms(context, task, columns); + } + } +} diff --git a/src/Microsoft.ML.Auto/TransformInference/TransformPostTrainerInference.cs b/src/Microsoft.ML.Auto/TransformInference/TransformPostTrainerInference.cs new file mode 100644 index 0000000000..aa2166d101 --- /dev/null +++ b/src/Microsoft.ML.Auto/TransformInference/TransformPostTrainerInference.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal class TransformPostTrainerInference + { + public static IEnumerable InferTransforms(MLContext context, TaskKind task, DatasetColumnInfo[] columns) + { + var suggestedTransforms = new List(); + suggestedTransforms.AddRange(InferLabelTransforms(context, task, columns)); + return suggestedTransforms; + } + + private static IEnumerable InferLabelTransforms(MLContext context, TaskKind task, + DatasetColumnInfo[] columns) + { + var inferredTransforms = new List(); + + if (task != TaskKind.MulticlassClassification) + { + return inferredTransforms; + } + + // If label column type wasn't originally key type, + // convert predicted label column back from key to value. + // (Non-key label column was converted to key, b/c multiclass trainers only + // accept label columns that are key type) + var labelColumn = columns.First(c => c.Purpose == ColumnPurpose.Label); + if (!labelColumn.Type.IsKey()) + { + inferredTransforms.Add(KeyToValueMappingExtension.CreateSuggestedTransform(context, DefaultColumnNames.PredictedLabel, DefaultColumnNames.PredictedLabel)); + } + + return inferredTransforms; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/Utils/BestResultUtil.cs b/src/Microsoft.ML.Auto/Utils/BestResultUtil.cs new file mode 100644 index 0000000000..4f8bc384bc --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/BestResultUtil.cs @@ -0,0 +1,89 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.ML.Auto +{ + internal class BestResultUtil + { + public static RunDetail GetBestRun(IEnumerable> results, + IMetricsAgent metricsAgent, bool isMetricMaximizing) + { + results = results.Where(r => r.ValidationMetrics != null); + if (!results.Any()) { return null; } + var scores = results.Select(r => metricsAgent.GetScore(r.ValidationMetrics)); + var indexOfBestScore = GetIndexOfBestScore(scores, isMetricMaximizing); + return results.ElementAt(indexOfBestScore); + } + + public static CrossValidationRunDetail GetBestRun(IEnumerable> results, + IMetricsAgent metricsAgent, bool isMetricMaximizing) + { + results = results.Where(r => r.Results != null && r.Results.Any(x => x.ValidationMetrics != null)); + if (!results.Any()) { return null; } + var scores = results.Select(r => r.Results.Average(x => metricsAgent.GetScore(x.ValidationMetrics))); + var indexOfBestScore = GetIndexOfBestScore(scores, isMetricMaximizing); + return results.ElementAt(indexOfBestScore); + } + + public static IEnumerable<(RunDetail, int)> GetTopNRunResults(IEnumerable> results, + IMetricsAgent metricsAgent, int n, bool isMetricMaximizing) + { + results = results.Where(r => r.ValidationMetrics != null); + if (!results.Any()) { return null; } + + var indexedValues = results.Select((k, v) => (k, v)); + + IEnumerable<(RunDetail, int)> orderedResults; + if (isMetricMaximizing) + { + orderedResults = indexedValues.OrderByDescending(t => metricsAgent.GetScore(t.Item1.ValidationMetrics)); + + } + else + { + orderedResults = indexedValues.OrderBy(t => metricsAgent.GetScore(t.Item1.ValidationMetrics)); + } + + return orderedResults.Take(n); + } + + public static int GetIndexOfBestScore(IEnumerable scores, bool isMetricMaximizing) + { + return isMetricMaximizing ? GetIndexOfMaxScore(scores) : GetIndexOfMinScore(scores); + } + + private static int GetIndexOfMinScore(IEnumerable scores) + { + var minScore = double.PositiveInfinity; + var minIndex = -1; + for (var i = 0; i < scores.Count(); i++) + { + if (scores.ElementAt(i) < minScore) + { + minScore = scores.ElementAt(i); + minIndex = i; + } + } + return minIndex; + } + + private static int GetIndexOfMaxScore(IEnumerable scores) + { + var maxScore = double.NegativeInfinity; + var maxIndex = -1; + for (var i = 0; i < scores.Count(); i++) + { + if (scores.ElementAt(i) > maxScore) + { + maxScore = scores.ElementAt(i); + maxIndex = i; + } + } + return maxIndex; + } + } +} diff --git a/src/Microsoft.ML.Auto/Utils/DatasetColumnInfo.cs b/src/Microsoft.ML.Auto/Utils/DatasetColumnInfo.cs new file mode 100644 index 0000000000..9e1aa66523 --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/DatasetColumnInfo.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal class DatasetColumnInfo + { + public readonly string Name; + public readonly DataViewType Type; + public readonly ColumnPurpose Purpose; + public readonly ColumnDimensions Dimensions; + + public DatasetColumnInfo(string name, DataViewType type, ColumnPurpose purpose, ColumnDimensions dimensions) + { + Name = name; + Type = type; + Purpose = purpose; + Dimensions = dimensions; + } + } + + internal static class DatasetColumnInfoUtil + { + public static DatasetColumnInfo[] GetDatasetColumnInfo(MLContext context, IDataView data, ColumnInformation columnInfo) + { + var purposes = PurposeInference.InferPurposes(context, data, columnInfo); + var colDimensions = DatasetDimensionsApi.CalcColumnDimensions(context, data, purposes); + var cols = new DatasetColumnInfo[data.Schema.Count]; + for (var i = 0; i < cols.Length; i++) + { + var schemaCol = data.Schema[i]; + var col = new DatasetColumnInfo(schemaCol.Name, schemaCol.Type, purposes[i].Purpose, colDimensions[i]); + cols[i] = col; + } + return cols; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/Utils/MLNetUtils/AnnotationBuilderExtensions.cs b/src/Microsoft.ML.Auto/Utils/MLNetUtils/AnnotationBuilderExtensions.cs new file mode 100644 index 0000000000..f615139118 --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/MLNetUtils/AnnotationBuilderExtensions.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal static class AnnotationBuilderExtensions + { + /// + /// Add slot names annotation. + /// + /// The to which to add the slot names. + /// The size of the slot names vector. + /// The getter delegate for the slot names. + public static void AddSlotNames(this DataViewSchema.Annotations.Builder builder, int size, ValueGetter>> getter) + { + builder.Add("SlotNames", new VectorDataViewType(TextDataViewType.Instance, size), getter, null); + } + + /// + /// Add key values annotation. + /// + /// The value type of key values. + /// The to which to add the key values. + /// The size of key values vector. + /// The value type of key values. Its raw type must match . + /// The getter delegate for the key values. + public static void AddKeyValues(this DataViewSchema.Annotations.Builder builder, int size, PrimitiveDataViewType valueType, ValueGetter> getter) + { + builder.Add("KeyValues", new VectorDataViewType(valueType, size), getter, null); + } + } +} diff --git a/src/Microsoft.ML.Auto/Utils/MLNetUtils/ArrayDataViewBuilder.cs b/src/Microsoft.ML.Auto/Utils/MLNetUtils/ArrayDataViewBuilder.cs new file mode 100644 index 0000000000..e529bf4bb2 --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/MLNetUtils/ArrayDataViewBuilder.cs @@ -0,0 +1,468 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using Microsoft.ML.Data; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Auto +{ + using BitArray = System.Collections.BitArray; + + /// + /// This is a class for composing an in memory IDataView. + /// + internal sealed class ArrayDataViewBuilder + { + private readonly IHost _host; + private readonly List _columns; + private readonly List _names; + private readonly Dictionary>>> _getSlotNames; + private readonly Dictionary>>> _getKeyValues; + + private int? RowCount + { + get + { + if (_columns.Count == 0) + return null; + return _columns[0].Length; + } + } + + public ArrayDataViewBuilder(IHostEnvironment env) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register("ArrayDataViewBuilder"); + + _columns = new List(); + _names = new List(); + _getSlotNames = new Dictionary>>>(); + _getKeyValues = new Dictionary>>>(); + } + + /// + /// Verifies that the input array to one of the add routines is of the same length + /// as previously added arrays, assuming there were any. + /// + private void CheckLength(string name, T[] values) + { + _host.CheckValue(name, nameof(name)); + _host.CheckValue(values, nameof(values)); + if (_columns.Count > 0 && values.Length != _columns[0].Length) + throw _host.Except("Previous inputs were of length {0}, but new input is of length {1}", _columns[0].Length, values.Length); + } + + /// + /// Constructs a new column from an array where values are copied to output simply + /// by being assigned. Output values are returned simply by being assigned, so the + /// type should be a type where assigning to a different + /// value does not compromise the immutability of the source object (so, for example, + /// a scalar, string, or ReadOnlyMemory would be perfectly acceptable, but a + /// HashSet or VBuffer would not be). + /// + public void AddColumn(string name, PrimitiveDataViewType type, params T[] values) + { + _host.CheckParam(type != null && type.RawType == typeof(T), nameof(type)); + CheckLength(name, values); + _columns.Add(new AssignmentColumn(type, values)); + _names.Add(name); + } + + /// + /// Constructs a new key column from an array where values are copied to output simply + /// by being assigned. + /// + /// The name of the column. + /// The delegate that does a reverse lookup based upon the given key. This is for annotation creation + /// The count of unique keys specified in values + /// The values to add to the column. Note that since this is creating a column, the values will be offset by 1. + public void AddColumn(string name, ValueGetter>> getKeyValues, ulong keyCount, params T1[] values) + { + _host.CheckValue(getKeyValues, nameof(getKeyValues)); + _host.CheckParam(keyCount > 0, nameof(keyCount)); + CheckLength(name, values); + var elemType = values.GetType().GetElementType(); + _columns.Add(new AssignmentColumn(new KeyDataViewType(elemType, keyCount), values)); + _getKeyValues.Add(name, getKeyValues); + _names.Add(name); + } + + /// + /// Creates a column with slot names from arrays. The added column will be re-interpreted as a buffer. + /// + public void AddColumn(string name, ValueGetter>> getNames, PrimitiveDataViewType itemType, params T[][] values) + { + _host.CheckValue(getNames, nameof(getNames)); + _host.CheckParam(itemType != null && itemType.RawType == typeof(T), nameof(itemType)); + CheckLength(name, values); + var col = new ArrayToVBufferColumn(itemType, values); + _columns.Add(col); + _getSlotNames.Add(name, getNames); + _names.Add(name); + } + + /// + /// Creates a column from arrays. The added column will be re-interpreted as a buffer. + /// + public void AddColumn(string name, PrimitiveDataViewType itemType, params T[][] values) + { + _host.CheckParam(itemType != null && itemType.RawType == typeof(T), nameof(itemType)); + CheckLength(name, values); + _columns.Add(new ArrayToVBufferColumn(itemType, values)); + _names.Add(name); + } + + /// + /// Adds a VBuffer{T} valued column. + /// + public void AddColumn(string name, PrimitiveDataViewType itemType, params VBuffer[] values) + { + _host.CheckParam(itemType != null && itemType.RawType == typeof(T), nameof(itemType)); + CheckLength(name, values); + _columns.Add(new VBufferColumn(itemType, values)); + _names.Add(name); + } + + /// + /// Adds a VBuffer{T} valued column. + /// + public void AddColumn(string name, ValueGetter>> getNames, PrimitiveDataViewType itemType, params VBuffer[] values) + { + _host.CheckValue(getNames, nameof(getNames)); + _host.CheckParam(itemType != null && itemType.RawType == typeof(T), nameof(itemType)); + CheckLength(name, values); + _columns.Add(new VBufferColumn(itemType, values)); + _getSlotNames.Add(name, getNames); + _names.Add(name); + } + + /// + /// Adds a ReadOnlyMemory valued column from an array of strings. + /// + public void AddColumn(string name, params string[] values) + { + CheckLength(name, values); + _columns.Add(new StringToTextColumn(values)); + _names.Add(name); + } + + /// + /// Constructs a data view from the columns added so far. Note that it is perfectly acceptable + /// to continue adding columns to the builder, but these additions will not be reflected in the + /// returned dataview. + /// + /// + public IDataView GetDataView(int? rowCount = null) + { + if (rowCount.HasValue) + { + _host.Check(!RowCount.HasValue || RowCount.Value == rowCount.Value, "Specified row count incompatible with existing columns"); + return new DataView(_host, this, rowCount.Value); + } + _host.Check(_columns.Count > 0, "Cannot construct data-view with neither any columns nor a specified row count"); + return new DataView(_host, this, RowCount.Value); + } + + private sealed class DataView : IDataView + { + private readonly int _rowCount; + private readonly Column[] _columns; + private readonly DataViewSchema _schema; + private readonly IHost _host; + + public DataViewSchema Schema { get { return _schema; } } + + public long? GetRowCount() { return _rowCount; } + + public bool CanShuffle { get { return true; } } + + public DataView(IHostEnvironment env, ArrayDataViewBuilder builder, int rowCount) + { + _host = env.Register("ArrayDataView"); + + _columns = builder._columns.ToArray(); + + var schemaBuilder = new DataViewSchema.Builder(); + for (int i = 0; i < _columns.Length; i++) + { + var meta = new DataViewSchema.Annotations.Builder(); + + if (builder._getSlotNames.TryGetValue(builder._names[i], out var slotNamesGetter)) + meta.AddSlotNames(_columns[i].Type.GetVectorSize(), slotNamesGetter); + + if (builder._getKeyValues.TryGetValue(builder._names[i], out var keyValueGetter)) + meta.AddKeyValues(_columns[i].Type.GetKeyCountAsInt32(_host), TextDataViewType.Instance, keyValueGetter); + schemaBuilder.AddColumn(builder._names[i], _columns[i].Type, meta.ToAnnotations()); + } + + _schema = schemaBuilder.ToSchema(); + _rowCount = rowCount; + } + + public DataViewRowCursor GetRowCursor(IEnumerable columnsNeeded, Random rand = null) + { + var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema); + + return new Cursor(_host, this, predicate, rand); + } + + public DataViewRowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int n, Random rand = null) + { + var predicate = RowCursorUtils.FromColumnsToPredicate(columnsNeeded, Schema); + + return new DataViewRowCursor[] { new Cursor(_host, this, predicate, rand) }; + } + + private sealed class Cursor : RootCursorBase + { + private readonly DataView _view; + private readonly BitArray _active; + private readonly int[] _indices; + + public override DataViewSchema Schema => _view.Schema; + + public override long Batch + { + // REVIEW: Implement cursor set support. + get { return 0; } + } + + public Cursor(IChannelProvider provider, DataView view, Func predicate, Random rand) + : base(provider) + { + _view = view; + _active = new BitArray(view.Schema.Count); + if (predicate == null) + _active.SetAll(true); + else + { + for (int i = 0; i < view.Schema.Count; ++i) + _active[i] = predicate(i); + } + if (rand != null) + _indices = MLNetUtils.GetRandomPermutation(rand, view._rowCount); + } + + public override ValueGetter GetIdGetter() + { + if (_indices == null) + { + return + (ref DataViewRowId val) => + { + Ch.Check(IsGood, RowCursorUtils.FetchValueStateError); + val = new DataViewRowId((ulong)Position, 0); + }; + } + else + { + return + (ref DataViewRowId val) => + { + Ch.Check(IsGood, RowCursorUtils.FetchValueStateError); + val = new DataViewRowId((ulong)MappedIndex(), 0); + }; + } + } + + /// + /// Returns whether the given column is active in this row. + /// + public override bool IsColumnActive(DataViewSchema.Column column) + { + Ch.Check(column.Index < Schema.Count); + return _active[column.Index]; + } + + /// + /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row. + /// This throws if the column is not active in this row, or if the type + /// differs from this column's type. + /// + /// is the column's content type. + /// is the output column whose getter should be returned. + public override ValueGetter GetGetter(DataViewSchema.Column column) + { + Ch.Check(column.Index < Schema.Count); + Ch.Check(column.Index < _active.Length && _active[column.Index], "the requested column is not active"); + + var columnValue = _view._columns[column.Index] as Column; + if (columnValue == null) + throw Ch.Except("Invalid TValue: '{0}'", typeof(TValue)); + + return + (ref TValue value) => + { + Ch.Check(IsGood, RowCursorUtils.FetchValueStateError); + columnValue.CopyOut(MappedIndex(), ref value); + }; + } + + protected override bool MoveNextCore() + { + return 1 < _view._rowCount - Position; + } + + private int MappedIndex() + { + if (_indices == null) + return (int)Position; + return _indices[(int)Position]; + } + } + } + + #region Column implementations + + private abstract class Column + { + public readonly DataViewType Type; + + public abstract int Length { get; } + + public Column(DataViewType type) + { + Type = type; + } + } + + private abstract class Column : Column + { + /// + /// Produce the output value given the index. + /// + public abstract void CopyOut(int index, ref TOut value); + + public Column(DataViewType type) + : base(type) + { + } + } + + private abstract class Column : Column + { + private readonly TIn[] _values; + + public override int Length { get { return _values.Length; } } + + public Column(DataViewType type, TIn[] values) + : base(type) + { + _values = values; + } + + /// + /// Assigns dst in such a way that the caller has ownership of dst without + /// compromising this object's ownership of src. What that operation will be + /// will depend on the types. + /// + protected abstract void CopyOut(in TIn src, ref TOut dst); + + /// + /// Produce the output value given the index. This overload utilizes the CopyOut + /// helper function. + /// + public override void CopyOut(int index, ref TOut value) + { + CopyOut(in _values[index], ref value); + } + } + + /// + /// A column where the input and output types are the same, and simple assignment does + /// not compromise ownership of the internal vlaues. + /// + private sealed class AssignmentColumn : Column + { + public AssignmentColumn(PrimitiveDataViewType type, T[] values) + : base(type, values) + { + } + + protected override void CopyOut(in T src, ref T dst) + { + dst = src; + } + } + + /// + /// A convenience column for converting strings into textspans. + /// + private sealed class StringToTextColumn : Column> + { + public StringToTextColumn(string[] values) + : base(TextDataViewType.Instance, values) + { + } + + protected override void CopyOut(in string src, ref ReadOnlyMemory dst) + { + dst = src.AsMemory(); + } + } + + private abstract class VectorColumn : Column> + { + public VectorColumn(PrimitiveDataViewType itemType, TIn[] values, Func lengthFunc) + : base(InferType(itemType, values, lengthFunc), values) + { + } + + /// + /// A utility function for subclasses that want to get the type with a dimension based + /// on the input value array and some length function over the input type. + /// + private static DataViewType InferType(PrimitiveDataViewType itemType, TIn[] values, Func lengthFunc) + { + int degree = 0; + if (MLNetUtils.Size(values) > 0) + { + degree = lengthFunc(values[0]); + for (int i = 1; i < values.Length; ++i) + { + if (degree != lengthFunc(values[i])) + { + degree = 0; + break; + } + } + } + return new VectorDataViewType(itemType, degree); + } + } + + /// + /// A column of buffers. + /// + private sealed class VBufferColumn : VectorColumn, T> + { + public VBufferColumn(PrimitiveDataViewType itemType, VBuffer[] values) + : base(itemType, values, v => v.Length) + { + } + + protected override void CopyOut(in VBuffer src, ref VBuffer dst) + { + src.CopyTo(ref dst); + } + } + + private sealed class ArrayToVBufferColumn : VectorColumn + { + public ArrayToVBufferColumn(PrimitiveDataViewType itemType, T[][] values) + : base(itemType, values, MLNetUtils.Size) + { + } + + protected override void CopyOut(in T[] src, ref VBuffer dst) + { + VBuffer.Copy(src, 0, ref dst, MLNetUtils.Size(src)); + } + } + + #endregion + } +} diff --git a/src/Microsoft.ML.Auto/Utils/MLNetUtils/ColumnTypeExtensions.cs b/src/Microsoft.ML.Auto/Utils/MLNetUtils/ColumnTypeExtensions.cs new file mode 100644 index 0000000000..f520fc2fd9 --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/MLNetUtils/ColumnTypeExtensions.cs @@ -0,0 +1,109 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Auto +{ + internal static class DataViewTypeExtensions + { + public static bool IsNumber(this DataViewType columnType) + { + return columnType is NumberDataViewType; + } + + public static bool IsText(this DataViewType columnType) + { + return columnType is TextDataViewType; + } + + public static bool IsBool(this DataViewType columnType) + { + return columnType is BooleanDataViewType; + } + + public static bool IsVector(this DataViewType columnType) + { + return columnType is VectorDataViewType; + } + + public static bool IsKey(this DataViewType columnType) + { + return columnType is KeyDataViewType; + } + + public static bool IsKnownSizeVector(this DataViewType columnType) + { + var vector = columnType as VectorDataViewType; + if (vector == null) + { + return false; + } + return vector.Size > 0; + } + + public static DataViewType GetItemType(this DataViewType columnType) + { + var vector = columnType as VectorDataViewType; + if (vector == null) + { + return columnType; + } + return vector.ItemType; + } + + /// + /// Zero return means either it's not a vector or the size is unknown. + /// + public static int GetVectorSize(this DataViewType columnType) + { + return (columnType as VectorDataViewType)?.Size ?? 0; + } + + public static DataKind GetRawKind(this DataViewType columnType) + { + columnType.RawType.TryGetDataKind(out var rawKind); + return rawKind; + } + + /// + /// Zero return means it's not a key type. + /// + public static ulong GetKeyCount(this DataViewType columnType) + { + return (columnType as KeyDataViewType)?.Count ?? 0; + } + + /// + /// Sometimes it is necessary to cast the Count to an int. This performs overflow check. + /// Zero return means it's not a key type. + /// + public static int GetKeyCountAsInt32(this DataViewType columnType, IExceptionContext ectx = null) + { + ulong keyCount = columnType.GetKeyCount(); + return (int)keyCount; + } + + /// + /// Equivalent to calling Equals(ColumnType) for non-vector types. For vector type, + /// returns true if current and other vector types have the same size and item type. + /// + public static bool SameSizeAndItemType(this DataViewType columnType, DataViewType other) + { + if (other == null) + return false; + + if (columnType.Equals(other)) + return true; + + // For vector types, we don't care about the factoring of the dimensions. + if (!(columnType is VectorDataViewType vectorType) || !(other is VectorDataViewType otherVectorType)) + return false; + if (!vectorType.ItemType.Equals(otherVectorType.ItemType)) + return false; + return vectorType.Size == otherVectorType.Size; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/Utils/MLNetUtils/Contracts.cs b/src/Microsoft.ML.Auto/Utils/MLNetUtils/Contracts.cs new file mode 100644 index 0000000000..d749e37520 --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/MLNetUtils/Contracts.cs @@ -0,0 +1,94 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Diagnostics; +using System.Globalization; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Auto +{ + internal static class Contracts + { + public static void Check(this IExceptionContext ctx, bool f, string msg) + { + if (!f) + { + throw Except(ctx, msg); + } + } + + public static void Check(this IExceptionContext ctx, bool f) + { + if (!f) + { + throw new InvalidOperationException(); + } + } + + public static void CheckValue(T val, string paramName) where T : class + { + if (object.ReferenceEquals(val, null)) + { + throw new ArgumentNullException(paramName); + } + } + + public static void CheckValue(this IExceptionContext ctx, T val, string paramName) where T : class + { + if (object.ReferenceEquals(val, null)) + { + throw new ArgumentNullException(paramName); + } + } + + public static void CheckParam(this IExceptionContext ctx, bool f, string paramName) + { + if (!f) + { + throw ExceptParam(ctx, paramName); + } + } + + public static void CheckParam(bool f, string paramName) + { + if (!f) + { + throw ExceptParam(paramName); + } + } + + public static void Assert(bool f, string msg) + { + if (!f) + { + Debug.Fail(msg); + } + } + + public static Exception Except(this IExceptionContext ctx, string msg, params object[] args) + => throw new InvalidOperationException(GetMsg(msg, args)); + + public static Exception ExceptParam(this IExceptionContext ctx, string paramName) + => new ArgumentOutOfRangeException(paramName); + + public static Exception Except(string msg) => new InvalidOperationException(msg); + + public static Exception ExceptParam(string paramName) + => new ArgumentOutOfRangeException(paramName); + + private static string GetMsg(string msg, params object[] args) + { + try + { + msg = string.Format(CultureInfo.InvariantCulture, msg, args); + } + catch (FormatException ex) + { + Contracts.Assert(false, "Format string arg mismatch: " + ex.Message); + } + return msg; + } + } +} diff --git a/src/Microsoft.ML.Auto/Utils/MLNetUtils/Conversions.cs b/src/Microsoft.ML.Auto/Utils/MLNetUtils/Conversions.cs new file mode 100644 index 0000000000..501a4731b9 --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/MLNetUtils/Conversions.cs @@ -0,0 +1,198 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Auto +{ + using BL = Boolean; + using R4 = Single; + using TX = ReadOnlyMemory; + + internal static class Conversions + { + /// + /// This produces zero for empty. It returns false if the text is not parsable. + /// On failure, it sets dst to the NA value. + /// + public static bool TryParse(in TX src, out R4 dst) + { + var span = src.Span; + var str = span.ToString(); + if (string.IsNullOrWhiteSpace(str)) + { + dst = R4.NaN; + return true; + } + if (float.TryParse(str, out dst)) + { + return true; + } + dst = R4.NaN; + return IsStdMissing(ref span); + } + + /// + /// Return true if the span contains a standard text representation of NA + /// other than the standard TX missing representation - callers should + /// have already dealt with that case and the case of empty. + /// The standard representations are any casing of: + /// ? NaN NA N/A + /// + private static bool IsStdMissing(ref ReadOnlySpan span) + { + char ch; + switch (span.Length) + { + default: + return false; + + case 1: + if (span[0] == '?') + return true; + return false; + case 2: + if ((ch = span[0]) != 'N' && ch != 'n') + return false; + if ((ch = span[1]) != 'A' && ch != 'a') + return false; + return true; + case 3: + if ((ch = span[0]) != 'N' && ch != 'n') + return false; + if ((ch = span[1]) == '/') + { + // Check for N/A. + if ((ch = span[2]) != 'A' && ch != 'a') + return false; + } + else + { + // Check for NaN. + if (ch != 'a' && ch != 'A') + return false; + if ((ch = span[2]) != 'N' && ch != 'n') + return false; + } + return true; + } + } + + /// + /// Try parsing a TX to a BL. This returns false for NA (span.IsMissing). + /// Otherwise, it trims the span, then succeeds on all casings of the strings: + /// * false, f, no, n, 0, -1, - => false + /// * true, t, yes, y, 1, +1, + => true + /// Empty string (but not missing string) succeeds and maps to false. + /// + public static bool TryParse(in TX src, out BL dst) + { + var span = src.Span; + + char ch; + switch (src.Length) + { + case 0: + // Empty succeeds and maps to false. + dst = false; + return true; + + case 1: + switch (span[0]) + { + case 'T': + case 't': + case 'Y': + case 'y': + case '1': + case '+': + dst = true; + return true; + case 'F': + case 'f': + case 'N': + case 'n': + case '0': + case '-': + dst = false; + return true; + } + break; + + case 2: + switch (span[0]) + { + case 'N': + case 'n': + if ((ch = span[1]) != 'O' && ch != 'o') + break; + dst = false; + return true; + case '+': + if ((ch = span[1]) != '1') + break; + dst = true; + return true; + case '-': + if ((ch = span[1]) != '1') + break; + dst = false; + return true; + } + break; + + case 3: + switch (span[0]) + { + case 'Y': + case 'y': + if ((ch = span[1]) != 'E' && ch != 'e') + break; + if ((ch = span[2]) != 'S' && ch != 's') + break; + dst = true; + return true; + } + break; + + case 4: + switch (span[0]) + { + case 'T': + case 't': + if ((ch = span[1]) != 'R' && ch != 'r') + break; + if ((ch = span[2]) != 'U' && ch != 'u') + break; + if ((ch = span[3]) != 'E' && ch != 'e') + break; + dst = true; + return true; + } + break; + + case 5: + switch (span[0]) + { + case 'F': + case 'f': + if ((ch = span[1]) != 'A' && ch != 'a') + break; + if ((ch = span[2]) != 'L' && ch != 'l') + break; + if ((ch = span[3]) != 'S' && ch != 's') + break; + if ((ch = span[4]) != 'E' && ch != 'e') + break; + dst = false; + return true; + } + break; + } + + dst = false; + return false; + } + } +} diff --git a/src/Microsoft.ML.Auto/Utils/MLNetUtils/DataKindExtensions.cs b/src/Microsoft.ML.Auto/Utils/MLNetUtils/DataKindExtensions.cs new file mode 100644 index 0000000000..a2780c8667 --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/MLNetUtils/DataKindExtensions.cs @@ -0,0 +1,96 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal static class DataKindExtensions + { + /// + /// Try to map a System.Type to a corresponding DataKind value. + /// + public static bool TryGetDataKind(this Type type, out DataKind kind) + { + if (type == typeof(sbyte)) + { + kind = DataKind.SByte; + } + else if (type == typeof(byte)) + { + kind = DataKind.Byte; + } + else if (type == typeof(short)) + { + kind = DataKind.Int16; + } + else if (type == typeof(ushort)) + { + kind = DataKind.UInt16; + } + else if (type == typeof(int)) + { + kind = DataKind.Int32; + } + else if (type == typeof(uint)) + { + kind = DataKind.UInt32; + } + else if (type == typeof(long)) + { + kind = DataKind.Int64; + } + else if (type == typeof(ulong)) + { + kind = DataKind.UInt64; + } + else if (type == typeof(float)) + { + kind = DataKind.Single; + } + else if (type == typeof(double)) + { + kind = DataKind.Double; + } + else + { + if (!(type == typeof(ReadOnlyMemory)) && !(type == typeof(string))) + { + if (type == typeof(bool)) + { + kind = DataKind.Boolean; + goto IL_01ad; + } + if (type == typeof(TimeSpan)) + { + kind = DataKind.TimeSpan; + goto IL_01ad; + } + if (type == typeof(DateTime)) + { + kind = DataKind.DateTime; + goto IL_01ad; + } + if (type == typeof(DateTimeOffset)) + { + kind = DataKind.DateTimeOffset; + goto IL_01ad; + } + if (type == typeof(DataViewRowId)) + { + kind = DataKind.UInt16; + goto IL_01ad; + } + kind = (DataKind)0; + return false; + } + kind = DataKind.String; + } + goto IL_01ad; + IL_01ad: + return true; + } + } +} diff --git a/src/Microsoft.ML.Auto/Utils/MLNetUtils/DefaultColumnNames.cs b/src/Microsoft.ML.Auto/Utils/MLNetUtils/DefaultColumnNames.cs new file mode 100644 index 0000000000..cb69603378 --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/MLNetUtils/DefaultColumnNames.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Auto +{ + internal static class DefaultColumnNames + { + public const string Features = "Features"; + public const string Label = "Label"; + public const string GroupId = "GroupId"; + public const string Name = "Name"; + public const string Weight = "Weight"; + public const string Score = "Score"; + public const string Probability = "Probability"; + public const string PredictedLabel = "PredictedLabel"; + public const string RecommendedItems = "Recommended"; + public const string User = "User"; + public const string Item = "Item"; + public const string Date = "Date"; + public const string FeatureContributions = "FeatureContributions"; + } +} diff --git a/src/Microsoft.ML.Auto/Utils/MLNetUtils/Hashing.cs b/src/Microsoft.ML.Auto/Utils/MLNetUtils/Hashing.cs new file mode 100644 index 0000000000..0484d56866 --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/MLNetUtils/Hashing.cs @@ -0,0 +1,33 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.Auto +{ + internal static class Hashing + { + public static uint CombineHash(uint u1, uint u2) + { + return ((u1 << 7) | (u1 >> 25)) ^ u2; + } + + public static int CombineHash(int n1, int n2) + { + return (int)CombineHash((uint)n1, (uint)n2); + } + + /// + /// Creates a combined hash of possibly heterogenously typed values. + /// + /// The leading hash, incorporated into the final hash + /// A variable list of objects, where null is a valid value + /// The combined hash incorpoating a starting hash, and the hash codes + /// of all input values + public static int CombinedHash(int startHash, params object[] os) + { + foreach (object o in os) + startHash = CombineHash(startHash, o == null ? 0 : o.GetHashCode()); + return startHash; + } + } +} diff --git a/src/Microsoft.ML.Auto/Utils/MLNetUtils/LinqExtensions.cs b/src/Microsoft.ML.Auto/Utils/MLNetUtils/LinqExtensions.cs new file mode 100644 index 0000000000..d8cbfbc999 --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/MLNetUtils/LinqExtensions.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.ML.Auto +{ + internal static class LinqExtensions + { + public static int ArgMax(this IEnumerable e) where T : IComparable + { + T max = e.First(); + int argMax = 0; + int i = 1; + foreach (T d in e.Skip(1)) + { + if (d.CompareTo(max) > 0) + { + argMax = i; + max = d; + } + ++i; + } + return argMax; + } + } +} diff --git a/src/Microsoft.ML.Auto/Utils/MLNetUtils/MLNetUtils.cs b/src/Microsoft.ML.Auto/Utils/MLNetUtils/MLNetUtils.cs new file mode 100644 index 0000000000..19e2ed7fd6 --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/MLNetUtils/MLNetUtils.cs @@ -0,0 +1,44 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Auto +{ + internal static class MLNetUtils + { + public static int[] GetRandomPermutation(Random rand, int size) + { + var res = GetIdentityPermutation(size); + Shuffle(rand, res); + return res; + } + + public static int[] GetIdentityPermutation(int size) + { + var res = new int[size]; + for (int i = 0; i < size; i++) + res[i] = i; + return res; + } + + public static void Shuffle(Random rand, Span rgv) + { + for (int iv = 0; iv < rgv.Length; iv++) + Swap(ref rgv[iv], ref rgv[iv + rand.Next(rgv.Length - iv)]); + } + + public static void Swap(ref T a, ref T b) + { + T temp = a; + a = b; + b = temp; + } + + public static int Size(T[] x) + { + return x == null ? 0 : x.Length; + } + } +} diff --git a/src/Microsoft.ML.Auto/Utils/MLNetUtils/ProbabilityFunctions.cs b/src/Microsoft.ML.Auto/Utils/MLNetUtils/ProbabilityFunctions.cs new file mode 100644 index 0000000000..3c45c500eb --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/MLNetUtils/ProbabilityFunctions.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Auto +{ + internal static class ProbabilityFunctions + { + /// + /// The approximate error function. + /// + /// The input parameter, of infinite range. + /// Evaluation of the function + public static double Erf(double x) + { + if (Double.IsInfinity(x)) + return Double.IsPositiveInfinity(x) ? 1.0 : -1.0; + + const double p = 0.3275911; + const double a1 = 0.254829592; + const double a2 = -0.284496736; + const double a3 = 1.421413741; + const double a4 = -1.453152027; + const double a5 = 1.061405429; + double t = 1.0 / (1.0 + p * Math.Abs(x)); + double ev = 1.0 - ((((((((a5 * t) + a4) * t) + a3) * t) + a2) * t + a1) * t) * Math.Exp(-(x * x)); + return x >= 0 ? ev : -ev; + } + } +} diff --git a/src/Microsoft.ML.Auto/Utils/MLNetUtils/RootCursorBase.cs b/src/Microsoft.ML.Auto/Utils/MLNetUtils/RootCursorBase.cs new file mode 100644 index 0000000000..f980bb12c3 --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/MLNetUtils/RootCursorBase.cs @@ -0,0 +1,73 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Auto +{ + internal abstract class RootCursorBase : DataViewRowCursor + { + protected readonly IChannel Ch; + + private long _position; + + private bool _disposed; + + /// + /// Zero-based position of the cursor. + /// + public sealed override long Position => _position; + + /// + /// Convenience property for checking whether the current state of the cursor is one where data can be fetched. + /// + protected bool IsGood => _position >= 0; + + /// + /// Creates an instance of the class + /// + /// Channel provider + protected RootCursorBase(IChannelProvider provider) + { + Contracts.CheckValue(provider, "provider"); + Ch = provider.Start("Cursor"); + _position = -1L; + } + + protected override void Dispose(bool disposing) + { + if (!_disposed) + { + if (disposing) + { + Ch.Dispose(); + _position = -1L; + } + _disposed = true; + base.Dispose(disposing); + } + } + + public sealed override bool MoveNext() + { + if (_disposed) + { + return false; + } + if (MoveNextCore()) + { + _position += 1L; + return true; + } + base.Dispose(); + return false; + } + + /// + /// Core implementation of , called if no prior call to this method + /// has returned . + /// + protected abstract bool MoveNextCore(); + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Auto/Utils/MLNetUtils/RowCursorUtils.cs b/src/Microsoft.ML.Auto/Utils/MLNetUtils/RowCursorUtils.cs new file mode 100644 index 0000000000..acb8113f7d --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/MLNetUtils/RowCursorUtils.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.ML.Auto +{ + internal static class RowCursorUtils + { + /// + /// Given a collection of , that is a subset of the Schema of the data, create a predicate, + /// that when passed a column index, will return true or false, based on whether + /// the column with the given is part of the . + /// + /// The subset of columns from the that are needed from this . + /// The from where the columnsNeeded originate. + internal static Func FromColumnsToPredicate(IEnumerable columnsNeeded, DataViewSchema sourceSchema) + { + Contracts.CheckValue(columnsNeeded, nameof(columnsNeeded)); + Contracts.CheckValue(sourceSchema, nameof(sourceSchema)); + + bool[] indicesRequested = new bool[sourceSchema.Count]; + + foreach (var col in columnsNeeded) + { + if (col.Index >= indicesRequested.Length) + throw Contracts.Except($"The requested column: {col} is not part of the {nameof(sourceSchema)}"); + + indicesRequested[col.Index] = true; + } + + return c => indicesRequested[c]; + } + + internal const string FetchValueStateError = "Values cannot be fetched at this time. This method was called either before the first call to " + + nameof(DataViewRowCursor.MoveNext) + ", or at any point after that method returned false."; + } +} diff --git a/src/Microsoft.ML.Auto/Utils/MLNetUtils/VBufferUtils.cs b/src/Microsoft.ML.Auto/Utils/MLNetUtils/VBufferUtils.cs new file mode 100644 index 0000000000..62132e42cf --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/MLNetUtils/VBufferUtils.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using System; + +namespace Microsoft.ML.Auto +{ + internal class VBufferUtils + { + public static bool HasNaNs(in VBuffer buffer) + { + var values = buffer.GetValues(); + for (int i = 0; i < values.Length; i++) + { + if (Single.IsNaN(values[i])) + return true; + } + return false; + } + } +} diff --git a/src/Microsoft.ML.Auto/Utils/MLNetUtils/VectorUtils.cs b/src/Microsoft.ML.Auto/Utils/MLNetUtils/VectorUtils.cs new file mode 100644 index 0000000000..940097342d --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/MLNetUtils/VectorUtils.cs @@ -0,0 +1,39 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Auto +{ + internal static class VectorUtils + { + public static double GetMean(double[] vector) + { + double sum = 0; + for (int i = 0; i < vector.Length; i++) + { + sum += vector[i]; + } + return sum / vector.Length; + } + + public static double GetStandardDeviation(double[] vector) + { + return GetStandardDeviation(vector, GetMean(vector)); + } + + private static double GetStandardDeviation(double[] vector, double mean) + { + double sum = 0; + int length = vector.Length; + double tmp; + for (int i = 0; i < length; i++) + { + tmp = vector[i] - mean; + sum += tmp * tmp; + } + return Math.Sqrt(sum / length); + } + } +} diff --git a/src/Microsoft.ML.Auto/Utils/SplitUtil.cs b/src/Microsoft.ML.Auto/Utils/SplitUtil.cs new file mode 100644 index 0000000000..34f5310807 --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/SplitUtil.cs @@ -0,0 +1,61 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.ML.Auto +{ + internal static class SplitUtil + { + private const string CrossValEmptyFoldErrorMsg = @"Cross validation split has 0 rows. Perhaps " + + "try increasing number of rows provided in training data, or lowering specified number of " + + "cross validation folds."; + + public static (IDataView[] trainDatasets, IDataView[] validationDatasets) CrossValSplit(MLContext context, + IDataView trainData, uint numFolds, string samplingKeyColumn) + { + var originalColumnNames = trainData.Schema.Select(c => c.Name); + var splits = context.Data.CrossValidationSplit(trainData, (int)numFolds, samplingKeyColumnName: samplingKeyColumn); + var trainDatasets = new IDataView[numFolds]; + var validationDatasets = new IDataView[numFolds]; + for (var i = 0; i < numFolds; i++) + { + var split = splits[i]; + trainDatasets[i] = DropAllColumnsExcept(context, split.TrainSet, originalColumnNames); + validationDatasets[i] = DropAllColumnsExcept(context, split.TestSet, originalColumnNames); + if (DatasetDimensionsUtil.IsDataViewEmpty(trainDatasets[i]) || DatasetDimensionsUtil.IsDataViewEmpty(validationDatasets[i])) + { + throw new InvalidOperationException(CrossValEmptyFoldErrorMsg); + } + } + return (trainDatasets, validationDatasets); + } + + /// + /// Split the data into a single train/test split. + /// + public static (IDataView trainData, IDataView validationData) TrainValidateSplit(MLContext context, IDataView trainData, + string samplingKeyColumn) + { + var originalColumnNames = trainData.Schema.Select(c => c.Name); + var splitData = context.Data.TrainTestSplit(trainData, samplingKeyColumnName: samplingKeyColumn); + trainData = DropAllColumnsExcept(context, splitData.TrainSet, originalColumnNames); + var validationData = DropAllColumnsExcept(context, splitData.TestSet, originalColumnNames); + return (trainData, validationData); + } + + private static IDataView DropAllColumnsExcept(MLContext context, IDataView data, IEnumerable columnsToKeep) + { + var allColumns = data.Schema.Select(c => c.Name); + var columnsToDrop = allColumns.Except(columnsToKeep); + if (!columnsToDrop.Any()) + { + return data; + } + return context.Transforms.DropColumns(columnsToDrop.ToArray()).Fit(data).Transform(data); + } + } +} diff --git a/src/Microsoft.ML.Auto/Utils/SweepableParamAttributes.cs b/src/Microsoft.ML.Auto/Utils/SweepableParamAttributes.cs new file mode 100644 index 0000000000..d16f813ef7 --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/SweepableParamAttributes.cs @@ -0,0 +1,213 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Linq; +using System.Text; + +namespace Microsoft.ML.Auto +{ + /// + /// Used to indicate suggested sweep ranges for parameter sweeping. + /// + internal abstract class SweepableParam + { + public string Name { get; set; } + private IComparable _rawValue; + public virtual IComparable RawValue + { + get => _rawValue; + set + { + if (!Frozen) + _rawValue = value; + } + } + + // The raw value will store an index for discrete parameters, + // but sometimes we want the text or numeric value itself, + // not the hot index. The processed value does that for discrete + // params. For other params, it just returns the raw value itself. + public virtual IComparable ProcessedValue() => _rawValue; + + // Allows for hyperparameter value freezing, so that sweeps + // will not alter the current value when true. + public bool Frozen { get; set; } + + // Allows the sweepable param to be set directly using the + // available ValueText attribute on IParameterValues (from + // the ParameterSets used in the old hyperparameter sweepers). + public abstract void SetUsingValueText(string valueText); + + public abstract SweepableParam Clone(); + } + + internal sealed class SweepableDiscreteParam : SweepableParam + { + public object[] Options { get; } + + public SweepableDiscreteParam(string name, object[] values, bool isBool = false) : this(values, isBool) + { + Name = name; + } + + public SweepableDiscreteParam(object[] values, bool isBool = false) + { + Options = isBool ? new object[] { false, true } : values; + } + + public override IComparable RawValue + { + get => base.RawValue; + set + { + var val = Convert.ToInt32(value); + if (!Frozen && 0 <= val && val < Options.Length) + base.RawValue = val; + } + } + + public override void SetUsingValueText(string valueText) + { + for (int i = 0; i < Options.Length; i++) + if (valueText == Options[i].ToString()) + RawValue = i; + } + + private static string TranslateOption(object o) + { + switch (o) + { + case float _: + case double _: + return $"{o}f"; + case long _: + case int _: + case byte _: + case short _: + return o.ToString(); + case bool _: + return o.ToString().ToLower(); + case Enum _: + var type = o.GetType(); + var defaultName = $"Enums.{type.Name}.{o.ToString()}"; + var name = type.FullName?.Replace("+", "."); + if (name == null) + return defaultName; + var index1 = name.LastIndexOf(".", StringComparison.Ordinal); + var index2 = name.Substring(0, index1).LastIndexOf(".", StringComparison.Ordinal) + 1; + if (index2 >= 0) + return $"{name.Substring(index2)}.{o.ToString()}"; + return defaultName; + default: + return $"\"{o}\""; + } + } + + public override SweepableParam Clone() => + new SweepableDiscreteParam(Name, Options) { RawValue = RawValue, Frozen = Frozen }; + + public override string ToString() + { + var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", "; + return $"[{GetType().Name}({name}new object[]{{{string.Join(", ", Options.Select(TranslateOption))}}})]"; + } + + public override IComparable ProcessedValue() => (IComparable)Options[(int)RawValue]; + } + + internal sealed class SweepableFloatParam : SweepableParam + { + public float Min { get; } + public float Max { get; } + public float? StepSize { get; } + public int? NumSteps { get; } + public bool IsLogScale { get; } + + public SweepableFloatParam(string name, float min, float max, float stepSize = -1, int numSteps = -1, + bool isLogScale = false) : this(min, max, stepSize, numSteps, isLogScale) + { + Name = name; + } + + public SweepableFloatParam(float min, float max, float stepSize = -1, int numSteps = -1, bool isLogScale = false) + { + Min = min; + Max = max; + if (!stepSize.Equals(-1)) + StepSize = stepSize; + if (numSteps != -1) + NumSteps = numSteps; + IsLogScale = isLogScale; + } + + public override void SetUsingValueText(string valueText) + { + RawValue = float.Parse(valueText); + } + + public override SweepableParam Clone() => + new SweepableFloatParam(Name, Min, Max, StepSize ?? -1, NumSteps ?? -1, IsLogScale) { RawValue = RawValue, Frozen = Frozen }; + + public override string ToString() + { + var optional = new StringBuilder(); + if (StepSize != null) + optional.Append($", stepSize:{StepSize}"); + if (NumSteps != null) + optional.Append($", numSteps:{NumSteps}"); + if (IsLogScale) + optional.Append($", isLogScale:true"); + var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", "; + return $"[{GetType().Name}({name}{Min}f, {Max}f{optional})]"; + } + } + + internal sealed class SweepableLongParam : SweepableParam + { + public long Min { get; } + public long Max { get; } + public float? StepSize { get; } + public int? NumSteps { get; } + public bool IsLogScale { get; } + + public SweepableLongParam(string name, long min, long max, float stepSize = -1, int numSteps = -1, + bool isLogScale = false) : this(min, max, stepSize, numSteps, isLogScale) + { + Name = name; + } + + public SweepableLongParam(long min, long max, float stepSize = -1, int numSteps = -1, bool isLogScale = false) + { + Min = min; + Max = max; + if (!stepSize.Equals(-1)) + StepSize = stepSize; + if (numSteps != -1) + NumSteps = numSteps; + IsLogScale = isLogScale; + } + + public override void SetUsingValueText(string valueText) + { + RawValue = long.Parse(valueText); + } + + public override SweepableParam Clone() => + new SweepableLongParam(Name, Min, Max, StepSize ?? -1, NumSteps ?? -1, IsLogScale) { RawValue = RawValue, Frozen = Frozen }; + + public override string ToString() + { + var optional = new StringBuilder(); + if (StepSize != null) + optional.Append($", stepSize:{StepSize}"); + if (NumSteps != null) + optional.Append($", numSteps:{NumSteps}"); + if (IsLogScale) + optional.Append($", isLogScale:true"); + var name = string.IsNullOrEmpty(Name) ? "" : $"\"{Name}\", "; + return $"[{GetType().Name}({name}{Min}, {Max}{optional})]"; + } + } +} diff --git a/src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs b/src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs new file mode 100644 index 0000000000..eddd143b37 --- /dev/null +++ b/src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs @@ -0,0 +1,232 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto +{ + internal static class UserInputValidationUtil + { + // column purpose names + private const string LabelColumnPurposeName = "label"; + private const string WeightColumnPurposeName = "weight"; + private const string NumericColumnPurposeName = "numeric"; + private const string CategoricalColumnPurposeName = "categorical"; + private const string TextColumnPurposeName = "text"; + private const string IgnoredColumnPurposeName = "ignored"; + private const string SamplingKeyColumnPurposeName = "sampling key"; + + public static void ValidateExperimentExecuteArgs(IDataView trainData, ColumnInformation columnInformation, + IDataView validationData) + { + ValidateTrainData(trainData); + ValidateColumnInformation(trainData, columnInformation); + ValidateValidationData(trainData, validationData); + } + + public static void ValidateInferColumnsArgs(string path, ColumnInformation columnInformation) + { + ValidateColumnInformation(columnInformation); + ValidatePath(path); + } + + public static void ValidateInferColumnsArgs(string path, string labelColumn) + { + ValidateLabelColumn(labelColumn); + ValidatePath(path); + } + + public static void ValidateInferColumnsArgs(string path) + { + ValidatePath(path); + } + + public static void ValidateNumberOfCVFoldsArg(uint numberOfCVFolds) + { + if (numberOfCVFolds <= 1) + { + throw new ArgumentException($"{nameof(numberOfCVFolds)} must be at least 2", nameof(numberOfCVFolds)); + } + } + + private static void ValidateTrainData(IDataView trainData) + { + if (trainData == null) + { + throw new ArgumentNullException(nameof(trainData), "Training data cannot be null"); + } + + var type = trainData.Schema.GetColumnOrNull(DefaultColumnNames.Features)?.Type.GetItemType(); + if (type != null && type != NumberDataViewType.Single) + { + throw new ArgumentException($"{DefaultColumnNames.Features} column must be of data type Single", nameof(trainData)); + } + } + + private static void ValidateColumnInformation(IDataView trainData, ColumnInformation columnInformation) + { + ValidateColumnInformation(columnInformation); + ValidateTrainDataColumn(trainData, columnInformation.LabelColumnName, LabelColumnPurposeName); + ValidateTrainDataColumn(trainData, columnInformation.ExampleWeightColumnName, WeightColumnPurposeName); + ValidateTrainDataColumn(trainData, columnInformation.SamplingKeyColumnName, SamplingKeyColumnPurposeName); + ValidateTrainDataColumns(trainData, columnInformation.CategoricalColumnNames, CategoricalColumnPurposeName, + new DataViewType[] { NumberDataViewType.Single, TextDataViewType.Instance }); + ValidateTrainDataColumns(trainData, columnInformation.NumericColumnNames, NumericColumnPurposeName, + new DataViewType[] { NumberDataViewType.Single, BooleanDataViewType.Instance }); + ValidateTrainDataColumns(trainData, columnInformation.TextColumnNames, TextColumnPurposeName, + new DataViewType[] { TextDataViewType.Instance }); + ValidateTrainDataColumns(trainData, columnInformation.IgnoredColumnNames, IgnoredColumnPurposeName); + } + + private static void ValidateColumnInformation(ColumnInformation columnInformation) + { + ValidateLabelColumn(columnInformation.LabelColumnName); + + ValidateColumnInfoEnumerationProperty(columnInformation.CategoricalColumnNames, CategoricalColumnPurposeName); + ValidateColumnInfoEnumerationProperty(columnInformation.NumericColumnNames, NumericColumnPurposeName); + ValidateColumnInfoEnumerationProperty(columnInformation.TextColumnNames, TextColumnPurposeName); + ValidateColumnInfoEnumerationProperty(columnInformation.IgnoredColumnNames, IgnoredColumnPurposeName); + + // keep a list of all columns, to detect duplicates + var allColumns = new List(); + allColumns.Add(columnInformation.LabelColumnName); + if (columnInformation.ExampleWeightColumnName != null) { allColumns.Add(columnInformation.ExampleWeightColumnName); } + if (columnInformation.CategoricalColumnNames != null) { allColumns.AddRange(columnInformation.CategoricalColumnNames); } + if (columnInformation.NumericColumnNames != null) { allColumns.AddRange(columnInformation.NumericColumnNames); } + if (columnInformation.TextColumnNames != null) { allColumns.AddRange(columnInformation.TextColumnNames); } + if (columnInformation.IgnoredColumnNames != null) { allColumns.AddRange(columnInformation.IgnoredColumnNames); } + + var duplicateColName = FindFirstDuplicate(allColumns); + if (duplicateColName != null) + { + throw new ArgumentException($"Duplicate column name {duplicateColName} is present in two or more distinct properties of provided column information", nameof(columnInformation)); + } + } + + private static void ValidateColumnInfoEnumerationProperty(IEnumerable columns, string columnPurpose) + { + if (columns?.Contains(null) == true) + { + throw new ArgumentException($"Null column string was specified as {columnPurpose} in column information"); + } + } + + private static void ValidateLabelColumn(string labelColumn) + { + if (labelColumn == null) + { + throw new ArgumentException("Provided label column cannot be null"); + } + } + + private static void ValidatePath(string path) + { + if (path == null) + { + throw new ArgumentNullException(nameof(path), "Provided path cannot be null"); + } + + var fileInfo = new FileInfo(path); + + if (!fileInfo.Exists) + { + throw new ArgumentException($"File '{path}' does not exist", nameof(path)); + } + + if (fileInfo.Length == 0) + { + throw new ArgumentException($"File at path '{path}' cannot be empty", nameof(path)); + } + } + + private static void ValidateValidationData(IDataView trainData, IDataView validationData) + { + if (validationData == null) + { + return; + } + + const string schemaMismatchError = "Training data and validation data schemas do not match."; + + if (trainData.Schema.Count != validationData.Schema.Count) + { + throw new ArgumentException($"{schemaMismatchError} Train data has '{trainData.Schema.Count}' columns," + + $"and validation data has '{validationData.Schema.Count}' columns.", nameof(validationData)); + } + + foreach (var trainCol in trainData.Schema) + { + var validCol = validationData.Schema.GetColumnOrNull(trainCol.Name); + if (validCol == null) + { + throw new ArgumentException($"{schemaMismatchError} Column '{trainCol.Name}' exsits in train data, but not in validation data.", nameof(validationData)); + } + + if (trainCol.Type != validCol.Value.Type) + { + throw new ArgumentException($"{schemaMismatchError} Column '{trainCol.Name}' is of type {trainCol.Type} in train data, and type " + + $"{validCol.Value.Type} in validation data.", nameof(validationData)); + } + } + } + + private static void ValidateTrainDataColumns(IDataView trainData, IEnumerable columnNames, string columnPurpose, + IEnumerable allowedTypes = null) + { + if (columnNames == null) + { + return; + } + + foreach (var columnName in columnNames) + { + ValidateTrainDataColumn(trainData, columnName, columnPurpose, allowedTypes); + } + } + + private static void ValidateTrainDataColumn(IDataView trainData, string columnName, string columnPurpose, IEnumerable allowedTypes = null) + { + if (columnName == null) + { + return; + } + + var nullableColumn = trainData.Schema.GetColumnOrNull(columnName); + if (nullableColumn == null) + { + throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' not found in training data."); + } + + if(allowedTypes == null) + { + return; + } + var column = nullableColumn.Value; + var itemType = column.Type.GetItemType(); + if (!allowedTypes.Contains(itemType)) + { + if (allowedTypes.Count() == 1) + { + throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' was of type {itemType}, " + + $"but only type {allowedTypes.First()} is allowed."); + } + else + { + throw new ArgumentException($"Provided {columnPurpose} column '{columnName}' was of type {itemType}, " + + $"but only types {string.Join(", ", allowedTypes)} are allowed."); + } + } + } + + private static string FindFirstDuplicate(IEnumerable values) + { + var groups = values.GroupBy(v => v); + return groups.FirstOrDefault(g => g.Count() > 1)?.Key; + } + } +} diff --git a/src/mlnet/Assembly.cs b/src/mlnet/Assembly.cs new file mode 100644 index 0000000000..6b2f72d314 --- /dev/null +++ b/src/mlnet/Assembly.cs @@ -0,0 +1,7 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("mlnet.Tests, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")] \ No newline at end of file diff --git a/src/mlnet/AutoML/AutoMLDebugLogger.cs b/src/mlnet/AutoML/AutoMLDebugLogger.cs new file mode 100644 index 0000000000..e6def1fd81 --- /dev/null +++ b/src/mlnet/AutoML/AutoMLDebugLogger.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Auto; +using NLog; + +namespace Microsoft.ML.CLI.AutoML +{ + internal class AutoMLDebugLogger: IDebugLogger + { + public static AutoMLDebugLogger Instance = new AutoMLDebugLogger(); + + private static Logger logger = LogManager.GetCurrentClassLogger(); + + public void Log(LogSeverity severity, string message) + { + logger.Log(LogLevel.Trace, message); + } + } +} diff --git a/src/mlnet/AutoML/AutoMLEngine.cs b/src/mlnet/AutoML/AutoMLEngine.cs new file mode 100644 index 0000000000..713009205f --- /dev/null +++ b/src/mlnet/AutoML/AutoMLEngine.cs @@ -0,0 +1,95 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.ML.Auto; +using Microsoft.ML.CLI.AutoML; +using Microsoft.ML.CLI.Data; +using Microsoft.ML.CLI.ShellProgressBar; +using Microsoft.ML.CLI.Utilities; +using Microsoft.ML.Data; +using NLog; + +namespace Microsoft.ML.CLI.CodeGenerator +{ + internal class AutoMLEngine : IAutoMLEngine + { + private NewCommandSettings settings; + private TaskKind taskKind; + private bool? enableCaching; + private static Logger logger = LogManager.GetCurrentClassLogger(); + + public AutoMLEngine(NewCommandSettings settings) + { + this.settings = settings; + this.taskKind = Utils.GetTaskKind(settings.MlTask); + this.enableCaching = Utils.GetCacheSettings(settings.Cache); + } + + public ColumnInferenceResults InferColumns(MLContext context, ColumnInformation columnInformation) + { + // Check what overload method of InferColumns needs to be called. + logger.Log(LogLevel.Trace, Strings.InferColumns); + ColumnInferenceResults columnInference = null; + var dataset = settings.Dataset.FullName; + if (columnInformation.LabelColumnName != null) + { + columnInference = context.Auto().InferColumns(dataset, columnInformation, groupColumns: false); + } + else + { + columnInference = context.Auto().InferColumns(dataset, settings.LabelColumnIndex, hasHeader: settings.HasHeader, groupColumns: false); + } + + return columnInference; + } + + IEnumerable> IAutoMLEngine.ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressBar progressBar) + { + var progressReporter = new ProgressHandlers.BinaryClassificationHandler(optimizationMetric, progressBar); + var result = context.Auto() + .CreateBinaryClassificationExperiment(new BinaryExperimentSettings() + { + MaxExperimentTimeInSeconds = settings.MaxExplorationTime, + CacheBeforeTrainer = this.enableCaching, + OptimizingMetric = optimizationMetric, + DebugLogger = AutoMLDebugLogger.Instance + }) + .Execute(trainData, validationData, columnInformation, progressHandler: progressReporter); + logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline); + return result; + } + + IEnumerable> IAutoMLEngine.ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressBar progressBar) + { + var progressReporter = new ProgressHandlers.RegressionHandler(optimizationMetric, progressBar); + var result = context.Auto() + .CreateRegressionExperiment(new RegressionExperimentSettings() + { + MaxExperimentTimeInSeconds = settings.MaxExplorationTime, + OptimizingMetric = optimizationMetric, + CacheBeforeTrainer = this.enableCaching, + DebugLogger = AutoMLDebugLogger.Instance + }).Execute(trainData, validationData, columnInformation, progressHandler: progressReporter); + logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline); + return result; + } + + IEnumerable> IAutoMLEngine.ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressBar progressBar) + { + var progressReporter = new ProgressHandlers.MulticlassClassificationHandler(optimizationMetric, progressBar); + var result = context.Auto() + .CreateMulticlassClassificationExperiment(new MulticlassExperimentSettings() + { + MaxExperimentTimeInSeconds = settings.MaxExplorationTime, + CacheBeforeTrainer = this.enableCaching, + OptimizingMetric = optimizationMetric, + DebugLogger = AutoMLDebugLogger.Instance + }).Execute(trainData, validationData, columnInformation, progressHandler: progressReporter); + logger.Log(LogLevel.Trace, Strings.RetrieveBestPipeline); + return result; + } + + } +} diff --git a/src/mlnet/AutoML/IAutoMLEngine.cs b/src/mlnet/AutoML/IAutoMLEngine.cs new file mode 100644 index 0000000000..b7ffc57652 --- /dev/null +++ b/src/mlnet/AutoML/IAutoMLEngine.cs @@ -0,0 +1,23 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.ML.Auto; +using Microsoft.ML.CLI.ShellProgressBar; +using Microsoft.ML.Data; + +namespace Microsoft.ML.CLI.CodeGenerator +{ + internal interface IAutoMLEngine + { + ColumnInferenceResults InferColumns(MLContext context, ColumnInformation columnInformation); + + IEnumerable> ExploreBinaryClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, BinaryClassificationMetric optimizationMetric, ProgressBar progressBar = null); + + IEnumerable> ExploreMultiClassificationModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, MulticlassClassificationMetric optimizationMetric, ProgressBar progressBar = null); + + IEnumerable> ExploreRegressionModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation, RegressionMetric optimizationMetric, ProgressBar progressBar = null); + + } +} diff --git a/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs b/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs new file mode 100644 index 0000000000..908ef6d9d3 --- /dev/null +++ b/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs @@ -0,0 +1,295 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using Microsoft.CodeAnalysis; +using Microsoft.ML.Auto; +using Microsoft.ML.CLI.Templates.Console; +using Microsoft.ML.CLI.Utilities; + +namespace Microsoft.ML.CLI.CodeGenerator.CSharp +{ + internal class CodeGenerator : IProjectGenerator + { + private readonly Pipeline pipeline; + private readonly CodeGeneratorSettings settings; + private readonly ColumnInferenceResults columnInferenceResult; + + internal CodeGenerator(Pipeline pipeline, ColumnInferenceResults columnInferenceResult, CodeGeneratorSettings settings) + { + this.pipeline = pipeline; + this.columnInferenceResult = columnInferenceResult; + this.settings = settings; + } + + public void GenerateOutput() + { + // Get Namespace + var namespaceValue = Utils.Normalize(settings.OutputName); + var labelType = columnInferenceResult.TextLoaderOptions.Columns.Where(t => t.Name == columnInferenceResult.ColumnInformation.LabelColumnName).First().DataKind; + Type labelTypeCsharp = Utils.GetCSharpType(labelType); + + // Generate Model Project + var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp); + + // Write files to disk. + var modelprojectDir = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.Model"); + var dataModelsDir = Path.Combine(modelprojectDir, "DataModels"); + var modelProjectName = $"{settings.OutputName}.Model.csproj"; + + Utils.WriteOutputToFiles(modelProjectContents.ObservationCSFileContent, "Observation.cs", dataModelsDir); + Utils.WriteOutputToFiles(modelProjectContents.PredictionCSFileContent, "Prediction.cs", dataModelsDir); + Utils.WriteOutputToFiles(modelProjectContents.ModelProjectFileContent, modelProjectName, modelprojectDir); + + // Generate ConsoleApp Project + var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp); + + // Write files to disk. + var consoleAppProjectDir = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.ConsoleApp"); + var consoleAppProjectName = $"{settings.OutputName}.ConsoleApp.csproj"; + + Utils.WriteOutputToFiles(consoleAppProjectContents.ConsoleAppProgramCSFileContent, "Program.cs", consoleAppProjectDir); + Utils.WriteOutputToFiles(consoleAppProjectContents.modelBuilderCSFileContent, "ModelBuilder.cs", consoleAppProjectDir); + Utils.WriteOutputToFiles(consoleAppProjectContents.ConsoleAppProjectFileContent, consoleAppProjectName, consoleAppProjectDir); + + // New solution file. + Utils.CreateSolutionFile(settings.OutputName, settings.OutputBaseDir); + + // Add projects to solution + var solutionPath = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.sln"); + Utils.AddProjectsToSolution(modelprojectDir, modelProjectName, consoleAppProjectDir, consoleAppProjectName, solutionPath); + } + + internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent, string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue, Type labelTypeCsharp) + { + var predictProgramCSFileContent = GeneratePredictProgramCSFileContent(namespaceValue); + predictProgramCSFileContent = Utils.FormatCode(predictProgramCSFileContent); + + var predictProjectFileContent = GeneratPredictProjectFileContent(namespaceValue, true, true); + + var transformsAndTrainers = GenerateTransformsAndTrainers(); + var modelBuilderCSFileContent = GenerateModelBuilderCSFileContent(transformsAndTrainers.Usings, transformsAndTrainers.TrainerMethod, transformsAndTrainers.PreTrainerTransforms, transformsAndTrainers.PostTrainerTransforms, namespaceValue, pipeline.CacheBeforeTrainer, labelTypeCsharp.Name); + modelBuilderCSFileContent = Utils.FormatCode(modelBuilderCSFileContent); + + return (predictProgramCSFileContent, predictProjectFileContent, modelBuilderCSFileContent); + } + + internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp) + { + var classLabels = this.GenerateClassLabels(); + var observationCSFileContent = GenerateObservationCSFileContent(namespaceValue, classLabels); + observationCSFileContent = Utils.FormatCode(observationCSFileContent); + var predictionCSFileContent = GeneratePredictionCSFileContent(labelTypeCsharp.Name, namespaceValue); + predictionCSFileContent = Utils.FormatCode(predictionCSFileContent); + var modelProjectFileContent = GenerateModelProjectFileContent(); + return (observationCSFileContent, predictionCSFileContent, modelProjectFileContent); + } + + internal (string Usings, string TrainerMethod, List PreTrainerTransforms, List PostTrainerTransforms) GenerateTransformsAndTrainers() + { + StringBuilder usingsBuilder = new StringBuilder(); + var usings = new List(); + + // Get pre-trainer transforms + var nodes = pipeline.Nodes.TakeWhile(t => t.NodeType == PipelineNodeType.Transform); + var preTrainerTransformsAndUsings = this.GenerateTransformsAndUsings(nodes); + + // Get post trainer transforms + nodes = pipeline.Nodes.SkipWhile(t => t.NodeType == PipelineNodeType.Transform) + .SkipWhile(t => t.NodeType == PipelineNodeType.Trainer) //skip the trainer + .TakeWhile(t => t.NodeType == PipelineNodeType.Transform); //post trainer transforms + var postTrainerTransformsAndUsings = this.GenerateTransformsAndUsings(nodes); + + //Get trainer code and its associated usings. + (string trainerMethod, string[] trainerUsings) = this.GenerateTrainerAndUsings(); + if (trainerUsings != null) + { + usings.AddRange(trainerUsings); + } + + //Get transforms code and its associated (unique) usings. + var preTrainerTransforms = preTrainerTransformsAndUsings?.Select(t => t.Item1).ToList(); + var postTrainerTransforms = postTrainerTransformsAndUsings?.Select(t => t.Item1).ToList(); + usings.AddRange(preTrainerTransformsAndUsings.Where(t => t.Item2 != null).SelectMany(t => t.Item2)); + usings.AddRange(postTrainerTransformsAndUsings.Where(t => t.Item2 != null).SelectMany(t => t.Item2)); + usings = usings.Distinct().ToList(); + + //Combine all using statements to actual text. + usingsBuilder = new StringBuilder(); + usings.ForEach(t => + { + if (t != null) + usingsBuilder.Append(t); + }); + + return (usingsBuilder.ToString(), trainerMethod, preTrainerTransforms, postTrainerTransforms); + } + + internal IList<(string, string[])> GenerateTransformsAndUsings(IEnumerable nodes) + { + //var nodes = pipeline.Nodes.TakeWhile(t => t.NodeType == PipelineNodeType.Transform); + //var nodes = pipeline.Nodes.Where(t => t.NodeType == PipelineNodeType.Transform); + var results = new List<(string, string[])>(); + foreach (var node in nodes) + { + ITransformGenerator generator = TransformGeneratorFactory.GetInstance(node); + results.Add((generator.GenerateTransformer(), generator.GenerateUsings())); + } + + return results; + } + + internal (string, string[]) GenerateTrainerAndUsings() + { + if (pipeline == null) + throw new ArgumentNullException(nameof(pipeline)); + var node = pipeline.Nodes.Where(t => t.NodeType == PipelineNodeType.Trainer).First(); + if (node == null) + throw new ArgumentException($"The trainer was not found."); + + ITrainerGenerator generator = TrainerGeneratorFactory.GetInstance(node); + var trainerString = generator.GenerateTrainer(); + var trainerUsings = generator.GenerateUsings(); + return (trainerString, trainerUsings); + } + + internal IList GenerateClassLabels() + { + IList result = new List(); + foreach (var column in columnInferenceResult.TextLoaderOptions.Columns) + { + StringBuilder sb = new StringBuilder(); + int range = (column.Source[0].Max - column.Source[0].Min).Value; + bool isArray = range > 0; + sb.Append(Symbols.PublicSymbol); + sb.Append(Symbols.Space); + switch (column.DataKind) + { + case Microsoft.ML.Data.DataKind.String: + sb.Append(Symbols.StringSymbol); + break; + case Microsoft.ML.Data.DataKind.Boolean: + sb.Append(Symbols.BoolSymbol); + break; + case Microsoft.ML.Data.DataKind.Single: + sb.Append(Symbols.FloatSymbol); + break; + case Microsoft.ML.Data.DataKind.Double: + sb.Append(Symbols.DoubleSymbol); + break; + case Microsoft.ML.Data.DataKind.Int32: + sb.Append(Symbols.IntSymbol); + break; + case Microsoft.ML.Data.DataKind.UInt32: + sb.Append(Symbols.UIntSymbol); + break; + case Microsoft.ML.Data.DataKind.Int64: + sb.Append(Symbols.LongSymbol); + break; + case Microsoft.ML.Data.DataKind.UInt64: + sb.Append(Symbols.UlongSymbol); + break; + default: + throw new ArgumentException($"The data type '{column.DataKind}' is not handled currently."); + + } + + if (range > 0) + { + result.Add($"[ColumnName(\"{column.Name}\"),LoadColumn({column.Source[0].Min}, {column.Source[0].Max}) VectorType({(range + 1)})]"); + sb.Append("[]"); + } + else + { + result.Add($"[ColumnName(\"{column.Name}\"), LoadColumn({column.Source[0].Min})]"); + } + sb.Append(" "); + sb.Append(Utils.Normalize(column.Name)); + sb.Append("{get; set;}"); + result.Add(sb.ToString()); + result.Add("\r\n"); + } + return result; + } + + #region Model project + private static string GenerateModelProjectFileContent() + { + ModelProject modelProject = new ModelProject(); + return modelProject.TransformText(); + } + + private string GeneratePredictionCSFileContent(string predictionLabelType, string namespaceValue) + { + PredictionClass predictionClass = new PredictionClass() { TaskType = settings.MlTask.ToString(), PredictionLabelType = predictionLabelType, Namespace = namespaceValue }; + return predictionClass.TransformText(); + } + + private string GenerateObservationCSFileContent(string namespaceValue, IList classLabels) + { + ObservationClass observationClass = new ObservationClass() { Namespace = namespaceValue, ClassLabels = classLabels }; + return observationClass.TransformText(); + } + #endregion + + #region Predict Project + private static string GeneratPredictProjectFileContent(string namespaceValue, bool includeMklComponentsPackage, bool includeLightGBMPackage) + { + var predictProjectFileContent = new PredictProject() { Namespace = namespaceValue, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeLightGBMPackage = includeLightGBMPackage }; + return predictProjectFileContent.TransformText(); + } + + private string GeneratePredictProgramCSFileContent(string namespaceValue) + { + PredictProgram predictProgram = new PredictProgram() + { + TaskType = settings.MlTask.ToString(), + LabelName = settings.LabelName, + Namespace = namespaceValue, + TestDataPath = settings.TestDataset, + TrainDataPath = settings.TrainDataset, + HasHeader = columnInferenceResult.TextLoaderOptions.HasHeader, + Separator = columnInferenceResult.TextLoaderOptions.Separators.FirstOrDefault(), + AllowQuoting = columnInferenceResult.TextLoaderOptions.AllowQuoting, + AllowSparse = columnInferenceResult.TextLoaderOptions.AllowSparse, + }; + return predictProgram.TransformText(); + } + + private string GenerateModelBuilderCSFileContent(string usings, + string trainerMethod, + List preTrainerTransforms, + List postTrainerTransforms, + string namespaceValue, + bool cacheBeforeTrainer, + string predictionLabelType) + { + var modelBuilder = new ModelBuilder() + { + PreTrainerTransforms = preTrainerTransforms, + PostTrainerTransforms = postTrainerTransforms, + HasHeader = columnInferenceResult.TextLoaderOptions.HasHeader, + Separator = columnInferenceResult.TextLoaderOptions.Separators.FirstOrDefault(), + AllowQuoting = columnInferenceResult.TextLoaderOptions.AllowQuoting, + AllowSparse = columnInferenceResult.TextLoaderOptions.AllowSparse, + Trainer = trainerMethod, + GeneratedUsings = usings, + Path = settings.TrainDataset, + TestPath = settings.TestDataset, + TaskType = settings.MlTask.ToString(), + Namespace = namespaceValue, + LabelName = settings.LabelName, + CacheBeforeTrainer = cacheBeforeTrainer, + }; + + return modelBuilder.TransformText(); + } + #endregion + + } +} diff --git a/src/mlnet/CodeGenerator/CSharp/CodeGeneratorSettings.cs b/src/mlnet/CodeGenerator/CSharp/CodeGeneratorSettings.cs new file mode 100644 index 0000000000..166d7ac70c --- /dev/null +++ b/src/mlnet/CodeGenerator/CSharp/CodeGeneratorSettings.cs @@ -0,0 +1,22 @@ +using Microsoft.ML.Auto; + +namespace Microsoft.ML.CLI.CodeGenerator.CSharp +{ + internal class CodeGeneratorSettings + { + internal string LabelName { get; set; } + + internal string ModelPath { get; set; } + + internal string OutputName { get; set; } + + internal string OutputBaseDir { get; set; } + + internal string TrainDataset { get; set; } + + internal string TestDataset { get; set; } + + internal TaskKind MlTask { get; set; } + + } +} diff --git a/src/mlnet/CodeGenerator/CSharp/Symbols.cs b/src/mlnet/CodeGenerator/CSharp/Symbols.cs new file mode 100644 index 0000000000..de045e808f --- /dev/null +++ b/src/mlnet/CodeGenerator/CSharp/Symbols.cs @@ -0,0 +1,21 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.CLI.CodeGenerator.CSharp +{ + internal static class Symbols + { + internal static readonly string Space = " "; + internal static readonly string StringSymbol = "string"; + internal static readonly string PublicSymbol = "public"; + internal static readonly string FloatSymbol = "float"; + internal static readonly string IntSymbol = "int"; + internal static readonly string UIntSymbol = "uint"; + internal static readonly string LongSymbol = "long"; + internal static readonly string UlongSymbol = "ulong"; + internal static readonly string BoolSymbol = "bool"; + internal static readonly string DoubleSymbol = "double"; + + } +} diff --git a/src/mlnet/CodeGenerator/CSharp/TrainerGeneratorBase.cs b/src/mlnet/CodeGenerator/CSharp/TrainerGeneratorBase.cs new file mode 100644 index 0000000000..b6b98f30f1 --- /dev/null +++ b/src/mlnet/CodeGenerator/CSharp/TrainerGeneratorBase.cs @@ -0,0 +1,162 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.ML.Auto; + +namespace Microsoft.ML.CLI.CodeGenerator.CSharp +{ + /// + /// Supports generation of code for trainers (Binary,Multi,Regression) + /// Ova is an exception though. Need to figure out how to tackle that. + /// + internal abstract class TrainerGeneratorBase : ITrainerGenerator + { + private PipelineNode node; + private Dictionary arguments = new Dictionary(); + private bool hasAdvancedSettings = false; + private string seperator = null; + + //abstract properties + internal abstract string OptionsName { get; } + internal abstract string MethodName { get; } + internal abstract IDictionary NamedParameters { get; } + internal abstract string[] Usings { get; } + + /// + /// Generates an instance of TrainerGenerator + /// + /// + protected TrainerGeneratorBase(PipelineNode node) + { + Initialize(node); + } + + private void Initialize(PipelineNode node) + { + this.node = node; + if (NamedParameters != null) + { + hasAdvancedSettings = node.Properties.Keys.Any(t => !NamedParameters.ContainsKey(t)); + } + seperator = hasAdvancedSettings ? "=" : ":"; + if (!node.Properties.ContainsKey("LabelColumnName")) + { + node.Properties.Add("LabelColumnName", "Label"); + } + node.Properties.Add("FeatureColumnName", "Features"); + + foreach (var kv in node.Properties) + { + object value = null; + + //For Nullable values. + if (kv.Value == null) + continue; + Type type = kv.Value.GetType(); + if (type == typeof(bool)) + { + //True to true + value = ((bool)kv.Value).ToString().ToLowerInvariant(); + } + if (type == typeof(float)) + { + //0.0 to 0.0f + value = ((float)kv.Value).ToString() + "f"; + } + + if (type == typeof(int) || type == typeof(double) || type == typeof(long)) + { + value = (kv.Value).ToString(); + } + + if (type == typeof(string)) + { + var val = kv.Value.ToString(); + if (val == "") + continue; // This is temporary fix and needs to be fixed in AutoML SDK + + // string to "string" + value = "\"" + val + "\""; + } + + if (type == typeof(CustomProperty)) + { + value = kv.Value; + } + //more special cases to handle + + if (NamedParameters != null) + { + arguments.Add(hasAdvancedSettings ? kv.Key : NamedParameters[kv.Key], value); + } + else + { + arguments.Add(kv.Key, value); + } + + } + } + + internal static string BuildComplexParameter(string paramName, IDictionary arguments, string seperator) + { + StringBuilder sb = new StringBuilder(); + sb.Append("new "); + sb.Append(paramName); + sb.Append("(){"); + sb.Append(AppendArguments(arguments, seperator)); + sb.Append("}"); + return sb.ToString(); + } + + internal static string AppendArguments(IDictionary arguments, string seperator) + { + if (arguments.Count == 0) + return string.Empty; + + StringBuilder sb = new StringBuilder(); + foreach (var kv in arguments) + { + sb.Append(kv.Key); + sb.Append(seperator); + if (kv.Value.GetType() == typeof(CustomProperty)) + sb.Append(BuildComplexParameter(((CustomProperty)kv.Value).Name, ((CustomProperty)kv.Value).Properties, "=")); + else + sb.Append(kv.Value.ToString()); + sb.Append(","); + } + sb.Remove(sb.Length - 1, 1); //remove the last , + return sb.ToString(); + } + + public virtual string GenerateTrainer() + { + StringBuilder sb = new StringBuilder(); + sb.Append(MethodName); + sb.Append("("); + if (hasAdvancedSettings) + { + var paramString = BuildComplexParameter(OptionsName, arguments, "="); + sb.Append(paramString); + } + else + { + sb.Append(AppendArguments(arguments, ":")); + } + sb.Append(")"); + return sb.ToString(); + } + + public virtual string[] GenerateUsings() + { + if (hasAdvancedSettings) + return Usings; + + return null; + } + } +} diff --git a/src/mlnet/CodeGenerator/CSharp/TrainerGeneratorFactory.cs b/src/mlnet/CodeGenerator/CSharp/TrainerGeneratorFactory.cs new file mode 100644 index 0000000000..0b6e82c578 --- /dev/null +++ b/src/mlnet/CodeGenerator/CSharp/TrainerGeneratorFactory.cs @@ -0,0 +1,76 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Linq; +using Microsoft.ML.Auto; +using static Microsoft.ML.CLI.CodeGenerator.CSharp.TrainerGenerators; + +namespace Microsoft.ML.CLI.CodeGenerator.CSharp +{ + internal interface ITrainerGenerator + { + string GenerateTrainer(); + + string[] GenerateUsings(); + } + + internal static class TrainerGeneratorFactory + { + internal static ITrainerGenerator GetInstance(PipelineNode node) + { + if (Enum.TryParse(node.Name, out TrainerName trainer)) + { + switch (trainer) + { + case TrainerName.LightGbmBinary: + return new LightGbmBinary(node); + case TrainerName.LightGbmMulti: + return new LightGbmMulti(node); + case TrainerName.LightGbmRegression: + return new LightGbmRegression(node); + case TrainerName.AveragedPerceptronBinary: + return new AveragedPerceptron(node); + case TrainerName.FastForestBinary: + return new FastForestClassification(node); + case TrainerName.FastForestRegression: + return new FastForestRegression(node); + case TrainerName.FastTreeBinary: + return new FastTreeClassification(node); + case TrainerName.FastTreeRegression: + return new FastTreeRegression(node); + case TrainerName.FastTreeTweedieRegression: + return new FastTreeTweedie(node); + case TrainerName.LinearSvmBinary: + return new LinearSvm(node); + case TrainerName.LbfgsLogisticRegressionBinary: + return new LbfgsLogisticRegressionBinary(node); + case TrainerName.LbfgsMaximumEntropyMulti: + return new LbfgsMaximumEntropyMulti(node); + case TrainerName.OnlineGradientDescentRegression: + return new OnlineGradientDescentRegression(node); + case TrainerName.OlsRegression: + return new OlsRegression(node); + case TrainerName.LbfgsPoissonRegression: + return new LbfgsPoissonRegression(node); + case TrainerName.SdcaLogisticRegressionBinary: + return new StochasticDualCoordinateAscentBinary(node); + case TrainerName.SdcaMaximumEntropyMulti: + return new StochasticDualCoordinateAscentMulti(node); + case TrainerName.SdcaRegression: + return new StochasticDualCoordinateAscentRegression(node); + case TrainerName.SgdCalibratedBinary: + return new SgdCalibratedBinary(node); + case TrainerName.SymbolicSgdLogisticRegressionBinary: + return new SymbolicSgdLogisticRegressionBinary(node); + case TrainerName.Ova: + return new OneVersusAll(node); + default: + throw new ArgumentException($"The trainer '{trainer}' is not handled currently."); + } + } + throw new ArgumentException($"The trainer '{node.Name}' is not handled currently."); + } + } +} diff --git a/src/mlnet/CodeGenerator/CSharp/TrainerGenerators.cs b/src/mlnet/CodeGenerator/CSharp/TrainerGenerators.cs new file mode 100644 index 0000000000..ee606a6cde --- /dev/null +++ b/src/mlnet/CodeGenerator/CSharp/TrainerGenerators.cs @@ -0,0 +1,562 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Text; +using Microsoft.ML.Auto; + +namespace Microsoft.ML.CLI.CodeGenerator.CSharp +{ + internal static class TrainerGenerators + { + internal abstract class LightGbmBase : TrainerGeneratorBase + { + //ClassName of the trainer + internal override string MethodName => "LightGbm"; + + //The named parameters to the trainer. + internal override IDictionary NamedParameters + { + get + { + return + new Dictionary() + { + {"NumberOfLeaves","numberOfLeaves" }, + {"LabelColumnName","labelColumnName" }, + {"FeatureColumnName","featureColumnName" }, + {"MinimumExampleCountPerLeaf","minimumExampleCountPerLeaf" }, + {"LearningRate","learningRate" }, + {"NumberOfIterations","numberOfIterations" }, + {"ExampleWeightColumnName","exampleWeightColumnName" } + }; + } + } + + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers.LightGbm;\r\n" }; + + public LightGbmBase(PipelineNode node) : base(node) + { + } + } + + internal class LightGbmBinary : LightGbmBase + { + internal override string OptionsName => "LightGbmBinaryTrainer.Options"; + + public LightGbmBinary(PipelineNode node) : base(node) + { + } + } + + internal class LightGbmMulti : LightGbmBase + { + internal override string OptionsName => "LightGbmMulticlassTrainer.Options"; + + public LightGbmMulti(PipelineNode node) : base(node) + { + } + } + + internal class LightGbmRegression : LightGbmBase + { + internal override string OptionsName => "LightGbmRegressionTrainer.Options"; + + public LightGbmRegression(PipelineNode node) : base(node) + { + } + } + + internal class AveragedPerceptron : TrainerGeneratorBase + { + //ClassName of the trainer + internal override string MethodName => "AveragedPerceptron"; + + //ClassName of the options to trainer + internal override string OptionsName => "AveragedPerceptronTrainer.Options"; + + //The named parameters to the trainer. + internal override IDictionary NamedParameters + { + get + { + return + new Dictionary() + { + {"LabelColumnName","labelColumnName" }, + {"FeatureColumnName","featureColumnName" }, + {"LossFunction","lossFunction" }, + {"LearningRate","learningRate" }, + {"DecreaseLearningRate","decreaseLearningRate" }, + {"L2Regularization","l2Regularization" }, + {"NumberOfIterations","numberOfIterations" } + }; + } + } + + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n " }; + + public AveragedPerceptron(PipelineNode node) : base(node) + { + } + } + + #region FastTree + internal abstract class FastTreeBase : TrainerGeneratorBase + { + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers.FastTree;\r\n" }; + + //The named parameters to the trainer. + internal override IDictionary NamedParameters + { + get + { + return + new Dictionary() + { + {"ExampleWeightColumnName","exampleWeightColumnName" }, + {"LabelColumnName","labelColumnName" }, + {"FeatureColumnName","featureColumnName" }, + {"LearningRate","learningRate" }, + {"NumberOfLeaves","numberOfLeaves" }, + {"NumberOfTrees","numberOfTrees" }, + {"MinimumExampleCountPerLeaf","minimumExampleCountPerLeaf" }, + }; + } + } + + public FastTreeBase(PipelineNode node) : base(node) + { + } + } + + internal class FastForestClassification : FastTreeBase + { + //ClassName of the trainer + internal override string MethodName => "FastForest"; + + //ClassName of the options to trainer + internal override string OptionsName => "FastForestClassification.Options"; + + public FastForestClassification(PipelineNode node) : base(node) + { + } + } + + internal class FastForestRegression : FastTreeBase + { + //ClassName of the trainer + internal override string MethodName => "FastForest"; + + //ClassName of the options to trainer + internal override string OptionsName => "FastForestRegression.Options"; + + public FastForestRegression(PipelineNode node) : base(node) + { + } + } + + internal class FastTreeClassification : FastTreeBase + { + //ClassName of the trainer + internal override string MethodName => "FastTree"; + + //ClassName of the options to trainer + internal override string OptionsName => "FastTreeBinaryTrainer.Options"; + + public FastTreeClassification(PipelineNode node) : base(node) + { + } + } + + internal class FastTreeRegression : FastTreeBase + { + //ClassName of the trainer + internal override string MethodName => "FastTree"; + + //ClassName of the options to trainer + internal override string OptionsName => "FastTreeRegressionTrainer.Options"; + + public FastTreeRegression(PipelineNode node) : base(node) + { + } + } + + internal class FastTreeTweedie : FastTreeBase + { + //ClassName of the trainer + internal override string MethodName => "FastTreeTweedie"; + + //ClassName of the options to trainer + internal override string OptionsName => "FastTreeTweedieTrainer.Options"; + + public FastTreeTweedie(PipelineNode node) : base(node) + { + } + } + #endregion + + internal class LinearSvm : TrainerGeneratorBase + { + //ClassName of the trainer + internal override string MethodName => "LinearSvm"; + + //ClassName of the options to trainer + internal override string OptionsName => "LinearSvmTrainer.Options"; + + //The named parameters to the trainer. + internal override IDictionary NamedParameters + { + get + { + return + new Dictionary() + { + {"ExampleWeightColumnName", "exampleWeightColumnName" }, + {"LabelColumnName","labelColumnName" }, + {"FeatureColumnName","featureColumnName" }, + {"NumberOfIterations","numIterations" }, + }; + } + } + + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n " }; + + public LinearSvm(PipelineNode node) : base(node) + { + } + } + + #region Logistic Regression + + internal abstract class LbfgsLogisticRegressionBase : TrainerGeneratorBase + { + //The named parameters to the trainer. + internal override IDictionary NamedParameters + { + get + { + return + new Dictionary() + { + {"ExampleWeightColumnName","exampleWeightColumnName" }, + {"LabelColumnName","labelColumnName" }, + {"FeatureColumnName","featureColumnName" }, + {"L1Regularization","l1Regularization" }, + {"L2Regularization","l2Regularization" }, + {"OptmizationTolerance","optimizationTolerance" }, + {"HistorySize","historySize" }, + {"EnforceNonNegativity","enforceNonNegativity" }, + }; + } + } + + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n" }; + + public LbfgsLogisticRegressionBase(PipelineNode node) : base(node) + { + } + } + internal class LbfgsLogisticRegressionBinary : LbfgsLogisticRegressionBase + { + internal override string MethodName => "LbfgsLogisticRegression"; + + //ClassName of the options to trainer + internal override string OptionsName => "LbfgsLogisticRegressionBinaryTrainer.Options"; + + public LbfgsLogisticRegressionBinary(PipelineNode node) : base(node) + { + } + } + + internal class LbfgsMaximumEntropyMulti : LbfgsLogisticRegressionBase + { + internal override string MethodName => "LbfgsMaximumEntropy"; + + //ClassName of the options to trainer + internal override string OptionsName => "LbfgsMaximumEntropyMulticlassTrainer.Options"; + + public LbfgsMaximumEntropyMulti(PipelineNode node) : base(node) + { + } + } + #endregion + + internal class OnlineGradientDescentRegression : TrainerGeneratorBase + { + //ClassName of the trainer + internal override string MethodName => "OnlineGradientDescent"; + + //ClassName of the options to trainer + internal override string OptionsName => "OnlineGradientDescentTrainer.Options"; + + //The named parameters to the trainer. + internal override IDictionary NamedParameters + { + get + { + return + new Dictionary() + { + {"LearningRate" , "learningRate" }, + {"DecreaseLearningRate" , "decreaseLearningRate" }, + {"L2Regularization" , "l2Regularization" }, + {"NumberOfIterations" , "numberOfIterations" }, + {"LabelColumnName" , "labelColumnName" }, + {"FeatureColumnName" , "featureColumnName" }, + {"LossFunction" ,"lossFunction" }, + }; + } + } + + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n" }; + + public OnlineGradientDescentRegression(PipelineNode node) : base(node) + { + } + } + + internal class OlsRegression : TrainerGeneratorBase + { + //ClassName of the trainer + internal override string MethodName => "Ols"; + + //ClassName of the options to trainer + internal override string OptionsName => "OlsTrainer.Options"; + + //The named parameters to the trainer. + internal override IDictionary NamedParameters + { + get + { + return + new Dictionary() + { + {"ExampleWeightColumnName","exampleWeightColumnName" }, + {"LabelColumnName","labelColumnName" }, + {"FeatureColumnName","featureColumnName" }, + }; + } + } + + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n" }; + + public OlsRegression(PipelineNode node) : base(node) + { + } + } + + internal class LbfgsPoissonRegression : TrainerGeneratorBase + { + //ClassName of the trainer + internal override string MethodName => "LbfgsPoissonRegression"; + + //ClassName of the options to trainer + internal override string OptionsName => "LbfgsPoissonRegressionTrainer.Options"; + + //The named parameters to the trainer. + internal override IDictionary NamedParameters + { + get + { + return + new Dictionary() + { + {"ExampleWeightColumnName","exampleWeightColumnName" }, + {"LabelColumnName","labelColumnName" }, + {"FeatureColumnName","featureColumnName" }, + {"L1Regularization","l1Regularization" }, + {"L2Regularization","l2Regularization" }, + {"OptmizationTolerance","optimizationTolerance" }, + {"HistorySize","historySize" }, + {"EnforceNonNegativity","enforceNonNegativity" }, + }; + } + } + + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n" }; + + public LbfgsPoissonRegression(PipelineNode node) : base(node) + { + } + } + + #region SDCA + internal abstract class StochasticDualCoordinateAscentBase : TrainerGeneratorBase + { + //The named parameters to the trainer. + internal override IDictionary NamedParameters + { + get + { + return + new Dictionary() + { + {"ExampleWeightColumnName","exampleWeightColumnName" }, + {"LabelColumnName","labelColumnName" }, + {"FeatureColumnName","featureColumnName" }, + {"Loss","loss" }, + {"L2Regularization","l2Regularization" }, + {"L1Regularization","l1Regularization" }, + {"MaximumNumberOfIterations","maximumNumberOfIterations" } + }; + } + } + + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n" }; + + public StochasticDualCoordinateAscentBase(PipelineNode node) : base(node) + { + } + } + + internal class StochasticDualCoordinateAscentBinary : StochasticDualCoordinateAscentBase + { + internal override string MethodName => "SdcaLogisticRegression"; + + //ClassName of the options to trainer + internal override string OptionsName => "SdcaLogisticRegressionBinaryTrainer.Options"; + + public StochasticDualCoordinateAscentBinary(PipelineNode node) : base(node) + { + } + } + + internal class StochasticDualCoordinateAscentMulti : StochasticDualCoordinateAscentBase + { + internal override string MethodName => "SdcaMaximumEntropy"; + + //ClassName of the options to trainer + internal override string OptionsName => "SdcaMaximumEntropyMulticlassTrainer.Options"; + + public StochasticDualCoordinateAscentMulti(PipelineNode node) : base(node) + { + } + } + + internal class StochasticDualCoordinateAscentRegression : StochasticDualCoordinateAscentBase + { + internal override string MethodName => "Sdca"; + + //ClassName of the options to trainer + internal override string OptionsName => "SdcaRegressionTrainer.Options"; + + public StochasticDualCoordinateAscentRegression(PipelineNode node) : base(node) + { + } + } + #endregion + + internal class SgdCalibratedBinary : TrainerGeneratorBase + { + //ClassName of the trainer + internal override string MethodName => "SgdCalibrated"; + + //ClassName of the options to trainer + internal override string OptionsName => "SgdCalibratedTrainer.Options"; + + //The named parameters to the trainer. + internal override IDictionary NamedParameters + { + get + { + return + new Dictionary() + { + {"ExampleWeightColumnName","exampleWeightColumnName" }, + {"LabelColumnName","labelColumnName" }, + {"FeatureColumnName","featureColumnName" }, + {"NumberOfIterations","numberOfIterations" }, + {"LearningRate","learningRate" }, + {"L2Regularization","l2Regularization" } + }; + } + } + + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n" }; + + public SgdCalibratedBinary(PipelineNode node) : base(node) + { + } + } + + internal class SymbolicSgdLogisticRegressionBinary : TrainerGeneratorBase + { + //ClassName of the trainer + internal override string MethodName => "SymbolicSgdLogisticRegression"; + + //ClassName of the options to trainer + internal override string OptionsName => "SymbolicSgdLogisticRegressionBinaryTrainer.Options"; + + //The named parameters to the trainer. + internal override IDictionary NamedParameters + { + get + { + return + new Dictionary() + { + {"LabelColumnName","labelColumnName" }, + {"FeatureColumnName","featureColumnName" }, + {"NumberOfIterations","numberOfIterations" } + }; + } + } + + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n" }; + + public SymbolicSgdLogisticRegressionBinary(PipelineNode node) : base(node) + { + + } + } + + internal class OneVersusAll : TrainerGeneratorBase + { + private PipelineNode node; + private string[] binaryTrainerUsings = null; + + //ClassName of the trainer + internal override string MethodName => "OneVersusAll"; + + //ClassName of the options to trainer + internal override string OptionsName => null; + + //The named parameters to the trainer. + internal override IDictionary NamedParameters => null; + + internal override string[] Usings => new string[] { "using Microsoft.ML.Trainers;\r\n" }; + + public OneVersusAll(PipelineNode node) : base(node) + { + this.node = node; + } + + public override string GenerateTrainer() + { + StringBuilder sb = new StringBuilder(); + sb.Append(MethodName); + sb.Append("("); + sb.Append("mlContext.BinaryClassification.Trainers."); // This is dependent on the name of the MLContext object in template. + var trainerGenerator = TrainerGeneratorFactory.GetInstance((PipelineNode)this.node.Properties["BinaryTrainer"]); + binaryTrainerUsings = trainerGenerator.GenerateUsings(); + sb.Append(trainerGenerator.GenerateTrainer()); + sb.Append(","); + sb.Append("labelColumnName:"); + sb.Append("\""); + sb.Append(node.Properties["LabelColumnName"]); + sb.Append("\""); + sb.Append(")"); + return sb.ToString(); + } + + public override string[] GenerateUsings() + { + return binaryTrainerUsings; + } + + } + + } +} diff --git a/src/mlnet/CodeGenerator/CSharp/TransformGeneratorBase.cs b/src/mlnet/CodeGenerator/CSharp/TransformGeneratorBase.cs new file mode 100644 index 0000000000..eaeae72b9b --- /dev/null +++ b/src/mlnet/CodeGenerator/CSharp/TransformGeneratorBase.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Auto; + +namespace Microsoft.ML.CLI.CodeGenerator.CSharp +{ + /// + /// Supports generation of code for trainers (Binary,Multi,Regression) + /// Ova is an exception though. Need to figure out how to tackle that. + /// + internal abstract class TransformGeneratorBase : ITransformGenerator + { + //abstract properties + internal abstract string MethodName { get; } + + internal virtual string[] Usings => null; + + protected string[] inputColumns; + + protected string[] outputColumns; + + /// + /// Generates an instance of TrainerGenerator + /// + /// + protected TransformGeneratorBase(PipelineNode node) + { + Initialize(node); + } + + private void Initialize(PipelineNode node) + { + inputColumns = new string[node.InColumns.Length]; + outputColumns = new string[node.OutColumns.Length]; + int i = 0; + foreach (var column in node.InColumns) + { + inputColumns[i++] = "\"" + column + "\""; + } + i = 0; + foreach (var column in node.OutColumns) + { + outputColumns[i++] = "\"" + column + "\""; + } + + } + + public abstract string GenerateTransformer(); + + public string[] GenerateUsings() + { + return Usings; + } + } +} diff --git a/src/mlnet/CodeGenerator/CSharp/TransformGeneratorFactory.cs b/src/mlnet/CodeGenerator/CSharp/TransformGeneratorFactory.cs new file mode 100644 index 0000000000..70500b091b --- /dev/null +++ b/src/mlnet/CodeGenerator/CSharp/TransformGeneratorFactory.cs @@ -0,0 +1,68 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.ML.Auto; + + +namespace Microsoft.ML.CLI.CodeGenerator.CSharp +{ + internal interface ITransformGenerator + { + string GenerateTransformer(); + + string[] GenerateUsings(); + } + + internal static class TransformGeneratorFactory + { + internal static ITransformGenerator GetInstance(PipelineNode node) + { + ITransformGenerator result = null; + if (Enum.TryParse(node.Name, out EstimatorName trainer)) + { + switch (trainer) + { + case EstimatorName.Normalizing: + result = new Normalizer(node); + break; + case EstimatorName.OneHotEncoding: + result = new OneHotEncoding(node); + break; + case EstimatorName.ColumnConcatenating: + result = new ColumnConcat(node); + break; + case EstimatorName.ColumnCopying: + result = new ColumnCopying(node); + break; + case EstimatorName.KeyToValueMapping: + result = new KeyToValueMapping(node); + break; + case EstimatorName.MissingValueIndicating: + result = new MissingValueIndicator(node); + break; + case EstimatorName.MissingValueReplacing: + result = new MissingValueReplacer(node); + break; + case EstimatorName.OneHotHashEncoding: + result = new OneHotHashEncoding(node); + break; + case EstimatorName.TextFeaturizing: + result = new TextFeaturizing(node); + break; + case EstimatorName.TypeConverting: + result = new TypeConverting(node); + break; + case EstimatorName.ValueToKeyMapping: + result = new ValueToKeyMapping(node); + break; + default: + return null; + + } + } + return result; + } + } +} diff --git a/src/mlnet/CodeGenerator/CSharp/TransformGenerators.cs b/src/mlnet/CodeGenerator/CSharp/TransformGenerators.cs new file mode 100644 index 0000000000..6f94d3e080 --- /dev/null +++ b/src/mlnet/CodeGenerator/CSharp/TransformGenerators.cs @@ -0,0 +1,333 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Linq; +using System.Text; +using Microsoft.ML.Auto; + +namespace Microsoft.ML.CLI.CodeGenerator.CSharp +{ + internal class Normalizer : TransformGeneratorBase + { + public Normalizer(PipelineNode node) : base(node) + { + } + + internal override string MethodName => "NormalizeMinMax"; + + public override string GenerateTransformer() + { + StringBuilder sb = new StringBuilder(); + string inputColumn = inputColumns.Count() > 0 ? inputColumns[0] : "\"Features\""; + string outputColumn = outputColumns.Count() > 0 ? outputColumns[0] : throw new Exception($"output columns for the suggested transform: {MethodName} are null"); + sb.Append(MethodName); + sb.Append("("); + sb.Append(outputColumn); + sb.Append(","); + sb.Append(inputColumn); + sb.Append(")"); + return sb.ToString(); + } + } + + internal class OneHotEncoding : TransformGeneratorBase + { + public OneHotEncoding(PipelineNode node) : base(node) + { + } + + internal override string MethodName => "Categorical.OneHotEncoding"; + + private string ArgumentsName = "InputOutputColumnPair"; + + public override string GenerateTransformer() + { + StringBuilder sb = new StringBuilder(); + sb.Append(MethodName); + sb.Append("("); + sb.Append("new []{"); + for (int i = 0; i < inputColumns.Length; i++) + { + sb.Append("new "); + sb.Append(ArgumentsName); + sb.Append("("); + sb.Append(outputColumns[i]); + sb.Append(","); + sb.Append(inputColumns[i]); + sb.Append(")"); + sb.Append(","); + } + sb.Remove(sb.Length - 1, 1); // remove extra , + + sb.Append("}"); + sb.Append(")"); + return sb.ToString(); + } + } + + internal class ColumnConcat : TransformGeneratorBase + { + public ColumnConcat(PipelineNode node) : base(node) + { + } + + internal override string MethodName => "Concatenate"; + + public override string GenerateTransformer() + { + StringBuilder sb = new StringBuilder(); + string inputColumn = inputColumns.Count() > 0 ? inputColumns[0] : "\"Features\""; + string outputColumn = outputColumns.Count() > 0 ? outputColumns[0] : throw new Exception($"output columns for the suggested transform: {MethodName} are null"); + sb.Append(MethodName); + sb.Append("("); + sb.Append(outputColumn); + sb.Append(","); + sb.Append("new []{"); + foreach (var col in inputColumns) + { + sb.Append(col); + sb.Append(","); + } + sb.Remove(sb.Length - 1, 1); + sb.Append("}"); + sb.Append(")"); + return sb.ToString(); + } + } + + internal class ColumnCopying : TransformGeneratorBase + { + public ColumnCopying(PipelineNode node) : base(node) + { + } + + internal override string MethodName => "CopyColumns"; + + public override string GenerateTransformer() + { + StringBuilder sb = new StringBuilder(); + string inputColumn = inputColumns.Count() > 0 ? inputColumns[0] : "\"Features\""; + string outputColumn = outputColumns.Count() > 0 ? outputColumns[0] : throw new Exception($"output columns for the suggested transform: {MethodName} are null"); + sb.Append(MethodName); + sb.Append("("); + sb.Append(outputColumn); + sb.Append(","); + sb.Append(inputColumn); + sb.Append(")"); + return sb.ToString(); + } + } + + internal class KeyToValueMapping : TransformGeneratorBase + { + public KeyToValueMapping(PipelineNode node) : base(node) + { + } + + internal override string MethodName => "Conversion.MapKeyToValue"; + + public override string GenerateTransformer() + { + StringBuilder sb = new StringBuilder(); + string inputColumn = inputColumns.Count() > 0 ? inputColumns[0] : "\"Features\""; + string outputColumn = outputColumns.Count() > 0 ? outputColumns[0] : throw new Exception($"output columns for the suggested transform: {MethodName} are null"); + sb.Append(MethodName); + sb.Append("("); + sb.Append(outputColumn); + sb.Append(","); + sb.Append(inputColumn); + sb.Append(")"); + return sb.ToString(); + } + } + + internal class MissingValueIndicator : TransformGeneratorBase + { + public MissingValueIndicator(PipelineNode node) : base(node) + { + } + + internal override string MethodName => "IndicateMissingValues"; + + private string ArgumentsName = "InputOutputColumnPair"; + + public override string GenerateTransformer() + { + StringBuilder sb = new StringBuilder(); + string inputColumn = inputColumns.Count() > 0 ? inputColumns[0] : "\"Features\""; + string outputColumn = outputColumns.Count() > 0 ? outputColumns[0] : throw new Exception($"output columns for the suggested transform: {MethodName} are null"); + sb.Append(MethodName); + sb.Append("("); + sb.Append("new []{"); + for (int i = 0; i < inputColumns.Length; i++) + { + sb.Append("new "); + sb.Append(ArgumentsName); + sb.Append("("); + sb.Append(outputColumns[i]); + sb.Append(","); + sb.Append(inputColumns[i]); + sb.Append(")"); + sb.Append(","); + } + sb.Remove(sb.Length - 1, 1); // remove extra , + sb.Append("}"); + sb.Append(")"); + return sb.ToString(); + } + } + + internal class MissingValueReplacer : TransformGeneratorBase + { + public MissingValueReplacer(PipelineNode node) : base(node) + { + } + + internal override string MethodName => "ReplaceMissingValues"; + + private string ArgumentsName = "InputOutputColumnPair"; + + public override string GenerateTransformer() + { + StringBuilder sb = new StringBuilder(); + sb.Append(MethodName); + sb.Append("("); + sb.Append("new []{"); + for (int i = 0; i < inputColumns.Length; i++) + { + sb.Append("new "); + sb.Append(ArgumentsName); + sb.Append("("); + sb.Append(outputColumns[i]); + sb.Append(","); + sb.Append(inputColumns[i]); + sb.Append(")"); + sb.Append(","); + } + sb.Remove(sb.Length - 1, 1); // remove extra , + + sb.Append("}"); + sb.Append(")"); + return sb.ToString(); + } + } + + internal class OneHotHashEncoding : TransformGeneratorBase + { + public OneHotHashEncoding(PipelineNode node) : base(node) + { + } + + internal override string MethodName => "Categorical.OneHotHashEncoding"; + + private string ArgumentsName = "InputOutputColumnPair"; + + public override string GenerateTransformer() + { + StringBuilder sb = new StringBuilder(); + sb.Append(MethodName); + sb.Append("("); + sb.Append("new []{"); + for (int i = 0; i < inputColumns.Length; i++) + { + sb.Append("new "); + sb.Append(ArgumentsName); + sb.Append("("); + sb.Append(outputColumns[i]); + sb.Append(","); + sb.Append(inputColumns[i]); + sb.Append(")"); + sb.Append(","); + } + sb.Remove(sb.Length - 1, 1); // remove extra , + + sb.Append("}"); + sb.Append(")"); + return sb.ToString(); + } + } + + internal class TextFeaturizing : TransformGeneratorBase + { + public TextFeaturizing(PipelineNode node) : base(node) + { + } + + internal override string MethodName => "Text.FeaturizeText"; + + public override string GenerateTransformer() + { + StringBuilder sb = new StringBuilder(); + string inputColumn = inputColumns.Count() > 0 ? inputColumns[0] : "\"Features\""; + string outputColumn = outputColumns.Count() > 0 ? outputColumns[0] : throw new Exception($"output columns for the suggested transform: {MethodName} are null"); + sb.Append(MethodName); + sb.Append("("); + sb.Append(outputColumn); + sb.Append(","); + sb.Append(inputColumn); + sb.Append(")"); + return sb.ToString(); + } + } + + internal class TypeConverting : TransformGeneratorBase + { + public TypeConverting(PipelineNode node) : base(node) + { + } + + internal override string MethodName => "Conversion.ConvertType"; + + private string ArgumentsName = "InputOutputColumnPair"; + + public override string GenerateTransformer() + { + StringBuilder sb = new StringBuilder(); + sb.Append(MethodName); + sb.Append("("); + sb.Append("new []{"); + for (int i = 0; i < inputColumns.Length; i++) + { + sb.Append("new "); + sb.Append(ArgumentsName); + sb.Append("("); + sb.Append(outputColumns[i]); + sb.Append(","); + sb.Append(inputColumns[i]); + sb.Append(")"); + sb.Append(","); + } + sb.Remove(sb.Length - 1, 1); // remove extra , + + sb.Append("}"); + sb.Append(")"); + return sb.ToString(); + } + } + + internal class ValueToKeyMapping : TransformGeneratorBase + { + public ValueToKeyMapping(PipelineNode node) : base(node) + { + } + + internal override string MethodName => "Conversion.MapValueToKey"; + + public override string GenerateTransformer() + { + StringBuilder sb = new StringBuilder(); + string inputColumn = inputColumns.Count() > 0 ? inputColumns[0] : "\"Features\""; + string outputColumn = outputColumns.Count() > 0 ? outputColumns[0] : throw new Exception($"output columns for the suggested transform: {MethodName} are null"); + sb.Append(MethodName); + sb.Append("("); + sb.Append(outputColumn); + sb.Append(","); + sb.Append(inputColumn); + sb.Append(")"); + return sb.ToString(); + } + } + +} diff --git a/src/mlnet/CodeGenerator/CodeGenerationHelper.cs b/src/mlnet/CodeGenerator/CodeGenerationHelper.cs new file mode 100644 index 0000000000..b974a38a37 --- /dev/null +++ b/src/mlnet/CodeGenerator/CodeGenerationHelper.cs @@ -0,0 +1,268 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Runtime.ExceptionServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.ML.Auto; +using Microsoft.ML.CLI.CodeGenerator.CSharp; +using Microsoft.ML.CLI.Data; +using Microsoft.ML.CLI.ShellProgressBar; +using Microsoft.ML.CLI.Utilities; +using Microsoft.ML.Data; +using NLog; +using NLog.Targets; + +namespace Microsoft.ML.CLI.CodeGenerator +{ + internal class CodeGenerationHelper + { + + private IAutoMLEngine automlEngine; + private NewCommandSettings settings; + private static Logger logger = LogManager.GetCurrentClassLogger(); + private TaskKind taskKind; + + public CodeGenerationHelper(IAutoMLEngine automlEngine, NewCommandSettings settings) + { + this.automlEngine = automlEngine; + this.settings = settings; + this.taskKind = Utils.GetTaskKind(settings.MlTask); + } + + public void GenerateCode() + { + Stopwatch watch = Stopwatch.StartNew(); + var context = new MLContext(); + var verboseLevel = Utils.GetVerbosity(settings.Verbosity); + + // Infer columns + ColumnInferenceResults columnInference = null; + try + { + var inputColumnInformation = new ColumnInformation(); + inputColumnInformation.LabelColumnName = settings.LabelColumnName; + foreach (var value in settings.IgnoreColumns) + { + inputColumnInformation.IgnoredColumnNames.Add(value); + } + columnInference = automlEngine.InferColumns(context, inputColumnInformation); + } + catch (Exception e) + { + logger.Log(LogLevel.Error, $"{Strings.InferColumnError}"); + logger.Log(LogLevel.Error, e.Message); + logger.Log(LogLevel.Trace, e.ToString()); + logger.Log(LogLevel.Error, Strings.Exiting); + return; + } + + var textLoaderOptions = columnInference.TextLoaderOptions; + var columnInformation = columnInference.ColumnInformation; + + // Sanitization of input data. + Array.ForEach(textLoaderOptions.Columns, t => t.Name = Utils.Sanitize(t.Name)); + columnInformation = Utils.GetSanitizedColumnInformation(columnInformation); + + // Load data + (IDataView trainData, IDataView validationData) = LoadData(context, textLoaderOptions); + + // Explore the models + + // The reason why we are doing this way of defining 3 different results is because of the AutoML API + // i.e there is no common class/interface to handle all three tasks together. + + IEnumerable> binaryRunDetails = default; + IEnumerable> multiRunDetails = default; + IEnumerable> regressionRunDetails = default; + if (verboseLevel > LogLevel.Trace) + { + Console.Write($"{Strings.ExplorePipeline}: "); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine($"{settings.MlTask}"); + Console.ResetColor(); + Console.Write($"{Strings.FurtherLearning}: "); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine($"{ Strings.LearningHttpLink}"); + Console.ResetColor(); + } + + logger.Log(LogLevel.Trace, $"{Strings.ExplorePipeline}: {settings.MlTask}"); + logger.Log(LogLevel.Trace, $"{Strings.FurtherLearning}: {Strings.LearningHttpLink}"); + try + { + var options = new ProgressBarOptions + { + ForegroundColor = ConsoleColor.Yellow, + ForegroundColorDone = ConsoleColor.Yellow, + BackgroundColor = ConsoleColor.Gray, + ProgressCharacter = '\u2593', + BackgroundCharacter = '─', + }; + var wait = TimeSpan.FromSeconds(settings.MaxExplorationTime); + + if (verboseLevel > LogLevel.Trace && !Console.IsOutputRedirected) + { + using (var pbar = new FixedDurationBar(wait, "", options)) + { + pbar.Message = Strings.WaitingForFirstIteration; + Thread t = default; + switch (taskKind) + { + case TaskKind.BinaryClassification: + t = new Thread(() => binaryRunDetails = automlEngine.ExploreBinaryClassificationModels(context, trainData, validationData, columnInformation, new BinaryExperimentSettings().OptimizingMetric, pbar)); + break; + case TaskKind.Regression: + t = new Thread(() => regressionRunDetails = automlEngine.ExploreRegressionModels(context, trainData, validationData, columnInformation, new RegressionExperimentSettings().OptimizingMetric, pbar)); + break; + case TaskKind.MulticlassClassification: + t = new Thread(() => multiRunDetails = automlEngine.ExploreMultiClassificationModels(context, trainData, validationData, columnInformation, new MulticlassExperimentSettings().OptimizingMetric, pbar)); + break; + default: + logger.Log(LogLevel.Error, Strings.UnsupportedMlTask); + break; + } + t.Start(); + + if (!pbar.CompletedHandle.WaitOne(wait)) + pbar.Message = $"{nameof(FixedDurationBar)} did not signal {nameof(FixedDurationBar.CompletedHandle)} after {wait}"; + + if (t.IsAlive == true) + { + string waitingMessage = Strings.WaitingForLastIteration; + string originalMessage = pbar.Message; + pbar.Message = waitingMessage; + t.Join(); + if (waitingMessage.Equals(pbar.Message)) + { + // Corner cases where thread was alive but has completed all iterations. + pbar.Message = originalMessage; + } + } + } + } + else + { + switch (taskKind) + { + case TaskKind.BinaryClassification: + binaryRunDetails = automlEngine.ExploreBinaryClassificationModels(context, trainData, validationData, columnInformation, new BinaryExperimentSettings().OptimizingMetric); + break; + case TaskKind.Regression: + regressionRunDetails = automlEngine.ExploreRegressionModels(context, trainData, validationData, columnInformation, new RegressionExperimentSettings().OptimizingMetric); + break; + case TaskKind.MulticlassClassification: + multiRunDetails = automlEngine.ExploreMultiClassificationModels(context, trainData, validationData, columnInformation, new MulticlassExperimentSettings().OptimizingMetric); + break; + default: + logger.Log(LogLevel.Error, Strings.UnsupportedMlTask); + break; + } + } + + + } + catch (Exception e) + { + logger.Log(LogLevel.Error, $"{Strings.ExplorePipelineException}:"); + logger.Log(LogLevel.Error, e.Message); + logger.Log(LogLevel.Debug, e.ToString()); + logger.Log(LogLevel.Info, Strings.LookIntoLogFile); + logger.Log(LogLevel.Error, Strings.Exiting); + return; + } + + var elapsedTime = watch.Elapsed.TotalSeconds; + + //Get the best pipeline + Pipeline bestPipeline = null; + ITransformer bestModel = null; + try + { + switch (taskKind) + { + case TaskKind.BinaryClassification: + var bestBinaryIteration = binaryRunDetails.Best(); + bestPipeline = bestBinaryIteration.Pipeline; + bestModel = bestBinaryIteration.Model; + ConsolePrinter.ExperimentResultsHeader(LogLevel.Info, settings.MlTask, settings.Dataset.Name, columnInformation.LabelColumnName, elapsedTime.ToString("F2"), binaryRunDetails.Count()); + ConsolePrinter.PrintIterationSummary(binaryRunDetails, new BinaryExperimentSettings().OptimizingMetric, 5); + break; + case TaskKind.Regression: + var bestRegressionIteration = regressionRunDetails.Best(); + bestPipeline = bestRegressionIteration.Pipeline; + bestModel = bestRegressionIteration.Model; + ConsolePrinter.ExperimentResultsHeader(LogLevel.Info, settings.MlTask, settings.Dataset.Name, columnInformation.LabelColumnName, elapsedTime.ToString("F2"), regressionRunDetails.Count()); + ConsolePrinter.PrintIterationSummary(regressionRunDetails, new RegressionExperimentSettings().OptimizingMetric, 5); + break; + case TaskKind.MulticlassClassification: + var bestMultiIteration = multiRunDetails.Best(); + bestPipeline = bestMultiIteration.Pipeline; + bestModel = bestMultiIteration.Model; + ConsolePrinter.ExperimentResultsHeader(LogLevel.Info, settings.MlTask, settings.Dataset.Name, columnInformation.LabelColumnName, elapsedTime.ToString("F2"), multiRunDetails.Count()); + ConsolePrinter.PrintIterationSummary(multiRunDetails, new MulticlassExperimentSettings().OptimizingMetric, 5); + break; + } + } + catch (Exception e) + { + logger.Log(LogLevel.Info, Strings.ErrorBestPipeline); + logger.Log(LogLevel.Info, e.Message); + logger.Log(LogLevel.Trace, e.ToString()); + logger.Log(LogLevel.Info, Strings.LookIntoLogFile); + logger.Log(LogLevel.Error, Strings.Exiting); + return; + } + + // Save the model + var modelprojectDir = Path.Combine(settings.OutputPath.FullName, $"{settings.Name}.Model"); + var modelPath = new FileInfo(Path.Combine(modelprojectDir, "MLModel.zip")); + Utils.SaveModel(bestModel, modelPath, context, trainData.Schema); + Console.ForegroundColor = ConsoleColor.Yellow; + logger.Log(LogLevel.Info, $"{Strings.SavingBestModel}: {modelPath}"); + + // Generate the Project + GenerateProject(columnInference, bestPipeline, columnInformation.LabelColumnName, modelPath); + logger.Log(LogLevel.Info, $"{Strings.GenerateModelConsumption}: { Path.Combine(settings.OutputPath.FullName, $"{settings.Name}.ConsoleApp")}"); + logger.Log(LogLevel.Info, $"{Strings.SeeLogFileForMoreInfo}: {settings.LogFilePath}"); + Console.ResetColor(); + } + + internal void GenerateProject(ColumnInferenceResults columnInference, Pipeline pipeline, string labelName, FileInfo modelPath) + { + // Generate code + var codeGenerator = new CodeGenerator.CSharp.CodeGenerator( + pipeline, + columnInference, + new CodeGeneratorSettings() + { + TrainDataset = settings.Dataset.FullName, + MlTask = taskKind, + TestDataset = settings.TestDataset?.FullName, + OutputName = settings.Name, + OutputBaseDir = settings.OutputPath.FullName, + LabelName = labelName, + ModelPath = modelPath.FullName + }); + codeGenerator.GenerateOutput(); + } + + internal (IDataView, IDataView) LoadData(MLContext context, TextLoader.Options textLoaderOptions) + { + logger.Log(LogLevel.Trace, Strings.CreateDataLoader); + var textLoader = context.Data.CreateTextLoader(textLoaderOptions); + + logger.Log(LogLevel.Trace, Strings.LoadData); + var trainData = textLoader.Load(settings.Dataset.FullName); + var validationData = settings.ValidationDataset == null ? null : textLoader.Load(settings.ValidationDataset.FullName); + + return (trainData, validationData); + } + } +} diff --git a/src/mlnet/CodeGenerator/IProjectGenerator.cs b/src/mlnet/CodeGenerator/IProjectGenerator.cs new file mode 100644 index 0000000000..5dfb7f60fc --- /dev/null +++ b/src/mlnet/CodeGenerator/IProjectGenerator.cs @@ -0,0 +1,11 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.CLI.CodeGenerator +{ + internal interface IProjectGenerator + { + void GenerateOutput(); + } +} \ No newline at end of file diff --git a/src/mlnet/Commands/CommandDefinitions.cs b/src/mlnet/Commands/CommandDefinitions.cs new file mode 100644 index 0000000000..54cbce8026 --- /dev/null +++ b/src/mlnet/Commands/CommandDefinitions.cs @@ -0,0 +1,160 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.CommandLine; +using System.CommandLine.Builder; +using System.CommandLine.Invocation; +using System.IO; +using System.Linq; + +namespace Microsoft.ML.CLI.Commands +{ + internal static class CommandDefinitions + { + internal static System.CommandLine.Command AutoTrain(ICommandHandler handler) + { + var newCommand = new System.CommandLine.Command("auto-train", "Create a new .NET project using ML.NET to train and run a model", handler: handler) + { + MlTask(), + Dataset(), + ValidationDataset(), + TestDataset(), + LabelName(), + LabelColumnIndex(), + HasHeader(), + MaxExplorationTime(), + Cache(), + IgnoreColumns(), + Verbosity(), + Name(), + OutputPath(), + }; + + newCommand.Argument.AddValidator((sym) => + { + if (!sym.Children.Contains("--dataset")) + { + return "Option required : --dataset"; + } + if (!sym.Children.Contains("--ml-task")) + { + return "Option required : --ml-task"; + } + if (!sym.Children.Contains("--label-column-name") && !sym.Children.Contains("--label-column-index")) + { + return "Option required : --label-column-name or --label-column-index"; + } + if (sym.Children.Contains("--label-column-name") && sym.Children.Contains("--label-column-index")) + { + return "The following options are mutually exclusive please provide only one : --label-column-name, --label-column-index"; + } + if (sym.Children.Contains("--label-column-index") && sym.Children["--ignore-columns"]?.Arguments.Count > 0) + { + return "Currently we don't support specifying --ignore-columns in conjunction with --label-column-index"; + } + + return null; + }); + + return newCommand; + + Option Dataset() => + new Option(new List() { "--dataset", "-d" }, "File path to either a single dataset or a training dataset for train/test split approaches.", + new Argument().ExistingOnly()); + + Option ValidationDataset() => + new Option(new List() { "--validation-dataset", "-v" }, "File path for the validation dataset in train/validation/test split approaches.", + new Argument(defaultValue: default(FileInfo)).ExistingOnly()); + + Option TestDataset() => + new Option(new List() { "--test-dataset", "-t" }, "File path for the test dataset in train/test approaches.", + new Argument(defaultValue: default(FileInfo)).ExistingOnly()); + + Option MlTask() => + new Option(new List() { "--ml-task", "--mltask", "--task", "-T" }, "Type of ML task to perform. Current supported tasks: regression, binary-classification, multiclass-classification.", + new Argument().FromAmong(GetMlTaskSuggestions())); + + Option LabelName() => + new Option(new List() { "--label-column-name", "-n" }, "Name of the label (target) column to predict.", + new Argument()); + + Option LabelColumnIndex() => + new Option(new List() { "--label-column-index", "-i" }, "Index of the label (target) column to predict.", + new Argument()); + + Option MaxExplorationTime() => + new Option(new List() { "--max-exploration-time", "-x" }, "Maximum time in seconds for exploring models with best configuration.", + new Argument(defaultValue: 10)); + + Option Verbosity() => + new Option(new List() { "--verbosity", "-V" }, "Output verbosity choices: q[uiet], m[inimal] (by default) and diag[nostic].", + new Argument(defaultValue: "m").FromAmong(GetVerbositySuggestions())); + + Option Name() => + new Option(new List() { "--name", "-N" }, "Name for the output project or solution to create. ", + new Argument()); + + Option OutputPath() => + new Option(new List() { "--output-path", "-o" }, "Location folder to place the generated output. The default is the current directory.", + new Argument(defaultValue: new DirectoryInfo("."))); + + Option HasHeader() => + new Option(new List() { "--has-header", "-h" }, "Specify true/false depending if the dataset file(s) have a header row.", + new Argument(defaultValue: true)); + + Option Cache() => + new Option(new List() { "--cache", "-c" }, "Specify on/off/auto if you want cache to be turned on, off or auto determined.", + new Argument(defaultValue: "auto").FromAmong(GetCacheSuggestions())); + + // This is a temporary hack to work around having comma separated values for argument. This feature needs to be enabled in the parser itself. + Option IgnoreColumns() => + new Option(new List() { "--ignore-columns", "-I" }, "Specify the columns that needs to be ignored in the given dataset.", + new Argument>(symbolResult => + { + try + { + List valuesList = new List(); + foreach (var argument in symbolResult.Arguments) + { + if (!string.IsNullOrWhiteSpace(argument)) + { + var values = argument.Split(",", StringSplitOptions.RemoveEmptyEntries); + valuesList.AddRange(values); + } + } + if (valuesList.Count > 0) + return ArgumentResult.Success(valuesList); + + } + catch (Exception) + { + return ArgumentResult.Failure($"Unknown exception occured while parsing argument for --ignore-columns :{string.Join(' ', symbolResult.Arguments.ToArray())}"); + } + + //This shouldn't be hit. + return ArgumentResult.Failure($"Unknown error while parsing argument for --ignore-columns"); + }) + { + Arity = ArgumentArity.OneOrMore, + }); + } + + private static string[] GetMlTaskSuggestions() + { + return new[] { "binary-classification", "multiclass-classification", "regression" }; + } + + private static string[] GetVerbositySuggestions() + { + return new[] { "q", "m", "diag" }; + } + + private static string[] GetCacheSuggestions() + { + return new[] { "on", "off", "auto" }; + } + } +} diff --git a/src/mlnet/Commands/IRunnableCommand.cs b/src/mlnet/Commands/IRunnableCommand.cs new file mode 100644 index 0000000000..8d676ea4ed --- /dev/null +++ b/src/mlnet/Commands/IRunnableCommand.cs @@ -0,0 +1,11 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.CLI.Commands +{ + internal interface ICommand + { + void Execute(); + } +} \ No newline at end of file diff --git a/src/mlnet/Commands/New/NewCommandHandler.cs b/src/mlnet/Commands/New/NewCommandHandler.cs new file mode 100644 index 0000000000..1be9219143 --- /dev/null +++ b/src/mlnet/Commands/New/NewCommandHandler.cs @@ -0,0 +1,30 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.DotNet.Cli.Telemetry; +using Microsoft.ML.CLI.CodeGenerator; +using Microsoft.ML.CLI.Data; + +namespace Microsoft.ML.CLI.Commands.New +{ + internal class NewCommand : ICommand + { + private readonly NewCommandSettings settings; + private readonly MlTelemetry telemetry; + + internal NewCommand(NewCommandSettings settings, MlTelemetry telemetry) + { + this.settings = settings; + this.telemetry = telemetry; + } + + public void Execute() + { + telemetry.LogAutoTrainMlCommand(settings.Dataset.FullName, settings.MlTask.ToString(), settings.Dataset.Length); + + CodeGenerationHelper codeGenerationHelper = new CodeGenerationHelper(new AutoMLEngine(settings), settings); // Needs to be improved. + codeGenerationHelper.GenerateCode(); + } + } +} diff --git a/src/mlnet/Commands/New/NewCommandSettings.cs b/src/mlnet/Commands/New/NewCommandSettings.cs new file mode 100644 index 0000000000..22fb7c19d4 --- /dev/null +++ b/src/mlnet/Commands/New/NewCommandSettings.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.IO; + +namespace Microsoft.ML.CLI.Data +{ + public class NewCommandSettings + { + public string Name { get; set; } + + public FileInfo Dataset { get; set; } + + public FileInfo ValidationDataset { get; set; } + + public FileInfo TestDataset { get; set; } + + public string LabelColumnName { get; set; } + + public string Verbosity { get; set; } + + public uint LabelColumnIndex { get; set; } + + public string MlTask { get; set; } + + public uint MaxExplorationTime { get; set; } + + public DirectoryInfo OutputPath { get; set; } + + public bool HasHeader { get; set; } + + public string Cache { get; set; } + + public List IgnoreColumns { get; set; } = new List(); + + public string LogFilePath { get; set; } + + } +} diff --git a/src/mlnet/NLog.config b/src/mlnet/NLog.config new file mode 100644 index 0000000000..b2ae67740e --- /dev/null +++ b/src/mlnet/NLog.config @@ -0,0 +1,13 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/src/mlnet/Program.cs b/src/mlnet/Program.cs new file mode 100644 index 0000000000..8a2715243b --- /dev/null +++ b/src/mlnet/Program.cs @@ -0,0 +1,93 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.CommandLine.Builder; +using System.CommandLine.Invocation; +using System.IO; +using System.Linq; +using Microsoft.DotNet.Cli.Telemetry; +using Microsoft.ML.CLI.Commands; +using Microsoft.ML.CLI.Commands.New; +using Microsoft.ML.CLI.Data; +using Microsoft.ML.CLI.Utilities; +using NLog; +using NLog.Targets; + +namespace Microsoft.ML.CLI +{ + class Program + { + public static void Main(string[] args) + { + var telemetry = new MlTelemetry(); + + // Create handler outside so that commandline and the handler is decoupled and testable. + var handler = CommandHandler.Create( + (options) => + { + // Map the verbosity to internal levels + var verbosity = Utils.GetVerbosity(options.Verbosity); + + // Build the output path + string outputBaseDir = string.Empty; + if (options.Name == null) + { + + options.Name = "Sample" + Utils.GetTaskKind(options.MlTask).ToString(); + outputBaseDir = Path.Combine(options.OutputPath.FullName, options.Name); + } + else + { + outputBaseDir = Path.Combine(options.OutputPath.FullName, options.Name); + } + + // Override the output path + options.OutputPath = new DirectoryInfo(outputBaseDir); + + // Instantiate the command + var command = new NewCommand(options, telemetry); + + // Override the Logger Configuration + var logconsole = LogManager.Configuration.FindTargetByName("logconsole"); + var logfile = (FileTarget)LogManager.Configuration.FindTargetByName("logfile"); + var logFilePath = Path.Combine(Path.Combine(outputBaseDir, "logs"), "debug_log.txt"); + logfile.FileName = logFilePath; + options.LogFilePath = logFilePath; + var config = LogManager.Configuration; + config.AddRule(verbosity, LogLevel.Fatal, logconsole); + + // Execute the command + command.Execute(); + }); + + var parser = new CommandLineBuilder() + // parser + .AddCommand(CommandDefinitions.AutoTrain(handler)) + .UseDefaults() + .Build(); + + var parseResult = parser.Parse(args); + + if (parseResult.Errors.Count == 0) + { + if (parseResult.RootCommandResult.Children.Count > 0) + { + var command = parseResult.RootCommandResult.Children.First(); + var parsedArguments = command.Children; + + if (parsedArguments.Count > 0) + { + var options = parsedArguments.ToList().Where(sr => sr is System.CommandLine.OptionResult).Cast(); + + var explicitlySpecifiedOptions = options.Where(opt => !opt.IsImplicit).Select(opt => opt.Name); + + telemetry.SetCommandAndParameters(command.Name, explicitlySpecifiedOptions); + } + } + } + + parser.InvokeAsync(parseResult).Wait(); + } + } +} diff --git a/src/mlnet/ProgressBar/ChildProgressBar.cs b/src/mlnet/ProgressBar/ChildProgressBar.cs new file mode 100644 index 0000000000..99476c77b7 --- /dev/null +++ b/src/mlnet/ProgressBar/ChildProgressBar.cs @@ -0,0 +1,52 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.CLI.ShellProgressBar +{ + public class ChildProgressBar : ProgressBarBase, IProgressBar + { + private readonly Action _scheduleDraw; + private readonly Action _growth; + + public DateTime StartDate { get; } = DateTime.Now; + + protected override void DisplayProgress() => _scheduleDraw?.Invoke(); + + internal ChildProgressBar(int maxTicks, string message, Action scheduleDraw, ProgressBarOptions options = null, Action growth = null) + : base(maxTicks, message, options) + { + _scheduleDraw = scheduleDraw; + _growth = growth; + _growth?.Invoke(ProgressBarHeight.Increment); + } + + private bool _calledDone; + private readonly object _callOnce = new object(); + + protected override void OnDone() + { + if (_calledDone) return; + lock (_callOnce) + { + if (_calledDone) return; + + if (this.EndTime == null) + this.EndTime = DateTime.Now; + + if (this.Collapse) + _growth?.Invoke(ProgressBarHeight.Decrement); + + _calledDone = true; + } + } + + public void Dispose() + { + OnDone(); + foreach (var c in this.Children) c.Dispose(); + } + } +} diff --git a/src/mlnet/ProgressBar/FixedDurationBar.cs b/src/mlnet/ProgressBar/FixedDurationBar.cs new file mode 100644 index 0000000000..7b4879ef18 --- /dev/null +++ b/src/mlnet/ProgressBar/FixedDurationBar.cs @@ -0,0 +1,41 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Threading; + +namespace Microsoft.ML.CLI.ShellProgressBar +{ + public class FixedDurationBar : ProgressBar + { + public bool IsCompleted { get; private set; } + + private readonly ManualResetEvent _completedHandle = new ManualResetEvent(false); + public WaitHandle CompletedHandle => _completedHandle; + + public FixedDurationBar(TimeSpan duration, string message, ConsoleColor color) : this(duration, message, new ProgressBarOptions { ForegroundColor = color }) { } + + public FixedDurationBar(TimeSpan duration, string message, ProgressBarOptions options = null) : base((int)Math.Ceiling(duration.TotalSeconds), message, options) + { + if (!this.Options.DisplayTimeInRealTime) + throw new ArgumentException( + $"{nameof(ProgressBarOptions)}.{nameof(ProgressBarOptions.DisplayTimeInRealTime)} has to be true for {nameof(FixedDurationBar)}", nameof(options) + ); + } + + private long _seenTicks = 0; + protected override void OnTimerTick() + { + Interlocked.Increment(ref _seenTicks); + if (_seenTicks % 2 == 0) this.Tick(); + base.OnTimerTick(); + } + + protected override void OnDone() + { + this.IsCompleted = true; + this._completedHandle.Set(); + } + } +} diff --git a/src/mlnet/ProgressBar/IProgressBar.cs b/src/mlnet/ProgressBar/IProgressBar.cs new file mode 100644 index 0000000000..3fe778efcd --- /dev/null +++ b/src/mlnet/ProgressBar/IProgressBar.cs @@ -0,0 +1,22 @@ +using System; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.CLI.ShellProgressBar +{ + public interface IProgressBar : IDisposable + { + ChildProgressBar Spawn(int maxTicks, string message, ProgressBarOptions options = null); + + void Tick(string message = null); + + int MaxTicks { get; set; } + string Message { get; set; } + + double Percentage { get; } + int CurrentTick { get; } + + ConsoleColor ForeGroundColor { get; } + } +} diff --git a/src/mlnet/ProgressBar/ProgressBar.cs b/src/mlnet/ProgressBar/ProgressBar.cs new file mode 100644 index 0000000000..5f8fac716e --- /dev/null +++ b/src/mlnet/ProgressBar/ProgressBar.cs @@ -0,0 +1,361 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.ML.CLI.ShellProgressBar +{ + public class ProgressBar : ProgressBarBase, IProgressBar + { + private static readonly bool IsWindows = RuntimeInformation.IsOSPlatform(OSPlatform.Windows); + + private readonly ConsoleColor _originalColor; + private readonly int _originalCursorTop; + private readonly int _originalWindowTop; + private int _isDisposed; + + private Timer _timer; + private int _visibleDescendants = 0; + private readonly AutoResetEvent _displayProgressEvent; + private readonly Task _displayProgress; + + public ProgressBar(int maxTicks, string message, ConsoleColor color) + : this(maxTicks, message, new ProgressBarOptions { ForegroundColor = color }) + { + } + + public ProgressBar(int maxTicks, string message, ProgressBarOptions options = null) + : base(maxTicks, message, options) + { + Console.WriteLine(); + Console.SetCursorPosition(Console.CursorLeft, Math.Max(0, Console.CursorTop - 1)); + _originalCursorTop = Console.CursorTop; + _originalWindowTop = Console.WindowTop; + _originalColor = Console.ForegroundColor; + + Console.CursorVisible = false; + + if (this.Options.EnableTaskBarProgress) + TaskbarProgress.SetState(TaskbarProgress.TaskbarStates.Normal); + + if (this.Options.DisplayTimeInRealTime) + _timer = new Timer((s) => OnTimerTick(), null, 500, 500); + else //draw once + _timer = new Timer((s) => + { + _timer.Dispose(); + DisplayProgress(); + }, null, 0, 1000); + + _displayProgressEvent = new AutoResetEvent(false); + _displayProgress = Task.Run(() => + { + while (_isDisposed == 0) + { + if (!_displayProgressEvent.WaitOne(TimeSpan.FromSeconds(10))) + continue; + try + { + UpdateProgress(); + } + catch + { + // don't want to crash background thread + } + } + }); + } + + protected virtual void OnTimerTick() + { + DisplayProgress(); + } + + protected override void Grow(ProgressBarHeight direction) + { + switch (direction) + { + case ProgressBarHeight.Increment: + Interlocked.Increment(ref _visibleDescendants); + break; + case ProgressBarHeight.Decrement: + Interlocked.Decrement(ref _visibleDescendants); + break; + } + } + + private struct Indentation + { + public Indentation(ConsoleColor color, bool lastChild) + { + this.ConsoleColor = color; + this.LastChild = lastChild; + } + + public string Glyph => !LastChild ? "├─" : "└─"; + + public readonly ConsoleColor ConsoleColor; + public readonly bool LastChild; + } + + private static void ProgressBarBottomHalf(double percentage, DateTime startDate, DateTime? endDate, string message, + Indentation[] indentation, bool progressBarOnBottom) + { + var depth = indentation.Length; + var maxCharacterWidth = Console.WindowWidth - (depth * 2) + 2; + var duration = ((endDate ?? DateTime.Now) - startDate); + string durationString = null; + if (duration.Days > 0) + durationString = $"{duration.Days:00}:{duration.Hours:00}:{duration.Minutes:00}:{duration.Seconds:00}"; + else + durationString = $"{duration.Hours:00}:{duration.Minutes:00}:{duration.Seconds:00}"; + + var column1Width = Console.WindowWidth - durationString.Length - (depth * 2) + 2; + var column2Width = durationString.Length; + + if (progressBarOnBottom) + DrawTopHalfPrefix(indentation, depth); + else + DrawBottomHalfPrefix(indentation, depth); + + var format = $"{{0, -{column1Width}}}{{1,{column2Width}}}"; + + var truncatedMessage = StringExtensions.Excerpt(message, column1Width); + var formatted = string.Format(format, truncatedMessage, durationString); + var m = formatted + new string(' ', Math.Max(0, maxCharacterWidth - formatted.Length)); + Console.Write(m); + } + + private static void DrawBottomHalfPrefix(Indentation[] indentation, int depth) + { + for (var i = 1; i < depth; i++) + { + var ind = indentation[i]; + Console.ForegroundColor = indentation[i - 1].ConsoleColor; + if (!ind.LastChild) + Console.Write(i == (depth - 1) ? ind.Glyph : "│ "); + else + Console.Write(i == (depth - 1) ? ind.Glyph : " "); + } + + Console.ForegroundColor = indentation[depth - 1].ConsoleColor; + } + + private static void ProgressBarTopHalf( + double percentage, + char progressCharacter, + char? progressBackgroundCharacter, + ConsoleColor? backgroundColor, + Indentation[] indentation, bool progressBarOnTop) + { + var depth = indentation.Length; + var width = Console.WindowWidth - (depth * 2) + 2; + + if (progressBarOnTop) + DrawBottomHalfPrefix(indentation, depth); + else + DrawTopHalfPrefix(indentation, depth); + + var newWidth = (int)((width * percentage) / 100d); + var progBar = new string(progressCharacter, newWidth); + Console.Write(progBar); + if (backgroundColor.HasValue) + { + Console.ForegroundColor = backgroundColor.Value; + Console.Write(new string(progressBackgroundCharacter ?? progressCharacter, width - newWidth)); + } + else Console.Write(new string(' ', width - newWidth)); + + Console.ForegroundColor = indentation[depth - 1].ConsoleColor; + } + + private static void DrawTopHalfPrefix(Indentation[] indentation, int depth) + { + for (var i = 1; i < depth; i++) + { + var ind = indentation[i]; + Console.ForegroundColor = indentation[i - 1].ConsoleColor; + if (ind.LastChild && i != (depth - 1)) + Console.Write(" "); + else + Console.Write("│ "); + } + + Console.ForegroundColor = indentation[depth - 1].ConsoleColor; + } + + protected override void DisplayProgress() + { + _displayProgressEvent.Set(); + } + + private void UpdateProgress() + { + Console.CursorVisible = false; + var indentation = new[] { new Indentation(this.ForeGroundColor, true) }; + var mainPercentage = this.Percentage; + var cursorTop = _originalCursorTop; + + Console.ForegroundColor = this.ForeGroundColor; + + void TopHalf() + { + ProgressBarTopHalf(mainPercentage, + this.Options.ProgressCharacter, + this.Options.BackgroundCharacter, + this.Options.BackgroundColor, + indentation, + this.Options.ProgressBarOnBottom + ); + } + + if (this.Options.ProgressBarOnBottom) + { + ProgressBarBottomHalf(mainPercentage, this._startDate, null, this.Message, indentation, this.Options.ProgressBarOnBottom); + Console.SetCursorPosition(0, ++cursorTop); + TopHalf(); + } + else + { + TopHalf(); + Console.SetCursorPosition(0, ++cursorTop); + ProgressBarBottomHalf(mainPercentage, this._startDate, null, this.Message, indentation, this.Options.ProgressBarOnBottom); + } + + if (this.Options.EnableTaskBarProgress) + TaskbarProgress.SetValue(mainPercentage, 100); + + DrawChildren(this.Children, indentation, ref cursorTop); + + ResetToBottom(ref cursorTop); + + Console.SetCursorPosition(0, _originalCursorTop); + Console.ForegroundColor = _originalColor; + + if (!(mainPercentage >= 100)) return; + _timer?.Dispose(); + _timer = null; + } + + private static void ResetToBottom(ref int cursorTop) + { + var resetString = new string(' ', Console.WindowWidth); + var windowHeight = Console.WindowHeight; + if (cursorTop >= (windowHeight - 1)) return; + do + { + Console.Write(resetString); + } while (++cursorTop < (windowHeight - 1)); + } + + private static void DrawChildren(IEnumerable children, Indentation[] indentation, ref int cursorTop) + { + var view = children.Where(c => !c.Collapse).Select((c, i) => new { c, i }).ToList(); + if (!view.Any()) return; + + var windowHeight = Console.WindowHeight; + var lastChild = view.Max(t => t.i); + foreach (var tuple in view) + { + //Dont bother drawing children that would fall off the screen + if (cursorTop >= (windowHeight - 2)) + return; + + var child = tuple.c; + var currentIndentation = new Indentation(child.ForeGroundColor, tuple.i == lastChild); + var childIndentation = NewIndentation(indentation, currentIndentation); + + var percentage = child.Percentage; + Console.ForegroundColor = child.ForeGroundColor; + + void TopHalf() + { + ProgressBarTopHalf(percentage, + child.Options.ProgressCharacter, + child.Options.BackgroundCharacter, + child.Options.BackgroundColor, + childIndentation, + child.Options.ProgressBarOnBottom + ); + } + + Console.SetCursorPosition(0, ++cursorTop); + + if (child.Options.ProgressBarOnBottom) + { + ProgressBarBottomHalf(percentage, child.StartDate, child.EndTime, child.Message, childIndentation, child.Options.ProgressBarOnBottom); + Console.SetCursorPosition(0, ++cursorTop); + TopHalf(); + } + else + { + TopHalf(); + Console.SetCursorPosition(0, ++cursorTop); + ProgressBarBottomHalf(percentage, child.StartDate, child.EndTime, child.Message, childIndentation, child.Options.ProgressBarOnBottom); + } + + DrawChildren(child.Children, childIndentation, ref cursorTop); + } + } + + private static Indentation[] NewIndentation(Indentation[] array, Indentation append) + { + var result = new Indentation[array.Length + 1]; + Array.Copy(array, result, array.Length); + result[array.Length] = append; + return result; + } + + public void Dispose() + { + if (Interlocked.CompareExchange(ref _isDisposed, 1, 0) != 0) + return; + + // make sure background task is stopped before we clean up + _displayProgressEvent.Set(); + _displayProgress.Wait(); + + // update one last time - needed because background task might have + // been already in progress before Dispose was called and it might + // have been running for a very long time due to poor performance + // of System.Console + UpdateProgress(); + + if (this.EndTime == null) this.EndTime = DateTime.Now; + var openDescendantsPadding = (_visibleDescendants * 2); + + if (this.Options.EnableTaskBarProgress) + TaskbarProgress.SetState(TaskbarProgress.TaskbarStates.NoProgress); + + try + { + var moveDown = 0; + var currentWindowTop = Console.WindowTop; + if (currentWindowTop != _originalWindowTop) + { + var x = Math.Max(0, Math.Min(2, currentWindowTop - _originalWindowTop)); + moveDown = _originalCursorTop + x; + } + else moveDown = _originalCursorTop + 2; + + Console.CursorVisible = true; + Console.SetCursorPosition(0, openDescendantsPadding + moveDown); + } + // This is bad and I should feel bad, but i rather eat pbar exceptions in productions then causing false negatives + catch + { + } + + Console.WriteLine(); + _timer?.Dispose(); + _timer = null; + foreach (var c in this.Children) c.Dispose(); + } + } +} diff --git a/src/mlnet/ProgressBar/ProgressBarBase.cs b/src/mlnet/ProgressBar/ProgressBarBase.cs new file mode 100644 index 0000000000..d47985f1e5 --- /dev/null +++ b/src/mlnet/ProgressBar/ProgressBarBase.cs @@ -0,0 +1,119 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Concurrent; +using System.Text; +using System.Threading; + +namespace Microsoft.ML.CLI.ShellProgressBar +{ + public abstract class ProgressBarBase + { + static ProgressBarBase() + { + Encoding.RegisterProvider(CodePagesEncodingProvider.Instance); + } + + protected readonly DateTime _startDate = DateTime.Now; + private int _maxTicks; + private int _currentTick; + private string _message; + + protected ProgressBarBase(int maxTicks, string message, ProgressBarOptions options) + { + this._maxTicks = Math.Max(0, maxTicks); + this._message = message; + this.Options = options ?? ProgressBarOptions.Default; + } + + internal ProgressBarOptions Options { get; } + internal ConcurrentBag Children { get; } = new ConcurrentBag(); + + protected abstract void DisplayProgress(); + + protected virtual void Grow(ProgressBarHeight direction) + { + } + + protected virtual void OnDone() + { + } + + public DateTime? EndTime { get; protected set; } + + public ConsoleColor ForeGroundColor => + EndTime.HasValue ? this.Options.ForegroundColorDone ?? this.Options.ForegroundColor : this.Options.ForegroundColor; + + public int CurrentTick => _currentTick; + + public int MaxTicks + { + get => _maxTicks; + set + { + Interlocked.Exchange(ref _maxTicks, value); + DisplayProgress(); + } + } + + public string Message + { + get => _message; + set + { + Interlocked.Exchange(ref _message, value); + DisplayProgress(); + } + } + + public double Percentage + { + get + { + var percentage = Math.Max(0, Math.Min(100, (100.0 / this._maxTicks) * this._currentTick)); + // Gracefully handle if the percentage is NaN due to division by 0 + if (double.IsNaN(percentage) || percentage < 0) percentage = 100; + return percentage; + } + } + + public bool Collapse => this.EndTime.HasValue && this.Options.CollapseWhenFinished; + + public ChildProgressBar Spawn(int maxTicks, string message, ProgressBarOptions options = null) + { + var pbar = new ChildProgressBar(maxTicks, message, DisplayProgress, options, this.Grow); + this.Children.Add(pbar); + DisplayProgress(); + return pbar; + } + + public void Tick(string message = null) + { + Interlocked.Increment(ref _currentTick); + + FinishTick(message); + } + + public void Tick(int newTickCount, string message = null) + { + Interlocked.Exchange(ref _currentTick, newTickCount); + + FinishTick(message); + } + + private void FinishTick(string message) + { + if (message != null) + Interlocked.Exchange(ref _message, message); + + if (_currentTick >= _maxTicks) + { + this.EndTime = DateTime.Now; + this.OnDone(); + } + DisplayProgress(); + } + } +} diff --git a/src/mlnet/ProgressBar/ProgressBarHeight.cs b/src/mlnet/ProgressBar/ProgressBarHeight.cs new file mode 100644 index 0000000000..67b683efe8 --- /dev/null +++ b/src/mlnet/ProgressBar/ProgressBarHeight.cs @@ -0,0 +1,11 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.CLI.ShellProgressBar +{ + public enum ProgressBarHeight + { + Increment, Decrement + } +} \ No newline at end of file diff --git a/src/mlnet/ProgressBar/ProgressBarOptions.cs b/src/mlnet/ProgressBar/ProgressBarOptions.cs new file mode 100644 index 0000000000..e8de881af0 --- /dev/null +++ b/src/mlnet/ProgressBar/ProgressBarOptions.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.ML.CLI.ShellProgressBar +{ + /// + /// Control the behaviour of your progressbar + /// + public class ProgressBarOptions + { + private bool _enableTaskBarProgress; + public static readonly ProgressBarOptions Default = new ProgressBarOptions(); + + /// The foreground color of the progress bar, message and time + public ConsoleColor ForegroundColor { get; set; } = ConsoleColor.Green; + + /// The foreground color the progressbar has reached a 100 percent + public ConsoleColor? ForegroundColorDone { get; set; } + + /// The background color of the remainder of the progressbar + public ConsoleColor? BackgroundColor { get; set; } + + /// The character to use to draw the progressbar + public char ProgressCharacter { get; set; } = '\u2588'; + + /// + /// The character to use for the background of the progress defaults to + /// + public char? BackgroundCharacter { get; set; } + + /// + /// When true will redraw the progressbar using a timer, otherwise only update when + /// is called. + /// Defaults to true + /// + public bool DisplayTimeInRealTime { get; set; } = true; + + /// + /// Collapse the progressbar when done, very useful for child progressbars + /// Defaults to true + /// + public bool CollapseWhenFinished { get; set; } = true; + + /// + /// By default the text and time information is displayed at the bottom and the progress bar at the top. + /// This setting swaps their position + /// + public bool ProgressBarOnBottom { get; set; } + + /// + /// Use Windows' task bar to display progress. + /// + /// + /// This feature is available on the Windows platform. + /// + public bool EnableTaskBarProgress + { + get => _enableTaskBarProgress; + set + { + if (value && !RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + throw new NotSupportedException("Task bar progress only works on Windows"); + + _enableTaskBarProgress = value; + } + } + } +} diff --git a/src/mlnet/ProgressBar/StringExtensions.cs b/src/mlnet/ProgressBar/StringExtensions.cs new file mode 100644 index 0000000000..7d9bdc8015 --- /dev/null +++ b/src/mlnet/ProgressBar/StringExtensions.cs @@ -0,0 +1,16 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.ML.CLI.ShellProgressBar +{ + internal static class StringExtensions + { + public static string Excerpt(string phrase, int length = 60) + { + if (string.IsNullOrEmpty(phrase) || phrase.Length < length) + return phrase; + return phrase.Substring(0, length - 3) + "..."; + } + } +} diff --git a/src/mlnet/ProgressBar/TaskbarProgress.cs b/src/mlnet/ProgressBar/TaskbarProgress.cs new file mode 100644 index 0000000000..19dbab3c70 --- /dev/null +++ b/src/mlnet/ProgressBar/TaskbarProgress.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.InteropServices; + +namespace Microsoft.ML.CLI.ShellProgressBar +{ + public static class TaskbarProgress + { + public enum TaskbarStates + { + NoProgress = 0, + Indeterminate = 0x1, + Normal = 0x2, + Error = 0x4, + Paused = 0x8 + } + + [ComImport()] + [Guid("ea1afb91-9e28-4b86-90e9-9e9f8a5eefaf")] + [InterfaceType(ComInterfaceType.InterfaceIsIUnknown)] + private interface ITaskbarList3 + { + // ITaskbarList + [PreserveSig] + void HrInit(); + + [PreserveSig] + void AddTab(IntPtr hwnd); + + [PreserveSig] + void DeleteTab(IntPtr hwnd); + + [PreserveSig] + void ActivateTab(IntPtr hwnd); + + [PreserveSig] + void SetActiveAlt(IntPtr hwnd); + + // ITaskbarList2 + [PreserveSig] + void MarkFullscreenWindow(IntPtr hwnd, [MarshalAs(UnmanagedType.Bool)] bool fFullscreen); + + // ITaskbarList3 + [PreserveSig] + void SetProgressValue(IntPtr hwnd, UInt64 ullCompleted, UInt64 ullTotal); + + [PreserveSig] + void SetProgressState(IntPtr hwnd, TaskbarStates state); + } + + [ComImport] + [Guid("56fdf344-fd6d-11d0-958a-006097c9a090")] + [ClassInterface(ClassInterfaceType.None)] + private class TaskbarInstance + { } + + [DllImport("kernel32.dll")] + static extern IntPtr GetConsoleWindow(); + + private static readonly ITaskbarList3 _taskbarInstance = (ITaskbarList3)new TaskbarInstance(); + private static readonly bool _taskbarSupported = RuntimeInformation.IsOSPlatform(OSPlatform.Windows); + + public static void SetState(TaskbarStates taskbarState) + { + if (_taskbarSupported) + _taskbarInstance.SetProgressState(GetConsoleWindow(), taskbarState); + } + + public static void SetValue(double progressValue, double progressMax) + { + if (_taskbarSupported) + _taskbarInstance.SetProgressValue(GetConsoleWindow(), (ulong)progressValue, (ulong)progressMax); + } + } +} diff --git a/src/mlnet/Strings.resx b/src/mlnet/Strings.resx new file mode 100644 index 0000000000..e5f203d5f5 --- /dev/null +++ b/src/mlnet/Strings.resx @@ -0,0 +1,195 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + text/microsoft-resx + + + 2.0 + + + System.Resources.ResXResourceReader, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + System.Resources.ResXResourceWriter, System.Windows.Forms, Version=4.0.0.0, Culture=neutral, PublicKeyToken=b77a5c561934e089 + + + Best pipeline + + + Creating Data loader ... + + + Exiting ... + + + Exploring multiple ML algorithms and settings to find you the best model for ML task + + + Exception occured while exploring pipelines + + + Generating a console project for the best pipeline at location + + + An Error occured during inferring columns + + + Inferring Columns ... + + + Loading data ... + + + Metrics for Binary Classification models + + + Metrics for regression models + + + Metrics for multi-class models + + + Retrieving best pipeline ... + + + Generated trained model for consumption + + + Unsupported ml-task + + + Generated log file + + + Generated C# code for model consumption + + + Generated C# code for model training + + + For further learning check + + + https://aka.ms/mlnet-cli + + + Waiting for the first iteration to complete ... + + + Waiting for the last iteration to complete ... + + + Error occured while retreiving best pipeline. + + + Please see the log file for more info. + + + Check out log file for more information + + \ No newline at end of file diff --git a/src/mlnet/Telemetry/DotNetAppInsights/BashPathUnderHomeDirectory.cs b/src/mlnet/Telemetry/DotNetAppInsights/BashPathUnderHomeDirectory.cs new file mode 100644 index 0000000000..58e8c3ea29 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/BashPathUnderHomeDirectory.cs @@ -0,0 +1,26 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; + +namespace Microsoft.DotNet.Configurer +{ + public struct BashPathUnderHomeDirectory + { + private readonly string _fullHomeDirectoryPath; + private readonly string _pathRelativeToHome; + + public BashPathUnderHomeDirectory(string fullHomeDirectoryPath, string pathRelativeToHome) + { + _fullHomeDirectoryPath = + fullHomeDirectoryPath ?? throw new ArgumentNullException(nameof(fullHomeDirectoryPath)); + _pathRelativeToHome = pathRelativeToHome ?? throw new ArgumentNullException(nameof(pathRelativeToHome)); + } + + public string PathWithTilde => $"~/{_pathRelativeToHome}"; + + public string PathWithDollar => $"$HOME/{_pathRelativeToHome}"; + + public string Path => $"{_fullHomeDirectoryPath}/{_pathRelativeToHome}"; + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/CliFolderPathCalculator.cs b/src/mlnet/Telemetry/DotNetAppInsights/CliFolderPathCalculator.cs new file mode 100644 index 0000000000..727b49086b --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/CliFolderPathCalculator.cs @@ -0,0 +1,61 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.IO; +using System.Runtime.InteropServices; +using Microsoft.DotNet.Cli.Utils; + +namespace Microsoft.DotNet.Configurer +{ + public static class CliFolderPathCalculator + { + public const string DotnetHomeVariableName = "DOTNET_CLI_HOME"; + private const string DotnetProfileDirectoryName = ".dotnet"; + private const string ToolsShimFolderName = "tools"; + private const string ToolsResolverCacheFolderName = "toolResolverCache"; + + public static string CliFallbackFolderPath => + Environment.GetEnvironmentVariable("DOTNET_CLI_TEST_FALLBACKFOLDER") ?? + Path.Combine(new DirectoryInfo(AppContext.BaseDirectory).Parent.FullName, "NuGetFallbackFolder"); + + public static string ToolsShimPath => Path.Combine(DotnetUserProfileFolderPath, ToolsShimFolderName); + + public static string ToolsPackagePath => ToolPackageFolderPathCalculator.GetToolPackageFolderPath(ToolsShimPath); + + public static BashPathUnderHomeDirectory ToolsShimPathInUnix => + new BashPathUnderHomeDirectory( + DotnetHomePath, + Path.Combine(DotnetProfileDirectoryName, ToolsShimFolderName)); + + public static string DotnetUserProfileFolderPath => + Path.Combine(DotnetHomePath, DotnetProfileDirectoryName); + + public static string ToolsResolverCachePath => Path.Combine(DotnetUserProfileFolderPath, ToolsResolverCacheFolderName); + + public static string PlatformHomeVariableName => + RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? "USERPROFILE" : "HOME"; + + public static string DotnetHomePath + { + get + { + var home = Environment.GetEnvironmentVariable(DotnetHomeVariableName); + if (string.IsNullOrEmpty(home)) + { + home = Environment.GetEnvironmentVariable(PlatformHomeVariableName); + if (string.IsNullOrEmpty(home)) + { + throw new ConfigurationException( + string.Format( + "The user's home directory could not be determined. Set the '{0}' environment variable to specify the directory to use.", + DotnetHomeVariableName)) + .DisplayAsError(); + } + } + + return home; + } + } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/CollectionsExtensions.cs b/src/mlnet/Telemetry/DotNetAppInsights/CollectionsExtensions.cs new file mode 100644 index 0000000000..08779f65e0 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/CollectionsExtensions.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.DotNet.Cli.Utils +{ + public static class CollectionsExtensions + { + public static IEnumerable OrEmptyIfNull(this IEnumerable enumerable) + { + return enumerable == null + ? Enumerable.Empty() + : enumerable; + } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/ConfigurationException.cs b/src/mlnet/Telemetry/DotNetAppInsights/ConfigurationException.cs new file mode 100644 index 0000000000..591203c4ab --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/ConfigurationException.cs @@ -0,0 +1,22 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; + +namespace Microsoft.DotNet.Configurer +{ + internal class ConfigurationException : Exception + { + public ConfigurationException() + { + } + + public ConfigurationException(string message) : base(message) + { + } + + public ConfigurationException(string message, Exception innerException) : base(message, innerException) + { + } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/DirectoryWrapper.cs b/src/mlnet/Telemetry/DotNetAppInsights/DirectoryWrapper.cs new file mode 100644 index 0000000000..c09bd8dfa2 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/DirectoryWrapper.cs @@ -0,0 +1,52 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Collections.Generic; +using System.IO; +using Microsoft.DotNet.InternalAbstractions; + +namespace Microsoft.Extensions.EnvironmentAbstractions +{ + internal class DirectoryWrapper: IDirectory + { + public bool Exists(string path) + { + return Directory.Exists(path); + } + + public ITemporaryDirectory CreateTemporaryDirectory() + { + return new TemporaryDirectory(); + } + + public IEnumerable EnumerateFiles(string path) + { + return Directory.EnumerateFiles(path); + } + + public IEnumerable EnumerateFileSystemEntries(string path) + { + return Directory.EnumerateFileSystemEntries(path); + } + + public string GetCurrentDirectory() + { + return Directory.GetCurrentDirectory(); + } + + public void CreateDirectory(string path) + { + Directory.CreateDirectory(path); + } + + public void Delete(string path, bool recursive) + { + Directory.Delete(path, recursive); + } + + public void Move(string source, string destination) + { + Directory.Move(source, destination); + } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/Env.cs b/src/mlnet/Telemetry/DotNetAppInsights/Env.cs new file mode 100644 index 0000000000..6aad6f4c49 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/Env.cs @@ -0,0 +1,40 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Collections.Generic; + +namespace Microsoft.DotNet.Cli.Utils +{ + public static class Env + { + private static IEnvironmentProvider _environment = new EnvironmentProvider(); + + public static IEnumerable ExecutableExtensions + { + get + { + return _environment.ExecutableExtensions; + } + } + + public static string GetCommandPath(string commandName, params string[] extensions) + { + return _environment.GetCommandPath(commandName, extensions); + } + + public static string GetCommandPathFromRootPath(string rootPath, string commandName, params string[] extensions) + { + return _environment.GetCommandPathFromRootPath(rootPath, commandName, extensions); + } + + public static string GetCommandPathFromRootPath(string rootPath, string commandName, IEnumerable extensions) + { + return _environment.GetCommandPathFromRootPath(rootPath, commandName, extensions); + } + + public static bool GetEnvironmentVariableAsBool(string name, bool defaultValue = false) + { + return _environment.GetEnvironmentVariableAsBool(name, defaultValue); + } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/EnvironmentProvider.cs b/src/mlnet/Telemetry/DotNetAppInsights/EnvironmentProvider.cs new file mode 100644 index 0000000000..f0a9fefbe9 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/EnvironmentProvider.cs @@ -0,0 +1,155 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.DotNet.PlatformAbstractions; + +namespace Microsoft.DotNet.Cli.Utils +{ + public class EnvironmentProvider : IEnvironmentProvider + { + private static char[] s_pathSeparator = new char[] { Path.PathSeparator }; + private static char[] s_quote = new char[] { '"' }; + private IEnumerable _searchPaths; + private readonly Lazy _userHomeDirectory = new Lazy(() => Environment.GetEnvironmentVariable("HOME") ?? string.Empty); + private IEnumerable _executableExtensions; + + public IEnumerable ExecutableExtensions + { + get + { + if (_executableExtensions == null) + { + + _executableExtensions = RuntimeEnvironment.OperatingSystemPlatform == Platform.Windows + ? Environment.GetEnvironmentVariable("PATHEXT") + .Split(';') + .Select(e => e.ToLower().Trim('"')) + : new [] { string.Empty }; + } + + return _executableExtensions; + } + } + + private IEnumerable SearchPaths + { + get + { + if (_searchPaths == null) + { + var searchPaths = new List { ApplicationEnvironment.ApplicationBasePath }; + + searchPaths.AddRange(Environment + .GetEnvironmentVariable("PATH") + .Split(s_pathSeparator) + .Select(p => p.Trim(s_quote)) + .Where(p => !string.IsNullOrWhiteSpace(p)) + .Select(p => ExpandTildeSlash(p))); + + _searchPaths = searchPaths; + } + + return _searchPaths; + } + } + + private string ExpandTildeSlash(string path) + { + const string tildeSlash = "~/"; + if (path.StartsWith(tildeSlash, StringComparison.Ordinal) && !string.IsNullOrEmpty(_userHomeDirectory.Value)) + { + return Path.Combine(_userHomeDirectory.Value, path.Substring(tildeSlash.Length)); + } + else + { + return path; + } + } + + public EnvironmentProvider( + IEnumerable extensionsOverride = null, + IEnumerable searchPathsOverride = null) + { + _executableExtensions = extensionsOverride; + _searchPaths = searchPathsOverride; + } + + public string GetCommandPath(string commandName, params string[] extensions) + { + if (!extensions.Any()) + { + extensions = ExecutableExtensions.ToArray(); + } + + var commandPath = SearchPaths.Join( + extensions, + p => true, s => true, + (p, s) => Path.Combine(p, commandName + s)) + .FirstOrDefault(File.Exists); + + return commandPath; + } + + public string GetCommandPathFromRootPath(string rootPath, string commandName, params string[] extensions) + { + if (!extensions.Any()) + { + extensions = ExecutableExtensions.ToArray(); + } + + var commandPath = extensions.Select(e => Path.Combine(rootPath, commandName + e)) + .FirstOrDefault(File.Exists); + + return commandPath; + } + + public string GetCommandPathFromRootPath(string rootPath, string commandName, IEnumerable extensions) + { + var extensionsArr = extensions.OrEmptyIfNull().ToArray(); + + return GetCommandPathFromRootPath(rootPath, commandName, extensionsArr); + } + + public string GetEnvironmentVariable(string name) + { + return Environment.GetEnvironmentVariable(name); + } + + public bool GetEnvironmentVariableAsBool(string name, bool defaultValue) + { + var str = Environment.GetEnvironmentVariable(name); + if (string.IsNullOrEmpty(str)) + { + return defaultValue; + } + + switch (str.ToLowerInvariant()) + { + case "true": + case "1": + case "yes": + return true; + case "false": + case "0": + case "no": + return false; + default: + return defaultValue; + } + } + + public string GetEnvironmentVariable(string variable, EnvironmentVariableTarget target) + { + return Environment.GetEnvironmentVariable(variable, target); + } + + public void SetEnvironmentVariable(string variable, string value, EnvironmentVariableTarget target) + { + Environment.SetEnvironmentVariable(variable, value, target); + } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/ExceptionExtensions.cs b/src/mlnet/Telemetry/DotNetAppInsights/ExceptionExtensions.cs new file mode 100644 index 0000000000..03f984404c --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/ExceptionExtensions.cs @@ -0,0 +1,22 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; + +namespace Microsoft.DotNet.Cli.Utils +{ + internal static class ExceptionExtensions + { + public static TException DisplayAsError(this TException exception) + where TException : Exception + { + exception.Data.Add(CLI_User_Displayed_Exception, true); + return exception; + } + + public static bool ShouldBeDisplayedAsError(this Exception e) => + e.Data.Contains(CLI_User_Displayed_Exception); + + internal const string CLI_User_Displayed_Exception = "CLI_User_Displayed_Exception"; + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/FileSystemWrapper.cs b/src/mlnet/Telemetry/DotNetAppInsights/FileSystemWrapper.cs new file mode 100644 index 0000000000..8818074310 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/FileSystemWrapper.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.Extensions.EnvironmentAbstractions +{ + internal class FileSystemWrapper : IFileSystem + { + public static IFileSystem Default { get; } = new FileSystemWrapper(); + + public IFile File { get; } = new FileWrapper(); + + public IDirectory Directory { get; } = new DirectoryWrapper(); + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/FileWrapper.cs b/src/mlnet/Telemetry/DotNetAppInsights/FileWrapper.cs new file mode 100644 index 0000000000..c46f8e37ae --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/FileWrapper.cs @@ -0,0 +1,64 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.IO; + +namespace Microsoft.Extensions.EnvironmentAbstractions +{ + internal class FileWrapper: IFile + { + public bool Exists(string path) + { + return File.Exists(path); + } + + public string ReadAllText(string path) + { + return File.ReadAllText(path); + } + + public Stream OpenRead(string path) + { + return File.OpenRead(path); + } + + public Stream OpenFile( + string path, + FileMode fileMode, + FileAccess fileAccess, + FileShare fileShare, + int bufferSize, + FileOptions fileOptions) + { + return new FileStream(path, fileMode, fileAccess, fileShare, bufferSize, fileOptions); + } + + public void CreateEmptyFile(string path) + { + using (File.Create(path)) + { + } + } + + public void WriteAllText(string path, string content) + { + File.WriteAllText(path, content); + } + + public void Move(string source, string destination) + { + File.Move(source, destination); + } + + public void Copy(string sourceFileName, string destFileName) + { + File.Copy(sourceFileName, destFileName); + } + + public void Delete(string path) + { + File.Delete(path); + } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/FirstTimeUseNoticeSentinel.cs b/src/mlnet/Telemetry/DotNetAppInsights/FirstTimeUseNoticeSentinel.cs new file mode 100644 index 0000000000..0ea9dfd9dd --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/FirstTimeUseNoticeSentinel.cs @@ -0,0 +1,58 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.IO; +using Microsoft.DotNet.AutoML; +using Microsoft.Extensions.EnvironmentAbstractions; + +namespace Microsoft.DotNet.Configurer +{ + public class FirstTimeUseNoticeSentinel : IFirstTimeUseNoticeSentinel + { + public static readonly string SENTINEL = $"{Product.Version}.MLNET.dotnetFirstUseSentinel"; + + private readonly IFile _file; + private readonly IDirectory _directory; + + private string _dotnetUserProfileFolderPath; + + private string SentinelPath => Path.Combine(_dotnetUserProfileFolderPath, SENTINEL); + + public FirstTimeUseNoticeSentinel() : + this( + CliFolderPathCalculator.DotnetUserProfileFolderPath, + FileSystemWrapper.Default.File, + FileSystemWrapper.Default.Directory) + { + } + + internal FirstTimeUseNoticeSentinel(string dotnetUserProfileFolderPath, IFile file, IDirectory directory) + { + _file = file; + _directory = directory; + _dotnetUserProfileFolderPath = dotnetUserProfileFolderPath; + } + + public bool Exists() + { + return _file.Exists(SentinelPath); + } + + public void CreateIfNotExists() + { + if (!Exists()) + { + if (!_directory.Exists(_dotnetUserProfileFolderPath)) + { + _directory.CreateDirectory(_dotnetUserProfileFolderPath); + } + + _file.CreateEmptyFile(SentinelPath); + } + } + + public void Dispose() + { + } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/IDirectory.cs b/src/mlnet/Telemetry/DotNetAppInsights/IDirectory.cs new file mode 100644 index 0000000000..f6bfb38850 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/IDirectory.cs @@ -0,0 +1,26 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Collections.Generic; + +namespace Microsoft.Extensions.EnvironmentAbstractions +{ + internal interface IDirectory + { + bool Exists(string path); + + ITemporaryDirectory CreateTemporaryDirectory(); + + IEnumerable EnumerateFiles(string path); + + IEnumerable EnumerateFileSystemEntries(string path); + + string GetCurrentDirectory(); + + void CreateDirectory(string path); + + void Delete(string path, bool recursive); + + void Move(string source, string destination); + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/IEnvironmentProvider.cs b/src/mlnet/Telemetry/DotNetAppInsights/IEnvironmentProvider.cs new file mode 100644 index 0000000000..17355e4a03 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/IEnvironmentProvider.cs @@ -0,0 +1,27 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.DotNet.Cli.Utils +{ + public interface IEnvironmentProvider + { + IEnumerable ExecutableExtensions { get; } + + string GetCommandPath(string commandName, params string[] extensions); + + string GetCommandPathFromRootPath(string rootPath, string commandName, params string[] extensions); + + string GetCommandPathFromRootPath(string rootPath, string commandName, IEnumerable extensions); + + bool GetEnvironmentVariableAsBool(string name, bool defaultValue); + + string GetEnvironmentVariable(string name); + + string GetEnvironmentVariable(string variable, EnvironmentVariableTarget target); + + void SetEnvironmentVariable(string variable, string value, EnvironmentVariableTarget target); + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/IFile.cs b/src/mlnet/Telemetry/DotNetAppInsights/IFile.cs new file mode 100644 index 0000000000..044297b6ef --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/IFile.cs @@ -0,0 +1,34 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.IO; + +namespace Microsoft.Extensions.EnvironmentAbstractions +{ + internal interface IFile + { + bool Exists(string path); + + string ReadAllText(string path); + + Stream OpenRead(string path); + + Stream OpenFile( + string path, + FileMode fileMode, + FileAccess fileAccess, + FileShare fileShare, + int bufferSize, + FileOptions fileOptions); + + void CreateEmptyFile(string path); + + void WriteAllText(string path, string content); + + void Move(string source, string destination); + + void Copy(string source, string destination); + + void Delete(string path); + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/IFileSentinel.cs b/src/mlnet/Telemetry/DotNetAppInsights/IFileSentinel.cs new file mode 100644 index 0000000000..f8fd8c7028 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/IFileSentinel.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; + +namespace Microsoft.DotNet.Configurer +{ + public interface IFileSentinel + { + bool Exists(); + + void Create(); + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/IFileSystem.cs b/src/mlnet/Telemetry/DotNetAppInsights/IFileSystem.cs new file mode 100644 index 0000000000..87e5f98631 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/IFileSystem.cs @@ -0,0 +1,11 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +namespace Microsoft.Extensions.EnvironmentAbstractions +{ + internal interface IFileSystem + { + IFile File { get; } + IDirectory Directory { get; } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/IFirstTimeUseNoticeSentinel.cs b/src/mlnet/Telemetry/DotNetAppInsights/IFirstTimeUseNoticeSentinel.cs new file mode 100644 index 0000000000..c0d1878fa0 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/IFirstTimeUseNoticeSentinel.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; + +namespace Microsoft.DotNet.Configurer +{ + public interface IFirstTimeUseNoticeSentinel : IDisposable + { + bool Exists(); + + void CreateIfNotExists(); + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/ITelemetry.cs b/src/mlnet/Telemetry/DotNetAppInsights/ITelemetry.cs new file mode 100644 index 0000000000..3dc4143abf --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/ITelemetry.cs @@ -0,0 +1,15 @@ + +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Collections.Generic; + +namespace Microsoft.DotNet.Cli.Telemetry +{ + public interface ITelemetry + { + bool Enabled { get; } + + void TrackEvent(string eventName, IDictionary properties, IDictionary measurements); + } +} \ No newline at end of file diff --git a/src/mlnet/Telemetry/DotNetAppInsights/ITemporaryDirectory.cs b/src/mlnet/Telemetry/DotNetAppInsights/ITemporaryDirectory.cs new file mode 100644 index 0000000000..1c9bd4b759 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/ITemporaryDirectory.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; + +namespace Microsoft.Extensions.EnvironmentAbstractions +{ + internal interface ITemporaryDirectory : IDisposable + { + string DirectoryPath { get; } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/IUserLevelCacheWriter.cs b/src/mlnet/Telemetry/DotNetAppInsights/IUserLevelCacheWriter.cs new file mode 100644 index 0000000000..c41b55920f --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/IUserLevelCacheWriter.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; + +namespace Microsoft.DotNet.Configurer +{ + public interface IUserLevelCacheWriter + { + string RunWithCache(string cacheKey, Func getValueToCache); + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/MacAddressGetter.cs b/src/mlnet/Telemetry/DotNetAppInsights/MacAddressGetter.cs new file mode 100644 index 0000000000..1106fbdc90 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/MacAddressGetter.cs @@ -0,0 +1,168 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using System.Diagnostics; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text.RegularExpressions; +using System.Net.NetworkInformation; +using System.ComponentModel; +using Microsoft.DotNet.Cli.Utils; + +namespace Microsoft.DotNet.Cli.Telemetry +{ + internal static class MacAddressGetter + { + private const string MacRegex = @"(?:[a-z0-9]{2}[:\-]){5}[a-z0-9]{2}"; + private const string ZeroRegex = @"(?:00[:\-]){5}00"; + private const int ErrorFileNotFound = 0x2; + public static string GetMacAddress() + { + try + { + var shelloutput = GetShellOutMacAddressOutput(); + if (shelloutput == null) + { + return null; + } + + return ParseMACAddress(shelloutput); + } + catch (Win32Exception e) + { + if (e.NativeErrorCode == ErrorFileNotFound) + { + return GetMacAddressByNetworkInterface(); + } + else + { + throw; + } + } + } + + private static string ParseMACAddress(string shelloutput) + { + string macAddress = null; + foreach (Match match in Regex.Matches(shelloutput, MacRegex, RegexOptions.IgnoreCase)) + { + if (!Regex.IsMatch(match.Value, ZeroRegex)) + { + macAddress = match.Value; + break; + } + } + + if (macAddress != null) + { + return macAddress; + } + return null; + } + + private static string GetIpCommandOutput() + { + var ipResult = new ProcessStartInfo + { + FileName = "ip", + Arguments = "link", + UseShellExecute = false + }.ExecuteAndCaptureOutput(out string ipStdOut, out string ipStdErr); + + if (ipResult == 0) + { + return ipStdOut; + } + else + { + return null; + } + } + + private static string GetShellOutMacAddressOutput() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + var result = new ProcessStartInfo + { + FileName = "getmac.exe", + UseShellExecute = false + }.ExecuteAndCaptureOutput(out string stdOut, out string stdErr); + + if (result == 0) + { + return stdOut; + } + else + { + return null; + } + } + else + { + try + { + var ifconfigResult = new ProcessStartInfo + { + FileName = "ifconfig", + Arguments = "-a", + UseShellExecute = false + }.ExecuteAndCaptureOutput(out string ifconfigStdOut, out string ifconfigStdErr); + + if (ifconfigResult == 0) + { + return ifconfigStdOut; + } + else + { + return GetIpCommandOutput(); + } + } + catch (Win32Exception e) + { + if (e.NativeErrorCode == ErrorFileNotFound) + { + return GetIpCommandOutput(); + } + else + { + throw; + } + } + } + } + + private static string GetMacAddressByNetworkInterface() + { + return GetMacAddressesByNetworkInterface().FirstOrDefault(); + } + + private static List GetMacAddressesByNetworkInterface() + { + NetworkInterface[] nics = NetworkInterface.GetAllNetworkInterfaces(); + var macs = new List(); + + if (nics == null || nics.Length < 1) + { + macs.Add(string.Empty); + return macs; + } + + foreach (NetworkInterface adapter in nics) + { + IPInterfaceProperties properties = adapter.GetIPProperties(); + + PhysicalAddress address = adapter.GetPhysicalAddress(); + byte[] bytes = address.GetAddressBytes(); + macs.Add(string.Join("-", bytes.Select(x => x.ToString("X2")))); + if (macs.Count >= 10) + { + break; + } + } + return macs; + } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/NativeMethods.cs b/src/mlnet/Telemetry/DotNetAppInsights/NativeMethods.cs new file mode 100644 index 0000000000..a22fcd8378 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/NativeMethods.cs @@ -0,0 +1,80 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Runtime.InteropServices; +using Microsoft.Win32.SafeHandles; + +namespace Microsoft.DotNet.Cli.Utils +{ + internal static class NativeMethods + { + internal static class Windows + { + internal enum JobObjectInfoClass : uint + { + JobObjectExtendedLimitInformation = 9, + } + + [Flags] + internal enum JobObjectLimitFlags : uint + { + JobObjectLimitKillOnJobClose = 0x2000, + } + + [StructLayout(LayoutKind.Sequential)] + internal struct JobObjectBasicLimitInformation + { + public Int64 PerProcessUserTimeLimit; + public Int64 PerJobUserTimeLimit; + public JobObjectLimitFlags LimitFlags; + public UIntPtr MinimumWorkingSetSize; + public UIntPtr MaximumWorkingSetSize; + public UInt32 ActiveProcessLimit; + public UIntPtr Affinity; + public UInt32 PriorityClass; + public UInt32 SchedulingClass; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct IoCounters + { + public UInt64 ReadOperationCount; + public UInt64 WriteOperationCount; + public UInt64 OtherOperationCount; + public UInt64 ReadTransferCount; + public UInt64 WriteTransferCount; + public UInt64 OtherTransferCount; + } + + [StructLayout(LayoutKind.Sequential)] + internal struct JobObjectExtendedLimitInformation + { + public JobObjectBasicLimitInformation BasicLimitInformation; + public IoCounters IoInfo; + public UIntPtr ProcessMemoryLimit; + public UIntPtr JobMemoryLimit; + public UIntPtr PeakProcessMemoryUsed; + public UIntPtr PeakJobMemoryUsed; + } + + [DllImport("kernel32.dll", CharSet = CharSet.Unicode, SetLastError = true)] + internal static extern SafeWaitHandle CreateJobObjectW(IntPtr lpJobAttributes, string lpName); + + [DllImport("kernel32.dll", SetLastError = true)] + internal static extern bool SetInformationJobObject(IntPtr hJob, JobObjectInfoClass jobObjectInformationClass, IntPtr lpJobObjectInformation, UInt32 cbJobObjectInformationLength); + + [DllImport("kernel32.dll", SetLastError = true)] + internal static extern bool AssignProcessToJobObject(IntPtr hJob, IntPtr hProcess); + } + + internal static class Posix + { + [DllImport("libc", SetLastError = true)] + internal static extern int kill(int pid, int sig); + + internal const int SIGINT = 2; + internal const int SIGTERM = 15; + } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/ProcessReaper.cs b/src/mlnet/Telemetry/DotNetAppInsights/ProcessReaper.cs new file mode 100644 index 0000000000..e33f2bd658 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/ProcessReaper.cs @@ -0,0 +1,197 @@ +using System; +using System.ComponentModel; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Threading; +using Microsoft.DotNet.PlatformAbstractions; +using Microsoft.Win32.SafeHandles; + +using RuntimeEnvironment = Microsoft.DotNet.PlatformAbstractions.RuntimeEnvironment; + +namespace Microsoft.DotNet.Cli.Utils +{ + /// + /// Responsible for reaping a target process if the current process terminates. + /// + /// + /// On Windows, a job object will be used to ensure the termination of the target + /// process (and its tree) even if the current process is rudely terminated. + /// + /// On POSIX systems, the reaper will handle SIGTERM and attempt to forward the + /// signal to the target process only. + /// + /// The reaper also suppresses SIGINT in the current process to allow the target + /// process to handle the signal. + /// + internal class ProcessReaper : IDisposable + { + /// + /// Creates a new process reaper. + /// + /// The target process to reap if the current process terminates. The process should not yet be started. + public ProcessReaper(Process process) + { + _process = process; + + // The tests need the event handlers registered prior to spawning the child to prevent a race + // where the child writes output the test expects before the intermediate dotnet process + // has registered the event handlers to handle the signals the tests will generate. + Console.CancelKeyPress += HandleCancelKeyPress; + if (RuntimeEnvironment.OperatingSystemPlatform != Platform.Windows) + { + _shutdownMutex = new Mutex(); + AppDomain.CurrentDomain.ProcessExit += HandleProcessExit; + } + } + + /// + /// Call to notify the reaper that the process has started. + /// + public void NotifyProcessStarted() + { + if (RuntimeEnvironment.OperatingSystemPlatform == Platform.Windows) + { + // Limit the use of job objects to versions of Windows that support nested jobs (i.e. Windows 8/2012 or later). + // Ideally, we would check for some new API export or OS feature instead of the OS version, + // but nested jobs are transparently implemented with respect to the Job Objects API. + // Note: Windows 8.1 and later may report as Windows 8 (see https://docs.microsoft.com/en-us/windows/desktop/sysinfo/operating-system-version). + // However, for the purpose of this check that is still sufficient. + if (Environment.OSVersion.Version.Major > 6 || + (Environment.OSVersion.Version.Major == 6 && Environment.OSVersion.Version.Minor >= 2)) + { + _job = AssignProcessToJobObject(_process.Handle); + } + } + } + + public void Dispose() + { + if (RuntimeEnvironment.OperatingSystemPlatform == Platform.Windows) + { + if (_job != null) + { + // Clear the kill on close flag because the child process terminated successfully + // If this fails, then we have no choice but to terminate any remaining processes in the job + SetKillOnJobClose(_job.DangerousGetHandle(), false); + + _job.Dispose(); + _job = null; + } + } + else + { + AppDomain.CurrentDomain.ProcessExit -= HandleProcessExit; + + // If there's been a shutdown via the process exit handler, + // this will block the current thread so we don't race with the CLR shutdown + // from the signal handler. + if (_shutdownMutex != null) + { + _shutdownMutex.WaitOne(); + _shutdownMutex.ReleaseMutex(); + _shutdownMutex.Dispose(); + _shutdownMutex = null; + } + } + + Console.CancelKeyPress -= HandleCancelKeyPress; + } + + private static void HandleCancelKeyPress(object sender, ConsoleCancelEventArgs e) + { + // Ignore SIGINT/SIGQUIT so that the process can handle the signal + e.Cancel = true; + } + + private static SafeWaitHandle AssignProcessToJobObject(IntPtr process) + { + var job = NativeMethods.Windows.CreateJobObjectW(IntPtr.Zero, null); + if (job == null || job.IsInvalid) + { + return null; + } + + if (!SetKillOnJobClose(job.DangerousGetHandle(), true)) + { + job.Dispose(); + return null; + } + + if (!NativeMethods.Windows.AssignProcessToJobObject(job.DangerousGetHandle(), process)) + { + job.Dispose(); + return null; + } + + return job; + } + + private void HandleProcessExit(object sender, EventArgs args) + { + int processId; + try + { + processId = _process.Id; + } + catch (InvalidOperationException) + { + // The process hasn't started yet; nothing to signal + return; + } + + // Take ownership of the shutdown mutex; this will ensure that the other + // thread also waiting on the process to exit won't complete CLR shutdown before + // this one does. + _shutdownMutex.WaitOne(); + + if (!_process.WaitForExit(0) && NativeMethods.Posix.kill(processId, NativeMethods.Posix.SIGTERM) != 0) + { + // Couldn't send the signal, don't wait + return; + } + + // If SIGTERM was ignored by the target, then we'll still wait + _process.WaitForExit(); + + Environment.ExitCode = _process.ExitCode; + } + + private static bool SetKillOnJobClose(IntPtr job, bool value) + { + var information = new NativeMethods.Windows.JobObjectExtendedLimitInformation + { + BasicLimitInformation = new NativeMethods.Windows.JobObjectBasicLimitInformation + { + LimitFlags = (value ? NativeMethods.Windows.JobObjectLimitFlags.JobObjectLimitKillOnJobClose : 0) + } + }; + + var length = Marshal.SizeOf(typeof(NativeMethods.Windows.JobObjectExtendedLimitInformation)); + var informationPtr = Marshal.AllocHGlobal(length); + + try + { + Marshal.StructureToPtr(information, informationPtr, false); + + if (!NativeMethods.Windows.SetInformationJobObject( + job, + NativeMethods.Windows.JobObjectInfoClass.JobObjectExtendedLimitInformation, + informationPtr, + (uint)length)) + { + return false; + } + + return true; + } + finally + { + Marshal.FreeHGlobal(informationPtr); + } + } + + private Process _process; + private SafeWaitHandle _job; + private Mutex _shutdownMutex; + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/ProcessStartInfoExtensions.cs b/src/mlnet/Telemetry/DotNetAppInsights/ProcessStartInfoExtensions.cs new file mode 100644 index 0000000000..075ba02391 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/ProcessStartInfoExtensions.cs @@ -0,0 +1,68 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Diagnostics; + +namespace Microsoft.DotNet.Cli.Utils +{ + internal static class ProcessStartInfoExtensions + { + public static int Execute(this ProcessStartInfo startInfo) + { + if (startInfo == null) + { + throw new ArgumentNullException(nameof(startInfo)); + } + + var process = new Process + { + StartInfo = startInfo + }; + + using (var reaper = new ProcessReaper(process)) + { + process.Start(); + reaper.NotifyProcessStarted(); + process.WaitForExit(); + } + + return process.ExitCode; + } + + public static int ExecuteAndCaptureOutput(this ProcessStartInfo startInfo, out string stdOut, out string stdErr) + { + var outStream = new StreamForwarder().Capture(); + var errStream = new StreamForwarder().Capture(); + + startInfo.RedirectStandardOutput = true; + startInfo.RedirectStandardError = true; + + var process = new Process + { + StartInfo = startInfo + }; + + process.EnableRaisingEvents = true; + + using (var reaper = new ProcessReaper(process)) + { + process.Start(); + reaper.NotifyProcessStarted(); + + var taskOut = outStream.BeginRead(process.StandardOutput); + var taskErr = errStream.BeginRead(process.StandardError); + + process.WaitForExit(); + + taskOut.Wait(); + taskErr.Wait(); + + stdOut = outStream.CapturedOutput; + stdErr = errStream.CapturedOutput; + } + + return process.ExitCode; + } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/Sha256Hasher.cs b/src/mlnet/Telemetry/DotNetAppInsights/Sha256Hasher.cs new file mode 100644 index 0000000000..4999199cea --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/Sha256Hasher.cs @@ -0,0 +1,36 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Security.Cryptography; +using System.Text; +namespace Microsoft.DotNet.Cli.Telemetry +{ + internal static class Sha256Hasher + { + /// + /// The hashed mac address needs to be the same hashed value as produced by the other distinct sources given the same input. (e.g. VsCode) + /// + public static string Hash(string text) + { + var sha256 = SHA256.Create(); + return HashInFormat(sha256, text); + } + + public static string HashWithNormalizedCasing(string text) + { + return Hash(text.ToUpperInvariant()); + } + + private static string HashInFormat(SHA256 sha256, string text) + { + byte[] bytes = Encoding.UTF8.GetBytes(text); + byte[] hash = sha256.ComputeHash(bytes); + StringBuilder hashString = new StringBuilder(); + foreach (byte x in hash) + { + hashString.AppendFormat("{0:x2}", x); + } + return hashString.ToString(); + } + } +} \ No newline at end of file diff --git a/src/mlnet/Telemetry/DotNetAppInsights/StreamForwarder.cs b/src/mlnet/Telemetry/DotNetAppInsights/StreamForwarder.cs new file mode 100644 index 0000000000..e3d96d07e8 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/StreamForwarder.cs @@ -0,0 +1,133 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.DotNet.Cli.Utils +{ + public sealed class StreamForwarder + { + private static readonly char[] s_ignoreCharacters = new char[] { '\r' }; + private static readonly char s_flushBuilderCharacter = '\n'; + + private StringBuilder _builder; + private StringWriter _capture; + private Action _writeLine; + + public string CapturedOutput + { + get + { + return _capture?.GetStringBuilder()?.ToString(); + } + } + + public StreamForwarder Capture() + { + ThrowIfCaptureSet(); + + _capture = new StringWriter(); + + return this; + } + + public StreamForwarder ForwardTo(Action writeLine) + { + ThrowIfNull(writeLine); + + ThrowIfForwarderSet(); + + _writeLine = writeLine; + + return this; + } + + public Task BeginRead(TextReader reader) + { + return Task.Run(() => Read(reader)); + } + + public void Read(TextReader reader) + { + var bufferSize = 1; + + int readCharacterCount; + char currentCharacter; + + var buffer = new char[bufferSize]; + _builder = new StringBuilder(); + + // Using Read with buffer size 1 to prevent looping endlessly + // like we would when using Read() with no buffer + while ((readCharacterCount = reader.Read(buffer, 0, bufferSize)) > 0) + { + currentCharacter = buffer[0]; + + if (currentCharacter == s_flushBuilderCharacter) + { + WriteBuilder(); + } + else if (!s_ignoreCharacters.Contains(currentCharacter)) + { + _builder.Append(currentCharacter); + } + } + + // Flush anything else when the stream is closed + // Which should only happen if someone used console.Write + WriteBuilder(); + } + + private void WriteBuilder() + { + if (_builder.Length == 0) + { + return; + } + + WriteLine(_builder.ToString()); + _builder.Clear(); + } + + private void WriteLine(string str) + { + if (_capture != null) + { + _capture.WriteLine(str); + } + + if (_writeLine != null) + { + _writeLine(str); + } + } + + private void ThrowIfNull(object obj) + { + if (obj == null) + { + throw new ArgumentNullException(nameof(obj)); + } + } + + private void ThrowIfForwarderSet() + { + if (_writeLine != null) + { + throw new InvalidOperationException("LocalizableStrings.WriteLineForwarderSetPreviously"); + } + } + + private void ThrowIfCaptureSet() + { + if (_capture != null) + { + throw new InvalidOperationException("LocalizableStrings.AlreadyCapturingStream"); + } + } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/TelemetryCommonProperties.cs b/src/mlnet/Telemetry/DotNetAppInsights/TelemetryCommonProperties.cs new file mode 100644 index 0000000000..44fcf64ec9 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/TelemetryCommonProperties.cs @@ -0,0 +1,106 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using Microsoft.DotNet.AutoML; +using System.IO; +using Microsoft.DotNet.Configurer; +using RuntimeEnvironment = Microsoft.DotNet.PlatformAbstractions.RuntimeEnvironment; +using RuntimeInformation = System.Runtime.InteropServices.RuntimeInformation; + +namespace Microsoft.DotNet.Cli.Telemetry +{ + internal class TelemetryCommonProperties + { + public TelemetryCommonProperties( + Func getCurrentDirectory = null, + Func hasher = null, + Func getMACAddress = null, + IUserLevelCacheWriter userLevelCacheWriter = null) + { + _getCurrentDirectory = getCurrentDirectory ?? Directory.GetCurrentDirectory; + _hasher = hasher ?? Sha256Hasher.Hash; + _getMACAddress = getMACAddress ?? MacAddressGetter.GetMacAddress; + _userLevelCacheWriter = userLevelCacheWriter ?? new UserLevelCacheWriter(); + } + + private Func _getCurrentDirectory; + private Func _hasher; + private Func _getMACAddress; + private IUserLevelCacheWriter _userLevelCacheWriter; + private const string OSPlatform = "OS Platform"; + private const string ProductVersion = "Product Version"; + private const string TelemetryProfile = "Telemetry Profile"; + private const string MachineId = "Machine ID"; + + private const string TelemetryProfileEnvironmentVariable = "DOTNET_CLI_TELEMETRY_PROFILE"; + private const string CannotFindMacAddress = "Unknown"; + + private const string MachineIdCacheKey = "MLNET_MachineId"; + private const string IsDockerContainerCacheKey = "IsDockerContainer"; + + public Dictionary GetTelemetryCommonProperties() + { + return new Dictionary + { + {OSPlatform, RuntimeEnvironment.OperatingSystemPlatform.ToString()}, + {ProductVersion, Product.Version}, + {TelemetryProfile, Environment.GetEnvironmentVariable(TelemetryProfileEnvironmentVariable)}, + {MachineId, _userLevelCacheWriter.RunWithCache(MachineIdCacheKey, GetMachineId)} + }; + } + + private string GetMachineId() + { + var macAddress = _getMACAddress(); + if (macAddress != null) + { + return _hasher(macAddress); + } + else + { + return Guid.NewGuid().ToString(); + } + } + + /// + /// Returns a string identifying the OS kernel. + /// For Unix this currently comes from "uname -srv". + /// For Windows this currently comes from RtlGetVersion(). + /// + /// Here are some example values: + /// + /// Alpine.36 Linux 4.9.60-linuxkit-aufs #1 SMP Mon Nov 6 16:00:12 UTC 2017 + /// Centos.73 Linux 3.10.0-514.26.2.el7.x86_64 #1 SMP Tue Jul 4 15:04:05 UTC 2017 + /// Debian.87 Linux 3.16.0-4-amd64 #1 SMP Debian 3.16.39-1+deb8u2 (2017-03-07) + /// Debian.90 Linux 4.9.0-2-amd64 #1 SMP Debian 4.9.18-1 (2017-03-30) + /// fedora.25 Linux 4.11.3-202.fc25.x86_64 #1 SMP Mon Jun 5 16:38:21 UTC 2017 + /// Fedora.26 Linux 4.14.15-200.fc26.x86_64 #1 SMP Wed Jan 24 04:26:15 UTC 2018 + /// Fedora.27 Linux 4.14.14-300.fc27.x86_64 #1 SMP Fri Jan 19 13:19:54 UTC 2018 + /// OpenSuse.423 Linux 4.4.104-39-default #1 SMP Thu Jan 4 08:11:03 UTC 2018 (7db1912) + /// RedHat.69 Linux 2.6.32-696.20.1.el6.x86_64 #1 SMP Fri Jan 12 15:07:59 EST 2018 + /// RedHat.72 Linux 3.10.0-514.21.1.el7.x86_64 #1 SMP Sat Apr 22 02:41:35 EDT 2017 + /// RedHat.73 Linux 3.10.0-514.21.1.el7.x86_64 #1 SMP Sat Apr 22 02:41:35 EDT 2017 + /// SLES.12 Linux 4.4.103-6.38-default #1 SMP Mon Dec 25 20:44:33 UTC 2017 (e4b9067) + /// suse.422 Linux 4.4.49-16-default #1 SMP Sun Feb 19 17:40:35 UTC 2017 (70e9954) + /// Ubuntu.1404 Linux 3.19.0-65-generic #73~14.04.1-Ubuntu SMP Wed Jun 29 21:05:22 UTC 2016 + /// Ubuntu.1604 Linux 4.13.0-1005-azure #7-Ubuntu SMP Mon Jan 8 21:37:36 UTC 2018 + /// Ubuntu.1604.WSL Linux 4.4.0-43-Microsoft #1-Microsoft Wed Dec 31 14:42:53 PST 2014 + /// Ubuntu.1610 Linux 4.8.0-45-generic #48-Ubuntu SMP Fri Mar 24 11:46:39 UTC 2017 + /// Ubuntu.1704 Linux 4.10.0-19-generic #21-Ubuntu SMP Thu Apr 6 17:04:57 UTC 2017 + /// Ubuntu.1710 Linux 4.13.0-25-generic #29-Ubuntu SMP Mon Jan 8 21:14:41 UTC 2018 + /// OSX1012 Darwin 16.7.0 Darwin Kernel Version 16.7.0: Thu Jan 11 22:59:40 PST 2018; root:xnu-3789.73.8~1/RELEASE_X86_64 + /// OSX1013 Darwin 17.4.0 Darwin Kernel Version 17.4.0: Sun Dec 17 09:19:54 PST 2017; root:xnu-4570.41.2~1/RELEASE_X86_64 + /// Windows.10 Microsoft Windows 10.0.14393 + /// Windows.10.Core Microsoft Windows 10.0.14393 + /// Windows.10.Nano Microsoft Windows 10.0.14393 + /// Windows.7 Microsoft Windows 6.1.7601 S + /// Windows.81 Microsoft Windows 6.3.9600 + /// + private static string GetKernelVersion() + { + return RuntimeInformation.OSDescription; + } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/TemporaryDirectory.cs b/src/mlnet/Telemetry/DotNetAppInsights/TemporaryDirectory.cs new file mode 100644 index 0000000000..d43683e156 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/TemporaryDirectory.cs @@ -0,0 +1,31 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using Microsoft.Extensions.EnvironmentAbstractions; +using System.IO; + +namespace Microsoft.DotNet.InternalAbstractions +{ + internal class TemporaryDirectory : ITemporaryDirectory + { + public string DirectoryPath { get; } + + public TemporaryDirectory() + { + DirectoryPath = Path.Combine(Path.GetTempPath(), Path.GetRandomFileName()); + Directory.CreateDirectory(DirectoryPath); + } + + public void Dispose() + { + try + { + Directory.Delete(DirectoryPath, true); + } + catch + { + // Ignore failures here. + } + } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/ToolPackageFolderPathCalculator.cs b/src/mlnet/Telemetry/DotNetAppInsights/ToolPackageFolderPathCalculator.cs new file mode 100644 index 0000000000..4323520ddf --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/ToolPackageFolderPathCalculator.cs @@ -0,0 +1,16 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.IO; + +namespace Microsoft.DotNet.Configurer +{ + public static class ToolPackageFolderPathCalculator + { + private const string NestedToolPackageFolderName = ".store"; + public static string GetToolPackageFolderPath(string toolsShimPath) + { + return Path.Combine(toolsShimPath, NestedToolPackageFolderName); + } + } +} diff --git a/src/mlnet/Telemetry/DotNetAppInsights/UserLevelCacheWriter.cs b/src/mlnet/Telemetry/DotNetAppInsights/UserLevelCacheWriter.cs new file mode 100644 index 0000000000..b674859078 --- /dev/null +++ b/src/mlnet/Telemetry/DotNetAppInsights/UserLevelCacheWriter.cs @@ -0,0 +1,73 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.IO; +using Microsoft.DotNet.AutoML; +using Microsoft.Extensions.EnvironmentAbstractions; + +namespace Microsoft.DotNet.Configurer +{ + public class UserLevelCacheWriter : IUserLevelCacheWriter + { + private readonly IFile _file; + private readonly IDirectory _directory; + private string _dotnetUserProfileFolderPath; + + public UserLevelCacheWriter() : + this( + CliFolderPathCalculator.DotnetUserProfileFolderPath, + FileSystemWrapper.Default.File, + FileSystemWrapper.Default.Directory) + { + } + + public string RunWithCache(string cacheKey, Func getValueToCache) + { + var cacheFilepath = GetCacheFilePath(cacheKey); + try + { + if (!_file.Exists(cacheFilepath)) + { + if (!_directory.Exists(_dotnetUserProfileFolderPath)) + { + _directory.CreateDirectory(_dotnetUserProfileFolderPath); + } + + var runResult = getValueToCache(); + + _file.WriteAllText(cacheFilepath, runResult); + return runResult; + } + else + { + return _file.ReadAllText(cacheFilepath); + } + } + catch (Exception ex) + { + if (ex is UnauthorizedAccessException + || ex is PathTooLongException + || ex is IOException) + { + return getValueToCache(); + } + + throw; + } + + } + + internal UserLevelCacheWriter(string dotnetUserProfileFolderPath, IFile file, IDirectory directory) + { + _file = file; + _directory = directory; + _dotnetUserProfileFolderPath = dotnetUserProfileFolderPath; + } + + private string GetCacheFilePath(string cacheKey) + { + return Path.Combine(_dotnetUserProfileFolderPath, $"{Product.Version}_{cacheKey}.dotnetUserLevelCache"); + } + } +} diff --git a/src/mlnet/Telemetry/MlTelemetry.cs b/src/mlnet/Telemetry/MlTelemetry.cs new file mode 100644 index 0000000000..2d65a74890 --- /dev/null +++ b/src/mlnet/Telemetry/MlTelemetry.cs @@ -0,0 +1,97 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using Microsoft.DotNet.Configurer; + +namespace Microsoft.DotNet.Cli.Telemetry +{ + public class MlTelemetry + { + private bool _firstTimeUse = false; + private bool _enabled = false; + private List _parameters = new List(); + private string _command; + + public void SetCommandAndParameters(string command, IEnumerable parameters) + { + if(parameters != null) + { + _parameters.AddRange(parameters); + } + + _command = command; + } + + public void LogAutoTrainMlCommand(string dataFileName, string task, long dataFileSize) + { + CheckFistTimeUse(); + + if(!_enabled) + { + return; + } + + var telemetry = new Telemetry(); + + var fileSizeBucket = Math.Pow(2, Math.Ceiling(Math.Log(dataFileSize, 2))); + + var fileNameHash = string.IsNullOrEmpty(dataFileName) ? string.Empty : Sha256Hasher.Hash(dataFileName); + + var paramString = string.Join(",", _parameters); + + var propertiesToLog = new Dictionary + { + { "Command", _command }, + { "FileSizeBucket", fileSizeBucket.ToString() }, + { "FileNameHash", fileNameHash }, + { "CommandLineParametersUsed", paramString }, + { "LearningTaskType", task } + }; + + telemetry.TrackEvent("mlnet-command", propertiesToLog, new Dictionary()); + } + + private void CheckFistTimeUse() + { + using (IFirstTimeUseNoticeSentinel firstTimeUseNoticeSentinel = new FirstTimeUseNoticeSentinel()) + { + // if we're in first time use invocation and there are repeat telemetry calls, don't send telemetry + if (_firstTimeUse) + { + return; + } + + _firstTimeUse = !firstTimeUseNoticeSentinel.Exists(); + + if (_firstTimeUse) + { + Console.WriteLine( +@"Welcome to the ML.NET CLI! +-------------------------- +Learn more about ML.NET CLI: https://aka.ms/mlnet-cli +Use 'dotnet ml --help' to see available commands or visit: https://aka.ms/mlnet-cli-docs + +Telemetry +--------- +The ML.NET CLI tool collect usage data in order to help us improve your experience. +The data is anonymous and doesn't include personal information or data from your datasets. +You can opt-out of telemetry by setting the MLDOTNET_CLI_TELEMETRY_OPTOUT environment variable to '1' or 'true' using your favorite shell. + +Read more about ML.NET CLI Tool telemetry: https://aka.ms/mlnet-cli-telemetry +"); + + firstTimeUseNoticeSentinel.CreateIfNotExists(); + + // since the user didn't yet have a chance to read the above message and decide to opt out, + // don't log any telemetry on the first invocation. + + return; + } + + _enabled = true; + } + } + } +} \ No newline at end of file diff --git a/src/mlnet/Telemetry/ProductVersion.cs b/src/mlnet/Telemetry/ProductVersion.cs new file mode 100644 index 0000000000..1c3afa4130 --- /dev/null +++ b/src/mlnet/Telemetry/ProductVersion.cs @@ -0,0 +1,19 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Reflection; + +namespace Microsoft.DotNet.AutoML +{ + public class Product + { + public static readonly string Version = GetProductVersion(); + + private static string GetProductVersion() + { + var version = typeof(Microsoft.ML.CLI.Program).GetTypeInfo().Assembly.GetCustomAttribute().Version; + + return version; + } + } +} \ No newline at end of file diff --git a/src/mlnet/Telemetry/Telemetry.cs b/src/mlnet/Telemetry/Telemetry.cs new file mode 100644 index 0000000000..71d5fda541 --- /dev/null +++ b/src/mlnet/Telemetry/Telemetry.cs @@ -0,0 +1,144 @@ +// Copyright (c) .NET Foundation and contributors. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading.Tasks; +using Microsoft.ApplicationInsights; +using Microsoft.DotNet.Cli.Utils; +using Microsoft.DotNet.PlatformAbstractions; + +namespace Microsoft.DotNet.Cli.Telemetry +{ + public class Telemetry : ITelemetry + { + private TelemetryClient _client = null; + private Dictionary _commonProperties = new Dictionary(); + private Task _trackEventTask = null; + + private const string InstrumentationKey = "c059917c-818d-489a-bfcb-351eaab73f2a"; + private const string MlTelemetryOptout = "MLDOTNET_CLI_TELEMETRY_OPTOUT"; + private const string MachineId = "MachineId"; + + public bool Enabled { get; } + + public Telemetry() + { + var optedOut = Env.GetEnvironmentVariableAsBool(MlTelemetryOptout, false); + + Enabled = !optedOut; + + if (!Enabled) + { + return; + } + + //initialize in task to offload to parallel thread + _trackEventTask = Task.Factory.StartNew(() => InitializeTelemetry()); + } + + public void TrackEvent( + string eventName, + IDictionary properties, + IDictionary measurements) + { + if (!Enabled) + { + return; + } + + //continue task in existing parallel thread + _trackEventTask = _trackEventTask.ContinueWith( + x => TrackEventTask(eventName, properties, measurements) + ); + } + + public void ThreadBlockingTrackEvent(string eventName, IDictionary properties, IDictionary measurements) + { + if (!Enabled) + { + return; + } + + TrackEventTask(eventName, properties, measurements); + } + + private void InitializeTelemetry() + { + try + { + _client = new TelemetryClient(); + _client.InstrumentationKey = InstrumentationKey; + _client.Context.Device.OperatingSystem = RuntimeEnvironment.OperatingSystem; + + // we don't want hostname etc to be sent in plain text. + // these need to be set to some non-empty values to override default behavior. + _client.Context.Cloud.RoleInstance = "-"; + _client.Context.Cloud.RoleName = "-"; + + _commonProperties = new TelemetryCommonProperties().GetTelemetryCommonProperties(); + } + catch (Exception e) + { + _client = null; + // we dont want to fail the tool if telemetry fails. + Debug.Fail(e.ToString()); + } + } + + private void TrackEventTask( + string eventName, + IDictionary properties, + IDictionary measurements) + { + if (_client == null) + { + return; + } + + try + { + var eventProperties = GetEventProperties(properties); + var eventMeasurements = GetEventMeasures(measurements); + + _client.TrackEvent(eventName, eventProperties, eventMeasurements); + _client.Flush(); + } + catch (Exception e) + { + Debug.Fail(e.ToString()); + } + } + + private Dictionary GetEventMeasures(IDictionary measurements) + { + Dictionary eventMeasurements = new Dictionary(); + if (measurements != null) + { + foreach (KeyValuePair measurement in measurements) + { + eventMeasurements[measurement.Key] = measurement.Value; + } + } + return eventMeasurements; + } + + private Dictionary GetEventProperties(IDictionary properties) + { + if (properties != null) + { + var eventProperties = new Dictionary(_commonProperties); + foreach (KeyValuePair property in properties) + { + eventProperties[property.Key] = property.Value; + } + return eventProperties; + } + else + { + return _commonProperties; + } + } + } +} \ No newline at end of file diff --git a/src/mlnet/Templates/Console/ModelBuilder.cs b/src/mlnet/Templates/Console/ModelBuilder.cs new file mode 100644 index 0000000000..32795203f7 --- /dev/null +++ b/src/mlnet/Templates/Console/ModelBuilder.cs @@ -0,0 +1,669 @@ +// ------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version: 15.0.0.0 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +// ------------------------------------------------------------------------------ +namespace Microsoft.ML.CLI.Templates.Console +{ + using System.Linq; + using System.Text; + using System.Text.RegularExpressions; + using System.Collections.Generic; + using Microsoft.ML.CLI.Utilities; + using System; + + /// + /// Class to produce the template output + /// + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] + public partial class ModelBuilder : ModelBuilderBase + { + /// + /// Create the template output + /// + public virtual string TransformText() + { + this.Write(@"//***************************************************************************************** +//* * +//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * +//* * +//***************************************************************************************** + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Data; +using "); + this.Write(this.ToStringHelper.ToStringWithCulture(Namespace)); + this.Write(".Model.DataModels;\r\n"); + this.Write(this.ToStringHelper.ToStringWithCulture(GeneratedUsings)); + this.Write("\r\nnamespace "); + this.Write(this.ToStringHelper.ToStringWithCulture(Namespace)); + this.Write(".ConsoleApp\r\n{\r\n public static class ModelBuilder\r\n {\r\n private stat" + + "ic string TRAIN_DATA_FILEPATH = @\""); + this.Write(this.ToStringHelper.ToStringWithCulture(Path)); + this.Write("\";\r\n"); +if(!string.IsNullOrEmpty(TestPath)){ + this.Write(" private static string TEST_DATA_FILEPATH = @\""); + this.Write(this.ToStringHelper.ToStringWithCulture(TestPath)); + this.Write("\";\r\n"); + } + this.Write(" private static string MODEL_FILEPATH = @\"../../../../"); + this.Write(this.ToStringHelper.ToStringWithCulture(Namespace)); + this.Write(@".Model/MLModel.zip""; + + // Create MLContext to be shared across the model creation workflow objects + // Set a random seed for repeatable/deterministic results across multiple trainings. + private static MLContext mlContext = new MLContext(seed: 1); + + public static void CreateModel() + { + // Load Data + IDataView trainingDataView = mlContext.Data.LoadFromTextFile( + path: TRAIN_DATA_FILEPATH, + hasHeader : "); + this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant())); + this.Write(",\r\n separatorChar : \'"); + this.Write(this.ToStringHelper.ToStringWithCulture(Regex.Escape(Separator.ToString()))); + this.Write("\',\r\n allowQuoting : "); + this.Write(this.ToStringHelper.ToStringWithCulture(AllowQuoting.ToString().ToLowerInvariant())); + this.Write(",\r\n allowSparse: "); + this.Write(this.ToStringHelper.ToStringWithCulture(AllowSparse.ToString().ToLowerInvariant())); + this.Write(");\r\n\r\n"); + if(!string.IsNullOrEmpty(TestPath)){ + this.Write(" IDataView testDataView = mlContext.Data.LoadFromTextFile(\r\n path: TEST_DATA_FILEPATH,\r\n" + + " hasHeader : "); + this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant())); + this.Write(",\r\n separatorChar : \'"); + this.Write(this.ToStringHelper.ToStringWithCulture(Regex.Escape(Separator.ToString()))); + this.Write("\',\r\n allowQuoting : "); + this.Write(this.ToStringHelper.ToStringWithCulture(AllowQuoting.ToString().ToLowerInvariant())); + this.Write(",\r\n allowSparse: "); + this.Write(this.ToStringHelper.ToStringWithCulture(AllowSparse.ToString().ToLowerInvariant())); + this.Write(");\r\n"); +} + this.Write(" // Build training pipeline\r\n IEstimator trai" + + "ningPipeline = BuildTrainingPipeline(mlContext);\r\n\r\n"); + if(string.IsNullOrEmpty(TestPath)){ + this.Write(" // Evaluate quality of Model\r\n Evaluate(mlContext, trainin" + + "gDataView, trainingPipeline);\r\n\r\n"); +} + this.Write(" // Train Model\r\n ITransformer mlModel = TrainModel(mlConte" + + "xt, trainingDataView, trainingPipeline);\r\n"); + if(!string.IsNullOrEmpty(TestPath)){ + this.Write("\r\n // Evaluate quality of Model\r\n EvaluateModel(mlContext, " + + "mlModel, testDataView);\r\n"); +} + this.Write("\r\n // Save model\r\n SaveModel(mlContext, mlModel, MODEL_FILE" + + "PATH, trainingDataView.Schema);\r\n }\r\n\r\n public static IEstimator BuildTrainingPipeline(MLContext mlContext)\r\n {\r\n"); + if(PreTrainerTransforms.Count >0 ) { + this.Write(" // Data process configuration with pipeline data transformations \r\n " + + " var dataProcessPipeline = "); + for(int i=0;i0) + { Write("\r\n .Append("); + } + Write("mlContext.Transforms."+PreTrainerTransforms[i]); + if(i>0) + { Write(")"); + } + } + if(CacheBeforeTrainer){ + Write("\r\n .AppendCacheCheckpoint(mlContext)"); + } + this.Write(";\r\n"); +} + this.Write("\r\n // Set the training algorithm \r\n var trainer = mlContext" + + "."); + this.Write(this.ToStringHelper.ToStringWithCulture(TaskType)); + this.Write(".Trainers."); + this.Write(this.ToStringHelper.ToStringWithCulture(Trainer)); + for(int i=0;i0 ) { + this.Write(" var trainingPipeline = dataProcessPipeline.Append(trainer);\r\n"); + } +else{ + this.Write(" var trainingPipeline = trainer;\r\n"); +} + this.Write(@" + return trainingPipeline; + } + + public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline) + { + Console.WriteLine(""=============== Training model ===============""); + + ITransformer model = trainingPipeline.Fit(trainingDataView); + + Console.WriteLine(""=============== End of training process ===============""); + return model; + } + +"); + if(!string.IsNullOrEmpty(TestPath)){ + this.Write(@" private static void EvaluateModel(MLContext mlContext, ITransformer mlModel, IDataView testDataView) + { + // Evaluate the model and show accuracy stats + Console.WriteLine(""===== Evaluating Model's accuracy with Test data =====""); + IDataView predictions = mlModel.Transform(testDataView); +"); +if("BinaryClassification".Equals(TaskType)){ + this.Write(" var metrics = mlContext."); + this.Write(this.ToStringHelper.ToStringWithCulture(TaskType)); + this.Write(".EvaluateNonCalibrated(predictions, \""); + this.Write(this.ToStringHelper.ToStringWithCulture(LabelName)); + this.Write("\", \"Score\");\r\n PrintBinaryClassificationMetrics(metrics);\r\n"); +} if("MulticlassClassification".Equals(TaskType)){ + this.Write(" var metrics = mlContext."); + this.Write(this.ToStringHelper.ToStringWithCulture(TaskType)); + this.Write(".Evaluate(predictions, \""); + this.Write(this.ToStringHelper.ToStringWithCulture(LabelName)); + this.Write("\", \"Score\");\r\n PrintMulticlassClassificationMetrics(metrics);\r\n"); +}if("Regression".Equals(TaskType)){ + this.Write(" var metrics = mlContext."); + this.Write(this.ToStringHelper.ToStringWithCulture(TaskType)); + this.Write(".Evaluate(predictions, \""); + this.Write(this.ToStringHelper.ToStringWithCulture(LabelName)); + this.Write("\", \"Score\");\r\n PrintRegressionMetrics(metrics);\r\n"); +} + this.Write(" }\r\n"); +}else{ + this.Write(@" private static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline) + { + // Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate) + // in order to evaluate and get the model's accuracy metrics + Console.WriteLine(""=============== Cross-validating to get model's accuracy metrics ===============""); +"); +if("BinaryClassification".Equals(TaskType)){ + this.Write(" var crossValidationResults = mlContext."); + this.Write(this.ToStringHelper.ToStringWithCulture(TaskType)); + this.Write(".CrossValidateNonCalibrated(trainingDataView, trainingPipeline, numberOfFolds: "); + this.Write(this.ToStringHelper.ToStringWithCulture(Kfolds)); + this.Write(", labelColumnName:\""); + this.Write(this.ToStringHelper.ToStringWithCulture(LabelName)); + this.Write("\");\r\n PrintBinaryClassificationFoldsAverageMetrics(crossValidationResu" + + "lts);\r\n"); +} +if("MulticlassClassification".Equals(TaskType)){ + this.Write(" var crossValidationResults = mlContext."); + this.Write(this.ToStringHelper.ToStringWithCulture(TaskType)); + this.Write(".CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: "); + this.Write(this.ToStringHelper.ToStringWithCulture(Kfolds)); + this.Write(", labelColumnName:\""); + this.Write(this.ToStringHelper.ToStringWithCulture(LabelName)); + this.Write("\");\r\n PrintMulticlassClassificationFoldsAverageMetrics(crossValidation" + + "Results);\r\n"); +} +if("Regression".Equals(TaskType)){ + this.Write(" var crossValidationResults = mlContext."); + this.Write(this.ToStringHelper.ToStringWithCulture(TaskType)); + this.Write(".CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: "); + this.Write(this.ToStringHelper.ToStringWithCulture(Kfolds)); + this.Write(", labelColumnName:\""); + this.Write(this.ToStringHelper.ToStringWithCulture(LabelName)); + this.Write("\");\r\n PrintRegressionFoldsAverageMetrics(crossValidationResults);\r\n"); +} + this.Write(" }\r\n"); +} + this.Write(@" private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema) + { + // Save/persist the trained model to a .ZIP file + Console.WriteLine($""=============== Saving the model ===============""); + using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write)) + mlContext.Model.Save(mlModel, modelInputSchema, fs); + + Console.WriteLine(""The model is saved to {0}"", GetAbsolutePath(modelRelativePath)); + } + + public static string GetAbsolutePath(string relativePath) + { + FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location); + string assemblyFolderPath = _dataRoot.Directory.FullName; + + string fullPath = Path.Combine(assemblyFolderPath, relativePath); + + return fullPath; + } + +"); +if("Regression".Equals(TaskType)){ + this.Write(" public static void PrintRegressionMetrics(RegressionMetrics metrics)\r\n " + + " {\r\n Console.WriteLine($\"****************************************" + + "*********\");\r\n Console.WriteLine($\"* Metrics for regression mod" + + "el \");\r\n Console.WriteLine($\"*----------------------------------" + + "--------------\");\r\n Console.WriteLine($\"* LossFn: {metri" + + "cs.LossFunction:0.##}\");\r\n Console.WriteLine($\"* R2 Score: " + + " {metrics.RSquared:0.##}\");\r\n Console.WriteLine($\"* Absolute lo" + + "ss: {metrics.MeanAbsoluteError:#.##}\");\r\n Console.WriteLine($\"* " + + " Squared loss: {metrics.MeanSquaredError:#.##}\");\r\n Console.WriteLin" + + "e($\"* RMS loss: {metrics.RootMeanSquaredError:#.##}\");\r\n C" + + "onsole.WriteLine($\"*************************************************\");\r\n " + + " }\r\n\r\n public static void PrintRegressionFoldsAverageMetrics(IEnumerable<" + + "TrainCatalogBase.CrossValidationResult> crossValidationResult" + + "s)\r\n {\r\n var L1 = crossValidationResults.Select(r => r.Metrics" + + ".MeanAbsoluteError);\r\n var L2 = crossValidationResults.Select(r => r." + + "Metrics.MeanSquaredError);\r\n var RMS = crossValidationResults.Select(" + + "r => r.Metrics.MeanAbsoluteError);\r\n var lossFunction = crossValidati" + + "onResults.Select(r => r.Metrics.LossFunction);\r\n var R2 = crossValida" + + "tionResults.Select(r => r.Metrics.RSquared);\r\n\r\n Console.WriteLine($\"" + + "********************************************************************************" + + "*****************************\");\r\n Console.WriteLine($\"* Metric" + + "s for Regression model \");\r\n Console.WriteLine($\"*--------------" + + "--------------------------------------------------------------------------------" + + "--------------\");\r\n Console.WriteLine($\"* Average L1 Loss: {" + + "L1.Average():0.###} \");\r\n Console.WriteLine($\"* Average L2 Loss" + + ": {L2.Average():0.###} \");\r\n Console.WriteLine($\"* Average " + + "RMS: {RMS.Average():0.###} \");\r\n Console.WriteLine($\"* " + + " Average Loss Function: {lossFunction.Average():0.###} \");\r\n Consol" + + "e.WriteLine($\"* Average R-squared: {R2.Average():0.###} \");\r\n " + + "Console.WriteLine($\"************************************************************" + + "*************************************************\");\r\n }\r\n"); + } if("BinaryClassification".Equals(TaskType)){ + this.Write(" public static void PrintBinaryClassificationMetrics(BinaryClassificationM" + + "etrics metrics)\r\n {\r\n Console.WriteLine($\"********************" + + "****************************************\");\r\n Console.WriteLine($\"* " + + " Metrics for binary classification model \");\r\n Console.Write" + + "Line($\"*-----------------------------------------------------------\");\r\n " + + " Console.WriteLine($\"* Accuracy: {metrics.Accuracy:P2}\");\r\n " + + "Console.WriteLine($\"* Auc: {metrics.AreaUnderRocCurve:P2}\");\r\n " + + " Console.WriteLine($\"*******************************************************" + + "*****\");\r\n }\r\n\r\n\r\n public static void PrintBinaryClassificationFol" + + "dsAverageMetrics(IEnumerable> crossValResults)\r\n {\r\n var metricsInMultiple" + + "Folds = crossValResults.Select(r => r.Metrics);\r\n\r\n var AccuracyValue" + + "s = metricsInMultipleFolds.Select(m => m.Accuracy);\r\n var AccuracyAve" + + "rage = AccuracyValues.Average();\r\n var AccuraciesStdDeviation = Calcu" + + "lateStandardDeviation(AccuracyValues);\r\n var AccuraciesConfidenceInte" + + "rval95 = CalculateConfidenceInterval95(AccuracyValues);\r\n\r\n\r\n Console" + + ".WriteLine($\"*******************************************************************" + + "******************************************\");\r\n Console.WriteLine($\"*" + + " Metrics for Binary Classification model \");\r\n Console.Wri" + + "teLine($\"*----------------------------------------------------------------------" + + "--------------------------------------\");\r\n Console.WriteLine($\"* " + + " Average Accuracy: {AccuracyAverage:0.###} - Standard deviation: ({Accurac" + + "iesStdDeviation:#.###}) - Confidence Interval 95%: ({AccuraciesConfidenceInterv" + + "al95:#.###})\");\r\n Console.WriteLine($\"*******************************" + + "******************************************************************************\")" + + ";\r\n }\r\n\r\n public static double CalculateStandardDeviation(IEnumera" + + "ble values)\r\n {\r\n double average = values.Average();\r\n" + + " double sumOfSquaresOfDifferences = values.Select(val => (val - avera" + + "ge) * (val - average)).Sum();\r\n double standardDeviation = Math.Sqrt(" + + "sumOfSquaresOfDifferences / (values.Count() - 1));\r\n return standardD" + + "eviation;\r\n }\r\n\r\n public static double CalculateConfidenceInterval" + + "95(IEnumerable values)\r\n {\r\n double confidenceInterval" + + "95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1))" + + ";\r\n return confidenceInterval95;\r\n }\r\n"); +} if("MulticlassClassification".Equals(TaskType)){ + this.Write(" public static void PrintMulticlassClassificationMetrics(MulticlassClassif" + + "icationMetrics metrics)\r\n {\r\n Console.WriteLine($\"************" + + "************************************************\");\r\n Console.WriteLi" + + "ne($\"* Metrics for multi-class classification model \");\r\n Consol" + + "e.WriteLine($\"*-----------------------------------------------------------\");\r\n " + + " Console.WriteLine($\" MacroAccuracy = {metrics.MacroAccuracy:0.####" + + "}, a value between 0 and 1, the closer to 1, the better\");\r\n Console." + + "WriteLine($\" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between " + + "0 and 1, the closer to 1, the better\");\r\n Console.WriteLine($\" Log" + + "Loss = {metrics.LogLoss:0.####}, the closer to 0, the better\");\r\n for" + + " (int i = 0; i < metrics.PerClassLogLoss.Count; i++)\r\n {\r\n " + + " Console.WriteLine($\" LogLoss for class {i + 1} = {metrics.PerClassLogLos" + + "s[i]:0.####}, the closer to 0, the better\");\r\n }\r\n Console" + + ".WriteLine($\"************************************************************\");\r\n " + + " }\r\n\r\n public static void PrintMulticlassClassificationFoldsAverageM" + + "etrics(IEnumerable> crossValResults)\r\n {\r\n var metricsInMultipleFolds " + + "= crossValResults.Select(r => r.Metrics);\r\n\r\n var microAccuracyValues" + + " = metricsInMultipleFolds.Select(m => m.MicroAccuracy);\r\n var microAc" + + "curacyAverage = microAccuracyValues.Average();\r\n var microAccuraciesS" + + "tdDeviation = CalculateStandardDeviation(microAccuracyValues);\r\n var " + + "microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccurac" + + "yValues);\r\n\r\n var macroAccuracyValues = metricsInMultipleFolds.Select" + + "(m => m.MacroAccuracy);\r\n var macroAccuracyAverage = macroAccuracyVal" + + "ues.Average();\r\n var macroAccuraciesStdDeviation = CalculateStandardD" + + "eviation(macroAccuracyValues);\r\n var macroAccuraciesConfidenceInterva" + + "l95 = CalculateConfidenceInterval95(macroAccuracyValues);\r\n\r\n var log" + + "LossValues = metricsInMultipleFolds.Select(m => m.LogLoss);\r\n var log" + + "LossAverage = logLossValues.Average();\r\n var logLossStdDeviation = Ca" + + "lculateStandardDeviation(logLossValues);\r\n var logLossConfidenceInter" + + "val95 = CalculateConfidenceInterval95(logLossValues);\r\n\r\n var logLoss" + + "ReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction);\r\n " + + " var logLossReductionAverage = logLossReductionValues.Average();\r\n " + + " var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReducti" + + "onValues);\r\n var logLossReductionConfidenceInterval95 = CalculateConf" + + "idenceInterval95(logLossReductionValues);\r\n\r\n Console.WriteLine($\"***" + + "********************************************************************************" + + "**************************\");\r\n Console.WriteLine($\"* Metrics f" + + "or Multi-class Classification model \");\r\n Console.WriteLine($\"*-" + + "--------------------------------------------------------------------------------" + + "---------------------------\");\r\n Console.WriteLine($\"* Average " + + "MicroAccuracy: {microAccuracyAverage:0.###} - Standard deviation: ({microAcc" + + "uraciesStdDeviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfide" + + "nceInterval95:#.###})\");\r\n Console.WriteLine($\"* Average MacroA" + + "ccuracy: {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuracie" + + "sStdDeviation:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInt" + + "erval95:#.###})\");\r\n Console.WriteLine($\"* Average LogLoss: " + + " {logLossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}" + + ") - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})\");\r\n " + + " Console.WriteLine($\"* Average LogLossReduction: {logLossReductionAvera" + + "ge:#.###} - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confi" + + "dence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})\");\r\n " + + " Console.WriteLine($\"*********************************************************" + + "****************************************************\");\r\n\r\n }\r\n\r\n " + + "public static double CalculateStandardDeviation(IEnumerable values)\r\n " + + " {\r\n double average = values.Average();\r\n double sumOf" + + "SquaresOfDifferences = values.Select(val => (val - average) * (val - average)).S" + + "um();\r\n double standardDeviation = Math.Sqrt(sumOfSquaresOfDifference" + + "s / (values.Count() - 1));\r\n return standardDeviation;\r\n }\r\n\r\n" + + " public static double CalculateConfidenceInterval95(IEnumerable v" + + "alues)\r\n {\r\n double confidenceInterval95 = 1.96 * CalculateSta" + + "ndardDeviation(values) / Math.Sqrt((values.Count() - 1));\r\n return co" + + "nfidenceInterval95;\r\n }\r\n"); +} + this.Write(" }\r\n}\r\n"); + return this.GenerationEnvironment.ToString(); + } + +public string Path {get;set;} +public string TestPath {get;set;} +public bool HasHeader {get;set;} +public char Separator {get;set;} +public IList PreTrainerTransforms {get;set;} +public string Trainer {get;set;} +public string TaskType {get;set;} +public string GeneratedUsings {get;set;} +public bool AllowQuoting {get;set;} +public bool AllowSparse {get;set;} +public int Kfolds {get;set;} = 5; +public string Namespace {get;set;} +public string LabelName {get;set;} +public bool CacheBeforeTrainer {get;set;} +public IList PostTrainerTransforms {get;set;} + + } + #region Base class + /// + /// Base class for this transformation + /// + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] + public class ModelBuilderBase + { + #region Fields + private global::System.Text.StringBuilder generationEnvironmentField; + private global::System.CodeDom.Compiler.CompilerErrorCollection errorsField; + private global::System.Collections.Generic.List indentLengthsField; + private string currentIndentField = ""; + private bool endsWithNewline; + private global::System.Collections.Generic.IDictionary sessionField; + #endregion + #region Properties + /// + /// The string builder that generation-time code is using to assemble generated output + /// + protected System.Text.StringBuilder GenerationEnvironment + { + get + { + if ((this.generationEnvironmentField == null)) + { + this.generationEnvironmentField = new global::System.Text.StringBuilder(); + } + return this.generationEnvironmentField; + } + set + { + this.generationEnvironmentField = value; + } + } + /// + /// The error collection for the generation process + /// + public System.CodeDom.Compiler.CompilerErrorCollection Errors + { + get + { + if ((this.errorsField == null)) + { + this.errorsField = new global::System.CodeDom.Compiler.CompilerErrorCollection(); + } + return this.errorsField; + } + } + /// + /// A list of the lengths of each indent that was added with PushIndent + /// + private System.Collections.Generic.List indentLengths + { + get + { + if ((this.indentLengthsField == null)) + { + this.indentLengthsField = new global::System.Collections.Generic.List(); + } + return this.indentLengthsField; + } + } + /// + /// Gets the current indent we use when adding lines to the output + /// + public string CurrentIndent + { + get + { + return this.currentIndentField; + } + } + /// + /// Current transformation session + /// + public virtual global::System.Collections.Generic.IDictionary Session + { + get + { + return this.sessionField; + } + set + { + this.sessionField = value; + } + } + #endregion + #region Transform-time helpers + /// + /// Write text directly into the generated output + /// + public void Write(string textToAppend) + { + if (string.IsNullOrEmpty(textToAppend)) + { + return; + } + // If we're starting off, or if the previous text ended with a newline, + // we have to append the current indent first. + if (((this.GenerationEnvironment.Length == 0) + || this.endsWithNewline)) + { + this.GenerationEnvironment.Append(this.currentIndentField); + this.endsWithNewline = false; + } + // Check if the current text ends with a newline + if (textToAppend.EndsWith(global::System.Environment.NewLine, global::System.StringComparison.CurrentCulture)) + { + this.endsWithNewline = true; + } + // This is an optimization. If the current indent is "", then we don't have to do any + // of the more complex stuff further down. + if ((this.currentIndentField.Length == 0)) + { + this.GenerationEnvironment.Append(textToAppend); + return; + } + // Everywhere there is a newline in the text, add an indent after it + textToAppend = textToAppend.Replace(global::System.Environment.NewLine, (global::System.Environment.NewLine + this.currentIndentField)); + // If the text ends with a newline, then we should strip off the indent added at the very end + // because the appropriate indent will be added when the next time Write() is called + if (this.endsWithNewline) + { + this.GenerationEnvironment.Append(textToAppend, 0, (textToAppend.Length - this.currentIndentField.Length)); + } + else + { + this.GenerationEnvironment.Append(textToAppend); + } + } + /// + /// Write text directly into the generated output + /// + public void WriteLine(string textToAppend) + { + this.Write(textToAppend); + this.GenerationEnvironment.AppendLine(); + this.endsWithNewline = true; + } + /// + /// Write formatted text directly into the generated output + /// + public void Write(string format, params object[] args) + { + this.Write(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args)); + } + /// + /// Write formatted text directly into the generated output + /// + public void WriteLine(string format, params object[] args) + { + this.WriteLine(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args)); + } + /// + /// Raise an error + /// + public void Error(string message) + { + System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError(); + error.ErrorText = message; + this.Errors.Add(error); + } + /// + /// Raise a warning + /// + public void Warning(string message) + { + System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError(); + error.ErrorText = message; + error.IsWarning = true; + this.Errors.Add(error); + } + /// + /// Increase the indent + /// + public void PushIndent(string indent) + { + if ((indent == null)) + { + throw new global::System.ArgumentNullException("indent"); + } + this.currentIndentField = (this.currentIndentField + indent); + this.indentLengths.Add(indent.Length); + } + /// + /// Remove the last indent that was added with PushIndent + /// + public string PopIndent() + { + string returnValue = ""; + if ((this.indentLengths.Count > 0)) + { + int indentLength = this.indentLengths[(this.indentLengths.Count - 1)]; + this.indentLengths.RemoveAt((this.indentLengths.Count - 1)); + if ((indentLength > 0)) + { + returnValue = this.currentIndentField.Substring((this.currentIndentField.Length - indentLength)); + this.currentIndentField = this.currentIndentField.Remove((this.currentIndentField.Length - indentLength)); + } + } + return returnValue; + } + /// + /// Remove any indentation + /// + public void ClearIndent() + { + this.indentLengths.Clear(); + this.currentIndentField = ""; + } + #endregion + #region ToString Helpers + /// + /// Utility class to produce culture-oriented representation of an object as a string. + /// + public class ToStringInstanceHelper + { + private System.IFormatProvider formatProviderField = global::System.Globalization.CultureInfo.InvariantCulture; + /// + /// Gets or sets format provider to be used by ToStringWithCulture method. + /// + public System.IFormatProvider FormatProvider + { + get + { + return this.formatProviderField ; + } + set + { + if ((value != null)) + { + this.formatProviderField = value; + } + } + } + /// + /// This is called from the compile/run appdomain to convert objects within an expression block to a string + /// + public string ToStringWithCulture(object objectToConvert) + { + if ((objectToConvert == null)) + { + throw new global::System.ArgumentNullException("objectToConvert"); + } + System.Type t = objectToConvert.GetType(); + System.Reflection.MethodInfo method = t.GetMethod("ToString", new System.Type[] { + typeof(System.IFormatProvider)}); + if ((method == null)) + { + return objectToConvert.ToString(); + } + else + { + return ((string)(method.Invoke(objectToConvert, new object[] { + this.formatProviderField }))); + } + } + } + private ToStringInstanceHelper toStringHelperField = new ToStringInstanceHelper(); + /// + /// Helper to produce culture-oriented representation of an object as a string + /// + public ToStringInstanceHelper ToStringHelper + { + get + { + return this.toStringHelperField; + } + } + #endregion + } + #endregion +} diff --git a/src/mlnet/Templates/Console/ModelBuilder.tt b/src/mlnet/Templates/Console/ModelBuilder.tt new file mode 100644 index 0000000000..e77d8aa34f --- /dev/null +++ b/src/mlnet/Templates/Console/ModelBuilder.tt @@ -0,0 +1,332 @@ +<#@ template language="C#" linePragmas="false" #> +<#@ assembly name="System.Core" #> +<#@ import namespace="System.Linq" #> +<#@ import namespace="System.Text" #> +<#@ import namespace="System.Text.RegularExpressions" #> +<#@ import namespace="System.Collections.Generic" #> +<#@ import namespace="Microsoft.ML.CLI.Utilities" #> +//***************************************************************************************** +//* * +//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * +//* * +//***************************************************************************************** + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Data; +using <#= Namespace #>.Model.DataModels; +<#= GeneratedUsings #> +namespace <#= Namespace #>.ConsoleApp +{ + public static class ModelBuilder + { + private static string TRAIN_DATA_FILEPATH = @"<#= Path #>"; +<#if(!string.IsNullOrEmpty(TestPath)){ #> + private static string TEST_DATA_FILEPATH = @"<#= TestPath #>"; +<# } #> + private static string MODEL_FILEPATH = @"../../../../<#= Namespace #>.Model/MLModel.zip"; + + // Create MLContext to be shared across the model creation workflow objects + // Set a random seed for repeatable/deterministic results across multiple trainings. + private static MLContext mlContext = new MLContext(seed: 1); + + public static void CreateModel() + { + // Load Data + IDataView trainingDataView = mlContext.Data.LoadFromTextFile( + path: TRAIN_DATA_FILEPATH, + hasHeader : <#= HasHeader.ToString().ToLowerInvariant() #>, + separatorChar : '<#= Regex.Escape(Separator.ToString()) #>', + allowQuoting : <#= AllowQuoting.ToString().ToLowerInvariant() #>, + allowSparse: <#= AllowSparse.ToString().ToLowerInvariant() #>); + +<# if(!string.IsNullOrEmpty(TestPath)){ #> + IDataView testDataView = mlContext.Data.LoadFromTextFile( + path: TEST_DATA_FILEPATH, + hasHeader : <#= HasHeader.ToString().ToLowerInvariant() #>, + separatorChar : '<#= Regex.Escape(Separator.ToString()) #>', + allowQuoting : <#= AllowQuoting.ToString().ToLowerInvariant() #>, + allowSparse: <#= AllowSparse.ToString().ToLowerInvariant() #>); +<#}#> + // Build training pipeline + IEstimator trainingPipeline = BuildTrainingPipeline(mlContext); + +<# if(string.IsNullOrEmpty(TestPath)){ #> + // Evaluate quality of Model + Evaluate(mlContext, trainingDataView, trainingPipeline); + +<#}#> + // Train Model + ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline); +<# if(!string.IsNullOrEmpty(TestPath)){ #> + + // Evaluate quality of Model + EvaluateModel(mlContext, mlModel, testDataView); +<#}#> + + // Save model + SaveModel(mlContext, mlModel, MODEL_FILEPATH, trainingDataView.Schema); + } + + public static IEstimator BuildTrainingPipeline(MLContext mlContext) + { +<# if(PreTrainerTransforms.Count >0 ) {#> + // Data process configuration with pipeline data transformations + var dataProcessPipeline = <# for(int i=0;i0) + { Write("\r\n .Append("); + } + Write("mlContext.Transforms."+PreTrainerTransforms[i]); + if(i>0) + { Write(")"); + } + } + if(CacheBeforeTrainer){ + Write("\r\n .AppendCacheCheckpoint(mlContext)"); + } #>; +<#}#> + + // Set the training algorithm + var trainer = mlContext.<#= TaskType #>.Trainers.<#= Trainer #><# for(int i=0;i; +<# if(PreTrainerTransforms.Count >0 ) {#> + var trainingPipeline = dataProcessPipeline.Append(trainer); +<# } +else{#> + var trainingPipeline = trainer; +<#}#> + + return trainingPipeline; + } + + public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline) + { + Console.WriteLine("=============== Training model ==============="); + + ITransformer model = trainingPipeline.Fit(trainingDataView); + + Console.WriteLine("=============== End of training process ==============="); + return model; + } + +<# if(!string.IsNullOrEmpty(TestPath)){ #> + private static void EvaluateModel(MLContext mlContext, ITransformer mlModel, IDataView testDataView) + { + // Evaluate the model and show accuracy stats + Console.WriteLine("===== Evaluating Model's accuracy with Test data ====="); + IDataView predictions = mlModel.Transform(testDataView); +<#if("BinaryClassification".Equals(TaskType)){ #> + var metrics = mlContext.<#= TaskType #>.EvaluateNonCalibrated(predictions, "<#= LabelName #>", "Score"); + PrintBinaryClassificationMetrics(metrics); +<#} if("MulticlassClassification".Equals(TaskType)){ #> + var metrics = mlContext.<#= TaskType #>.Evaluate(predictions, "<#= LabelName #>", "Score"); + PrintMulticlassClassificationMetrics(metrics); +<#}if("Regression".Equals(TaskType)){ #> + var metrics = mlContext.<#= TaskType #>.Evaluate(predictions, "<#= LabelName #>", "Score"); + PrintRegressionMetrics(metrics); +<#} #> + } +<#}else{#> + private static void Evaluate(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline) + { + // Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate) + // in order to evaluate and get the model's accuracy metrics + Console.WriteLine("=============== Cross-validating to get model's accuracy metrics ==============="); +<#if("BinaryClassification".Equals(TaskType)){ #> + var crossValidationResults = mlContext.<#= TaskType #>.CrossValidateNonCalibrated(trainingDataView, trainingPipeline, numberOfFolds: <#= Kfolds #>, labelColumnName:"<#= LabelName #>"); + PrintBinaryClassificationFoldsAverageMetrics(crossValidationResults); +<#}#><#if("MulticlassClassification".Equals(TaskType)){ #> + var crossValidationResults = mlContext.<#= TaskType #>.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: <#= Kfolds #>, labelColumnName:"<#= LabelName #>"); + PrintMulticlassClassificationFoldsAverageMetrics(crossValidationResults); +<#}#><#if("Regression".Equals(TaskType)){ #> + var crossValidationResults = mlContext.<#= TaskType #>.CrossValidate(trainingDataView, trainingPipeline, numberOfFolds: <#= Kfolds #>, labelColumnName:"<#= LabelName #>"); + PrintRegressionFoldsAverageMetrics(crossValidationResults); +<#}#> + } +<#}#> + private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema) + { + // Save/persist the trained model to a .ZIP file + Console.WriteLine($"=============== Saving the model ==============="); + using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write)) + mlContext.Model.Save(mlModel, modelInputSchema, fs); + + Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath)); + } + + public static string GetAbsolutePath(string relativePath) + { + FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location); + string assemblyFolderPath = _dataRoot.Directory.FullName; + + string fullPath = Path.Combine(assemblyFolderPath, relativePath); + + return fullPath; + } + +<#if("Regression".Equals(TaskType)){ #> + public static void PrintRegressionMetrics(RegressionMetrics metrics) + { + Console.WriteLine($"*************************************************"); + Console.WriteLine($"* Metrics for regression model "); + Console.WriteLine($"*------------------------------------------------"); + Console.WriteLine($"* LossFn: {metrics.LossFunction:0.##}"); + Console.WriteLine($"* R2 Score: {metrics.RSquared:0.##}"); + Console.WriteLine($"* Absolute loss: {metrics.MeanAbsoluteError:#.##}"); + Console.WriteLine($"* Squared loss: {metrics.MeanSquaredError:#.##}"); + Console.WriteLine($"* RMS loss: {metrics.RootMeanSquaredError:#.##}"); + Console.WriteLine($"*************************************************"); + } + + public static void PrintRegressionFoldsAverageMetrics(IEnumerable> crossValidationResults) + { + var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError); + var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError); + var RMS = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError); + var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction); + var R2 = crossValidationResults.Select(r => r.Metrics.RSquared); + + Console.WriteLine($"*************************************************************************************************************"); + Console.WriteLine($"* Metrics for Regression model "); + Console.WriteLine($"*------------------------------------------------------------------------------------------------------------"); + Console.WriteLine($"* Average L1 Loss: {L1.Average():0.###} "); + Console.WriteLine($"* Average L2 Loss: {L2.Average():0.###} "); + Console.WriteLine($"* Average RMS: {RMS.Average():0.###} "); + Console.WriteLine($"* Average Loss Function: {lossFunction.Average():0.###} "); + Console.WriteLine($"* Average R-squared: {R2.Average():0.###} "); + Console.WriteLine($"*************************************************************************************************************"); + } +<# } if("BinaryClassification".Equals(TaskType)){ #> + public static void PrintBinaryClassificationMetrics(BinaryClassificationMetrics metrics) + { + Console.WriteLine($"************************************************************"); + Console.WriteLine($"* Metrics for binary classification model "); + Console.WriteLine($"*-----------------------------------------------------------"); + Console.WriteLine($"* Accuracy: {metrics.Accuracy:P2}"); + Console.WriteLine($"* Auc: {metrics.AreaUnderRocCurve:P2}"); + Console.WriteLine($"************************************************************"); + } + + + public static void PrintBinaryClassificationFoldsAverageMetrics(IEnumerable> crossValResults) + { + var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics); + + var AccuracyValues = metricsInMultipleFolds.Select(m => m.Accuracy); + var AccuracyAverage = AccuracyValues.Average(); + var AccuraciesStdDeviation = CalculateStandardDeviation(AccuracyValues); + var AccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(AccuracyValues); + + + Console.WriteLine($"*************************************************************************************************************"); + Console.WriteLine($"* Metrics for Binary Classification model "); + Console.WriteLine($"*------------------------------------------------------------------------------------------------------------"); + Console.WriteLine($"* Average Accuracy: {AccuracyAverage:0.###} - Standard deviation: ({AccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({AccuraciesConfidenceInterval95:#.###})"); + Console.WriteLine($"*************************************************************************************************************"); + } + + public static double CalculateStandardDeviation(IEnumerable values) + { + double average = values.Average(); + double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum(); + double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1)); + return standardDeviation; + } + + public static double CalculateConfidenceInterval95(IEnumerable values) + { + double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1)); + return confidenceInterval95; + } +<#} if("MulticlassClassification".Equals(TaskType)){#> + public static void PrintMulticlassClassificationMetrics(MulticlassClassificationMetrics metrics) + { + Console.WriteLine($"************************************************************"); + Console.WriteLine($"* Metrics for multi-class classification model "); + Console.WriteLine($"*-----------------------------------------------------------"); + Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better"); + Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better"); + Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better"); + for (int i = 0; i < metrics.PerClassLogLoss.Count; i++) + { + Console.WriteLine($" LogLoss for class {i + 1} = {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better"); + } + Console.WriteLine($"************************************************************"); + } + + public static void PrintMulticlassClassificationFoldsAverageMetrics(IEnumerable> crossValResults) + { + var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics); + + var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy); + var microAccuracyAverage = microAccuracyValues.Average(); + var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues); + var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues); + + var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy); + var macroAccuracyAverage = macroAccuracyValues.Average(); + var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues); + var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues); + + var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss); + var logLossAverage = logLossValues.Average(); + var logLossStdDeviation = CalculateStandardDeviation(logLossValues); + var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues); + + var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction); + var logLossReductionAverage = logLossReductionValues.Average(); + var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues); + var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues); + + Console.WriteLine($"*************************************************************************************************************"); + Console.WriteLine($"* Metrics for Multi-class Classification model "); + Console.WriteLine($"*------------------------------------------------------------------------------------------------------------"); + Console.WriteLine($"* Average MicroAccuracy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})"); + Console.WriteLine($"* Average MacroAccuracy: {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})"); + Console.WriteLine($"* Average LogLoss: {logLossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})"); + Console.WriteLine($"* Average LogLossReduction: {logLossReductionAverage:#.###} - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})"); + Console.WriteLine($"*************************************************************************************************************"); + + } + + public static double CalculateStandardDeviation(IEnumerable values) + { + double average = values.Average(); + double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum(); + double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1)); + return standardDeviation; + } + + public static double CalculateConfidenceInterval95(IEnumerable values) + { + double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1)); + return confidenceInterval95; + } +<#}#> + } +} +<#+ +public string Path {get;set;} +public string TestPath {get;set;} +public bool HasHeader {get;set;} +public char Separator {get;set;} +public IList PreTrainerTransforms {get;set;} +public string Trainer {get;set;} +public string TaskType {get;set;} +public string GeneratedUsings {get;set;} +public bool AllowQuoting {get;set;} +public bool AllowSparse {get;set;} +public int Kfolds {get;set;} = 5; +public string Namespace {get;set;} +public string LabelName {get;set;} +public bool CacheBeforeTrainer {get;set;} +public IList PostTrainerTransforms {get;set;} +#> diff --git a/src/mlnet/Templates/Console/ModelProject.cs b/src/mlnet/Templates/Console/ModelProject.cs new file mode 100644 index 0000000000..5a9b788408 --- /dev/null +++ b/src/mlnet/Templates/Console/ModelProject.cs @@ -0,0 +1,331 @@ +// ------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version: 15.0.0.0 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +// ------------------------------------------------------------------------------ +namespace Microsoft.ML.CLI.Templates.Console +{ + using System.Linq; + using System.Text; + using System.Collections.Generic; + using System; + + /// + /// Class to produce the template output + /// + + #line 1 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\ModelProject.tt" + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] + public partial class ModelProject : ModelProjectBase + { +#line hidden + /// + /// Create the template output + /// + public virtual string TransformText() + { + this.Write(@" + + + netcoreapp2.1 + + + + https://api.nuget.org/v3/index.json; + + + + + + + + + PreserveNewest + + + + +"); + return this.GenerationEnvironment.ToString(); + } + } + + #line default + #line hidden + #region Base class + /// + /// Base class for this transformation + /// + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] + public class ModelProjectBase + { + #region Fields + private global::System.Text.StringBuilder generationEnvironmentField; + private global::System.CodeDom.Compiler.CompilerErrorCollection errorsField; + private global::System.Collections.Generic.List indentLengthsField; + private string currentIndentField = ""; + private bool endsWithNewline; + private global::System.Collections.Generic.IDictionary sessionField; + #endregion + #region Properties + /// + /// The string builder that generation-time code is using to assemble generated output + /// + protected System.Text.StringBuilder GenerationEnvironment + { + get + { + if ((this.generationEnvironmentField == null)) + { + this.generationEnvironmentField = new global::System.Text.StringBuilder(); + } + return this.generationEnvironmentField; + } + set + { + this.generationEnvironmentField = value; + } + } + /// + /// The error collection for the generation process + /// + public System.CodeDom.Compiler.CompilerErrorCollection Errors + { + get + { + if ((this.errorsField == null)) + { + this.errorsField = new global::System.CodeDom.Compiler.CompilerErrorCollection(); + } + return this.errorsField; + } + } + /// + /// A list of the lengths of each indent that was added with PushIndent + /// + private System.Collections.Generic.List indentLengths + { + get + { + if ((this.indentLengthsField == null)) + { + this.indentLengthsField = new global::System.Collections.Generic.List(); + } + return this.indentLengthsField; + } + } + /// + /// Gets the current indent we use when adding lines to the output + /// + public string CurrentIndent + { + get + { + return this.currentIndentField; + } + } + /// + /// Current transformation session + /// + public virtual global::System.Collections.Generic.IDictionary Session + { + get + { + return this.sessionField; + } + set + { + this.sessionField = value; + } + } + #endregion + #region Transform-time helpers + /// + /// Write text directly into the generated output + /// + public void Write(string textToAppend) + { + if (string.IsNullOrEmpty(textToAppend)) + { + return; + } + // If we're starting off, or if the previous text ended with a newline, + // we have to append the current indent first. + if (((this.GenerationEnvironment.Length == 0) + || this.endsWithNewline)) + { + this.GenerationEnvironment.Append(this.currentIndentField); + this.endsWithNewline = false; + } + // Check if the current text ends with a newline + if (textToAppend.EndsWith(global::System.Environment.NewLine, global::System.StringComparison.CurrentCulture)) + { + this.endsWithNewline = true; + } + // This is an optimization. If the current indent is "", then we don't have to do any + // of the more complex stuff further down. + if ((this.currentIndentField.Length == 0)) + { + this.GenerationEnvironment.Append(textToAppend); + return; + } + // Everywhere there is a newline in the text, add an indent after it + textToAppend = textToAppend.Replace(global::System.Environment.NewLine, (global::System.Environment.NewLine + this.currentIndentField)); + // If the text ends with a newline, then we should strip off the indent added at the very end + // because the appropriate indent will be added when the next time Write() is called + if (this.endsWithNewline) + { + this.GenerationEnvironment.Append(textToAppend, 0, (textToAppend.Length - this.currentIndentField.Length)); + } + else + { + this.GenerationEnvironment.Append(textToAppend); + } + } + /// + /// Write text directly into the generated output + /// + public void WriteLine(string textToAppend) + { + this.Write(textToAppend); + this.GenerationEnvironment.AppendLine(); + this.endsWithNewline = true; + } + /// + /// Write formatted text directly into the generated output + /// + public void Write(string format, params object[] args) + { + this.Write(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args)); + } + /// + /// Write formatted text directly into the generated output + /// + public void WriteLine(string format, params object[] args) + { + this.WriteLine(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args)); + } + /// + /// Raise an error + /// + public void Error(string message) + { + System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError(); + error.ErrorText = message; + this.Errors.Add(error); + } + /// + /// Raise a warning + /// + public void Warning(string message) + { + System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError(); + error.ErrorText = message; + error.IsWarning = true; + this.Errors.Add(error); + } + /// + /// Increase the indent + /// + public void PushIndent(string indent) + { + if ((indent == null)) + { + throw new global::System.ArgumentNullException("indent"); + } + this.currentIndentField = (this.currentIndentField + indent); + this.indentLengths.Add(indent.Length); + } + /// + /// Remove the last indent that was added with PushIndent + /// + public string PopIndent() + { + string returnValue = ""; + if ((this.indentLengths.Count > 0)) + { + int indentLength = this.indentLengths[(this.indentLengths.Count - 1)]; + this.indentLengths.RemoveAt((this.indentLengths.Count - 1)); + if ((indentLength > 0)) + { + returnValue = this.currentIndentField.Substring((this.currentIndentField.Length - indentLength)); + this.currentIndentField = this.currentIndentField.Remove((this.currentIndentField.Length - indentLength)); + } + } + return returnValue; + } + /// + /// Remove any indentation + /// + public void ClearIndent() + { + this.indentLengths.Clear(); + this.currentIndentField = ""; + } + #endregion + #region ToString Helpers + /// + /// Utility class to produce culture-oriented representation of an object as a string. + /// + public class ToStringInstanceHelper + { + private System.IFormatProvider formatProviderField = global::System.Globalization.CultureInfo.InvariantCulture; + /// + /// Gets or sets format provider to be used by ToStringWithCulture method. + /// + public System.IFormatProvider FormatProvider + { + get + { + return this.formatProviderField ; + } + set + { + if ((value != null)) + { + this.formatProviderField = value; + } + } + } + /// + /// This is called from the compile/run appdomain to convert objects within an expression block to a string + /// + public string ToStringWithCulture(object objectToConvert) + { + if ((objectToConvert == null)) + { + throw new global::System.ArgumentNullException("objectToConvert"); + } + System.Type t = objectToConvert.GetType(); + System.Reflection.MethodInfo method = t.GetMethod("ToString", new System.Type[] { + typeof(System.IFormatProvider)}); + if ((method == null)) + { + return objectToConvert.ToString(); + } + else + { + return ((string)(method.Invoke(objectToConvert, new object[] { + this.formatProviderField }))); + } + } + } + private ToStringInstanceHelper toStringHelperField = new ToStringInstanceHelper(); + /// + /// Helper to produce culture-oriented representation of an object as a string + /// + public ToStringInstanceHelper ToStringHelper + { + get + { + return this.toStringHelperField; + } + } + #endregion + } + #endregion +} diff --git a/src/mlnet/Templates/Console/ModelProject.tt b/src/mlnet/Templates/Console/ModelProject.tt new file mode 100644 index 0000000000..7ca417d9d1 --- /dev/null +++ b/src/mlnet/Templates/Console/ModelProject.tt @@ -0,0 +1,26 @@ +<#@ template language="C#" #> +<#@ assembly name="System.Core" #> +<#@ import namespace="System.Linq" #> +<#@ import namespace="System.Text" #> +<#@ import namespace="System.Collections.Generic" #> + + + + netcoreapp2.1 + + + + https://api.nuget.org/v3/index.json; + + + + + + + + + PreserveNewest + + + + diff --git a/src/mlnet/Templates/Console/ObservationClass.cs b/src/mlnet/Templates/Console/ObservationClass.cs new file mode 100644 index 0000000000..62b660bd39 --- /dev/null +++ b/src/mlnet/Templates/Console/ObservationClass.cs @@ -0,0 +1,355 @@ +// ------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version: 15.0.0.0 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +// ------------------------------------------------------------------------------ +namespace Microsoft.ML.CLI.Templates.Console +{ + using System.Linq; + using System.Text; + using System.Collections.Generic; + using System; + + /// + /// Class to produce the template output + /// + + #line 1 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\ObservationClass.tt" + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] + public partial class ObservationClass : ObservationClassBase + { +#line hidden + /// + /// Create the template output + /// + public virtual string TransformText() + { + this.Write(@"//***************************************************************************************** +//* * +//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * +//* * +//***************************************************************************************** + +using Microsoft.ML.Data; + +namespace "); + + #line 14 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\ObservationClass.tt" + this.Write(this.ToStringHelper.ToStringWithCulture(Namespace)); + + #line default + #line hidden + this.Write(".Model.DataModels\r\n{\r\n public class SampleObservation\r\n {\r\n"); + + #line 18 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\ObservationClass.tt" +foreach(var label in ClassLabels){ + + #line default + #line hidden + this.Write(" "); + + #line 19 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\ObservationClass.tt" + this.Write(this.ToStringHelper.ToStringWithCulture(label)); + + #line default + #line hidden + this.Write("\r\n"); + + #line 20 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\ObservationClass.tt" +} + + #line default + #line hidden + this.Write("}\r\n}\r\n"); + return this.GenerationEnvironment.ToString(); + } + + #line 23 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\ObservationClass.tt" + +public IList ClassLabels {get;set;} +public string Namespace {get;set;} + + + #line default + #line hidden + } + + #line default + #line hidden + #region Base class + /// + /// Base class for this transformation + /// + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] + public class ObservationClassBase + { + #region Fields + private global::System.Text.StringBuilder generationEnvironmentField; + private global::System.CodeDom.Compiler.CompilerErrorCollection errorsField; + private global::System.Collections.Generic.List indentLengthsField; + private string currentIndentField = ""; + private bool endsWithNewline; + private global::System.Collections.Generic.IDictionary sessionField; + #endregion + #region Properties + /// + /// The string builder that generation-time code is using to assemble generated output + /// + protected System.Text.StringBuilder GenerationEnvironment + { + get + { + if ((this.generationEnvironmentField == null)) + { + this.generationEnvironmentField = new global::System.Text.StringBuilder(); + } + return this.generationEnvironmentField; + } + set + { + this.generationEnvironmentField = value; + } + } + /// + /// The error collection for the generation process + /// + public System.CodeDom.Compiler.CompilerErrorCollection Errors + { + get + { + if ((this.errorsField == null)) + { + this.errorsField = new global::System.CodeDom.Compiler.CompilerErrorCollection(); + } + return this.errorsField; + } + } + /// + /// A list of the lengths of each indent that was added with PushIndent + /// + private System.Collections.Generic.List indentLengths + { + get + { + if ((this.indentLengthsField == null)) + { + this.indentLengthsField = new global::System.Collections.Generic.List(); + } + return this.indentLengthsField; + } + } + /// + /// Gets the current indent we use when adding lines to the output + /// + public string CurrentIndent + { + get + { + return this.currentIndentField; + } + } + /// + /// Current transformation session + /// + public virtual global::System.Collections.Generic.IDictionary Session + { + get + { + return this.sessionField; + } + set + { + this.sessionField = value; + } + } + #endregion + #region Transform-time helpers + /// + /// Write text directly into the generated output + /// + public void Write(string textToAppend) + { + if (string.IsNullOrEmpty(textToAppend)) + { + return; + } + // If we're starting off, or if the previous text ended with a newline, + // we have to append the current indent first. + if (((this.GenerationEnvironment.Length == 0) + || this.endsWithNewline)) + { + this.GenerationEnvironment.Append(this.currentIndentField); + this.endsWithNewline = false; + } + // Check if the current text ends with a newline + if (textToAppend.EndsWith(global::System.Environment.NewLine, global::System.StringComparison.CurrentCulture)) + { + this.endsWithNewline = true; + } + // This is an optimization. If the current indent is "", then we don't have to do any + // of the more complex stuff further down. + if ((this.currentIndentField.Length == 0)) + { + this.GenerationEnvironment.Append(textToAppend); + return; + } + // Everywhere there is a newline in the text, add an indent after it + textToAppend = textToAppend.Replace(global::System.Environment.NewLine, (global::System.Environment.NewLine + this.currentIndentField)); + // If the text ends with a newline, then we should strip off the indent added at the very end + // because the appropriate indent will be added when the next time Write() is called + if (this.endsWithNewline) + { + this.GenerationEnvironment.Append(textToAppend, 0, (textToAppend.Length - this.currentIndentField.Length)); + } + else + { + this.GenerationEnvironment.Append(textToAppend); + } + } + /// + /// Write text directly into the generated output + /// + public void WriteLine(string textToAppend) + { + this.Write(textToAppend); + this.GenerationEnvironment.AppendLine(); + this.endsWithNewline = true; + } + /// + /// Write formatted text directly into the generated output + /// + public void Write(string format, params object[] args) + { + this.Write(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args)); + } + /// + /// Write formatted text directly into the generated output + /// + public void WriteLine(string format, params object[] args) + { + this.WriteLine(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args)); + } + /// + /// Raise an error + /// + public void Error(string message) + { + System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError(); + error.ErrorText = message; + this.Errors.Add(error); + } + /// + /// Raise a warning + /// + public void Warning(string message) + { + System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError(); + error.ErrorText = message; + error.IsWarning = true; + this.Errors.Add(error); + } + /// + /// Increase the indent + /// + public void PushIndent(string indent) + { + if ((indent == null)) + { + throw new global::System.ArgumentNullException("indent"); + } + this.currentIndentField = (this.currentIndentField + indent); + this.indentLengths.Add(indent.Length); + } + /// + /// Remove the last indent that was added with PushIndent + /// + public string PopIndent() + { + string returnValue = ""; + if ((this.indentLengths.Count > 0)) + { + int indentLength = this.indentLengths[(this.indentLengths.Count - 1)]; + this.indentLengths.RemoveAt((this.indentLengths.Count - 1)); + if ((indentLength > 0)) + { + returnValue = this.currentIndentField.Substring((this.currentIndentField.Length - indentLength)); + this.currentIndentField = this.currentIndentField.Remove((this.currentIndentField.Length - indentLength)); + } + } + return returnValue; + } + /// + /// Remove any indentation + /// + public void ClearIndent() + { + this.indentLengths.Clear(); + this.currentIndentField = ""; + } + #endregion + #region ToString Helpers + /// + /// Utility class to produce culture-oriented representation of an object as a string. + /// + public class ToStringInstanceHelper + { + private System.IFormatProvider formatProviderField = global::System.Globalization.CultureInfo.InvariantCulture; + /// + /// Gets or sets format provider to be used by ToStringWithCulture method. + /// + public System.IFormatProvider FormatProvider + { + get + { + return this.formatProviderField ; + } + set + { + if ((value != null)) + { + this.formatProviderField = value; + } + } + } + /// + /// This is called from the compile/run appdomain to convert objects within an expression block to a string + /// + public string ToStringWithCulture(object objectToConvert) + { + if ((objectToConvert == null)) + { + throw new global::System.ArgumentNullException("objectToConvert"); + } + System.Type t = objectToConvert.GetType(); + System.Reflection.MethodInfo method = t.GetMethod("ToString", new System.Type[] { + typeof(System.IFormatProvider)}); + if ((method == null)) + { + return objectToConvert.ToString(); + } + else + { + return ((string)(method.Invoke(objectToConvert, new object[] { + this.formatProviderField }))); + } + } + } + private ToStringInstanceHelper toStringHelperField = new ToStringInstanceHelper(); + /// + /// Helper to produce culture-oriented representation of an object as a string + /// + public ToStringInstanceHelper ToStringHelper + { + get + { + return this.toStringHelperField; + } + } + #endregion + } + #endregion +} diff --git a/src/mlnet/Templates/Console/ObservationClass.tt b/src/mlnet/Templates/Console/ObservationClass.tt new file mode 100644 index 0000000000..07da8f56cc --- /dev/null +++ b/src/mlnet/Templates/Console/ObservationClass.tt @@ -0,0 +1,26 @@ +<#@ template language="C#" #> +<#@ assembly name="System.Core" #> +<#@ import namespace="System.Linq" #> +<#@ import namespace="System.Text" #> +<#@ import namespace="System.Collections.Generic" #> +//***************************************************************************************** +//* * +//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * +//* * +//***************************************************************************************** + +using Microsoft.ML.Data; + +namespace <#= Namespace #>.Model.DataModels +{ + public class SampleObservation + { +<#foreach(var label in ClassLabels){#> + <#=label#> +<#}#> +} +} +<#+ +public IList ClassLabels {get;set;} +public string Namespace {get;set;} +#> \ No newline at end of file diff --git a/src/mlnet/Templates/Console/PredictProgram.cs b/src/mlnet/Templates/Console/PredictProgram.cs new file mode 100644 index 0000000000..1bc67d8265 --- /dev/null +++ b/src/mlnet/Templates/Console/PredictProgram.cs @@ -0,0 +1,507 @@ +// ------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version: 15.0.0.0 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +// ------------------------------------------------------------------------------ +namespace Microsoft.ML.CLI.Templates.Console +{ + using System.Linq; + using System.Text; + using System.Text.RegularExpressions; + using System.Collections.Generic; + using Microsoft.ML.CLI.Utilities; + using System; + + /// + /// Class to produce the template output + /// + + #line 1 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] + public partial class PredictProgram : PredictProgramBase + { +#line hidden + /// + /// Create the template output + /// + public virtual string TransformText() + { + this.Write(@"//***************************************************************************************** +//* * +//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * +//* * +//***************************************************************************************** + +using System; +using System.IO; +using System.Linq; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using "); + + #line 20 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + this.Write(this.ToStringHelper.ToStringWithCulture(Namespace)); + + #line default + #line hidden + this.Write(".Model.DataModels;\r\n\r\n\r\nnamespace "); + + #line 23 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + this.Write(this.ToStringHelper.ToStringWithCulture(Namespace)); + + #line default + #line hidden + this.Write(".ConsoleApp\r\n{\r\n class Program\r\n {\r\n //Machine Learning model to loa" + + "d and use for predictions\r\n private const string MODEL_FILEPATH = @\"MLMod" + + "el.zip\";\r\n\r\n //Dataset to use for predictions \r\n"); + + #line 31 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" +if(string.IsNullOrEmpty(TestDataPath)){ + + #line default + #line hidden + this.Write(" private const string DATA_FILEPATH = @\""); + + #line 32 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + this.Write(this.ToStringHelper.ToStringWithCulture(TrainDataPath)); + + #line default + #line hidden + this.Write("\";\r\n"); + + #line 33 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + } else{ + + #line default + #line hidden + this.Write(" private const string DATA_FILEPATH = @\""); + + #line 34 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + this.Write(this.ToStringHelper.ToStringWithCulture(TestDataPath)); + + #line default + #line hidden + this.Write("\";\r\n"); + + #line 35 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + } + + #line default + #line hidden + this.Write(@" + static void Main(string[] args) + { + MLContext mlContext = new MLContext(); + + // Training code used by ML.NET CLI and AutoML to generate the model + //ModelBuilder.CreateModel(); + + ITransformer mlModel = mlContext.Model.Load(MODEL_FILEPATH, out DataViewSchema inputSchema); + var predEngine = mlContext.Model.CreatePredictionEngine(mlModel); + + // Create sample data to do a single prediction with it + SampleObservation sampleData = CreateSingleDataSample(mlContext, DATA_FILEPATH); + + // Try a single prediction + SamplePrediction predictionResult = predEngine.Predict(sampleData); + +"); + + #line 53 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" +if("BinaryClassification".Equals(TaskType)){ + + #line default + #line hidden + this.Write(" Console.WriteLine($\"Single Prediction --> Actual value: {sampleData."); + + #line 54 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName))); + + #line default + #line hidden + this.Write("} | Predicted value: {predictionResult.Prediction}\");\r\n"); + + #line 55 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" +}else if("Regression".Equals(TaskType)){ + + #line default + #line hidden + this.Write(" Console.WriteLine($\"Single Prediction --> Actual value: {sampleData."); + + #line 56 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName))); + + #line default + #line hidden + this.Write("} | Predicted value: {predictionResult.Score}\");\r\n"); + + #line 57 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" +} else if("MulticlassClassification".Equals(TaskType)){ + + #line default + #line hidden + this.Write(" Console.WriteLine($\"Single Prediction --> Actual value: {sampleData."); + + #line 58 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + this.Write(this.ToStringHelper.ToStringWithCulture(Utils.Normalize(LabelName))); + + #line default + #line hidden + this.Write("} | Predicted value: {predictionResult.Prediction} | Predicted scores: [{String.J" + + "oin(\",\", predictionResult.Score)}]\");\r\n"); + + #line 59 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" +} + + #line default + #line hidden + this.Write(@" + Console.WriteLine(""=============== End of process, hit any key to finish ===============""); + Console.ReadKey(); + } + + // Method to load single row of data to try a single prediction + // You can change this code and create your own sample data here (Hardcoded or from any source) + private static SampleObservation CreateSingleDataSample(MLContext mlContext, string dataFilePath) + { + // Read dataset to get a single row for trying a prediction + IDataView dataView = mlContext.Data.LoadFromTextFile( + path: dataFilePath, + hasHeader : "); + + #line 72 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + this.Write(this.ToStringHelper.ToStringWithCulture(HasHeader.ToString().ToLowerInvariant())); + + #line default + #line hidden + this.Write(",\r\n separatorChar : \'"); + + #line 73 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + this.Write(this.ToStringHelper.ToStringWithCulture(Regex.Escape(Separator.ToString()))); + + #line default + #line hidden + this.Write("\',\r\n allowQuoting : "); + + #line 74 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + this.Write(this.ToStringHelper.ToStringWithCulture(AllowQuoting.ToString().ToLowerInvariant())); + + #line default + #line hidden + this.Write(",\r\n allowSparse: "); + + #line 75 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + this.Write(this.ToStringHelper.ToStringWithCulture(AllowSparse.ToString().ToLowerInvariant())); + + #line default + #line hidden + this.Write(@"); + + // Here (SampleObservation object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file. + SampleObservation sampleForPrediction = mlContext.Data.CreateEnumerable(dataView, false) + .First(); + return sampleForPrediction; + } + } +} +"); + return this.GenerationEnvironment.ToString(); + } + + #line 84 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictProgram.tt" + +public string TaskType {get;set;} +public string Namespace {get;set;} +public string LabelName {get;set;} +public string TestDataPath {get;set;} +public string TrainDataPath {get;set;} +public char Separator {get;set;} +public bool AllowQuoting {get;set;} +public bool AllowSparse {get;set;} +public bool HasHeader {get;set;} + + + #line default + #line hidden + } + + #line default + #line hidden + #region Base class + /// + /// Base class for this transformation + /// + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] + public class PredictProgramBase + { + #region Fields + private global::System.Text.StringBuilder generationEnvironmentField; + private global::System.CodeDom.Compiler.CompilerErrorCollection errorsField; + private global::System.Collections.Generic.List indentLengthsField; + private string currentIndentField = ""; + private bool endsWithNewline; + private global::System.Collections.Generic.IDictionary sessionField; + #endregion + #region Properties + /// + /// The string builder that generation-time code is using to assemble generated output + /// + protected System.Text.StringBuilder GenerationEnvironment + { + get + { + if ((this.generationEnvironmentField == null)) + { + this.generationEnvironmentField = new global::System.Text.StringBuilder(); + } + return this.generationEnvironmentField; + } + set + { + this.generationEnvironmentField = value; + } + } + /// + /// The error collection for the generation process + /// + public System.CodeDom.Compiler.CompilerErrorCollection Errors + { + get + { + if ((this.errorsField == null)) + { + this.errorsField = new global::System.CodeDom.Compiler.CompilerErrorCollection(); + } + return this.errorsField; + } + } + /// + /// A list of the lengths of each indent that was added with PushIndent + /// + private System.Collections.Generic.List indentLengths + { + get + { + if ((this.indentLengthsField == null)) + { + this.indentLengthsField = new global::System.Collections.Generic.List(); + } + return this.indentLengthsField; + } + } + /// + /// Gets the current indent we use when adding lines to the output + /// + public string CurrentIndent + { + get + { + return this.currentIndentField; + } + } + /// + /// Current transformation session + /// + public virtual global::System.Collections.Generic.IDictionary Session + { + get + { + return this.sessionField; + } + set + { + this.sessionField = value; + } + } + #endregion + #region Transform-time helpers + /// + /// Write text directly into the generated output + /// + public void Write(string textToAppend) + { + if (string.IsNullOrEmpty(textToAppend)) + { + return; + } + // If we're starting off, or if the previous text ended with a newline, + // we have to append the current indent first. + if (((this.GenerationEnvironment.Length == 0) + || this.endsWithNewline)) + { + this.GenerationEnvironment.Append(this.currentIndentField); + this.endsWithNewline = false; + } + // Check if the current text ends with a newline + if (textToAppend.EndsWith(global::System.Environment.NewLine, global::System.StringComparison.CurrentCulture)) + { + this.endsWithNewline = true; + } + // This is an optimization. If the current indent is "", then we don't have to do any + // of the more complex stuff further down. + if ((this.currentIndentField.Length == 0)) + { + this.GenerationEnvironment.Append(textToAppend); + return; + } + // Everywhere there is a newline in the text, add an indent after it + textToAppend = textToAppend.Replace(global::System.Environment.NewLine, (global::System.Environment.NewLine + this.currentIndentField)); + // If the text ends with a newline, then we should strip off the indent added at the very end + // because the appropriate indent will be added when the next time Write() is called + if (this.endsWithNewline) + { + this.GenerationEnvironment.Append(textToAppend, 0, (textToAppend.Length - this.currentIndentField.Length)); + } + else + { + this.GenerationEnvironment.Append(textToAppend); + } + } + /// + /// Write text directly into the generated output + /// + public void WriteLine(string textToAppend) + { + this.Write(textToAppend); + this.GenerationEnvironment.AppendLine(); + this.endsWithNewline = true; + } + /// + /// Write formatted text directly into the generated output + /// + public void Write(string format, params object[] args) + { + this.Write(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args)); + } + /// + /// Write formatted text directly into the generated output + /// + public void WriteLine(string format, params object[] args) + { + this.WriteLine(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args)); + } + /// + /// Raise an error + /// + public void Error(string message) + { + System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError(); + error.ErrorText = message; + this.Errors.Add(error); + } + /// + /// Raise a warning + /// + public void Warning(string message) + { + System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError(); + error.ErrorText = message; + error.IsWarning = true; + this.Errors.Add(error); + } + /// + /// Increase the indent + /// + public void PushIndent(string indent) + { + if ((indent == null)) + { + throw new global::System.ArgumentNullException("indent"); + } + this.currentIndentField = (this.currentIndentField + indent); + this.indentLengths.Add(indent.Length); + } + /// + /// Remove the last indent that was added with PushIndent + /// + public string PopIndent() + { + string returnValue = ""; + if ((this.indentLengths.Count > 0)) + { + int indentLength = this.indentLengths[(this.indentLengths.Count - 1)]; + this.indentLengths.RemoveAt((this.indentLengths.Count - 1)); + if ((indentLength > 0)) + { + returnValue = this.currentIndentField.Substring((this.currentIndentField.Length - indentLength)); + this.currentIndentField = this.currentIndentField.Remove((this.currentIndentField.Length - indentLength)); + } + } + return returnValue; + } + /// + /// Remove any indentation + /// + public void ClearIndent() + { + this.indentLengths.Clear(); + this.currentIndentField = ""; + } + #endregion + #region ToString Helpers + /// + /// Utility class to produce culture-oriented representation of an object as a string. + /// + public class ToStringInstanceHelper + { + private System.IFormatProvider formatProviderField = global::System.Globalization.CultureInfo.InvariantCulture; + /// + /// Gets or sets format provider to be used by ToStringWithCulture method. + /// + public System.IFormatProvider FormatProvider + { + get + { + return this.formatProviderField ; + } + set + { + if ((value != null)) + { + this.formatProviderField = value; + } + } + } + /// + /// This is called from the compile/run appdomain to convert objects within an expression block to a string + /// + public string ToStringWithCulture(object objectToConvert) + { + if ((objectToConvert == null)) + { + throw new global::System.ArgumentNullException("objectToConvert"); + } + System.Type t = objectToConvert.GetType(); + System.Reflection.MethodInfo method = t.GetMethod("ToString", new System.Type[] { + typeof(System.IFormatProvider)}); + if ((method == null)) + { + return objectToConvert.ToString(); + } + else + { + return ((string)(method.Invoke(objectToConvert, new object[] { + this.formatProviderField }))); + } + } + } + private ToStringInstanceHelper toStringHelperField = new ToStringInstanceHelper(); + /// + /// Helper to produce culture-oriented representation of an object as a string + /// + public ToStringInstanceHelper ToStringHelper + { + get + { + return this.toStringHelperField; + } + } + #endregion + } + #endregion +} diff --git a/src/mlnet/Templates/Console/PredictProgram.tt b/src/mlnet/Templates/Console/PredictProgram.tt new file mode 100644 index 0000000000..fc9ed43172 --- /dev/null +++ b/src/mlnet/Templates/Console/PredictProgram.tt @@ -0,0 +1,94 @@ +<#@ template language="C#" #> +<#@ assembly name="System.Core" #> +<#@ import namespace="System.Linq" #> +<#@ import namespace="System.Text" #> +<#@ import namespace="System.Text.RegularExpressions" #> +<#@ import namespace="System.Collections.Generic" #> +<#@ import namespace="Microsoft.ML.CLI.Utilities" #> +//***************************************************************************************** +//* * +//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * +//* * +//***************************************************************************************** + +using System; +using System.IO; +using System.Linq; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using <#= Namespace #>.Model.DataModels; + + +namespace <#= Namespace #>.ConsoleApp +{ + class Program + { + //Machine Learning model to load and use for predictions + private const string MODEL_FILEPATH = @"MLModel.zip"; + + //Dataset to use for predictions +<#if(string.IsNullOrEmpty(TestDataPath)){ #> + private const string DATA_FILEPATH = @"<#= TrainDataPath #>"; +<# } else{ #> + private const string DATA_FILEPATH = @"<#= TestDataPath #>"; +<# } #> + + static void Main(string[] args) + { + MLContext mlContext = new MLContext(); + + // Training code used by ML.NET CLI and AutoML to generate the model + //ModelBuilder.CreateModel(); + + ITransformer mlModel = mlContext.Model.Load(MODEL_FILEPATH, out DataViewSchema inputSchema); + var predEngine = mlContext.Model.CreatePredictionEngine(mlModel); + + // Create sample data to do a single prediction with it + SampleObservation sampleData = CreateSingleDataSample(mlContext, DATA_FILEPATH); + + // Try a single prediction + SamplePrediction predictionResult = predEngine.Predict(sampleData); + +<#if("BinaryClassification".Equals(TaskType)){ #> + Console.WriteLine($"Single Prediction --> Actual value: {sampleData.<#= Utils.Normalize(LabelName) #>} | Predicted value: {predictionResult.Prediction}"); +<#}else if("Regression".Equals(TaskType)){#> + Console.WriteLine($"Single Prediction --> Actual value: {sampleData.<#= Utils.Normalize(LabelName) #>} | Predicted value: {predictionResult.Score}"); +<#} else if("MulticlassClassification".Equals(TaskType)){#> + Console.WriteLine($"Single Prediction --> Actual value: {sampleData.<#= Utils.Normalize(LabelName) #>} | Predicted value: {predictionResult.Prediction} | Predicted scores: [{String.Join(",", predictionResult.Score)}]"); +<#}#> + + Console.WriteLine("=============== End of process, hit any key to finish ==============="); + Console.ReadKey(); + } + + // Method to load single row of data to try a single prediction + // You can change this code and create your own sample data here (Hardcoded or from any source) + private static SampleObservation CreateSingleDataSample(MLContext mlContext, string dataFilePath) + { + // Read dataset to get a single row for trying a prediction + IDataView dataView = mlContext.Data.LoadFromTextFile( + path: dataFilePath, + hasHeader : <#= HasHeader.ToString().ToLowerInvariant() #>, + separatorChar : '<#= Regex.Escape(Separator.ToString()) #>', + allowQuoting : <#= AllowQuoting.ToString().ToLowerInvariant() #>, + allowSparse: <#= AllowSparse.ToString().ToLowerInvariant() #>); + + // Here (SampleObservation object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file. + SampleObservation sampleForPrediction = mlContext.Data.CreateEnumerable(dataView, false) + .First(); + return sampleForPrediction; + } + } +} +<#+ +public string TaskType {get;set;} +public string Namespace {get;set;} +public string LabelName {get;set;} +public string TestDataPath {get;set;} +public string TrainDataPath {get;set;} +public char Separator {get;set;} +public bool AllowQuoting {get;set;} +public bool AllowSparse {get;set;} +public bool HasHeader {get;set;} +#> diff --git a/src/mlnet/Templates/Console/PredictProject.cs b/src/mlnet/Templates/Console/PredictProject.cs new file mode 100644 index 0000000000..58e93a664b --- /dev/null +++ b/src/mlnet/Templates/Console/PredictProject.cs @@ -0,0 +1,326 @@ +// ------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version: 15.0.0.0 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +// ------------------------------------------------------------------------------ +namespace Microsoft.ML.CLI.Templates.Console +{ + using System.Linq; + using System.Text; + using System.Text.RegularExpressions; + using System.Collections.Generic; + using System; + + /// + /// Class to produce the template output + /// + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] + public partial class PredictProject : PredictProjectBase + { + /// + /// Create the template output + /// + public virtual string TransformText() + { + this.Write("\r\n\r\n \r\n Exe\r\n netcoreapp2.1\r\n \r\n \r\n \r\n"); + if(IncludeLightGBMPackage){ + this.Write(" \r" + + "\n"); +} + if(IncludeMklComponentsPackage){ + this.Write(" \r\n"); +} + this.Write(" \r\n \r\n \r\n \r\n\r\n"); + return this.GenerationEnvironment.ToString(); + } + +public string Namespace {get;set;} +public bool IncludeLightGBMPackage {get;set;} +public bool IncludeMklComponentsPackage {get;set;} + + } + #region Base class + /// + /// Base class for this transformation + /// + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] + public class PredictProjectBase + { + #region Fields + private global::System.Text.StringBuilder generationEnvironmentField; + private global::System.CodeDom.Compiler.CompilerErrorCollection errorsField; + private global::System.Collections.Generic.List indentLengthsField; + private string currentIndentField = ""; + private bool endsWithNewline; + private global::System.Collections.Generic.IDictionary sessionField; + #endregion + #region Properties + /// + /// The string builder that generation-time code is using to assemble generated output + /// + protected System.Text.StringBuilder GenerationEnvironment + { + get + { + if ((this.generationEnvironmentField == null)) + { + this.generationEnvironmentField = new global::System.Text.StringBuilder(); + } + return this.generationEnvironmentField; + } + set + { + this.generationEnvironmentField = value; + } + } + /// + /// The error collection for the generation process + /// + public System.CodeDom.Compiler.CompilerErrorCollection Errors + { + get + { + if ((this.errorsField == null)) + { + this.errorsField = new global::System.CodeDom.Compiler.CompilerErrorCollection(); + } + return this.errorsField; + } + } + /// + /// A list of the lengths of each indent that was added with PushIndent + /// + private System.Collections.Generic.List indentLengths + { + get + { + if ((this.indentLengthsField == null)) + { + this.indentLengthsField = new global::System.Collections.Generic.List(); + } + return this.indentLengthsField; + } + } + /// + /// Gets the current indent we use when adding lines to the output + /// + public string CurrentIndent + { + get + { + return this.currentIndentField; + } + } + /// + /// Current transformation session + /// + public virtual global::System.Collections.Generic.IDictionary Session + { + get + { + return this.sessionField; + } + set + { + this.sessionField = value; + } + } + #endregion + #region Transform-time helpers + /// + /// Write text directly into the generated output + /// + public void Write(string textToAppend) + { + if (string.IsNullOrEmpty(textToAppend)) + { + return; + } + // If we're starting off, or if the previous text ended with a newline, + // we have to append the current indent first. + if (((this.GenerationEnvironment.Length == 0) + || this.endsWithNewline)) + { + this.GenerationEnvironment.Append(this.currentIndentField); + this.endsWithNewline = false; + } + // Check if the current text ends with a newline + if (textToAppend.EndsWith(global::System.Environment.NewLine, global::System.StringComparison.CurrentCulture)) + { + this.endsWithNewline = true; + } + // This is an optimization. If the current indent is "", then we don't have to do any + // of the more complex stuff further down. + if ((this.currentIndentField.Length == 0)) + { + this.GenerationEnvironment.Append(textToAppend); + return; + } + // Everywhere there is a newline in the text, add an indent after it + textToAppend = textToAppend.Replace(global::System.Environment.NewLine, (global::System.Environment.NewLine + this.currentIndentField)); + // If the text ends with a newline, then we should strip off the indent added at the very end + // because the appropriate indent will be added when the next time Write() is called + if (this.endsWithNewline) + { + this.GenerationEnvironment.Append(textToAppend, 0, (textToAppend.Length - this.currentIndentField.Length)); + } + else + { + this.GenerationEnvironment.Append(textToAppend); + } + } + /// + /// Write text directly into the generated output + /// + public void WriteLine(string textToAppend) + { + this.Write(textToAppend); + this.GenerationEnvironment.AppendLine(); + this.endsWithNewline = true; + } + /// + /// Write formatted text directly into the generated output + /// + public void Write(string format, params object[] args) + { + this.Write(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args)); + } + /// + /// Write formatted text directly into the generated output + /// + public void WriteLine(string format, params object[] args) + { + this.WriteLine(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args)); + } + /// + /// Raise an error + /// + public void Error(string message) + { + System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError(); + error.ErrorText = message; + this.Errors.Add(error); + } + /// + /// Raise a warning + /// + public void Warning(string message) + { + System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError(); + error.ErrorText = message; + error.IsWarning = true; + this.Errors.Add(error); + } + /// + /// Increase the indent + /// + public void PushIndent(string indent) + { + if ((indent == null)) + { + throw new global::System.ArgumentNullException("indent"); + } + this.currentIndentField = (this.currentIndentField + indent); + this.indentLengths.Add(indent.Length); + } + /// + /// Remove the last indent that was added with PushIndent + /// + public string PopIndent() + { + string returnValue = ""; + if ((this.indentLengths.Count > 0)) + { + int indentLength = this.indentLengths[(this.indentLengths.Count - 1)]; + this.indentLengths.RemoveAt((this.indentLengths.Count - 1)); + if ((indentLength > 0)) + { + returnValue = this.currentIndentField.Substring((this.currentIndentField.Length - indentLength)); + this.currentIndentField = this.currentIndentField.Remove((this.currentIndentField.Length - indentLength)); + } + } + return returnValue; + } + /// + /// Remove any indentation + /// + public void ClearIndent() + { + this.indentLengths.Clear(); + this.currentIndentField = ""; + } + #endregion + #region ToString Helpers + /// + /// Utility class to produce culture-oriented representation of an object as a string. + /// + public class ToStringInstanceHelper + { + private System.IFormatProvider formatProviderField = global::System.Globalization.CultureInfo.InvariantCulture; + /// + /// Gets or sets format provider to be used by ToStringWithCulture method. + /// + public System.IFormatProvider FormatProvider + { + get + { + return this.formatProviderField ; + } + set + { + if ((value != null)) + { + this.formatProviderField = value; + } + } + } + /// + /// This is called from the compile/run appdomain to convert objects within an expression block to a string + /// + public string ToStringWithCulture(object objectToConvert) + { + if ((objectToConvert == null)) + { + throw new global::System.ArgumentNullException("objectToConvert"); + } + System.Type t = objectToConvert.GetType(); + System.Reflection.MethodInfo method = t.GetMethod("ToString", new System.Type[] { + typeof(System.IFormatProvider)}); + if ((method == null)) + { + return objectToConvert.ToString(); + } + else + { + return ((string)(method.Invoke(objectToConvert, new object[] { + this.formatProviderField }))); + } + } + } + private ToStringInstanceHelper toStringHelperField = new ToStringInstanceHelper(); + /// + /// Helper to produce culture-oriented representation of an object as a string + /// + public ToStringInstanceHelper ToStringHelper + { + get + { + return this.toStringHelperField; + } + } + #endregion + } + #endregion +} diff --git a/src/mlnet/Templates/Console/PredictProject.tt b/src/mlnet/Templates/Console/PredictProject.tt new file mode 100644 index 0000000000..2427250293 --- /dev/null +++ b/src/mlnet/Templates/Console/PredictProject.tt @@ -0,0 +1,30 @@ +<#@ template language="C#" linePragmas="false" #> +<#@ assembly name="System.Core" #> +<#@ import namespace="System.Linq" #> +<#@ import namespace="System.Text" #> +<#@ import namespace="System.Text.RegularExpressions" #> +<#@ import namespace="System.Collections.Generic" #> + + + + Exe + netcoreapp2.1 + + + +<# if(IncludeLightGBMPackage){ #> + +<#}#> +<# if(IncludeMklComponentsPackage){ #> + +<#}#> + + + + + +<#+ +public string Namespace {get;set;} +public bool IncludeLightGBMPackage {get;set;} +public bool IncludeMklComponentsPackage {get;set;} +#> diff --git a/src/mlnet/Templates/Console/PredictionClass.cs b/src/mlnet/Templates/Console/PredictionClass.cs new file mode 100644 index 0000000000..11d5303d2c --- /dev/null +++ b/src/mlnet/Templates/Console/PredictionClass.cs @@ -0,0 +1,388 @@ +// ------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version: 15.0.0.0 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +// ------------------------------------------------------------------------------ +namespace Microsoft.ML.CLI.Templates.Console +{ + using System.Linq; + using System.Text; + using System.Collections.Generic; + using System; + + /// + /// Class to produce the template output + /// + + #line 1 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictionClass.tt" + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] + public partial class PredictionClass : PredictionClassBase + { +#line hidden + /// + /// Create the template output + /// + public virtual string TransformText() + { + this.Write(@"//***************************************************************************************** +//* * +//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * +//* * +//***************************************************************************************** + +using System; +using Microsoft.ML.Data; + +namespace "); + + #line 15 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictionClass.tt" + this.Write(this.ToStringHelper.ToStringWithCulture(Namespace)); + + #line default + #line hidden + this.Write(".Model.DataModels\r\n{\r\n public class SamplePrediction\r\n {\r\n"); + + #line 19 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictionClass.tt" +if("BinaryClassification".Equals(TaskType)){ + + #line default + #line hidden + this.Write(" // ColumnName attribute is used to change the column name from\r\n /" + + "/ its default value, which is the name of the field.\r\n [ColumnName(\"Predi" + + "ctedLabel\")]\r\n public bool Prediction { get; set; }\r\n\r\n"); + + #line 25 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictionClass.tt" + } if("MulticlassClassification".Equals(TaskType)){ + + #line default + #line hidden + this.Write(" // ColumnName attribute is used to change the column name from\r\n /" + + "/ its default value, which is the name of the field.\r\n [ColumnName(\"Predi" + + "ctedLabel\")]\r\n public "); + + #line 29 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictionClass.tt" + this.Write(this.ToStringHelper.ToStringWithCulture(PredictionLabelType)); + + #line default + #line hidden + this.Write(" Prediction { get; set; }\r\n"); + + #line 30 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictionClass.tt" + } + + #line default + #line hidden + + #line 31 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictionClass.tt" +if("MulticlassClassification".Equals(TaskType)){ + + #line default + #line hidden + this.Write(" public float[] Score { get; set; }\r\n"); + + #line 33 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictionClass.tt" +}else{ + + #line default + #line hidden + this.Write(" public float Score { get; set; }\r\n"); + + #line 35 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictionClass.tt" +} + + #line default + #line hidden + this.Write(" }\r\n}\r\n"); + return this.GenerationEnvironment.ToString(); + } + + #line 38 "E:\src\machinelearning-automl\src\mlnet\Templates\Console\PredictionClass.tt" + +public string TaskType {get;set;} +public string PredictionLabelType {get;set;} +public string Namespace {get;set;} + + + #line default + #line hidden + } + + #line default + #line hidden + #region Base class + /// + /// Base class for this transformation + /// + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("Microsoft.VisualStudio.TextTemplating", "15.0.0.0")] + public class PredictionClassBase + { + #region Fields + private global::System.Text.StringBuilder generationEnvironmentField; + private global::System.CodeDom.Compiler.CompilerErrorCollection errorsField; + private global::System.Collections.Generic.List indentLengthsField; + private string currentIndentField = ""; + private bool endsWithNewline; + private global::System.Collections.Generic.IDictionary sessionField; + #endregion + #region Properties + /// + /// The string builder that generation-time code is using to assemble generated output + /// + protected System.Text.StringBuilder GenerationEnvironment + { + get + { + if ((this.generationEnvironmentField == null)) + { + this.generationEnvironmentField = new global::System.Text.StringBuilder(); + } + return this.generationEnvironmentField; + } + set + { + this.generationEnvironmentField = value; + } + } + /// + /// The error collection for the generation process + /// + public System.CodeDom.Compiler.CompilerErrorCollection Errors + { + get + { + if ((this.errorsField == null)) + { + this.errorsField = new global::System.CodeDom.Compiler.CompilerErrorCollection(); + } + return this.errorsField; + } + } + /// + /// A list of the lengths of each indent that was added with PushIndent + /// + private System.Collections.Generic.List indentLengths + { + get + { + if ((this.indentLengthsField == null)) + { + this.indentLengthsField = new global::System.Collections.Generic.List(); + } + return this.indentLengthsField; + } + } + /// + /// Gets the current indent we use when adding lines to the output + /// + public string CurrentIndent + { + get + { + return this.currentIndentField; + } + } + /// + /// Current transformation session + /// + public virtual global::System.Collections.Generic.IDictionary Session + { + get + { + return this.sessionField; + } + set + { + this.sessionField = value; + } + } + #endregion + #region Transform-time helpers + /// + /// Write text directly into the generated output + /// + public void Write(string textToAppend) + { + if (string.IsNullOrEmpty(textToAppend)) + { + return; + } + // If we're starting off, or if the previous text ended with a newline, + // we have to append the current indent first. + if (((this.GenerationEnvironment.Length == 0) + || this.endsWithNewline)) + { + this.GenerationEnvironment.Append(this.currentIndentField); + this.endsWithNewline = false; + } + // Check if the current text ends with a newline + if (textToAppend.EndsWith(global::System.Environment.NewLine, global::System.StringComparison.CurrentCulture)) + { + this.endsWithNewline = true; + } + // This is an optimization. If the current indent is "", then we don't have to do any + // of the more complex stuff further down. + if ((this.currentIndentField.Length == 0)) + { + this.GenerationEnvironment.Append(textToAppend); + return; + } + // Everywhere there is a newline in the text, add an indent after it + textToAppend = textToAppend.Replace(global::System.Environment.NewLine, (global::System.Environment.NewLine + this.currentIndentField)); + // If the text ends with a newline, then we should strip off the indent added at the very end + // because the appropriate indent will be added when the next time Write() is called + if (this.endsWithNewline) + { + this.GenerationEnvironment.Append(textToAppend, 0, (textToAppend.Length - this.currentIndentField.Length)); + } + else + { + this.GenerationEnvironment.Append(textToAppend); + } + } + /// + /// Write text directly into the generated output + /// + public void WriteLine(string textToAppend) + { + this.Write(textToAppend); + this.GenerationEnvironment.AppendLine(); + this.endsWithNewline = true; + } + /// + /// Write formatted text directly into the generated output + /// + public void Write(string format, params object[] args) + { + this.Write(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args)); + } + /// + /// Write formatted text directly into the generated output + /// + public void WriteLine(string format, params object[] args) + { + this.WriteLine(string.Format(global::System.Globalization.CultureInfo.CurrentCulture, format, args)); + } + /// + /// Raise an error + /// + public void Error(string message) + { + System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError(); + error.ErrorText = message; + this.Errors.Add(error); + } + /// + /// Raise a warning + /// + public void Warning(string message) + { + System.CodeDom.Compiler.CompilerError error = new global::System.CodeDom.Compiler.CompilerError(); + error.ErrorText = message; + error.IsWarning = true; + this.Errors.Add(error); + } + /// + /// Increase the indent + /// + public void PushIndent(string indent) + { + if ((indent == null)) + { + throw new global::System.ArgumentNullException("indent"); + } + this.currentIndentField = (this.currentIndentField + indent); + this.indentLengths.Add(indent.Length); + } + /// + /// Remove the last indent that was added with PushIndent + /// + public string PopIndent() + { + string returnValue = ""; + if ((this.indentLengths.Count > 0)) + { + int indentLength = this.indentLengths[(this.indentLengths.Count - 1)]; + this.indentLengths.RemoveAt((this.indentLengths.Count - 1)); + if ((indentLength > 0)) + { + returnValue = this.currentIndentField.Substring((this.currentIndentField.Length - indentLength)); + this.currentIndentField = this.currentIndentField.Remove((this.currentIndentField.Length - indentLength)); + } + } + return returnValue; + } + /// + /// Remove any indentation + /// + public void ClearIndent() + { + this.indentLengths.Clear(); + this.currentIndentField = ""; + } + #endregion + #region ToString Helpers + /// + /// Utility class to produce culture-oriented representation of an object as a string. + /// + public class ToStringInstanceHelper + { + private System.IFormatProvider formatProviderField = global::System.Globalization.CultureInfo.InvariantCulture; + /// + /// Gets or sets format provider to be used by ToStringWithCulture method. + /// + public System.IFormatProvider FormatProvider + { + get + { + return this.formatProviderField ; + } + set + { + if ((value != null)) + { + this.formatProviderField = value; + } + } + } + /// + /// This is called from the compile/run appdomain to convert objects within an expression block to a string + /// + public string ToStringWithCulture(object objectToConvert) + { + if ((objectToConvert == null)) + { + throw new global::System.ArgumentNullException("objectToConvert"); + } + System.Type t = objectToConvert.GetType(); + System.Reflection.MethodInfo method = t.GetMethod("ToString", new System.Type[] { + typeof(System.IFormatProvider)}); + if ((method == null)) + { + return objectToConvert.ToString(); + } + else + { + return ((string)(method.Invoke(objectToConvert, new object[] { + this.formatProviderField }))); + } + } + } + private ToStringInstanceHelper toStringHelperField = new ToStringInstanceHelper(); + /// + /// Helper to produce culture-oriented representation of an object as a string + /// + public ToStringInstanceHelper ToStringHelper + { + get + { + return this.toStringHelperField; + } + } + #endregion + } + #endregion +} diff --git a/src/mlnet/Templates/Console/PredictionClass.tt b/src/mlnet/Templates/Console/PredictionClass.tt new file mode 100644 index 0000000000..2497a440a8 --- /dev/null +++ b/src/mlnet/Templates/Console/PredictionClass.tt @@ -0,0 +1,42 @@ +<#@ template language="C#" #> +<#@ assembly name="System.Core" #> +<#@ import namespace="System.Linq" #> +<#@ import namespace="System.Text" #> +<#@ import namespace="System.Collections.Generic" #> +//***************************************************************************************** +//* * +//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * +//* * +//***************************************************************************************** + +using System; +using Microsoft.ML.Data; + +namespace <#= Namespace #>.Model.DataModels +{ + public class SamplePrediction + { +<#if("BinaryClassification".Equals(TaskType)){ #> + // ColumnName attribute is used to change the column name from + // its default value, which is the name of the field. + [ColumnName("PredictedLabel")] + public bool Prediction { get; set; } + +<# } if("MulticlassClassification".Equals(TaskType)){ #> + // ColumnName attribute is used to change the column name from + // its default value, which is the name of the field. + [ColumnName("PredictedLabel")] + public <#= PredictionLabelType#> Prediction { get; set; } +<# }#> +<#if("MulticlassClassification".Equals(TaskType)){ #> + public float[] Score { get; set; } +<#}else{ #> + public float Score { get; set; } +<#}#> + } +} +<#+ +public string TaskType {get;set;} +public string PredictionLabelType {get;set;} +public string Namespace {get;set;} +#> diff --git a/src/mlnet/Utilities/ConsolePrinter.cs b/src/mlnet/Utilities/ConsolePrinter.cs new file mode 100644 index 0000000000..a024f3c647 --- /dev/null +++ b/src/mlnet/Utilities/ConsolePrinter.cs @@ -0,0 +1,153 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Auto; +using Microsoft.ML.Data; +using NLog; + +namespace Microsoft.ML.CLI.Utilities +{ + internal class ConsolePrinter + { + private const int Width = 114; + private static NLog.Logger logger = NLog.LogManager.GetCurrentClassLogger(); + internal static readonly string TABLESEPERATOR = "------------------------------------------------------------------------------------------------------------------"; + + internal static void PrintMetrics(int iteration, string trainerName, BinaryClassificationMetrics metrics, double bestMetric, double? runtimeInSeconds, LogLevel logLevel, int iterationNumber = -1) + { + logger.Log(logLevel, CreateRow($"{iteration,-4} {trainerName,-35} {metrics?.Accuracy ?? double.NaN,9:F4} {metrics?.AreaUnderRocCurve ?? double.NaN,8:F4} {metrics?.AreaUnderPrecisionRecallCurve ?? double.NaN,8:F4} {metrics?.F1Score ?? double.NaN,9:F4} {runtimeInSeconds.Value,9:F1} {iterationNumber + 1,10}", Width)); + } + + internal static void PrintMetrics(int iteration, string trainerName, MulticlassClassificationMetrics metrics, double bestMetric, double? runtimeInSeconds, LogLevel logLevel, int iterationNumber = -1) + { + logger.Log(logLevel, CreateRow($"{iteration,-4} {trainerName,-35} {metrics?.MicroAccuracy ?? double.NaN,14:F4} {metrics?.MacroAccuracy ?? double.NaN,14:F4} {runtimeInSeconds.Value,9:F1} {iterationNumber + 1,10}", Width)); + } + + internal static void PrintMetrics(int iteration, string trainerName, RegressionMetrics metrics, double bestMetric, double? runtimeInSeconds, LogLevel logLevel, int iterationNumber = -1) + { + logger.Log(logLevel, CreateRow($"{iteration,-4} {trainerName,-35} {metrics?.RSquared ?? double.NaN,8:F4} {metrics?.MeanAbsoluteError ?? double.NaN,13:F2} {metrics?.MeanSquaredError ?? double.NaN,12:F2} {metrics?.RootMeanSquaredError ?? double.NaN,8:F2} {runtimeInSeconds.Value,9:F1} {iterationNumber + 1,10}", Width)); + } + + internal static void PrintBinaryClassificationMetricsHeader(LogLevel logLevel) + { + logger.Log(logLevel, CreateRow($"{"",-4} {"Trainer",-35} {"Accuracy",9} {"AUC",8} {"AUPRC",8} {"F1-score",9} {"Duration",9} {"#Iteration",10}", Width)); + } + + internal static void PrintMulticlassClassificationMetricsHeader(LogLevel logLevel) + { + logger.Log(logLevel, CreateRow($"{"",-4} {"Trainer",-35} {"MicroAccuracy",14} {"MacroAccuracy",14} {"Duration",9} {"#Iteration",10}", Width)); + } + + internal static void PrintRegressionMetricsHeader(LogLevel logLevel) + { + logger.Log(logLevel, CreateRow($"{"",-4} {"Trainer",-35} {"RSquared",8} {"Absolute-loss",13} {"Squared-loss",12} {"RMS-loss",8} {"Duration",9} {"#Iteration",10}", Width)); + } + + internal static void ExperimentResultsHeader(LogLevel logLevel, string mltask, string datasetName, string labelName, string time, int numModelsExplored) + { + logger.Log(logLevel, string.Empty); + logger.Log(logLevel, $"===============================================Experiment Results================================================="); + logger.Log(logLevel, TABLESEPERATOR); + var header = "Summary"; + logger.Log(logLevel, CreateRow(header.PadLeft((Width / 2) + header.Length / 2), Width)); + logger.Log(logLevel, TABLESEPERATOR); + logger.Log(logLevel, CreateRow($"{"ML Task",-7}: {mltask,-20}", Width)); + logger.Log(logLevel, CreateRow($"{"Dataset",-7}: {datasetName,-25}", Width)); + logger.Log(logLevel, CreateRow($"{"Label",-6}: {labelName,-25}", Width)); + logger.Log(logLevel, CreateRow($"{"Total experiment time",-22}: {time} Secs", Width)); + logger.Log(logLevel, CreateRow($"{"Total number of models explored",-30}: {numModelsExplored}", Width)); + logger.Log(logLevel, TABLESEPERATOR); + } + + internal static string CreateRow(string message, int width) + { + return "|" + message.PadRight(width - 2) + "|"; + } + + internal static void PrintIterationSummary(IEnumerable> results, BinaryClassificationMetric optimizationMetric, int count) + { + var metricsAgent = new BinaryMetricsAgent(null, optimizationMetric); + var topNResults = BestResultUtil.GetTopNRunResults(results, metricsAgent, count, new OptimizingMetricInfo(optimizationMetric).IsMaximizing); + var header = $"Top {topNResults?.Count()} models explored"; + logger.Log(LogLevel.Info, CreateRow(header.PadLeft((Width / 2) + header.Length / 2), Width)); + logger.Log(LogLevel.Info, TABLESEPERATOR); + + PrintBinaryClassificationMetricsHeader(LogLevel.Info); + int i = 0; + foreach (var pair in topNResults) + { + var result = pair.Item1; + if (i == 0) + { + // Print top iteration colored. + Console.ForegroundColor = ConsoleColor.Yellow; + PrintMetrics(++i, result?.TrainerName, result?.ValidationMetrics, metricsAgent.GetScore(result?.ValidationMetrics), result?.RuntimeInSeconds, LogLevel.Info, pair.Item2); + Console.ResetColor(); + continue; + } + PrintMetrics(++i, result?.TrainerName, result?.ValidationMetrics, metricsAgent.GetScore(result?.ValidationMetrics), result?.RuntimeInSeconds, LogLevel.Info, pair.Item2); + } + logger.Log(LogLevel.Info, TABLESEPERATOR); + } + + internal static void PrintIterationSummary(IEnumerable> results, RegressionMetric optimizationMetric, int count) + { + var metricsAgent = new RegressionMetricsAgent(null, optimizationMetric); + var topNResults = BestResultUtil.GetTopNRunResults(results, metricsAgent, count, new OptimizingMetricInfo(optimizationMetric).IsMaximizing); + var header = $"Top {topNResults?.Count()} models explored"; + logger.Log(LogLevel.Info, CreateRow(header.PadLeft((Width / 2) + header.Length / 2), Width)); + logger.Log(LogLevel.Info, TABLESEPERATOR); + + PrintRegressionMetricsHeader(LogLevel.Info); + int i = 0; + foreach (var pair in topNResults) + { + var result = pair.Item1; + if (i == 0) + { + // Print top iteration colored. + Console.ForegroundColor = ConsoleColor.Yellow; + PrintMetrics(++i, result?.TrainerName, result?.ValidationMetrics, metricsAgent.GetScore(result?.ValidationMetrics), result?.RuntimeInSeconds, LogLevel.Info, pair.Item2); + Console.ResetColor(); + continue; + } + PrintMetrics(++i, result?.TrainerName, result?.ValidationMetrics, metricsAgent.GetScore(result?.ValidationMetrics), result?.RuntimeInSeconds, LogLevel.Info, pair.Item2); + } + logger.Log(LogLevel.Info, TABLESEPERATOR); + } + + internal static void PrintIterationSummary(IEnumerable> results, MulticlassClassificationMetric optimizationMetric, int count) + { + var metricsAgent = new MultiMetricsAgent(null, optimizationMetric); + var topNResults = BestResultUtil.GetTopNRunResults(results, metricsAgent, count, new OptimizingMetricInfo(optimizationMetric).IsMaximizing); + var header = $"Top {topNResults?.Count()} models explored"; + logger.Log(LogLevel.Info, CreateRow(header.PadLeft((Width / 2) + header.Length / 2), Width)); + logger.Log(LogLevel.Info, TABLESEPERATOR); + PrintMulticlassClassificationMetricsHeader(LogLevel.Info); + int i = 0; + foreach (var pair in topNResults) + { + var result = pair.Item1; + if (i == 0) + { + // Print top iteration colored. + Console.ForegroundColor = ConsoleColor.Yellow; + PrintMetrics(++i, result?.TrainerName, result?.ValidationMetrics, metricsAgent.GetScore(result?.ValidationMetrics), result?.RuntimeInSeconds, LogLevel.Info, pair.Item2); + Console.ResetColor(); + continue; + } + PrintMetrics(++i, result?.TrainerName, result?.ValidationMetrics, metricsAgent.GetScore(result?.ValidationMetrics), result?.RuntimeInSeconds, LogLevel.Info, pair.Item2); + } + logger.Log(LogLevel.Info, TABLESEPERATOR); + } + internal static void PrintException(Exception e, LogLevel logLevel) + { + logger.Log(logLevel, e.ToString()); + } + } +} + diff --git a/src/mlnet/Utilities/ProgressHandlers.cs b/src/mlnet/Utilities/ProgressHandlers.cs new file mode 100644 index 0000000000..98519c210c --- /dev/null +++ b/src/mlnet/Utilities/ProgressHandlers.cs @@ -0,0 +1,161 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.ML.Auto; +using Microsoft.ML.CLI.ShellProgressBar; +using Microsoft.ML.Data; +using NLog; + +namespace Microsoft.ML.CLI.Utilities +{ + internal class ProgressHandlers + { + private static int MetricComparator(double a, double b, bool isMaximizing) + { + return (isMaximizing ? a.CompareTo(b) : -a.CompareTo(b)); + } + + internal class RegressionHandler : IProgress> + { + private readonly bool isMaximizing; + private readonly Func, double> GetScore; + private RunDetail bestResult; + private int iterationIndex; + private ProgressBar progressBar; + private string optimizationMetric = string.Empty; + + public RegressionHandler(RegressionMetric optimizationMetric, ShellProgressBar.ProgressBar progressBar) + { + this.isMaximizing = new OptimizingMetricInfo(optimizationMetric).IsMaximizing; + this.optimizationMetric = optimizationMetric.ToString(); + this.progressBar = progressBar; + GetScore = (RunDetail result) => new RegressionMetricsAgent(null, optimizationMetric).GetScore(result?.ValidationMetrics); + ConsolePrinter.PrintRegressionMetricsHeader(LogLevel.Trace); + } + + public void Report(RunDetail iterationResult) + { + iterationIndex++; + UpdateBestResult(iterationResult); + if (progressBar != null) + progressBar.Message = $"Best {this.optimizationMetric}: {GetScore(bestResult):F4}, Best Algorithm: {bestResult?.TrainerName}, Last Algorithm: {iterationResult?.TrainerName}"; + ConsolePrinter.PrintMetrics(iterationIndex, iterationResult?.TrainerName, iterationResult?.ValidationMetrics, GetScore(bestResult), iterationResult?.RuntimeInSeconds, LogLevel.Trace); + if (iterationResult.Exception != null) + { + ConsolePrinter.PrintException(iterationResult.Exception, LogLevel.Trace); + } + } + + private void UpdateBestResult(RunDetail iterationResult) + { + if (MetricComparator(GetScore(iterationResult), GetScore(bestResult), isMaximizing) > 0) + { + bestResult = iterationResult; + } + } + } + + internal class BinaryClassificationHandler : IProgress> + { + private readonly bool isMaximizing; + private readonly Func, double> GetScore; + private RunDetail bestResult; + private int iterationIndex; + private ProgressBar progressBar; + private BinaryClassificationMetric optimizationMetric; + + public BinaryClassificationHandler(BinaryClassificationMetric optimizationMetric, ProgressBar progressBar) + { + this.isMaximizing = new OptimizingMetricInfo(optimizationMetric).IsMaximizing; + this.optimizationMetric = optimizationMetric; + this.progressBar = progressBar; + GetScore = (RunDetail result) => new BinaryMetricsAgent(null, optimizationMetric).GetScore(result?.ValidationMetrics); + ConsolePrinter.PrintBinaryClassificationMetricsHeader(LogLevel.Trace); + } + + public void Report(RunDetail iterationResult) + { + iterationIndex++; + UpdateBestResult(iterationResult); + if (progressBar != null) + progressBar.Message = GetProgressBarMessage(iterationResult); + ConsolePrinter.PrintMetrics(iterationIndex, iterationResult?.TrainerName, iterationResult?.ValidationMetrics, GetScore(bestResult), iterationResult?.RuntimeInSeconds, LogLevel.Trace); + if (iterationResult.Exception != null) + { + ConsolePrinter.PrintException(iterationResult.Exception, LogLevel.Trace); + } + } + + private string GetProgressBarMessage(RunDetail iterationResult) + { + if (optimizationMetric == BinaryClassificationMetric.Accuracy) + { + return $"Best Accuracy: {GetScore(bestResult) * 100:F2}%, Best Algorithm: {bestResult?.TrainerName}, Last Algorithm: {iterationResult?.TrainerName}"; + } + + return $"Best {this.optimizationMetric}: {GetScore(bestResult):F4}, Best Algorithm: {bestResult?.TrainerName}, Last Algorithm: {iterationResult?.TrainerName}"; + } + + private void UpdateBestResult(RunDetail iterationResult) + { + if (MetricComparator(GetScore(iterationResult), GetScore(bestResult), isMaximizing) > 0) + { + bestResult = iterationResult; + } + } + } + + internal class MulticlassClassificationHandler : IProgress> + { + private readonly bool isMaximizing; + private readonly Func, double> GetScore; + private RunDetail bestResult; + private int iterationIndex; + private ProgressBar progressBar; + private MulticlassClassificationMetric optimizationMetric; + + public MulticlassClassificationHandler(MulticlassClassificationMetric optimizationMetric, ProgressBar progressBar) + { + this.isMaximizing = new OptimizingMetricInfo(optimizationMetric).IsMaximizing; + this.optimizationMetric = optimizationMetric; + this.progressBar = progressBar; + GetScore = (RunDetail result) => new MultiMetricsAgent(null, optimizationMetric).GetScore(result?.ValidationMetrics); + ConsolePrinter.PrintMulticlassClassificationMetricsHeader(LogLevel.Trace); + } + + public void Report(RunDetail iterationResult) + { + iterationIndex++; + UpdateBestResult(iterationResult); + if (progressBar != null) + progressBar.Message = GetProgressBarMessage(iterationResult); + ConsolePrinter.PrintMetrics(iterationIndex, iterationResult?.TrainerName, iterationResult?.ValidationMetrics, GetScore(bestResult), iterationResult?.RuntimeInSeconds, LogLevel.Trace); + if (iterationResult.Exception != null) + { + ConsolePrinter.PrintException(iterationResult.Exception, LogLevel.Trace); + } + } + + private void UpdateBestResult(RunDetail iterationResult) + { + if (MetricComparator(GetScore(iterationResult), GetScore(bestResult), isMaximizing) > 0) + { + bestResult = iterationResult; + } + } + + private string GetProgressBarMessage(RunDetail iterationResult) + { + if (optimizationMetric == MulticlassClassificationMetric.MicroAccuracy) + { + return $"Best Accuracy: {GetScore(bestResult) * 100:F2}%, Best Algorithm: {bestResult?.TrainerName}, Last Algorithm: {iterationResult?.TrainerName}"; + } + + return $"Best {this.optimizationMetric}: {GetScore(bestResult):F4}, Best Algorithm: {bestResult?.TrainerName}, Last Algorithm: {iterationResult?.TrainerName}"; + } + } + + } +} \ No newline at end of file diff --git a/src/mlnet/Utilities/Utils.cs b/src/mlnet/Utilities/Utils.cs new file mode 100644 index 0000000000..a8ee940e84 --- /dev/null +++ b/src/mlnet/Utilities/Utils.cs @@ -0,0 +1,222 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Formatting; +using Microsoft.ML.Auto; +using Microsoft.ML.Data; +using NLog; + +namespace Microsoft.ML.CLI.Utilities +{ + internal class Utils + { + internal static LogLevel GetVerbosity(string verbosity) + { + switch (verbosity) + { + case "q": + return LogLevel.Warn; + case "m": + return LogLevel.Info; + case "diag": + return LogLevel.Trace; + default: + return LogLevel.Info; + } + } + + internal static void SaveModel(ITransformer model, FileInfo modelPath, MLContext mlContext, + DataViewSchema modelInputSchema) + { + + if (!Directory.Exists(modelPath.Directory.FullName)) + { + Directory.CreateDirectory(modelPath.Directory.FullName); + } + + using (var fs = File.Create(modelPath.FullName)) + mlContext.Model.Save(model, modelInputSchema, fs); + } + + internal static string Sanitize(string name) + { + return string.Join("", name.Select(x => Char.IsLetterOrDigit(x) ? x : '_')); + } + + internal static TaskKind GetTaskKind(string mlTask) + { + switch (mlTask) + { + case "binary-classification": + return TaskKind.BinaryClassification; + case "multiclass-classification": + return TaskKind.MulticlassClassification; + case "regression": + return TaskKind.Regression; + default: // this should never be hit because the validation is done on command-line-api. + throw new NotImplementedException($"{Strings.UnsupportedMlTask} : {mlTask}"); + } + } + + internal static string Normalize(string input) + { + //check if first character is int + if (!string.IsNullOrEmpty(input) && int.TryParse(input.Substring(0, 1), out int val)) + { + input = "Col" + input; + return input; + } + switch (input) + { + case null: throw new ArgumentNullException(nameof(input)); + case "": throw new ArgumentException($"{nameof(input)} cannot be empty", nameof(input)); + default: + var sanitizedInput = Sanitize(input); + return sanitizedInput.First().ToString().ToUpper() + input.Substring(1); + } + } + + internal static Type GetCSharpType(DataKind labelType) + { + switch (labelType) + { + case Microsoft.ML.Data.DataKind.String: + return typeof(string); + case Microsoft.ML.Data.DataKind.Boolean: + return typeof(bool); + case Microsoft.ML.Data.DataKind.Single: + return typeof(float); + case Microsoft.ML.Data.DataKind.Double: + return typeof(double); + case Microsoft.ML.Data.DataKind.Int32: + return typeof(int); + case Microsoft.ML.Data.DataKind.UInt32: + return typeof(uint); + case Microsoft.ML.Data.DataKind.Int64: + return typeof(long); + case Microsoft.ML.Data.DataKind.UInt64: + return typeof(ulong); + default: + throw new ArgumentException($"The data type '{labelType}' is not handled currently."); + } + } + + internal static bool? GetCacheSettings(string input) + { + switch (input) + { + case "on": return true; + case "off": return false; + case "auto": return null; + default: + throw new ArgumentException($"{nameof(input)} is invalid", nameof(input)); + } + } + + internal static ColumnInformation GetSanitizedColumnInformation(ColumnInformation columnInformation) + { + var result = new ColumnInformation(); + + result.LabelColumnName = Sanitize(columnInformation.LabelColumnName); + + if (!string.IsNullOrEmpty(columnInformation.ExampleWeightColumnName)) + result.ExampleWeightColumnName = Sanitize(columnInformation.ExampleWeightColumnName); + + if (!string.IsNullOrEmpty(columnInformation.SamplingKeyColumnName)) + result.SamplingKeyColumnName = Sanitize(columnInformation.SamplingKeyColumnName); + + foreach (var value in columnInformation.CategoricalColumnNames) + { + result.CategoricalColumnNames.Add(Sanitize(value)); + } + + foreach (var value in columnInformation.IgnoredColumnNames) + { + result.IgnoredColumnNames.Add(Sanitize(value)); + } + + foreach (var value in columnInformation.NumericColumnNames) + { + result.NumericColumnNames.Add(Sanitize(value)); + } + + foreach (var value in columnInformation.TextColumnNames) + { + result.TextColumnNames.Add(Sanitize(value)); + } + return result; + } + + internal static void WriteOutputToFiles(string fileContent, string fileName, string outPutBaseDir) + { + if (!Directory.Exists(outPutBaseDir)) + { + Directory.CreateDirectory(outPutBaseDir); + } + File.WriteAllText($"{outPutBaseDir}/{fileName}", fileContent); + } + + internal static string FormatCode(string trainProgramCSFileContent) + { + //Format + var tree = CSharpSyntaxTree.ParseText(trainProgramCSFileContent); + var syntaxNode = tree.GetRoot(); + trainProgramCSFileContent = Formatter.Format(syntaxNode, new AdhocWorkspace()).ToFullString(); + return trainProgramCSFileContent; + } + + + internal static int AddProjectsToSolution(string modelprojectDir, + string modelProjectName, + string consoleAppProjectDir, + string consoleAppProjectName, + string solutionPath) + { + // TODO make this method generic : (string solutionpath, string[] projects) + var proc = new System.Diagnostics.Process(); + try + { + proc.StartInfo.FileName = @"dotnet"; + proc.StartInfo.Arguments = $"sln \"{solutionPath}\" add \"{Path.Combine(consoleAppProjectDir, consoleAppProjectName)}\" \"{Path.Combine(modelprojectDir, modelProjectName)}\""; + proc.StartInfo.UseShellExecute = false; + proc.StartInfo.RedirectStandardOutput = true; + proc.Start(); + string outPut = proc.StandardOutput.ReadToEnd(); + proc.WaitForExit(); + var exitCode = proc.ExitCode; + return exitCode; + } + finally + { + proc.Close(); + } + } + + internal static int CreateSolutionFile(string solutionFile, string outputPath) + { + var proc = new System.Diagnostics.Process(); + try + { + proc.StartInfo.FileName = @"dotnet"; + proc.StartInfo.Arguments = $"new sln --name \"{solutionFile}\" --output \"{outputPath}\" --force"; + proc.StartInfo.UseShellExecute = false; + proc.StartInfo.RedirectStandardOutput = true; + proc.Start(); + string outPut = proc.StandardOutput.ReadToEnd(); + proc.WaitForExit(); + var exitCode = proc.ExitCode; + return exitCode; + } + finally + { + proc.Close(); + } + } + } +} diff --git a/src/mlnet/mlnet.csproj b/src/mlnet/mlnet.csproj new file mode 100644 index 0000000000..ea5fe98666 --- /dev/null +++ b/src/mlnet/mlnet.csproj @@ -0,0 +1,118 @@ + + + + Exe + netcoreapp2.1 + true + Microsoft.ML.CLI + mlnet + mlnet + mlnet + + false + false + + + + + + + + + + + + + + + + mscorlib + + + System + + + System.Core + + + + + + + + + + True + True + Strings.resx + + + True + True + ModelProject.tt + + + True + True + ObservationClass.tt + + + True + True + PredictionClass.tt + + + True + True + PredictProgram.tt + + + True + True + PredictProject.tt + + + True + True + ModelBuilder.tt + + + + + + ResXFileCodeGenerator + Strings.Designer.cs + + + + + + Always + + + TextTemplatingFilePreprocessor + ModelProject.cs + + + TextTemplatingFilePreprocessor + ObservationClass.cs + + + TextTemplatingFilePreprocessor + PredictionClass.cs + + + TextTemplatingFilePreprocessor + PredictProgram.cs + + + TextTemplatingFilePreprocessor + PredictProject.cs + + + TextTemplatingFilePreprocessor + ModelBuilder.cs + + + + diff --git a/src/mlnet/strings.Designer.cs b/src/mlnet/strings.Designer.cs new file mode 100644 index 0000000000..85e622ca2d --- /dev/null +++ b/src/mlnet/strings.Designer.cs @@ -0,0 +1,288 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by a tool. +// Runtime Version:4.0.30319.42000 +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + +namespace Microsoft.ML.CLI { + using System; + + + /// + /// A strongly-typed resource class, for looking up localized strings, etc. + /// + // This class was auto-generated by the StronglyTypedResourceBuilder + // class via a tool like ResGen or Visual Studio. + // To add or remove a member, edit your .ResX file then rerun ResGen + // with the /str option, or rebuild your VS project. + [global::System.CodeDom.Compiler.GeneratedCodeAttribute("System.Resources.Tools.StronglyTypedResourceBuilder", "15.0.0.0")] + [global::System.Diagnostics.DebuggerNonUserCodeAttribute()] + [global::System.Runtime.CompilerServices.CompilerGeneratedAttribute()] + internal class Strings { + + private static global::System.Resources.ResourceManager resourceMan; + + private static global::System.Globalization.CultureInfo resourceCulture; + + [global::System.Diagnostics.CodeAnalysis.SuppressMessageAttribute("Microsoft.Performance", "CA1811:AvoidUncalledPrivateCode")] + internal Strings() { + } + + /// + /// Returns the cached ResourceManager instance used by this class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Resources.ResourceManager ResourceManager { + get { + if (object.ReferenceEquals(resourceMan, null)) { + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Microsoft.ML.CLI.Strings", typeof(Strings).Assembly); + resourceMan = temp; + } + return resourceMan; + } + } + + /// + /// Overrides the current thread's CurrentUICulture property for all + /// resource lookups using this strongly typed resource class. + /// + [global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Advanced)] + internal static global::System.Globalization.CultureInfo Culture { + get { + return resourceCulture; + } + set { + resourceCulture = value; + } + } + + /// + /// Looks up a localized string similar to Best pipeline. + /// + internal static string BestPipeline { + get { + return ResourceManager.GetString("BestPipeline", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Creating Data loader .... + /// + internal static string CreateDataLoader { + get { + return ResourceManager.GetString("CreateDataLoader", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Error occured while retreiving best pipeline.. + /// + internal static string ErrorBestPipeline { + get { + return ResourceManager.GetString("ErrorBestPipeline", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Exiting .... + /// + internal static string Exiting { + get { + return ResourceManager.GetString("Exiting", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Exploring multiple ML algorithms and settings to find you the best model for ML task. + /// + internal static string ExplorePipeline { + get { + return ResourceManager.GetString("ExplorePipeline", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Exception occured while exploring pipelines. + /// + internal static string ExplorePipelineException { + get { + return ResourceManager.GetString("ExplorePipelineException", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to For further learning check. + /// + internal static string FurtherLearning { + get { + return ResourceManager.GetString("FurtherLearning", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Generated log file . + /// + internal static string GenerateLogFile { + get { + return ResourceManager.GetString("GenerateLogFile", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Generated C# code for model consumption. + /// + internal static string GenerateModelConsumption { + get { + return ResourceManager.GetString("GenerateModelConsumption", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Generated C# code for model training. + /// + internal static string GenerateModelTraining { + get { + return ResourceManager.GetString("GenerateModelTraining", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Generating a console project for the best pipeline at location . + /// + internal static string GenerateProject { + get { + return ResourceManager.GetString("GenerateProject", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to An Error occured during inferring columns. + /// + internal static string InferColumnError { + get { + return ResourceManager.GetString("InferColumnError", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Inferring Columns .... + /// + internal static string InferColumns { + get { + return ResourceManager.GetString("InferColumns", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to https://aka.ms/mlnet-cli. + /// + internal static string LearningHttpLink { + get { + return ResourceManager.GetString("LearningHttpLink", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Loading data .... + /// + internal static string LoadData { + get { + return ResourceManager.GetString("LoadData", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Please see the log file for more info.. + /// + internal static string LookIntoLogFile { + get { + return ResourceManager.GetString("LookIntoLogFile", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Metrics for Binary Classification models. + /// + internal static string MetricsForBinaryClassModels { + get { + return ResourceManager.GetString("MetricsForBinaryClassModels", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Metrics for multi-class models. + /// + internal static string MetricsForMulticlassModels { + get { + return ResourceManager.GetString("MetricsForMulticlassModels", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Metrics for regression models. + /// + internal static string MetricsForRegressionModels { + get { + return ResourceManager.GetString("MetricsForRegressionModels", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Retrieving best pipeline .... + /// + internal static string RetrieveBestPipeline { + get { + return ResourceManager.GetString("RetrieveBestPipeline", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Generated trained model for consumption. + /// + internal static string SavingBestModel { + get { + return ResourceManager.GetString("SavingBestModel", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Check out log file for more information. + /// + internal static string SeeLogFileForMoreInfo { + get { + return ResourceManager.GetString("SeeLogFileForMoreInfo", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Unsupported ml-task. + /// + internal static string UnsupportedMlTask { + get { + return ResourceManager.GetString("UnsupportedMlTask", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Waiting for the first iteration to complete .... + /// + internal static string WaitingForFirstIteration { + get { + return ResourceManager.GetString("WaitingForFirstIteration", resourceCulture); + } + } + + /// + /// Looks up a localized string similar to Waiting for the last iteration to complete .... + /// + internal static string WaitingForLastIteration { + get { + return ResourceManager.GetString("WaitingForLastIteration", resourceCulture); + } + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs new file mode 100644 index 0000000000..cef86d8557 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/AutoFitTests.cs @@ -0,0 +1,66 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Linq; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class AutoFitTests + { + [TestMethod] + public void AutoFitBinaryTest() + { + var context = new MLContext(); + var dataPath = DatasetUtil.DownloadUciAdultDataset(); + var columnInference = context.Auto().InferColumns(dataPath, DatasetUtil.UciAdultLabel); + var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions); + var trainData = textLoader.Load(dataPath); + var results = context.Auto() + .CreateBinaryClassificationExperiment(0) + .Execute(trainData, new ColumnInformation() { LabelColumnName = DatasetUtil.UciAdultLabel }); + var best = results.Best(); + Assert.IsTrue(best.ValidationMetrics.Accuracy > 0.70); + Assert.IsNotNull(best.Estimator); + Assert.IsNotNull(best.Model); + Assert.IsNotNull(best.TrainerName); + } + + [TestMethod] + public void AutoFitMultiTest() + { + var context = new MLContext(); + var columnInference = context.Auto().InferColumns(DatasetUtil.TrivialMulticlassDatasetPath, DatasetUtil.TrivialMulticlassDatasetLabel); + var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions); + var trainData = textLoader.Load(DatasetUtil.TrivialMulticlassDatasetPath); + var results = context.Auto() + .CreateMulticlassClassificationExperiment(0) + .Execute(trainData, 5, DatasetUtil.TrivialMulticlassDatasetLabel); + var best = results.Best(); + Assert.IsTrue(best.Results.First().ValidationMetrics.MicroAccuracy >= 0.7); + var scoredData = best.Results.First().Model.Transform(trainData); + Assert.AreEqual(NumberDataViewType.Single, scoredData.Schema[DefaultColumnNames.PredictedLabel].Type); + } + + [TestMethod] + public void AutoFitRegressionTest() + { + var context = new MLContext(); + var dataPath = DatasetUtil.DownloadMlNetGeneratedRegressionDataset(); + var columnInference = context.Auto().InferColumns(dataPath, DatasetUtil.MlNetGeneratedRegressionLabel); + var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderOptions); + var trainData = textLoader.Load(dataPath); + var validationData = context.Data.TakeRows(trainData, 20); + trainData = context.Data.SkipRows(trainData, 20); + var results = context.Auto() + .CreateRegressionExperiment(0) + .Execute(trainData, validationData, + new ColumnInformation() { LabelColumnName = DatasetUtil.MlNetGeneratedRegressionLabel }); + + Assert.IsTrue(results.Max(i => i.ValidationMetrics.RSquared > 0.9)); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/BestResultUtilTests.cs b/test/Microsoft.ML.AutoML.Tests/BestResultUtilTests.cs new file mode 100644 index 0000000000..a0693e582f --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/BestResultUtilTests.cs @@ -0,0 +1,63 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.ML.Data; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class BestResultUtilTests + { + [TestMethod] + public void FindBestResultWithSomeNullMetrics() + { + var metrics1 = MetricsUtil.CreateRegressionMetrics(0.2, 0.2, 0.2, 0.2, 0.2); + var metrics2 = MetricsUtil.CreateRegressionMetrics(0.3, 0.3, 0.3, 0.3, 0.3); + var metrics3 = MetricsUtil.CreateRegressionMetrics(0.1, 0.1, 0.1, 0.1, 0.1); + + var runResults = new List>() + { + new RunDetail(null, null, null, null, null, null), + new RunDetail(null, null, null, null, metrics1, null), + new RunDetail(null, null, null, null, metrics2, null), + new RunDetail(null, null, null, null, metrics3, null), + }; + + var metricsAgent = new RegressionMetricsAgent(null, RegressionMetric.RSquared); + var bestResult = BestResultUtil.GetBestRun(runResults, metricsAgent, true); + Assert.AreEqual(0.3, bestResult.ValidationMetrics.RSquared); + } + + [TestMethod] + public void FindBestResultWithAllNullMetrics() + { + var runResults = new List>() + { + new RunDetail(null, null, null, null, null, null), + }; + + var metricsAgent = new RegressionMetricsAgent(null, RegressionMetric.RSquared); + var bestResult = BestResultUtil.GetBestRun(runResults, metricsAgent, true); + Assert.AreEqual(null, bestResult); + } + + [TestMethod] + public void GetIndexOfBestScoreMaximizingUtil() + { + var scores = new double[] { 0, 2, 5, 100, -100, -70 }; + var indexOfMaxScore = BestResultUtil.GetIndexOfBestScore(scores, true); + Assert.AreEqual(3, indexOfMaxScore); + } + + [TestMethod] + public void GetIndexOfBestScoreMinimizingUtil() + { + var scores = new double[] { 0, 2, 5, 100, -100, -70 }; + var indexOfMaxScore = BestResultUtil.GetIndexOfBestScore(scores, false); + Assert.AreEqual(4, indexOfMaxScore); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/ColumnInferenceTests.cs b/test/Microsoft.ML.AutoML.Tests/ColumnInferenceTests.cs new file mode 100644 index 0000000000..be6e7be4f6 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/ColumnInferenceTests.cs @@ -0,0 +1,151 @@ +using System; +using System.IO; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class ColumnInferenceTests + { + [TestMethod] + public void UnGroupReturnsMoreColumnsThanGroup() + { + var dataPath = DatasetUtil.DownloadUciAdultDataset(); + var context = new MLContext(); + var columnInferenceWithoutGrouping = context.Auto().InferColumns(dataPath, DatasetUtil.UciAdultLabel, groupColumns: false); + foreach (var col in columnInferenceWithoutGrouping.TextLoaderOptions.Columns) + { + Assert.IsFalse(col.Source.Length > 1 || col.Source[0].Min != col.Source[0].Max); + } + + var columnInferenceWithGrouping = context.Auto().InferColumns(dataPath, DatasetUtil.UciAdultLabel, groupColumns: true); + Assert.IsTrue(columnInferenceWithGrouping.TextLoaderOptions.Columns.Count() < columnInferenceWithoutGrouping.TextLoaderOptions.Columns.Count()); + } + + [TestMethod] + public void IncorrectLabelColumnThrows() + { + var dataPath = DatasetUtil.DownloadUciAdultDataset(); + var context = new MLContext(); + Assert.ThrowsException(new System.Action(() => context.Auto().InferColumns(dataPath, "Junk", groupColumns: false))); + } + + [TestMethod] + [ExpectedException(typeof(ArgumentOutOfRangeException))] + public void LabelIndexOutOfBoundsThrows() + { + new MLContext().Auto().InferColumns(DatasetUtil.DownloadUciAdultDataset(), 100); + } + + [TestMethod] + public void IdentifyLabelColumnThroughIndexWithHeader() + { + var result = new MLContext().Auto().InferColumns(DatasetUtil.DownloadUciAdultDataset(), 14, hasHeader: true); + Assert.AreEqual(true, result.TextLoaderOptions.HasHeader); + var labelCol = result.TextLoaderOptions.Columns.First(c => c.Source[0].Min == 14 && c.Source[0].Max == 14); + Assert.AreEqual("hours-per-week", labelCol.Name); + Assert.AreEqual("hours-per-week", result.ColumnInformation.LabelColumnName); + } + + [TestMethod] + public void IdentifyLabelColumnThroughIndexWithoutHeader() + { + var result = new MLContext().Auto().InferColumns(DatasetUtil.DownloadIrisDataset(), DatasetUtil.IrisDatasetLabelColIndex); + Assert.AreEqual(false, result.TextLoaderOptions.HasHeader); + var labelCol = result.TextLoaderOptions.Columns.First(c => c.Source[0].Min == DatasetUtil.IrisDatasetLabelColIndex && + c.Source[0].Max == DatasetUtil.IrisDatasetLabelColIndex); + Assert.AreEqual(DefaultColumnNames.Label, labelCol.Name); + Assert.AreEqual(DefaultColumnNames.Label, result.ColumnInformation.LabelColumnName); + } + + [TestMethod] + public void DatasetWithEmptyColumn() + { + var result = new MLContext().Auto().InferColumns(Path.Combine("TestData", "DatasetWithEmptyColumn.txt"), DefaultColumnNames.Label, groupColumns: false); + var emptyColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "Empty"); + Assert.AreEqual(DataKind.Single, emptyColumn.DataKind); + } + + [TestMethod] + public void DatasetWithBoolColumn() + { + var result = new MLContext().Auto().InferColumns(Path.Combine("TestData", "BinaryDatasetWithBoolColumn.txt"), DefaultColumnNames.Label); + Assert.AreEqual(2, result.TextLoaderOptions.Columns.Count()); + + var boolColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "Bool"); + var labelColumn = result.TextLoaderOptions.Columns.First(c => c.Name == DefaultColumnNames.Label); + // ensure non-label Boolean column is detected as R4 + Assert.AreEqual(DataKind.Single, boolColumn.DataKind); + Assert.AreEqual(DataKind.Boolean, labelColumn.DataKind); + + // ensure non-label Boolean column is detected as R4 + Assert.AreEqual(1, result.ColumnInformation.NumericColumnNames.Count()); + Assert.AreEqual("Bool", result.ColumnInformation.NumericColumnNames.First()); + Assert.AreEqual(DefaultColumnNames.Label, result.ColumnInformation.LabelColumnName); + } + + [TestMethod] + public void WhereNameColumnIsOnlyFeature() + { + var result = new MLContext().Auto().InferColumns(Path.Combine("TestData", "NameColumnIsOnlyFeatureDataset.txt"), DefaultColumnNames.Label); + Assert.AreEqual(2, result.TextLoaderOptions.Columns.Count()); + + var nameColumn = result.TextLoaderOptions.Columns.First(c => c.Name == "Username"); + var labelColumn = result.TextLoaderOptions.Columns.First(c => c.Name == DefaultColumnNames.Label); + Assert.AreEqual(DataKind.String, nameColumn.DataKind); + Assert.AreEqual(DataKind.Boolean, labelColumn.DataKind); + + Assert.AreEqual(1, result.ColumnInformation.TextColumnNames.Count()); + Assert.AreEqual("Username", result.ColumnInformation.TextColumnNames.First()); + Assert.AreEqual(DefaultColumnNames.Label, result.ColumnInformation.LabelColumnName); + } + + [TestMethod] + public void DefaultColumnNamesInferredCorrectly() + { + var result = new MLContext().Auto().InferColumns(Path.Combine("TestData", "DatasetWithDefaultColumnNames.txt"), + new ColumnInformation() + { + LabelColumnName = DefaultColumnNames.Label, + ExampleWeightColumnName = DefaultColumnNames.Weight, + }, + groupColumns : false); + + Assert.AreEqual(DefaultColumnNames.Label, result.ColumnInformation.LabelColumnName); + Assert.AreEqual(DefaultColumnNames.Weight, result.ColumnInformation.ExampleWeightColumnName); + Assert.AreEqual(result.ColumnInformation.NumericColumnNames.Count(), 3); + } + + [TestMethod] + public void DefaultColumnNamesNoGrouping() + { + var result = new MLContext().Auto().InferColumns(Path.Combine("TestData", "DatasetWithDefaultColumnNames.txt"), + new ColumnInformation() + { + LabelColumnName = DefaultColumnNames.Label, + ExampleWeightColumnName = DefaultColumnNames.Weight, + }); + + Assert.AreEqual(DefaultColumnNames.Label, result.ColumnInformation.LabelColumnName); + Assert.AreEqual(DefaultColumnNames.Weight, result.ColumnInformation.ExampleWeightColumnName); + Assert.AreEqual(1, result.ColumnInformation.NumericColumnNames.Count()); + Assert.AreEqual(DefaultColumnNames.Features, result.ColumnInformation.NumericColumnNames.First()); + } + + [TestMethod] + public void InferColumnsColumnInfoParam() + { + var columnInfo = new ColumnInformation() { LabelColumnName = DatasetUtil.MlNetGeneratedRegressionLabel }; + var result = new MLContext().Auto().InferColumns(DatasetUtil.DownloadMlNetGeneratedRegressionDataset(), + columnInfo); + var labelCol = result.TextLoaderOptions.Columns.First(c => c.Name == DatasetUtil.MlNetGeneratedRegressionLabel); + Assert.AreEqual(DataKind.Single, labelCol.DataKind); + Assert.AreEqual(DatasetUtil.MlNetGeneratedRegressionLabel, result.ColumnInformation.LabelColumnName); + Assert.AreEqual(1, result.ColumnInformation.NumericColumnNames.Count()); + Assert.AreEqual(DefaultColumnNames.Features, result.ColumnInformation.NumericColumnNames.First()); + Assert.AreEqual(null, result.ColumnInformation.ExampleWeightColumnName); + } + } +} \ No newline at end of file diff --git a/test/Microsoft.ML.AutoML.Tests/ColumnInformationUtilTests.cs b/test/Microsoft.ML.AutoML.Tests/ColumnInformationUtilTests.cs new file mode 100644 index 0000000000..a3631768da --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/ColumnInformationUtilTests.cs @@ -0,0 +1,36 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class ColumnInformationUtilTests + { + [TestMethod] + public void GetColumnPurpose() + { + var columnInfo = new ColumnInformation() + { + LabelColumnName = "Label", + ExampleWeightColumnName = "Weight", + SamplingKeyColumnName = "SamplingKey", + }; + columnInfo.CategoricalColumnNames.Add("Cat"); + columnInfo.NumericColumnNames.Add("Num"); + columnInfo.TextColumnNames.Add("Text"); + columnInfo.IgnoredColumnNames.Add("Ignored"); + + Assert.AreEqual(ColumnPurpose.Label, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Label")); + Assert.AreEqual(ColumnPurpose.Weight, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Weight")); + Assert.AreEqual(ColumnPurpose.SamplingKey, ColumnInformationUtil.GetColumnPurpose(columnInfo, "SamplingKey")); + Assert.AreEqual(ColumnPurpose.CategoricalFeature, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Cat")); + Assert.AreEqual(ColumnPurpose.NumericFeature, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Num")); + Assert.AreEqual(ColumnPurpose.TextFeature, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Text")); + Assert.AreEqual(ColumnPurpose.Ignore, ColumnInformationUtil.GetColumnPurpose(columnInfo, "Ignored")); + Assert.AreEqual(null, ColumnInformationUtil.GetColumnPurpose(columnInfo, "NonExistent")); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/ConversionTests.cs b/test/Microsoft.ML.AutoML.Tests/ConversionTests.cs new file mode 100644 index 0000000000..e9e522aa88 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/ConversionTests.cs @@ -0,0 +1,85 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class ConversionTests + { + [TestMethod] + public void ConvertFloatMissingValues() + { + var missingValues = new string[] + { + "", + "?", " ", + "na", "n/a", "nan", + "NA", "N/A", "NaN", "NAN" + }; + + foreach(var missingValue in missingValues) + { + float value; + var success = Conversions.TryParse(missingValue.AsMemory(), out value); + Assert.IsTrue(success); + Assert.AreEqual(value, float.NaN); + } + } + + [TestMethod] + public void ConvertFloatParseFailure() + { + var values = new string[] + { + "a", "aa", "nb", "aaa", "naa", "nba", "n/b" + }; + + foreach (var value in values) + { + var success = Conversions.TryParse(value.AsMemory(), out float _); + Assert.IsFalse(success); + } + } + + [TestMethod] + public void ConvertBoolMissingValues() + { + var missingValues = new string[] + { + "", + "no", "NO", "+1", "-1", + "yes", "YES", + "true", "TRUE", + "false", "FALSE" + }; + + foreach (var missingValue in missingValues) + { + var success = Conversions.TryParse(missingValue.AsMemory(), out bool _); + Assert.IsTrue(success); + } + } + + [TestMethod] + public void ConvertBoolParseFailure() + { + var values = new string[] + { + "aa", "na", "+a", "-a", + "aaa", "yaa", "yea", + "aaaa", "taaa", "traa", "trua", + "aaaaa", "fbbbb", "faaaa", "falaa", "falsa" + }; + + foreach (var value in values) + { + var success = Conversions.TryParse(value.AsMemory(), out bool _); + Assert.IsFalse(success); + } + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/DatasetDimensionsTests.cs b/test/Microsoft.ML.AutoML.Tests/DatasetDimensionsTests.cs new file mode 100644 index 0000000000..72dc4d9d68 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/DatasetDimensionsTests.cs @@ -0,0 +1,86 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class DatasetDimensionsTests + { + public object DatasetDimensionUtil { get; private set; } + + [TestMethod] + public void TextColumnDimensionsTest() + { + var context = new MLContext(); + var dataBuilder = new ArrayDataViewBuilder(context); + dataBuilder.AddColumn("categorical", new string[] { "0", "1", "0", "1", "0", "1", "2", "2", "0", "1" }); + dataBuilder.AddColumn("text", new string[] { "0", "1", "0", "1", "0", "1", "2", "2", "0", "1" }); + var data = dataBuilder.GetDataView(); + var dimensions = DatasetDimensionsApi.CalcColumnDimensions(context, data, new[] { + new PurposeInference.Column(0, ColumnPurpose.CategoricalFeature), + new PurposeInference.Column(0, ColumnPurpose.TextFeature), + }); + Assert.IsNotNull(dimensions); + Assert.AreEqual(2, dimensions.Length); + Assert.AreEqual(3, dimensions[0].Cardinality); + Assert.AreEqual(null, dimensions[1].Cardinality); + Assert.IsNull(dimensions[0].HasMissing); + Assert.IsNull(dimensions[1].HasMissing); + } + + [TestMethod] + public void FloatColumnDimensionsTest() + { + var context = new MLContext(); + var dataBuilder = new ArrayDataViewBuilder(context); + dataBuilder.AddColumn("NoNan", NumberDataViewType.Single, new float[] { 0, 1, 0, 1, 0 }); + dataBuilder.AddColumn("Nan", NumberDataViewType.Single, new float[] { 0, 1, 0, 1, float.NaN }); + var data = dataBuilder.GetDataView(); + var dimensions = DatasetDimensionsApi.CalcColumnDimensions(context, data, new[] { + new PurposeInference.Column(0, ColumnPurpose.NumericFeature), + new PurposeInference.Column(1, ColumnPurpose.NumericFeature), + }); + Assert.IsNotNull(dimensions); + Assert.AreEqual(2, dimensions.Length); + Assert.AreEqual(null, dimensions[0].Cardinality); + Assert.AreEqual(null, dimensions[1].Cardinality); + Assert.AreEqual(false, dimensions[0].HasMissing); + Assert.AreEqual(true, dimensions[1].HasMissing); + } + + [TestMethod] + public void FloatVectorColumnHasNanTest() + { + var context = new MLContext(); + var dataBuilder = new ArrayDataViewBuilder(context); + var slotNames = new[] { "Col1", "Col2" }; + var colValues = new float[][] + { + new float[] { 0, 0 }, + new float[] { 1, 1 }, + }; + dataBuilder.AddColumn("NoNan", Util.GetKeyValueGetter(slotNames), NumberDataViewType.Single, colValues); + colValues = new float[][] + { + new float[] { 0, 0 }, + new float[] { 1, float.NaN }, + }; + dataBuilder.AddColumn("Nan", Util.GetKeyValueGetter(slotNames), NumberDataViewType.Single, colValues); + var data = dataBuilder.GetDataView(); + var dimensions = DatasetDimensionsApi.CalcColumnDimensions(context, data, new[] { + new PurposeInference.Column(0, ColumnPurpose.NumericFeature), + new PurposeInference.Column(1, ColumnPurpose.NumericFeature), + }); + Assert.IsNotNull(dimensions); + Assert.AreEqual(2, dimensions.Length); + Assert.AreEqual(null, dimensions[0].Cardinality); + Assert.AreEqual(null, dimensions[1].Cardinality); + Assert.AreEqual(false, dimensions[0].HasMissing); + Assert.AreEqual(true, dimensions[1].HasMissing); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/DatasetUtil.cs b/test/Microsoft.ML.AutoML.Tests/DatasetUtil.cs new file mode 100644 index 0000000000..9f76ebb083 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/DatasetUtil.cs @@ -0,0 +1,60 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using System.Net; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto.Test +{ + internal static class DatasetUtil + { + public const string UciAdultLabel = DefaultColumnNames.Label; + public const string TrivialMulticlassDatasetLabel = "Target"; + public const string MlNetGeneratedRegressionLabel = "target"; + public const int IrisDatasetLabelColIndex = 0; + + public static string TrivialMulticlassDatasetPath = Path.Combine("TestData", "TrivialMulticlassDataset.txt"); + + private static IDataView _uciAdultDataView; + + public static IDataView GetUciAdultDataView() + { + if(_uciAdultDataView == null) + { + var context = new MLContext(); + var uciAdultDataFile = DownloadUciAdultDataset(); + var columnInferenceResult = context.Auto().InferColumns(uciAdultDataFile, UciAdultLabel); + var textLoader = context.Data.CreateTextLoader(columnInferenceResult.TextLoaderOptions); + _uciAdultDataView = textLoader.Load(uciAdultDataFile); + } + return _uciAdultDataView; + } + + // downloads the UCI Adult dataset from the ML.Net repo + public static string DownloadUciAdultDataset() => + DownloadIfNotExists("https://raw.githubusercontent.com/dotnet/machinelearning/f0e639af5ffdc839aae8e65d19b5a9a1f0db634a/test/data/adult.tiny.with-schema.txt", "uciadult.dataset"); + + public static string DownloadMlNetGeneratedRegressionDataset() => + DownloadIfNotExists("https://raw.githubusercontent.com/dotnet/machinelearning/e78971ea6fd736038b4c355b840e5cbabae8cb55/test/data/generated_regression_dataset.csv", "mlnet_generated_regression.dataset"); + + public static string DownloadIrisDataset() => + DownloadIfNotExists("https://raw.githubusercontent.com/dotnet/machinelearning/54596ac/test/data/iris.txt", "iris.dataset"); + + private static string DownloadIfNotExists(string baseGitPath, string dataFile) + { + // if file doesn't already exist, download it + if(!File.Exists(dataFile)) + { + using (var client = new WebClient()) + { + client.DownloadFile(new Uri($"{baseGitPath}"), dataFile); + } + } + + return dataFile; + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/Directory.Build.props b/test/Microsoft.ML.AutoML.Tests/Directory.Build.props new file mode 100644 index 0000000000..e161d1461b --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/Directory.Build.props @@ -0,0 +1,9 @@ + + + + + trx + $(OutputPath) + + + \ No newline at end of file diff --git a/test/Microsoft.ML.AutoML.Tests/EstimatorExtensionTests.cs b/test/Microsoft.ML.AutoML.Tests/EstimatorExtensionTests.cs new file mode 100644 index 0000000000..d778429f93 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/EstimatorExtensionTests.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class EstimatorExtensionTests + { + [TestMethod] + public void EstimatorExtensionInstanceTests() + { + var context = new MLContext(); + var pipelineNode = new PipelineNode() + { + InColumns = new string[] { "Input" }, + OutColumns = new string[] { "Output" } + }; + + var estimatorNames = Enum.GetValues(typeof(EstimatorName)).Cast(); + foreach (var estimatorName in estimatorNames) + { + var extension = EstimatorExtensionCatalog.GetExtension(estimatorName); + var instance = extension.CreateInstance(context, pipelineNode); + Assert.IsNotNull(instance); + } + } + + [TestMethod] + public void EstimatorExtensionStaticTests() + { + var context = new MLContext(); + var inCol = "Input"; + var outCol = "Output"; + var inCols = new string[] { inCol }; + var outCols = new string[] { outCol }; + Assert.IsNotNull(ColumnConcatenatingExtension.CreateSuggestedTransform(context, inCols, outCol)); + Assert.IsNotNull(ColumnCopyingExtension.CreateSuggestedTransform(context, inCol, outCol)); + Assert.IsNotNull(MissingValueIndicatingExtension.CreateSuggestedTransform(context, inCols, outCols)); + Assert.IsNotNull(MissingValueReplacingExtension.CreateSuggestedTransform(context, inCols, outCols)); + Assert.IsNotNull(NormalizingExtension.CreateSuggestedTransform(context, inCol, outCol)); + Assert.IsNotNull(OneHotEncodingExtension.CreateSuggestedTransform(context, inCols, outCols)); + Assert.IsNotNull(OneHotHashEncodingExtension.CreateSuggestedTransform(context, inCols, outCols)); + Assert.IsNotNull(TextFeaturizingExtension.CreateSuggestedTransform(context, inCol, outCol)); + Assert.IsNotNull(TypeConvertingExtension.CreateSuggestedTransform(context, inCols, outCols)); + Assert.IsNotNull(ValueToKeyMappingExtension.CreateSuggestedTransform(context, inCol, outCol)); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/GetNextPipelineTests.cs b/test/Microsoft.ML.AutoML.Tests/GetNextPipelineTests.cs new file mode 100644 index 0000000000..023bfc2085 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/GetNextPipelineTests.cs @@ -0,0 +1,84 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Newtonsoft.Json; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class GetNextPipelineTests + { + [TestMethod] + public void GetNextPipeline() + { + var context = new MLContext(); + var uciAdult = DatasetUtil.GetUciAdultDataView(); + var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(context, uciAdult, new ColumnInformation() { LabelColumnName = DatasetUtil.UciAdultLabel }); + + // get next pipeline + var pipeline = PipelineSuggester.GetNextPipeline(context, new List(), columns, TaskKind.BinaryClassification); + + // serialize & deserialize pipeline + var serialized = JsonConvert.SerializeObject(pipeline); + Console.WriteLine(serialized); + var deserialized = JsonConvert.DeserializeObject(serialized); + + // run pipeline + var estimator = deserialized.ToEstimator(context); + var scoredData = estimator.Fit(uciAdult).Transform(uciAdult); + var score = context.BinaryClassification.EvaluateNonCalibrated(scoredData).Accuracy; + var result = new PipelineScore(deserialized, score, true); + + Assert.IsNotNull(result); + } + + [TestMethod] + public void GetNextPipelineMock() + { + var context = new MLContext(); + var uciAdult = DatasetUtil.GetUciAdultDataView(); + var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(context, uciAdult, new ColumnInformation() { LabelColumnName = DatasetUtil.UciAdultLabel }); + + // Get next pipeline loop + var history = new List(); + var task = TaskKind.BinaryClassification; + var maxIterations = 60; + for (var i = 0; i < maxIterations; i++) + { + // Get next pipeline + var pipeline = PipelineSuggester.GetNextPipeline(context, history, columns, task); + if (pipeline == null) + { + break; + } + + var result = new PipelineScore(pipeline, AutoMlUtils.random.Value.NextDouble(), true); + history.Add(result); + } + + Assert.AreEqual(maxIterations, history.Count); + + // Get all 'Stage 1' and 'Stage 2' runs from Pipeline Suggester + var allAvailableTrainers = RecipeInference.AllowedTrainers(context, task, new ColumnInformation(), null); + var stage1Runs = history.Take(allAvailableTrainers.Count()); + var stage2Runs = history.Skip(allAvailableTrainers.Count()); + + // Get the trainer names from top 3 Stage 1 runs + var topStage1Runs = stage1Runs.OrderByDescending(r => r.Score).Take(3); + var topStage1TrainerNames = topStage1Runs.Select(r => r.Pipeline.Nodes.Last().Name); + + // Get unique trainer names from Stage 2 runs + var stage2TrainerNames = stage2Runs.Select(r => r.Pipeline.Nodes.Last().Name).Distinct(); + + // Assert that are only 3 unique trainers used in stage 2 + Assert.AreEqual(3, stage2TrainerNames.Count()); + // Assert that all trainers in stage 2 were the top trainers from stage 1 + Assert.IsFalse(topStage1TrainerNames.Except(stage2TrainerNames).Any()); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/InferredPipelineTests.cs b/test/Microsoft.ML.AutoML.Tests/InferredPipelineTests.cs new file mode 100644 index 0000000000..9a931b6810 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/InferredPipelineTests.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class InferredPipelineTests + { + [TestMethod] + public void InferredPipelinesHashTest() + { + var context = new MLContext(); + var columnInfo = new ColumnInformation(); + + // test same learners with no hyperparams have the same hash code + var trainer1 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo); + var trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo); + var transforms1 = new List(); + var transforms2 = new List(); + var inferredPipeline1 = new SuggestedPipeline(transforms1, new List(), trainer1, context, false); + var inferredPipeline2 = new SuggestedPipeline(transforms2, new List(), trainer2, context, false); + Assert.AreEqual(inferredPipeline1.GetHashCode(), inferredPipeline2.GetHashCode()); + + // test same learners with hyperparams set vs empty hyperparams have different hash codes + var hyperparams1 = new ParameterSet(new List() { new LongParameterValue("NumberOfLeaves", 2) }); + trainer1 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo, hyperparams1); + trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo); + inferredPipeline1 = new SuggestedPipeline(transforms1, new List(), trainer1, context, false); + inferredPipeline2 = new SuggestedPipeline(transforms2, new List(), trainer2, context, false); + Assert.AreNotEqual(inferredPipeline1.GetHashCode(), inferredPipeline2.GetHashCode()); + + // same learners with different hyperparams + hyperparams1 = new ParameterSet(new List() { new LongParameterValue("NumberOfLeaves", 2) }); + var hyperparams2 = new ParameterSet(new List() { new LongParameterValue("NumberOfLeaves", 6) }); + trainer1 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo, hyperparams1); + trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo, hyperparams2); + inferredPipeline1 = new SuggestedPipeline(transforms1, new List(), trainer1, context, false); + inferredPipeline2 = new SuggestedPipeline(transforms2, new List(), trainer2, context, false); + Assert.AreNotEqual(inferredPipeline1.GetHashCode(), inferredPipeline2.GetHashCode()); + + // same learners with same transforms + trainer1 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo); + trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo); + transforms1 = new List() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") }; + transforms2 = new List() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") }; + inferredPipeline1 = new SuggestedPipeline(transforms1, new List(), trainer1, context, false); + inferredPipeline2 = new SuggestedPipeline(transforms2, new List(), trainer2, context, false); + Assert.AreEqual(inferredPipeline1.GetHashCode(), inferredPipeline2.GetHashCode()); + + // same transforms with different learners + trainer1 = new SuggestedTrainer(context, new SdcaLogisticRegressionBinaryExtension(), columnInfo); + trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo); + transforms1 = new List() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") }; + transforms2 = new List() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") }; + inferredPipeline1 = new SuggestedPipeline(transforms1, new List(), trainer1, context, false); + inferredPipeline2 = new SuggestedPipeline(transforms2, new List(), trainer2, context, false); + Assert.AreNotEqual(inferredPipeline1.GetHashCode(), inferredPipeline2.GetHashCode()); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs b/test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs new file mode 100644 index 0000000000..9baaea540a --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/MetricsAgentsTests.cs @@ -0,0 +1,165 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using Microsoft.ML.Data; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class MetricsAgentsTests + { + [TestMethod] + public void BinaryMetricsGetScoreTest() + { + var metrics = MetricsUtil.CreateBinaryClassificationMetrics(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8); + Assert.AreEqual(0.1, GetScore(metrics, BinaryClassificationMetric.AreaUnderRocCurve)); + Assert.AreEqual(0.2, GetScore(metrics, BinaryClassificationMetric.Accuracy)); + Assert.AreEqual(0.3, GetScore(metrics, BinaryClassificationMetric.PositivePrecision)); + Assert.AreEqual(0.4, GetScore(metrics, BinaryClassificationMetric.PositiveRecall)); + Assert.AreEqual(0.5, GetScore(metrics, BinaryClassificationMetric.NegativePrecision)); + Assert.AreEqual(0.6, GetScore(metrics, BinaryClassificationMetric.NegativeRecall)); + Assert.AreEqual(0.7, GetScore(metrics, BinaryClassificationMetric.F1Score)); + Assert.AreEqual(0.8, GetScore(metrics, BinaryClassificationMetric.AreaUnderPrecisionRecallCurve)); + } + + [TestMethod] + public void BinaryMetricsNonPerfectTest() + { + var metrics = MetricsUtil.CreateBinaryClassificationMetrics(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8); + Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.Accuracy)); + Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.AreaUnderRocCurve)); + Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.AreaUnderPrecisionRecallCurve)); + Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.F1Score)); + Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.NegativePrecision)); + Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.NegativeRecall)); + Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.PositivePrecision)); + Assert.AreEqual(false, IsPerfectModel(metrics, BinaryClassificationMetric.PositiveRecall)); + } + + [TestMethod] + public void BinaryMetricsPerfectTest() + { + var metrics = MetricsUtil.CreateBinaryClassificationMetrics(1, 1, 1, 1, 1, 1, 1, 1); + Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.Accuracy)); + Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.AreaUnderRocCurve)); + Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.AreaUnderPrecisionRecallCurve)); + Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.F1Score)); + Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.NegativePrecision)); + Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.NegativeRecall)); + Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.PositivePrecision)); + Assert.AreEqual(true, IsPerfectModel(metrics, BinaryClassificationMetric.PositiveRecall)); + } + + [TestMethod] + public void MulticlassMetricsGetScoreTest() + { + var metrics = MetricsUtil.CreateMulticlassClassificationMetrics(0.1, 0.2, 0.3, 0.4, 0, 0.5, new double[] {}); + Assert.AreEqual(0.1, GetScore(metrics, MulticlassClassificationMetric.MicroAccuracy)); + Assert.AreEqual(0.2, GetScore(metrics, MulticlassClassificationMetric.MacroAccuracy)); + Assert.AreEqual(0.3, GetScore(metrics, MulticlassClassificationMetric.LogLoss)); + Assert.AreEqual(0.4, GetScore(metrics, MulticlassClassificationMetric.LogLossReduction)); + Assert.AreEqual(0.5, GetScore(metrics, MulticlassClassificationMetric.TopKAccuracy)); + } + + [TestMethod] + public void MulticlassMetricsNonPerfectTest() + { + var metrics = MetricsUtil.CreateMulticlassClassificationMetrics(0.1, 0.2, 0.3, 0.4, 0, 0.5, new double[] { }); + Assert.AreEqual(false, IsPerfectModel(metrics, MulticlassClassificationMetric.MacroAccuracy)); + Assert.AreEqual(false, IsPerfectModel(metrics, MulticlassClassificationMetric.MicroAccuracy)); + Assert.AreEqual(false, IsPerfectModel(metrics, MulticlassClassificationMetric.LogLoss)); + Assert.AreEqual(false, IsPerfectModel(metrics, MulticlassClassificationMetric.LogLossReduction)); + Assert.AreEqual(false, IsPerfectModel(metrics, MulticlassClassificationMetric.TopKAccuracy)); + } + + [TestMethod] + public void MulticlassMetricsPerfectTest() + { + var metrics = MetricsUtil.CreateMulticlassClassificationMetrics(1, 1, 0, 1, 0, 1, new double[] { }); + Assert.AreEqual(true, IsPerfectModel(metrics, MulticlassClassificationMetric.MicroAccuracy)); + Assert.AreEqual(true, IsPerfectModel(metrics, MulticlassClassificationMetric.MacroAccuracy)); + Assert.AreEqual(true, IsPerfectModel(metrics, MulticlassClassificationMetric.LogLoss)); + Assert.AreEqual(true, IsPerfectModel(metrics, MulticlassClassificationMetric.LogLossReduction)); + Assert.AreEqual(true, IsPerfectModel(metrics, MulticlassClassificationMetric.TopKAccuracy)); + } + + [TestMethod] + public void RegressionMetricsGetScoreTest() + { + var metrics = MetricsUtil.CreateRegressionMetrics(0.2, 0.3, 0.4, 0.5, 0.6); + Assert.AreEqual(0.2, GetScore(metrics, RegressionMetric.MeanAbsoluteError)); + Assert.AreEqual(0.3, GetScore(metrics, RegressionMetric.MeanSquaredError)); + Assert.AreEqual(0.4, GetScore(metrics, RegressionMetric.RootMeanSquaredError)); + Assert.AreEqual(0.6, GetScore(metrics, RegressionMetric.RSquared)); + } + + [TestMethod] + public void RegressionMetricsNonPerfectTest() + { + var metrics = MetricsUtil.CreateRegressionMetrics(0.2, 0.3, 0.4, 0.5, 0.6); + Assert.AreEqual(false, IsPerfectModel(metrics, RegressionMetric.MeanAbsoluteError)); + Assert.AreEqual(false, IsPerfectModel(metrics, RegressionMetric.MeanSquaredError)); + Assert.AreEqual(false, IsPerfectModel(metrics, RegressionMetric.RootMeanSquaredError)); + Assert.AreEqual(false, IsPerfectModel(metrics, RegressionMetric.RSquared)); + } + + [TestMethod] + public void RegressionMetricsPerfectTest() + { + var metrics = MetricsUtil.CreateRegressionMetrics(0, 0, 0, 0, 1); + Assert.AreEqual(true, IsPerfectModel(metrics, RegressionMetric.MeanAbsoluteError)); + Assert.AreEqual(true, IsPerfectModel(metrics, RegressionMetric.MeanSquaredError)); + Assert.AreEqual(true, IsPerfectModel(metrics, RegressionMetric.RootMeanSquaredError)); + Assert.AreEqual(true, IsPerfectModel(metrics, RegressionMetric.RSquared)); + } + + [TestMethod] + [ExpectedException(typeof(NotSupportedException))] + public void ThrowNotSupportedMetricException() + { + throw MetricsAgentUtil.BuildMetricNotSupportedException(BinaryClassificationMetric.Accuracy); + } + + private static double GetScore(BinaryClassificationMetrics metrics, BinaryClassificationMetric metric) + { + return new BinaryMetricsAgent(null, metric).GetScore(metrics); + } + + private static double GetScore(MulticlassClassificationMetrics metrics, MulticlassClassificationMetric metric) + { + return new MultiMetricsAgent(null, metric).GetScore(metrics); + } + + private static double GetScore(RegressionMetrics metrics, RegressionMetric metric) + { + return new RegressionMetricsAgent(null, metric).GetScore(metrics); + } + + private static bool IsPerfectModel(BinaryClassificationMetrics metrics, BinaryClassificationMetric metric) + { + var metricsAgent = new BinaryMetricsAgent(null, metric); + return IsPerfectModel(metricsAgent, metrics); + } + + private static bool IsPerfectModel(MulticlassClassificationMetrics metrics, MulticlassClassificationMetric metric) + { + var metricsAgent = new MultiMetricsAgent(null, metric); + return IsPerfectModel(metricsAgent, metrics); + } + + private static bool IsPerfectModel(RegressionMetrics metrics, RegressionMetric metric) + { + var metricsAgent = new RegressionMetricsAgent(null, metric); + return IsPerfectModel(metricsAgent, metrics); + } + + private static bool IsPerfectModel(IMetricsAgent metricsAgent, TMetrics metrics) + { + var score = metricsAgent.GetScore(metrics); + return metricsAgent.IsModelPerfect(score); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/MetricsUtil.cs b/test/Microsoft.ML.AutoML.Tests/MetricsUtil.cs new file mode 100644 index 0000000000..89dc0cc51f --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/MetricsUtil.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Reflection; +using Microsoft.ML.Data; + +namespace Microsoft.ML.Auto.Test +{ + internal static class MetricsUtil + { + public static BinaryClassificationMetrics CreateBinaryClassificationMetrics( + double auc, double accuracy, double positivePrecision, + double positiveRecall, double negativePrecision, + double negativeRecall, double f1Score, double auprc) + { + return CreateInstance(auc, accuracy, + positivePrecision, positiveRecall, negativePrecision, + negativeRecall, f1Score, auprc); + } + + public static MulticlassClassificationMetrics CreateMulticlassClassificationMetrics( + double accuracyMicro, double accuracyMacro, double logLoss, + double logLossReduction, int topK, double topKAccuracy, + double[] perClassLogLoss) + { + return CreateInstance(accuracyMicro, + accuracyMacro, logLoss, logLossReduction, topK, + topKAccuracy, perClassLogLoss); + } + + public static RegressionMetrics CreateRegressionMetrics(double l1, + double l2, double rms, double lossFn, double rSquared) + { + return CreateInstance(l1, l2, + rms, lossFn, rSquared); + } + + private static T CreateInstance(params object[] args) + { + var type = typeof(T); + var instance = type.Assembly.CreateInstance( + type.FullName, false, + BindingFlags.Instance | BindingFlags.NonPublic, + null, args, null, null); + return (T)instance; + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/Microsoft.ML.AutoML.Tests.csproj b/test/Microsoft.ML.AutoML.Tests/Microsoft.ML.AutoML.Tests.csproj new file mode 100644 index 0000000000..5e5547f422 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/Microsoft.ML.AutoML.Tests.csproj @@ -0,0 +1,43 @@ + + + + netcoreapp2.1 + + false + + + false + false + + Microsoft.ML.Auto.Test + + + + + + + + + + + + + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + PreserveNewest + + + + diff --git a/test/Microsoft.ML.AutoML.Tests/PurposeInferenceTests.cs b/test/Microsoft.ML.AutoML.Tests/PurposeInferenceTests.cs new file mode 100644 index 0000000000..3865ca75a1 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/PurposeInferenceTests.cs @@ -0,0 +1,38 @@ +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class PurposeInferenceTests + { + [TestMethod] + public void PurposeInferenceHiddenColumnsTest() + { + var context = new MLContext(); + + // build basic data view + var schemaBuilder = new DataViewSchema.Builder(); + schemaBuilder.AddColumn(DefaultColumnNames.Label, BooleanDataViewType.Instance); + schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single); + var schema = schemaBuilder.ToSchema(); + IDataView data = new EmptyDataView(context, schema); + + // normalize 'Features' column. this has the effect of creating 2 columns named + // 'Features' in the data view, the first of which gets marked as 'Hidden' + var normalizer = context.Transforms.NormalizeMinMax(DefaultColumnNames.Features); + data = normalizer.Fit(data).Transform(data); + + // infer purposes + var purposes = PurposeInference.InferPurposes(context, data, new ColumnInformation()); + + Assert.AreEqual(3, purposes.Count()); + Assert.AreEqual(ColumnPurpose.Label, purposes[0].Purpose); + // assert first 'Features' purpose (hidden column) is Ignore + Assert.AreEqual(ColumnPurpose.Ignore, purposes[1].Purpose); + // assert second 'Features' purpose is NumericFeature + Assert.AreEqual(ColumnPurpose.NumericFeature, purposes[2].Purpose); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/SuggestedPipelineBuilderTests.cs b/test/Microsoft.ML.AutoML.Tests/SuggestedPipelineBuilderTests.cs new file mode 100644 index 0000000000..e59c3fccea --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/SuggestedPipelineBuilderTests.cs @@ -0,0 +1,83 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class SuggestedPipelineBuilderTests + { + private static MLContext _context = new MLContext(); + + [TestMethod] + public void TrainerWantsCaching() + { + TestPipelineBuilderCaching(BuildAveragedPerceptronTrainer(), + new bool?[] { true, false, null }, + new[] { true, false, true }); + } + + [TestMethod] + public void TrainerDoesntWantCaching() + { + TestPipelineBuilderCaching(BuildLightGbmTrainer(), + new bool?[] { true, false, null }, + new[] { true, false, false }); + } + + [TestMethod] + public void TrainerNeedsNormalization() + { + var pipeline = BuildSuggestedPipeline(BuildAveragedPerceptronTrainer()); + Assert.AreEqual(EstimatorName.Normalizing.ToString(), + pipeline.Transforms[0].PipelineNode.Name); + } + + [TestMethod] + public void TrainerNotNeedNormalization() + { + var pipeline = BuildSuggestedPipeline(BuildLightGbmTrainer()); + Assert.AreEqual(0, pipeline.Transforms.Count); + } + + private static void TestPipelineBuilderCaching( + SuggestedTrainer trainer, + bool?[] enableCachingOptions, + bool[] resultShouldHaveCaching) + { + for (var i = 0; i < enableCachingOptions.Length; i++) + { + var suggestedPipeline = BuildSuggestedPipeline(trainer, + enableCachingOptions[i]); + Assert.AreEqual(resultShouldHaveCaching[i], + suggestedPipeline.ToPipeline().CacheBeforeTrainer); + } + } + + private static SuggestedTrainer BuildAveragedPerceptronTrainer() + { + return new SuggestedTrainer(_context, + new AveragedPerceptronBinaryExtension(), + new ColumnInformation()); + } + + private static SuggestedTrainer BuildLightGbmTrainer() + { + return new SuggestedTrainer(_context, + new LightGbmBinaryExtension(), + new ColumnInformation()); + } + + private static SuggestedPipeline BuildSuggestedPipeline(SuggestedTrainer trainer, + bool? enableCaching = null) + { + return SuggestedPipelineBuilder.Build(_context, + new List(), + new List(), + trainer, enableCaching); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/SweeperTests.cs b/test/Microsoft.ML.AutoML.Tests/SweeperTests.cs new file mode 100644 index 0000000000..c2e60922cf --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/SweeperTests.cs @@ -0,0 +1,173 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class SweeperTests + { + [TestMethod] + public void SmacQuickRunTest() + { + var numInitialPopulation = 10; + + var floatValueGenerator = new FloatValueGenerator(new FloatParamArguments() { Name = "float", Min = 1, Max = 1000 }); + var floatLogValueGenerator = new FloatValueGenerator(new FloatParamArguments() { Name = "floatLog", Min = 1, Max = 1000, LogBase = true }); + var longValueGenerator = new LongValueGenerator(new LongParamArguments() { Name = "long", Min = 1, Max = 1000 }); + var longLogValueGenerator = new LongValueGenerator(new LongParamArguments() { Name = "longLog", Min = 1, Max = 1000, LogBase = true }); + var discreteValueGeneator = new DiscreteValueGenerator(new DiscreteParamArguments() { Name = "discrete", Values = new[] { "200", "400", "600", "800" } }); + + var sweeper = new SmacSweeper(new MLContext(), new SmacSweeper.Arguments() + { + SweptParameters = new IValueGenerator[] { + floatValueGenerator, + floatLogValueGenerator, + longValueGenerator, + longLogValueGenerator, + discreteValueGeneator + }, + NumberInitialPopulation = numInitialPopulation + }); + + // sanity check grid + Assert.IsNotNull(floatValueGenerator[0].ValueText); + Assert.IsNotNull(floatLogValueGenerator[0].ValueText); + Assert.IsNotNull(longValueGenerator[0].ValueText); + Assert.IsNotNull(longLogValueGenerator[0].ValueText); + Assert.IsNotNull(discreteValueGeneator[0].ValueText); + + List results = new List(); + + RunResult bestResult = null; + for (var i = 0; i < numInitialPopulation + 1; i++) + { + ParameterSet[] pars = sweeper.ProposeSweeps(1, results); + + foreach (ParameterSet p in pars) + { + float x1 = float.Parse(p["float"].ValueText); + float x2 = float.Parse(p["floatLog"].ValueText); + long x3 = long.Parse(p["long"].ValueText); + long x4 = long.Parse(p["longLog"].ValueText); + int x5 = int.Parse(p["discrete"].ValueText); + + double metric = x1 + x2 + x3 + x4 + x5; + + RunResult result = new RunResult(p, metric, true); + if (bestResult == null || bestResult.MetricValue < metric) + { + bestResult = result; + } + results.Add(result); + + Console.WriteLine($"{metric}\t{x1},{x2}"); + } + + } + + Console.WriteLine($"Best: {bestResult.MetricValue}"); + + Assert.IsNotNull(bestResult); + Assert.IsTrue(bestResult.MetricValue > 0); + } + + + [Ignore] + [TestMethod] + public void Smac4ParamsConvergenceTest() + { + var sweeper = new SmacSweeper(new MLContext(), new SmacSweeper.Arguments() + { + SweptParameters = new INumericValueGenerator[] { + new FloatValueGenerator(new FloatParamArguments() { Name = "x1", Min = 1, Max = 1000}), + new FloatValueGenerator(new FloatParamArguments() { Name = "x2", Min = 1, Max = 1000}), + new FloatValueGenerator(new FloatParamArguments() { Name = "x3", Min = 1, Max = 1000}), + new FloatValueGenerator(new FloatParamArguments() { Name = "x4", Min = 1, Max = 1000}), + }, + }); + + List results = new List(); + + RunResult bestResult = null; + for (var i = 0; i < 300; i++) + { + ParameterSet[] pars = sweeper.ProposeSweeps(1, results); + + // if run converged, break + if (pars == null) + { + break; + } + + foreach (ParameterSet p in pars) + { + float x1 = (p["x1"] as FloatParameterValue).Value; + float x2 = (p["x2"] as FloatParameterValue).Value; + float x3 = (p["x3"] as FloatParameterValue).Value; + float x4 = (p["x4"] as FloatParameterValue).Value; + + double metric = -200 * (Math.Abs(100 - x1) + + Math.Abs(300 - x2) + + Math.Abs(500 - x3) + + Math.Abs(700 - x4)); + + RunResult result = new RunResult(p, metric, true); + if (bestResult == null || bestResult.MetricValue < metric) + { + bestResult = result; + } + results.Add(result); + + Console.WriteLine($"{metric}\t{x1},{x2},{x3},{x4}"); + } + + } + + Console.WriteLine($"Best: {bestResult.MetricValue}"); + } + + [Ignore] + [TestMethod] + public void Smac2ParamsConvergenceTest() + { + var sweeper = new SmacSweeper(new MLContext(), new SmacSweeper.Arguments() + { + SweptParameters = new INumericValueGenerator[] { + new FloatValueGenerator(new FloatParamArguments() { Name = "foo", Min = 1, Max = 5}), + new LongValueGenerator(new LongParamArguments() { Name = "bar", Min = 1, Max = 1000, LogBase = true }) + }, + }); + + Random rand = new Random(0); + List results = new List(); + + int count = 0; + while (true) + { + ParameterSet[] pars = sweeper.ProposeSweeps(1, results); + if(pars == null) + { + break; + } + foreach (ParameterSet p in pars) + { + float foo = 0; + long bar = 0; + + foo = (p["foo"] as FloatParameterValue).Value; + bar = (p["bar"] as LongParameterValue).Value; + + double metric = ((5 - Math.Abs(4 - foo)) * 200) + (1001 - Math.Abs(33 - bar)) + rand.Next(1, 20); + results.Add(new RunResult(p, metric, true)); + count++; + Console.WriteLine("{0}--{1}--{2}--{3}", count, foo, bar, metric); + } + } + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/TestData/BinaryDatasetWithBoolColumn.txt b/test/Microsoft.ML.AutoML.Tests/TestData/BinaryDatasetWithBoolColumn.txt new file mode 100644 index 0000000000..7fc6e787df --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/TestData/BinaryDatasetWithBoolColumn.txt @@ -0,0 +1,5 @@ +Label,Bool +0,1 +0,0 +1,1 +1,0 \ No newline at end of file diff --git a/test/Microsoft.ML.AutoML.Tests/TestData/DatasetWithDefaultColumnNames.txt b/test/Microsoft.ML.AutoML.Tests/TestData/DatasetWithDefaultColumnNames.txt new file mode 100644 index 0000000000..26aa3a2102 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/TestData/DatasetWithDefaultColumnNames.txt @@ -0,0 +1,4 @@ +Label,Weight,Name,Features,FeatureContributions,Feature1 +0,1,GUID1,1,1,1 +0,1,GUID2,1,1,1 +1,1,GUID3,1,1,1 \ No newline at end of file diff --git a/test/Microsoft.ML.AutoML.Tests/TestData/DatasetWithEmptyColumn.txt b/test/Microsoft.ML.AutoML.Tests/TestData/DatasetWithEmptyColumn.txt new file mode 100644 index 0000000000..7033743b5b --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/TestData/DatasetWithEmptyColumn.txt @@ -0,0 +1,4 @@ +Label,Feature1,Empty +0,2, +0,4, +1,1, \ No newline at end of file diff --git a/test/Microsoft.ML.AutoML.Tests/TestData/NameColumnIsOnlyFeatureDataset.txt b/test/Microsoft.ML.AutoML.Tests/TestData/NameColumnIsOnlyFeatureDataset.txt new file mode 100644 index 0000000000..3e436a9ae6 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/TestData/NameColumnIsOnlyFeatureDataset.txt @@ -0,0 +1,103 @@ +Label,Username +0,a0 +0,a1 +0,a2 +0,a3 +0,a4 +0,a5 +0,a6 +0,a7 +0,a8 +0,a9 +0,a10 +0,a11 +0,a12 +0,a13 +0,a14 +0,a15 +0,a16 +0,a17 +0,a18 +0,a19 +0,a20 +0,a21 +0,a22 +0,a23 +0,a24 +0,a25 +0,a26 +0,a27 +0,a28 +0,a29 +0,a30 +0,a31 +0,a32 +0,a33 +0,a34 +0,a35 +0,a36 +0,a37 +0,a38 +0,a39 +0,a40 +0,a41 +0,a42 +0,a43 +0,a44 +0,a45 +0,a46 +0,a47 +0,a48 +0,a49 +0,a50 +1,b0 +1,b1 +1,b2 +1,b3 +1,b4 +1,b5 +1,b6 +1,b7 +1,b8 +1,b9 +1,b10 +1,b11 +1,b12 +1,b13 +1,b14 +1,b15 +1,b16 +1,b17 +1,b18 +1,b19 +1,b20 +1,b21 +1,b22 +1,b23 +1,b24 +1,b25 +1,b26 +1,b27 +1,b28 +1,b29 +1,b30 +1,b31 +1,b32 +1,b33 +1,b34 +1,b35 +1,b36 +1,b37 +1,b38 +1,b39 +1,b40 +1,b41 +1,b42 +1,b43 +1,b44 +1,b45 +1,b46 +1,b47 +1,b48 +1,b49 +1,b50 \ No newline at end of file diff --git a/test/Microsoft.ML.AutoML.Tests/TestData/TrivialMulticlassDataset.txt b/test/Microsoft.ML.AutoML.Tests/TestData/TrivialMulticlassDataset.txt new file mode 100644 index 0000000000..c9566415b6 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/TestData/TrivialMulticlassDataset.txt @@ -0,0 +1,181 @@ +Target Row Column +1 14 20 +1 19 26 +3 17 4 +1 10 20 +1 3 5 +1 7 5 +1 18 36 +2 1 36 +2 1 38 +3 17 1 +2 6 26 +2 9 30 +3 13 8 +2 7 33 +2 8 30 +3 10 1 +1 18 25 +1 13 12 +1 3 2 +2 8 28 +1 11 24 +2 3 28 +2 1 16 +1 9 7 +1 15 16 +3 19 4 +1 1 8 +1 8 0 +1 10 34 +1 18 37 +2 1 17 +2 8 39 +1 17 30 +2 1 27 +2 0 38 +1 11 16 +3 19 3 +1 7 8 +1 13 13 +1 19 31 +3 16 1 +1 5 1 +2 6 11 +1 9 5 +3 10 6 +1 1 2 +2 6 30 +2 7 15 +1 17 21 +1 18 23 +3 10 7 +2 5 39 +2 2 27 +3 12 6 +3 11 4 +1 9 3 +1 12 22 +2 8 19 +2 1 14 +1 11 11 +1 10 36 +3 12 4 +1 15 21 +1 17 37 +1 6 3 +2 3 18 +1 10 10 +1 11 33 +1 18 19 +2 7 35 +3 10 2 +1 12 30 +1 12 26 +2 1 31 +2 5 21 +2 1 11 +1 7 3 +2 8 36 +3 10 4 +1 18 26 +2 8 10 +1 10 22 +1 15 14 +3 16 0 +2 0 30 +2 3 34 +3 13 9 +1 0 2 +1 15 36 +1 15 23 +1 10 30 +2 6 20 +2 9 24 +2 9 35 +1 7 6 +2 7 39 +2 5 20 +3 12 8 +2 9 12 +1 17 25 +1 12 33 +2 6 19 +1 17 10 +2 4 35 +1 15 31 +3 12 7 +1 17 16 +2 1 19 +2 3 25 +1 16 30 +1 19 30 +1 5 4 +2 6 10 +1 18 20 +1 13 26 +2 3 39 +2 2 20 +1 4 7 +2 3 33 +1 16 20 +2 1 21 +3 15 2 +3 19 2 +1 12 10 +2 5 37 +2 1 32 +3 18 6 +1 2 1 +1 16 21 +2 1 23 +1 17 33 +2 5 11 +2 3 14 +1 11 12 +1 13 20 +1 19 38 +1 15 10 +2 8 11 +3 11 0 +1 18 10 +1 19 24 +1 13 11 +2 4 23 +1 16 26 +1 7 7 +1 17 29 +1 18 30 +1 13 10 +2 6 21 +1 19 32 +2 7 12 +1 12 28 +2 2 11 +1 12 15 +2 8 32 +3 15 9 +3 16 5 +1 9 1 +1 19 28 +3 16 3 +1 15 17 +2 7 38 +1 16 38 +1 14 26 +1 10 26 +1 10 37 +3 18 5 +2 5 27 +2 2 22 +1 11 39 +1 16 36 +1 0 9 +2 5 19 +1 18 28 +1 12 13 +1 17 17 +1 8 1 +2 6 15 +3 14 4 +1 1 4 \ No newline at end of file diff --git a/test/Microsoft.ML.AutoML.Tests/TextFileSampleTests.cs b/test/Microsoft.ML.AutoML.Tests/TextFileSampleTests.cs new file mode 100644 index 0000000000..b9cf90f39a --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/TextFileSampleTests.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.IO; +using System.Text; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class TextFileSampleTests + { + [TestMethod] + public void CanParseLargeRandomStream() + { + using (var stream = new MemoryStream()) + { + const int numRows = 100000; + const int rowSize = 100; + + for (var i = 0; i < numRows; i++) + { + var row = new byte[rowSize]; + AutoMlUtils.random.Value.NextBytes(row); + + // ensure byte array has no 0s, so text file sampler doesn't + // think file is encoded with UTF-16 or UTF-32 without a BOM + for (var k = 0; k < row.Length; k++) + { + if(row[k] == 0) + { + row[k] = 1; + } + } + stream.Write(row); + stream.Write(Encoding.UTF8.GetBytes("\r\n")); + } + + stream.Seek(0, SeekOrigin.Begin); + + var sample = TextFileSample.CreateFromFullStream(stream); + Assert.IsNotNull(sample); + Assert.IsTrue(sample.FullFileSize > 0); + } + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/TrainerExtensionsTests.cs b/test/Microsoft.ML.AutoML.Tests/TrainerExtensionsTests.cs new file mode 100644 index 0000000000..a3ca10fddd --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/TrainerExtensionsTests.cs @@ -0,0 +1,311 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class TrainerExtensionsTests + { + [TestMethod] + public void TrainerExtensionInstanceTests() + { + var context = new MLContext(); + var columnInfo = new ColumnInformation(); + var trainerNames = Enum.GetValues(typeof(TrainerName)).Cast() + .Except(new[] { TrainerName.Ova }); + foreach (var trainerName in trainerNames) + { + var extension = TrainerExtensionCatalog.GetTrainerExtension(trainerName); + var sweepParams = extension.GetHyperparamSweepRanges(); + Assert.IsNotNull(sweepParams); + foreach (var sweepParam in sweepParams) + { + sweepParam.RawValue = 1; + } + var instance = extension.CreateInstance(context, sweepParams, columnInfo); + Assert.IsNotNull(instance); + var pipelineNode = extension.CreatePipelineNode(null, columnInfo); + Assert.IsNotNull(pipelineNode); + } + } + + [TestMethod] + public void BuildLightGbmPipelineNode() + { + var sweepParams = SweepableParams.BuildLightGbmParams(); + foreach (var sweepParam in sweepParams) + { + sweepParam.RawValue = 1; + } + + var pipelineNode = new LightGbmBinaryExtension().CreatePipelineNode(sweepParams, new ColumnInformation()); + + var expectedJson = @"{ + ""Name"": ""LightGbmBinary"", + ""NodeType"": ""Trainer"", + ""InColumns"": [ + ""Features"" + ], + ""OutColumns"": [ + ""Score"" + ], + ""Properties"": { + ""NumberOfIterations"": 20, + ""LearningRate"": 1, + ""NumberOfLeaves"": 1, + ""MinimumExampleCountPerLeaf"": 10, + ""UseCategoricalSplit"": false, + ""HandleMissingValue"": false, + ""MinimumExampleCountPerGroup"": 50, + ""MaximumCategoricalSplitPointCount"": 16, + ""CategoricalSmoothing"": 10, + ""L2CategoricalRegularization"": 0.5, + ""Booster"": { + ""Name"": ""GradientBooster.Options"", + ""Properties"": { + ""L2Regularization"": 0.5, + ""L1Regularization"": 0.5 + } + }, + ""LabelColumnName"": ""Label"" + } +}"; + Util.AssertObjectMatchesJson(expectedJson, pipelineNode); + } + + [TestMethod] + public void BuildSdcaPipelineNode() + { + var sweepParams = SweepableParams.BuildSdcaParams(); + foreach (var sweepParam in sweepParams) + { + sweepParam.RawValue = 1; + } + + var pipelineNode = new SdcaLogisticRegressionBinaryExtension().CreatePipelineNode(sweepParams, new ColumnInformation()); + var expectedJson = @"{ + ""Name"": ""SdcaLogisticRegressionBinary"", + ""NodeType"": ""Trainer"", + ""InColumns"": [ + ""Features"" + ], + ""OutColumns"": [ + ""Score"" + ], + ""Properties"": { + ""L2Regularization"": 1E-07, + ""L1Regularization"": 0.0, + ""ConvergenceTolerance"": 0.01, + ""MaximumNumberOfIterations"": 10, + ""Shuffle"": true, + ""BiasLearningRate"": 0.01, + ""LabelColumnName"": ""Label"" + } +}"; + Util.AssertObjectMatchesJson(expectedJson, pipelineNode); + } + + [TestMethod] + public void BuildLightGbmPipelineNodeDefaultParams() + { + var pipelineNode = new LightGbmBinaryExtension().CreatePipelineNode( + new List(), + new ColumnInformation()); + var expectedJson = @"{ + ""Name"": ""LightGbmBinary"", + ""NodeType"": ""Trainer"", + ""InColumns"": [ + ""Features"" + ], + ""OutColumns"": [ + ""Score"" + ], + ""Properties"": { + ""LabelColumnName"": ""Label"" + } +}"; + Util.AssertObjectMatchesJson(expectedJson, pipelineNode); + } + + [TestMethod] + public void BuildPipelineNodeWithCustomColumns() + { + var columnInfo = new ColumnInformation() + { + LabelColumnName = "L", + ExampleWeightColumnName = "W" + }; + var sweepParams = SweepableParams.BuildFastForestParams(); + foreach (var sweepParam in sweepParams) + { + sweepParam.RawValue = 1; + } + + var pipelineNode = new FastForestBinaryExtension().CreatePipelineNode(sweepParams, columnInfo); + var expectedJson = @"{ + ""Name"": ""FastForestBinary"", + ""NodeType"": ""Trainer"", + ""InColumns"": [ + ""Features"" + ], + ""OutColumns"": [ + ""Score"" + ], + ""Properties"": { + ""NumberOfLeaves"": 1, + ""MinimumExampleCountPerLeaf"": 10, + ""NumberOfTrees"": 100, + ""LabelColumnName"": ""L"", + ""ExampleWeightColumnName"": ""W"" + } +}"; + Util.AssertObjectMatchesJson(expectedJson, pipelineNode); + } + + [TestMethod] + public void BuildDefaultAveragedPerceptronPipelineNode() + { + var pipelineNode = new AveragedPerceptronBinaryExtension().CreatePipelineNode(null, new ColumnInformation() { LabelColumnName = "L" }); + var expectedJson = @"{ + ""Name"": ""AveragedPerceptronBinary"", + ""NodeType"": ""Trainer"", + ""InColumns"": [ + ""Features"" + ], + ""OutColumns"": [ + ""Score"" + ], + ""Properties"": { + ""LabelColumnName"": ""L"", + ""NumberOfIterations"": 10 + } +}"; + Util.AssertObjectMatchesJson(expectedJson, pipelineNode); + } + + [TestMethod] + public void BuildOvaPipelineNode() + { + var pipelineNode = new FastForestOvaExtension().CreatePipelineNode(null, new ColumnInformation()); + var expectedJson = @"{ + ""Name"": ""Ova"", + ""NodeType"": ""Trainer"", + ""InColumns"": null, + ""OutColumns"": null, + ""Properties"": { + ""LabelColumnName"": ""Label"", + ""BinaryTrainer"": { + ""Name"": ""FastForestBinary"", + ""NodeType"": ""Trainer"", + ""InColumns"": [ + ""Features"" + ], + ""OutColumns"": [ + ""Score"" + ], + ""Properties"": { + ""LabelColumnName"": ""Label"" + } + } + } +}"; + Util.AssertObjectMatchesJson(expectedJson, pipelineNode); + } + + [TestMethod] + public void BuildParameterSetLightGbm() + { + var props = new Dictionary() + { + {"NumberOfIterations", 1 }, + {"LearningRate", 1 }, + {"Booster", new CustomProperty() { + Name = "GradientBooster.Options", + Properties = new Dictionary() + { + {"L2Regularization", 1 }, + {"L1Regularization", 1 }, + } + } }, + }; + var binaryParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmBinary, props); + var multiParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmMulti, props); + var regressionParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.LightGbmRegression, props); + + foreach (var paramSet in new ParameterSet[] { binaryParams, multiParams, regressionParams }) + { + Assert.AreEqual(4, paramSet.Count); + Assert.AreEqual("1", paramSet["NumberOfIterations"].ValueText); + Assert.AreEqual("1", paramSet["LearningRate"].ValueText); + Assert.AreEqual("1", paramSet["L2Regularization"].ValueText); + Assert.AreEqual("1", paramSet["L1Regularization"].ValueText); + } + } + + [TestMethod] + public void BuildParameterSetSdca() + { + var props = new Dictionary() + { + {"LearningRate", 1 }, + }; + + var sdcaParams = TrainerExtensionUtil.BuildParameterSet(TrainerName.SdcaLogisticRegressionBinary, props); + + Assert.AreEqual(1, sdcaParams.Count); + Assert.AreEqual("1", sdcaParams["LearningRate"].ValueText); + } + + [TestMethod] + public void PublicToPrivateTrainerNamesBinaryTest() + { + var publicNames = Enum.GetValues(typeof(BinaryClassificationTrainer)).Cast(); + var internalNames = TrainerExtensionUtil.GetTrainerNames(publicNames); + Assert.AreEqual(publicNames.Distinct().Count(), internalNames.Distinct().Count()); + } + + [TestMethod] + public void PublicToPrivateTrainerNamesMultiTest() + { + var publicNames = Enum.GetValues(typeof(MulticlassClassificationTrainer)).Cast(); + var internalNames = TrainerExtensionUtil.GetTrainerNames(publicNames); + Assert.AreEqual(publicNames.Distinct().Count(), internalNames.Distinct().Count()); + } + + [TestMethod] + public void PublicToPrivateTrainerNamesRegressionTest() + { + var publicNames = Enum.GetValues(typeof(RegressionTrainer)).Cast(); + var internalNames = TrainerExtensionUtil.GetTrainerNames(publicNames); + Assert.AreEqual(publicNames.Distinct().Count(), internalNames.Distinct().Count()); + } + + [TestMethod] + public void PublicToPrivateTrainerNamesNullTest() + { + var internalNames = TrainerExtensionUtil.GetTrainerNames(null as IEnumerable); + Assert.AreEqual(null, internalNames); + } + + [TestMethod] + public void AllowedTrainersWhitelistNullTest() + { + var trainers = RecipeInference.AllowedTrainers(new MLContext(), TaskKind.BinaryClassification, new ColumnInformation(), null); + Assert.IsTrue(trainers.Any()); + } + + [TestMethod] + public void AllowedTrainersWhitelistTest() + { + var whitelist = new[] { TrainerName.AveragedPerceptronBinary, TrainerName.FastForestBinary }; + var trainers = RecipeInference.AllowedTrainers(new MLContext(), TaskKind.BinaryClassification, new ColumnInformation(), whitelist); + Assert.AreEqual(whitelist.Count(), trainers.Count()); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/TransformInferenceTests.cs b/test/Microsoft.ML.AutoML.Tests/TransformInferenceTests.cs new file mode 100644 index 0000000000..0a4461cf79 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/TransformInferenceTests.cs @@ -0,0 +1,788 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class TransformInferenceTests + { + [TestMethod] + public void TransformInferenceNumAndCatCols() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo("Numeric1", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + new DatasetColumnInfo("Categorical1", TextDataViewType.Instance, ColumnPurpose.CategoricalFeature, new ColumnDimensions(7, null)), + new DatasetColumnInfo("Categorical2", TextDataViewType.Instance, ColumnPurpose.CategoricalFeature, new ColumnDimensions(7, null)), + new DatasetColumnInfo("LargeCat1", TextDataViewType.Instance, ColumnPurpose.CategoricalFeature, new ColumnDimensions(500, null)), + new DatasetColumnInfo("LargeCat2", TextDataViewType.Instance, ColumnPurpose.CategoricalFeature, new ColumnDimensions(500, null)), + }, @"[ + { + ""Name"": ""OneHotEncoding"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Categorical1"", + ""Categorical2"" + ], + ""OutColumns"": [ + ""Categorical1"", + ""Categorical2"" + ], + ""Properties"": {} + }, + { + ""Name"": ""OneHotHashEncoding"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""LargeCat1"", + ""LargeCat2"" + ], + ""OutColumns"": [ + ""LargeCat1"", + ""LargeCat2"" + ], + ""Properties"": {} + }, + { + ""Name"": ""ColumnConcatenating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Categorical1"", + ""Categorical2"", + ""LargeCat1"", + ""LargeCat2"", + ""Numeric1"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void TransformInferenceNumCatAndFeatCols() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo(DefaultColumnNames.Features, NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + new DatasetColumnInfo("Numeric1", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + new DatasetColumnInfo("Categorical1", TextDataViewType.Instance, ColumnPurpose.CategoricalFeature, new ColumnDimensions(7, null)), + new DatasetColumnInfo("Categorical2", TextDataViewType.Instance, ColumnPurpose.CategoricalFeature, new ColumnDimensions(7, null)), + new DatasetColumnInfo("LargeCat1", TextDataViewType.Instance, ColumnPurpose.CategoricalFeature, new ColumnDimensions(500, null)), + new DatasetColumnInfo("LargeCat2", TextDataViewType.Instance, ColumnPurpose.CategoricalFeature, new ColumnDimensions(500, null)), + }, @"[ + { + ""Name"": ""OneHotEncoding"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Categorical1"", + ""Categorical2"" + ], + ""OutColumns"": [ + ""Categorical1"", + ""Categorical2"" + ], + ""Properties"": {} + }, + { + ""Name"": ""OneHotHashEncoding"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""LargeCat1"", + ""LargeCat2"" + ], + ""OutColumns"": [ + ""LargeCat1"", + ""LargeCat2"" + ], + ""Properties"": {} + }, + { + ""Name"": ""ColumnConcatenating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Categorical1"", + ""Categorical2"", + ""LargeCat1"", + ""LargeCat2"", + ""Features"", + ""Numeric1"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void TransformInferenceCatAndFeatCols() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo(DefaultColumnNames.Features, NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + new DatasetColumnInfo("Categorical1", TextDataViewType.Instance, ColumnPurpose.CategoricalFeature, new ColumnDimensions(7, null)), + new DatasetColumnInfo("LargeCat1", TextDataViewType.Instance, ColumnPurpose.CategoricalFeature, new ColumnDimensions(500, null)), + }, @"[ + { + ""Name"": ""OneHotEncoding"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Categorical1"" + ], + ""OutColumns"": [ + ""Categorical1"" + ], + ""Properties"": {} + }, + { + ""Name"": ""OneHotHashEncoding"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""LargeCat1"" + ], + ""OutColumns"": [ + ""LargeCat1"" + ], + ""Properties"": {} + }, + { + ""Name"": ""ColumnConcatenating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Categorical1"", + ""LargeCat1"", + ""Features"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void TransformInferenceNumericCol() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo("Numeric", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + }, + @"[ + { + ""Name"": ""ColumnConcatenating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Numeric"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void TransformInferenceNumericCols() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo("Numeric1", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + new DatasetColumnInfo("Numeric2", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + }, @"[ + { + ""Name"": ""ColumnConcatenating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Numeric1"", + ""Numeric2"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void TransformInferenceFeatColScalar() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo(DefaultColumnNames.Features, NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + }, @"[ + { + ""Name"": ""ColumnConcatenating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Features"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void TransformInferenceFeatColVector() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo(DefaultColumnNames.Features, new VectorDataViewType(NumberDataViewType.Single), ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + }, @"[]"); + } + + [TestMethod] + public void NumericAndFeatCol() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo(DefaultColumnNames.Features, NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + new DatasetColumnInfo("Numeric", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + }, @"[ + { + ""Name"": ""ColumnConcatenating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Features"", + ""Numeric"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void NumericScalarCol() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo("Numeric", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + }, @"[ + { + ""Name"": ""ColumnConcatenating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Numeric"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void NumericVectorCol() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo("Numeric", new VectorDataViewType(NumberDataViewType.Single), ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + }, @"[ + { + ""Name"": ""ColumnCopying"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Numeric"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void TransformInferenceTextCol() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo("Text", TextDataViewType.Instance, ColumnPurpose.TextFeature, new ColumnDimensions(null, null)), + }, @"[ + { + ""Name"": ""TextFeaturizing"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Text"" + ], + ""OutColumns"": [ + ""Text_tf"" + ], + ""Properties"": {} + }, + { + ""Name"": ""ColumnCopying"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Text_tf"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void TransformInferenceTextAndFeatCol() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo(DefaultColumnNames.Features, NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + new DatasetColumnInfo("Text", TextDataViewType.Instance, ColumnPurpose.TextFeature, new ColumnDimensions(null, null)), + }, + @"[ + { + ""Name"": ""TextFeaturizing"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Text"" + ], + ""OutColumns"": [ + ""Text_tf"" + ], + ""Properties"": {} + }, + { + ""Name"": ""ColumnConcatenating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Text_tf"", + ""Features"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void TransformInferenceBoolCol() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo("Bool", BooleanDataViewType.Instance, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + }, @"[ + { + ""Name"": ""TypeConverting"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Bool"" + ], + ""OutColumns"": [ + ""Bool"" + ], + ""Properties"": {} + }, + { + ""Name"": ""ColumnConcatenating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Bool"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void TransformInferenceBoolAndNumCols() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo("Numeric", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + new DatasetColumnInfo("Bool", BooleanDataViewType.Instance, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + }, @"[ + { + ""Name"": ""TypeConverting"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Bool"" + ], + ""OutColumns"": [ + ""Bool"" + ], + ""Properties"": {} + }, + { + ""Name"": ""ColumnConcatenating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Bool"", + ""Numeric"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void TransformInferenceBoolAndFeatCol() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo(DefaultColumnNames.Features, NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + new DatasetColumnInfo("Bool", BooleanDataViewType.Instance, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + }, @"[ + { + ""Name"": ""TypeConverting"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Bool"" + ], + ""OutColumns"": [ + ""Bool"" + ], + ""Properties"": {} + }, + { + ""Name"": ""ColumnConcatenating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Bool"", + ""Features"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void TransformInferenceNumericMissingCol() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo("Missing", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, true)), + new DatasetColumnInfo("Numeric", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, false)), + }, @"[ + { + ""Name"": ""MissingValueIndicating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Missing"" + ], + ""OutColumns"": [ + ""Missing_MissingIndicator"" + ], + ""Properties"": {} + }, + { + ""Name"": ""TypeConverting"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Missing_MissingIndicator"" + ], + ""OutColumns"": [ + ""Missing_MissingIndicator"" + ], + ""Properties"": {} + }, + { + ""Name"": ""MissingValueReplacing"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Missing"" + ], + ""OutColumns"": [ + ""Missing"" + ], + ""Properties"": {} + }, + { + ""Name"": ""ColumnConcatenating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Missing_MissingIndicator"", + ""Missing"", + ""Numeric"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void TransformInferenceNumericMissingCols() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo("Missing1", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, true)), + new DatasetColumnInfo("Missing2", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, true)), + new DatasetColumnInfo("Numeric", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, false)), + }, @"[ + { + ""Name"": ""MissingValueIndicating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Missing1"", + ""Missing2"" + ], + ""OutColumns"": [ + ""Missing1_MissingIndicator"", + ""Missing2_MissingIndicator"" + ], + ""Properties"": {} + }, + { + ""Name"": ""TypeConverting"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Missing1_MissingIndicator"", + ""Missing2_MissingIndicator"" + ], + ""OutColumns"": [ + ""Missing1_MissingIndicator"", + ""Missing2_MissingIndicator"" + ], + ""Properties"": {} + }, + { + ""Name"": ""MissingValueReplacing"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Missing1"", + ""Missing2"" + ], + ""OutColumns"": [ + ""Missing1"", + ""Missing2"" + ], + ""Properties"": {} + }, + { + ""Name"": ""ColumnConcatenating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Missing1_MissingIndicator"", + ""Missing2_MissingIndicator"", + ""Missing1"", + ""Missing2"", + ""Numeric"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void TransformInferenceIgnoreCol() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo("Numeric1", NumberDataViewType.Single, ColumnPurpose.Ignore, new ColumnDimensions(null, null)), + new DatasetColumnInfo("Numeric2", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + }, @"[ + { + ""Name"": ""ColumnConcatenating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Numeric2"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void TransformInferenceDefaultLabelCol() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo(DefaultColumnNames.Features, new VectorDataViewType(NumberDataViewType.Single), ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + new DatasetColumnInfo(DefaultColumnNames.Label, NumberDataViewType.Single, ColumnPurpose.Label, new ColumnDimensions(null, null)), + }, @"[]"); + } + + [TestMethod] + public void TransformInferenceCustomLabelCol() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo(DefaultColumnNames.Features, new VectorDataViewType(NumberDataViewType.Single), ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + new DatasetColumnInfo("CustomLabel", NumberDataViewType.Single, ColumnPurpose.Label, new ColumnDimensions(null, null)), + }, @"[]"); + } + + [TestMethod] + public void TransformInferenceCustomTextLabelColMulticlass() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo(DefaultColumnNames.Features, new VectorDataViewType(NumberDataViewType.Single), ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + new DatasetColumnInfo("CustomLabel", TextDataViewType.Instance, ColumnPurpose.Label, new ColumnDimensions(null, null)), + }, @"[ + { + ""Name"": ""ValueToKeyMapping"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""CustomLabel"" + ], + ""OutColumns"": [ + ""CustomLabel"" + ], + ""Properties"": {} + } +]", TaskKind.MulticlassClassification); + } + + [TestMethod] + public void TransformInferenceMissingNameCollision() + { + TransformInferenceTestCore(new[] + { + new DatasetColumnInfo("Missing", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, true)), + new DatasetColumnInfo("Missing_MissingIndicator", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, false)), + new DatasetColumnInfo("Missing_MissingIndicator0", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, false)), + }, @"[ + { + ""Name"": ""MissingValueIndicating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Missing"" + ], + ""OutColumns"": [ + ""Missing_MissingIndicator1"" + ], + ""Properties"": {} + }, + { + ""Name"": ""TypeConverting"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Missing_MissingIndicator1"" + ], + ""OutColumns"": [ + ""Missing_MissingIndicator1"" + ], + ""Properties"": {} + }, + { + ""Name"": ""MissingValueReplacing"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Missing"" + ], + ""OutColumns"": [ + ""Missing"" + ], + ""Properties"": {} + }, + { + ""Name"": ""ColumnConcatenating"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""Missing_MissingIndicator1"", + ""Missing"", + ""Missing_MissingIndicator"", + ""Missing_MissingIndicator0"" + ], + ""OutColumns"": [ + ""Features"" + ], + ""Properties"": {} + } +]"); + } + + private static void TransformInferenceTestCore( + DatasetColumnInfo[] columns, + string expectedJson, + TaskKind task = TaskKind.BinaryClassification) + { + var transforms = TransformInferenceApi.InferTransforms(new MLContext(), task, columns); + TestApplyTransformsToRealDataView(transforms, columns); + var pipelineNodes = transforms.Select(t => t.PipelineNode); + Util.AssertObjectMatchesJson(expectedJson, pipelineNodes); + } + + private static void TestApplyTransformsToRealDataView(IEnumerable transforms, + IEnumerable columns) + { + // create a dummy data view from input columns + var data = BuildDummyDataView(columns); + + // iterate thru suggested transforms and apply it to a real data view + foreach (var transform in transforms.Select(t => t.Estimator)) + { + data = transform.Fit(data).Transform(data); + } + + // assert Features column of type 'R4' exists + var featuresCol = data.Schema.GetColumnOrNull(DefaultColumnNames.Features); + Assert.IsNotNull(featuresCol); + Assert.AreEqual(true, featuresCol.Value.Type.IsVector()); + Assert.AreEqual(NumberDataViewType.Single, featuresCol.Value.Type.GetItemType()); + } + + private static IDataView BuildDummyDataView(IEnumerable columns) + { + return BuildDummyDataView(columns.Select(c => (c.Name, c.Type))); + } + + private static IDataView BuildDummyDataView(IEnumerable<(string name, DataViewType type)> columns) + { + var dataBuilder = new ArrayDataViewBuilder(new MLContext()); + foreach(var column in columns) + { + if (column.type == NumberDataViewType.Single) + { + dataBuilder.AddColumn(column.name, NumberDataViewType.Single, new float[] { 0 }); + } + else if (column.type == BooleanDataViewType.Instance) + { + dataBuilder.AddColumn(column.name, BooleanDataViewType.Instance, new bool[] { false }); + } + else if (column.type == TextDataViewType.Instance) + { + dataBuilder.AddColumn(column.name, new string[] { "a" }); + } + else if (column.type.IsVector() && column.type.GetItemType() == NumberDataViewType.Single) + { + dataBuilder.AddColumn(column.name, Util.GetKeyValueGetter(new[] { "1", "2" }), + NumberDataViewType.Single, new float[] { 0, 0 }); + } + } + return dataBuilder.GetDataView(); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/TransformPostTrainerInferenceTests.cs b/test/Microsoft.ML.AutoML.Tests/TransformPostTrainerInferenceTests.cs new file mode 100644 index 0000000000..1099d58f58 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/TransformPostTrainerInferenceTests.cs @@ -0,0 +1,70 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class TransformPostTrainerInferenceTests + { + [TestMethod] + public void TransformPostTrainerMulticlassNonKeyLabel() + { + TransformPostTrainerInferenceTestCore(TaskKind.MulticlassClassification, + new[] + { + new DatasetColumnInfo("Numeric1", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + new DatasetColumnInfo("Label", NumberDataViewType.Single, ColumnPurpose.Label, new ColumnDimensions(null, null)), + }, @"[ + { + ""Name"": ""KeyToValueMapping"", + ""NodeType"": ""Transform"", + ""InColumns"": [ + ""PredictedLabel"" + ], + ""OutColumns"": [ + ""PredictedLabel"" + ], + ""Properties"": {} + } +]"); + } + + [TestMethod] + public void TransformPostTrainerBinaryLabel() + { + TransformPostTrainerInferenceTestCore(TaskKind.BinaryClassification, + new[] + { + new DatasetColumnInfo("Numeric1", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + new DatasetColumnInfo("Label", NumberDataViewType.Single, ColumnPurpose.Label, new ColumnDimensions(null, null)), + }, @"[]"); + } + + [TestMethod] + public void TransformPostTrainerMulticlassKeyLabel() + { + TransformPostTrainerInferenceTestCore(TaskKind.MulticlassClassification, + new[] + { + new DatasetColumnInfo("Numeric1", NumberDataViewType.Single, ColumnPurpose.NumericFeature, new ColumnDimensions(null, null)), + new DatasetColumnInfo("Label", new KeyDataViewType(typeof(uint), 3), ColumnPurpose.Label, new ColumnDimensions(null, null)), + }, @"[]"); + } + + private static void TransformPostTrainerInferenceTestCore( + TaskKind task, + DatasetColumnInfo[] columns, + string expectedJson) + { + var transforms = TransformInferenceApi.InferTransformsPostTrainer(new MLContext(), task, columns); + var pipelineNodes = transforms.Select(t => t.PipelineNode); + Util.AssertObjectMatchesJson(expectedJson, pipelineNodes); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs b/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs new file mode 100644 index 0000000000..e8962c484c --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/UserInputValidationTests.cs @@ -0,0 +1,187 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using Microsoft.ML.Data; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Microsoft.ML.Auto.Test +{ + [TestClass] + public class UserInputValidationTests + { + private static readonly IDataView Data = DatasetUtil.GetUciAdultDataView(); + + [TestMethod] + [ExpectedException(typeof(ArgumentNullException))] + public void ValidateExperimentExecuteNullTrainData() + { + UserInputValidationUtil.ValidateExperimentExecuteArgs(null, new ColumnInformation(), null); + } + + [TestMethod] + [ExpectedException(typeof(ArgumentException))] + public void ValidateExperimentExecuteNullLabel() + { + UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, + new ColumnInformation() { LabelColumnName = null }, null); + } + + [TestMethod] + [ExpectedException(typeof(ArgumentException))] + public void ValidateExperimentExecuteLabelNotInTrain() + { + UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, + new ColumnInformation() { LabelColumnName = "L" }, null); + } + + [TestMethod] + [ExpectedException(typeof(ArgumentException))] + public void ValidateExperimentExecuteNumericColNotInTrain() + { + var columnInfo = new ColumnInformation(); + columnInfo.NumericColumnNames.Add("N"); + + UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null); + } + + [TestMethod] + [ExpectedException(typeof(ArgumentException))] + public void ValidateExperimentExecuteNullNumericCol() + { + var columnInfo = new ColumnInformation(); + columnInfo.NumericColumnNames.Add(null); + UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null); + } + + [TestMethod] + [ExpectedException(typeof(ArgumentException))] + public void ValidateExperimentExecuteDuplicateCol() + { + var columnInfo = new ColumnInformation(); + columnInfo.NumericColumnNames.Add(DefaultColumnNames.Label); + + UserInputValidationUtil.ValidateExperimentExecuteArgs(Data, columnInfo, null); + } + + [TestMethod] + [ExpectedException(typeof(ArgumentException))] + public void ValidateExperimentExecuteArgsTrainValidColCountMismatch() + { + var context = new MLContext(); + + var trainDataBuilder = new ArrayDataViewBuilder(context); + trainDataBuilder.AddColumn("0", new string[] { "0" }); + trainDataBuilder.AddColumn("1", new string[] { "1" }); + var trainData = trainDataBuilder.GetDataView(); + + var validDataBuilder = new ArrayDataViewBuilder(context); + validDataBuilder.AddColumn("0", new string[] { "0" }); + var validData = validDataBuilder.GetDataView(); + + UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, + new ColumnInformation() { LabelColumnName = "0" }, validData); + } + + [TestMethod] + [ExpectedException(typeof(ArgumentException))] + public void ValidateExperimentExecuteArgsTrainValidColNamesMismatch() + { + var context = new MLContext(); + + var trainDataBuilder = new ArrayDataViewBuilder(context); + trainDataBuilder.AddColumn("0", new string[] { "0" }); + trainDataBuilder.AddColumn("1", new string[] { "1" }); + var trainData = trainDataBuilder.GetDataView(); + + var validDataBuilder = new ArrayDataViewBuilder(context); + validDataBuilder.AddColumn("0", new string[] { "0" }); + validDataBuilder.AddColumn("2", new string[] { "2" }); + var validData = validDataBuilder.GetDataView(); + + UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, + new ColumnInformation() { LabelColumnName = "0" }, validData); + } + + [TestMethod] + [ExpectedException(typeof(ArgumentException))] + public void ValidateExperimentExecuteArgsTrainValidColTypeMismatch() + { + var context = new MLContext(); + + var trainDataBuilder = new ArrayDataViewBuilder(context); + trainDataBuilder.AddColumn("0", new string[] { "0" }); + trainDataBuilder.AddColumn("1", new string[] { "1" }); + var trainData = trainDataBuilder.GetDataView(); + + var validDataBuilder = new ArrayDataViewBuilder(context); + validDataBuilder.AddColumn("0", new string[] { "0" }); + validDataBuilder.AddColumn("1", NumberDataViewType.Single, new float[] { 1 }); + var validData = validDataBuilder.GetDataView(); + + UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, + new ColumnInformation() { LabelColumnName = "0" }, validData); + } + + [TestMethod] + [ExpectedException(typeof(ArgumentNullException))] + public void ValidateInferColumnsArgsNullPath() + { + UserInputValidationUtil.ValidateInferColumnsArgs(null, "Label"); + } + + [TestMethod] + [ExpectedException(typeof(ArgumentException))] + public void ValidateInferColumnsArgsPathNotExist() + { + UserInputValidationUtil.ValidateInferColumnsArgs("idontexist", "Label"); + } + + [TestMethod] + [ExpectedException(typeof(ArgumentException))] + public void ValidateInferColumnsArgsEmptyFile() + { + const string emptyFilePath = "empty"; + File.Create(emptyFilePath).Dispose(); + UserInputValidationUtil.ValidateInferColumnsArgs(emptyFilePath, "Label"); + } + + [TestMethod] + public void ValidateInferColsPath() + { + UserInputValidationUtil.ValidateInferColumnsArgs(DatasetUtil.DownloadUciAdultDataset()); + } + + [TestMethod] + [ExpectedException(typeof(ArgumentException))] + public void ValidateFeaturesColInvalidType() + { + var schemaBuilder = new DataViewSchema.Builder(); + schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Double); + schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single); + var schema = schemaBuilder.ToSchema(); + var dataView = new EmptyDataView(new MLContext(), schema); + UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, new ColumnInformation(), null); + } + + [TestMethod] + [ExpectedException(typeof(ArgumentException))] + public void ValidateTextColumnNotText() + { + const string TextPurposeColName = "TextColumn"; + var schemaBuilder = new DataViewSchema.Builder(); + schemaBuilder.AddColumn(DefaultColumnNames.Features, NumberDataViewType.Single); + schemaBuilder.AddColumn(DefaultColumnNames.Label, NumberDataViewType.Single); + schemaBuilder.AddColumn(TextPurposeColName, NumberDataViewType.Double); + var schema = schemaBuilder.ToSchema(); + var dataView = new EmptyDataView(new MLContext(), schema); + + var columnInfo = new ColumnInformation(); + columnInfo.NumericColumnNames.Add(TextPurposeColName); + + UserInputValidationUtil.ValidateExperimentExecuteArgs(dataView, columnInfo, null); + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/Util.cs b/test/Microsoft.ML.AutoML.Tests/Util.cs new file mode 100644 index 0000000000..3a652c48ca --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/Util.cs @@ -0,0 +1,37 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Data; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; + +namespace Microsoft.ML.Auto.Test +{ + internal static class Util + { + public static void AssertObjectMatchesJson(string expectedJson, T obj) + { + var actualJson = JsonConvert.SerializeObject(obj, + Formatting.Indented, new JsonConverter[] { new StringEnumConverter() }); + Assert.AreEqual(expectedJson, actualJson); + } + + public static ValueGetter>> GetKeyValueGetter(IEnumerable colNames) + { + return (ref VBuffer> dst) => + { + var editor = VBufferEditor.Create(ref dst, colNames.Count()); + for (int i = 0; i < colNames.Count(); i++) + { + editor.Values[i] = colNames.ElementAt(i).AsMemory(); + } + dst = editor.Commit(); + }; + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/Utils/MLNetUtils/EmptyDataView.cs b/test/Microsoft.ML.AutoML.Tests/Utils/MLNetUtils/EmptyDataView.cs new file mode 100644 index 0000000000..bc9bd3cca6 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/Utils/MLNetUtils/EmptyDataView.cs @@ -0,0 +1,81 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using Microsoft.ML.Data; +using Microsoft.ML.Runtime; + +namespace Microsoft.ML.Auto.Test +{ + /// + /// This implements a data view that has a schema, but no rows. + /// + internal sealed class EmptyDataView : IDataView + { + private readonly IHost _host; + + public bool CanShuffle => true; + public DataViewSchema Schema { get; } + + public EmptyDataView(IHostEnvironment env, DataViewSchema schema) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(EmptyDataView)); + _host.CheckValue(schema, nameof(schema)); + Schema = schema; + } + + public long? GetRowCount() => 0; + + public DataViewRowCursor GetRowCursor(IEnumerable columnsNeeded, Random rand = null) + { + return new Cursor(_host, Schema, columnsNeeded); + } + + public DataViewRowCursor[] GetRowCursorSet(IEnumerable columnsNeeded, int n, Random rand = null) + { + return new[] { new Cursor(_host, Schema, columnsNeeded) }; + } + + private sealed class Cursor : RootCursorBase + { + private readonly bool[] _active; + + public override DataViewSchema Schema { get; } + public override long Batch => 0; + + public Cursor(IChannelProvider provider, DataViewSchema schema, IEnumerable columnsNeeded) + : base(provider) + { + Schema = schema; + _active = MLNetUtils.BuildArray(Schema.Count, columnsNeeded); + } + + public override ValueGetter GetIdGetter() + { + return (ref DataViewRowId val) => throw Ch.Except(RowCursorUtils.FetchValueStateError); + } + + protected override bool MoveNextCore() => false; + + /// + /// Returns whether the given column is active in this row. + /// + public override bool IsColumnActive(DataViewSchema.Column column) => column.Index < _active.Length && _active[column.Index]; + + /// + /// Returns a value getter delegate to fetch the value of column with the given columnIndex, from the row. + /// This throws if the column is not active in this row, or if the type + /// differs from this column's type. + /// + /// is the column's content type. + /// is the output column whose getter should be returned. + public override ValueGetter GetGetter(DataViewSchema.Column column) + { + return (ref TValue value) => throw Ch.Except(RowCursorUtils.FetchValueStateError); + } + } + } +} \ No newline at end of file diff --git a/test/Microsoft.ML.AutoML.Tests/Utils/MLNetUtils/MLNetUtils.cs b/test/Microsoft.ML.AutoML.Tests/Utils/MLNetUtils/MLNetUtils.cs new file mode 100644 index 0000000000..c39f8ae195 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/Utils/MLNetUtils/MLNetUtils.cs @@ -0,0 +1,25 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; + +namespace Microsoft.ML.Auto.Test +{ + internal static class MLNetUtils + { + public static bool[] BuildArray(int length, IEnumerable columnsNeeded) + { + Contracts.CheckParam(length >= 0, nameof(length)); + + var result = new bool[length]; + foreach (var col in columnsNeeded) + { + if (col.Index < result.Length) + result[col.Index] = true; + } + + return result; + } + } +} diff --git a/test/Microsoft.ML.AutoML.Tests/run-tests.proj b/test/Microsoft.ML.AutoML.Tests/run-tests.proj new file mode 100644 index 0000000000..dd2433b3c5 --- /dev/null +++ b/test/Microsoft.ML.AutoML.Tests/run-tests.proj @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentBinaryTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentBinaryTest.approved.txt new file mode 100644 index 0000000000..e25d5e46a3 --- /dev/null +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentBinaryTest.approved.txt @@ -0,0 +1,149 @@ +//***************************************************************************************** +//* * +//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * +//* * +//***************************************************************************************** + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Data; +using TestNamespace.Model.DataModels; + +namespace TestNamespace.ConsoleApp +{ + public static class ModelBuilder + { + private static string TRAIN_DATA_FILEPATH = @"x:\dummypath\dummy_train.csv"; + private static string TEST_DATA_FILEPATH = @"x:\dummypath\dummy_test.csv"; + private static string MODEL_FILEPATH = @"../../../../TestNamespace.Model/MLModel.zip"; + + // Create MLContext to be shared across the model creation workflow objects + // Set a random seed for repeatable/deterministic results across multiple trainings. + private static MLContext mlContext = new MLContext(seed: 1); + + public static void CreateModel() + { + // Load Data + IDataView trainingDataView = mlContext.Data.LoadFromTextFile( + path: TRAIN_DATA_FILEPATH, + hasHeader: true, + separatorChar: ',', + allowQuoting: true, + allowSparse: true); + + IDataView testDataView = mlContext.Data.LoadFromTextFile( + path: TEST_DATA_FILEPATH, + hasHeader: true, + separatorChar: ',', + allowQuoting: true, + allowSparse: true); + // Build training pipeline + IEstimator trainingPipeline = BuildTrainingPipeline(mlContext); + + // Train Model + ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline); + + // Evaluate quality of Model + EvaluateModel(mlContext, mlModel, testDataView); + + // Save model + SaveModel(mlContext, mlModel, MODEL_FILEPATH, trainingDataView.Schema); + } + + public static IEstimator BuildTrainingPipeline(MLContext mlContext) + { + // Data process configuration with pipeline data transformations + var dataProcessPipeline = mlContext.Transforms.Concatenate("Out", new[] { "In" }) + .AppendCacheCheckpoint(mlContext); + + // Set the training algorithm + var trainer = mlContext.BinaryClassification.Trainers.LightGbm(labelColumnName: "Label", featureColumnName: "Features"); + var trainingPipeline = dataProcessPipeline.Append(trainer); + + return trainingPipeline; + } + + public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline) + { + Console.WriteLine("=============== Training model ==============="); + + ITransformer model = trainingPipeline.Fit(trainingDataView); + + Console.WriteLine("=============== End of training process ==============="); + return model; + } + + private static void EvaluateModel(MLContext mlContext, ITransformer mlModel, IDataView testDataView) + { + // Evaluate the model and show accuracy stats + Console.WriteLine("===== Evaluating Model's accuracy with Test data ====="); + IDataView predictions = mlModel.Transform(testDataView); + var metrics = mlContext.BinaryClassification.EvaluateNonCalibrated(predictions, "Label", "Score"); + PrintBinaryClassificationMetrics(metrics); + } + private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema) + { + // Save/persist the trained model to a .ZIP file + Console.WriteLine($"=============== Saving the model ==============="); + using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write)) + mlContext.Model.Save(mlModel, modelInputSchema, fs); + + Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath)); + } + + public static string GetAbsolutePath(string relativePath) + { + FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location); + string assemblyFolderPath = _dataRoot.Directory.FullName; + + string fullPath = Path.Combine(assemblyFolderPath, relativePath); + + return fullPath; + } + + public static void PrintBinaryClassificationMetrics(BinaryClassificationMetrics metrics) + { + Console.WriteLine($"************************************************************"); + Console.WriteLine($"* Metrics for binary classification model "); + Console.WriteLine($"*-----------------------------------------------------------"); + Console.WriteLine($"* Accuracy: {metrics.Accuracy:P2}"); + Console.WriteLine($"* Auc: {metrics.AreaUnderRocCurve:P2}"); + Console.WriteLine($"************************************************************"); + } + + + public static void PrintBinaryClassificationFoldsAverageMetrics(IEnumerable> crossValResults) + { + var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics); + + var AccuracyValues = metricsInMultipleFolds.Select(m => m.Accuracy); + var AccuracyAverage = AccuracyValues.Average(); + var AccuraciesStdDeviation = CalculateStandardDeviation(AccuracyValues); + var AccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(AccuracyValues); + + + Console.WriteLine($"*************************************************************************************************************"); + Console.WriteLine($"* Metrics for Binary Classification model "); + Console.WriteLine($"*------------------------------------------------------------------------------------------------------------"); + Console.WriteLine($"* Average Accuracy: {AccuracyAverage:0.###} - Standard deviation: ({AccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({AccuraciesConfidenceInterval95:#.###})"); + Console.WriteLine($"*************************************************************************************************************"); + } + + public static double CalculateStandardDeviation(IEnumerable values) + { + double average = values.Average(); + double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum(); + double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1)); + return standardDeviation; + } + + public static double CalculateConfidenceInterval95(IEnumerable values) + { + double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1)); + return confidenceInterval95; + } + } +} diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentOvaTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentOvaTest.approved.txt new file mode 100644 index 0000000000..36b7deff19 --- /dev/null +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentOvaTest.approved.txt @@ -0,0 +1,171 @@ +//***************************************************************************************** +//* * +//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * +//* * +//***************************************************************************************** + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Data; +using TestNamespace.Model.DataModels; + +namespace TestNamespace.ConsoleApp +{ + public static class ModelBuilder + { + private static string TRAIN_DATA_FILEPATH = @"x:\dummypath\dummy_train.csv"; + private static string TEST_DATA_FILEPATH = @"x:\dummypath\dummy_test.csv"; + private static string MODEL_FILEPATH = @"../../../../TestNamespace.Model/MLModel.zip"; + + // Create MLContext to be shared across the model creation workflow objects + // Set a random seed for repeatable/deterministic results across multiple trainings. + private static MLContext mlContext = new MLContext(seed: 1); + + public static void CreateModel() + { + // Load Data + IDataView trainingDataView = mlContext.Data.LoadFromTextFile( + path: TRAIN_DATA_FILEPATH, + hasHeader: true, + separatorChar: ',', + allowQuoting: true, + allowSparse: true); + + IDataView testDataView = mlContext.Data.LoadFromTextFile( + path: TEST_DATA_FILEPATH, + hasHeader: true, + separatorChar: ',', + allowQuoting: true, + allowSparse: true); + // Build training pipeline + IEstimator trainingPipeline = BuildTrainingPipeline(mlContext); + + // Train Model + ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline); + + // Evaluate quality of Model + EvaluateModel(mlContext, mlModel, testDataView); + + // Save model + SaveModel(mlContext, mlModel, MODEL_FILEPATH, trainingDataView.Schema); + } + + public static IEstimator BuildTrainingPipeline(MLContext mlContext) + { + // Data process configuration with pipeline data transformations + var dataProcessPipeline = mlContext.Transforms.Concatenate("Out", new[] { "In" }) + .AppendCacheCheckpoint(mlContext); + + // Set the training algorithm + var trainer = mlContext.MulticlassClassification.Trainers.OneVersusAll(mlContext.BinaryClassification.Trainers.FastForest(labelColumnName: "Label", featureColumnName: "Features"), labelColumnName: "Label"); + var trainingPipeline = dataProcessPipeline.Append(trainer); + + return trainingPipeline; + } + + public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline) + { + Console.WriteLine("=============== Training model ==============="); + + ITransformer model = trainingPipeline.Fit(trainingDataView); + + Console.WriteLine("=============== End of training process ==============="); + return model; + } + + private static void EvaluateModel(MLContext mlContext, ITransformer mlModel, IDataView testDataView) + { + // Evaluate the model and show accuracy stats + Console.WriteLine("===== Evaluating Model's accuracy with Test data ====="); + IDataView predictions = mlModel.Transform(testDataView); + var metrics = mlContext.MulticlassClassification.Evaluate(predictions, "Label", "Score"); + PrintMulticlassClassificationMetrics(metrics); + } + private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema) + { + // Save/persist the trained model to a .ZIP file + Console.WriteLine($"=============== Saving the model ==============="); + using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write)) + mlContext.Model.Save(mlModel, modelInputSchema, fs); + + Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath)); + } + + public static string GetAbsolutePath(string relativePath) + { + FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location); + string assemblyFolderPath = _dataRoot.Directory.FullName; + + string fullPath = Path.Combine(assemblyFolderPath, relativePath); + + return fullPath; + } + + public static void PrintMulticlassClassificationMetrics(MulticlassClassificationMetrics metrics) + { + Console.WriteLine($"************************************************************"); + Console.WriteLine($"* Metrics for multi-class classification model "); + Console.WriteLine($"*-----------------------------------------------------------"); + Console.WriteLine($" MacroAccuracy = {metrics.MacroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better"); + Console.WriteLine($" MicroAccuracy = {metrics.MicroAccuracy:0.####}, a value between 0 and 1, the closer to 1, the better"); + Console.WriteLine($" LogLoss = {metrics.LogLoss:0.####}, the closer to 0, the better"); + for (int i = 0; i < metrics.PerClassLogLoss.Count; i++) + { + Console.WriteLine($" LogLoss for class {i + 1} = {metrics.PerClassLogLoss[i]:0.####}, the closer to 0, the better"); + } + Console.WriteLine($"************************************************************"); + } + + public static void PrintMulticlassClassificationFoldsAverageMetrics(IEnumerable> crossValResults) + { + var metricsInMultipleFolds = crossValResults.Select(r => r.Metrics); + + var microAccuracyValues = metricsInMultipleFolds.Select(m => m.MicroAccuracy); + var microAccuracyAverage = microAccuracyValues.Average(); + var microAccuraciesStdDeviation = CalculateStandardDeviation(microAccuracyValues); + var microAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(microAccuracyValues); + + var macroAccuracyValues = metricsInMultipleFolds.Select(m => m.MacroAccuracy); + var macroAccuracyAverage = macroAccuracyValues.Average(); + var macroAccuraciesStdDeviation = CalculateStandardDeviation(macroAccuracyValues); + var macroAccuraciesConfidenceInterval95 = CalculateConfidenceInterval95(macroAccuracyValues); + + var logLossValues = metricsInMultipleFolds.Select(m => m.LogLoss); + var logLossAverage = logLossValues.Average(); + var logLossStdDeviation = CalculateStandardDeviation(logLossValues); + var logLossConfidenceInterval95 = CalculateConfidenceInterval95(logLossValues); + + var logLossReductionValues = metricsInMultipleFolds.Select(m => m.LogLossReduction); + var logLossReductionAverage = logLossReductionValues.Average(); + var logLossReductionStdDeviation = CalculateStandardDeviation(logLossReductionValues); + var logLossReductionConfidenceInterval95 = CalculateConfidenceInterval95(logLossReductionValues); + + Console.WriteLine($"*************************************************************************************************************"); + Console.WriteLine($"* Metrics for Multi-class Classification model "); + Console.WriteLine($"*------------------------------------------------------------------------------------------------------------"); + Console.WriteLine($"* Average MicroAccuracy: {microAccuracyAverage:0.###} - Standard deviation: ({microAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({microAccuraciesConfidenceInterval95:#.###})"); + Console.WriteLine($"* Average MacroAccuracy: {macroAccuracyAverage:0.###} - Standard deviation: ({macroAccuraciesStdDeviation:#.###}) - Confidence Interval 95%: ({macroAccuraciesConfidenceInterval95:#.###})"); + Console.WriteLine($"* Average LogLoss: {logLossAverage:#.###} - Standard deviation: ({logLossStdDeviation:#.###}) - Confidence Interval 95%: ({logLossConfidenceInterval95:#.###})"); + Console.WriteLine($"* Average LogLossReduction: {logLossReductionAverage:#.###} - Standard deviation: ({logLossReductionStdDeviation:#.###}) - Confidence Interval 95%: ({logLossReductionConfidenceInterval95:#.###})"); + Console.WriteLine($"*************************************************************************************************************"); + + } + + public static double CalculateStandardDeviation(IEnumerable values) + { + double average = values.Average(); + double sumOfSquaresOfDifferences = values.Select(val => (val - average) * (val - average)).Sum(); + double standardDeviation = Math.Sqrt(sumOfSquaresOfDifferences / (values.Count() - 1)); + return standardDeviation; + } + + public static double CalculateConfidenceInterval95(IEnumerable values) + { + double confidenceInterval95 = 1.96 * CalculateStandardDeviation(values) / Math.Sqrt((values.Count() - 1)); + return confidenceInterval95; + } + } +} diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentRegressionTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentRegressionTest.approved.txt new file mode 100644 index 0000000000..122634ff08 --- /dev/null +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppModelBuilderCSFileContentRegressionTest.approved.txt @@ -0,0 +1,139 @@ +//***************************************************************************************** +//* * +//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * +//* * +//***************************************************************************************** + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Data; +using TestNamespace.Model.DataModels; + +namespace TestNamespace.ConsoleApp +{ + public static class ModelBuilder + { + private static string TRAIN_DATA_FILEPATH = @"x:\dummypath\dummy_train.csv"; + private static string TEST_DATA_FILEPATH = @"x:\dummypath\dummy_test.csv"; + private static string MODEL_FILEPATH = @"../../../../TestNamespace.Model/MLModel.zip"; + + // Create MLContext to be shared across the model creation workflow objects + // Set a random seed for repeatable/deterministic results across multiple trainings. + private static MLContext mlContext = new MLContext(seed: 1); + + public static void CreateModel() + { + // Load Data + IDataView trainingDataView = mlContext.Data.LoadFromTextFile( + path: TRAIN_DATA_FILEPATH, + hasHeader: true, + separatorChar: ',', + allowQuoting: true, + allowSparse: true); + + IDataView testDataView = mlContext.Data.LoadFromTextFile( + path: TEST_DATA_FILEPATH, + hasHeader: true, + separatorChar: ',', + allowQuoting: true, + allowSparse: true); + // Build training pipeline + IEstimator trainingPipeline = BuildTrainingPipeline(mlContext); + + // Train Model + ITransformer mlModel = TrainModel(mlContext, trainingDataView, trainingPipeline); + + // Evaluate quality of Model + EvaluateModel(mlContext, mlModel, testDataView); + + // Save model + SaveModel(mlContext, mlModel, MODEL_FILEPATH, trainingDataView.Schema); + } + + public static IEstimator BuildTrainingPipeline(MLContext mlContext) + { + // Data process configuration with pipeline data transformations + var dataProcessPipeline = mlContext.Transforms.Concatenate("Out", new[] { "In" }) + .AppendCacheCheckpoint(mlContext); + + // Set the training algorithm + var trainer = mlContext.Regression.Trainers.LightGbm(labelColumnName: "Label", featureColumnName: "Features"); + var trainingPipeline = dataProcessPipeline.Append(trainer); + + return trainingPipeline; + } + + public static ITransformer TrainModel(MLContext mlContext, IDataView trainingDataView, IEstimator trainingPipeline) + { + Console.WriteLine("=============== Training model ==============="); + + ITransformer model = trainingPipeline.Fit(trainingDataView); + + Console.WriteLine("=============== End of training process ==============="); + return model; + } + + private static void EvaluateModel(MLContext mlContext, ITransformer mlModel, IDataView testDataView) + { + // Evaluate the model and show accuracy stats + Console.WriteLine("===== Evaluating Model's accuracy with Test data ====="); + IDataView predictions = mlModel.Transform(testDataView); + var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score"); + PrintRegressionMetrics(metrics); + } + private static void SaveModel(MLContext mlContext, ITransformer mlModel, string modelRelativePath, DataViewSchema modelInputSchema) + { + // Save/persist the trained model to a .ZIP file + Console.WriteLine($"=============== Saving the model ==============="); + using (var fs = new FileStream(GetAbsolutePath(modelRelativePath), FileMode.Create, FileAccess.Write, FileShare.Write)) + mlContext.Model.Save(mlModel, modelInputSchema, fs); + + Console.WriteLine("The model is saved to {0}", GetAbsolutePath(modelRelativePath)); + } + + public static string GetAbsolutePath(string relativePath) + { + FileInfo _dataRoot = new FileInfo(typeof(Program).Assembly.Location); + string assemblyFolderPath = _dataRoot.Directory.FullName; + + string fullPath = Path.Combine(assemblyFolderPath, relativePath); + + return fullPath; + } + + public static void PrintRegressionMetrics(RegressionMetrics metrics) + { + Console.WriteLine($"*************************************************"); + Console.WriteLine($"* Metrics for regression model "); + Console.WriteLine($"*------------------------------------------------"); + Console.WriteLine($"* LossFn: {metrics.LossFunction:0.##}"); + Console.WriteLine($"* R2 Score: {metrics.RSquared:0.##}"); + Console.WriteLine($"* Absolute loss: {metrics.MeanAbsoluteError:#.##}"); + Console.WriteLine($"* Squared loss: {metrics.MeanSquaredError:#.##}"); + Console.WriteLine($"* RMS loss: {metrics.RootMeanSquaredError:#.##}"); + Console.WriteLine($"*************************************************"); + } + + public static void PrintRegressionFoldsAverageMetrics(IEnumerable> crossValidationResults) + { + var L1 = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError); + var L2 = crossValidationResults.Select(r => r.Metrics.MeanSquaredError); + var RMS = crossValidationResults.Select(r => r.Metrics.MeanAbsoluteError); + var lossFunction = crossValidationResults.Select(r => r.Metrics.LossFunction); + var R2 = crossValidationResults.Select(r => r.Metrics.RSquared); + + Console.WriteLine($"*************************************************************************************************************"); + Console.WriteLine($"* Metrics for Regression model "); + Console.WriteLine($"*------------------------------------------------------------------------------------------------------------"); + Console.WriteLine($"* Average L1 Loss: {L1.Average():0.###} "); + Console.WriteLine($"* Average L2 Loss: {L2.Average():0.###} "); + Console.WriteLine($"* Average RMS: {RMS.Average():0.###} "); + Console.WriteLine($"* Average Loss Function: {lossFunction.Average():0.###} "); + Console.WriteLine($"* Average R-squared: {R2.Average():0.###} "); + Console.WriteLine($"*************************************************************************************************************"); + } + } +} diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppProgramCSFileContentTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppProgramCSFileContentTest.approved.txt new file mode 100644 index 0000000000..34ce3713fa --- /dev/null +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppProgramCSFileContentTest.approved.txt @@ -0,0 +1,66 @@ +//***************************************************************************************** +//* * +//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * +//* * +//***************************************************************************************** + +using System; +using System.IO; +using System.Linq; +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Data; +using TestNamespace.Model.DataModels; + + +namespace TestNamespace.ConsoleApp +{ + class Program + { + //Machine Learning model to load and use for predictions + private const string MODEL_FILEPATH = @"MLModel.zip"; + + //Dataset to use for predictions + private const string DATA_FILEPATH = @"x:\dummypath\dummy_test.csv"; + + static void Main(string[] args) + { + MLContext mlContext = new MLContext(); + + // Training code used by ML.NET CLI and AutoML to generate the model + //ModelBuilder.CreateModel(); + + ITransformer mlModel = mlContext.Model.Load(MODEL_FILEPATH, out DataViewSchema inputSchema); + var predEngine = mlContext.Model.CreatePredictionEngine(mlModel); + + // Create sample data to do a single prediction with it + SampleObservation sampleData = CreateSingleDataSample(mlContext, DATA_FILEPATH); + + // Try a single prediction + SamplePrediction predictionResult = predEngine.Predict(sampleData); + + Console.WriteLine($"Single Prediction --> Actual value: {sampleData.Label} | Predicted value: {predictionResult.Prediction}"); + + Console.WriteLine("=============== End of process, hit any key to finish ==============="); + Console.ReadKey(); + } + + // Method to load single row of data to try a single prediction + // You can change this code and create your own sample data here (Hardcoded or from any source) + private static SampleObservation CreateSingleDataSample(MLContext mlContext, string dataFilePath) + { + // Read dataset to get a single row for trying a prediction + IDataView dataView = mlContext.Data.LoadFromTextFile( + path: dataFilePath, + hasHeader: true, + separatorChar: ',', + allowQuoting: true, + allowSparse: true); + + // Here (SampleObservation object) you could provide new test data, hardcoded or from the end-user application, instead of the row from the file. + SampleObservation sampleForPrediction = mlContext.Data.CreateEnumerable(dataView, false) + .First(); + return sampleForPrediction; + } + } +} diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppProjectFileContentTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppProjectFileContentTest.approved.txt new file mode 100644 index 0000000000..789c3638c0 --- /dev/null +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ConsoleAppProjectFileContentTest.approved.txt @@ -0,0 +1,15 @@ + + + + Exe + netcoreapp2.1 + + + + + + + + + + diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ModelProjectFileContentTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ModelProjectFileContentTest.approved.txt new file mode 100644 index 0000000000..8f1acbadb8 --- /dev/null +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ModelProjectFileContentTest.approved.txt @@ -0,0 +1,21 @@ + + + + netcoreapp2.1 + + + + https://api.nuget.org/v3/index.json; + + + + + + + + + PreserveNewest + + + + diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ObservationCSFileContentTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ObservationCSFileContentTest.approved.txt new file mode 100644 index 0000000000..12f935ee2a --- /dev/null +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.ObservationCSFileContentTest.approved.txt @@ -0,0 +1,38 @@ +//***************************************************************************************** +//* * +//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * +//* * +//***************************************************************************************** + +using Microsoft.ML.Data; + +namespace TestNamespace.Model.DataModels +{ + public class SampleObservation + { + [ColumnName("Label"), LoadColumn(0)] + public bool Label { get; set; } + + + [ColumnName("col1"), LoadColumn(1)] + public float Col1 { get; set; } + + + [ColumnName("col2"), LoadColumn(0)] + public float Col2 { get; set; } + + + [ColumnName("col3"), LoadColumn(0)] + public string Col3 { get; set; } + + + [ColumnName("col4"), LoadColumn(0)] + public int Col4 { get; set; } + + + [ColumnName("col5"), LoadColumn(0)] + public uint Col5 { get; set; } + + + } +} diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.PredictionCSFileContentTest.approved.txt b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.PredictionCSFileContentTest.approved.txt new file mode 100644 index 0000000000..4e0a7e5b9c --- /dev/null +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.PredictionCSFileContentTest.approved.txt @@ -0,0 +1,21 @@ +//***************************************************************************************** +//* * +//* This is an auto-generated file by Microsoft ML.NET CLI (Command-Line Interface) tool. * +//* * +//***************************************************************************************** + +using System; +using Microsoft.ML.Data; + +namespace TestNamespace.Model.DataModels +{ + public class SamplePrediction + { + // ColumnName attribute is used to change the column name from + // its default value, which is the name of the field. + [ColumnName("PredictedLabel")] + public bool Prediction { get; set; } + + public float Score { get; set; } + } +} diff --git a/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.cs b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.cs new file mode 100644 index 0000000000..d902ff9a64 --- /dev/null +++ b/test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.cs @@ -0,0 +1,324 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using ApprovalTests; +using ApprovalTests.Reporters; +using Microsoft.ML; +using Microsoft.ML.Auto; +using Microsoft.ML.CLI.CodeGenerator.CSharp; +using Microsoft.ML.Data; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace mlnet.Tests +{ + [TestClass] + [UseReporter(typeof(DiffReporter))] + public class ConsoleCodeGeneratorTests + { + private Pipeline mockedPipeline; + private Pipeline mockedOvaPipeline; + private ColumnInferenceResults columnInference = default; + private string namespaceValue = "TestNamespace"; + + + [TestMethod] + [UseReporter(typeof(DiffReporter))] + [MethodImpl(MethodImplOptions.NoInlining)] + public void ConsoleAppModelBuilderCSFileContentOvaTest() + { + (Pipeline pipeline, + ColumnInferenceResults columnInference) = GetMockedOvaPipelineAndInference(); + + var consoleCodeGen = new CodeGenerator(pipeline, columnInference, new CodeGeneratorSettings() + { + MlTask = TaskKind.MulticlassClassification, + OutputBaseDir = null, + OutputName = "MyNamespace", + TrainDataset = "x:\\dummypath\\dummy_train.csv", + TestDataset = "x:\\dummypath\\dummy_test.csv", + LabelName = "Label", + ModelPath = "x:\\models\\model.zip" + }); + var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float)); + + Approvals.Verify(result.modelBuilderCSFileContent); + } + + [TestMethod] + [UseReporter(typeof(DiffReporter))] + [MethodImpl(MethodImplOptions.NoInlining)] + public void ConsoleAppModelBuilderCSFileContentBinaryTest() + { + (Pipeline pipeline, + ColumnInferenceResults columnInference) = GetMockedBinaryPipelineAndInference(); + + var consoleCodeGen = new CodeGenerator(pipeline, columnInference, new CodeGeneratorSettings() + { + MlTask = TaskKind.BinaryClassification, + OutputBaseDir = null, + OutputName = "MyNamespace", + TrainDataset = "x:\\dummypath\\dummy_train.csv", + TestDataset = "x:\\dummypath\\dummy_test.csv", + LabelName = "Label", + ModelPath = "x:\\models\\model.zip" + }); + var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float)); + + Approvals.Verify(result.modelBuilderCSFileContent); + } + + [TestMethod] + [UseReporter(typeof(DiffReporter))] + [MethodImpl(MethodImplOptions.NoInlining)] + public void ConsoleAppModelBuilderCSFileContentRegressionTest() + { + (Pipeline pipeline, + ColumnInferenceResults columnInference) = GetMockedRegressionPipelineAndInference(); + + var consoleCodeGen = new CodeGenerator(pipeline, columnInference, new CodeGeneratorSettings() + { + MlTask = TaskKind.Regression, + OutputBaseDir = null, + OutputName = "MyNamespace", + TrainDataset = "x:\\dummypath\\dummy_train.csv", + TestDataset = "x:\\dummypath\\dummy_test.csv", + LabelName = "Label", + ModelPath = "x:\\models\\model.zip" + }); + var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float)); + + Approvals.Verify(result.modelBuilderCSFileContent); + } + + [TestMethod] + [UseReporter(typeof(DiffReporter))] + [MethodImpl(MethodImplOptions.NoInlining)] + public void ModelProjectFileContentTest() + { + (Pipeline pipeline, + ColumnInferenceResults columnInference) = GetMockedBinaryPipelineAndInference(); + + var consoleCodeGen = new CodeGenerator(pipeline, columnInference, new CodeGeneratorSettings() + { + MlTask = TaskKind.BinaryClassification, + OutputBaseDir = null, + OutputName = "MyNamespace", + TrainDataset = "x:\\dummypath\\dummy_train.csv", + TestDataset = "x:\\dummypath\\dummy_test.csv", + LabelName = "Label", + ModelPath = "x:\\models\\model.zip" + }); + var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float)); + + Approvals.Verify(result.ModelProjectFileContent); + } + + [TestMethod] + [UseReporter(typeof(DiffReporter))] + [MethodImpl(MethodImplOptions.NoInlining)] + public void ObservationCSFileContentTest() + { + (Pipeline pipeline, + ColumnInferenceResults columnInference) = GetMockedBinaryPipelineAndInference(); + + var consoleCodeGen = new CodeGenerator(pipeline, columnInference, new CodeGeneratorSettings() + { + MlTask = TaskKind.BinaryClassification, + OutputBaseDir = null, + OutputName = "MyNamespace", + TrainDataset = "x:\\dummypath\\dummy_train.csv", + TestDataset = "x:\\dummypath\\dummy_test.csv", + LabelName = "Label", + ModelPath = "x:\\models\\model.zip" + }); + var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float)); + + Approvals.Verify(result.ObservationCSFileContent); + } + + + [TestMethod] + [UseReporter(typeof(DiffReporter))] + [MethodImpl(MethodImplOptions.NoInlining)] + public void PredictionCSFileContentTest() + { + (Pipeline pipeline, + ColumnInferenceResults columnInference) = GetMockedBinaryPipelineAndInference(); + + var consoleCodeGen = new CodeGenerator(pipeline, columnInference, new CodeGeneratorSettings() + { + MlTask = TaskKind.BinaryClassification, + OutputBaseDir = null, + OutputName = "MyNamespace", + TrainDataset = "x:\\dummypath\\dummy_train.csv", + TestDataset = "x:\\dummypath\\dummy_test.csv", + LabelName = "Label", + ModelPath = "x:\\models\\model.zip" + }); + var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float)); + + Approvals.Verify(result.PredictionCSFileContent); + } + + [TestMethod] + [UseReporter(typeof(DiffReporter))] + [MethodImpl(MethodImplOptions.NoInlining)] + public void ConsoleAppProgramCSFileContentTest() + { + (Pipeline pipeline, + ColumnInferenceResults columnInference) = GetMockedBinaryPipelineAndInference(); + + var consoleCodeGen = new CodeGenerator(pipeline, columnInference, new CodeGeneratorSettings() + { + MlTask = TaskKind.BinaryClassification, + OutputBaseDir = null, + OutputName = "MyNamespace", + TrainDataset = "x:\\dummypath\\dummy_train.csv", + TestDataset = "x:\\dummypath\\dummy_test.csv", + LabelName = "Label", + ModelPath = "x:\\models\\model.zip" + }); + var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float)); + + Approvals.Verify(result.ConsoleAppProgramCSFileContent); + } + + [TestMethod] + [UseReporter(typeof(DiffReporter))] + [MethodImpl(MethodImplOptions.NoInlining)] + public void ConsoleAppProjectFileContentTest() + { + (Pipeline pipeline, + ColumnInferenceResults columnInference) = GetMockedBinaryPipelineAndInference(); + + var consoleCodeGen = new CodeGenerator(pipeline, columnInference, new CodeGeneratorSettings() + { + MlTask = TaskKind.BinaryClassification, + OutputBaseDir = null, + OutputName = "MyNamespace", + TrainDataset = "x:\\dummypath\\dummy_train.csv", + TestDataset = "x:\\dummypath\\dummy_test.csv", + LabelName = "Label", + ModelPath = "x:\\models\\model.zip" + }); + var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float)); + + Approvals.Verify(result.ConsoleAppProjectFileContent); + } + + private (Pipeline, ColumnInferenceResults) GetMockedBinaryPipelineAndInference() + { + if (mockedPipeline == null) + { + MLContext context = new MLContext(); + // same learners with different hyperparams + var hyperparams1 = new Microsoft.ML.Auto.ParameterSet(new List() { new LongParameterValue("NumLeaves", 2) }); + var trainer1 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), new ColumnInformation(), hyperparams1); + var transforms1 = new List() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") }; + var inferredPipeline1 = new SuggestedPipeline(transforms1, new List(), trainer1, context, true); + + this.mockedPipeline = inferredPipeline1.ToPipeline(); + var textLoaderArgs = new TextLoader.Options() + { + Columns = new[] { + new TextLoader.Column("Label", DataKind.Boolean, 0), + new TextLoader.Column("col1", DataKind.Single, 1), + new TextLoader.Column("col2", DataKind.Single, 0), + new TextLoader.Column("col3", DataKind.String, 0), + new TextLoader.Column("col4", DataKind.Int32, 0), + new TextLoader.Column("col5", DataKind.UInt32, 0), + }, + AllowQuoting = true, + AllowSparse = true, + HasHeader = true, + Separators = new[] { ',' } + }; + + this.columnInference = new ColumnInferenceResults() + { + TextLoaderOptions = textLoaderArgs, + ColumnInformation = new ColumnInformation() { LabelColumnName = "Label" } + }; + } + return (mockedPipeline, columnInference); + } + + private (Pipeline, ColumnInferenceResults) GetMockedRegressionPipelineAndInference() + { + if (mockedPipeline == null) + { + MLContext context = new MLContext(); + // same learners with different hyperparams + var hyperparams1 = new Microsoft.ML.Auto.ParameterSet(new List() { new LongParameterValue("NumLeaves", 2) }); + var trainer1 = new SuggestedTrainer(context, new LightGbmRegressionExtension(), new ColumnInformation(), hyperparams1); + var transforms1 = new List() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") }; + var inferredPipeline1 = new SuggestedPipeline(transforms1, new List(), trainer1, context, true); + + this.mockedPipeline = inferredPipeline1.ToPipeline(); + var textLoaderArgs = new TextLoader.Options() + { + Columns = new[] { + new TextLoader.Column("Label", DataKind.Boolean, 0), + new TextLoader.Column("col1", DataKind.Single, 1), + new TextLoader.Column("col2", DataKind.Single, 0), + new TextLoader.Column("col3", DataKind.String, 0), + new TextLoader.Column("col4", DataKind.Int32, 0), + new TextLoader.Column("col5", DataKind.UInt32, 0), + }, + AllowQuoting = true, + AllowSparse = true, + HasHeader = true, + Separators = new[] { ',' } + }; + + this.columnInference = new ColumnInferenceResults() + { + TextLoaderOptions = textLoaderArgs, + ColumnInformation = new ColumnInformation() { LabelColumnName = "Label" } + }; + } + return (mockedPipeline, columnInference); + } + private (Pipeline, ColumnInferenceResults) GetMockedOvaPipelineAndInference() + { + if (mockedOvaPipeline == null) + { + MLContext context = new MLContext(); + // same learners with different hyperparams + var hyperparams1 = new Microsoft.ML.Auto.ParameterSet(new List() { new LongParameterValue("NumLeaves", 2) }); + var trainer1 = new SuggestedTrainer(context, new FastForestOvaExtension(), new ColumnInformation(), hyperparams1); + var transforms1 = new List() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") }; + var inferredPipeline1 = new SuggestedPipeline(transforms1, new List(), trainer1, context, true); + + this.mockedOvaPipeline = inferredPipeline1.ToPipeline(); + var textLoaderArgs = new TextLoader.Options() + { + Columns = new[] { + new TextLoader.Column("Label", DataKind.Boolean, 0), + new TextLoader.Column("col1", DataKind.Single, 1), + new TextLoader.Column("col2", DataKind.Single, 0), + new TextLoader.Column("col3", DataKind.String, 0), + new TextLoader.Column("col4", DataKind.Int32, 0), + new TextLoader.Column("col5", DataKind.UInt32, 0), + }, + AllowQuoting = true, + AllowSparse = true, + HasHeader = true, + Separators = new[] { ',' } + }; + + + this.columnInference = new ColumnInferenceResults() + { + TextLoaderOptions = textLoaderArgs, + ColumnInformation = new ColumnInformation() { LabelColumnName = "Label" } + }; + + } + return (mockedOvaPipeline, columnInference); + } + } +} diff --git a/test/mlnet.Tests/CodeGenTests.cs b/test/mlnet.Tests/CodeGenTests.cs new file mode 100644 index 0000000000..2ab1dfdd64 --- /dev/null +++ b/test/mlnet.Tests/CodeGenTests.cs @@ -0,0 +1,137 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML; +using Microsoft.ML.Auto; +using Microsoft.ML.CLI.CodeGenerator.CSharp; +using Microsoft.ML.Data; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace mlnet.Tests +{ + [TestClass] + public class CodeGeneratorTests + { + [TestMethod] + public void TrainerGeneratorBasicNamedParameterTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"LearningRate", 0.1f }, + {"NumLeaves", 1 }, + }; + PipelineNode node = new PipelineNode("LightGbmBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expected = "LightGbm(new LightGbmBinaryTrainer.Options(){LearningRate=0.1f,NumLeaves=1,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expected, actual.Item1); + Assert.AreEqual(1, actual.Item2.Count()); + Assert.AreEqual("using Microsoft.ML.Trainers.LightGbm;\r\n", actual.Item2.First()); + } + + [TestMethod] + public void TrainerGeneratorBasicAdvancedParameterTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"LearningRate", 0.1f }, + {"NumLeaves", 1 }, + {"UseSoftmax", true } + }; + PipelineNode node = new PipelineNode("LightGbmBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainer = "LightGbm(new LightGbmBinaryTrainer.Options(){LearningRate=0.1f,NumLeaves=1,UseSoftmax=true,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + string expectedUsing = "using Microsoft.ML.Trainers.LightGbm;\r\n"; + Assert.AreEqual(expectedTrainer, actual.Item1); + Assert.AreEqual(expectedUsing, actual.Item2[0]); + } + + [TestMethod] + public void TransformGeneratorBasicTest() + { + var context = new MLContext(); + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("Normalizing", PipelineNodeType.Transform, new string[] { "Label" }, new string[] { "Label" }, elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTransformsAndUsings(new List() { node }); + string expected = "NormalizeMinMax(\"Label\",\"Label\")"; + Assert.AreEqual(expected, actual[0].Item1); + Assert.IsNull(actual[0].Item2); + } + + [TestMethod] + public void TransformGeneratorUsingTest() + { + var context = new MLContext(); + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("OneHotEncoding", PipelineNodeType.Transform, new string[] { "Label" }, new string[] { "Label" }, elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTransformsAndUsings(new List() { node }); + string expectedTransform = "Categorical.OneHotEncoding(new []{new InputOutputColumnPair(\"Label\",\"Label\")})"; + Assert.AreEqual(expectedTransform, actual[0].Item1); + Assert.IsNull(actual[0].Item2); + } + + [TestMethod] + public void ClassLabelGenerationBasicTest() + { + var columns = new TextLoader.Column[] + { + new TextLoader.Column(){ Name = DefaultColumnNames.Label, Source = new TextLoader.Range[]{new TextLoader.Range(0) }, DataKind = DataKind.Boolean }, + }; + + var result = new ColumnInferenceResults() + { + TextLoaderOptions = new TextLoader.Options() + { + Columns = columns, + AllowQuoting = false, + AllowSparse = false, + Separators = new[] { ',' }, + HasHeader = true, + TrimWhitespace = true + }, + ColumnInformation = new ColumnInformation() + }; + + CodeGenerator codeGenerator = new CodeGenerator(null, result, null); + var actual = codeGenerator.GenerateClassLabels(); + var expected1 = "[ColumnName(\"Label\"), LoadColumn(0)]"; + var expected2 = "public bool Label{get; set;}"; + + Assert.AreEqual(expected1, actual[0]); + Assert.AreEqual(expected2, actual[1]); + } + + [TestMethod] + public void TrainerComplexParameterTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"Booster", new CustomProperty(){Properties= new Dictionary(), Name = "TreeBooster"} }, + }; + PipelineNode node = new PipelineNode("LightGbmBinary", PipelineNodeType.Trainer, new string[] { "Label" }, default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainer = "LightGbm(new LightGbmBinaryTrainer.Options(){Booster=new TreeBooster(){},LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + var expectedUsings = "using Microsoft.ML.Trainers.LightGbm;\r\n"; + Assert.AreEqual(expectedTrainer, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + } + } +} diff --git a/test/mlnet.Tests/CommandLineTests.cs b/test/mlnet.Tests/CommandLineTests.cs new file mode 100644 index 0000000000..3530f23271 --- /dev/null +++ b/test/mlnet.Tests/CommandLineTests.cs @@ -0,0 +1,302 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.CommandLine.Builder; +using System.CommandLine.Invocation; +using System.IO; +using System.Linq; +using Microsoft.ML.CLI.Commands; +using Microsoft.ML.CLI.Data; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace mlnet.Tests +{ + [TestClass] + public class CommandLineTests + { + [TestMethod] + public void TestMinimumCommandLineArgs() + { + bool parsingSuccessful = false; + + // Create handler outside so that commandline and the handler is decoupled and testable. + var handler = CommandHandler.Create( + (opt) => + { + parsingSuccessful = true; + }); + + var parser = new CommandLineBuilder() + // Parser + .AddCommand(CommandDefinitions.AutoTrain(handler)) + .UseDefaults() + .UseExceptionHandler((e, ctx) => + { + Console.WriteLine(e.ToString()); + }) + .Build(); + + var trainDataset = Path.GetTempFileName(); + var testDataset = Path.GetTempFileName(); + string[] args = new[] { "auto-train", "--ml-task", "binary-classification", "--dataset", trainDataset, "--label-column-name", "Label" }; + parser.InvokeAsync(args).Wait(); + File.Delete(trainDataset); + File.Delete(testDataset); + Assert.IsTrue(parsingSuccessful); + } + + + [TestMethod] + public void TestCommandLineArgsFailTest() + { + bool parsingSuccessful = false; + + // Create handler outside so that commandline and the handler is decoupled and testable. + var handler = CommandHandler.Create( + (opt) => + { + parsingSuccessful = true; + }); + + var parser = new CommandLineBuilder() + // parser + .AddCommand(CommandDefinitions.AutoTrain(handler)) + .UseDefaults() + .UseExceptionHandler((e, ctx) => + { + Console.WriteLine(e.ToString()); + }) + .Build(); + + // Incorrect mltask test + var trainDataset = Path.GetTempFileName(); + var testDataset = Path.GetTempFileName(); + + //wrong value to ml-task + string[] args = new[] { "auto-train", "--ml-task", "bad-value", "--train-dataset", trainDataset, "--label-column-name", "Label" }; + parser.InvokeAsync(args).Wait(); + Assert.IsFalse(parsingSuccessful); + + // Incorrect invocation + args = new[] { "auto-train", "binary-classification", "--train-dataset", trainDataset, "--label-column-name", "Label" }; + parser.InvokeAsync(args).Wait(); + Assert.IsFalse(parsingSuccessful); + + // Non-existent file test + args = new[] { "auto-train", "--ml-task", "binary-classification", "--train-dataset", "nonexistentfile.csv", "--label-column-name", "Label" }; + parser.InvokeAsync(args).Wait(); + Assert.IsFalse(parsingSuccessful); + + // No label column or index test + args = new[] { "auto-train", "--ml-task", "binary-classification", "--train-dataset", trainDataset, "--test-dataset", testDataset }; + parser.InvokeAsync(args).Wait(); + File.Delete(trainDataset); + File.Delete(testDataset); + Assert.IsFalse(parsingSuccessful); + } + + [TestMethod] + public void TestCommandLineArgsValuesTest() + { + bool parsingSuccessful = false; + var trainDataset = Path.GetTempFileName(); + var testDataset = Path.GetTempFileName(); + var validDataset = Path.GetTempFileName(); + var labelName = "Label"; + var name = "testname"; + var outputPath = Path.GetTempPath(); + var falseString = "false"; + + // Create handler outside so that commandline and the handler is decoupled and testable. + var handler = CommandHandler.Create( + (opt) => + { + Assert.AreEqual(opt.MlTask, "binary-classification"); + Assert.AreEqual(opt.Dataset.FullName, trainDataset); + Assert.AreEqual(opt.TestDataset.FullName, testDataset); + Assert.AreEqual(opt.ValidationDataset.FullName, validDataset); + Assert.AreEqual(opt.LabelColumnName, labelName); + Assert.AreEqual(opt.MaxExplorationTime, (uint)5); + Assert.AreEqual(opt.Name, name); + Assert.AreEqual(opt.OutputPath.FullName, outputPath); + Assert.AreEqual(opt.HasHeader, bool.Parse(falseString)); + parsingSuccessful = true; + }); + + var parser = new CommandLineBuilder() + // Parser + .AddCommand(CommandDefinitions.AutoTrain(handler)) + .UseDefaults() + .UseExceptionHandler((e, ctx) => + { + Console.WriteLine(e.ToString()); + }) + .Build(); + + // Incorrect mltask test + string[] args = new[] { "auto-train", "--ml-task", "binary-classification", "--dataset", trainDataset, "--label-column-name", labelName, "--validation-dataset", validDataset, "--test-dataset", testDataset, "--max-exploration-time", "5", "--name", name, "--output-path", outputPath, "--has-header", falseString }; + parser.InvokeAsync(args).Wait(); + File.Delete(trainDataset); + File.Delete(testDataset); + File.Delete(validDataset); + Assert.IsTrue(parsingSuccessful); + + } + + [TestMethod] + public void TestCommandLineArgsMutuallyExclusiveArgsTest() + { + bool parsingSuccessful = false; + var dataset = Path.GetTempFileName(); + var trainDataset = Path.GetTempFileName(); + var testDataset = Path.GetTempFileName(); + var labelName = "Label"; + + // Create handler outside so that commandline and the handler is decoupled and testable. + var handler = CommandHandler.Create( + (opt) => + { + parsingSuccessful = true; + }); + + var parser = new CommandLineBuilder() + // Parser + .AddCommand(CommandDefinitions.AutoTrain(handler)) + .UseDefaults() + .UseExceptionHandler((e, ctx) => + { + Console.WriteLine(e.ToString()); + }) + .Build(); + + // Incorrect arguments : specifying dataset and train-dataset + string[] args = new[] { "auto-train", "--ml-task", "BinaryClassification", "--dataset", dataset, "--train-dataset", trainDataset, "--label-column-name", labelName, "--test-dataset", testDataset, "--max-exploration-time", "5" }; + parser.InvokeAsync(args).Wait(); + Assert.IsFalse(parsingSuccessful); + + // Incorrect arguments : specifying train-dataset and not specifying test-dataset + args = new[] { "auto-train", "--ml-task", "BinaryClassification", "--train-dataset", trainDataset, "--label-column-name", labelName, "--max-exploration-time", "5" }; + parser.InvokeAsync(args).Wait(); + Assert.IsFalse(parsingSuccessful); + + // Incorrect arguments : specifying label column name and index + args = new[] { "auto-train", "--ml-task", "BinaryClassification", "--train-dataset", trainDataset, "--label-column-name", labelName, "--label-column-index", "0", "--test-dataset", testDataset, "--max-exploration-time", "5" }; + parser.InvokeAsync(args).Wait(); + File.Delete(trainDataset); + File.Delete(testDataset); + File.Delete(dataset); + Assert.IsFalse(parsingSuccessful); + + } + + [TestMethod] + public void CacheArgumentTest() + { + bool parsingSuccessful = false; + var trainDataset = Path.GetTempFileName(); + var testDataset = Path.GetTempFileName(); + var labelName = "Label"; + var cache = "on"; + + // Create handler outside so that commandline and the handler is decoupled and testable. + var handler = CommandHandler.Create( + (opt) => + { + Assert.AreEqual(opt.MlTask, "binary-classification"); + Assert.AreEqual(opt.Dataset.FullName, trainDataset); + Assert.AreEqual(opt.LabelColumnName, labelName); + Assert.AreEqual(opt.Cache, cache); + parsingSuccessful = true; + }); + + var parser = new CommandLineBuilder() + // Parser + .AddCommand(CommandDefinitions.AutoTrain(handler)) + .UseDefaults() + .UseExceptionHandler((e, ctx) => + { + Console.WriteLine(e.ToString()); + }) + .Build(); + + // valid cache test + string[] args = new[] { "auto-train", "--ml-task", "binary-classification", "--dataset", trainDataset, "--label-column-name", labelName, "--cache", cache }; + parser.InvokeAsync(args).Wait(); + Assert.IsTrue(parsingSuccessful); + + parsingSuccessful = false; + + cache = "off"; + // valid cache test + args = new[] { "auto-train", "--ml-task", "binary-classification", "--dataset", trainDataset, "--label-column-name", labelName, "--cache", cache }; + parser.InvokeAsync(args).Wait(); + Assert.IsTrue(parsingSuccessful); + + parsingSuccessful = false; + + cache = "auto"; + // valid cache test + args = new[] { "auto-train", "--ml-task", "binary-classification", "--dataset", trainDataset, "--label-column-name", labelName, "--cache", cache }; + parser.InvokeAsync(args).Wait(); + Assert.IsTrue(parsingSuccessful); + + parsingSuccessful = false; + + // invalid cache test + args = new[] { "auto-train", "--ml-task", "binary-classification", "--dataset", trainDataset, "--label-column-name", labelName, "--cache", "blah" }; + parser.InvokeAsync(args).Wait(); + Assert.IsFalse(parsingSuccessful); + + File.Delete(trainDataset); + File.Delete(testDataset); + } + + [TestMethod] + public void IgnoreColumnsArgumentTest() + { + bool parsingSuccessful = false; + var trainDataset = Path.GetTempFileName(); + var testDataset = Path.GetTempFileName(); + var labelName = "Label"; + var ignoreColumns = "a,b,c"; + + // Create handler outside so that commandline and the handler is decoupled and testable. + var handler = CommandHandler.Create( + (opt) => + { + Assert.AreEqual(opt.MlTask, "binary-classification"); + Assert.AreEqual(opt.Dataset.FullName, trainDataset); + Assert.AreEqual(opt.LabelColumnName, labelName); + Assert.IsTrue(opt.IgnoreColumns.SequenceEqual(new List() { "a", "b", "c" })); + parsingSuccessful = true; + }); + + var parser = new CommandLineBuilder() + // Parser + .AddCommand(CommandDefinitions.AutoTrain(handler)) + .UseDefaults() + .UseExceptionHandler((e, ctx) => + { + Console.WriteLine(e.ToString()); + }) + .Build(); + + // valid cache test + string[] args = new[] { "auto-train", "--ml-task", "binary-classification", "--dataset", trainDataset, "--label-column-name", labelName, "--ignore-columns", ignoreColumns }; + parser.InvokeAsync(args).Wait(); + Assert.IsTrue(parsingSuccessful); + + parsingSuccessful = false; + + args = new[] { "auto-train", "--ml-task", "binary-classification", "--dataset", trainDataset, "--label-column-name", labelName, "--ignore-columns", "a b c" }; + parser.InvokeAsync(args).Wait(); + Assert.IsFalse(parsingSuccessful); + + File.Delete(trainDataset); + File.Delete(testDataset); + } + } +} diff --git a/test/mlnet.Tests/DatasetUtil.cs b/test/mlnet.Tests/DatasetUtil.cs new file mode 100644 index 0000000000..e1cf2cfc6b --- /dev/null +++ b/test/mlnet.Tests/DatasetUtil.cs @@ -0,0 +1,62 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using System.Net; +using Microsoft.ML; +using Microsoft.ML.Auto; + +namespace mlnet.Tests +{ + internal static class DatasetUtil + { + public const string UciAdultLabel = DefaultColumnNames.Label; + public const string TrivialDatasetLabel = DefaultColumnNames.Label; + public const string MlNetGeneratedRegressionLabel = "target"; + public const int IrisDatasetLabelColIndex = 0; + + private static IDataView _uciAdultDataView; + + public static IDataView GetUciAdultDataView() + { + if (_uciAdultDataView == null) + { + var context = new MLContext(); + var uciAdultDataFile = DownloadUciAdultDataset(); + var columnInferenceResult = context.Auto().InferColumns(uciAdultDataFile, UciAdultLabel); + var textLoader = context.Data.CreateTextLoader(columnInferenceResult.TextLoaderOptions); + _uciAdultDataView = textLoader.Load(uciAdultDataFile); + } + return _uciAdultDataView; + } + + // downloads the UCI Adult dataset from the ML.Net repo + public static string DownloadUciAdultDataset() => + DownloadIfNotExists("https://raw.githubusercontent.com/dotnet/machinelearning/f0e639af5ffdc839aae8e65d19b5a9a1f0db634a/test/data/adult.tiny.with-schema.txt", "uciadult.dataset"); + + public static string DownloadTrivialDataset() => + DownloadIfNotExists("https://raw.githubusercontent.com/dotnet/machinelearning/eae76959e6714af44caa212e102a5f06f0110e72/test/data/trivial-train.tsv", "trivial.dataset"); + + public static string DownloadMlNetGeneratedRegressionDataset() => + DownloadIfNotExists("https://raw.githubusercontent.com/dotnet/machinelearning/e78971ea6fd736038b4c355b840e5cbabae8cb55/test/data/generated_regression_dataset.csv", "mlnet_generated_regression.dataset"); + + public static string DownloadIrisDataset() => + DownloadIfNotExists("https://raw.githubusercontent.com/dotnet/machinelearning/54596ac/test/data/iris.txt", "iris.dataset"); + + private static string DownloadIfNotExists(string baseGitPath, string dataFile) + { + // if file doesn't already exist, download it + if (!File.Exists(dataFile)) + { + using (var client = new WebClient()) + { + client.DownloadFile(new Uri($"{baseGitPath}"), dataFile); + } + } + + return dataFile; + } + } +} diff --git a/test/mlnet.Tests/Directory.Build.props b/test/mlnet.Tests/Directory.Build.props new file mode 100644 index 0000000000..e161d1461b --- /dev/null +++ b/test/mlnet.Tests/Directory.Build.props @@ -0,0 +1,9 @@ + + + + + trx + $(OutputPath) + + + \ No newline at end of file diff --git a/test/mlnet.Tests/TrainerGeneratorTests.cs b/test/mlnet.Tests/TrainerGeneratorTests.cs new file mode 100644 index 0000000000..122e8c64ae --- /dev/null +++ b/test/mlnet.Tests/TrainerGeneratorTests.cs @@ -0,0 +1,694 @@ +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Auto; +using Microsoft.ML.CLI.CodeGenerator.CSharp; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace mlnet.Tests +{ + /**************************** + * TODO : Add all trainer tests : + * **************************/ + [TestClass] + public class TrainerGeneratorTests + { + [TestMethod] + public void LightGbmBinaryBasicTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"LearningRate", 0.1f }, + {"NumberOfLeaves", 1 }, + }; + PipelineNode node = new PipelineNode("LightGbmBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "LightGbm(learningRate:0.1f,numberOfLeaves:1,labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void LightGbmBinaryAdvancedParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"LearningRate", 0.1f }, + {"NumLeaves", 1 }, + {"UseSoftmax", true } + }; + PipelineNode node = new PipelineNode("LightGbmBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "LightGbm(new LightGbmBinaryTrainer.Options(){LearningRate=0.1f,NumLeaves=1,UseSoftmax=true,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + string expectedUsings = "using Microsoft.ML.Trainers.LightGbm;\r\n"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + [TestMethod] + public void SymbolicSgdLogisticRegressionBinaryBasicTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("SymbolicSgdLogisticRegressionBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "SymbolicSgdLogisticRegression(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void SymbolicSgdLogisticRegressionBinaryAdvancedParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"LearningRate", 0.1f }, + }; + PipelineNode node = new PipelineNode("SymbolicSgdLogisticRegressionBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; + string expectedTrainerString = "SymbolicSgdLogisticRegression(new SymbolicSgdLogisticRegressionBinaryTrainer.Options(){LearningRate=0.1f,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + [TestMethod] + public void SgdCalibratedBinaryBasicTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("SgdCalibratedBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "SgdCalibrated(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void SgdCalibratedBinaryAdvancedParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"Shuffle", true }, + }; + PipelineNode node = new PipelineNode("SgdCalibratedBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; + string expectedTrainerString = "SgdCalibrated(new SgdCalibratedTrainer.Options(){Shuffle=true,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + [TestMethod] + public void SdcaLogisticRegressionBinaryBasicTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("SdcaLogisticRegressionBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "SdcaLogisticRegression(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void SdcaLogisticRegressionBinaryAdvancedParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"BiasLearningRate", 0.1f }, + }; + PipelineNode node = new PipelineNode("SdcaLogisticRegressionBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; + string expectedTrainerString = "SdcaLogisticRegression(new SdcaLogisticRegressionBinaryTrainer.Options(){BiasLearningRate=0.1f,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + [TestMethod] + public void SdcaMaximumEntropyMultiBasicTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("SdcaMaximumEntropyMulti", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "SdcaMaximumEntropy(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void SdcaMaximumEntropyMultiAdvancedParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"BiasLearningRate", 0.1f }, + }; + PipelineNode node = new PipelineNode("SdcaMaximumEntropyMulti", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; + string expectedTrainerString = "SdcaMaximumEntropy(new SdcaMaximumEntropyMulticlassTrainer.Options(){BiasLearningRate=0.1f,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + [TestMethod] + public void SdcaRegressionBasicTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("SdcaRegression", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "Sdca(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void SdcaRegressionAdvancedParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"BiasLearningRate", 0.1f }, + }; + PipelineNode node = new PipelineNode("SdcaRegression", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; + string expectedTrainerString = "Sdca(new SdcaRegressionTrainer.Options(){BiasLearningRate=0.1f,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + [TestMethod] + public void LbfgsPoissonRegressionBasicTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("LbfgsPoissonRegression", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "LbfgsPoissonRegression(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void LbfgsPoissonRegressionAdvancedParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"MaximumNumberOfIterations", 1 }, + }; + PipelineNode node = new PipelineNode("LbfgsPoissonRegression", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; + string expectedTrainerString = "LbfgsPoissonRegression(new LbfgsPoissonRegressionTrainer.Options(){MaximumNumberOfIterations=1,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + [TestMethod] + public void OlsRegressionBasicTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("OlsRegression", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "Ols(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void OlsRegressionAdvancedParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"L2Regularization", 0.1f }, + }; + PipelineNode node = new PipelineNode("OlsRegression", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; + string expectedTrainerString = "Ols(new OlsTrainer.Options(){L2Regularization=0.1f,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + [TestMethod] + public void OnlineGradientDescentRegressionBasicTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("OnlineGradientDescentRegression", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "OnlineGradientDescent(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void OnlineGradientDescentRegressionAdvancedParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"RecencyGainMulti", true }, + }; + PipelineNode node = new PipelineNode("OnlineGradientDescentRegression", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; + string expectedTrainerString = "OnlineGradientDescent(new OnlineGradientDescentTrainer.Options(){RecencyGainMulti=true,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + [TestMethod] + public void LbfgsLogisticRegressionBinaryBasicTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("LbfgsLogisticRegressionBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "LbfgsLogisticRegression(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void LbfgsLogisticRegressionBinaryAdvancedParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"DenseOptimizer", true }, + }; + PipelineNode node = new PipelineNode("LbfgsLogisticRegressionBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; + string expectedTrainerString = "LbfgsLogisticRegression(new LbfgsLogisticRegressionBinaryTrainer.Options(){DenseOptimizer=true,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + [TestMethod] + public void LbfgsMaximumEntropyMultiMultiBasicTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("LbfgsMaximumEntropyMulti", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "LbfgsMaximumEntropy(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void LbfgsMaximumEntropyMultiAdvancedParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"DenseOptimizer", true }, + }; + PipelineNode node = new PipelineNode("LbfgsMaximumEntropyMulti", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; + string expectedTrainerString = "LbfgsMaximumEntropy(new LbfgsMaximumEntropyMulticlassTrainer.Options(){DenseOptimizer=true,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + [TestMethod] + public void LinearSvmBinaryBasicTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("LinearSvmBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "LinearSvm(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void LinearSvmBinaryParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"NoBias", true }, + }; + PipelineNode node = new PipelineNode("LinearSvmBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers;\r\n "; + string expectedTrainerString = "LinearSvm(new LinearSvmTrainer.Options(){NoBias=true,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + + [TestMethod] + public void FastTreeTweedieRegressionBasicTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("FastTreeTweedieRegression", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "FastTreeTweedie(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void FastTreeTweedieRegressionAdvancedParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"Shrinkage", 0.1f }, + }; + PipelineNode node = new PipelineNode("OnlineGradientDescentRegression", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers;\r\n"; + string expectedTrainerString = "OnlineGradientDescent(new OnlineGradientDescentTrainer.Options(){Shrinkage=0.1f,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + + [TestMethod] + public void FastTreeRegressionBasicTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("FastTreeRegression", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "FastTree(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void FastTreeRegressionAdvancedParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"Shrinkage", 0.1f }, + }; + PipelineNode node = new PipelineNode("FastTreeRegression", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers.FastTree;\r\n"; + string expectedTrainerString = "FastTree(new FastTreeRegressionTrainer.Options(){Shrinkage=0.1f,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + + [TestMethod] + public void FastTreeBinaryBasicTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("FastTreeBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "FastTree(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void FastTreeBinaryAdvancedParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"Shrinkage", 0.1f }, + }; + PipelineNode node = new PipelineNode("FastTreeBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers.FastTree;\r\n"; + string expectedTrainerString = "FastTree(new FastTreeBinaryTrainer.Options(){Shrinkage=0.1f,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + + [TestMethod] + public void FastForestRegressionBasicTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("FastForestRegression", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "FastForest(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void FastForestRegressionAdvancedParameterTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"Shrinkage", 0.1f }, + }; + PipelineNode node = new PipelineNode("FastForestRegression", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers.FastTree;\r\n"; + string expectedTrainerString = "FastForest(new FastForestRegression.Options(){Shrinkage=0.1f,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + + [TestMethod] + public void FastForestBinaryBasicTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("FastForestBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "FastForest(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void FastForestBinaryAdvancedParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"Shrinkage", 0.1f }, + }; + PipelineNode node = new PipelineNode("FastForestBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers.FastTree;\r\n"; + string expectedTrainerString = "FastForest(new FastForestClassification.Options(){Shrinkage=0.1f,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + + + [TestMethod] + public void AveragedPerceptronBinaryBasicTest() + { + var context = new MLContext(); + + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("AveragedPerceptronBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + string expectedTrainerString = "AveragedPerceptron(labelColumnName:\"Label\",featureColumnName:\"Features\")"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.IsNull(actual.Item2); + + } + + [TestMethod] + public void AveragedPerceptronBinaryAdvancedParameterTest() + { + + var context = new MLContext(); + + var elementProperties = new Dictionary() + { + {"Shuffle", true }, + }; + PipelineNode node = new PipelineNode("AveragedPerceptronBinary", PipelineNodeType.Trainer, default(string[]), default(string), elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTrainerAndUsings(); + var expectedUsings = "using Microsoft.ML.Trainers;\r\n "; + string expectedTrainerString = "AveragedPerceptron(new AveragedPerceptronTrainer.Options(){Shuffle=true,LabelColumnName=\"Label\",FeatureColumnName=\"Features\"})"; + Assert.AreEqual(expectedTrainerString, actual.Item1); + Assert.AreEqual(expectedUsings, actual.Item2[0]); + + } + } +} diff --git a/test/mlnet.Tests/TransformGeneratorTests.cs b/test/mlnet.Tests/TransformGeneratorTests.cs new file mode 100644 index 0000000000..07469b960f --- /dev/null +++ b/test/mlnet.Tests/TransformGeneratorTests.cs @@ -0,0 +1,172 @@ +using System.Collections.Generic; +using Microsoft.ML; +using Microsoft.ML.Auto; +using Microsoft.ML.CLI.CodeGenerator.CSharp; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace mlnet.Tests +{ + [TestClass] + public class TransformGeneratorTests + { + [TestMethod] + public void MissingValueReplacingTest() + { + var context = new MLContext(); + var elementProperties = new Dictionary();//categorical + PipelineNode node = new PipelineNode("MissingValueReplacing", PipelineNodeType.Transform, new string[] { "categorical_column_1" }, new string[] { "categorical_column_1" }, elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTransformsAndUsings(new PipelineNode[] { node }); + var expectedTransform = "ReplaceMissingValues(new []{new InputOutputColumnPair(\"categorical_column_1\",\"categorical_column_1\")})"; + Assert.AreEqual(expectedTransform, actual[0].Item1); + Assert.IsNull(actual[0].Item2); + } + + [TestMethod] + public void OneHotEncodingTest() + { + var context = new MLContext(); + var elementProperties = new Dictionary();//categorical + PipelineNode node = new PipelineNode("OneHotEncoding", PipelineNodeType.Transform, new string[] { "categorical_column_1" }, new string[] { "categorical_column_1" }, elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTransformsAndUsings(new PipelineNode[] { node }); + string expectedTransform = "Categorical.OneHotEncoding(new []{new InputOutputColumnPair(\"categorical_column_1\",\"categorical_column_1\")})"; + Assert.AreEqual(expectedTransform, actual[0].Item1); + Assert.IsNull(actual[0].Item2); + } + + [TestMethod] + public void NormalizingTest() + { + var context = new MLContext(); + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("Normalizing", PipelineNodeType.Transform, new string[] { "numeric_column_1" }, new string[] { "numeric_column_1_copy" }, elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTransformsAndUsings(new PipelineNode[] { node }); + string expectedTransform = "NormalizeMinMax(\"numeric_column_1_copy\",\"numeric_column_1\")"; + string expectedUsings = null; + Assert.AreEqual(expectedTransform, actual[0].Item1); + Assert.AreEqual(expectedUsings, actual[0].Item2); + } + + [TestMethod] + public void ColumnConcatenatingTest() + { + var context = new MLContext(); + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("ColumnConcatenating", PipelineNodeType.Transform, new string[] { "numeric_column_1", "numeric_column_2" }, new string[] { "Features" }, elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTransformsAndUsings(new PipelineNode[] { node }); + string expectedTransform = "Concatenate(\"Features\",new []{\"numeric_column_1\",\"numeric_column_2\"})"; + string expectedUsings = null; + Assert.AreEqual(expectedTransform, actual[0].Item1); + Assert.AreEqual(expectedUsings, actual[0].Item2); + } + + [TestMethod] + public void ColumnCopyingTest() + { + var context = new MLContext(); + var elementProperties = new Dictionary();//nume to num feature 2 + PipelineNode node = new PipelineNode("ColumnCopying", PipelineNodeType.Transform, new string[] { "numeric_column_1" }, new string[] { "numeric_column_2" }, elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTransformsAndUsings(new PipelineNode[] { node }); + string expectedTransform = "CopyColumns(\"numeric_column_2\",\"numeric_column_1\")"; + string expectedUsings = null; + Assert.AreEqual(expectedTransform, actual[0].Item1); + Assert.AreEqual(expectedUsings, actual[0].Item2); + } + + [TestMethod] + public void KeyToValueMappingTest() + { + var context = new MLContext(); + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("KeyToValueMapping", PipelineNodeType.Transform, new string[] { "Label" }, new string[] { "Label" }, elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTransformsAndUsings(new PipelineNode[] { node }); + string expectedTransform = "Conversion.MapKeyToValue(\"Label\",\"Label\")"; + Assert.AreEqual(expectedTransform, actual[0].Item1); + Assert.IsNull(actual[0].Item2); + } + + [TestMethod] + public void MissingValueIndicatingTest() + { + var context = new MLContext(); + var elementProperties = new Dictionary();//numeric feature + PipelineNode node = new PipelineNode("MissingValueIndicating", PipelineNodeType.Transform, new string[] { "numeric_column_1" }, new string[] { "numeric_column_1" }, elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTransformsAndUsings(new PipelineNode[] { node }); + string expectedTransform = "IndicateMissingValues(new []{new InputOutputColumnPair(\"numeric_column_1\",\"numeric_column_1\")})"; + string expectedUsings = null; + Assert.AreEqual(expectedTransform, actual[0].Item1); + Assert.AreEqual(expectedUsings, actual[0].Item2); + } + + [TestMethod] + public void OneHotHashEncodingTest() + { + var context = new MLContext(); + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("OneHotHashEncoding", PipelineNodeType.Transform, new string[] { "Categorical_column_1" }, new string[] { "Categorical_column_1" }, elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTransformsAndUsings(new PipelineNode[] { node }); + string expectedTransform = "Categorical.OneHotHashEncoding(new []{new InputOutputColumnPair(\"Categorical_column_1\",\"Categorical_column_1\")})"; + Assert.AreEqual(expectedTransform, actual[0].Item1); + Assert.IsNull(actual[0].Item2); + } + + [TestMethod] + public void TextFeaturizingTest() + { + var context = new MLContext(); + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("TextFeaturizing", PipelineNodeType.Transform, new string[] { "Text_column_1" }, new string[] { "Text_column_1" }, elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTransformsAndUsings(new PipelineNode[] { node }); + string expectedTransform = "Text.FeaturizeText(\"Text_column_1\",\"Text_column_1\")"; + string expectedUsings = null; + Assert.AreEqual(expectedTransform, actual[0].Item1); + Assert.AreEqual(expectedUsings, actual[0].Item2); + } + + [TestMethod] + public void TypeConvertingTest() + { + var context = new MLContext(); + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("TypeConverting", PipelineNodeType.Transform, new string[] { "I4_column_1" }, new string[] { "R4_column_1" }, elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTransformsAndUsings(new PipelineNode[] { node }); + string expectedTransform = "Conversion.ConvertType(new []{new InputOutputColumnPair(\"R4_column_1\",\"I4_column_1\")})"; + Assert.AreEqual(expectedTransform, actual[0].Item1); + Assert.IsNull(actual[0].Item2); + } + + [TestMethod] + public void ValueToKeyMappingTest() + { + var context = new MLContext(); + var elementProperties = new Dictionary(); + PipelineNode node = new PipelineNode("ValueToKeyMapping", PipelineNodeType.Transform, new string[] { "Label" }, new string[] { "Label" }, elementProperties); + Pipeline pipeline = new Pipeline(new PipelineNode[] { node }); + CodeGenerator codeGenerator = new CodeGenerator(pipeline, null, null); + var actual = codeGenerator.GenerateTransformsAndUsings(new PipelineNode[] { node }); + string expectedTransform = "Conversion.MapValueToKey(\"Label\",\"Label\")"; + Assert.AreEqual(expectedTransform, actual[0].Item1); + Assert.IsNull(actual[0].Item2); + } + + } +} diff --git a/test/mlnet.Tests/mlnet.Tests.csproj b/test/mlnet.Tests/mlnet.Tests.csproj new file mode 100644 index 0000000000..107f8e983b --- /dev/null +++ b/test/mlnet.Tests/mlnet.Tests.csproj @@ -0,0 +1,23 @@ + + + + netcoreapp2.1 + false + + false + false + + + + + + + + + + + + + + + diff --git a/test/mlnet.Tests/run-tests.proj b/test/mlnet.Tests/run-tests.proj new file mode 100644 index 0000000000..dd2433b3c5 --- /dev/null +++ b/test/mlnet.Tests/run-tests.proj @@ -0,0 +1,19 @@ + + + + + + + + + + + + + + + + + \ No newline at end of file