Skip to content

Commit 0f6a07d

Browse files
srsaggamDmitry-A
authored andcommitted
Added more command line args implementation to CLI tool and refactoring (dotnet#110)
* Added sequential grouping of columns * reverted the file * Set up CI with Azure Pipelines * Update azure-pipelines.yml for Azure Pipelines * Update azure-pipelines.yml for Azure Pipelines * added git status * reverted change * added codegen options and refactoring * minor fixes' * renamed params, minor refactoring * added tests for commandline and refactoring * removed file * added back the test case * minor fixes * Update src/mlnet.Test/CommandLineTests.cs Co-Authored-By: srsaggam <[email protected]> * review comments * capitalize the first character * changed the name of test case * remove unused directives
1 parent 4726a7a commit 0f6a07d

File tree

12 files changed

+331
-177
lines changed

12 files changed

+331
-177
lines changed

src/Microsoft.ML.Auto/API/DataExtensions.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@ public static (TextLoader.Arguments TextLoaderArgs, IEnumerable<(string Name, Co
1818
return ColumnInferenceApi.InferColumns(mlContext, path, label, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
1919
}
2020

21-
public static (TextLoader.Arguments TextLoaderArgs, IEnumerable<(string Name, ColumnPurpose Purpose)> ColumnPurpopses) InferColumns(this DataOperationsCatalog catalog, string path, int labelColumnIndex,
22-
bool hasHeader = false, char? separatorChar = null, bool? allowQuotedStrings = null, bool? supportSparse = null,
21+
public static (TextLoader.Arguments TextLoaderArgs, IEnumerable<(string Name, ColumnPurpose Purpose)> ColumnPurpopses) InferColumns(this DataOperationsCatalog catalog, string path, uint labelColumnIndex,
22+
bool hasHeader = false, char? separatorChar = null, bool? allowQuotedStrings = null, bool? supportSparse = null,
2323
bool trimWhitespace = false, bool groupColumns = true)
2424
{
25-
UserInputValidationUtil.ValidateInferColumnsArgs(path, labelColumnIndex);
25+
UserInputValidationUtil.ValidateInferColumnsArgs(path);
2626
var mlContext = new MLContext();
2727
return ColumnInferenceApi.InferColumns(mlContext, path, labelColumnIndex, hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
2828
}

src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs

+10-10
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ namespace Microsoft.ML.Auto
1111
{
1212
internal static class ColumnInferenceApi
1313
{
14-
public static (TextLoader.Arguments, IEnumerable<(string, ColumnPurpose)>) InferColumns(MLContext context, string path, int labelColumnIndex,
14+
public static (TextLoader.Arguments, IEnumerable<(string, ColumnPurpose)>) InferColumns(MLContext context, string path, uint labelColumnIndex,
1515
bool hasHeader, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
1616
{
1717
var sample = TextFileSample.CreateFromFullFile(path);
@@ -31,7 +31,7 @@ public static (TextLoader.Arguments, IEnumerable<(string, ColumnPurpose)>) Infer
3131
typeInference.Columns[labelColumnIndex].SuggestedName = DefaultColumnNames.Label;
3232
}
3333

34-
return InferColumns(context, path, typeInference.Columns[labelColumnIndex].SuggestedName,
34+
return InferColumns(context, path, typeInference.Columns[labelColumnIndex].SuggestedName,
3535
hasHeader, splitInference, typeInference, trimWhitespace, groupColumns);
3636
}
3737

@@ -87,14 +87,14 @@ public static (TextLoader.Arguments, IEnumerable<(string, ColumnPurpose)>) Infer
8787
}
8888

8989
return (new TextLoader.Arguments()
90-
{
91-
Column = columnResults.ToArray(),
92-
AllowQuoting = splitInference.AllowQuote,
93-
AllowSparse = splitInference.AllowSparse,
94-
Separators = new char[] { splitInference.Separator.Value },
95-
HasHeader = hasHeader,
96-
TrimWhitespace = trimWhitespace
97-
}, purposeResults);
90+
{
91+
Column = columnResults.ToArray(),
92+
AllowQuoting = splitInference.AllowQuote,
93+
AllowSparse = splitInference.AllowSparse,
94+
Separators = new char[] { splitInference.Separator.Value },
95+
HasHeader = hasHeader,
96+
TrimWhitespace = trimWhitespace
97+
}, purposeResults);
9898
}
9999

100100
private static TextFileContents.ColumnSplitResult InferSplit(TextFileSample sample, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse)

src/Microsoft.ML.Auto/Utils/UserInputValidationUtil.cs

+10-19
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,8 @@ public static void ValidateInferColumnsArgs(string path, string label)
2828
ValidatePath(path);
2929
}
3030

31-
public static void ValidateInferColumnsArgs(string path, int labelColumnIndex)
31+
public static void ValidateInferColumnsArgs(string path)
3232
{
33-
ValidateLabelColumnIndex(labelColumnIndex);
3433
ValidatePath(path);
3534
}
3635

@@ -42,7 +41,7 @@ public static void ValidateAutoReadArgs(string path, string label)
4241

4342
private static void ValidateTrainData(IDataView trainData)
4443
{
45-
if(trainData == null)
44+
if (trainData == null)
4645
{
4746
throw new ArgumentNullException(nameof(trainData), "Training data cannot be null");
4847
}
@@ -52,7 +51,7 @@ private static void ValidateLabel(IDataView trainData, string label)
5251
{
5352
ValidateLabel(label);
5453

55-
if(trainData.Schema.GetColumnOrNull(label) == null)
54+
if (trainData.Schema.GetColumnOrNull(label) == null)
5655
{
5756
throw new ArgumentException($"Provided label column '{label}' not found in training data.", nameof(label));
5857
}
@@ -66,14 +65,6 @@ private static void ValidateLabel(string label)
6665
}
6766
}
6867

69-
private static void ValidateLabelColumnIndex(int labelColumnIndex)
70-
{
71-
if (labelColumnIndex < 0)
72-
{
73-
throw new ArgumentOutOfRangeException(nameof(labelColumnIndex), $"Provided label column index ({labelColumnIndex}) must be non-negative.");
74-
}
75-
}
76-
7768
private static void ValidatePath(string path)
7869
{
7970
if (path == null)
@@ -96,7 +87,7 @@ private static void ValidatePath(string path)
9687

9788
private static void ValidateValidationData(IDataView trainData, IDataView validationData)
9889
{
99-
if(validationData == null)
90+
if (validationData == null)
10091
{
10192
return;
10293
}
@@ -109,15 +100,15 @@ private static void ValidateValidationData(IDataView trainData, IDataView valida
109100
$"and validation data has '{validationData.Schema.Count}' columns.", nameof(validationData));
110101
}
111102

112-
foreach(var trainCol in trainData.Schema)
103+
foreach (var trainCol in trainData.Schema)
113104
{
114105
var validCol = validationData.Schema.GetColumnOrNull(trainCol.Name);
115-
if(validCol == null)
106+
if (validCol == null)
116107
{
117108
throw new ArgumentException($"{schemaMismatchError} Column '{trainCol.Name}' exsits in train data, but not in validation data.", nameof(validationData));
118109
}
119110

120-
if(trainCol.Type != validCol.Value.Type)
111+
if (trainCol.Type != validCol.Value.Type)
121112
{
122113
throw new ArgumentException($"{schemaMismatchError} Column '{trainCol.Name}' is of type {trainCol.Type} in train data, and type " +
123114
$"{validCol.Value.Type} in validation data.", nameof(validationData));
@@ -127,12 +118,12 @@ private static void ValidateValidationData(IDataView trainData, IDataView valida
127118

128119
private static void ValidateSettings(AutoFitSettings settings)
129120
{
130-
if(settings?.StoppingCriteria == null)
121+
if (settings?.StoppingCriteria == null)
131122
{
132123
return;
133124
}
134125

135-
if(settings.StoppingCriteria.MaxIterations <= 0)
126+
if (settings.StoppingCriteria.MaxIterations <= 0)
136127
{
137128
throw new ArgumentOutOfRangeException(nameof(settings), "Max iterations must be > 0");
138129
}
@@ -162,7 +153,7 @@ private static void ValidatePurposeOverrides(IDataView trainData, IDataView vali
162153
}
163154

164155
// if column w/ purpose = 'Label' found, ensure it matches the passed-in label
165-
if(colPurpose == ColumnPurpose.Label && colName != label)
156+
if (colPurpose == ColumnPurpose.Label && colName != label)
166157
{
167158
throw new ArgumentException($"Label column name in provided list of purposes '{colName}' must match " +
168159
$"the label column name '{label}'", nameof(purposeOverrides));

src/Test/UserInputValidationTests.cs

+2-9
Original file line numberDiff line numberDiff line change
@@ -193,16 +193,9 @@ public void ValidateInferColumnsArgsEmptyFile()
193193
}
194194

195195
[TestMethod]
196-
public void ValidateOkayInferColsLabelIndex()
196+
public void ValidateInferColsPath()
197197
{
198-
UserInputValidationUtil.ValidateInferColumnsArgs(DatasetUtil.DownloadUciAdultDataset(), 0);
199-
}
200-
201-
[TestMethod]
202-
[ExpectedException(typeof(ArgumentOutOfRangeException))]
203-
public void ValidateInferColsNegativeLabelIndex()
204-
{
205-
UserInputValidationUtil.ValidateInferColumnsArgs(DatasetUtil.DownloadUciAdultDataset(), -1);
198+
UserInputValidationUtil.ValidateInferColumnsArgs(DatasetUtil.DownloadUciAdultDataset());
206199
}
207200
}
208201
}

0 commit comments

Comments
 (0)