Skip to content

Commit e5cc871

Browse files
daholsteDmitry-A
authored andcommitted
rev InferColumns to accept ColumnInfo input param (dotnet#186)
1 parent 33bc8f8 commit e5cc871

File tree

3 files changed

+38
-9
lines changed

3 files changed

+38
-9
lines changed

src/Microsoft.ML.Auto/API/AutoInferenceCatalog.cs

+11-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using Microsoft.ML.Data;
6+
57
namespace Microsoft.ML.Auto
68
{
79
public class AutoInferenceCatalog
@@ -52,11 +54,18 @@ public MulticlassClassificationExperiment CreateMulticlassClassificationExperime
5254
return new MulticlassClassificationExperiment(_context, experimentSettings);
5355
}
5456

55-
public ColumnInferenceResults InferColumns(string path, string label,char? separatorChar = null, bool? allowQuotedStrings = null,
57+
public ColumnInferenceResults InferColumns(string path, string labelColumn = DefaultColumnNames.Label, char? separatorChar = null, bool? allowQuotedStrings = null,
58+
bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true)
59+
{
60+
//UserInputValidationUtil.ValidateInferColumnsArgs(path, label);
61+
return ColumnInferenceApi.InferColumns(_context, path, labelColumn, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
62+
}
63+
64+
public ColumnInferenceResults InferColumns(string path, ColumnInformation columnInformation, char? separatorChar = null, bool? allowQuotedStrings = null,
5665
bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true)
5766
{
5867
//UserInputValidationUtil.ValidateInferColumnsArgs(path, label);
59-
return ColumnInferenceApi.InferColumns(_context, path, label, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
68+
return ColumnInferenceApi.InferColumns(_context, path, columnInformation, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
6069
}
6170

6271
public ColumnInferenceResults InferColumns(string path, uint labelColumnIndex, bool hasHeader = false, char? separatorChar = null,

src/Microsoft.ML.Auto/ColumnInference/ColumnInferenceApi.cs

+14-7
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,28 @@ public static ColumnInferenceResults InferColumns(MLContext context, string path
2424
typeInference.Columns[labelColumnIndex].SuggestedName = DefaultColumnNames.Label;
2525
}
2626

27-
return InferColumns(context, path, typeInference.Columns[labelColumnIndex].SuggestedName,
28-
hasHeader, splitInference, typeInference, trimWhitespace, groupColumns);
27+
var columnInfo = new ColumnInformation() { LabelColumn = typeInference.Columns[labelColumnIndex].SuggestedName };
28+
29+
return InferColumns(context, path, columnInfo, hasHeader, splitInference, typeInference, trimWhitespace, groupColumns);
30+
}
31+
32+
public static ColumnInferenceResults InferColumns(MLContext context, string path, string labelColumn,
33+
char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
34+
{
35+
var columnInfo = new ColumnInformation() { LabelColumn = labelColumn };
36+
return InferColumns(context, path, columnInfo, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
2937
}
3038

31-
public static ColumnInferenceResults InferColumns(MLContext context, string path, string label,
39+
public static ColumnInferenceResults InferColumns(MLContext context, string path, ColumnInformation columnInfo,
3240
char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
3341
{
3442
var sample = TextFileSample.CreateFromFullFile(path);
3543
var splitInference = InferSplit(context, sample, separatorChar, allowQuotedStrings, supportSparse);
36-
var typeInference = InferColumnTypes(context, sample, splitInference, true, null, label);
37-
return InferColumns(context, path, label, true, splitInference, typeInference, trimWhitespace, groupColumns);
44+
var typeInference = InferColumnTypes(context, sample, splitInference, true, null, columnInfo.LabelColumn);
45+
return InferColumns(context, path, columnInfo, true, splitInference, typeInference, trimWhitespace, groupColumns);
3846
}
3947

40-
public static ColumnInferenceResults InferColumns(MLContext context, string path, string label, bool hasHeader,
48+
public static ColumnInferenceResults InferColumns(MLContext context, string path, ColumnInformation columnInfo, bool hasHeader,
4149
TextFileContents.ColumnSplitResult splitInference, ColumnTypeInference.InferenceResult typeInference,
4250
bool trimWhitespace, bool groupColumns)
4351
{
@@ -54,7 +62,6 @@ public static ColumnInferenceResults InferColumns(MLContext context, string path
5462
var textLoader = context.Data.CreateTextLoader(typedLoaderArgs);
5563
var dataView = textLoader.Read(path);
5664

57-
var columnInfo = new ColumnInformation() { LabelColumn = label };
5865
var purposeInferenceResult = PurposeInference.InferPurposes(context, dataView, columnInfo);
5966

6067
// start building result objects

src/Test/ColumnInferenceTests.cs

+13
Original file line numberDiff line numberDiff line change
@@ -112,5 +112,18 @@ public void DefaultColumnNamesInferredCorrectly()
112112
Assert.AreEqual(DefaultColumnNames.GroupId, result.ColumnInformation.GroupIdColumn);
113113
Assert.AreEqual(result.ColumnInformation.NumericColumns.Count(), 3);
114114
}
115+
116+
[TestMethod]
117+
public void InferColumnsColumnInfoParam()
118+
{
119+
var columnInfo = new ColumnInformation() { LabelColumn = DatasetUtil.MlNetGeneratedRegressionLabel };
120+
var result = new MLContext().AutoInference().InferColumns(DatasetUtil.DownloadMlNetGeneratedRegressionDataset(),
121+
columnInfo);
122+
var labelCol = result.TextLoaderArgs.Column.First(c => c.Name == DatasetUtil.MlNetGeneratedRegressionLabel);
123+
Assert.AreEqual(DataKind.R4, labelCol.Type);
124+
Assert.AreEqual(DatasetUtil.MlNetGeneratedRegressionLabel, result.ColumnInformation.LabelColumn);
125+
Assert.AreEqual(1, result.ColumnInformation.NumericColumns.Count());
126+
Assert.AreEqual(DefaultColumnNames.Features, result.ColumnInformation.NumericColumns.First());
127+
}
115128
}
116129
}

0 commit comments

Comments
 (0)