Skip to content

Commit c8f8e38

Browse files
authored
CLI tool - make validation dataset optional and support for crossvalidation in generated code (dotnet#83)
* Added sequential grouping of columns * reverted the file * bug fixes, more logic to templates to support cross-validate * formatting and fix type in consolehelper * Added logic in templates * revert settings
1 parent 3bcaaf8 commit c8f8e38

File tree

7 files changed

+221
-168
lines changed

7 files changed

+221
-168
lines changed

src/mlnet/CodeGenerator/TrainerGenerators.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ internal class LinearSvm : TrainerGeneratorBase
178178
internal override string MethodName => "LinearSupportVectorMachines";
179179

180180
//ClassName of the options to trainer
181-
internal override string OptionsName => "LinearSvm.Options";
181+
internal override string OptionsName => "LinearSvmTrainer.Options";
182182

183183
//The named parameters to the trainer.
184184
internal override IDictionary<string, string> NamedParameters

src/mlnet/Commands/CommandDefinitions.cs

+2-38
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,19 @@ public static System.CommandLine.Command New()
1818
{
1919
var newCommand = new System.CommandLine.Command("new", "ML.NET CLI tool for code generation",
2020

21-
handler: CommandHandler.Create</*FileInfo,*/ FileInfo,/* FileInfo,*/ FileInfo, TaskKind, string>((/*FileInfo dataset,*/ FileInfo trainDataset, /*FileInfo validationDataset,*/ FileInfo testDataset, TaskKind mlTask, string labelColumnName) =>
21+
handler: CommandHandler.Create<FileInfo, FileInfo, TaskKind, string>((FileInfo trainDataset, FileInfo testDataset, TaskKind mlTask, string labelColumnName) =>
2222
{
2323
NewCommand.Run(new Options()
2424
{
25-
/*Dataset = dataset,*/
2625
TrainDataset = trainDataset,
27-
/*ValidationDataset = validationDataset,*/
2826
TestDataset = testDataset,
2927
MlTask = mlTask,
3028
LabelName = labelColumnName
3129
});
3230

3331
}))
3432
{
35-
//Dataset(),
3633
TrainDataset(),
37-
//ValidationDataset(),
3834
TestDataset(),
3935
MlTask(),
4036
LabelName(),
@@ -51,10 +47,6 @@ public static System.CommandLine.Command New()
5147
{
5248
return "Option required : --train-dataset";
5349
}
54-
if (sym.Children["--test-dataset"] == null)
55-
{
56-
return "Option required : --test-dataset";
57-
}
5850
if (sym.Children["--ml-task"] == null)
5951
{
6052
return "Option required : --ml-task";
@@ -69,21 +61,14 @@ public static System.CommandLine.Command New()
6961

7062
return newCommand;
7163

72-
//Option Dataset() =>
73-
// new Option("--dataset", "Dataset file path.",
74-
// new Argument<FileInfo>().ExistingOnly());
7564

7665
Option TrainDataset() =>
7766
new Option("--train-dataset", "Train dataset file path.",
7867
new Argument<FileInfo>().ExistingOnly());
7968

80-
//Option ValidationDataset() =>
81-
// new Option("--validation-dataset", "Test dataset file path.",
82-
// new Argument<FileInfo>().ExistingOnly());
83-
8469
Option TestDataset() =>
8570
new Option("--test-dataset", "Test dataset file path.",
86-
new Argument<FileInfo>().ExistingOnly());
71+
new Argument<FileInfo>(defaultValue: default(FileInfo)).ExistingOnly());
8772

8873
Option MlTask() =>
8974
new Option("--ml-task", "Type of ML task.",
@@ -93,27 +78,6 @@ Option LabelName() =>
9378
new Option("--label-column-name", "Name of the label column.",
9479
new Argument<string>());
9580

96-
//Option ColumnSeperator() =>
97-
// new Option("--column-separator", "Column separator in dataset file.",
98-
// new Argument<string>(defaultValue: default(string)));
99-
100-
//Option ExplorationTimeout() =>
101-
// new Option("--exploration-timeout", "Timeout for exploring the best models.",
102-
// new Argument<int>(defaultValue: 10));
103-
104-
//Option Name() =>
105-
// new Option("--name", "Name of the project file.",
106-
// new Argument<string>(defaultValue: "SampleProject"));
107-
108-
//Option ShowOutput() =>
109-
// new Option("--show-output", "Show output on the console",
110-
// new Argument<bool>(defaultValue: true));
111-
112-
//Option LabelIndex() =>
113-
// new Option("--label-column-index", "Index of the label column.",
114-
// new Argument<int>(defaultValue: -1));
115-
116-
11781
}
11882

11983
private static string[] GetMlTaskSuggestions()

src/mlnet/Commands/NewCommand.cs

+11-7
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,21 @@ internal static void Run(Options options)
2424
// For Version 0.1 It is required that the data set has header.
2525
var columnInference = context.Data.InferColumns(options.TrainDataset.FullName, label, true, groupColumns: false);
2626
var textLoader = context.Data.CreateTextLoader(columnInference);
27-
var trainData = textLoader.Read(options.TrainDataset.FullName);
2827

29-
var validationData = textLoader.Read(options.TestDataset.FullName);
30-
Pipeline pipelineToDeconstruct = null;
28+
IDataView trainData = textLoader.Read(options.TrainDataset.FullName);
29+
IDataView validationData = options.TestDataset == null ? null : textLoader.Read(options.TestDataset.FullName);
3130

32-
var result = ExploreModels(options, context, label, trainData, validationData, pipelineToDeconstruct);
33-
pipelineToDeconstruct = result.Item1;
31+
//Explore the models
32+
Pipeline pipeline = null;
33+
var result = ExploreModels(options, context, label, trainData, validationData, pipeline);
34+
35+
//Get the best pipeline
36+
pipeline = result.Item1;
3437
var model = result.Item2;
38+
3539
//Path can be overriden from args
3640
GenerateModel(model, @"./BestModel", "model.zip", context);
37-
RunCodeGen(options, columnInference, pipelineToDeconstruct);
41+
RunCodeGen(options, columnInference, pipeline);
3842
}
3943

4044
private static void GenerateModel(ITransformer model, string ModelPath, string modelName, MLContext mlContext)
@@ -116,7 +120,7 @@ private static void RunCodeGen(Options options, ColumnInferenceResult columnInfe
116120
MLCodeGen codeGen = new MLCodeGen()
117121
{
118122
Path = options.TrainDataset.FullName,
119-
TestPath = options.TestDataset.FullName,
123+
TestPath = options.TestDataset?.FullName,
120124
Columns = columns,
121125
Transforms = transforms,
122126
HasHeader = columnInference.HasHeader,

0 commit comments

Comments
 (0)