diff --git a/Microsoft.ML.AutoML.sln b/Microsoft.ML.AutoML.sln
new file mode 100644
index 0000000000..280cef5704
--- /dev/null
+++ b/Microsoft.ML.AutoML.sln
@@ -0,0 +1,91 @@
+
+Microsoft Visual Studio Solution File, Format Version 12.00
+# Visual Studio 15
+VisualStudioVersion = 15.0.28010.2050
+MinimumVisualStudioVersion = 10.0.40219.1
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Auto", "src\Microsoft.ML.Auto\Microsoft.ML.Auto.csproj", "{B3727729-3DF8-47E0-8710-9B41DAF55817}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.AutoML.Tests", "test\Microsoft.ML.AutoML.Tests\Microsoft.ML.AutoML.Tests.csproj", "{55ACB7E2-053D-43BB-88E8-0E102FBD62F0}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "mlnet", "src\mlnet\mlnet.csproj", "{ED714FA5-6F89-401B-9E7F-CADF1373C553}"
+EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "mlnet.Tests", "test\mlnet.Tests\mlnet.Tests.csproj", "{AAC3E4E6-C146-44BB-8873-A1E61D563F2A}"
+EndProject
+Global
+ GlobalSection(SolutionConfigurationPlatforms) = preSolution
+ Debug|Any CPU = Debug|Any CPU
+ Debug-Intrinsics|Any CPU = Debug-Intrinsics|Any CPU
+ Debug-netfx|Any CPU = Debug-netfx|Any CPU
+ Release|Any CPU = Release|Any CPU
+ Release-Intrinsics|Any CPU = Release-Intrinsics|Any CPU
+ Release-netfx|Any CPU = Release-netfx|Any CPU
+ EndGlobalSection
+ GlobalSection(ProjectConfigurationPlatforms) = postSolution
+ {B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
+ {B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
+ {B3727729-3DF8-47E0-8710-9B41DAF55817}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
+ {B3727729-3DF8-47E0-8710-9B41DAF55817}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {B3727729-3DF8-47E0-8710-9B41DAF55817}.Release|Any CPU.Build.0 = Release|Any CPU
+ {B3727729-3DF8-47E0-8710-9B41DAF55817}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {B3727729-3DF8-47E0-8710-9B41DAF55817}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
+ {B3727729-3DF8-47E0-8710-9B41DAF55817}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
+ {B3727729-3DF8-47E0-8710-9B41DAF55817}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
+ {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
+ {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
+ {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
+ {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release|Any CPU.Build.0 = Release|Any CPU
+ {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
+ {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
+ {64A7294E-A2C7-4499-8F0B-4BB074047C6B}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
+ {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
+ {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
+ {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
+ {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
+ {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release|Any CPU.Build.0 = Release|Any CPU
+ {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
+ {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
+ {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
+ {55ACB7E2-053D-43BB-88E8-0E102FBD62F0}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
+ {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
+ {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
+ {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
+ {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
+ {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release|Any CPU.Build.0 = Release|Any CPU
+ {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
+ {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
+ {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
+ {ED714FA5-6F89-401B-9E7F-CADF1373C553}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
+ {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug-Intrinsics|Any CPU
+ {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug-Intrinsics|Any CPU.Build.0 = Debug-Intrinsics|Any CPU
+ {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug-netfx|Any CPU.ActiveCfg = Debug-netfx|Any CPU
+ {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Debug-netfx|Any CPU.Build.0 = Debug-netfx|Any CPU
+ {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release|Any CPU.Build.0 = Release|Any CPU
+ {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release-Intrinsics|Any CPU.ActiveCfg = Release-Intrinsics|Any CPU
+ {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release-Intrinsics|Any CPU.Build.0 = Release-Intrinsics|Any CPU
+ {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release-netfx|Any CPU.ActiveCfg = Release-netfx|Any CPU
+ {AAC3E4E6-C146-44BB-8873-A1E61D563F2A}.Release-netfx|Any CPU.Build.0 = Release-netfx|Any CPU
+ EndGlobalSection
+ GlobalSection(SolutionProperties) = preSolution
+ HideSolutionNode = FALSE
+ EndGlobalSection
+ GlobalSection(ExtensibilityGlobals) = postSolution
+ SolutionGuid = {8C1BC26C-B87E-47CD-928E-00EFE4353B40}
+ EndGlobalSection
+EndGlobal
diff --git a/build.proj b/build.proj
index 15fea4e309..3dd5edd0e7 100644
--- a/build.proj
+++ b/build.proj
@@ -22,6 +22,7 @@
+
diff --git a/src/Microsoft.ML.Auto/API/AutoInferenceCatalog.cs b/src/Microsoft.ML.Auto/API/AutoInferenceCatalog.cs
new file mode 100644
index 0000000000..adf04111f2
--- /dev/null
+++ b/src/Microsoft.ML.Auto/API/AutoInferenceCatalog.cs
@@ -0,0 +1,79 @@
+// Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ public sealed class AutoMLCatalog
+ {
+ private readonly MLContext _context;
+
+ internal AutoMLCatalog(MLContext context)
+ {
+ _context = context;
+ }
+
+ public RegressionExperiment CreateRegressionExperiment(uint maxExperimentTimeInSeconds)
+ {
+ return new RegressionExperiment(_context, new RegressionExperimentSettings()
+ {
+ MaxExperimentTimeInSeconds = maxExperimentTimeInSeconds
+ });
+ }
+
+ public RegressionExperiment CreateRegressionExperiment(RegressionExperimentSettings experimentSettings)
+ {
+ return new RegressionExperiment(_context, experimentSettings);
+ }
+
+ public BinaryClassificationExperiment CreateBinaryClassificationExperiment(uint maxExperimentTimeInSeconds)
+ {
+ return new BinaryClassificationExperiment(_context, new BinaryExperimentSettings()
+ {
+ MaxExperimentTimeInSeconds = maxExperimentTimeInSeconds
+ });
+ }
+
+ public BinaryClassificationExperiment CreateBinaryClassificationExperiment(BinaryExperimentSettings experimentSettings)
+ {
+ return new BinaryClassificationExperiment(_context, experimentSettings);
+ }
+
+ public MulticlassClassificationExperiment CreateMulticlassClassificationExperiment(uint maxExperimentTimeInSeconds)
+ {
+ return new MulticlassClassificationExperiment(_context, new MulticlassExperimentSettings()
+ {
+ MaxExperimentTimeInSeconds = maxExperimentTimeInSeconds
+ });
+ }
+
+ public MulticlassClassificationExperiment CreateMulticlassClassificationExperiment(MulticlassExperimentSettings experimentSettings)
+ {
+ return new MulticlassClassificationExperiment(_context, experimentSettings);
+ }
+
+ public ColumnInferenceResults InferColumns(string path, string labelColumn = DefaultColumnNames.Label, char? separatorChar = null, bool? allowQuotedStrings = null,
+ bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true)
+ {
+ UserInputValidationUtil.ValidateInferColumnsArgs(path, labelColumn);
+ return ColumnInferenceApi.InferColumns(_context, path, labelColumn, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
+ }
+
+ public ColumnInferenceResults InferColumns(string path, ColumnInformation columnInformation, char? separatorChar = null, bool? allowQuotedStrings = null,
+ bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true)
+ {
+ columnInformation = columnInformation ?? new ColumnInformation();
+ UserInputValidationUtil.ValidateInferColumnsArgs(path, columnInformation);
+ return ColumnInferenceApi.InferColumns(_context, path, columnInformation, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
+ }
+
+ public ColumnInferenceResults InferColumns(string path, uint labelColumnIndex, bool hasHeader = false, char? separatorChar = null,
+ bool? allowQuotedStrings = null, bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true)
+ {
+ UserInputValidationUtil.ValidateInferColumnsArgs(path);
+ return ColumnInferenceApi.InferColumns(_context, path, labelColumnIndex, hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/API/BinaryClassificationExperiment.cs b/src/Microsoft.ML.Auto/API/BinaryClassificationExperiment.cs
new file mode 100644
index 0000000000..3ab9dbb7a1
--- /dev/null
+++ b/src/Microsoft.ML.Auto/API/BinaryClassificationExperiment.cs
@@ -0,0 +1,73 @@
+// Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ public sealed class BinaryExperimentSettings : ExperimentSettings
+ {
+ public BinaryClassificationMetric OptimizingMetric { get; set; } = BinaryClassificationMetric.Accuracy;
+ public ICollection Trainers { get; } =
+ Enum.GetValues(typeof(BinaryClassificationTrainer)).OfType().ToList();
+ }
+
+ public enum BinaryClassificationMetric
+ {
+ Accuracy,
+ AreaUnderRocCurve,
+ AreaUnderPrecisionRecallCurve,
+ F1Score,
+ PositivePrecision,
+ PositiveRecall,
+ NegativePrecision,
+ NegativeRecall,
+ }
+
+ public enum BinaryClassificationTrainer
+ {
+ AveragedPerceptron,
+ FastForest,
+ FastTree,
+ LightGbm,
+ LinearSupportVectorMachines,
+ LbfgsLogisticRegression,
+ SdcaLogisticRegression,
+ SgdCalibrated,
+ SymbolicSgdLogisticRegression,
+ }
+
+ public sealed class BinaryClassificationExperiment : ExperimentBase
+ {
+ internal BinaryClassificationExperiment(MLContext context, BinaryExperimentSettings settings)
+ : base(context,
+ new BinaryMetricsAgent(context, settings.OptimizingMetric),
+ new OptimizingMetricInfo(settings.OptimizingMetric),
+ settings,
+ TaskKind.BinaryClassification,
+ TrainerExtensionUtil.GetTrainerNames(settings.Trainers))
+ {
+ }
+ }
+
+ public static class BinaryExperimentResultExtensions
+ {
+ public static RunDetail Best(this IEnumerable> results, BinaryClassificationMetric metric = BinaryClassificationMetric.Accuracy)
+ {
+ var metricsAgent = new BinaryMetricsAgent(null, metric);
+ var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
+ return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
+ }
+
+ public static CrossValidationRunDetail Best(this IEnumerable> results, BinaryClassificationMetric metric = BinaryClassificationMetric.Accuracy)
+ {
+ var metricsAgent = new BinaryMetricsAgent(null, metric);
+ var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
+ return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/API/ColumnInference.cs b/src/Microsoft.ML.Auto/API/ColumnInference.cs
new file mode 100644
index 0000000000..588116897d
--- /dev/null
+++ b/src/Microsoft.ML.Auto/API/ColumnInference.cs
@@ -0,0 +1,27 @@
+// Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.Collections.ObjectModel;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ public sealed class ColumnInferenceResults
+ {
+ public TextLoader.Options TextLoaderOptions { get; internal set; } = new TextLoader.Options();
+ public ColumnInformation ColumnInformation { get; internal set; } = new ColumnInformation();
+ }
+
+ public sealed class ColumnInformation
+ {
+ public string LabelColumnName { get; set; } = DefaultColumnNames.Label;
+ public string ExampleWeightColumnName { get; set; }
+ public string SamplingKeyColumnName { get; set; }
+ public ICollection CategoricalColumnNames { get; } = new Collection();
+ public ICollection NumericColumnNames { get; } = new Collection();
+ public ICollection TextColumnNames { get; } = new Collection();
+ public ICollection IgnoredColumnNames { get; } = new Collection();
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Auto/API/ExperimentBase.cs b/src/Microsoft.ML.Auto/API/ExperimentBase.cs
new file mode 100644
index 0000000000..381196c54c
--- /dev/null
+++ b/src/Microsoft.ML.Auto/API/ExperimentBase.cs
@@ -0,0 +1,202 @@
+// Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+
+namespace Microsoft.ML.Auto
+{
+ public abstract class ExperimentBase where TMetrics : class
+ {
+ protected readonly MLContext Context;
+
+ private readonly IMetricsAgent _metricsAgent;
+ private readonly OptimizingMetricInfo _optimizingMetricInfo;
+ private readonly ExperimentSettings _settings;
+ private readonly TaskKind _task;
+ private readonly IEnumerable _trainerWhitelist;
+
+ internal ExperimentBase(MLContext context,
+ IMetricsAgent metricsAgent,
+ OptimizingMetricInfo optimizingMetricInfo,
+ ExperimentSettings settings,
+ TaskKind task,
+ IEnumerable trainerWhitelist)
+ {
+ Context = context;
+ _metricsAgent = metricsAgent;
+ _optimizingMetricInfo = optimizingMetricInfo;
+ _settings = settings;
+ _task = task;
+ _trainerWhitelist = trainerWhitelist;
+ }
+
+ public IEnumerable> Execute(IDataView trainData, string labelColumn = DefaultColumnNames.Label,
+ string samplingKeyColumn = null, IEstimator preFeaturizers = null, IProgress> progressHandler = null)
+ {
+ var columnInformation = new ColumnInformation()
+ {
+ LabelColumnName = labelColumn,
+ SamplingKeyColumnName = samplingKeyColumn
+ };
+ return Execute(trainData, columnInformation, preFeaturizers, progressHandler);
+ }
+
+ public IEnumerable> Execute(IDataView trainData, ColumnInformation columnInformation,
+ IEstimator preFeaturizer = null, IProgress> progressHandler = null)
+ {
+ // Cross val threshold for # of dataset rows --
+ // If dataset has < threshold # of rows, use cross val.
+ // Else, run experiment using train-validate split.
+ const int crossValRowCountThreshold = 15000;
+
+ var rowCount = DatasetDimensionsUtil.CountRows(trainData, crossValRowCountThreshold);
+
+ if (rowCount < crossValRowCountThreshold)
+ {
+ const int numCrossValFolds = 10;
+ var splitResult = SplitUtil.CrossValSplit(Context, trainData, numCrossValFolds, columnInformation?.SamplingKeyColumnName);
+ return ExecuteCrossValSummary(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizer, progressHandler);
+ }
+ else
+ {
+ var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumnName);
+ return ExecuteTrainValidate(splitResult.trainData, columnInformation, splitResult.validationData, preFeaturizer, progressHandler);
+ }
+ }
+
+ public IEnumerable> Execute(IDataView trainData, IDataView validationData, string labelColumn = DefaultColumnNames.Label, IEstimator preFeaturizer = null, IProgress> progressHandler = null)
+ {
+ var columnInformation = new ColumnInformation() { LabelColumnName = labelColumn };
+ return Execute(trainData, validationData, columnInformation, preFeaturizer, progressHandler);
+ }
+
+ public IEnumerable> Execute(IDataView trainData, IDataView validationData, ColumnInformation columnInformation, IEstimator preFeaturizer = null, IProgress> progressHandler = null)
+ {
+ if (validationData == null)
+ {
+ var splitResult = SplitUtil.TrainValidateSplit(Context, trainData, columnInformation?.SamplingKeyColumnName);
+ trainData = splitResult.trainData;
+ validationData = splitResult.validationData;
+ }
+ return ExecuteTrainValidate(trainData, columnInformation, validationData, preFeaturizer, progressHandler);
+ }
+
+ public IEnumerable> Execute(IDataView trainData, uint numberOfCVFolds, ColumnInformation columnInformation = null, IEstimator preFeaturizer = null, IProgress> progressHandler = null)
+ {
+ UserInputValidationUtil.ValidateNumberOfCVFoldsArg(numberOfCVFolds);
+ var splitResult = SplitUtil.CrossValSplit(Context, trainData, numberOfCVFolds, columnInformation?.SamplingKeyColumnName);
+ return ExecuteCrossVal(splitResult.trainDatasets, columnInformation, splitResult.validationDatasets, preFeaturizer, progressHandler);
+ }
+
+ public IEnumerable> Execute(IDataView trainData,
+ uint numberOfCVFolds, string labelColumn = DefaultColumnNames.Label,
+ string samplingKeyColumn = null, IEstimator preFeaturizer = null,
+ Progress> progressHandler = null)
+ {
+ var columnInformation = new ColumnInformation()
+ {
+ LabelColumnName = labelColumn,
+ SamplingKeyColumnName = samplingKeyColumn
+ };
+ return Execute(trainData, numberOfCVFolds, columnInformation, preFeaturizer, progressHandler);
+ }
+
+ private IEnumerable> ExecuteTrainValidate(IDataView trainData,
+ ColumnInformation columnInfo,
+ IDataView validationData,
+ IEstimator preFeaturizer,
+ IProgress> progressHandler)
+ {
+ columnInfo = columnInfo ?? new ColumnInformation();
+ UserInputValidationUtil.ValidateExperimentExecuteArgs(trainData, columnInfo, validationData);
+
+ // Apply pre-featurizer
+ ITransformer preprocessorTransform = null;
+ if (preFeaturizer != null)
+ {
+ preprocessorTransform = preFeaturizer.Fit(trainData);
+ trainData = preprocessorTransform.Transform(trainData);
+ validationData = preprocessorTransform.Transform(validationData);
+ }
+
+ var runner = new TrainValidateRunner(Context, trainData, validationData, columnInfo.LabelColumnName, _metricsAgent,
+ preFeaturizer, preprocessorTransform, _settings.DebugLogger);
+ var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainData, columnInfo);
+ return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
+ }
+
+ private IEnumerable> ExecuteCrossVal(IDataView[] trainDatasets,
+ ColumnInformation columnInfo,
+ IDataView[] validationDatasets,
+ IEstimator preFeaturizer,
+ IProgress> progressHandler)
+ {
+ columnInfo = columnInfo ?? new ColumnInformation();
+ UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0]);
+
+ // Apply pre-featurizer
+ ITransformer[] preprocessorTransforms = null;
+ (trainDatasets, validationDatasets, preprocessorTransforms) = ApplyPreFeaturizerCrossVal(trainDatasets, validationDatasets, preFeaturizer);
+
+ var runner = new CrossValRunner(Context, trainDatasets, validationDatasets, _metricsAgent, preFeaturizer,
+ preprocessorTransforms, columnInfo.LabelColumnName, _settings.DebugLogger);
+ var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainDatasets[0], columnInfo);
+ return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
+ }
+
+ private IEnumerable> ExecuteCrossValSummary(IDataView[] trainDatasets,
+ ColumnInformation columnInfo,
+ IDataView[] validationDatasets,
+ IEstimator preFeaturizer,
+ IProgress> progressHandler)
+ {
+ columnInfo = columnInfo ?? new ColumnInformation();
+ UserInputValidationUtil.ValidateExperimentExecuteArgs(trainDatasets[0], columnInfo, validationDatasets[0]);
+
+ // Apply pre-featurizer
+ ITransformer[] preprocessorTransforms = null;
+ (trainDatasets, validationDatasets, preprocessorTransforms) = ApplyPreFeaturizerCrossVal(trainDatasets, validationDatasets, preFeaturizer);
+
+ var runner = new CrossValSummaryRunner(Context, trainDatasets, validationDatasets, _metricsAgent, preFeaturizer,
+ preprocessorTransforms, columnInfo.LabelColumnName, _optimizingMetricInfo, _settings.DebugLogger);
+ var columns = DatasetColumnInfoUtil.GetDatasetColumnInfo(Context, trainDatasets[0], columnInfo);
+ return Execute(columnInfo, columns, preFeaturizer, progressHandler, runner);
+ }
+
+ private IEnumerable Execute(ColumnInformation columnInfo,
+ DatasetColumnInfo[] columns,
+ IEstimator preFeaturizer,
+ IProgress progressHandler,
+ IRunner runner)
+ where TRunDetail : RunDetail
+ {
+ // Execute experiment & get all pipelines run
+ var experiment = new Experiment(Context, _task, _optimizingMetricInfo, progressHandler,
+ _settings, _metricsAgent, _trainerWhitelist, columns, runner);
+
+ return experiment.Execute();
+ }
+
+ private static (IDataView[] trainDatasets, IDataView[] validDatasets, ITransformer[] preprocessorTransforms)
+ ApplyPreFeaturizerCrossVal(IDataView[] trainDatasets, IDataView[] validDatasets, IEstimator preFeaturizer)
+ {
+ if (preFeaturizer == null)
+ {
+ return (trainDatasets, validDatasets, null);
+ }
+
+ var preprocessorTransforms = new ITransformer[trainDatasets.Length];
+ for (var i = 0; i < trainDatasets.Length; i++)
+ {
+ // Preprocess train and validation data
+ preprocessorTransforms[i] = preFeaturizer.Fit(trainDatasets[i]);
+ trainDatasets[i] = preprocessorTransforms[i].Transform(trainDatasets[i]);
+ validDatasets[i] = preprocessorTransforms[i].Transform(validDatasets[i]);
+ }
+
+ return (trainDatasets, validDatasets, preprocessorTransforms);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/API/ExperimentSettings.cs b/src/Microsoft.ML.Auto/API/ExperimentSettings.cs
new file mode 100644
index 0000000000..43c6c8befe
--- /dev/null
+++ b/src/Microsoft.ML.Auto/API/ExperimentSettings.cs
@@ -0,0 +1,33 @@
+// Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.IO;
+using System.Threading;
+
+namespace Microsoft.ML.Auto
+{
+ public class ExperimentSettings
+ {
+ public uint MaxExperimentTimeInSeconds { get; set; } = 24 * 60 * 60;
+ public CancellationToken CancellationToken { get; set; } = default;
+
+ ///
+ /// This is a pointer to a directory where all models trained during the AutoML experiment will be saved.
+ /// If null, models will be kept in memory instead of written to disk.
+ /// (Please note: for an experiment with high runtime operating on a large dataset, opting to keep models in
+ /// memory could cause a system to run out of memory.)
+ ///
+ public DirectoryInfo CacheDirectory { get; set; } = new DirectoryInfo(Path.Combine(Path.GetTempPath(), "Microsoft.ML.Auto"));
+
+ ///
+ /// This setting controls whether or not an AutoML experiment will make use of ML.NET-provided caching.
+ /// If set to true, caching will be forced on for all pipelines. If set to false, caching will be forced off.
+ /// If set to null (default value), AutoML will decide whether to enable caching for each model.
+ ///
+ public bool? CacheBeforeTrainer = null;
+
+ internal int MaxModels = int.MaxValue;
+ internal IDebugLogger DebugLogger;
+ }
+}
diff --git a/src/Microsoft.ML.Auto/API/InferenceException.cs b/src/Microsoft.ML.Auto/API/InferenceException.cs
new file mode 100644
index 0000000000..423c4ae3ce
--- /dev/null
+++ b/src/Microsoft.ML.Auto/API/InferenceException.cs
@@ -0,0 +1,31 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.ML.Auto
+{
+ public enum InferenceType
+ {
+ ColumnDataKind,
+ ColumnSplit,
+ Label,
+ }
+
+ public sealed class InferenceException : Exception
+ {
+ public InferenceType InferenceType;
+
+ public InferenceException(InferenceType inferenceType, string message)
+ : base(message)
+ {
+ }
+
+ public InferenceException(InferenceType inferenceType, string message, Exception inner)
+ : base(message, inner)
+ {
+ }
+ }
+
+}
diff --git a/src/Microsoft.ML.Auto/API/MLContextExtension.cs b/src/Microsoft.ML.Auto/API/MLContextExtension.cs
new file mode 100644
index 0000000000..9287fe827c
--- /dev/null
+++ b/src/Microsoft.ML.Auto/API/MLContextExtension.cs
@@ -0,0 +1,14 @@
+// Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.ML.Auto
+{
+ public static class MLContextExtension
+ {
+ public static AutoMLCatalog Auto(this MLContext mlContext)
+ {
+ return new AutoMLCatalog(mlContext);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/API/MulticlassClassificationExperiment.cs b/src/Microsoft.ML.Auto/API/MulticlassClassificationExperiment.cs
new file mode 100644
index 0000000000..f7f5a856cb
--- /dev/null
+++ b/src/Microsoft.ML.Auto/API/MulticlassClassificationExperiment.cs
@@ -0,0 +1,71 @@
+// Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ public sealed class MulticlassExperimentSettings : ExperimentSettings
+ {
+ public MulticlassClassificationMetric OptimizingMetric { get; set; } = MulticlassClassificationMetric.MicroAccuracy;
+ public ICollection Trainers { get; } =
+ Enum.GetValues(typeof(MulticlassClassificationTrainer)).OfType().ToList();
+ }
+
+ public enum MulticlassClassificationMetric
+ {
+ MicroAccuracy,
+ MacroAccuracy,
+ LogLoss,
+ LogLossReduction,
+ TopKAccuracy,
+ }
+
+ public enum MulticlassClassificationTrainer
+ {
+ AveragedPerceptronOVA,
+ FastForestOVA,
+ FastTreeOVA,
+ LightGbm,
+ LinearSupportVectorMachinesOVA,
+ LbfgsMaximumEntropy,
+ LbfgsLogisticRegressionOVA,
+ SdcaMaximumEntropy,
+ SgdCalibratedOVA,
+ SymbolicSgdLogisticRegressionOVA,
+ }
+
+ public sealed class MulticlassClassificationExperiment : ExperimentBase
+ {
+ internal MulticlassClassificationExperiment(MLContext context, MulticlassExperimentSettings settings)
+ : base(context,
+ new MultiMetricsAgent(context, settings.OptimizingMetric),
+ new OptimizingMetricInfo(settings.OptimizingMetric),
+ settings,
+ TaskKind.MulticlassClassification,
+ TrainerExtensionUtil.GetTrainerNames(settings.Trainers))
+ {
+ }
+ }
+
+ public static class MulticlassExperimentResultExtensions
+ {
+ public static RunDetail Best(this IEnumerable> results, MulticlassClassificationMetric metric = MulticlassClassificationMetric.MicroAccuracy)
+ {
+ var metricsAgent = new MultiMetricsAgent(null, metric);
+ var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
+ return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
+ }
+
+ public static CrossValidationRunDetail Best(this IEnumerable> results, MulticlassClassificationMetric metric = MulticlassClassificationMetric.MicroAccuracy)
+ {
+ var metricsAgent = new MultiMetricsAgent(null, metric);
+ var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
+ return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Auto/API/Pipeline.cs b/src/Microsoft.ML.Auto/API/Pipeline.cs
new file mode 100644
index 0000000000..864d82037b
--- /dev/null
+++ b/src/Microsoft.ML.Auto/API/Pipeline.cs
@@ -0,0 +1,110 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+
+namespace Microsoft.ML.Auto
+{
+ internal class Pipeline
+ {
+ public PipelineNode[] Nodes { get; set; }
+ public bool CacheBeforeTrainer { get; set; }
+
+ public Pipeline(PipelineNode[] nodes, bool cacheBeforeTrainer = false)
+ {
+ Nodes = nodes;
+ CacheBeforeTrainer = cacheBeforeTrainer;
+ }
+
+ // (used by Newtonsoft)
+ internal Pipeline()
+ {
+ }
+
+ public IEstimator ToEstimator(MLContext context)
+ {
+ var inferredPipeline = SuggestedPipeline.FromPipeline(context, this);
+ return inferredPipeline.ToEstimator();
+ }
+ }
+
+ internal class PipelineNode
+ {
+ public string Name { get; set; }
+ public PipelineNodeType NodeType { get; set; }
+ public string[] InColumns { get; set; }
+ public string[] OutColumns { get; set; }
+ public IDictionary Properties { get; set; }
+
+ public PipelineNode(string name, PipelineNodeType nodeType,
+ string[] inColumns, string[] outColumns,
+ IDictionary properties = null)
+ {
+ Name = name;
+ NodeType = nodeType;
+ InColumns = inColumns;
+ OutColumns = outColumns;
+ Properties = properties ?? new Dictionary();
+ }
+
+ public PipelineNode(string name, PipelineNodeType nodeType,
+ string inColumn, string outColumn, IDictionary properties = null) :
+ this(name, nodeType, new string[] { inColumn }, new string[] { outColumn }, properties)
+ {
+ }
+
+ public PipelineNode(string name, PipelineNodeType nodeType,
+ string[] inColumns, string outColumn, IDictionary properties = null) :
+ this(name, nodeType, inColumns, new string[] { outColumn }, properties)
+ {
+ }
+
+ // (used by Newtonsoft)
+ internal PipelineNode()
+ {
+ }
+ }
+
+ internal enum PipelineNodeType
+ {
+ Transform,
+ Trainer
+ }
+
+ internal class CustomProperty
+ {
+ public string Name { get; set; }
+ public IDictionary Properties { get; set; }
+
+ public CustomProperty(string name, IDictionary properties)
+ {
+ Name = name;
+ Properties = properties;
+ }
+
+ internal CustomProperty()
+ {
+ }
+ }
+
+ internal class PipelineScore
+ {
+ public readonly double Score;
+
+ ///
+ /// This setting is true if the pipeline run succeeded and ran to completion.
+ /// Else, it is false if some exception was thrown before the run could complete.
+ ///
+ public readonly bool RunSucceded;
+
+ internal readonly Pipeline Pipeline;
+
+ internal PipelineScore(Pipeline pipeline, double score, bool runSucceeded)
+ {
+ Pipeline = pipeline;
+ Score = score;
+ RunSucceded = runSucceeded;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/API/RegressionExperiment.cs b/src/Microsoft.ML.Auto/API/RegressionExperiment.cs
new file mode 100644
index 0000000000..51f5988f64
--- /dev/null
+++ b/src/Microsoft.ML.Auto/API/RegressionExperiment.cs
@@ -0,0 +1,68 @@
+// Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ public sealed class RegressionExperimentSettings : ExperimentSettings
+ {
+ public RegressionMetric OptimizingMetric { get; set; } = RegressionMetric.RSquared;
+ public ICollection Trainers { get; } =
+ Enum.GetValues(typeof(RegressionTrainer)).OfType().ToList();
+ }
+
+ public enum RegressionMetric
+ {
+ MeanAbsoluteError,
+ MeanSquaredError,
+ RootMeanSquaredError,
+ RSquared
+ }
+
+ public enum RegressionTrainer
+ {
+ FastForest,
+ FastTree,
+ FastTreeTweedie,
+ LightGbm,
+ OnlineGradientDescent,
+ Ols,
+ LbfgsPoissonRegression,
+ StochasticDualCoordinateAscent,
+ }
+
+ public sealed class RegressionExperiment : ExperimentBase
+ {
+ internal RegressionExperiment(MLContext context, RegressionExperimentSettings settings)
+ : base(context,
+ new RegressionMetricsAgent(context, settings.OptimizingMetric),
+ new OptimizingMetricInfo(settings.OptimizingMetric),
+ settings,
+ TaskKind.Regression,
+ TrainerExtensionUtil.GetTrainerNames(settings.Trainers))
+ {
+ }
+ }
+
+ public static class RegressionExperimentResultExtensions
+ {
+ public static RunDetail Best(this IEnumerable> results, RegressionMetric metric = RegressionMetric.RSquared)
+ {
+ var metricsAgent = new RegressionMetricsAgent(null, metric);
+ var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
+ return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
+ }
+
+ public static CrossValidationRunDetail Best(this IEnumerable> results, RegressionMetric metric = RegressionMetric.RSquared)
+ {
+ var metricsAgent = new RegressionMetricsAgent(null, metric);
+ var isMetricMaximizing = new OptimizingMetricInfo(metric).IsMaximizing;
+ return BestResultUtil.GetBestRun(results, metricsAgent, isMetricMaximizing);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/API/RunDetails/CrossValidationRunDetail.cs b/src/Microsoft.ML.Auto/API/RunDetails/CrossValidationRunDetail.cs
new file mode 100644
index 0000000000..713c820a99
--- /dev/null
+++ b/src/Microsoft.ML.Auto/API/RunDetails/CrossValidationRunDetail.cs
@@ -0,0 +1,41 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+
+namespace Microsoft.ML.Auto
+{
+ public sealed class CrossValidationRunDetail : RunDetail
+ {
+ public IEnumerable> Results { get; private set; }
+
+ internal CrossValidationRunDetail(string trainerName,
+ IEstimator estimator,
+ Pipeline pipeline,
+ IEnumerable> results) : base(trainerName, estimator, pipeline)
+ {
+ Results = results;
+ }
+ }
+
+ public sealed class TrainResult
+ {
+ public TMetrics ValidationMetrics { get; private set; }
+ public ITransformer Model { get { return _modelContainer.GetModel(); } }
+ public Exception Exception { get; private set; }
+
+ private readonly ModelContainer _modelContainer;
+
+ internal TrainResult(ModelContainer modelContainer,
+ TMetrics metrics,
+ Exception exception)
+ {
+ _modelContainer = modelContainer;
+ ValidationMetrics = metrics;
+ Exception = exception;
+ }
+ }
+
+}
diff --git a/src/Microsoft.ML.Auto/API/RunDetails/RunDetail.cs b/src/Microsoft.ML.Auto/API/RunDetails/RunDetail.cs
new file mode 100644
index 0000000000..a83670986d
--- /dev/null
+++ b/src/Microsoft.ML.Auto/API/RunDetails/RunDetail.cs
@@ -0,0 +1,48 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.ML.Auto
+{
+ public sealed class RunDetail : RunDetail
+ {
+ public TMetrics ValidationMetrics { get; private set; }
+ public ITransformer Model { get { return _modelContainer.GetModel(); } }
+ public Exception Exception { get; private set; }
+
+ private readonly ModelContainer _modelContainer;
+
+ internal RunDetail(string trainerName,
+ IEstimator estimator,
+ Pipeline pipeline,
+ ModelContainer modelContainer,
+ TMetrics metrics,
+ Exception exception) : base(trainerName, estimator, pipeline)
+ {
+ _modelContainer = modelContainer;
+ ValidationMetrics = metrics;
+ Exception = exception;
+ }
+ }
+
+ public abstract class RunDetail
+ {
+ public string TrainerName { get; private set; }
+ public double RuntimeInSeconds { get; internal set; }
+ public IEstimator Estimator { get; private set; }
+
+ internal Pipeline Pipeline { get; private set; }
+ internal double PipelineInferenceTimeInSeconds { get; set; }
+
+ internal RunDetail(string trainerName,
+ IEstimator estimator,
+ Pipeline pipeline)
+ {
+ TrainerName = trainerName;
+ Estimator = estimator;
+ Pipeline = pipeline;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Assembly.cs b/src/Microsoft.ML.Auto/Assembly.cs
new file mode 100644
index 0000000000..4a3999e45a
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Assembly.cs
@@ -0,0 +1,10 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Runtime.CompilerServices;
+
+[assembly: InternalsVisibleTo("Microsoft.ML.AutoML.Tests, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")]
+[assembly: InternalsVisibleTo("mlnet, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")]
+[assembly: InternalsVisibleTo("mlnet.Tests, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")]
+[assembly: InternalsVisibleTo("Benchmark, PublicKey=00240000048000009400000006020000002400005253413100040000010001004b86c4cb78549b34bab61a3b1800e23bfeb5b3ec390074041536a7e3cbd97f5f04cf0f857155a8928eaa29ebfd11cfbbad3ba70efea7bda3226c6a8d370a4cd303f714486b6ebc225985a638471e6ef571cc92a4613c00b8fa65d61ccee0cbe5f36330c9a01f4183559f1bef24cc2917c6d913e3a541333a1d05d9bed22b38cb")]
diff --git a/src/Microsoft.ML.Auto/AutoMlUtils.cs b/src/Microsoft.ML.Auto/AutoMlUtils.cs
new file mode 100644
index 0000000000..1ee5570ee3
--- /dev/null
+++ b/src/Microsoft.ML.Auto/AutoMlUtils.cs
@@ -0,0 +1,24 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Threading;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ internal static class AutoMlUtils
+ {
+ public static readonly ThreadLocal random = new ThreadLocal(() => new Random());
+
+ public static void Assert(bool boolVal, string message = null)
+ {
+ if (!boolVal)
+ {
+ message = message ?? "Assertion failed";
+ throw new InvalidOperationException(message);
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/ColumnInference/ColumnGroupingInference.cs b/src/Microsoft.ML.Auto/ColumnInference/ColumnGroupingInference.cs
new file mode 100644
index 0000000000..c4535ba7d0
--- /dev/null
+++ b/src/Microsoft.ML.Auto/ColumnInference/ColumnGroupingInference.cs
@@ -0,0 +1,151 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Microsoft.ML.Data;
+using static Microsoft.ML.Data.TextLoader;
+
+namespace Microsoft.ML.Auto
+{
+ ///
+ /// This class incapsulates logic for grouping together the inferred columns of the text file based on their type
+ /// and purpose, and generating column names.
+ ///
+ internal static class ColumnGroupingInference
+ {
+ ///
+ /// This is effectively a merger of and a
+ /// with support for vector-value columns.
+ ///
+ public class GroupingColumn
+ {
+ public string SuggestedName;
+ public DataKind ItemKind;
+ public ColumnPurpose Purpose;
+ public Range[] Ranges;
+
+ public GroupingColumn(string name, DataKind kind, ColumnPurpose purpose, Range[] ranges)
+ {
+ SuggestedName = name;
+ ItemKind = kind;
+ Purpose = purpose;
+ Ranges = ranges;
+ }
+
+ public TextLoader.Column GenerateTextLoaderColumn()
+ {
+ return new TextLoader.Column(SuggestedName, ItemKind, Ranges);
+ }
+ }
+
+ ///
+ /// Group together the single-valued columns with the same type and purpose and generate column names.
+ ///
+ /// The host environment to use.
+ /// Whether the original file had a header.
+ /// If yes, the fields are used to generate the column
+ /// names, otherwise they are ignored.
+ /// The (detected) column types.
+ /// The (detected) column purposes. Must be parallel to .
+ /// The struct containing an array of grouped columns specifications.
+ public static GroupingColumn[] InferGroupingAndNames(MLContext env, bool hasHeader, ColumnTypeInference.Column[] types, PurposeInference.Column[] purposes)
+ {
+ var result = new List();
+ var tuples = types.Zip(purposes, Tuple.Create).ToList();
+ var grouped =
+ from t in tuples
+ group t by
+ new
+ {
+ t.Item1.ItemType,
+ t.Item2.Purpose,
+ purposeGroupId = GetPurposeGroupId(t.Item1.ColumnIndex, t.Item2.Purpose)
+ }
+ into g
+ select g;
+
+ foreach (var g in grouped)
+ {
+ string name = (hasHeader && g.Count() == 1)
+ ? g.First().Item1.SuggestedName
+ : GetName(g.Key.ItemType.GetRawKind(), g.Key.Purpose, result);
+
+ var ranges = GetRanges(g.Select(t => t.Item1.ColumnIndex).ToArray());
+ result.Add(new GroupingColumn(name, g.Key.ItemType.GetRawKind(), g.Key.Purpose, ranges));
+ }
+
+ return result.ToArray();
+ }
+
+ private static int GetPurposeGroupId(int columnIndex, ColumnPurpose purpose)
+ {
+ if (purpose == ColumnPurpose.CategoricalFeature ||
+ purpose == ColumnPurpose.TextFeature ||
+ purpose == ColumnPurpose.Ignore)
+ return columnIndex;
+ return 0;
+ }
+
+ private static string GetName(DataKind itemKind, ColumnPurpose purpose, List previousColumns)
+ {
+ string prefix = GetPurposeName(purpose, itemKind);
+ int i = 0;
+ string name = prefix;
+ while (previousColumns.Any(x => x.SuggestedName == name))
+ {
+ i++;
+ name = string.Format("{0}{1:00}", prefix, i);
+ }
+
+ return name;
+ }
+
+ private static string GetPurposeName(ColumnPurpose purpose, DataKind itemKind)
+ {
+ switch (purpose)
+ {
+ case ColumnPurpose.NumericFeature:
+ if (itemKind == DataKind.Boolean)
+ {
+ return "BooleanFeatures";
+ }
+ else
+ {
+ return "Features";
+ }
+ case ColumnPurpose.CategoricalFeature:
+ return "Cat";
+ default:
+ return Enum.GetName(typeof(ColumnPurpose), purpose);
+ }
+ }
+
+ ///
+ /// Generates a collection of Ranges from indices.
+ ///
+ private static Range[] GetRanges(int[] indices)
+ {
+ Array.Sort(indices);
+ var allRanges = new List();
+ var currRange = new Range(indices[0]);
+ for (int i = 1; i < indices.Length; i++)
+ {
+ if (indices[i] == currRange.Max + 1)
+ {
+ currRange.Max++;
+ }
+ else
+ {
+ allRanges.Add(currRange);
+ currRange = new Range(indices[i]);
+ }
+ }
+ allRanges.Add(currRange);
+ return allRanges.ToArray();
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs b/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs
new file mode 100644
index 0000000000..6db0fab782
--- /dev/null
+++ b/src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs
@@ -0,0 +1,150 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ internal static class ColumnInferenceApi
+ {
+ public static ColumnInferenceResults InferColumns(MLContext context, string path, uint labelColumnIndex,
+ bool hasHeader, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
+ {
+ var sample = TextFileSample.CreateFromFullFile(path);
+ var splitInference = InferSplit(context, sample, separatorChar, allowQuotedStrings, supportSparse);
+ var typeInference = InferColumnTypes(context, sample, splitInference, hasHeader, labelColumnIndex, null);
+
+ // if no column is named label,
+ // rename label column to default ML.NET label column name
+ if (!typeInference.Columns.Any(c => c.SuggestedName == DefaultColumnNames.Label))
+ {
+ typeInference.Columns[labelColumnIndex].SuggestedName = DefaultColumnNames.Label;
+ }
+
+ var columnInfo = new ColumnInformation() { LabelColumnName = typeInference.Columns[labelColumnIndex].SuggestedName };
+
+ return InferColumns(context, path, columnInfo, hasHeader, splitInference, typeInference, trimWhitespace, groupColumns);
+ }
+
+ public static ColumnInferenceResults InferColumns(MLContext context, string path, string labelColumn,
+ char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
+ {
+ var columnInfo = new ColumnInformation() { LabelColumnName = labelColumn };
+ return InferColumns(context, path, columnInfo, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
+ }
+
+ public static ColumnInferenceResults InferColumns(MLContext context, string path, ColumnInformation columnInfo,
+ char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
+ {
+ var sample = TextFileSample.CreateFromFullFile(path);
+ var splitInference = InferSplit(context, sample, separatorChar, allowQuotedStrings, supportSparse);
+ var typeInference = InferColumnTypes(context, sample, splitInference, true, null, columnInfo.LabelColumnName);
+ return InferColumns(context, path, columnInfo, true, splitInference, typeInference, trimWhitespace, groupColumns);
+ }
+
+ public static ColumnInferenceResults InferColumns(MLContext context, string path, ColumnInformation columnInfo, bool hasHeader,
+ TextFileContents.ColumnSplitResult splitInference, ColumnTypeInference.InferenceResult typeInference,
+ bool trimWhitespace, bool groupColumns)
+ {
+ var loaderColumns = ColumnTypeInference.GenerateLoaderColumns(typeInference.Columns);
+ var typedLoaderOptions = new TextLoader.Options
+ {
+ Columns = loaderColumns,
+ Separators = new[] { splitInference.Separator.Value },
+ AllowSparse = splitInference.AllowSparse,
+ AllowQuoting = splitInference.AllowQuote,
+ HasHeader = hasHeader,
+ TrimWhitespace = trimWhitespace
+ };
+ var textLoader = context.Data.CreateTextLoader(typedLoaderOptions);
+ var dataView = textLoader.Load(path);
+
+ var purposeInferenceResult = PurposeInference.InferPurposes(context, dataView, columnInfo);
+
+ // start building result objects
+ IEnumerable columnResults = null;
+ IEnumerable<(string, ColumnPurpose)> purposeResults = null;
+
+ // infer column grouping and generate column names
+ if (groupColumns)
+ {
+ var groupingResult = ColumnGroupingInference.InferGroupingAndNames(context, hasHeader,
+ typeInference.Columns, purposeInferenceResult);
+
+ columnResults = groupingResult.Select(c => c.GenerateTextLoaderColumn());
+ purposeResults = groupingResult.Select(c => (c.SuggestedName, c.Purpose));
+ }
+ else
+ {
+ columnResults = loaderColumns;
+ purposeResults = purposeInferenceResult.Select(p => (dataView.Schema[p.ColumnIndex].Name, p.Purpose));
+ }
+
+ var textLoaderOptions = new TextLoader.Options()
+ {
+ Columns = columnResults.ToArray(),
+ AllowQuoting = splitInference.AllowQuote,
+ AllowSparse = splitInference.AllowSparse,
+ Separators = new char[] { splitInference.Separator.Value },
+ HasHeader = hasHeader,
+ TrimWhitespace = trimWhitespace
+ };
+
+ return new ColumnInferenceResults()
+ {
+ TextLoaderOptions = textLoaderOptions,
+ ColumnInformation = ColumnInformationUtil.BuildColumnInfo(purposeResults)
+ };
+ }
+
+ private static TextFileContents.ColumnSplitResult InferSplit(MLContext context, TextFileSample sample, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse)
+ {
+ var separatorCandidates = separatorChar == null ? TextFileContents.DefaultSeparators : new char[] { separatorChar.Value };
+ var splitInference = TextFileContents.TrySplitColumns(context, sample, separatorCandidates);
+
+ // respect passed-in overrides
+ if (allowQuotedStrings != null)
+ {
+ splitInference.AllowQuote = allowQuotedStrings.Value;
+ }
+ if (supportSparse != null)
+ {
+ splitInference.AllowSparse = supportSparse.Value;
+ }
+
+ if (!splitInference.IsSuccess)
+ {
+ throw new InferenceException(InferenceType.ColumnSplit, "Unable to split the file provided into multiple, consistent columns.");
+ }
+
+ return splitInference;
+ }
+
+ private static ColumnTypeInference.InferenceResult InferColumnTypes(MLContext context, TextFileSample sample,
+ TextFileContents.ColumnSplitResult splitInference, bool hasHeader, uint? labelColumnIndex, string label)
+ {
+ // infer column types
+ var typeInferenceResult = ColumnTypeInference.InferTextFileColumnTypes(context, sample,
+ new ColumnTypeInference.Arguments
+ {
+ ColumnCount = splitInference.ColumnCount,
+ Separator = splitInference.Separator.Value,
+ AllowSparse = splitInference.AllowSparse,
+ AllowQuote = splitInference.AllowQuote,
+ HasHeader = hasHeader,
+ LabelColumnIndex = labelColumnIndex,
+ Label = label
+ });
+
+ if (!typeInferenceResult.IsSuccess)
+ {
+ throw new InferenceException(InferenceType.ColumnDataKind, "Unable to infer column types of the file provided.");
+ }
+
+ return typeInferenceResult;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs b/src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs
new file mode 100644
index 0000000000..c567730f6f
--- /dev/null
+++ b/src/Microsoft.ML.Auto/ColumnInference/ColumnInformationUtil.cs
@@ -0,0 +1,93 @@
+// Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ internal static class ColumnInformationUtil
+ {
+ internal static ColumnPurpose? GetColumnPurpose(this ColumnInformation columnInfo, string columnName)
+ {
+ if (columnName == columnInfo.LabelColumnName)
+ {
+ return ColumnPurpose.Label;
+ }
+
+ if (columnName == columnInfo.ExampleWeightColumnName)
+ {
+ return ColumnPurpose.Weight;
+ }
+
+ if (columnName == columnInfo.SamplingKeyColumnName)
+ {
+ return ColumnPurpose.SamplingKey;
+ }
+
+ if (columnInfo.CategoricalColumnNames.Contains(columnName))
+ {
+ return ColumnPurpose.CategoricalFeature;
+ }
+
+ if (columnInfo.NumericColumnNames.Contains(columnName))
+ {
+ return ColumnPurpose.NumericFeature;
+ }
+
+ if (columnInfo.TextColumnNames.Contains(columnName))
+ {
+ return ColumnPurpose.TextFeature;
+ }
+
+ if (columnInfo.IgnoredColumnNames.Contains(columnName))
+ {
+ return ColumnPurpose.Ignore;
+ }
+
+ return null;
+ }
+
+ internal static ColumnInformation BuildColumnInfo(IEnumerable<(string name, ColumnPurpose purpose)> columnPurposes)
+ {
+ var columnInfo = new ColumnInformation();
+
+ foreach (var column in columnPurposes)
+ {
+ switch (column.purpose)
+ {
+ case ColumnPurpose.Label:
+ columnInfo.LabelColumnName = column.name;
+ break;
+ case ColumnPurpose.Weight:
+ columnInfo.ExampleWeightColumnName = column.name;
+ break;
+ case ColumnPurpose.SamplingKey:
+ columnInfo.SamplingKeyColumnName = column.name;
+ break;
+ case ColumnPurpose.CategoricalFeature:
+ columnInfo.CategoricalColumnNames.Add(column.name);
+ break;
+ case ColumnPurpose.Ignore:
+ columnInfo.IgnoredColumnNames.Add(column.name);
+ break;
+ case ColumnPurpose.NumericFeature:
+ columnInfo.NumericColumnNames.Add(column.name);
+ break;
+ case ColumnPurpose.TextFeature:
+ columnInfo.TextColumnNames.Add(column.name);
+ break;
+ }
+ }
+
+ return columnInfo;
+ }
+
+ public static ColumnInformation BuildColumnInfo(IEnumerable columns)
+ {
+ return BuildColumnInfo(columns.Select(c => (c.Name, c.Purpose)));
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/ColumnInference/ColumnPurpose.cs b/src/Microsoft.ML.Auto/ColumnInference/ColumnPurpose.cs
new file mode 100644
index 0000000000..45bf787396
--- /dev/null
+++ b/src/Microsoft.ML.Auto/ColumnInference/ColumnPurpose.cs
@@ -0,0 +1,18 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.ML.Auto
+{
+ internal enum ColumnPurpose
+ {
+ Ignore = 0,
+ Label = 1,
+ NumericFeature = 2,
+ CategoricalFeature = 3,
+ TextFeature = 4,
+ Weight = 5,
+ ImagePath = 6,
+ SamplingKey = 7
+ }
+}
diff --git a/src/Microsoft.ML.Auto/ColumnInference/ColumnTypeInference.cs b/src/Microsoft.ML.Auto/ColumnInference/ColumnTypeInference.cs
new file mode 100644
index 0000000000..760ea0bc4a
--- /dev/null
+++ b/src/Microsoft.ML.Auto/ColumnInference/ColumnTypeInference.cs
@@ -0,0 +1,413 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text.RegularExpressions;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ ///
+ /// This class incapsulates logic for automatic inference of column types for the text file.
+ /// It also attempts to guess whether there is a header row.
+ ///
+ internal static class ColumnTypeInference
+ {
+ // Maximum number of columns to invoke type inference.
+ // REVIEW: revisit this requirement. Either work for arbitrary number of columns,
+ // or have a 'dumb' inference that would quickly figure everything out.
+ private const int SmartColumnsLim = 10000;
+
+ internal sealed class Arguments
+ {
+ public char Separator;
+ public bool AllowSparse;
+ public bool AllowQuote;
+ public int ColumnCount;
+ public bool HasHeader;
+ public int MaxRowsToRead;
+ public uint? LabelColumnIndex;
+ public string Label;
+
+ public Arguments()
+ {
+ MaxRowsToRead = 10000;
+ }
+ }
+
+ private class IntermediateColumn
+ {
+ private readonly ReadOnlyMemory[] _data;
+ private readonly int _columnId;
+ private PrimitiveDataViewType _suggestedType;
+ private bool? _hasHeader;
+
+ public int ColumnId
+ {
+ get { return _columnId; }
+ }
+
+ public PrimitiveDataViewType SuggestedType
+ {
+ get { return _suggestedType; }
+ set { _suggestedType = value; }
+ }
+
+ public bool? HasHeader
+ {
+ get { return _hasHeader; }
+ set { _hasHeader = value; }
+ }
+
+ public IntermediateColumn(ReadOnlyMemory[] data, int columnId)
+ {
+ _data = data;
+ _columnId = columnId;
+ }
+
+ public ReadOnlyMemory[] RawData { get { return _data; } }
+
+ public string Name { get; set; }
+
+ public bool HasAllBooleanValues()
+ {
+ if (this.RawData.Skip(1)
+ .All(x =>
+ {
+ bool value;
+ // (note: Conversions.TryParse parses an empty string as a Boolean)
+ return !string.IsNullOrEmpty(x.ToString()) &&
+ Conversions.TryParse(in x, out value);
+ }))
+ {
+ return true;
+ }
+
+ return false;
+ }
+ }
+
+ public class Column
+ {
+ public readonly int ColumnIndex;
+
+ public PrimitiveDataViewType ItemType;
+ public string SuggestedName;
+
+ public Column(int columnIndex, string suggestedName, PrimitiveDataViewType itemType)
+ {
+ ColumnIndex = columnIndex;
+ SuggestedName = suggestedName;
+ ItemType = itemType;
+ }
+ }
+
+ public readonly struct InferenceResult
+ {
+ public readonly Column[] Columns;
+ public readonly bool HasHeader;
+ public readonly bool IsSuccess;
+ public readonly ReadOnlyMemory[][] Data;
+
+ private InferenceResult(bool isSuccess, Column[] columns, bool hasHeader, ReadOnlyMemory[][] data)
+ {
+ IsSuccess = isSuccess;
+ Columns = columns;
+ HasHeader = hasHeader;
+ Data = data;
+ }
+
+ public static InferenceResult Success(Column[] columns, bool hasHeader, ReadOnlyMemory[][] data)
+ {
+ return new InferenceResult(true, columns, hasHeader, data);
+ }
+
+ public static InferenceResult Fail()
+ {
+ return new InferenceResult(false, null, false, null);
+ }
+ }
+
+ private interface ITypeInferenceExpert
+ {
+ void Apply(IntermediateColumn[] columns);
+ }
+
+ ///
+ /// Current design is as follows: there's a sequence of 'experts' that each look at all the columns.
+ /// Every expert may or may not assign the 'answer' (suggested type) to a column. If the expert needs
+ /// some information about the column (for example, the column values), this information is lazily calculated
+ /// by the column object, not the expert itself, to allow the reuse of the same information by another
+ /// expert.
+ ///
+ private static class Experts
+ {
+ internal sealed class BooleanValues : ITypeInferenceExpert
+ {
+ public void Apply(IntermediateColumn[] columns)
+ {
+ foreach (var col in columns)
+ {
+ // skip columns that already have a suggested type,
+ // or that don't have all Boolean values
+ if (col.SuggestedType != null ||
+ !col.HasAllBooleanValues())
+ {
+ continue;
+ }
+
+ col.SuggestedType = BooleanDataViewType.Instance;
+ bool first;
+
+ col.HasHeader = !Conversions.TryParse(in col.RawData[0], out first);
+ }
+ }
+ }
+
+ internal sealed class AllNumericValues : ITypeInferenceExpert
+ {
+ public void Apply(IntermediateColumn[] columns)
+ {
+ foreach (var col in columns)
+ {
+ if (!col.RawData.Skip(1)
+ .All(x =>
+ {
+ float value;
+ return Conversions.TryParse(in x, out value);
+ })
+ )
+ {
+ continue;
+ }
+
+ col.SuggestedType = NumberDataViewType.Single;
+
+ var headerStr = col.RawData[0].ToString();
+ col.HasHeader = !double.TryParse(headerStr, out var doubleVal);
+ }
+ }
+ }
+
+ internal sealed class EverythingText : ITypeInferenceExpert
+ {
+ public void Apply(IntermediateColumn[] columns)
+ {
+ foreach (var col in columns)
+ {
+ if (col.SuggestedType != null)
+ continue;
+
+ col.SuggestedType = TextDataViewType.Instance;
+ col.HasHeader = IsLookLikeHeader(col.RawData[0]);
+ }
+ }
+
+ private bool? IsLookLikeHeader(ReadOnlyMemory value)
+ {
+ var v = value.ToString();
+ if (v.Length > 100)
+ return false;
+ var headerCandidates = new[] { "^Label", "^Feature", "^Market", "^m_", "^Weight" };
+ foreach (var candidate in headerCandidates)
+ {
+ if (Regex.IsMatch(v, candidate, RegexOptions.IgnoreCase))
+ return true;
+ }
+
+ return null;
+ }
+ }
+ }
+
+ private static IEnumerable GetExperts()
+ {
+ // Current logic is pretty primitive: if every value (except the first) of a column
+ // parses as numeric then it's numeric. Else if it parses as a Boolean, it's Boolean. Otherwise, it is text.
+ yield return new Experts.AllNumericValues();
+ yield return new Experts.BooleanValues();
+ yield return new Experts.EverythingText();
+ }
+
+ ///
+ /// Auto-detect column types of the file.
+ ///
+ public static InferenceResult InferTextFileColumnTypes(MLContext context, IMultiStreamSource fileSource, Arguments args)
+ {
+ return InferTextFileColumnTypesCore(context, fileSource, args);
+ }
+
+ private static InferenceResult InferTextFileColumnTypesCore(MLContext context, IMultiStreamSource fileSource, Arguments args)
+ {
+ if (args.ColumnCount == 0)
+ {
+ // too many empty columns for automatic inference
+ return InferenceResult.Fail();
+ }
+
+ if (args.ColumnCount >= SmartColumnsLim)
+ {
+ // too many columns for automatic inference
+ return InferenceResult.Fail();
+ }
+
+ // read the file as the specified number of text columns
+ var textLoaderOptions = new TextLoader.Options
+ {
+ Columns = new[] { new TextLoader.Column("C", DataKind.String, 0, args.ColumnCount - 1) },
+ Separators = new[] { args.Separator },
+ AllowSparse = args.AllowSparse,
+ AllowQuoting = args.AllowQuote,
+ };
+ var textLoader = context.Data.CreateTextLoader(textLoaderOptions);
+ var idv = textLoader.Load(fileSource);
+ idv = context.Data.TakeRows(idv, args.MaxRowsToRead);
+
+ // read all the data into memory.
+ // list items are rows of the dataset.
+ var data = new List[]>();
+ using (var cursor = idv.GetRowCursor(idv.Schema))
+ {
+ var column = cursor.Schema.GetColumnOrNull("C").Value;
+ var colType = column.Type;
+ ValueGetter>> vecGetter = null;
+ ValueGetter> oneGetter = null;
+ bool isVector = colType.IsVector();
+ if (isVector) { vecGetter = cursor.GetGetter>>(column); }
+ else
+ {
+ oneGetter = cursor.GetGetter>(column);
+ }
+
+ VBuffer> line = default;
+ ReadOnlyMemory tsValue = default;
+ while (cursor.MoveNext())
+ {
+ if (isVector)
+ {
+ vecGetter(ref line);
+ var values = new ReadOnlyMemory[args.ColumnCount];
+ line.CopyTo(values);
+ data.Add(values);
+ }
+ else
+ {
+ oneGetter(ref tsValue);
+ var values = new[] { tsValue };
+ data.Add(values);
+ }
+ }
+ }
+
+ if (data.Count < 2)
+ {
+ // too few rows for automatic inference
+ return InferenceResult.Fail();
+ }
+
+ var cols = new IntermediateColumn[args.ColumnCount];
+ for (int i = 0; i < args.ColumnCount; i++)
+ {
+ cols[i] = new IntermediateColumn(data.Select(x => x[i]).ToArray(), i);
+ }
+
+ foreach (var expert in GetExperts())
+ {
+ expert.Apply(cols);
+ }
+
+ // Aggregating header signals.
+ int suspect = 0;
+ var usedNames = new HashSet();
+ for (int i = 0; i < args.ColumnCount; i++)
+ {
+ if (cols[i].HasHeader == true)
+ {
+ if (usedNames.Add(cols[i].RawData[0].ToString()))
+ suspect++;
+ else
+ {
+ // duplicate value in the first column is a strong signal that this is not a header
+ suspect -= args.ColumnCount;
+ }
+ }
+ else if (cols[i].HasHeader == false)
+ suspect--;
+ }
+
+ // suggest names
+ usedNames.Clear();
+ foreach (var col in cols)
+ {
+ string name0;
+ string name;
+ name0 = name = SuggestName(col, args.HasHeader);
+ int i = 0;
+ while (!usedNames.Add(name))
+ {
+ name = string.Format("{0}_{1:00}", name0, i++);
+ }
+ col.Name = name;
+ }
+
+ // validate & retrieve label column
+ var labelColumn = GetAndValidateLabelColumn(args, cols);
+
+ // if label column has all Boolean values, set its type as Boolean
+ if (labelColumn.HasAllBooleanValues())
+ {
+ labelColumn.SuggestedType = BooleanDataViewType.Instance;
+ }
+
+ var outCols = cols.Select(x => new Column(x.ColumnId, x.Name, x.SuggestedType)).ToArray();
+
+ return InferenceResult.Success(outCols, args.HasHeader, cols.Select(col => col.RawData).ToArray());
+ }
+
+ private static string SuggestName(IntermediateColumn column, bool hasHeader)
+ {
+ var header = column.RawData[0].ToString();
+ return (hasHeader && !string.IsNullOrWhiteSpace(header)) ? header : string.Format("col{0}", column.ColumnId);
+ }
+
+ private static IntermediateColumn GetAndValidateLabelColumn(Arguments args, IntermediateColumn[] cols)
+ {
+ IntermediateColumn labelColumn = null;
+ if (args.LabelColumnIndex != null)
+ {
+ // if label column index > inferred # of columns, throw error
+ if (args.LabelColumnIndex >= cols.Count())
+ {
+ throw new ArgumentOutOfRangeException(nameof(args.LabelColumnIndex), $"Label column index ({args.LabelColumnIndex}) is >= than # of inferred columns ({cols.Count()}).");
+ }
+
+ labelColumn = cols[args.LabelColumnIndex.Value];
+ }
+ else
+ {
+ labelColumn = cols.FirstOrDefault(c => c.Name == args.Label);
+ if (labelColumn == null)
+ {
+ throw new ArgumentException($"Specified label column '{args.Label}' was not found.");
+ }
+ }
+
+ return labelColumn;
+ }
+
+ public static TextLoader.Column[] GenerateLoaderColumns(Column[] columns)
+ {
+ var loaderColumns = new List();
+ foreach (var col in columns)
+ {
+ var loaderColumn = new TextLoader.Column(col.SuggestedName, col.ItemType.GetRawKind(), col.ColumnIndex);
+ loaderColumns.Add(loaderColumn);
+ }
+ return loaderColumns.ToArray();
+ }
+ }
+
+}
diff --git a/src/Microsoft.ML.Auto/ColumnInference/PurposeInference.cs b/src/Microsoft.ML.Auto/ColumnInference/PurposeInference.cs
new file mode 100644
index 0000000000..4cdf1e5411
--- /dev/null
+++ b/src/Microsoft.ML.Auto/ColumnInference/PurposeInference.cs
@@ -0,0 +1,283 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ ///
+ /// Automatic inference of column purposes for the data view.
+ /// This is used in the context of text import wizard, but can be used outside as well.
+ ///
+ internal static class PurposeInference
+ {
+ public const int MaxRowsToRead = 1000;
+
+ public class Column
+ {
+ public readonly int ColumnIndex;
+ public readonly ColumnPurpose Purpose;
+
+ public Column(int columnIndex, ColumnPurpose purpose)
+ {
+ ColumnIndex = columnIndex;
+ Purpose = purpose;
+ }
+ }
+
+ ///
+ /// The design is the same as for : there's a sequence of 'experts'
+ /// that each look at all the columns. Every expert may or may not assign the 'answer' (suggested purpose)
+ /// to a column. If the expert needs some information about the column (for example, the column values),
+ /// this information is lazily calculated by the column object, not the expert itself, to allow the reuse
+ /// of the same information by another expert.
+ ///
+ private interface IPurposeInferenceExpert
+ {
+ void Apply(IntermediateColumn[] columns);
+ }
+
+ private class IntermediateColumn
+ {
+ private readonly IDataView _data;
+ private readonly int _columnId;
+ private ColumnPurpose _suggestedPurpose;
+ private readonly Lazy _type;
+ private readonly Lazy _columnName;
+ private IReadOnlyList> _cachedData;
+
+ public bool IsPurposeSuggested { get; private set; }
+
+ public ColumnPurpose SuggestedPurpose
+ {
+ get { return _suggestedPurpose; }
+ set
+ {
+ _suggestedPurpose = value;
+ IsPurposeSuggested = true;
+ }
+ }
+
+ public DataViewType Type { get { return _type.Value; } }
+
+ public string ColumnName { get { return _columnName.Value; } }
+
+ public IntermediateColumn(IDataView data, int columnId, ColumnPurpose suggestedPurpose = ColumnPurpose.Ignore)
+ {
+ _data = data;
+ _columnId = columnId;
+ _type = new Lazy(() => _data.Schema[_columnId].Type);
+ _columnName = new Lazy(() => _data.Schema[_columnId].Name);
+ _suggestedPurpose = suggestedPurpose;
+ }
+
+ public Column GetColumn()
+ {
+ return new Column(_columnId, _suggestedPurpose);
+ }
+
+ public IReadOnlyList> GetColumnData()
+ {
+ if (_cachedData != null)
+ return _cachedData;
+
+ var results = new List>();
+ var column = _data.Schema[_columnId];
+
+ using (var cursor = _data.GetRowCursor(new[] { column }))
+ {
+ var getter = cursor.GetGetter>(column);
+ while (cursor.MoveNext())
+ {
+ var value = default(ReadOnlyMemory);
+ getter(ref value);
+
+ var copy = new ReadOnlyMemory(value.ToArray());
+
+ results.Add(copy);
+ }
+ }
+
+ _cachedData = results;
+
+ return results;
+ }
+ }
+
+ private static class Experts
+ {
+ internal sealed class TextClassification : IPurposeInferenceExpert
+ {
+ public void Apply(IntermediateColumn[] columns)
+ {
+ string[] commonImageExtensions = { ".bmp", ".dib", ".rle", ".jpg", ".jpeg", ".jpe", ".jfif", ".gif", ".tif", ".tiff", ".png" };
+ foreach (var column in columns)
+ {
+ if (column.IsPurposeSuggested || !column.Type.IsText())
+ continue;
+
+ var data = column.GetColumnData();
+
+ long sumLength = 0;
+ int sumSpaces = 0;
+ var seen = new HashSet();
+ int imagePathCount = 0;
+ foreach (var span in data)
+ {
+ sumLength += span.Length;
+ seen.Add(span.ToString());
+ string spanStr = span.ToString();
+ sumSpaces += spanStr.Count(x => x == ' ');
+
+ foreach (var ext in commonImageExtensions)
+ {
+ if (spanStr.EndsWith(ext, StringComparison.OrdinalIgnoreCase))
+ {
+ imagePathCount++;
+ break;
+ }
+ }
+ }
+
+ if (imagePathCount < data.Count - 1)
+ {
+ Double avgLength = 1.0 * sumLength / data.Count;
+ Double cardinalityRatio = 1.0 * seen.Count / data.Count;
+ Double avgSpaces = 1.0 * sumSpaces / data.Count;
+ if (cardinalityRatio < 0.7)
+ column.SuggestedPurpose = ColumnPurpose.CategoricalFeature;
+ // (note: the columns.Count() == 1 condition below, in case a dataset has only
+ // a 'name' and a 'label' column, forces what would be an 'ignore' column to become a text feature)
+ else if (cardinalityRatio >= 0.85 && (avgLength > 30 || avgSpaces >= 1 || columns.Count() == 1))
+ column.SuggestedPurpose = ColumnPurpose.TextFeature;
+ else if (cardinalityRatio >= 0.9)
+ column.SuggestedPurpose = ColumnPurpose.Ignore;
+ }
+ else
+ column.SuggestedPurpose = ColumnPurpose.ImagePath;
+ }
+ }
+ }
+
+ internal sealed class NumericAreFeatures : IPurposeInferenceExpert
+ {
+ public void Apply(IntermediateColumn[] columns)
+ {
+ foreach (var column in columns)
+ {
+ if (column.IsPurposeSuggested)
+ continue;
+ if (column.Type.GetItemType().IsNumber())
+ column.SuggestedPurpose = ColumnPurpose.NumericFeature;
+ }
+ }
+ }
+
+ internal sealed class BooleanProcessing : IPurposeInferenceExpert
+ {
+ public void Apply(IntermediateColumn[] columns)
+ {
+ foreach (var column in columns)
+ {
+ if (column.IsPurposeSuggested)
+ continue;
+ if (column.Type.GetItemType().IsBool())
+ column.SuggestedPurpose = ColumnPurpose.NumericFeature;
+ }
+ }
+ }
+
+ internal sealed class TextArraysAreText : IPurposeInferenceExpert
+ {
+ public void Apply(IntermediateColumn[] columns)
+ {
+ foreach (var column in columns)
+ {
+ if (column.IsPurposeSuggested)
+ continue;
+ if (column.Type.IsVector() && column.Type.GetItemType().IsText())
+ column.SuggestedPurpose = ColumnPurpose.TextFeature;
+ }
+ }
+ }
+
+ internal sealed class IgnoreEverythingElse : IPurposeInferenceExpert
+ {
+ public void Apply(IntermediateColumn[] columns)
+ {
+ foreach (var column in columns)
+ {
+ if (!column.IsPurposeSuggested)
+ column.SuggestedPurpose = ColumnPurpose.Ignore;
+ }
+ }
+ }
+ }
+
+ private static IEnumerable GetExperts()
+ {
+ // Each of the experts respects the decisions of all the experts above.
+
+ // Single-value text columns may be category, name, text or ignore.
+ yield return new Experts.TextClassification();
+ // Vector-value text columns are always treated as text.
+ // REVIEW: could be improved.
+ yield return new Experts.TextArraysAreText();
+ // Check column on boolean only values.
+ yield return new Experts.BooleanProcessing();
+ // All numeric columns are features.
+ yield return new Experts.NumericAreFeatures();
+ // Everything else is ignored.
+ yield return new Experts.IgnoreEverythingElse();
+ }
+
+ ///
+ /// Auto-detect purpose for the data view columns.
+ ///
+ public static PurposeInference.Column[] InferPurposes(MLContext context, IDataView data,
+ ColumnInformation columnInfo)
+ {
+ data = context.Data.TakeRows(data, MaxRowsToRead);
+
+ var allColumns = new List();
+ var columnsToInfer = new List();
+
+ for (var i = 0; i < data.Schema.Count; i++)
+ {
+ var column = data.Schema[i];
+ IntermediateColumn intermediateCol;
+
+ if (column.IsHidden)
+ {
+ intermediateCol = new IntermediateColumn(data, i, ColumnPurpose.Ignore);
+ allColumns.Add(intermediateCol);
+ continue;
+ }
+
+ var columnPurpose = columnInfo.GetColumnPurpose(column.Name);
+ if (columnPurpose == null)
+ {
+ intermediateCol = new IntermediateColumn(data, i);
+ columnsToInfer.Add(intermediateCol);
+ }
+ else
+ {
+ intermediateCol = new IntermediateColumn(data, i, columnPurpose.Value);
+ }
+
+ allColumns.Add(intermediateCol);
+ }
+
+ foreach (var expert in GetExperts())
+ {
+ expert.Apply(columnsToInfer.ToArray());
+ }
+
+ return allColumns.Select(c => c.GetColumn()).ToArray();
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/ColumnInference/TextFileContents.cs b/src/Microsoft.ML.Auto/ColumnInference/TextFileContents.cs
new file mode 100644
index 0000000000..9fed93d27a
--- /dev/null
+++ b/src/Microsoft.ML.Auto/ColumnInference/TextFileContents.cs
@@ -0,0 +1,124 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ ///
+ /// Utilities for various heuristics against text files.
+ /// Currently, separator inference and column count detection.
+ ///
+ internal static class TextFileContents
+ {
+ public class ColumnSplitResult
+ {
+ public readonly bool IsSuccess;
+ public readonly char? Separator;
+ public readonly int ColumnCount;
+
+ public bool AllowQuote { get; set; }
+ public bool AllowSparse { get; set; }
+
+ public ColumnSplitResult(bool isSuccess, char? separator, bool allowQuote, bool allowSparse, int columnCount)
+ {
+ IsSuccess = isSuccess;
+ Separator = separator;
+ AllowQuote = allowQuote;
+ AllowSparse = allowSparse;
+ ColumnCount = columnCount;
+ }
+ }
+
+ // If the fraction of lines having the same number of columns exceeds this, we consider the column count to be known.
+ private const Double UniformColumnCountThreshold = 0.98;
+
+ public static readonly char[] DefaultSeparators = { '\t', ',', ' ', ';' };
+
+ ///
+ /// Attempt to detect text loader arguments.
+ /// The algorithm selects the first 'acceptable' set: the one that recognizes the same number of columns in at
+ /// least of the sample's lines,
+ /// and this number of columns is more than 1.
+ /// We sweep on separator, allow sparse and allow quote parameter.
+ ///
+ public static ColumnSplitResult TrySplitColumns(MLContext context, IMultiStreamSource source, char[] separatorCandidates)
+ {
+ var sparse = new[] { false, true };
+ var quote = new[] { true, false };
+ var foundAny = false;
+ var result = default(ColumnSplitResult);
+ foreach (var perm in (from _allowSparse in sparse
+ from _allowQuote in quote
+ from _sep in separatorCandidates
+ select new { _allowSparse, _allowQuote, _sep }))
+ {
+ var options = new TextLoader.Options
+ {
+ Columns = new[] { new TextLoader.Column() {
+ Name = "C",
+ DataKind = DataKind.String,
+ Source = new[] { new TextLoader.Range(0, null) }
+ } },
+ Separators = new[] { perm._sep },
+ AllowQuoting = perm._allowQuote,
+ AllowSparse = perm._allowSparse
+ };
+
+ if (TryParseFile(context, options, source, out result))
+ {
+ foundAny = true;
+ break;
+ }
+ }
+ return foundAny ? result : new ColumnSplitResult(false, null, true, true, 0);
+ }
+
+ private static bool TryParseFile(MLContext context, TextLoader.Options options, IMultiStreamSource source,
+ out ColumnSplitResult result)
+ {
+ result = null;
+ // try to instantiate data view with swept arguments
+ try
+ {
+ var textLoader = context.Data.CreateTextLoader(options, source);
+ var idv = context.Data.TakeRows(textLoader.Load(source), 1000);
+ var columnCounts = new List();
+ var column = idv.Schema["C"];
+
+ using (var cursor = idv.GetRowCursor(new[] { column }))
+ {
+ var getter = cursor.GetGetter>>(column);
+
+ VBuffer> line = default;
+ while (cursor.MoveNext())
+ {
+ getter(ref line);
+ columnCounts.Add(line.Length);
+ }
+ }
+
+ var mostCommon = columnCounts.GroupBy(x => x).OrderByDescending(x => x.Count()).First();
+ if (mostCommon.Count() < UniformColumnCountThreshold * columnCounts.Count)
+ {
+ return false;
+ }
+
+ // disallow single-column case
+ if (mostCommon.Key <= 1) { return false; }
+
+ result = new ColumnSplitResult(true, options.Separators.First(), options.AllowQuoting, options.AllowSparse, mostCommon.Key);
+ return true;
+ }
+ // fail gracefully if unable to instantiate data view with swept arguments
+ catch(Exception)
+ {
+ return false;
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Auto/ColumnInference/TextFileSample.cs b/src/Microsoft.ML.Auto/ColumnInference/TextFileSample.cs
new file mode 100644
index 0000000000..f528d5a8b8
--- /dev/null
+++ b/src/Microsoft.ML.Auto/ColumnInference/TextFileSample.cs
@@ -0,0 +1,304 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ ///
+ /// This class holds an in-memory sample of the text file, and serves as an proxy to it.
+ ///
+ internal sealed class TextFileSample : IMultiStreamSource
+ {
+ // REVIEW: consider including multiple files via IMultiStreamSource.
+
+ // REVIEW: right now, it expects 0x0A being the trailing character of line break.
+ // Consider a more general implementation.
+
+ private const int BufferSizeMb = 4;
+ private const int FirstChunkSizeMb = 1;
+ private const int LinesPerChunk = 20;
+ private const Double OversamplingRate = 1.1;
+
+ private readonly byte[] _buffer;
+ private readonly long? _fullFileSize;
+ private readonly long? _approximateRowCount;
+
+ private TextFileSample(byte[] buffer, long? fullFileSize, long? lineCount)
+ {
+ _buffer = buffer;
+ _fullFileSize = fullFileSize;
+ _approximateRowCount = lineCount;
+ }
+
+ public int Count
+ {
+ get { return 1; }
+ }
+
+ // Full file size, if known, otherwise, null.
+ public long? FullFileSize
+ {
+ get { return _fullFileSize; }
+ }
+
+ public int SampleSize
+ {
+ get { return _buffer.Length; }
+ }
+
+ public string GetPathOrNull(int index)
+ {
+ //Contracts.Check(index == 0, "Index must be 0");
+ return null;
+ }
+
+ public Stream Open(int index)
+ {
+ //Contracts.Check(index == 0, "Index must be 0");
+ return new MemoryStream(_buffer);
+ }
+
+ public TextReader OpenTextReader(int index)
+ {
+ return new StreamReader(Open(index));
+ }
+
+ public long? ApproximateRowCount => _approximateRowCount;
+
+ public static TextFileSample CreateFromFullFile(string path)
+ {
+ using (var fs = new FileStream(path, FileMode.Open, FileAccess.Read, FileShare.Read))
+ {
+ return CreateFromFullStream(fs);
+ }
+ }
+
+ ///
+ /// Create a by reading multiple chunks from the file (or other source) and
+ /// then stitching them together. The algorithm is as follows:
+ /// 0. If the source is not seekable, revert to .
+ /// 1. If the file length is less than 2 * , revert to .
+ /// 2. Read first MB chunk. Determine average line length in the chunk.
+ /// 3. Determine how large one chunk should be, and how many chunks there should be, to end up
+ /// with * MB worth of lines.
+ /// 4. Determine seek locations and read the chunks.
+ /// 5. Stitch and return a .
+ ///
+ public static TextFileSample CreateFromFullStream(Stream stream)
+ {
+ if (!stream.CanSeek)
+ {
+ return CreateFromHead(stream);
+ }
+ var fileSize = stream.Length;
+
+ if (fileSize <= 2 * BufferSizeMb * (1 << 20))
+ {
+ return CreateFromHead(stream);
+ }
+
+ var firstChunk = new byte[FirstChunkSizeMb * (1 << 20)];
+ int count = stream.Read(firstChunk, 0, firstChunk.Length);
+ if (!IsEncodingOkForSampling(firstChunk))
+ return CreateFromHead(stream);
+ // REVIEW: CreateFromHead still truncates the file before the last 0x0A byte. For multi-byte encoding,
+ // this might cause an unfinished string to be present in the buffer. Right now this is considered an acceptable
+ // price to pay for parse-free processing.
+
+ var lineCount = firstChunk.Count(x => x == '\n');
+ if (lineCount == 0)
+ {
+ throw new ArgumentException("Counldn't identify line breaks. Provided file is not text?");
+ }
+
+ long approximateRowCount = (long)(lineCount * fileSize * 1.0 / firstChunk.Length);
+ var firstNewline = Array.FindIndex(firstChunk, x => x == '\n');
+
+ // First line may be header, so we exclude it. The remaining lineCount-1 line breaks are
+ // splitting the text into lineCount lines, and the last line is actually half-size.
+ Double averageLineLength = 2.0 * (firstChunk.Length - firstNewline) / (lineCount * 2 - 1);
+ averageLineLength = Math.Max(averageLineLength, 3);
+
+ int usefulChunkSize = (int)(averageLineLength * LinesPerChunk);
+ int chunkSize = (int)(usefulChunkSize + averageLineLength); // assuming that 1 line worth will be trimmed out
+
+ int chunkCount = (int)Math.Ceiling((BufferSizeMb * OversamplingRate - FirstChunkSizeMb) * (1 << 20) / usefulChunkSize);
+ int maxChunkCount = (int)Math.Floor((double)(fileSize - firstChunk.Length) / chunkSize);
+ chunkCount = Math.Min(chunkCount, maxChunkCount);
+
+ var chunks = new List();
+ chunks.Add(firstChunk);
+
+ // determine the start of each remaining chunk
+ long fileSizeRemaining = fileSize - firstChunk.Length - ((long)chunkSize) * chunkCount;
+
+ var chunkStartIndices = Enumerable.Range(0, chunkCount)
+ .Select(x => AutoMlUtils.random.Value.NextDouble() * fileSizeRemaining)
+ .OrderBy(x => x)
+ .Select((spot, i) => (long)(spot + firstChunk.Length + i * chunkSize))
+ .ToArray();
+
+ foreach (var chunkStartIndex in chunkStartIndices)
+ {
+ stream.Seek(chunkStartIndex, SeekOrigin.Begin);
+ byte[] chunk = new byte[chunkSize];
+ int readCount = stream.Read(chunk, 0, chunkSize);
+ Array.Resize(ref chunk, chunkSize);
+ chunks.Add(chunk);
+ }
+
+ return new TextFileSample(StitchChunks(false, chunks.ToArray()), fileSize, approximateRowCount);
+ }
+
+ ///
+ /// Create a by reading one chunk from the beginning.
+ ///
+ private static TextFileSample CreateFromHead(Stream stream)
+ {
+ var buf = new byte[BufferSizeMb * (1 << 20)];
+ int readCount = stream.Read(buf, 0, buf.Length);
+ Array.Resize(ref buf, readCount);
+ long? multiplier = stream.CanSeek ? (int?)(stream.Length / buf.Length) : null;
+ return new TextFileSample(StitchChunks(readCount == stream.Length, buf),
+ stream.CanSeek ? (long?)stream.Length : null,
+ multiplier.HasValue ? buf.Count(x => x == '\n') * multiplier : null);
+ }
+
+ ///
+ /// Given an array of chunks of the text file, of which the first chunk is the head,
+ /// this method trims incomplete lines from the beginning and end of each chunk
+ /// (except that it doesn't trim the beginning of the first chunk and end of last chunk if we read whole file),
+ /// then joins the rest together to form a final byte buffer and returns a
+ /// wrapped around it.
+ ///
+ /// did we read whole file
+ /// chunks of data
+ ///
+ private static byte[] StitchChunks(bool wholeFile, params byte[][] chunks)
+ {
+ using (var resultStream = new MemoryStream(BufferSizeMb * (1 << 20)))
+ {
+ for (int i = 0; i < chunks.Length; i++)
+ {
+ int iMin = (i == 0) ? 0 : Array.FindIndex(chunks[i], x => x == '\n') + 1;
+ int iLim = (wholeFile && i == chunks.Length - 1)
+ ? chunks[i].Length
+ : Array.FindLastIndex(chunks[i], x => x == '\n') + 1;
+
+ if (iLim == 0)
+ {
+ //entire buffer is one string, skip
+ continue;
+ }
+
+ resultStream.Write(chunks[i], iMin, iLim - iMin);
+ }
+
+ var resultBuffer = resultStream.ToArray();
+ if (resultBuffer.Length == 0)
+ {
+ throw new ArgumentException("File is not text, or couldn't detect line breaks");
+ }
+
+ return resultBuffer;
+ }
+ }
+
+ ///
+ /// Detect whether we can auto-detect EOL characters without parsing.
+ /// If we do, we can cheaply sample from different file locations and trim the partial strings.
+ /// The encodings that pass the test are UTF8 and all single-byte encodings.
+ ///
+ private static bool IsEncodingOkForSampling(byte[] buffer)
+ {
+ // First check if a BOM/signature exists (sourced from https://www.unicode.org/faq/utf_bom.html#bom4)
+ if (buffer.Length >= 4 && buffer[0] == 0x00 && buffer[1] == 0x00 && buffer[2] == 0xFE && buffer[3] == 0xFF)
+ {
+ // UTF-32, big-endian
+ return false;
+ }
+ if (buffer.Length >= 4 && buffer[0] == 0xFF && buffer[1] == 0xFE && buffer[2] == 0x00 && buffer[3] == 0x00)
+ {
+ // UTF-32, little-endian
+ return false;
+ }
+ if (buffer.Length >= 2 && buffer[0] == 0xFE && buffer[1] == 0xFF)
+ {
+ // UTF-16, big-endian
+ return false;
+ }
+ if (buffer.Length >= 2 && buffer[0] == 0xFF && buffer[1] == 0xFE)
+ {
+ // UTF-16, little-endian
+ return false;
+ }
+ if (buffer.Length >= 3 && buffer[0] == 0xEF && buffer[1] == 0xBB && buffer[2] == 0xBF)
+ {
+ // UTF-8
+ return true;
+ }
+ if (buffer.Length >= 3 && buffer[0] == 0x2b && buffer[1] == 0x2f && buffer[2] == 0x76)
+ {
+ // UTF-7
+ return true;
+ }
+
+ // No BOM/signature was found, so now we need to 'sniff' the file to see if can manually discover the encoding.
+ int sniffLim = Math.Min(1000, buffer.Length);
+
+ // Some text files are encoded in UTF8, but have no BOM/signature. Hence the below manually checks for a UTF8 pattern. This code is based off
+ // the top answer at: https://stackoverflow.com/questions/6555015/check-for-invalid-utf8 .
+ int i = 0;
+ bool utf8 = false;
+ while (i < sniffLim - 4)
+ {
+ if (buffer[i] <= 0x7F)
+ {
+ i += 1;
+ continue;
+ }
+ if (buffer[i] >= 0xC2 && buffer[i] <= 0xDF && buffer[i + 1] >= 0x80 && buffer[i + 1] < 0xC0)
+ {
+ i += 2;
+ utf8 = true;
+ continue;
+ }
+ if (buffer[i] >= 0xE0 && buffer[i] <= 0xF0 && buffer[i + 1] >= 0x80 && buffer[i + 1] < 0xC0 &&
+ buffer[i + 2] >= 0x80 && buffer[i + 2] < 0xC0)
+ {
+ i += 3;
+ utf8 = true;
+ continue;
+ }
+ if (buffer[i] >= 0xF0 && buffer[i] <= 0xF4 && buffer[i + 1] >= 0x80 && buffer[i + 1] < 0xC0 &&
+ buffer[i + 2] >= 0x80 && buffer[i + 2] < 0xC0 && buffer[i + 3] >= 0x80 && buffer[i + 3] < 0xC0)
+ {
+ i += 4;
+ utf8 = true;
+ continue;
+ }
+ utf8 = false;
+ break;
+ }
+ if (utf8)
+ {
+ return true;
+ }
+
+ if (buffer.Take(sniffLim).Any(x => x == 0))
+ {
+ // likely a UTF-16 or UTF-32 without a BOM.
+ return false;
+ }
+
+ // If all else failed, the file is likely in a local 1-byte encoding.
+ return true;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/DatasetDimensions/ColumnDimensions.cs b/src/Microsoft.ML.Auto/DatasetDimensions/ColumnDimensions.cs
new file mode 100644
index 0000000000..78283ac6c5
--- /dev/null
+++ b/src/Microsoft.ML.Auto/DatasetDimensions/ColumnDimensions.cs
@@ -0,0 +1,18 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.ML.Auto
+{
+ internal class ColumnDimensions
+ {
+ public int? Cardinality;
+ public bool? HasMissing;
+
+ public ColumnDimensions(int? cardinality, bool? hasMissing)
+ {
+ Cardinality = cardinality;
+ HasMissing = hasMissing;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/DatasetDimensions/DatasetDimensionsApi.cs b/src/Microsoft.ML.Auto/DatasetDimensions/DatasetDimensionsApi.cs
new file mode 100644
index 0000000000..8d18b5057b
--- /dev/null
+++ b/src/Microsoft.ML.Auto/DatasetDimensions/DatasetDimensionsApi.cs
@@ -0,0 +1,50 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ internal class DatasetDimensionsApi
+ {
+ private const long MaxRowsToRead = 1000;
+
+ public static ColumnDimensions[] CalcColumnDimensions(MLContext context, IDataView data, PurposeInference.Column[] purposes)
+ {
+ data = context.Data.TakeRows(data, MaxRowsToRead);
+
+ var colDimensions = new ColumnDimensions[data.Schema.Count];
+
+ for (var i = 0; i < data.Schema.Count; i++)
+ {
+ var column = data.Schema[i];
+ var purpose = purposes[i];
+
+ // default column dimensions
+ int? cardinality = null;
+ bool? hasMissing = null;
+
+ var itemType = column.Type.GetItemType();
+
+ // If categorical text feature, calculate cardinality
+ if (itemType.IsText() && purpose.Purpose == ColumnPurpose.CategoricalFeature)
+ {
+ cardinality = DatasetDimensionsUtil.GetTextColumnCardinality(data, column);
+ }
+
+ // If numeric feature, discover missing values
+ if (itemType == NumberDataViewType.Single)
+ {
+ hasMissing = column.Type.IsVector() ?
+ DatasetDimensionsUtil.HasMissingNumericVector(data, column) :
+ DatasetDimensionsUtil.HasMissingNumericSingleValue(data, column);
+ }
+
+ colDimensions[i] = new ColumnDimensions(cardinality, hasMissing);
+ }
+
+ return colDimensions;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/DatasetDimensions/DatasetDimensionsUtil.cs b/src/Microsoft.ML.Auto/DatasetDimensions/DatasetDimensionsUtil.cs
new file mode 100644
index 0000000000..c0dea14fbb
--- /dev/null
+++ b/src/Microsoft.ML.Auto/DatasetDimensions/DatasetDimensionsUtil.cs
@@ -0,0 +1,85 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ internal static class DatasetDimensionsUtil
+ {
+ public static int GetTextColumnCardinality(IDataView data, DataViewSchema.Column column)
+ {
+ var seen = new HashSet();
+ using (var cursor = data.GetRowCursor(new[] { column }))
+ {
+ var getter = cursor.GetGetter>(column);
+ while (cursor.MoveNext())
+ {
+ var value = default(ReadOnlyMemory);
+ getter(ref value);
+ var valueStr = value.ToString();
+ seen.Add(valueStr);
+ }
+ }
+ return seen.Count;
+ }
+
+ public static bool HasMissingNumericSingleValue(IDataView data, DataViewSchema.Column column)
+ {
+ using (var cursor = data.GetRowCursor(new[] { column }))
+ {
+ var getter = cursor.GetGetter(column);
+ var value = default(Single);
+ while (cursor.MoveNext())
+ {
+ getter(ref value);
+ if (Single.IsNaN(value))
+ {
+ return true;
+ }
+ }
+ return false;
+ }
+ }
+
+ public static bool HasMissingNumericVector(IDataView data, DataViewSchema.Column column)
+ {
+ using (var cursor = data.GetRowCursor(new[] { column }))
+ {
+ var getter = cursor.GetGetter>(column);
+ var value = default(VBuffer);
+ while (cursor.MoveNext())
+ {
+ getter(ref value);
+ if (VBufferUtils.HasNaNs(value))
+ {
+ return true;
+ }
+ }
+ return false;
+ }
+ }
+
+ public static ulong CountRows(IDataView data, ulong maxRows)
+ {
+ var cursor = data.GetRowCursor(new[] { data.Schema[0] });
+ ulong rowCount = 0;
+ while (cursor.MoveNext())
+ {
+ if (++rowCount == maxRows)
+ {
+ break;
+ }
+ }
+ return rowCount;
+ }
+
+ public static bool IsDataViewEmpty(IDataView data)
+ {
+ return CountRows(data, 1) == 0;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/DebugLogger.cs b/src/Microsoft.ML.Auto/DebugLogger.cs
new file mode 100644
index 0000000000..90ed9cfdd2
--- /dev/null
+++ b/src/Microsoft.ML.Auto/DebugLogger.cs
@@ -0,0 +1,17 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.ML.Auto
+{
+ internal interface IDebugLogger
+ {
+ void Log(LogSeverity logLevel, string message);
+ }
+
+ internal enum LogSeverity
+ {
+ Error,
+ Debug
+ }
+}
diff --git a/src/Microsoft.ML.Auto/EstimatorExtensions/EstimatorExtensionCatalog.cs b/src/Microsoft.ML.Auto/EstimatorExtensions/EstimatorExtensionCatalog.cs
new file mode 100644
index 0000000000..e37acd42ac
--- /dev/null
+++ b/src/Microsoft.ML.Auto/EstimatorExtensions/EstimatorExtensionCatalog.cs
@@ -0,0 +1,49 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+
+namespace Microsoft.ML.Auto
+{
+ internal enum EstimatorName
+ {
+ ColumnConcatenating,
+ ColumnCopying,
+ KeyToValueMapping,
+ MissingValueIndicating,
+ MissingValueReplacing,
+ Normalizing,
+ OneHotEncoding,
+ OneHotHashEncoding,
+ TextFeaturizing,
+ TypeConverting,
+ ValueToKeyMapping
+ }
+
+ internal class EstimatorExtensionCatalog
+ {
+ private static readonly IDictionary _namesToExtensionTypes = new
+ Dictionary()
+ {
+ { EstimatorName.ColumnConcatenating, typeof(ColumnConcatenatingExtension) },
+ { EstimatorName.ColumnCopying, typeof(ColumnCopyingExtension) },
+ { EstimatorName.KeyToValueMapping, typeof(KeyToValueMappingExtension) },
+ { EstimatorName.MissingValueIndicating, typeof(MissingValueIndicatingExtension) },
+ { EstimatorName.MissingValueReplacing, typeof(MissingValueReplacingExtension) },
+ { EstimatorName.Normalizing, typeof(NormalizingExtension) },
+ { EstimatorName.OneHotEncoding, typeof(OneHotEncodingExtension) },
+ { EstimatorName.OneHotHashEncoding, typeof(OneHotHashEncodingExtension) },
+ { EstimatorName.TextFeaturizing, typeof(TextFeaturizingExtension) },
+ { EstimatorName.TypeConverting, typeof(TypeConvertingExtension) },
+ { EstimatorName.ValueToKeyMapping, typeof(ValueToKeyMappingExtension) },
+ };
+
+ public static IEstimatorExtension GetExtension(EstimatorName estimatorName)
+ {
+ var extType = _namesToExtensionTypes[estimatorName];
+ return (IEstimatorExtension)Activator.CreateInstance(extType);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/EstimatorExtensions/EstimatorExtensions.cs b/src/Microsoft.ML.Auto/EstimatorExtensions/EstimatorExtensions.cs
new file mode 100644
index 0000000000..573e24d932
--- /dev/null
+++ b/src/Microsoft.ML.Auto/EstimatorExtensions/EstimatorExtensions.cs
@@ -0,0 +1,272 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data;
+using Microsoft.ML.Transforms;
+
+namespace Microsoft.ML.Auto
+{
+ internal class ColumnConcatenatingExtension : IEstimatorExtension
+ {
+ public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode)
+ {
+ return CreateInstance(context, pipelineNode.InColumns, pipelineNode.OutColumns[0]);
+ }
+
+ public static SuggestedTransform CreateSuggestedTransform(MLContext context, string[] inColumns, string outColumn)
+ {
+ var pipelineNode = new PipelineNode(EstimatorName.ColumnConcatenating.ToString(),
+ PipelineNodeType.Transform, inColumns, outColumn);
+ var estimator = CreateInstance(context, inColumns, outColumn);
+ return new SuggestedTransform(pipelineNode, estimator);
+ }
+
+ private static IEstimator CreateInstance(MLContext context, string[] inColumns, string outColumn)
+ {
+ return context.Transforms.Concatenate(outColumn, inColumns);
+ }
+ }
+
+ internal class ColumnCopyingExtension : IEstimatorExtension
+ {
+ public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode)
+ {
+ return CreateInstance(context, pipelineNode.InColumns[0], pipelineNode.OutColumns[0]);
+ }
+
+ public static SuggestedTransform CreateSuggestedTransform(MLContext context, string inColumn, string outColumn)
+ {
+ var pipelineNode = new PipelineNode(EstimatorName.ColumnCopying.ToString(),
+ PipelineNodeType.Transform, inColumn, outColumn);
+ var estimator = CreateInstance(context, inColumn, outColumn);
+ return new SuggestedTransform(pipelineNode, estimator);
+ }
+
+ private static IEstimator CreateInstance(MLContext context, string inColumn, string outColumn)
+ {
+ return context.Transforms.CopyColumns(outColumn, inColumn);
+ }
+ }
+
+ internal class KeyToValueMappingExtension : IEstimatorExtension
+ {
+ public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode)
+ {
+ return CreateInstance(context, pipelineNode.InColumns[0], pipelineNode.OutColumns[0]);
+ }
+
+ public static SuggestedTransform CreateSuggestedTransform(MLContext context, string inColumn, string outColumn)
+ {
+ var pipelineNode = new PipelineNode(EstimatorName.KeyToValueMapping.ToString(),
+ PipelineNodeType.Transform, inColumn, outColumn);
+ var estimator = CreateInstance(context, inColumn, outColumn);
+ return new SuggestedTransform(pipelineNode, estimator);
+ }
+
+ private static IEstimator CreateInstance(MLContext context, string inColumn, string outColumn)
+ {
+ return context.Transforms.Conversion.MapKeyToValue(outColumn, inColumn);
+ }
+ }
+
+ internal class MissingValueIndicatingExtension : IEstimatorExtension
+ {
+ public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode)
+ {
+ return CreateInstance(context, pipelineNode.InColumns, pipelineNode.OutColumns);
+ }
+
+ public static SuggestedTransform CreateSuggestedTransform(MLContext context, string[] inColumns, string[] outColumns)
+ {
+ var pipelineNode = new PipelineNode(EstimatorName.MissingValueIndicating.ToString(),
+ PipelineNodeType.Transform, inColumns, outColumns);
+ var estimator = CreateInstance(context, inColumns, outColumns);
+ return new SuggestedTransform(pipelineNode, estimator);
+ }
+
+ private static IEstimator CreateInstance(MLContext context, string[] inColumns, string[] outColumns)
+ {
+ var pairs = new InputOutputColumnPair[inColumns.Length];
+ for (var i = 0; i < inColumns.Length; i++)
+ {
+ var pair = new InputOutputColumnPair(outColumns[i], inColumns[i]);
+ pairs[i] = pair;
+ }
+ return context.Transforms.IndicateMissingValues(pairs);
+ }
+ }
+
+ internal class MissingValueReplacingExtension : IEstimatorExtension
+ {
+ public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode)
+ {
+ return CreateInstance(context, pipelineNode.InColumns, pipelineNode.OutColumns);
+ }
+
+ public static SuggestedTransform CreateSuggestedTransform(MLContext context, string[] inColumns, string[] outColumns)
+ {
+ var pipelineNode = new PipelineNode(EstimatorName.MissingValueReplacing.ToString(),
+ PipelineNodeType.Transform, inColumns, outColumns);
+ var estimator = CreateInstance(context, inColumns, outColumns);
+ return new SuggestedTransform(pipelineNode, estimator);
+ }
+
+ private static IEstimator CreateInstance(MLContext context, string[] inColumns, string[] outColumns)
+ {
+ var pairs = new InputOutputColumnPair[inColumns.Length];
+ for (var i = 0; i < inColumns.Length; i++)
+ {
+ var pair = new InputOutputColumnPair(outColumns[i], inColumns[i]);
+ pairs[i] = pair;
+ }
+ return context.Transforms.ReplaceMissingValues(pairs);
+ }
+ }
+
+ internal class NormalizingExtension : IEstimatorExtension
+ {
+ public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode)
+ {
+ return CreateInstance(context, pipelineNode.InColumns[0], pipelineNode.OutColumns[0]);
+ }
+
+ public static SuggestedTransform CreateSuggestedTransform(MLContext context, string inColumn, string outColumn)
+ {
+ var pipelineNode = new PipelineNode(EstimatorName.Normalizing.ToString(),
+ PipelineNodeType.Transform, inColumn, outColumn);
+ var estimator = CreateInstance(context, inColumn, outColumn);
+ return new SuggestedTransform(pipelineNode, estimator);
+ }
+
+ private static IEstimator CreateInstance(MLContext context, string inColumn, string outColumn)
+ {
+ return context.Transforms.NormalizeMinMax(outColumn, inColumn);
+ }
+ }
+
+ internal class OneHotEncodingExtension : IEstimatorExtension
+ {
+ public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode)
+ {
+ return CreateInstance(context, pipelineNode.InColumns, pipelineNode.OutColumns);
+ }
+
+ public static SuggestedTransform CreateSuggestedTransform(MLContext context, string[] inColumns, string[] outColumns)
+ {
+ var pipelineNode = new PipelineNode(EstimatorName.OneHotEncoding.ToString(),
+ PipelineNodeType.Transform, inColumns, outColumns);
+ var estimator = CreateInstance(context, inColumns, outColumns);
+ return new SuggestedTransform(pipelineNode, estimator);
+ }
+
+ public static IEstimator CreateInstance(MLContext context, string[] inColumns, string[] outColumns)
+ {
+ var cols = new InputOutputColumnPair[inColumns.Length];
+ for (var i = 0; i < cols.Length; i++)
+ {
+ cols[i] = new InputOutputColumnPair(outColumns[i], inColumns[i]);
+ }
+ return context.Transforms.Categorical.OneHotEncoding(cols);
+ }
+ }
+
+ internal class OneHotHashEncodingExtension : IEstimatorExtension
+ {
+ public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode)
+ {
+ return CreateInstance(context, pipelineNode.InColumns, pipelineNode.OutColumns);
+ }
+
+ public static SuggestedTransform CreateSuggestedTransform(MLContext context, string inColumn, string outColumn)
+ {
+ return CreateSuggestedTransform(context, new[] { inColumn }, new[] { outColumn });
+ }
+
+ public static SuggestedTransform CreateSuggestedTransform(MLContext context, string[] inColumns, string[] outColumns)
+ {
+ var pipelineNode = new PipelineNode(EstimatorName.OneHotHashEncoding.ToString(),
+ PipelineNodeType.Transform, inColumns, outColumns);
+ var estimator = CreateInstance(context, inColumns, outColumns);
+ return new SuggestedTransform(pipelineNode, estimator);
+ }
+
+ private static IEstimator CreateInstance(MLContext context, string[] inColumns, string[] outColumns)
+ {
+ var cols = new InputOutputColumnPair[inColumns.Length];
+ for (var i = 0; i < cols.Length; i++)
+ {
+ cols[i] = new InputOutputColumnPair(outColumns[i], inColumns[i]);
+ }
+ return context.Transforms.Categorical.OneHotHashEncoding(cols);
+ }
+ }
+
+ internal class TextFeaturizingExtension : IEstimatorExtension
+ {
+ public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode)
+ {
+ return CreateInstance(context, pipelineNode.InColumns[0], pipelineNode.OutColumns[0]);
+ }
+
+ public static SuggestedTransform CreateSuggestedTransform(MLContext context, string inColumn, string outColumn)
+ {
+ var pipelineNode = new PipelineNode(EstimatorName.TextFeaturizing.ToString(),
+ PipelineNodeType.Transform, inColumn, outColumn);
+ var estimator = CreateInstance(context, inColumn, outColumn);
+ return new SuggestedTransform(pipelineNode, estimator);
+ }
+
+ private static IEstimator CreateInstance(MLContext context, string inColumn, string outColumn)
+ {
+ return context.Transforms.Text.FeaturizeText(outColumn, inColumn);
+ }
+ }
+
+ internal class TypeConvertingExtension : IEstimatorExtension
+ {
+ public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode)
+ {
+ return CreateInstance(context, pipelineNode.InColumns, pipelineNode.OutColumns);
+ }
+
+ public static SuggestedTransform CreateSuggestedTransform(MLContext context, string[] inColumns, string[] outColumns)
+ {
+ var pipelineNode = new PipelineNode(EstimatorName.TypeConverting.ToString(),
+ PipelineNodeType.Transform, inColumns, outColumns);
+ var estimator = CreateInstance(context, inColumns, outColumns);
+ return new SuggestedTransform(pipelineNode, estimator);
+ }
+
+ private static IEstimator CreateInstance(MLContext context, string[] inColumns, string[] outColumns)
+ {
+ var cols = new InputOutputColumnPair[inColumns.Length];
+ for (var i = 0; i < cols.Length; i++)
+ {
+ cols[i] = new InputOutputColumnPair(outColumns[i], inColumns[i]);
+ }
+ return context.Transforms.Conversion.ConvertType(cols);
+ }
+ }
+
+ internal class ValueToKeyMappingExtension : IEstimatorExtension
+ {
+ public IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode)
+ {
+ return CreateInstance(context, pipelineNode.InColumns[0], pipelineNode.OutColumns[0]);
+ }
+
+ public static SuggestedTransform CreateSuggestedTransform(MLContext context, string inColumn, string outColumn)
+ {
+ var pipelineNode = new PipelineNode(EstimatorName.ValueToKeyMapping.ToString(),
+ PipelineNodeType.Transform, inColumn, outColumn);
+ var estimator = CreateInstance(context, inColumn, outColumn);
+ return new SuggestedTransform(pipelineNode, estimator);
+ }
+
+ private static IEstimator CreateInstance(MLContext context, string inColumn, string outColumn)
+ {
+ return context.Transforms.Conversion.MapValueToKey(outColumn, inColumn);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/EstimatorExtensions/IEstimatorExtension.cs b/src/Microsoft.ML.Auto/EstimatorExtensions/IEstimatorExtension.cs
new file mode 100644
index 0000000000..9701fc5a15
--- /dev/null
+++ b/src/Microsoft.ML.Auto/EstimatorExtensions/IEstimatorExtension.cs
@@ -0,0 +1,11 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.ML.Auto
+{
+ internal interface IEstimatorExtension
+ {
+ IEstimator CreateInstance(MLContext context, PipelineNode pipelineNode);
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Experiment/Experiment.cs b/src/Microsoft.ML.Auto/Experiment/Experiment.cs
new file mode 100644
index 0000000000..4eb389ab79
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/Experiment.cs
@@ -0,0 +1,150 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.IO;
+using System.Linq;
+
+namespace Microsoft.ML.Auto
+{
+ internal class Experiment where TRunDetail : RunDetail
+ {
+ private readonly MLContext _context;
+ private readonly OptimizingMetricInfo _optimizingMetricInfo;
+ private readonly TaskKind _task;
+ private readonly IProgress _progressCallback;
+ private readonly ExperimentSettings _experimentSettings;
+ private readonly IMetricsAgent _metricsAgent;
+ private readonly IEnumerable _trainerWhitelist;
+ private readonly DirectoryInfo _modelDirectory;
+ private readonly DatasetColumnInfo[] _datasetColumnInfo;
+ private readonly IRunner _runner;
+ private readonly IList _history = new List();
+
+
+ public Experiment(MLContext context,
+ TaskKind task,
+ OptimizingMetricInfo metricInfo,
+ IProgress progressCallback,
+ ExperimentSettings experimentSettings,
+ IMetricsAgent metricsAgent,
+ IEnumerable trainerWhitelist,
+ DatasetColumnInfo[] datasetColumnInfo,
+ IRunner runner)
+ {
+ _context = context;
+ _optimizingMetricInfo = metricInfo;
+ _task = task;
+ _progressCallback = progressCallback;
+ _experimentSettings = experimentSettings;
+ _metricsAgent = metricsAgent;
+ _trainerWhitelist = trainerWhitelist;
+ _modelDirectory = GetModelDirectory(_experimentSettings.CacheDirectory);
+ _datasetColumnInfo = datasetColumnInfo;
+ _runner = runner;
+ }
+
+ public IList Execute()
+ {
+ var stopwatch = Stopwatch.StartNew();
+ var iterationResults = new List();
+
+ do
+ {
+ var iterationStopwatch = Stopwatch.StartNew();
+
+ // get next pipeline
+ var getPiplelineStopwatch = Stopwatch.StartNew();
+ var pipeline = PipelineSuggester.GetNextInferredPipeline(_context, _history, _datasetColumnInfo, _task, _optimizingMetricInfo.IsMaximizing, _trainerWhitelist, _experimentSettings.CacheBeforeTrainer);
+ var pipelineInferenceTimeInSeconds = getPiplelineStopwatch.Elapsed.TotalSeconds;
+
+ // break if no candidates returned, means no valid pipeline available
+ if (pipeline == null)
+ {
+ break;
+ }
+
+ // evaluate pipeline
+ Log(LogSeverity.Debug, $"Evaluating pipeline {pipeline.ToString()}");
+ (SuggestedPipelineRunDetail suggestedPipelineRunDetail, TRunDetail runDetail)
+ = _runner.Run(pipeline, _modelDirectory, _history.Count + 1);
+ _history.Add(suggestedPipelineRunDetail);
+ WriteIterationLog(pipeline, suggestedPipelineRunDetail, iterationStopwatch);
+
+ runDetail.RuntimeInSeconds = iterationStopwatch.Elapsed.TotalSeconds;
+ runDetail.PipelineInferenceTimeInSeconds = getPiplelineStopwatch.Elapsed.TotalSeconds;
+
+ ReportProgress(runDetail);
+ iterationResults.Add(runDetail);
+
+ // if model is perfect, break
+ if (_metricsAgent.IsModelPerfect(suggestedPipelineRunDetail.Score))
+ {
+ break;
+ }
+
+ } while (_history.Count < _experimentSettings.MaxModels &&
+ !_experimentSettings.CancellationToken.IsCancellationRequested &&
+ stopwatch.Elapsed.TotalSeconds < _experimentSettings.MaxExperimentTimeInSeconds);
+
+ return iterationResults;
+ }
+
+ private static DirectoryInfo GetModelDirectory(DirectoryInfo rootDir)
+ {
+ if (rootDir == null)
+ {
+ return null;
+ }
+ var subdirs = rootDir.Exists ?
+ new HashSet(rootDir.EnumerateDirectories().Select(d => d.Name)) :
+ new HashSet();
+ string experimentDir;
+ for (var i = 0; ; i++)
+ {
+ experimentDir = $"experiment{i}";
+ if (!subdirs.Contains(experimentDir))
+ {
+ break;
+ }
+ }
+ var experimentDirFullPath = Path.Combine(rootDir.FullName, experimentDir);
+ var experimentDirInfo = new DirectoryInfo(experimentDirFullPath);
+ if (!experimentDirInfo.Exists)
+ {
+ experimentDirInfo.Create();
+ }
+ return experimentDirInfo;
+ }
+
+ private void ReportProgress(TRunDetail iterationResult)
+ {
+ try
+ {
+ _progressCallback?.Report(iterationResult);
+ }
+ catch (Exception ex)
+ {
+ Log(LogSeverity.Error, $"Progress report callback reported exception {ex}");
+ }
+ }
+
+ private void WriteIterationLog(SuggestedPipeline pipeline, SuggestedPipelineRunDetail runResult, Stopwatch stopwatch)
+ {
+ Log(LogSeverity.Debug, $"{_history.Count}\t{runResult.Score}\t{stopwatch.Elapsed}\t{pipeline.ToString()}");
+ }
+
+ private void Log(LogSeverity severity, string message)
+ {
+ if(_experimentSettings?.DebugLogger == null)
+ {
+ return;
+ }
+
+ _experimentSettings.DebugLogger.Log(severity, message);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Experiment/MetricsAgents/BinaryMetricsAgent.cs b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/BinaryMetricsAgent.cs
new file mode 100644
index 0000000000..4b23abd0cd
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/BinaryMetricsAgent.cs
@@ -0,0 +1,86 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ internal class BinaryMetricsAgent : IMetricsAgent
+ {
+ private readonly MLContext _mlContext;
+ private readonly BinaryClassificationMetric _optimizingMetric;
+
+ public BinaryMetricsAgent(MLContext mlContext,
+ BinaryClassificationMetric optimizingMetric)
+ {
+ _mlContext = mlContext;
+ _optimizingMetric = optimizingMetric;
+ }
+
+ public double GetScore(BinaryClassificationMetrics metrics)
+ {
+ if (metrics == null)
+ {
+ return double.NaN;
+ }
+
+ switch (_optimizingMetric)
+ {
+ case BinaryClassificationMetric.Accuracy:
+ return metrics.Accuracy;
+ case BinaryClassificationMetric.AreaUnderRocCurve:
+ return metrics.AreaUnderRocCurve;
+ case BinaryClassificationMetric.AreaUnderPrecisionRecallCurve:
+ return metrics.AreaUnderPrecisionRecallCurve;
+ case BinaryClassificationMetric.F1Score:
+ return metrics.F1Score;
+ case BinaryClassificationMetric.NegativePrecision:
+ return metrics.NegativePrecision;
+ case BinaryClassificationMetric.NegativeRecall:
+ return metrics.NegativeRecall;
+ case BinaryClassificationMetric.PositivePrecision:
+ return metrics.PositivePrecision;
+ case BinaryClassificationMetric.PositiveRecall:
+ return metrics.PositiveRecall;
+ default:
+ throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric);
+ }
+ }
+
+ public bool IsModelPerfect(double score)
+ {
+ if (double.IsNaN(score))
+ {
+ return false;
+ }
+
+ switch (_optimizingMetric)
+ {
+ case BinaryClassificationMetric.Accuracy:
+ return score == 1;
+ case BinaryClassificationMetric.AreaUnderRocCurve:
+ return score == 1;
+ case BinaryClassificationMetric.AreaUnderPrecisionRecallCurve:
+ return score == 1;
+ case BinaryClassificationMetric.F1Score:
+ return score == 1;
+ case BinaryClassificationMetric.NegativePrecision:
+ return score == 1;
+ case BinaryClassificationMetric.NegativeRecall:
+ return score == 1;
+ case BinaryClassificationMetric.PositivePrecision:
+ return score == 1;
+ case BinaryClassificationMetric.PositiveRecall:
+ return score == 1;
+ default:
+ throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric);
+ }
+ }
+
+ public BinaryClassificationMetrics EvaluateMetrics(IDataView data, string labelColumn)
+ {
+ return _mlContext.BinaryClassification.EvaluateNonCalibrated(data, labelColumn);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Experiment/MetricsAgents/IMetricsAgent.cs b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/IMetricsAgent.cs
new file mode 100644
index 0000000000..d1605aac18
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/IMetricsAgent.cs
@@ -0,0 +1,15 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.ML.Auto
+{
+ internal interface IMetricsAgent
+ {
+ double GetScore(T metrics);
+
+ bool IsModelPerfect(double score);
+
+ T EvaluateMetrics(IDataView data, string labelColumn);
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Experiment/MetricsAgents/MetricsAgentUtil.cs b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/MetricsAgentUtil.cs
new file mode 100644
index 0000000000..80d292648f
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/MetricsAgentUtil.cs
@@ -0,0 +1,16 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.ML.Auto
+{
+ internal static class MetricsAgentUtil
+ {
+ public static NotSupportedException BuildMetricNotSupportedException(T optimizingMetric)
+ {
+ return new NotSupportedException($"{optimizingMetric} is not a supported sweep metric");
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Experiment/MetricsAgents/MultiMetricsAgent.cs b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/MultiMetricsAgent.cs
new file mode 100644
index 0000000000..eb625c02a0
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/MultiMetricsAgent.cs
@@ -0,0 +1,74 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ internal class MultiMetricsAgent : IMetricsAgent
+ {
+ private readonly MLContext _mlContext;
+ private readonly MulticlassClassificationMetric _optimizingMetric;
+
+ public MultiMetricsAgent(MLContext mlContext,
+ MulticlassClassificationMetric optimizingMetric)
+ {
+ _mlContext = mlContext;
+ _optimizingMetric = optimizingMetric;
+ }
+
+ public double GetScore(MulticlassClassificationMetrics metrics)
+ {
+ if (metrics == null)
+ {
+ return double.NaN;
+ }
+
+ switch (_optimizingMetric)
+ {
+ case MulticlassClassificationMetric.MacroAccuracy:
+ return metrics.MacroAccuracy;
+ case MulticlassClassificationMetric.MicroAccuracy:
+ return metrics.MicroAccuracy;
+ case MulticlassClassificationMetric.LogLoss:
+ return metrics.LogLoss;
+ case MulticlassClassificationMetric.LogLossReduction:
+ return metrics.LogLossReduction;
+ case MulticlassClassificationMetric.TopKAccuracy:
+ return metrics.TopKAccuracy;
+ default:
+ throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric);
+ }
+ }
+
+ public bool IsModelPerfect(double score)
+ {
+ if (double.IsNaN(score))
+ {
+ return false;
+ }
+
+ switch (_optimizingMetric)
+ {
+ case MulticlassClassificationMetric.MacroAccuracy:
+ return score == 1;
+ case MulticlassClassificationMetric.MicroAccuracy:
+ return score == 1;
+ case MulticlassClassificationMetric.LogLoss:
+ return score == 0;
+ case MulticlassClassificationMetric.LogLossReduction:
+ return score == 1;
+ case MulticlassClassificationMetric.TopKAccuracy:
+ return score == 1;
+ default:
+ throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric);
+ }
+ }
+
+ public MulticlassClassificationMetrics EvaluateMetrics(IDataView data, string labelColumn)
+ {
+ return _mlContext.MulticlassClassification.Evaluate(data, labelColumn);
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Auto/Experiment/MetricsAgents/RegressionMetricsAgent.cs b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/RegressionMetricsAgent.cs
new file mode 100644
index 0000000000..9350acd643
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/MetricsAgents/RegressionMetricsAgent.cs
@@ -0,0 +1,69 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ internal class RegressionMetricsAgent : IMetricsAgent
+ {
+ private readonly MLContext _mlContext;
+ private readonly RegressionMetric _optimizingMetric;
+
+ public RegressionMetricsAgent(MLContext mlContext, RegressionMetric optimizingMetric)
+ {
+ _mlContext = mlContext;
+ _optimizingMetric = optimizingMetric;
+ }
+
+ public double GetScore(RegressionMetrics metrics)
+ {
+ if (metrics == null)
+ {
+ return double.NaN;
+ }
+
+ switch (_optimizingMetric)
+ {
+ case RegressionMetric.MeanAbsoluteError:
+ return metrics.MeanAbsoluteError;
+ case RegressionMetric.MeanSquaredError:
+ return metrics.MeanSquaredError;
+ case RegressionMetric.RootMeanSquaredError:
+ return metrics.RootMeanSquaredError;
+ case RegressionMetric.RSquared:
+ return metrics.RSquared;
+ default:
+ throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric);
+ }
+ }
+
+ public bool IsModelPerfect(double score)
+ {
+ if (double.IsNaN(score))
+ {
+ return false;
+ }
+
+ switch (_optimizingMetric)
+ {
+ case RegressionMetric.MeanAbsoluteError:
+ return score == 0;
+ case RegressionMetric.MeanSquaredError:
+ return score == 0;
+ case RegressionMetric.RootMeanSquaredError:
+ return score == 0;
+ case RegressionMetric.RSquared:
+ return score == 1;
+ default:
+ throw MetricsAgentUtil.BuildMetricNotSupportedException(_optimizingMetric);
+ }
+ }
+
+ public RegressionMetrics EvaluateMetrics(IDataView data, string labelColumn)
+ {
+ return _mlContext.Regression.Evaluate(data, labelColumn);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Experiment/ModelContainer.cs b/src/Microsoft.ML.Auto/Experiment/ModelContainer.cs
new file mode 100644
index 0000000000..775471fdfe
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/ModelContainer.cs
@@ -0,0 +1,50 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.IO;
+
+namespace Microsoft.ML.Auto
+{
+ internal class ModelContainer
+ {
+ private readonly MLContext _mlContext;
+ private readonly FileInfo _fileInfo;
+ private readonly ITransformer _model;
+
+ internal ModelContainer(MLContext mlContext, ITransformer model)
+ {
+ _mlContext = mlContext;
+ _model = model;
+ }
+
+ internal ModelContainer(MLContext mlContext, FileInfo fileInfo, ITransformer model, DataViewSchema modelInputSchema)
+ {
+ _mlContext = mlContext;
+ _fileInfo = fileInfo;
+
+ // Write model to disk
+ using (var fs = File.Create(fileInfo.FullName))
+ {
+ _mlContext.Model.Save(model, modelInputSchema, fs);
+ }
+ }
+
+ public ITransformer GetModel()
+ {
+ // If model stored in memory, return it
+ if (_model != null)
+ {
+ return _model;
+ }
+
+ // Load model from disk
+ ITransformer model;
+ using (var stream = new FileStream(_fileInfo.FullName, FileMode.Open, FileAccess.Read, FileShare.Read))
+ {
+ model = _mlContext.Model.Load(stream, out var modelInputSchema);
+ }
+ return model;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Experiment/OptimizingMetricInfo.cs b/src/Microsoft.ML.Auto/Experiment/OptimizingMetricInfo.cs
new file mode 100644
index 0000000000..54f23fe1be
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/OptimizingMetricInfo.cs
@@ -0,0 +1,44 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Linq;
+
+namespace Microsoft.ML.Auto
+{
+ internal sealed class OptimizingMetricInfo
+ {
+ public bool IsMaximizing { get; }
+
+ private static RegressionMetric[] _minimizingRegressionMetrics = new RegressionMetric[]
+ {
+ RegressionMetric.MeanAbsoluteError,
+ RegressionMetric.MeanSquaredError,
+ RegressionMetric.RootMeanSquaredError
+ };
+
+ private static BinaryClassificationMetric[] _minimizingBinaryMetrics = new BinaryClassificationMetric[]
+ {
+ };
+
+ private static MulticlassClassificationMetric[] _minimizingMulticlassMetrics = new MulticlassClassificationMetric[]
+ {
+ MulticlassClassificationMetric.LogLoss,
+ };
+
+ public OptimizingMetricInfo(RegressionMetric regressionMetric)
+ {
+ IsMaximizing = !_minimizingRegressionMetrics.Contains(regressionMetric);
+ }
+
+ public OptimizingMetricInfo(BinaryClassificationMetric binaryMetric)
+ {
+ IsMaximizing = !_minimizingBinaryMetrics.Contains(binaryMetric);
+ }
+
+ public OptimizingMetricInfo(MulticlassClassificationMetric multiMetric)
+ {
+ IsMaximizing = !_minimizingMulticlassMetrics.Contains(multiMetric);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Experiment/RecipeInference.cs b/src/Microsoft.ML.Auto/Experiment/RecipeInference.cs
new file mode 100644
index 0000000000..c680a4e96b
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/RecipeInference.cs
@@ -0,0 +1,29 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+
+namespace Microsoft.ML.Auto
+{
+ internal static class RecipeInference
+ {
+ ///
+ /// Given a predictor type, return a set of all permissible trainers (with their sweeper params, if defined).
+ ///
+ /// Array of viable learners.
+ public static IEnumerable AllowedTrainers(MLContext mlContext, TaskKind task,
+ ColumnInformation columnInfo, IEnumerable trainerWhitelist)
+ {
+ var trainerExtensions = TrainerExtensionCatalog.GetTrainers(task, trainerWhitelist);
+
+ var trainers = new List();
+ foreach (var trainerExtension in trainerExtensions)
+ {
+ var learner = new SuggestedTrainer(mlContext, trainerExtension, columnInfo);
+ trainers.Add(learner);
+ }
+ return trainers.ToArray();
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Experiment/Runners/CrossValRunner.cs b/src/Microsoft.ML.Auto/Experiment/Runners/CrossValRunner.cs
new file mode 100644
index 0000000000..f211b3107f
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/Runners/CrossValRunner.cs
@@ -0,0 +1,74 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+
+namespace Microsoft.ML.Auto
+{
+ internal class CrossValRunner : IRunner>
+ where TMetrics : class
+ {
+ private readonly MLContext _context;
+ private readonly IDataView[] _trainDatasets;
+ private readonly IDataView[] _validDatasets;
+ private readonly IMetricsAgent _metricsAgent;
+ private readonly IEstimator _preFeaturizer;
+ private readonly ITransformer[] _preprocessorTransforms;
+ private readonly string _labelColumn;
+ private readonly IDebugLogger _logger;
+ private readonly DataViewSchema _modelInputSchema;
+
+ public CrossValRunner(MLContext context,
+ IDataView[] trainDatasets,
+ IDataView[] validDatasets,
+ IMetricsAgent metricsAgent,
+ IEstimator preFeaturizer,
+ ITransformer[] preprocessorTransforms,
+ string labelColumn,
+ IDebugLogger logger)
+ {
+ _context = context;
+ _trainDatasets = trainDatasets;
+ _validDatasets = validDatasets;
+ _metricsAgent = metricsAgent;
+ _preFeaturizer = preFeaturizer;
+ _preprocessorTransforms = preprocessorTransforms;
+ _labelColumn = labelColumn;
+ _logger = logger;
+ _modelInputSchema = trainDatasets[0].Schema;
+ }
+
+ public (SuggestedPipelineRunDetail suggestedPipelineRunDetail, CrossValidationRunDetail runDetail)
+ Run(SuggestedPipeline pipeline, DirectoryInfo modelDirectory, int iterationNum)
+ {
+ var trainResults = new List>();
+
+ for (var i = 0; i < _trainDatasets.Length; i++)
+ {
+ var modelFileInfo = RunnerUtil.GetModelFileInfo(modelDirectory, iterationNum, i + 1);
+ var trainResult = RunnerUtil.TrainAndScorePipeline(_context, pipeline, _trainDatasets[i], _validDatasets[i],
+ _labelColumn, _metricsAgent, _preprocessorTransforms?[i], modelFileInfo, _modelInputSchema, _logger);
+ trainResults.Add(new SuggestedPipelineTrainResult(trainResult.model, trainResult.metrics, trainResult.exception, trainResult.score));
+ }
+
+ var avgScore = CalcAverageScore(trainResults.Select(r => r.Score));
+ var allRunsSucceeded = trainResults.All(r => r.Exception == null);
+
+ var suggestedPipelineRunDetail = new SuggestedPipelineCrossValRunDetail(pipeline, avgScore, allRunsSucceeded, trainResults);
+ var runDetail = suggestedPipelineRunDetail.ToIterationResult(_preFeaturizer);
+ return (suggestedPipelineRunDetail, runDetail);
+ }
+
+ private static double CalcAverageScore(IEnumerable scores)
+ {
+ if (scores.Any(s => double.IsNaN(s)))
+ {
+ return double.NaN;
+ }
+ return scores.Average();
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Experiment/Runners/CrossValSummaryRunner.cs b/src/Microsoft.ML.Auto/Experiment/Runners/CrossValSummaryRunner.cs
new file mode 100644
index 0000000000..701baa1663
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/Runners/CrossValSummaryRunner.cs
@@ -0,0 +1,101 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+
+namespace Microsoft.ML.Auto
+{
+ internal class CrossValSummaryRunner : IRunner>
+ where TMetrics : class
+ {
+ private readonly MLContext _context;
+ private readonly IDataView[] _trainDatasets;
+ private readonly IDataView[] _validDatasets;
+ private readonly IMetricsAgent _metricsAgent;
+ private readonly IEstimator _preFeaturizer;
+ private readonly ITransformer[] _preprocessorTransforms;
+ private readonly string _labelColumn;
+ private readonly OptimizingMetricInfo _optimizingMetricInfo;
+ private readonly IDebugLogger _logger;
+ private readonly DataViewSchema _modelInputSchema;
+
+ public CrossValSummaryRunner(MLContext context,
+ IDataView[] trainDatasets,
+ IDataView[] validDatasets,
+ IMetricsAgent metricsAgent,
+ IEstimator preFeaturizer,
+ ITransformer[] preprocessorTransforms,
+ string labelColumn,
+ OptimizingMetricInfo optimizingMetricInfo,
+ IDebugLogger logger)
+ {
+ _context = context;
+ _trainDatasets = trainDatasets;
+ _validDatasets = validDatasets;
+ _metricsAgent = metricsAgent;
+ _preFeaturizer = preFeaturizer;
+ _preprocessorTransforms = preprocessorTransforms;
+ _labelColumn = labelColumn;
+ _optimizingMetricInfo = optimizingMetricInfo;
+ _logger = logger;
+ _modelInputSchema = trainDatasets[0].Schema;
+ }
+
+ public (SuggestedPipelineRunDetail suggestedPipelineRunDetail, RunDetail runDetail)
+ Run(SuggestedPipeline pipeline, DirectoryInfo modelDirectory, int iterationNum)
+ {
+ var trainResults = new List<(ModelContainer model, TMetrics metrics, Exception exception, double score)>();
+
+ for (var i = 0; i < _trainDatasets.Length; i++)
+ {
+ var modelFileInfo = RunnerUtil.GetModelFileInfo(modelDirectory, iterationNum, i + 1);
+ var trainResult = RunnerUtil.TrainAndScorePipeline(_context, pipeline, _trainDatasets[i], _validDatasets[i],
+ _labelColumn, _metricsAgent, _preprocessorTransforms?.ElementAt(i), modelFileInfo, _modelInputSchema,
+ _logger);
+ trainResults.Add(trainResult);
+ }
+
+ var allRunsSucceeded = trainResults.All(r => r.exception == null);
+ if (!allRunsSucceeded)
+ {
+ var firstException = trainResults.First(r => r.exception != null).exception;
+ var errorRunDetail = new SuggestedPipelineRunDetail(pipeline, double.NaN, false, null, null, firstException);
+ return (errorRunDetail, errorRunDetail.ToIterationResult(_preFeaturizer));
+ }
+
+ // Get the model from the best fold
+ var bestFoldIndex = BestResultUtil.GetIndexOfBestScore(trainResults.Select(r => r.score), _optimizingMetricInfo.IsMaximizing);
+ var bestModel = trainResults.ElementAt(bestFoldIndex).model;
+
+ // Get the metrics from the fold whose score is closest to avg of all fold scores
+ var avgScore = trainResults.Average(r => r.score);
+ var indexClosestToAvg = GetIndexClosestToAverage(trainResults.Select(r => r.score), avgScore);
+ var metricsClosestToAvg = trainResults[indexClosestToAvg].metrics;
+
+ // Build result objects
+ var suggestedPipelineRunDetail = new SuggestedPipelineRunDetail(pipeline, avgScore, allRunsSucceeded, metricsClosestToAvg, bestModel, null);
+ var runDetail = suggestedPipelineRunDetail.ToIterationResult(_preFeaturizer);
+ return (suggestedPipelineRunDetail, runDetail);
+ }
+
+ private static int GetIndexClosestToAverage(IEnumerable values, double average)
+ {
+ int avgFoldIndex = -1;
+ var smallestDistFromAvg = double.PositiveInfinity;
+ for (var i = 0; i < values.Count(); i++)
+ {
+ var distFromAvg = Math.Abs(values.ElementAt(i) - average);
+ if (distFromAvg < smallestDistFromAvg)
+ {
+ smallestDistFromAvg = distFromAvg;
+ avgFoldIndex = i;
+ }
+ }
+ return avgFoldIndex;
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Auto/Experiment/Runners/IRunner.cs b/src/Microsoft.ML.Auto/Experiment/Runners/IRunner.cs
new file mode 100644
index 0000000000..8bb56fc9d0
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/Runners/IRunner.cs
@@ -0,0 +1,14 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.IO;
+
+namespace Microsoft.ML.Auto
+{
+ internal interface IRunner where TRunDetail : RunDetail
+ {
+ (SuggestedPipelineRunDetail suggestedPipelineRunDetail, TRunDetail runDetail)
+ Run (SuggestedPipeline pipeline, DirectoryInfo modelDirectory, int iterationNum);
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Auto/Experiment/Runners/RunnerUtil.cs b/src/Microsoft.ML.Auto/Experiment/Runners/RunnerUtil.cs
new file mode 100644
index 0000000000..88575a2280
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/Runners/RunnerUtil.cs
@@ -0,0 +1,59 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.IO;
+
+namespace Microsoft.ML.Auto
+{
+ internal static class RunnerUtil
+ {
+ public static (ModelContainer model, TMetrics metrics, Exception exception, double score)
+ TrainAndScorePipeline(MLContext context,
+ SuggestedPipeline pipeline,
+ IDataView trainData,
+ IDataView validData,
+ string labelColumn,
+ IMetricsAgent metricsAgent,
+ ITransformer preprocessorTransform,
+ FileInfo modelFileInfo,
+ DataViewSchema modelInputSchema,
+ IDebugLogger logger) where TMetrics : class
+ {
+ try
+ {
+ var estimator = pipeline.ToEstimator();
+ var model = estimator.Fit(trainData);
+
+ var scoredData = model.Transform(validData);
+ var metrics = metricsAgent.EvaluateMetrics(scoredData, labelColumn);
+ var score = metricsAgent.GetScore(metrics);
+
+ if (preprocessorTransform != null)
+ {
+ model = preprocessorTransform.Append(model);
+ }
+
+ // Build container for model
+ var modelContainer = modelFileInfo == null ?
+ new ModelContainer(context, model) :
+ new ModelContainer(context, modelFileInfo, model, modelInputSchema);
+
+ return (modelContainer, metrics, null, score);
+ }
+ catch (Exception ex)
+ {
+ logger?.Log(LogSeverity.Error, $"Pipeline crashed: {pipeline.ToString()} . Exception: {ex}");
+ return (null, null, ex, double.NaN);
+ }
+ }
+
+ public static FileInfo GetModelFileInfo(DirectoryInfo modelDirectory, int iterationNum, int foldNum)
+ {
+ return modelDirectory == null ?
+ null :
+ new FileInfo(Path.Combine(modelDirectory.FullName, $"Model{iterationNum}_{foldNum}.zip"));
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Auto/Experiment/Runners/TrainValidateRunner.cs b/src/Microsoft.ML.Auto/Experiment/Runners/TrainValidateRunner.cs
new file mode 100644
index 0000000000..9226dcbeb0
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/Runners/TrainValidateRunner.cs
@@ -0,0 +1,65 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.IO;
+
+namespace Microsoft.ML.Auto
+{
+ internal class TrainValidateRunner : IRunner>
+ where TMetrics : class
+ {
+ private readonly MLContext _context;
+ private readonly IDataView _trainData;
+ private readonly IDataView _validData;
+ private readonly string _labelColumn;
+ private readonly IMetricsAgent _metricsAgent;
+ private readonly IEstimator _preFeaturizer;
+ private readonly ITransformer _preprocessorTransform;
+ private readonly IDebugLogger _logger;
+ private readonly DataViewSchema _modelInputSchema;
+
+ public TrainValidateRunner(MLContext context,
+ IDataView trainData,
+ IDataView validData,
+ string labelColumn,
+ IMetricsAgent metricsAgent,
+ IEstimator preFeaturizer,
+ ITransformer preprocessorTransform,
+ IDebugLogger logger)
+ {
+ _context = context;
+ _trainData = trainData;
+ _validData = validData;
+ _labelColumn = labelColumn;
+ _metricsAgent = metricsAgent;
+ _preFeaturizer = preFeaturizer;
+ _preprocessorTransform = preprocessorTransform;
+ _logger = logger;
+ _modelInputSchema = trainData.Schema;
+ }
+
+ public (SuggestedPipelineRunDetail suggestedPipelineRunDetail, RunDetail runDetail)
+ Run(SuggestedPipeline pipeline, DirectoryInfo modelDirectory, int iterationNum)
+ {
+ var modelFileInfo = GetModelFileInfo(modelDirectory, iterationNum);
+ var trainResult = RunnerUtil.TrainAndScorePipeline(_context, pipeline, _trainData, _validData,
+ _labelColumn, _metricsAgent, _preprocessorTransform, modelFileInfo, _modelInputSchema, _logger);
+ var suggestedPipelineRunDetail = new SuggestedPipelineRunDetail(pipeline,
+ trainResult.score,
+ trainResult.exception == null,
+ trainResult.metrics,
+ trainResult.model,
+ trainResult.exception);
+ var runDetail = suggestedPipelineRunDetail.ToIterationResult(_preFeaturizer);
+ return (suggestedPipelineRunDetail, runDetail);
+ }
+
+ private static FileInfo GetModelFileInfo(DirectoryInfo modelDirectory, int iterationNum)
+ {
+ return modelDirectory == null ?
+ null :
+ new FileInfo(Path.Combine(modelDirectory.FullName, $"Model{iterationNum}.zip"));
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/Microsoft.ML.Auto/Experiment/SuggestedPipeline.cs b/src/Microsoft.ML.Auto/Experiment/SuggestedPipeline.cs
new file mode 100644
index 0000000000..6019c9c9ac
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/SuggestedPipeline.cs
@@ -0,0 +1,144 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ ///
+ /// A runnable pipeline. Contains a learner and set of transforms,
+ /// along with a RunSummary if it has already been exectued.
+ ///
+ internal class SuggestedPipeline
+ {
+ public readonly IList Transforms;
+ public readonly SuggestedTrainer Trainer;
+ public readonly IList TransformsPostTrainer;
+
+ private readonly MLContext _context;
+ private readonly bool _cacheBeforeTrainer;
+
+ public SuggestedPipeline(IEnumerable transforms,
+ IEnumerable transformsPostTrainer,
+ SuggestedTrainer trainer,
+ MLContext context,
+ bool cacheBeforeTrainer)
+ {
+ Transforms = transforms.Select(t => t.Clone()).ToList();
+ TransformsPostTrainer = transformsPostTrainer.Select(t => t.Clone()).ToList();
+ Trainer = trainer.Clone();
+ _context = context;
+ _cacheBeforeTrainer = cacheBeforeTrainer;
+ }
+
+ public override string ToString() => $"{string.Join(" ", Transforms.Select(t => $"xf={t}"))} tr={this.Trainer} {string.Join(" ", TransformsPostTrainer.Select(t => $"xf={t}"))} cache={(_cacheBeforeTrainer ? "+" : "-")}";
+
+ public override bool Equals(object obj)
+ {
+ var pipeline = obj as SuggestedPipeline;
+ if(pipeline == null)
+ {
+ return false;
+ }
+ return pipeline.ToString() == this.ToString();
+ }
+
+ public override int GetHashCode()
+ {
+ return ToString().GetHashCode();
+ }
+
+ public Pipeline ToPipeline()
+ {
+ var pipelineElements = new List();
+ foreach(var transform in Transforms)
+ {
+ pipelineElements.Add(transform.PipelineNode);
+ }
+ pipelineElements.Add(Trainer.ToPipelineNode());
+ foreach (var transform in TransformsPostTrainer)
+ {
+ pipelineElements.Add(transform.PipelineNode);
+ }
+ return new Pipeline(pipelineElements.ToArray(), _cacheBeforeTrainer);
+ }
+
+ public static SuggestedPipeline FromPipeline(MLContext context, Pipeline pipeline)
+ {
+ var transforms = new List();
+ var transformsPostTrainer = new List();
+ SuggestedTrainer trainer = null;
+
+ var trainerEncountered = false;
+ foreach(var pipelineNode in pipeline.Nodes)
+ {
+ if(pipelineNode.NodeType == PipelineNodeType.Trainer)
+ {
+ var trainerName = (TrainerName)Enum.Parse(typeof(TrainerName), pipelineNode.Name);
+ var trainerExtension = TrainerExtensionCatalog.GetTrainerExtension(trainerName);
+ var hyperParamSet = TrainerExtensionUtil.BuildParameterSet(trainerName, pipelineNode.Properties);
+ var columnInfo = TrainerExtensionUtil.BuildColumnInfo(pipelineNode.Properties);
+ trainer = new SuggestedTrainer(context, trainerExtension, columnInfo, hyperParamSet);
+ trainerEncountered = true;
+ }
+ else if (pipelineNode.NodeType == PipelineNodeType.Transform)
+ {
+ var estimatorName = (EstimatorName)Enum.Parse(typeof(EstimatorName), pipelineNode.Name);
+ var estimatorExtension = EstimatorExtensionCatalog.GetExtension(estimatorName);
+ var estimator = estimatorExtension.CreateInstance(context, pipelineNode);
+ var transform = new SuggestedTransform(pipelineNode, estimator);
+ if (!trainerEncountered)
+ {
+ transforms.Add(transform);
+ }
+ else
+ {
+ transformsPostTrainer.Add(transform);
+ }
+ }
+ }
+
+ return new SuggestedPipeline(transforms, transformsPostTrainer, trainer, context, pipeline.CacheBeforeTrainer);
+ }
+
+ public IEstimator ToEstimator()
+ {
+ IEstimator pipeline = new EstimatorChain();
+
+ // Append each transformer to the pipeline
+ foreach (var transform in Transforms)
+ {
+ if (transform.Estimator != null)
+ {
+ pipeline = pipeline.Append(transform.Estimator);
+ }
+ }
+
+ // Get learner
+ var learner = Trainer.BuildTrainer();
+
+ if (_cacheBeforeTrainer)
+ {
+ pipeline = pipeline.AppendCacheCheckpoint(_context);
+ }
+
+ // Append learner to pipeline
+ pipeline = pipeline.Append(learner);
+
+ // Append each post-trainer transformer to the pipeline
+ foreach (var transform in TransformsPostTrainer)
+ {
+ if (transform.Estimator != null)
+ {
+ pipeline = pipeline.Append(transform.Estimator);
+ }
+ }
+
+ return pipeline;
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineBuilder.cs b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineBuilder.cs
new file mode 100644
index 0000000000..a3fad88e0b
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineBuilder.cs
@@ -0,0 +1,43 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ internal static class SuggestedPipelineBuilder
+ {
+ public static SuggestedPipeline Build(MLContext context,
+ ICollection transforms,
+ ICollection transformsPostTrainer,
+ SuggestedTrainer trainer,
+ bool? enableCaching)
+ {
+ var trainerInfo = trainer.BuildTrainer().Info;
+ AddNormalizationTransforms(context, trainerInfo, transforms);
+ var cacheBeforeTrainer = ShouldCacheBeforeTrainer(trainerInfo, enableCaching);
+ return new SuggestedPipeline(transforms, transformsPostTrainer, trainer, context, cacheBeforeTrainer);
+ }
+
+ private static void AddNormalizationTransforms(MLContext context,
+ TrainerInfo trainerInfo,
+ ICollection transforms)
+ {
+ // Only add normalization if trainer needs it
+ if (!trainerInfo.NeedNormalization)
+ {
+ return;
+ }
+
+ var transform = NormalizingExtension.CreateSuggestedTransform(context, DefaultColumnNames.Features, DefaultColumnNames.Features);
+ transforms.Add(transform);
+ }
+
+ private static bool ShouldCacheBeforeTrainer(TrainerInfo trainerInfo, bool? enableCaching)
+ {
+ return enableCaching == true || (enableCaching == null && trainerInfo.WantCaching);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineCrossValRunDetail.cs b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineCrossValRunDetail.cs
new file mode 100644
index 0000000000..e21b55428a
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineCrossValRunDetail.cs
@@ -0,0 +1,54 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace Microsoft.ML.Auto
+{
+ internal sealed class SuggestedPipelineTrainResult
+ {
+ public readonly TMetrics ValidationMetrics;
+ public readonly ModelContainer ModelContainer;
+ public readonly Exception Exception;
+ public readonly double Score;
+
+ internal SuggestedPipelineTrainResult(ModelContainer modelContainer,
+ TMetrics metrics,
+ Exception exception,
+ double score)
+ {
+ ModelContainer = modelContainer;
+ ValidationMetrics = metrics;
+ Exception = exception;
+ Score = score;
+ }
+
+ public TrainResult ToTrainResult()
+ {
+ return new TrainResult(ModelContainer, ValidationMetrics, Exception);
+ }
+ }
+
+ internal sealed class SuggestedPipelineCrossValRunDetail : SuggestedPipelineRunDetail
+ {
+ public readonly IEnumerable> Results;
+
+ internal SuggestedPipelineCrossValRunDetail(SuggestedPipeline pipeline,
+ double score,
+ bool runSucceeded,
+ IEnumerable> results) : base(pipeline, score, runSucceeded)
+ {
+ Results = results;
+ }
+
+ public CrossValidationRunDetail ToIterationResult(IEstimator preFeaturizer)
+ {
+ var estimator = SuggestedPipelineRunDetailUtil.PrependPreFeaturizer(Pipeline.ToEstimator(), preFeaturizer);
+ return new CrossValidationRunDetail(Pipeline.Trainer.TrainerName.ToString(), estimator,
+ Pipeline.ToPipeline(), Results.Select(r => r.ToTrainResult()));
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineRunDetail.cs b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineRunDetail.cs
new file mode 100644
index 0000000000..7cb76e1ab3
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineRunDetail.cs
@@ -0,0 +1,58 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+
+namespace Microsoft.ML.Auto
+{
+ internal class SuggestedPipelineRunDetail
+ {
+ public readonly SuggestedPipeline Pipeline;
+ public readonly bool RunSucceded;
+ public readonly double Score;
+
+ public SuggestedPipelineRunDetail(SuggestedPipeline pipeline, double score, bool runSucceeded)
+ {
+ Pipeline = pipeline;
+ Score = score;
+ RunSucceded = runSucceeded;
+ }
+
+ public static SuggestedPipelineRunDetail FromPipelineRunResult(MLContext context, PipelineScore pipelineRunResult)
+ {
+ return new SuggestedPipelineRunDetail(SuggestedPipeline.FromPipeline(context, pipelineRunResult.Pipeline), pipelineRunResult.Score, pipelineRunResult.RunSucceded);
+ }
+
+ public IRunResult ToRunResult(bool isMetricMaximizing)
+ {
+ return new RunResult(Pipeline.Trainer.HyperParamSet, Score, isMetricMaximizing);
+ }
+ }
+
+ internal class SuggestedPipelineRunDetail : SuggestedPipelineRunDetail
+ {
+ public readonly TMetrics ValidationMetrics;
+ public readonly ModelContainer ModelContainer;
+ public readonly Exception Exception;
+
+ internal SuggestedPipelineRunDetail(SuggestedPipeline pipeline,
+ double score,
+ bool runSucceeded,
+ TMetrics validationMetrics,
+ ModelContainer modelContainer,
+ Exception ex) : base(pipeline, score, runSucceeded)
+ {
+ ValidationMetrics = validationMetrics;
+ ModelContainer = modelContainer;
+ Exception = ex;
+ }
+
+ public RunDetail ToIterationResult(IEstimator preFeaturizer)
+ {
+ var estimator = SuggestedPipelineRunDetailUtil.PrependPreFeaturizer(Pipeline.ToEstimator(), preFeaturizer);
+ return new RunDetail(Pipeline.Trainer.TrainerName.ToString(), estimator,
+ Pipeline.ToPipeline(), ModelContainer, ValidationMetrics, Exception);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineRunDetailUtil.cs b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineRunDetailUtil.cs
new file mode 100644
index 0000000000..8fcb59b4d5
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/SuggestedPipelineRunDetails/SuggestedPipelineRunDetailUtil.cs
@@ -0,0 +1,18 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+namespace Microsoft.ML.Auto
+{
+ internal static class SuggestedPipelineRunDetailUtil
+ {
+ public static IEstimator PrependPreFeaturizer(IEstimator estimator, IEstimator preFeaturizer)
+ {
+ if (preFeaturizer == null)
+ {
+ return estimator;
+ }
+ return preFeaturizer.Append(estimator);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Experiment/SuggestedTrainer.cs b/src/Microsoft.ML.Auto/Experiment/SuggestedTrainer.cs
new file mode 100644
index 0000000000..90fae04eb9
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Experiment/SuggestedTrainer.cs
@@ -0,0 +1,92 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Trainers;
+
+namespace Microsoft.ML.Auto
+{
+ internal class SuggestedTrainer
+ {
+ public IEnumerable SweepParams { get; }
+ public TrainerName TrainerName { get; }
+ public ParameterSet HyperParamSet { get; set; }
+
+ private readonly MLContext _mlContext;
+ private readonly ITrainerExtension _trainerExtension;
+ private readonly ColumnInformation _columnInfo;
+
+ internal SuggestedTrainer(MLContext mlContext, ITrainerExtension trainerExtension,
+ ColumnInformation columnInfo,
+ ParameterSet hyperParamSet = null)
+ {
+ _mlContext = mlContext;
+ _trainerExtension = trainerExtension;
+ _columnInfo = columnInfo;
+ SweepParams = _trainerExtension.GetHyperparamSweepRanges();
+ TrainerName = TrainerExtensionCatalog.GetTrainerName(_trainerExtension);
+ SetHyperparamValues(hyperParamSet);
+ }
+
+ public void SetHyperparamValues(ParameterSet hyperParamSet)
+ {
+ HyperParamSet = hyperParamSet;
+ PropagateParamSetValues();
+ }
+
+ public SuggestedTrainer Clone()
+ {
+ return new SuggestedTrainer(_mlContext, _trainerExtension, _columnInfo, HyperParamSet?.Clone());
+ }
+
+ public ITrainerEstimator, object> BuildTrainer()
+ {
+ IEnumerable sweepParams = null;
+ if (HyperParamSet != null)
+ {
+ sweepParams = SweepParams;
+ }
+ return _trainerExtension.CreateInstance(_mlContext, sweepParams, _columnInfo);
+ }
+
+ public override string ToString()
+ {
+ var paramsStr = string.Empty;
+ if (SweepParams != null)
+ {
+ paramsStr = string.Join(", ", SweepParams.Where(p => p != null && p.RawValue != null).Select(p => $"{p.Name}:{p.ProcessedValue()}"));
+ }
+ return $"{TrainerName}{{{paramsStr}}}";
+ }
+
+ public PipelineNode ToPipelineNode()
+ {
+ var sweepParams = SweepParams.Where(p => p.RawValue != null);
+ return _trainerExtension.CreatePipelineNode(sweepParams, _columnInfo);
+ }
+
+ ///
+ /// make sure sweep params and param set are consistent
+ ///
+ private void PropagateParamSetValues()
+ {
+ if (HyperParamSet == null)
+ {
+ return;
+ }
+
+ var spMap = SweepParams.ToDictionary(sp => sp.Name);
+
+ foreach (var hp in HyperParamSet)
+ {
+ if (spMap.ContainsKey(hp.Name))
+ {
+ var sp = spMap[hp.Name];
+ sp.SetUsingValueText(hp.ValueText);
+ }
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Auto/Microsoft.ML.Auto.csproj b/src/Microsoft.ML.Auto/Microsoft.ML.Auto.csproj
new file mode 100644
index 0000000000..cb8eb4fc72
--- /dev/null
+++ b/src/Microsoft.ML.Auto/Microsoft.ML.Auto.csproj
@@ -0,0 +1,34 @@
+
+
+ netstandard2.0
+ 7.3
+ Microsoft.ML.Auto
+
+ false
+ false
+
+
+
+
+
+
+
+
+
+
+
+ Microsoft
+ LICENSE
+ https://dot.net/ml
+ https://aka.ms/mlnetlogo
+ https://aka.ms/mlnetreleasenotes
+
+ ML.NET ML Machine Learning AutoML
+ Microsoft.ML.Auto
+
+
+
+
+
+
+
diff --git a/src/Microsoft.ML.Auto/PipelineSuggesters/PipelineSuggester.cs b/src/Microsoft.ML.Auto/PipelineSuggesters/PipelineSuggester.cs
new file mode 100644
index 0000000000..ca834ce801
--- /dev/null
+++ b/src/Microsoft.ML.Auto/PipelineSuggesters/PipelineSuggester.cs
@@ -0,0 +1,217 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Microsoft.ML.Data;
+
+namespace Microsoft.ML.Auto
+{
+ internal static class PipelineSuggester
+ {
+ private const int TopKTrainers = 3;
+
+ public static Pipeline GetNextPipeline(MLContext context,
+ IEnumerable history,
+ DatasetColumnInfo[] columns,
+ TaskKind task,
+ bool isMaximizingMetric = true)
+ {
+ var inferredHistory = history.Select(r => SuggestedPipelineRunDetail.FromPipelineRunResult(context, r));
+ var nextInferredPipeline = GetNextInferredPipeline(context, inferredHistory, columns, task, isMaximizingMetric);
+ return nextInferredPipeline?.ToPipeline();
+ }
+
+ public static SuggestedPipeline GetNextInferredPipeline(MLContext context,
+ IEnumerable history,
+ DatasetColumnInfo[] columns,
+ TaskKind task,
+ bool isMaximizingMetric,
+ IEnumerable trainerWhitelist = null,
+ bool? _enableCaching = null)
+ {
+ var availableTrainers = RecipeInference.AllowedTrainers(context, task,
+ ColumnInformationUtil.BuildColumnInfo(columns), trainerWhitelist);
+ var transforms = TransformInferenceApi.InferTransforms(context, task, columns).ToList();
+ var transformsPostTrainer = TransformInferenceApi.InferTransformsPostTrainer(context, task, columns).ToList();
+
+ // if we haven't run all pipelines once
+ if (history.Count() < availableTrainers.Count())
+ {
+ return GetNextFirstStagePipeline(context, history, availableTrainers, transforms, transformsPostTrainer, _enableCaching);
+ }
+
+ // get top trainers from stage 1 runs
+ var topTrainers = GetTopTrainers(history, availableTrainers, isMaximizingMetric);
+
+ // sort top trainers by # of times they've been run, from lowest to highest
+ var orderedTopTrainers = OrderTrainersByNumTrials(history, topTrainers);
+
+ // keep as hashset of previously visited pipelines
+ var visitedPipelines = new HashSet(history.Select(h => h.Pipeline));
+
+ // iterate over top trainers (from least run to most run),
+ // to find next pipeline
+ foreach (var trainer in orderedTopTrainers)
+ {
+ var newTrainer = trainer.Clone();
+
+ // repeat until passes or runs out of chances
+ const int maxNumberAttempts = 10;
+ var count = 0;
+ do
+ {
+ // sample new hyperparameters for the learner
+ if (!SampleHyperparameters(context, newTrainer, history, isMaximizingMetric))
+ {
+ // if unable to sample new hyperparameters for the learner
+ // (ie SMAC returned 0 suggestions), break
+ break;
+ }
+
+ var suggestedPipeline = SuggestedPipelineBuilder.Build(context, transforms, transformsPostTrainer, newTrainer, _enableCaching);
+
+ // make sure we have not seen pipeline before
+ if (!visitedPipelines.Contains(suggestedPipeline))
+ {
+ return suggestedPipeline;
+ }
+ } while (++count <= maxNumberAttempts);
+ }
+
+ return null;
+ }
+
+ ///
+ /// Get top trainers from first stage
+ ///
+ private static IEnumerable GetTopTrainers(IEnumerable