Skip to content

Commit f25e4f9

Browse files
authored
Remove duplicate value-to-key mapping transform for multiclass string labels (dotnet#283)
1 parent 3a4595d commit f25e4f9

File tree

5 files changed

+32
-38
lines changed

5 files changed

+32
-38
lines changed

src/Microsoft.ML.Auto/PipelineSuggesters/PipelineSuggester.cs

+1-18
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ public static SuggestedPipeline GetNextInferredPipeline(MLContext context,
3535
{
3636
var availableTrainers = RecipeInference.AllowedTrainers(context, task,
3737
ColumnInformationUtil.BuildColumnInfo(columns), trainerWhitelist);
38-
var transforms = CalculateTransforms(context, columns, task);
39-
//var transforms = TransformInferenceApi.InferTransforms(context, columns, task);
38+
var transforms = TransformInferenceApi.InferTransforms(context, task, columns);
4039

4140
// if we haven't run all pipelines once
4241
if (history.Count() < availableTrainers.Count())
@@ -213,21 +212,5 @@ private static bool SampleHyperparameters(MLContext context, SuggestedTrainer tr
213212

214213
return true;
215214
}
216-
217-
private static IEnumerable<SuggestedTransform> CalculateTransforms(
218-
MLContext context,
219-
(string, DataViewType, ColumnPurpose, ColumnDimensions)[] columns,
220-
TaskKind task)
221-
{
222-
var transforms = TransformInferenceApi.InferTransforms(context, columns).ToList();
223-
// this is a work-around for ML.NET bug tracked by https://github.com/dotnet/machinelearning/issues/1969
224-
if (task == TaskKind.MulticlassClassification)
225-
{
226-
var labelColumn = columns.First(c => c.Item3 == ColumnPurpose.Label).Item1;
227-
var transform = ValueToKeyMappingExtension.CreateSuggestedTransform(context, labelColumn, labelColumn);
228-
transforms.Add(transform);
229-
}
230-
return transforms;
231-
}
232215
}
233216
}

src/Microsoft.ML.Auto/TransformInference/TransformInference.cs

+19-14
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,12 @@ public bool Equals(ColumnRoutingStructure obj)
117117

118118
internal interface ITransformInferenceExpert
119119
{
120-
IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns);
120+
IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns, TaskKind task);
121121
}
122122

123123
public abstract class TransformInferenceExpertBase : ITransformInferenceExpert
124124
{
125-
public abstract IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns);
125+
public abstract IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns, TaskKind task);
126126

127127
protected readonly MLContext Context;
128128

@@ -137,8 +137,8 @@ private static IEnumerable<ITransformInferenceExpert> GetExperts(MLContext conte
137137
// The expert work independently of each other, the sequence is irrelevant
138138
// (it only determines the sequence of resulting transforms).
139139

140-
// For text labels, convert to categories.
141-
yield return new Experts.AutoLabel(context);
140+
// For multiclass tasks, convert label column to key
141+
yield return new Experts.Label(context);
142142

143143
// For boolean columns use convert transform
144144
yield return new Experts.Boolean(context);
@@ -155,21 +155,26 @@ private static IEnumerable<ITransformInferenceExpert> GetExperts(MLContext conte
155155

156156
internal static class Experts
157157
{
158-
internal sealed class AutoLabel : TransformInferenceExpertBase
158+
internal sealed class Label : TransformInferenceExpertBase
159159
{
160-
public AutoLabel(MLContext context) : base(context)
160+
public Label(MLContext context) : base(context)
161161
{
162162
}
163163

164-
public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns)
164+
public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns, TaskKind task)
165165
{
166+
if (task != TaskKind.MulticlassClassification)
167+
{
168+
yield break;
169+
}
170+
166171
var lastLabelColId = Array.FindLastIndex(columns, x => x.Purpose == ColumnPurpose.Label);
167172
if (lastLabelColId < 0)
168173
yield break;
169174

170175
var col = columns[lastLabelColId];
171176

172-
if (col.Type.IsText())
177+
if (!col.Type.IsKey())
173178
{
174179
yield return ValueToKeyMappingExtension.CreateSuggestedTransform(Context, col.ColumnName, col.ColumnName);
175180
}
@@ -182,7 +187,7 @@ public Categorical(MLContext context) : base(context)
182187
{
183188
}
184189

185-
public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns)
190+
public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns, TaskKind task)
186191
{
187192
bool foundCat = false;
188193
bool foundCatHash = false;
@@ -232,7 +237,7 @@ public Boolean(MLContext context) : base(context)
232237
{
233238
}
234239

235-
public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns)
240+
public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns, TaskKind task)
236241
{
237242
var newColumns = new List<string>();
238243

@@ -260,7 +265,7 @@ public Text(MLContext context) : base(context)
260265
{
261266
}
262267

263-
public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns)
268+
public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns, TaskKind task)
264269
{
265270
var featureCols = new List<string>();
266271

@@ -286,7 +291,7 @@ public NumericMissing(MLContext context) : base(context)
286291
{
287292
}
288293

289-
public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns)
294+
public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] columns, TaskKind task)
290295
{
291296
var columnsWithMissing = new List<string>();
292297
foreach (var column in columns)
@@ -313,7 +318,7 @@ public override IEnumerable<SuggestedTransform> Apply(IntermediateColumn[] colum
313318
/// <summary>
314319
/// Automatically infer transforms for the data view
315320
/// </summary>
316-
public static SuggestedTransform[] InferTransforms(MLContext context, (string, DataViewType, ColumnPurpose, ColumnDimensions)[] columns)
321+
public static SuggestedTransform[] InferTransforms(MLContext context, TaskKind task, (string, DataViewType, ColumnPurpose, ColumnDimensions)[] columns)
317322
{
318323
var intermediateCols = columns.Where(c => c.Item3 != ColumnPurpose.Ignore)
319324
.Select(c => new IntermediateColumn(c.Item1, c.Item2, c.Item3, c.Item4))
@@ -322,7 +327,7 @@ public static SuggestedTransform[] InferTransforms(MLContext context, (string, D
322327
var suggestedTransforms = new List<SuggestedTransform>();
323328
foreach (var expert in GetExperts(context))
324329
{
325-
SuggestedTransform[] suggestions = expert.Apply(intermediateCols).ToArray();
330+
SuggestedTransform[] suggestions = expert.Apply(intermediateCols, task).ToArray();
326331
suggestedTransforms.AddRange(suggestions);
327332
}
328333

src/Microsoft.ML.Auto/TransformInference/TransformInferenceApi.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ namespace Microsoft.ML.Auto
99
{
1010
internal static class TransformInferenceApi
1111
{
12-
public static IEnumerable<SuggestedTransform> InferTransforms(MLContext context, (string, DataViewType, ColumnPurpose, ColumnDimensions)[] columns)
12+
public static IEnumerable<SuggestedTransform> InferTransforms(MLContext context, TaskKind task, (string, DataViewType, ColumnPurpose, ColumnDimensions)[] columns)
1313
{
14-
return TransformInference.InferTransforms(context, columns);
14+
return TransformInference.InferTransforms(context, task, columns);
1515
}
1616
}
1717
}

src/Microsoft.ML.Auto/Utils/MLNetUtils/ColumnTypeExtensions.cs

+5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ public static bool IsVector(this DataViewType columnType)
2929
return columnType is VectorType;
3030
}
3131

32+
public static bool IsKey(this DataViewType columnType)
33+
{
34+
return columnType is KeyType;
35+
}
36+
3237
public static bool IsKnownSizeVector(this DataViewType columnType)
3338
{
3439
var vector = columnType as VectorType;

src/Test/TransformInferenceTests.cs

+5-4
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ public void TransformInferenceCustomLabelCol()
645645
}
646646

647647
[TestMethod]
648-
public void TransformInferenceCustomTextLabelCol()
648+
public void TransformInferenceCustomTextLabelColMulticlass()
649649
{
650650
TransformInferenceTestCore(new (string, DataViewType, ColumnPurpose, ColumnDimensions)[]
651651
{
@@ -663,7 +663,7 @@ public void TransformInferenceCustomTextLabelCol()
663663
],
664664
""Properties"": {}
665665
}
666-
]");
666+
]", TaskKind.MulticlassClassification);
667667
}
668668

669669
[TestMethod]
@@ -727,9 +727,10 @@ public void TransformInferenceMissingNameCollision()
727727

728728
private static void TransformInferenceTestCore(
729729
(string name, DataViewType type, ColumnPurpose purpose, ColumnDimensions dimensions)[] columns,
730-
string expectedJson)
730+
string expectedJson,
731+
TaskKind task = TaskKind.BinaryClassification)
731732
{
732-
var transforms = TransformInferenceApi.InferTransforms(new MLContext(), columns);
733+
var transforms = TransformInferenceApi.InferTransforms(new MLContext(), task, columns);
733734
TestApplyTransformsToRealDataView(transforms, columns);
734735
var pipelineNodes = transforms.Select(t => t.PipelineNode);
735736
Util.AssertObjectMatchesJson(expectedJson, pipelineNodes);

0 commit comments

Comments
 (0)