2
2
// The .NET Foundation licenses this file to you under the MIT license.
3
3
// See the LICENSE file in the project root for more information.
4
4
5
+ using System ;
6
+ using System . Collections . Generic ;
5
7
using System . Linq ;
6
8
using Microsoft . ML . Data ;
7
9
8
10
namespace Microsoft . ML . Auto
9
11
{
10
12
internal static class ColumnInferenceApi
11
13
{
12
- public static ColumnInferenceResult InferColumns ( MLContext context , string path , string label ,
14
+ public static ( TextLoader . Arguments , IEnumerable < ( string , ColumnPurpose ) > ) InferColumns ( MLContext context , string path , int labelColumnIndex ,
13
15
bool hasHeader , char ? separatorChar , bool ? allowQuotedStrings , bool ? supportSparse , bool trimWhitespace , bool groupColumns )
14
16
{
15
17
var sample = TextFileSample . CreateFromFullFile ( path ) ;
16
18
var splitInference = InferSplit ( sample , separatorChar , allowQuotedStrings , supportSparse ) ;
17
19
var typeInference = InferColumnTypes ( context , sample , splitInference , hasHeader ) ;
20
+
21
+ // If label column index > inferred # of columns, throw error
22
+ if ( labelColumnIndex >= typeInference . Columns . Count ( ) )
23
+ {
24
+ throw new ArgumentOutOfRangeException ( nameof ( labelColumnIndex ) , $ "Label column index ({ labelColumnIndex } ) is >= than # of inferred columns ({ typeInference . Columns . Count ( ) } ).") ;
25
+ }
26
+
27
+ // if no column is named label,
28
+ // rename label column to default ML.NET label column name
29
+ if ( ! typeInference . Columns . Any ( c => c . SuggestedName == DefaultColumnNames . Label ) )
30
+ {
31
+ typeInference . Columns [ labelColumnIndex ] . SuggestedName = DefaultColumnNames . Label ;
32
+ }
33
+
34
+ return InferColumns ( context , path , typeInference . Columns [ labelColumnIndex ] . SuggestedName ,
35
+ hasHeader , splitInference , typeInference , trimWhitespace , groupColumns ) ;
36
+ }
37
+
38
+ public static ( TextLoader . Arguments , IEnumerable < ( string , ColumnPurpose ) > ) InferColumns ( MLContext context , string path , string label ,
39
+ char ? separatorChar , bool ? allowQuotedStrings , bool ? supportSparse , bool trimWhitespace , bool groupColumns )
40
+ {
41
+ var sample = TextFileSample . CreateFromFullFile ( path ) ;
42
+ var splitInference = InferSplit ( sample , separatorChar , allowQuotedStrings , supportSparse ) ;
43
+ var typeInference = InferColumnTypes ( context , sample , splitInference , true ) ;
44
+ return InferColumns ( context , path , label , true , splitInference , typeInference , trimWhitespace , groupColumns ) ;
45
+ }
46
+
47
+ public static ( TextLoader . Arguments , IEnumerable < ( string , ColumnPurpose ) > ) InferColumns ( MLContext context , string path , string label , bool hasHeader ,
48
+ TextFileContents . ColumnSplitResult splitInference , ColumnTypeInference . InferenceResult typeInference ,
49
+ bool trimWhitespace , bool groupColumns )
50
+ {
18
51
var loaderColumns = ColumnTypeInference . GenerateLoaderColumns ( typeInference . Columns ) ;
19
52
if ( ! loaderColumns . Any ( t => label . Equals ( t . Name ) ) )
20
53
{
@@ -34,25 +67,34 @@ public static ColumnInferenceResult InferColumns(MLContext context, string path,
34
67
35
68
var purposeInferenceResult = PurposeInference . InferPurposes ( context , dataView , label ) ;
36
69
37
- ( TextLoader . Column , ColumnPurpose Purpose ) [ ] inferredColumns = null ;
70
+ // start building result objects
71
+ IEnumerable < TextLoader . Column > columnResults = null ;
72
+ IEnumerable < ( string , ColumnPurpose ) > purposeResults = null ;
73
+
38
74
// infer column grouping and generate column names
39
75
if ( groupColumns )
40
76
{
41
77
var groupingResult = ColumnGroupingInference . InferGroupingAndNames ( context , hasHeader ,
42
78
typeInference . Columns , purposeInferenceResult ) ;
43
79
44
- // build result objects & return
45
- inferredColumns = groupingResult . Select ( c => ( c . GenerateTextLoaderColumn ( ) , c . Purpose ) ) . ToArray ( ) ;
80
+ columnResults = groupingResult . Select ( c => c . GenerateTextLoaderColumn ( ) ) ;
81
+ purposeResults = groupingResult . Select ( c => ( c . SuggestedName , c . Purpose ) ) ;
46
82
}
47
83
else
48
84
{
49
- inferredColumns = new ( TextLoader . Column , ColumnPurpose Purpose ) [ loaderColumns . Length ] ;
50
- for ( int i = 0 ; i < loaderColumns . Length ; i ++ )
51
- {
52
- inferredColumns [ i ] = ( loaderColumns [ i ] , purposeInferenceResult [ i ] . Purpose ) ;
53
- }
85
+ columnResults = loaderColumns ;
86
+ purposeResults = purposeInferenceResult . Select ( p => ( dataView . Schema [ p . ColumnIndex ] . Name , p . Purpose ) ) ;
54
87
}
55
- return new ColumnInferenceResult ( inferredColumns , splitInference . AllowQuote , splitInference . AllowSparse , new char [ ] { splitInference . Separator . Value } , hasHeader , trimWhitespace ) ;
88
+
89
+ return ( new TextLoader . Arguments ( )
90
+ {
91
+ Column = columnResults . ToArray ( ) ,
92
+ AllowQuoting = splitInference . AllowQuote ,
93
+ AllowSparse = splitInference . AllowSparse ,
94
+ Separators = new char [ ] { splitInference . Separator . Value } ,
95
+ HasHeader = hasHeader ,
96
+ TrimWhitespace = trimWhitespace
97
+ } , purposeResults ) ;
56
98
}
57
99
58
100
private static TextFileContents . ColumnSplitResult InferSplit ( TextFileSample sample , char ? separatorChar , bool ? allowQuotedStrings , bool ? supportSparse )
0 commit comments