Skip to content

Commit 4a6921d

Browse files
authored
Implement ignore columns command line arg (dotnet#290)
* normalize line endings * added --ignore-columns * null checks * unit tests
1 parent 6e5a5d7 commit 4a6921d

File tree

10 files changed

+887
-518
lines changed

10 files changed

+887
-518
lines changed

src/mlnet.Test/CommandLineTests.cs

+70-9
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
6+
using System.Collections.Generic;
57
using System.CommandLine.Builder;
68
using System.CommandLine.Invocation;
79
using System.IO;
10+
using System.Linq;
811
using Microsoft.ML.CLI.Commands;
912
using Microsoft.ML.CLI.Data;
1013
using Microsoft.VisualStudio.TestTools.UnitTesting;
@@ -30,6 +33,10 @@ public void TestMinimumCommandLineArgs()
3033
// Parser
3134
.AddCommand(CommandDefinitions.New(handler))
3235
.UseDefaults()
36+
.UseExceptionHandler((e, ctx) =>
37+
{
38+
Console.WriteLine(e.ToString());
39+
})
3340
.Build();
3441

3542
var trainDataset = Path.GetTempFileName();
@@ -58,6 +65,10 @@ public void TestCommandLineArgsFailTest()
5865
// parser
5966
.AddCommand(CommandDefinitions.New(handler))
6067
.UseDefaults()
68+
.UseExceptionHandler((e, ctx) =>
69+
{
70+
Console.WriteLine(e.ToString());
71+
})
6172
.Build();
6273

6374
// Incorrect mltask test
@@ -96,29 +107,33 @@ public void TestCommandLineArgsValuesTest()
96107
var validDataset = Path.GetTempFileName();
97108
var labelName = "Label";
98109
var name = "testname";
99-
var outputPath = ".";
110+
var outputPath = "x:\\mlnet";
100111
var falseString = "false";
101112

102113
// Create handler outside so that commandline and the handler is decoupled and testable.
103114
var handler = CommandHandler.Create<NewCommandSettings>(
104115
(opt) =>
105116
{
106-
parsingSuccessful = true;
107117
Assert.AreEqual(opt.MlTask, "binary-classification");
108-
Assert.AreEqual(opt.Dataset, trainDataset);
109-
Assert.AreEqual(opt.TestDataset, testDataset);
110-
Assert.AreEqual(opt.ValidationDataset, validDataset);
118+
Assert.AreEqual(opt.Dataset.FullName, trainDataset);
119+
Assert.AreEqual(opt.TestDataset.FullName, testDataset);
120+
Assert.AreEqual(opt.ValidationDataset.FullName, validDataset);
111121
Assert.AreEqual(opt.LabelColumnName, labelName);
112-
Assert.AreEqual(opt.MaxExplorationTime, 5);
122+
Assert.AreEqual(opt.MaxExplorationTime, (uint)5);
113123
Assert.AreEqual(opt.Name, name);
114-
Assert.AreEqual(opt.OutputPath, outputPath);
124+
Assert.AreEqual(opt.OutputPath.FullName, outputPath);
115125
Assert.AreEqual(opt.HasHeader, bool.Parse(falseString));
126+
parsingSuccessful = true;
116127
});
117128

118129
var parser = new CommandLineBuilder()
119130
// Parser
120131
.AddCommand(CommandDefinitions.New(handler))
121132
.UseDefaults()
133+
.UseExceptionHandler((e, ctx) =>
134+
{
135+
Console.WriteLine(e.ToString());
136+
})
122137
.Build();
123138

124139
// Incorrect mltask test
@@ -151,6 +166,10 @@ public void TestCommandLineArgsMutuallyExclusiveArgsTest()
151166
// Parser
152167
.AddCommand(CommandDefinitions.New(handler))
153168
.UseDefaults()
169+
.UseExceptionHandler((e, ctx) =>
170+
{
171+
Console.WriteLine(e.ToString());
172+
})
154173
.Build();
155174

156175
// Incorrect arguments : specifying dataset and train-dataset
@@ -186,17 +205,21 @@ public void CacheArgumentTest()
186205
var handler = CommandHandler.Create<NewCommandSettings>(
187206
(opt) =>
188207
{
189-
parsingSuccessful = true;
190208
Assert.AreEqual(opt.MlTask, "binary-classification");
191-
Assert.AreEqual(opt.Dataset, trainDataset);
209+
Assert.AreEqual(opt.Dataset.FullName, trainDataset);
192210
Assert.AreEqual(opt.LabelColumnName, labelName);
193211
Assert.AreEqual(opt.Cache, cache);
212+
parsingSuccessful = true;
194213
});
195214

196215
var parser = new CommandLineBuilder()
197216
// Parser
198217
.AddCommand(CommandDefinitions.New(handler))
199218
.UseDefaults()
219+
.UseExceptionHandler((e, ctx) =>
220+
{
221+
Console.WriteLine(e.ToString());
222+
})
200223
.Build();
201224

202225
// valid cache test
@@ -230,5 +253,43 @@ public void CacheArgumentTest()
230253
File.Delete(trainDataset);
231254
File.Delete(testDataset);
232255
}
256+
257+
[TestMethod]
258+
public void IgnoreColumnsArgumentTest()
259+
{
260+
bool parsingSuccessful = false;
261+
var trainDataset = Path.GetTempFileName();
262+
var testDataset = Path.GetTempFileName();
263+
var labelName = "Label";
264+
265+
// Create handler outside so that commandline and the handler is decoupled and testable.
266+
var handler = CommandHandler.Create<NewCommandSettings>(
267+
(opt) =>
268+
{
269+
Assert.AreEqual(opt.MlTask, "binary-classification");
270+
Assert.AreEqual(opt.Dataset.FullName, trainDataset);
271+
Assert.AreEqual(opt.LabelColumnName, labelName);
272+
Assert.IsTrue(opt.IgnoreColumns.SequenceEqual(new List<string>() { "a", "b", "c" }));
273+
parsingSuccessful = true;
274+
});
275+
276+
var parser = new CommandLineBuilder()
277+
// Parser
278+
.AddCommand(CommandDefinitions.New(handler))
279+
.UseDefaults()
280+
.UseExceptionHandler((e, ctx) =>
281+
{
282+
Console.WriteLine(e.ToString());
283+
})
284+
.Build();
285+
286+
// valid cache test
287+
string[] args = new[] { "new", "--ml-task", "binary-classification", "--dataset", trainDataset, "--label-column-name", labelName, "--ignore-columns", "a", "b", "c" };
288+
parser.InvokeAsync(args).Wait();
289+
Assert.IsTrue(parsingSuccessful);
290+
291+
File.Delete(trainDataset);
292+
File.Delete(testDataset);
293+
}
233294
}
234295
}

src/mlnet/AutoML/AutoMLEngine.cs

+7-7
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,15 @@ public AutoMLEngine(NewCommandSettings settings)
2424
this.enableCaching = Utils.GetCacheSettings(settings.Cache);
2525
}
2626

27-
public ColumnInferenceResults InferColumns(MLContext context)
27+
public ColumnInferenceResults InferColumns(MLContext context, ColumnInformation columnInformation)
2828
{
2929
//Check what overload method of InferColumns needs to be called.
3030
logger.Log(LogLevel.Info, Strings.InferColumns);
3131
ColumnInferenceResults columnInference = null;
3232
var dataset = settings.Dataset.FullName;
33-
if (settings.LabelColumnName != null)
33+
if (columnInformation.LabelColumn != null)
3434
{
35-
columnInference = context.Auto().InferColumns(dataset, settings.LabelColumnName, groupColumns: false);
35+
columnInference = context.Auto().InferColumns(dataset, columnInformation, groupColumns: false);
3636
}
3737
else
3838
{
@@ -42,7 +42,7 @@ public ColumnInferenceResults InferColumns(MLContext context)
4242
return columnInference;
4343
}
4444

45-
(Pipeline, ITransformer) IAutoMLEngine.ExploreModels(MLContext context, IDataView trainData, IDataView validationData, string labelName)
45+
(Pipeline, ITransformer) IAutoMLEngine.ExploreModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation)
4646
{
4747
ITransformer model = null;
4848

@@ -58,7 +58,7 @@ public ColumnInferenceResults InferColumns(MLContext context)
5858
ProgressHandler = progressReporter,
5959
EnableCaching = this.enableCaching
6060
})
61-
.Execute(trainData, validationData, new ColumnInformation() { LabelColumn = labelName });
61+
.Execute(trainData, validationData, columnInformation);
6262
logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline);
6363
var bestIteration = result.Best();
6464
pipeline = bestIteration.Pipeline;
@@ -74,7 +74,7 @@ public ColumnInferenceResults InferColumns(MLContext context)
7474
MaxExperimentTimeInSeconds = settings.MaxExplorationTime,
7575
ProgressHandler = progressReporter,
7676
EnableCaching = this.enableCaching
77-
}).Execute(trainData, validationData, new ColumnInformation() { LabelColumn = labelName });
77+
}).Execute(trainData, validationData, columnInformation);
7878
logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline);
7979
var bestIteration = result.Best();
8080
pipeline = bestIteration.Pipeline;
@@ -100,7 +100,7 @@ public ColumnInferenceResults InferColumns(MLContext context)
100100

101101
var result = context.Auto()
102102
.CreateMulticlassClassificationExperiment(experimentSettings)
103-
.Execute(trainData, validationData, new ColumnInformation() { LabelColumn = labelName });
103+
.Execute(trainData, validationData, columnInformation);
104104
logger.Log(LogLevel.Info, Strings.RetrieveBestPipeline);
105105
var bestIteration = result.Best();
106106
pipeline = bestIteration.Pipeline;

src/mlnet/AutoML/IAutoMLEngine.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ namespace Microsoft.ML.CLI.CodeGenerator
1010
{
1111
internal interface IAutoMLEngine
1212
{
13-
ColumnInferenceResults InferColumns(MLContext context);
13+
ColumnInferenceResults InferColumns(MLContext context, ColumnInformation columnInformation);
1414

15-
(Pipeline, ITransformer) ExploreModels(MLContext context, IDataView trainData, IDataView validationData, string labelName);
15+
(Pipeline, ITransformer) ExploreModels(MLContext context, IDataView trainData, IDataView validationData, ColumnInformation columnInformation);
1616

1717
}
1818
}

src/mlnet/CodeGenerator/CodeGenerationHelper.cs

+15-7
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@ public void GenerateCode()
3636
ColumnInferenceResults columnInference = null;
3737
try
3838
{
39-
columnInference = automlEngine.InferColumns(context);
39+
var inputColumnInformation = new ColumnInformation();
40+
inputColumnInformation.LabelColumn = settings.LabelColumnName;
41+
foreach (var value in settings.IgnoreColumns)
42+
{
43+
inputColumnInformation.IgnoredColumns.Add(value);
44+
}
45+
columnInference = automlEngine.InferColumns(context, inputColumnInformation);
4046
}
4147
catch (Exception e)
4248
{
@@ -47,20 +53,22 @@ public void GenerateCode()
4753
return;
4854
}
4955

50-
// Sanitize columns
51-
Array.ForEach(columnInference.TextLoaderOptions.Columns, t => t.Name = Utils.Sanitize(t.Name));
56+
var textLoaderOptions = columnInference.TextLoaderOptions;
57+
var columnInformation = columnInference.ColumnInformation;
5258

53-
var sanitizedLabelName = Utils.Sanitize(columnInference.ColumnInformation.LabelColumn);
59+
// Sanitization of input data.
60+
Array.ForEach(textLoaderOptions.Columns, t => t.Name = Utils.Sanitize(t.Name));
61+
columnInformation = Utils.GetSanitizedColumnInformation(columnInformation);
5462

5563
// Load data
56-
(IDataView trainData, IDataView validationData) = LoadData(context, columnInference.TextLoaderOptions);
64+
(IDataView trainData, IDataView validationData) = LoadData(context, textLoaderOptions);
5765

5866
// Explore the models
5967
(Pipeline, ITransformer) result = default;
6068
Console.WriteLine($"{Strings.ExplorePipeline}: {settings.MlTask}");
6169
try
6270
{
63-
result = automlEngine.ExploreModels(context, trainData, validationData, sanitizedLabelName);
71+
result = automlEngine.ExploreModels(context, trainData, validationData, columnInformation);
6472
}
6573
catch (Exception e)
6674
{
@@ -82,7 +90,7 @@ public void GenerateCode()
8290
Utils.SaveModel(model, modelPath, context);
8391

8492
// Generate the Project
85-
GenerateProject(columnInference, pipeline, sanitizedLabelName, modelPath);
93+
GenerateProject(columnInference, pipeline, columnInformation.LabelColumn, modelPath);
8694
}
8795

8896
internal void GenerateProject(ColumnInferenceResults columnInference, Pipeline pipeline, string labelName, FileInfo modelPath)

src/mlnet/Commands/CommandDefinitions.cs

+11-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ internal static System.CommandLine.Command New(ICommandHandler handler)
3030
Name(),
3131
OutputPath(),
3232
HasHeader(),
33-
Cache()
33+
Cache(),
34+
IgnoreColumns()
3435
};
3536

3637
newCommand.Argument.AddValidator((sym) =>
@@ -51,6 +52,11 @@ internal static System.CommandLine.Command New(ICommandHandler handler)
5152
{
5253
return "The following options are mutually exclusive please provide only one : --label-column-name, --label-column-index";
5354
}
55+
if (sym.Children["--label-column-index"] != null && sym.Children["--ignore-columns"] != null)
56+
{
57+
return "Currently we don't support specifying --ignore-columns in conjunction with --label-column-index";
58+
}
59+
5460
return null;
5561
});
5662

@@ -104,6 +110,10 @@ Option Cache() =>
104110
new Option(new List<string>() { "--cache" }, "Specify on/off/auto if you want cache to be turned on, off or auto determined.",
105111
new Argument<string>(defaultValue: "auto").FromAmong(GetCacheSuggestions()));
106112

113+
Option IgnoreColumns() =>
114+
new Option(new List<string>() { "--ignore-columns" }, "Specify the columns that needs to be ignored in the given dataset.",
115+
new Argument<List<string>>());
116+
107117
}
108118

109119
private static string[] GetMlTaskSuggestions()

src/mlnet/Commands/New/NewCommandSettings.cs

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System.Collections.Generic;
56
using System.IO;
67

78
namespace Microsoft.ML.CLI.Data
@@ -32,5 +33,7 @@ public class NewCommandSettings
3233

3334
public string Cache { get; set; }
3435

36+
public List<string> IgnoreColumns { get; set; }
37+
3538
}
3639
}

src/mlnet/Program.cs

-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@ public static void Main(string[] args)
6262
.UseDefaults()
6363
.Build();
6464

65-
6665
parser.InvokeAsync(args).Wait();
6766
}
6867
}

0 commit comments

Comments
 (0)