Skip to content

Commit f8bac23

Browse files
daholsteDmitry-A
authored andcommitted
rev ColumnInference API: can take label index; rev output object types; add tests (dotnet#89)
1 parent b5697b4 commit f8bac23

14 files changed

+221
-220
lines changed

src/AutoML/API/InferenceException.cs

+2-9
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,9 @@ namespace Microsoft.ML.Auto
88
{
99
public enum InferenceType
1010
{
11-
Seperator,
12-
Header,
13-
Label,
14-
Task,
1511
ColumnDataKind,
16-
ColumnPurpose,
17-
Tranform,
18-
Trainer,
19-
Hyperparams,
20-
ColumnSplit
12+
ColumnSplit,
13+
Label,
2114
}
2215

2316
public class InferenceException : Exception

src/AutoML/API/MLContextDataExtensions.cs

+8-66
Original file line numberDiff line numberDiff line change
@@ -2,87 +2,29 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5-
using System;
65
using System.Collections.Generic;
7-
using System.Linq;
8-
using Microsoft.Data.DataView;
96
using Microsoft.ML.Data;
107

118
namespace Microsoft.ML.Auto
129
{
1310
public static class DataExtensions
1411
{
1512
// Delimiter, header, column datatype inference
16-
public static ColumnInferenceResult InferColumns(this DataOperationsCatalog catalog, string path, string label,
17-
bool hasHeader = false, char? separatorChar = null, bool? allowQuotedStrings = null, bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true)
13+
public static (TextLoader.Arguments TextLoaderArgs, IEnumerable<(string Name, ColumnPurpose Purpose)> ColumnPurpopses) InferColumns(this DataOperationsCatalog catalog, string path, string label,
14+
char? separatorChar = null, bool? allowQuotedStrings = null, bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true)
1815
{
1916
UserInputValidationUtil.ValidateInferColumnsArgs(path, label);
2017
var mlContext = new MLContext();
21-
return ColumnInferenceApi.InferColumns(mlContext, path, label, hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
18+
return ColumnInferenceApi.InferColumns(mlContext, path, label, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
2219
}
2320

24-
public static IDataView AutoRead(this DataOperationsCatalog catalog, string path, string label,
25-
bool hasHeader = false, char? separatorChar = null, bool? allowQuotedStrings = null, bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true)
21+
public static (TextLoader.Arguments TextLoaderArgs, IEnumerable<(string Name, ColumnPurpose Purpose)> ColumnPurpopses) InferColumns(this DataOperationsCatalog catalog, string path, int labelColumnIndex,
22+
bool hasHeader = false, char? separatorChar = null, bool? allowQuotedStrings = null, bool? supportSparse = null,
23+
bool trimWhitespace = false, bool groupColumns = true)
2624
{
27-
UserInputValidationUtil.ValidateAutoReadArgs(path, label);
25+
UserInputValidationUtil.ValidateInferColumnsArgs(path, labelColumnIndex);
2826
var mlContext = new MLContext();
29-
var columnInferenceResult = ColumnInferenceApi.InferColumns(mlContext, path, label, hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
30-
var textLoader = columnInferenceResult.BuildTextLoader();
31-
return textLoader.Read(path);
32-
}
33-
34-
public static TextLoader CreateTextLoader(this DataOperationsCatalog catalog, ColumnInferenceResult columnInferenceResult)
35-
{
36-
UserInputValidationUtil.ValidateCreateTextReaderArgs(columnInferenceResult);
37-
return columnInferenceResult.BuildTextLoader();
38-
}
39-
40-
// Task inference
41-
public static MachineLearningTaskType InferTask(this DataOperationsCatalog catalog, IDataView dataView)
42-
{
43-
throw new NotImplementedException();
44-
}
45-
46-
public enum MachineLearningTaskType
47-
{
48-
Regression,
49-
BinaryClassification,
50-
MultiClassClassification
51-
}
52-
}
53-
54-
public class ColumnInferenceResult
55-
{
56-
public readonly IEnumerable<(TextLoader.Column, ColumnPurpose)> Columns;
57-
public readonly bool AllowQuotedStrings;
58-
public readonly bool SupportSparse;
59-
public readonly char[] Separators;
60-
public readonly bool HasHeader;
61-
public readonly bool TrimWhitespace;
62-
63-
public ColumnInferenceResult(IEnumerable<(TextLoader.Column, ColumnPurpose)> columns,
64-
bool allowQuotedStrings, bool supportSparse, char[] separators, bool hasHeader, bool trimWhitespace)
65-
{
66-
Columns = columns;
67-
AllowQuotedStrings = allowQuotedStrings;
68-
SupportSparse = supportSparse;
69-
Separators = separators;
70-
HasHeader = hasHeader;
71-
TrimWhitespace = trimWhitespace;
72-
}
73-
74-
internal TextLoader BuildTextLoader()
75-
{
76-
var context = new MLContext();
77-
return new TextLoader(context, new TextLoader.Arguments()
78-
{
79-
AllowQuoting = AllowQuotedStrings,
80-
AllowSparse = SupportSparse,
81-
Column = Columns.Select(c => c.Item1).ToArray(),
82-
Separators = Separators,
83-
HasHeader = HasHeader,
84-
TrimWhitespace = TrimWhitespace
85-
});
27+
return ColumnInferenceApi.InferColumns(mlContext, path, labelColumnIndex, hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
8628
}
8729
}
8830
}

src/AutoML/ColumnInference/ColumnInferenceApi.cs

+52-10
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,52 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
6+
using System.Collections.Generic;
57
using System.Linq;
68
using Microsoft.ML.Data;
79

810
namespace Microsoft.ML.Auto
911
{
1012
internal static class ColumnInferenceApi
1113
{
12-
public static ColumnInferenceResult InferColumns(MLContext context, string path, string label,
14+
public static (TextLoader.Arguments, IEnumerable<(string, ColumnPurpose)>) InferColumns(MLContext context, string path, int labelColumnIndex,
1315
bool hasHeader, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
1416
{
1517
var sample = TextFileSample.CreateFromFullFile(path);
1618
var splitInference = InferSplit(sample, separatorChar, allowQuotedStrings, supportSparse);
1719
var typeInference = InferColumnTypes(context, sample, splitInference, hasHeader);
20+
21+
// If label column index > inferred # of columns, throw error
22+
if (labelColumnIndex >= typeInference.Columns.Count())
23+
{
24+
throw new ArgumentOutOfRangeException(nameof(labelColumnIndex), $"Label column index ({labelColumnIndex}) is >= than # of inferred columns ({typeInference.Columns.Count()}).");
25+
}
26+
27+
// if no column is named label,
28+
// rename label column to default ML.NET label column name
29+
if (!typeInference.Columns.Any(c => c.SuggestedName == DefaultColumnNames.Label))
30+
{
31+
typeInference.Columns[labelColumnIndex].SuggestedName = DefaultColumnNames.Label;
32+
}
33+
34+
return InferColumns(context, path, typeInference.Columns[labelColumnIndex].SuggestedName,
35+
hasHeader, splitInference, typeInference, trimWhitespace, groupColumns);
36+
}
37+
38+
public static (TextLoader.Arguments, IEnumerable<(string, ColumnPurpose)>) InferColumns(MLContext context, string path, string label,
39+
char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
40+
{
41+
var sample = TextFileSample.CreateFromFullFile(path);
42+
var splitInference = InferSplit(sample, separatorChar, allowQuotedStrings, supportSparse);
43+
var typeInference = InferColumnTypes(context, sample, splitInference, true);
44+
return InferColumns(context, path, label, true, splitInference, typeInference, trimWhitespace, groupColumns);
45+
}
46+
47+
public static (TextLoader.Arguments, IEnumerable<(string, ColumnPurpose)>) InferColumns(MLContext context, string path, string label, bool hasHeader,
48+
TextFileContents.ColumnSplitResult splitInference, ColumnTypeInference.InferenceResult typeInference,
49+
bool trimWhitespace, bool groupColumns)
50+
{
1851
var loaderColumns = ColumnTypeInference.GenerateLoaderColumns(typeInference.Columns);
1952
if (!loaderColumns.Any(t => label.Equals(t.Name)))
2053
{
@@ -34,25 +67,34 @@ public static ColumnInferenceResult InferColumns(MLContext context, string path,
3467

3568
var purposeInferenceResult = PurposeInference.InferPurposes(context, dataView, label);
3669

37-
(TextLoader.Column, ColumnPurpose Purpose)[] inferredColumns = null;
70+
// start building result objects
71+
IEnumerable<TextLoader.Column> columnResults = null;
72+
IEnumerable<(string, ColumnPurpose)> purposeResults = null;
73+
3874
// infer column grouping and generate column names
3975
if (groupColumns)
4076
{
4177
var groupingResult = ColumnGroupingInference.InferGroupingAndNames(context, hasHeader,
4278
typeInference.Columns, purposeInferenceResult);
4379

44-
// build result objects & return
45-
inferredColumns = groupingResult.Select(c => (c.GenerateTextLoaderColumn(), c.Purpose)).ToArray();
80+
columnResults = groupingResult.Select(c => c.GenerateTextLoaderColumn());
81+
purposeResults = groupingResult.Select(c => (c.SuggestedName, c.Purpose));
4682
}
4783
else
4884
{
49-
inferredColumns = new (TextLoader.Column, ColumnPurpose Purpose)[loaderColumns.Length];
50-
for (int i = 0; i < loaderColumns.Length; i++)
51-
{
52-
inferredColumns[i] = (loaderColumns[i], purposeInferenceResult[i].Purpose);
53-
}
85+
columnResults = loaderColumns;
86+
purposeResults = purposeInferenceResult.Select(p => (dataView.Schema[p.ColumnIndex].Name, p.Purpose));
5487
}
55-
return new ColumnInferenceResult(inferredColumns, splitInference.AllowQuote, splitInference.AllowSparse, new char[] { splitInference.Separator.Value }, hasHeader, trimWhitespace);
88+
89+
return (new TextLoader.Arguments()
90+
{
91+
Column = columnResults.ToArray(),
92+
AllowQuoting = splitInference.AllowQuote,
93+
AllowSparse = splitInference.AllowSparse,
94+
Separators = new char[] { splitInference.Separator.Value },
95+
HasHeader = hasHeader,
96+
TrimWhitespace = trimWhitespace
97+
}, purposeResults);
5698
}
5799

58100
private static TextFileContents.ColumnSplitResult InferSplit(TextFileSample sample, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse)

src/AutoML/ColumnInference/ColumnTypeInference.cs

+3-2
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,13 @@ public IntermediateColumn(ReadOnlyMemory<char>[] data, int columnId)
7070
public ReadOnlyMemory<char>[] RawData { get { return _data; } }
7171
}
7272

73-
public readonly struct Column
73+
public struct Column
7474
{
7575
public readonly int ColumnIndex;
76-
public readonly string SuggestedName;
7776
public readonly PrimitiveType ItemType;
7877

78+
public string SuggestedName;
79+
7980
public Column(int columnIndex, string suggestedName, PrimitiveType itemType)
8081
{
8182
ColumnIndex = columnIndex;

src/AutoML/Utils/UserInputValidationUtil.cs

+19-33
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public static void ValidateAutoFitArgs(IDataView trainData, string label, IDataV
1717
{
1818
ValidateTrainData(trainData);
1919
ValidateValidationData(trainData, validationData);
20-
ValidateLabel(trainData, validationData, label);
20+
ValidateLabel(trainData, label);
2121
ValidateSettings(settings);
2222
ValidatePurposeOverrides(trainData, validationData, label, purposeOverrides);
2323
}
@@ -28,49 +28,27 @@ public static void ValidateInferColumnsArgs(string path, string label)
2828
ValidatePath(path);
2929
}
3030

31-
public static void ValidateAutoReadArgs(string path, string label)
31+
public static void ValidateInferColumnsArgs(string path, int labelColumnIndex)
3232
{
33-
ValidateLabel(label);
33+
ValidateLabelColumnIndex(labelColumnIndex);
3434
ValidatePath(path);
3535
}
3636

37-
public static void ValidateCreateTextReaderArgs(ColumnInferenceResult columnInferenceResult)
37+
public static void ValidateAutoReadArgs(string path, string label)
3838
{
39-
if(columnInferenceResult == null)
40-
{
41-
throw new ArgumentNullException($"Column inference result cannot be null", nameof(columnInferenceResult));
42-
}
43-
44-
if (columnInferenceResult.Separators == null || !columnInferenceResult.Separators.Any())
45-
{
46-
throw new ArgumentException($"Column inference result cannot have null or empty separators", nameof(columnInferenceResult));
47-
}
48-
49-
if (columnInferenceResult.Columns == null || !columnInferenceResult.Columns.Any())
50-
{
51-
throw new ArgumentException($"Column inference result must contain at least one column", nameof(columnInferenceResult));
52-
}
53-
54-
if(columnInferenceResult.Columns.Any(c => c.Item1 == null))
55-
{
56-
throw new ArgumentException($"Column inference result cannot contain null columns", nameof(columnInferenceResult));
57-
}
58-
59-
if (columnInferenceResult.Columns.Any(c => c.Item1.Name == null || c.Item1.Type == null || c.Item1.Source == null))
60-
{
61-
throw new ArgumentException($"Column inference result cannot contain a column that has a null name, type, or source", nameof(columnInferenceResult));
62-
}
39+
ValidateLabel(label);
40+
ValidatePath(path);
6341
}
6442

6543
private static void ValidateTrainData(IDataView trainData)
6644
{
6745
if(trainData == null)
6846
{
69-
throw new ArgumentNullException("Training data cannot be null", nameof(trainData));
47+
throw new ArgumentNullException(nameof(trainData), "Training data cannot be null");
7048
}
7149
}
7250

73-
private static void ValidateLabel(IDataView trainData, IDataView validationData, string label)
51+
private static void ValidateLabel(IDataView trainData, string label)
7452
{
7553
ValidateLabel(label);
7654

@@ -84,15 +62,23 @@ private static void ValidateLabel(string label)
8462
{
8563
if (label == null)
8664
{
87-
throw new ArgumentNullException("Provided label cannot be null", nameof(label));
65+
throw new ArgumentNullException(nameof(label), "Provided label cannot be null");
66+
}
67+
}
68+
69+
private static void ValidateLabelColumnIndex(int labelColumnIndex)
70+
{
71+
if (labelColumnIndex < 0)
72+
{
73+
throw new ArgumentOutOfRangeException(nameof(labelColumnIndex), $"Provided label column index ({labelColumnIndex}) must be non-negative.");
8874
}
8975
}
9076

9177
private static void ValidatePath(string path)
9278
{
9379
if (path == null)
9480
{
95-
throw new ArgumentNullException("Provided path cannot be null", nameof(path));
81+
throw new ArgumentNullException(nameof(path), "Provided path cannot be null");
9682
}
9783

9884
var fileInfo = new FileInfo(path);
@@ -148,7 +134,7 @@ private static void ValidateSettings(AutoFitSettings settings)
148134

149135
if(settings.StoppingCriteria.MaxIterations <= 0)
150136
{
151-
throw new ArgumentOutOfRangeException("Max iterations must be > 0", nameof(settings));
137+
throw new ArgumentOutOfRangeException(nameof(settings), "Max iterations must be > 0");
152138
}
153139
}
154140

src/Test/AutoFitTests.cs

+6-6
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ public void AutoFitBinaryTest()
1515
{
1616
var context = new MLContext();
1717
var dataPath = DatasetUtil.DownloadUciAdultDataset();
18-
var columnInference = context.Data.InferColumns(dataPath, DatasetUtil.UciAdultLabel, true);
19-
var textLoader = context.Data.CreateTextLoader(columnInference);
18+
var columnInference = context.Data.InferColumns(dataPath, DatasetUtil.UciAdultLabel);
19+
var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderArgs);
2020
var trainData = textLoader.Read(dataPath);
2121
var validationData = trainData.Take(100);
2222
trainData = trainData.Skip(100);
@@ -38,8 +38,8 @@ public void AutoFitMultiTest()
3838
{
3939
var context = new MLContext();
4040
var dataPath = DatasetUtil.DownloadTrivialDataset();
41-
var columnInference = context.Data.InferColumns(dataPath, DatasetUtil.TrivialDatasetLabel, true);
42-
var textLoader = context.Data.CreateTextLoader(columnInference);
41+
var columnInference = context.Data.InferColumns(dataPath, DatasetUtil.TrivialDatasetLabel);
42+
var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderArgs);
4343
var trainData = textLoader.Read(dataPath);
4444
var validationData = trainData.Take(20);
4545
trainData = trainData.Skip(20);
@@ -61,8 +61,8 @@ public void AutoFitRegressionTest()
6161
{
6262
var context = new MLContext();
6363
var dataPath = DatasetUtil.DownloadMlNetGeneratedRegressionDataset();
64-
var columnInference = context.Data.InferColumns(dataPath, DatasetUtil.MlNetGeneratedRegressionLabel, true);
65-
var textLoader = context.Data.CreateTextLoader(columnInference);
64+
var columnInference = context.Data.InferColumns(dataPath, DatasetUtil.MlNetGeneratedRegressionLabel);
65+
var textLoader = context.Data.CreateTextLoader(columnInference.TextLoaderArgs);
6666
var trainData = textLoader.Read(dataPath);
6767
var validationData = trainData.Take(20);
6868
trainData = trainData.Skip(20);

0 commit comments

Comments
 (0)