Skip to content

Commit 390e9d7

Browse files
authored
Ungroup Columns in Column Inference (dotnet#40)
* Added sequential grouping of columns * added ungrouping of column option * reverted the file
1 parent 1c4886b commit 390e9d7

File tree

3 files changed

+51
-14
lines changed

3 files changed

+51
-14
lines changed

src/AutoML/API/MLContextDataExtensions.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,19 @@ public static class DataExtensions
1313
{
1414
// Delimiter, header, column datatype inference
1515
public static ColumnInferenceResult InferColumns(this DataOperations catalog, string path, string label,
16-
bool hasHeader = false, char? separatorChar = null, bool? allowQuotedStrings = null, bool? supportSparse = null, bool trimWhitespace = false)
16+
bool hasHeader = false, char? separatorChar = null, bool? allowQuotedStrings = null, bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true)
1717
{
1818
UserInputValidationUtil.ValidateInferColumnsArgs(path, label);
1919
var mlContext = new MLContext();
20-
return ColumnInferenceApi.InferColumns(mlContext, path, label, hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace);
20+
return ColumnInferenceApi.InferColumns(mlContext, path, label, hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
2121
}
2222

2323
public static IDataView AutoRead(this DataOperations catalog, string path, string label,
24-
bool hasHeader = false, char? separatorChar = null, bool? allowQuotedStrings = null, bool? supportSparse = null, bool trimWhitespace = false)
24+
bool hasHeader = false, char? separatorChar = null, bool? allowQuotedStrings = null, bool? supportSparse = null, bool trimWhitespace = false, bool groupColumns = true)
2525
{
2626
UserInputValidationUtil.ValidateAutoReadArgs(path, label);
2727
var mlContext = new MLContext();
28-
var columnInferenceResult = ColumnInferenceApi.InferColumns(mlContext, path, label, hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace);
28+
var columnInferenceResult = ColumnInferenceApi.InferColumns(mlContext, path, label, hasHeader, separatorChar, allowQuotedStrings, supportSparse, trimWhitespace, groupColumns);
2929
var textLoader = columnInferenceResult.BuildTextLoader();
3030
return textLoader.Read(path);
3131
}

src/AutoML/ColumnInference/ColumnInferenceApi.cs

+23-10
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,16 @@ namespace Microsoft.ML.Auto
99
{
1010
internal static class ColumnInferenceApi
1111
{
12-
public static ColumnInferenceResult InferColumns(MLContext context, string path, string label,
13-
bool hasHeader, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace)
12+
public static ColumnInferenceResult InferColumns(MLContext context, string path, string label,
13+
bool hasHeader, char? separatorChar, bool? allowQuotedStrings, bool? supportSparse, bool trimWhitespace, bool groupColumns)
1414
{
1515
var sample = TextFileSample.CreateFromFullFile(path);
1616
var splitInference = InferSplit(sample, separatorChar, allowQuotedStrings, supportSparse);
1717
var typeInference = InferColumnTypes(context, sample, splitInference, hasHeader);
18+
var loaderColumns = ColumnTypeInference.GenerateLoaderColumns(typeInference.Columns);
1819
var typedLoaderArgs = new TextLoader.Arguments
1920
{
20-
Column = ColumnTypeInference.GenerateLoaderColumns(typeInference.Columns),
21+
Column = loaderColumns,
2122
Separator = splitInference.Separator,
2223
AllowSparse = splitInference.AllowSparse,
2324
AllowQuoting = splitInference.AllowQuote,
@@ -29,12 +30,24 @@ public static ColumnInferenceResult InferColumns(MLContext context, string path,
2930

3031
var purposeInferenceResult = PurposeInference.InferPurposes(context, dataView, label);
3132

33+
(TextLoader.Column, ColumnPurpose Purpose)[] inferredColumns = null;
3234
// infer column grouping and generate column names
33-
var groupingResult = ColumnGroupingInference.InferGroupingAndNames(context, hasHeader,
34-
typeInference.Columns, purposeInferenceResult);
35+
if (groupColumns)
36+
{
37+
var groupingResult = ColumnGroupingInference.InferGroupingAndNames(context, hasHeader,
38+
typeInference.Columns, purposeInferenceResult);
3539

36-
// build result objects & return
37-
var inferredColumns = groupingResult.Select(c => (c.GenerateTextLoaderColumn(), c.Purpose)).ToArray();
40+
// build result objects & return
41+
inferredColumns = groupingResult.Select(c => (c.GenerateTextLoaderColumn(), c.Purpose)).ToArray();
42+
}
43+
else
44+
{
45+
inferredColumns = new (TextLoader.Column, ColumnPurpose Purpose)[loaderColumns.Length];
46+
for (int i = 0; i < loaderColumns.Length; i++)
47+
{
48+
inferredColumns[i] = (loaderColumns[i], purposeInferenceResult[i].Purpose);
49+
}
50+
}
3851
return new ColumnInferenceResult(inferredColumns, splitInference.AllowQuote, splitInference.AllowSparse, splitInference.Separator, hasHeader, trimWhitespace);
3952
}
4053

@@ -44,15 +57,15 @@ private static TextFileContents.ColumnSplitResult InferSplit(TextFileSample samp
4457
var splitInference = TextFileContents.TrySplitColumns(sample, separatorCandidates);
4558

4659
// respect passed-in overrides
47-
if(allowQuotedStrings != null)
60+
if (allowQuotedStrings != null)
4861
{
4962
splitInference.AllowQuote = allowQuotedStrings.Value;
5063
}
51-
if(supportSparse != null)
64+
if (supportSparse != null)
5265
{
5366
splitInference.AllowSparse = supportSparse.Value;
5467
}
55-
68+
5669
if (!splitInference.IsSuccess)
5770
{
5871
throw new InferenceException(InferenceType.ColumnSplit, "Unable to split the file provided into multiple, consistent columns.");

src/Test/ColumnInferenceTests.cs

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using System.Linq;
2+
using Microsoft.VisualStudio.TestTools.UnitTesting;
3+
4+
namespace Microsoft.ML.Auto.Test
5+
{
6+
[TestClass]
7+
public class ColumnInferenceTests
8+
{
9+
[TestMethod]
10+
public void UnGroupColumnsTest()
11+
{
12+
var dataPath = DatasetUtil.DownloadUciAdultDataset();
13+
var context = new MLContext();
14+
var columnInferenceWithoutGrouping = context.Data.InferColumns(dataPath, DatasetUtil.UciAdultLabel, true, groupColumns: false);
15+
foreach (var col in columnInferenceWithoutGrouping.Columns)
16+
{
17+
Assert.IsFalse(col.Item1.Source.Length > 1 || col.Item1.Source[0].Min != col.Item1.Source[0].Max);
18+
}
19+
20+
var columnInferenceWithGrouping = context.Data.InferColumns(dataPath, DatasetUtil.UciAdultLabel, true, groupColumns: true);
21+
Assert.IsTrue(columnInferenceWithGrouping.Columns.Count() < columnInferenceWithoutGrouping.Columns.Count());
22+
}
23+
}
24+
}

0 commit comments

Comments
 (0)