Skip to content

Commit df8f3b0

Browse files
daholsteDmitry-A
authored andcommitted
Add caching (dotnet#249)
1 parent 2cf834e commit df8f3b0

File tree

6 files changed

+40
-24
lines changed

6 files changed

+40
-24
lines changed

src/Microsoft.ML.Auto/API/ExperimentSettings.cs

+7-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,13 @@ public class ExperimentSettings
1111
public uint MaxExperimentTimeInSeconds { get; set; } = 24 * 60 * 60;
1212
public CancellationToken CancellationToken { get; set; } = default;
1313

14-
internal bool EnableCaching;
14+
/// <summary>
15+
/// This setting controls whether or not an AutoML experiment will make use of ML.NET-provided caching.
16+
/// If set to true, caching will be forced on for all pipelines. If set to false, caching will be forced off.
17+
/// If set to null (default value), AutoML will decide whether to enable caching for each model.
18+
/// </summary>
19+
public bool? EnableCaching = null;
20+
1521
internal int MaxModels = int.MaxValue;
1622
internal IDebugLogger DebugLogger;
1723
}

src/Microsoft.ML.Auto/Experiment/Experiment.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public List<RunResult<T>> Execute()
8484
var getPiplelineStopwatch = Stopwatch.StartNew();
8585

8686
// get next pipeline
87-
pipeline = PipelineSuggester.GetNextInferredPipeline(_context, _history, columns, _task, _optimizingMetricInfo.IsMaximizing, _trainerWhitelist);
87+
pipeline = PipelineSuggester.GetNextInferredPipeline(_context, _history, columns, _task, _optimizingMetricInfo.IsMaximizing, _trainerWhitelist, _experimentSettings.EnableCaching);
8888

8989
getPiplelineStopwatch.Stop();
9090

src/Microsoft.ML.Auto/Experiment/SuggestedPipeline.cs

+13-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
using System;
66
using System.Collections.Generic;
77
using System.Linq;
8-
using Microsoft.Data.DataView;
98
using Microsoft.ML.Data;
109

1110
namespace Microsoft.ML.Auto
@@ -16,20 +15,24 @@ namespace Microsoft.ML.Auto
1615
/// </summary>
1716
internal class SuggestedPipeline
1817
{
19-
private readonly MLContext _context;
2018
public readonly IList<SuggestedTransform> Transforms;
2119
public readonly SuggestedTrainer Trainer;
2220

21+
private readonly MLContext _context;
22+
private readonly bool? _enableCaching;
23+
2324
public SuggestedPipeline(IEnumerable<SuggestedTransform> transforms,
2425
SuggestedTrainer trainer,
2526
MLContext context,
27+
bool? enableCaching,
2628
bool autoNormalize = true)
2729
{
2830
Transforms = transforms.Select(t => t.Clone()).ToList();
2931
Trainer = trainer.Clone();
3032
_context = context;
33+
_enableCaching = enableCaching;
3134

32-
if(autoNormalize)
35+
if (autoNormalize)
3336
{
3437
AddNormalizationTransforms();
3538
}
@@ -88,7 +91,7 @@ public static SuggestedPipeline FromPipeline(MLContext context, Pipeline pipelin
8891
}
8992
}
9093

91-
return new SuggestedPipeline(transforms, trainer, context, false);
94+
return new SuggestedPipeline(transforms, trainer, context, null);
9295
}
9396

9497
public IEstimator<ITransformer> ToEstimator()
@@ -107,6 +110,11 @@ public IEstimator<ITransformer> ToEstimator()
107110
// Get learner
108111
var learner = Trainer.BuildTrainer();
109112

113+
if (_enableCaching == true || (_enableCaching == null && learner.Info.WantCaching))
114+
{
115+
pipeline = pipeline.AppendCacheCheckpoint(_context);
116+
}
117+
110118
// Append learner to pipeline
111119
pipeline = pipeline.Append(learner);
112120

@@ -128,4 +136,4 @@ private void AddNormalizationTransforms()
128136
Transforms.Add(transform);
129137
}
130138
}
131-
}
139+
}

src/Microsoft.ML.Auto/PipelineSuggesters/PipelineSuggester.cs

+7-5
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ public static SuggestedPipeline GetNextInferredPipeline(MLContext context,
3030
(string, DataViewType, ColumnPurpose, ColumnDimensions)[] columns,
3131
TaskKind task,
3232
bool isMaximizingMetric,
33-
IEnumerable<TrainerName> trainerWhitelist = null)
33+
IEnumerable<TrainerName> trainerWhitelist = null,
34+
bool? _enableCaching = null)
3435
{
3536
var availableTrainers = RecipeInference.AllowedTrainers(context, task,
3637
ColumnInformationUtil.BuildColumnInfo(columns), trainerWhitelist);
@@ -40,7 +41,7 @@ public static SuggestedPipeline GetNextInferredPipeline(MLContext context,
4041
// if we haven't run all pipelines once
4142
if (history.Count() < availableTrainers.Count())
4243
{
43-
return GetNextFirstStagePipeline(context, history, availableTrainers, transforms);
44+
return GetNextFirstStagePipeline(context, history, availableTrainers, transforms, _enableCaching);
4445
}
4546

4647
// get top trainers from stage 1 runs
@@ -71,7 +72,7 @@ public static SuggestedPipeline GetNextInferredPipeline(MLContext context,
7172
break;
7273
}
7374

74-
var suggestedPipeline = new SuggestedPipeline(transforms, newTrainer, context);
75+
var suggestedPipeline = new SuggestedPipeline(transforms, newTrainer, context, _enableCaching);
7576

7677
// make sure we have not seen pipeline before
7778
if (!visitedPipelines.Contains(suggestedPipeline))
@@ -117,10 +118,11 @@ private static IEnumerable<SuggestedTrainer> OrderTrainersByNumTrials(IEnumerabl
117118
private static SuggestedPipeline GetNextFirstStagePipeline(MLContext context,
118119
IEnumerable<SuggestedPipelineResult> history,
119120
IEnumerable<SuggestedTrainer> availableTrainers,
120-
IEnumerable<SuggestedTransform> transforms)
121+
IEnumerable<SuggestedTransform> transforms,
122+
bool? _enableCaching)
121123
{
122124
var trainer = availableTrainers.ElementAt(history.Count());
123-
return new SuggestedPipeline(transforms, trainer, context);
125+
return new SuggestedPipeline(transforms, trainer, context, _enableCaching);
124126
}
125127

126128
private static IValueGenerator[] ConvertToValueGenerators(IEnumerable<SweepableParam> hps)

src/Test/InferredPipelineTests.cs

+10-10
Original file line numberDiff line numberDiff line change
@@ -22,43 +22,43 @@ public void InferredPipelinesHashTest()
2222
var trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo);
2323
var transforms1 = new List<SuggestedTransform>();
2424
var transforms2 = new List<SuggestedTransform>();
25-
var inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context);
26-
var inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context);
25+
var inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context, null);
26+
var inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context, null);
2727
Assert.AreEqual(inferredPipeline1.GetHashCode(), inferredPipeline2.GetHashCode());
2828

2929
// test same learners with hyperparams set vs empty hyperparams have different hash codes
3030
var hyperparams1 = new ParameterSet(new List<IParameterValue>() { new LongParameterValue("NumLeaves", 2) });
3131
trainer1 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo, hyperparams1);
3232
trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo);
33-
inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context);
34-
inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context);
33+
inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context, null);
34+
inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context, null);
3535
Assert.AreNotEqual(inferredPipeline1.GetHashCode(), inferredPipeline2.GetHashCode());
3636

3737
// same learners with different hyperparams
3838
hyperparams1 = new ParameterSet(new List<IParameterValue>() { new LongParameterValue("NumLeaves", 2) });
3939
var hyperparams2 = new ParameterSet(new List<IParameterValue>() { new LongParameterValue("NumLeaves", 6) });
4040
trainer1 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo, hyperparams1);
4141
trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo, hyperparams2);
42-
inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context);
43-
inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context);
42+
inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context, null);
43+
inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context, null);
4444
Assert.AreNotEqual(inferredPipeline1.GetHashCode(), inferredPipeline2.GetHashCode());
4545

4646
// same learners with same transforms
4747
trainer1 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo);
4848
trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo);
4949
transforms1 = new List<SuggestedTransform>() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") };
5050
transforms2 = new List<SuggestedTransform>() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") };
51-
inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context);
52-
inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context);
51+
inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context, null);
52+
inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context, null);
5353
Assert.AreEqual(inferredPipeline1.GetHashCode(), inferredPipeline2.GetHashCode());
5454

5555
// same transforms with different learners
5656
trainer1 = new SuggestedTrainer(context, new SdcaBinaryExtension(), columnInfo);
5757
trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo);
5858
transforms1 = new List<SuggestedTransform>() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") };
5959
transforms2 = new List<SuggestedTransform>() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") };
60-
inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context);
61-
inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context);
60+
inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context, null);
61+
inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context, null);
6262
Assert.AreNotEqual(inferredPipeline1.GetHashCode(), inferredPipeline2.GetHashCode());
6363
}
6464
}

src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ public void GeneratedHelperCodeTest()
104104
var trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), new ColumnInformation(), hyperparams2);
105105
var transforms1 = new List<SuggestedTransform>() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") };
106106
var transforms2 = new List<SuggestedTransform>() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") };
107-
var inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context);
108-
var inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context);
107+
var inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context, null);
108+
var inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context, null);
109109

110110
this.pipeline = inferredPipeline1.ToPipeline();
111111
var textLoaderArgs = new TextLoader.Options()

0 commit comments

Comments
 (0)