|
| 1 | +// Licensed to the .NET Foundation under one or more agreements. |
| 2 | +// The .NET Foundation licenses this file to you under the MIT license. |
| 3 | +// See the LICENSE file in the project root for more information. |
| 4 | + |
| 5 | +using System; |
| 6 | +using System.Collections.Generic; |
| 7 | +using System.Linq; |
| 8 | +using System.Threading; |
| 9 | +using Microsoft.Data.DataView; |
| 10 | +using Microsoft.ML.Core.Data; |
| 11 | +using Microsoft.ML.Data; |
| 12 | + |
| 13 | +namespace Microsoft.ML.Auto |
| 14 | +{ |
| 15 | + public class AutoFitBinaryClassificationOptions |
| 16 | + { |
| 17 | + public IDataView TrainData; |
| 18 | + public string LabelColumnName = DefaultColumnNames.Label; |
| 19 | + public IDataView ValidationData; |
| 20 | + public uint TimeoutInSeconds = AutoFitDefaults.TimeoutInSeconds; |
| 21 | + public CancellationToken CancellationToken = default; |
| 22 | + public IProgress<AutoFitRunResult<BinaryClassificationMetrics>> ProgressCallback; |
| 23 | + public IEstimator<ITransformer> PreFeaturizers; |
| 24 | + public IEnumerable<(string, ColumnPurpose)> ColumnPurposes; |
| 25 | + } |
| 26 | + |
| 27 | + public static class BinaryClassificationExtensions |
| 28 | + { |
| 29 | + public static List<AutoFitRunResult<BinaryClassificationMetrics>> AutoFit(this BinaryClassificationCatalog catalog, |
| 30 | + IDataView trainData, |
| 31 | + string labelColumnName = DefaultColumnNames.Label, |
| 32 | + IDataView validationData = null, |
| 33 | + uint timeoutInSeconds = AutoFitDefaults.TimeoutInSeconds, |
| 34 | + CancellationToken cancellationToken = default, |
| 35 | + IProgress<AutoFitRunResult<BinaryClassificationMetrics>> progressCallback = null) |
| 36 | + { |
| 37 | + var settings = new AutoFitSettings(); |
| 38 | + settings.StoppingCriteria.TimeoutInSeconds = timeoutInSeconds; |
| 39 | + |
| 40 | + return AutoFit(catalog, trainData, labelColumnName, validationData, settings, |
| 41 | + null, null, cancellationToken, progressCallback, null); |
| 42 | + } |
| 43 | + |
| 44 | + public static List<AutoFitRunResult<BinaryClassificationMetrics>> AutoFit(this BinaryClassificationCatalog catalog, |
| 45 | + AutoFitBinaryClassificationOptions options) |
| 46 | + { |
| 47 | + var settings = new AutoFitSettings(); |
| 48 | + settings.StoppingCriteria.TimeoutInSeconds = options.TimeoutInSeconds; |
| 49 | + |
| 50 | + return AutoFit(catalog, options.TrainData, options.LabelColumnName, options.ValidationData, settings, |
| 51 | + options.PreFeaturizers, options.ColumnPurposes, options.CancellationToken, options.ProgressCallback, null); |
| 52 | + } |
| 53 | + |
| 54 | + internal static List<AutoFitRunResult<BinaryClassificationMetrics>> AutoFit(this BinaryClassificationCatalog catalog, |
| 55 | + IDataView trainData, |
| 56 | + string labelColumnName = DefaultColumnNames.Label, |
| 57 | + IDataView validationData = null, |
| 58 | + AutoFitSettings settings = null, |
| 59 | + IEstimator<ITransformer> preFeaturizers = null, |
| 60 | + IEnumerable<(string, ColumnPurpose)> columnPurposes = null, |
| 61 | + CancellationToken cancellationToken = default, |
| 62 | + IProgress<AutoFitRunResult<BinaryClassificationMetrics>> progressCallback = null, |
| 63 | + IDebugLogger debugLogger = null) |
| 64 | + { |
| 65 | + UserInputValidationUtil.ValidateAutoFitArgs(trainData, labelColumnName, validationData, settings, columnPurposes); |
| 66 | + |
| 67 | + if (validationData == null) |
| 68 | + { |
| 69 | + (trainData, validationData) = catalog.TestValidateSplit(trainData); |
| 70 | + } |
| 71 | + |
| 72 | + // run autofit & get all pipelines run in that process |
| 73 | + var autoFitter = new AutoFitter<BinaryClassificationMetrics>(TaskKind.BinaryClassification, trainData, labelColumnName, validationData, |
| 74 | + settings, preFeaturizers, columnPurposes, |
| 75 | + OptimizingMetric.RSquared, cancellationToken, progressCallback, debugLogger); |
| 76 | + |
| 77 | + return autoFitter.Fit(); |
| 78 | + } |
| 79 | + |
| 80 | + public static AutoFitRunResult<BinaryClassificationMetrics> Best(this IEnumerable<AutoFitRunResult<BinaryClassificationMetrics>> results) |
| 81 | + { |
| 82 | + double maxScore = results.Select(r => r.Metrics.Accuracy).Max(); |
| 83 | + return results.First(r => r.Metrics.Accuracy == maxScore); |
| 84 | + } |
| 85 | + } |
| 86 | + |
| 87 | +} |
0 commit comments