Skip to content

Commit 3a4595d

Browse files
authored
cache arg implementation in CLI (dotnet#280)
* cache implementation * corrected the null case * added tests for all cases
1 parent 0bb8951 commit 3a4595d

File tree

5 files changed

+90
-3
lines changed

5 files changed

+90
-3
lines changed

src/mlnet.Test/CommandLineTests.cs

+58
Original file line numberDiff line numberDiff line change
@@ -172,5 +172,63 @@ public void TestCommandLineArgsMutuallyExclusiveArgsTest()
172172
Assert.IsFalse(parsingSuccessful);
173173

174174
}
175+
176+
[TestMethod]
177+
public void CacheArgumentTest()
178+
{
179+
bool parsingSuccessful = false;
180+
var trainDataset = Path.GetTempFileName();
181+
var testDataset = Path.GetTempFileName();
182+
var labelName = "Label";
183+
var cache = "on";
184+
185+
// Create handler outside so that commandline and the handler is decoupled and testable.
186+
var handler = CommandHandler.Create<NewCommandSettings>(
187+
(opt) =>
188+
{
189+
parsingSuccessful = true;
190+
Assert.AreEqual(opt.MlTask, "binary-classification");
191+
Assert.AreEqual(opt.Dataset, trainDataset);
192+
Assert.AreEqual(opt.LabelColumnName, labelName);
193+
Assert.AreEqual(opt.Cache, cache);
194+
});
195+
196+
var parser = new CommandLineBuilder()
197+
// Parser
198+
.AddCommand(CommandDefinitions.New(handler))
199+
.UseDefaults()
200+
.Build();
201+
202+
// valid cache test
203+
string[] args = new[] { "new", "--ml-task", "binary-classification", "--dataset", trainDataset, "--label-column-name", labelName, "--cache", cache };
204+
parser.InvokeAsync(args).Wait();
205+
Assert.IsTrue(parsingSuccessful);
206+
207+
parsingSuccessful = false;
208+
209+
cache = "off";
210+
// valid cache test
211+
args = new[] { "new", "--ml-task", "binary-classification", "--dataset", trainDataset, "--label-column-name", labelName, "--cache", cache };
212+
parser.InvokeAsync(args).Wait();
213+
Assert.IsTrue(parsingSuccessful);
214+
215+
parsingSuccessful = false;
216+
217+
cache = "auto";
218+
// valid cache test
219+
args = new[] { "new", "--ml-task", "binary-classification", "--dataset", trainDataset, "--label-column-name", labelName, "--cache", cache };
220+
parser.InvokeAsync(args).Wait();
221+
Assert.IsTrue(parsingSuccessful);
222+
223+
parsingSuccessful = false;
224+
225+
// invalid cache test
226+
args = new[] { "new", "--ml-task", "binary-classification", "--dataset", trainDataset, "--label-column-name", labelName, "--cache", "blah" };
227+
parser.InvokeAsync(args).Wait();
228+
Assert.IsFalse(parsingSuccessful);
229+
230+
File.Delete(trainDataset);
231+
File.Delete(testDataset);
232+
}
175233
}
176234
}

src/mlnet/AutoML/AutoMLEngine.cs

+8-3
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@ internal class AutoMLEngine : IAutoMLEngine
1414
{
1515
private NewCommandSettings settings;
1616
private TaskKind taskKind;
17+
private bool? enableCaching;
1718
private static Logger logger = LogManager.GetCurrentClassLogger();
1819

1920
public AutoMLEngine(NewCommandSettings settings)
2021
{
2122
this.settings = settings;
2223
this.taskKind = Utils.GetTaskKind(settings.MlTask);
24+
this.enableCaching = Utils.GetCacheSettings(settings.Cache);
2325
}
2426

2527
public ColumnInferenceResults InferColumns(MLContext context)
@@ -53,7 +55,8 @@ public ColumnInferenceResults InferColumns(MLContext context)
5355
.CreateBinaryClassificationExperiment(new BinaryExperimentSettings()
5456
{
5557
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
56-
ProgressHandler = progressReporter
58+
ProgressHandler = progressReporter,
59+
EnableCaching = this.enableCaching
5760
})
5861
.Execute(trainData, validationData, new ColumnInformation() { LabelColumn = labelName });
5962
logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline);
@@ -69,7 +72,8 @@ public ColumnInferenceResults InferColumns(MLContext context)
6972
.CreateRegressionExperiment(new RegressionExperimentSettings()
7073
{
7174
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
72-
ProgressHandler = progressReporter
75+
ProgressHandler = progressReporter,
76+
EnableCaching = this.enableCaching
7377
}).Execute(trainData, validationData, new ColumnInformation() { LabelColumn = labelName });
7478
logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline);
7579
var bestIteration = result.Best();
@@ -84,7 +88,8 @@ public ColumnInferenceResults InferColumns(MLContext context)
8488
var experimentSettings = new MulticlassExperimentSettings()
8589
{
8690
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
87-
ProgressHandler = progressReporter
91+
ProgressHandler = progressReporter,
92+
EnableCaching = this.enableCaching
8893
};
8994

9095
// Inclusion list for currently supported learners. Need to remove once we have codegen support for all other learners.

src/mlnet/Commands/CommandDefinitions.cs

+10
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ internal static System.CommandLine.Command New(ICommandHandler handler)
3030
Name(),
3131
OutputPath(),
3232
HasHeader(),
33+
Cache()
3334
};
3435

3536
newCommand.Argument.AddValidator((sym) =>
@@ -99,6 +100,10 @@ Option HasHeader() =>
99100
new Option(new List<string>() { "--has-header" }, "Specify true/false depending if the dataset file(s) have a header row.",
100101
new Argument<bool>(defaultValue: true));
101102

103+
Option Cache() =>
104+
new Option(new List<string>() { "--cache" }, "Specify on/off/auto if you want cache to be turned on, off or auto determined.",
105+
new Argument<string>(defaultValue: "auto").FromAmong(GetCacheSuggestions()));
106+
102107
}
103108

104109
private static string[] GetMlTaskSuggestions()
@@ -110,5 +115,10 @@ private static string[] GetVerbositySuggestions()
110115
{
111116
return new[] { "q", "m", "diag" };
112117
}
118+
119+
private static string[] GetCacheSuggestions()
120+
{
121+
return new[] { "on", "off", "auto" };
122+
}
113123
}
114124
}

src/mlnet/Commands/New/NewCommandSettings.cs

+2
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,7 @@ public class NewCommandSettings
3030

3131
public bool HasHeader { get; set; }
3232

33+
public string Cache { get; set; }
34+
3335
}
3436
}

src/mlnet/Utilities/Utils.cs

+12
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,17 @@ internal static string Normalize(string input)
7979
}
8080
}
8181

82+
internal static bool? GetCacheSettings(string input)
83+
{
84+
switch (input)
85+
{
86+
case "on": return true;
87+
case "off": return false;
88+
case "auto": return null;
89+
default:
90+
throw new ArgumentException($"{nameof(input)} is invalid", nameof(input));
91+
}
92+
}
93+
8294
}
8395
}

0 commit comments

Comments
 (0)