@@ -117,12 +117,12 @@ public bool Equals(ColumnRoutingStructure obj)
117
117
118
118
internal interface ITransformInferenceExpert
119
119
{
120
- IEnumerable < SuggestedTransform > Apply ( IntermediateColumn [ ] columns ) ;
120
+ IEnumerable < SuggestedTransform > Apply ( IntermediateColumn [ ] columns , TaskKind task ) ;
121
121
}
122
122
123
123
public abstract class TransformInferenceExpertBase : ITransformInferenceExpert
124
124
{
125
- public abstract IEnumerable < SuggestedTransform > Apply ( IntermediateColumn [ ] columns ) ;
125
+ public abstract IEnumerable < SuggestedTransform > Apply ( IntermediateColumn [ ] columns , TaskKind task ) ;
126
126
127
127
protected readonly MLContext Context ;
128
128
@@ -137,8 +137,8 @@ private static IEnumerable<ITransformInferenceExpert> GetExperts(MLContext conte
137
137
// The expert work independently of each other, the sequence is irrelevant
138
138
// (it only determines the sequence of resulting transforms).
139
139
140
- // For text labels , convert to categories.
141
- yield return new Experts . AutoLabel ( context ) ;
140
+ // For multiclass tasks , convert label column to key
141
+ yield return new Experts . Label ( context ) ;
142
142
143
143
// For boolean columns use convert transform
144
144
yield return new Experts . Boolean ( context ) ;
@@ -155,21 +155,26 @@ private static IEnumerable<ITransformInferenceExpert> GetExperts(MLContext conte
155
155
156
156
internal static class Experts
157
157
{
158
- internal sealed class AutoLabel : TransformInferenceExpertBase
158
+ internal sealed class Label : TransformInferenceExpertBase
159
159
{
160
- public AutoLabel ( MLContext context ) : base ( context )
160
+ public Label ( MLContext context ) : base ( context )
161
161
{
162
162
}
163
163
164
- public override IEnumerable < SuggestedTransform > Apply ( IntermediateColumn [ ] columns )
164
+ public override IEnumerable < SuggestedTransform > Apply ( IntermediateColumn [ ] columns , TaskKind task )
165
165
{
166
+ if ( task != TaskKind . MulticlassClassification )
167
+ {
168
+ yield break ;
169
+ }
170
+
166
171
var lastLabelColId = Array . FindLastIndex ( columns , x => x . Purpose == ColumnPurpose . Label ) ;
167
172
if ( lastLabelColId < 0 )
168
173
yield break ;
169
174
170
175
var col = columns [ lastLabelColId ] ;
171
176
172
- if ( col . Type . IsText ( ) )
177
+ if ( ! col . Type . IsKey ( ) )
173
178
{
174
179
yield return ValueToKeyMappingExtension . CreateSuggestedTransform ( Context , col . ColumnName , col . ColumnName ) ;
175
180
}
@@ -182,7 +187,7 @@ public Categorical(MLContext context) : base(context)
182
187
{
183
188
}
184
189
185
- public override IEnumerable < SuggestedTransform > Apply ( IntermediateColumn [ ] columns )
190
+ public override IEnumerable < SuggestedTransform > Apply ( IntermediateColumn [ ] columns , TaskKind task )
186
191
{
187
192
bool foundCat = false ;
188
193
bool foundCatHash = false ;
@@ -232,7 +237,7 @@ public Boolean(MLContext context) : base(context)
232
237
{
233
238
}
234
239
235
- public override IEnumerable < SuggestedTransform > Apply ( IntermediateColumn [ ] columns )
240
+ public override IEnumerable < SuggestedTransform > Apply ( IntermediateColumn [ ] columns , TaskKind task )
236
241
{
237
242
var newColumns = new List < string > ( ) ;
238
243
@@ -260,7 +265,7 @@ public Text(MLContext context) : base(context)
260
265
{
261
266
}
262
267
263
- public override IEnumerable < SuggestedTransform > Apply ( IntermediateColumn [ ] columns )
268
+ public override IEnumerable < SuggestedTransform > Apply ( IntermediateColumn [ ] columns , TaskKind task )
264
269
{
265
270
var featureCols = new List < string > ( ) ;
266
271
@@ -286,7 +291,7 @@ public NumericMissing(MLContext context) : base(context)
286
291
{
287
292
}
288
293
289
- public override IEnumerable < SuggestedTransform > Apply ( IntermediateColumn [ ] columns )
294
+ public override IEnumerable < SuggestedTransform > Apply ( IntermediateColumn [ ] columns , TaskKind task )
290
295
{
291
296
var columnsWithMissing = new List < string > ( ) ;
292
297
foreach ( var column in columns )
@@ -313,7 +318,7 @@ public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] colum
313
318
/// <summary>
314
319
/// Automatically infer transforms for the data view
315
320
/// </summary>
316
- public static SuggestedTransform [ ] InferTransforms ( MLContext context , ( string , DataViewType , ColumnPurpose , ColumnDimensions ) [ ] columns )
321
+ public static SuggestedTransform [ ] InferTransforms ( MLContext context , TaskKind task , ( string , DataViewType , ColumnPurpose , ColumnDimensions ) [ ] columns )
317
322
{
318
323
var intermediateCols = columns . Where ( c => c . Item3 != ColumnPurpose . Ignore )
319
324
. Select ( c => new IntermediateColumn ( c . Item1 , c . Item2 , c . Item3 , c . Item4 ) )
@@ -322,7 +327,7 @@ public static SuggestedTransform[] InferTransforms(MLContext context, (string, D
322
327
var suggestedTransforms = new List < SuggestedTransform > ( ) ;
323
328
foreach ( var expert in GetExperts ( context ) )
324
329
{
325
- SuggestedTransform [ ] suggestions = expert . Apply ( intermediateCols ) . ToArray ( ) ;
330
+ SuggestedTransform [ ] suggestions = expert . Apply ( intermediateCols , task ) . ToArray ( ) ;
326
331
suggestedTransforms . AddRange ( suggestions ) ;
327
332
}
328
333
0 commit comments