Skip to content

Commit d08e9cf

Browse files
daholsteDmitry-A
authored andcommitted
fix multiclass with nonstandard label (dotnet#207)
1 parent 5db4b73 commit d08e9cf

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

src/Microsoft.ML.Auto/PipelineSuggesters/PipelineSuggester.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,8 @@ private static IEnumerable<SuggestedTransform> CalculateTransforms(
221221
// this is a work-around for ML.NET bug tracked by https://github.com/dotnet/machinelearning/issues/1969
222222
if (task == TaskKind.MulticlassClassification)
223223
{
224-
var transform = ValueToKeyMappingExtension.CreateSuggestedTransform(context, DefaultColumnNames.Label, DefaultColumnNames.Label);
224+
var labelColumn = columns.First(c => c.Item3 == ColumnPurpose.Label).Item1;
225+
var transform = ValueToKeyMappingExtension.CreateSuggestedTransform(context, labelColumn, labelColumn);
225226
transforms.Add(transform);
226227
}
227228
return transforms;

src/Microsoft.ML.Auto/TrainerExtensions/MultiTrainerExtensions.cs

+7-7
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab
2424
ColumnInformation columnInfo)
2525
{
2626
var binaryTrainer = _binaryLearnerCatalogItem.CreateInstance(mlContext, sweepParams, columnInfo) as ITrainerEstimatorProducingFloat;
27-
return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer);
27+
return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer, labelColumn: columnInfo.LabelColumn);
2828
}
2929

3030
public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
@@ -46,7 +46,7 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab
4646
ColumnInformation columnInfo)
4747
{
4848
var binaryTrainer = _binaryLearnerCatalogItem.CreateInstance(mlContext, sweepParams, columnInfo) as ITrainerEstimatorProducingFloat;
49-
return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer);
49+
return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer, labelColumn: columnInfo.LabelColumn);
5050
}
5151

5252
public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
@@ -89,7 +89,7 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab
8989
ColumnInformation columnInfo)
9090
{
9191
var binaryTrainer = _binaryLearnerCatalogItem.CreateInstance(mlContext, sweepParams, columnInfo) as ITrainerEstimatorProducingFloat;
92-
return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer);
92+
return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer, labelColumn: columnInfo.LabelColumn);
9393
}
9494

9595
public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
@@ -132,7 +132,7 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab
132132
ColumnInformation columnInfo)
133133
{
134134
var binaryTrainer = _binaryLearnerCatalogItem.CreateInstance(mlContext, sweepParams, columnInfo) as ITrainerEstimatorProducingFloat;
135-
return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer);
135+
return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer, labelColumn: columnInfo.LabelColumn);
136136
}
137137

138138
public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
@@ -154,7 +154,7 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab
154154
ColumnInformation columnInfo)
155155
{
156156
var binaryTrainer = _binaryLearnerCatalogItem.CreateInstance(mlContext, sweepParams, columnInfo) as ITrainerEstimatorProducingFloat;
157-
return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer);
157+
return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer, labelColumn: columnInfo.LabelColumn);
158158
}
159159

160160
public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
@@ -176,7 +176,7 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab
176176
ColumnInformation columnInfo)
177177
{
178178
var binaryTrainer = _binaryLearnerCatalogItem.CreateInstance(mlContext, sweepParams, columnInfo) as ITrainerEstimatorProducingFloat;
179-
return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer);
179+
return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer, labelColumn: columnInfo.LabelColumn);
180180
}
181181

182182
public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)
@@ -198,7 +198,7 @@ public ITrainerEstimator CreateInstance(MLContext mlContext, IEnumerable<Sweepab
198198
ColumnInformation columnInfo)
199199
{
200200
var binaryTrainer = _binaryLearnerCatalogItem.CreateInstance(mlContext, sweepParams, columnInfo) as ITrainerEstimatorProducingFloat;
201-
return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer);
201+
return mlContext.MulticlassClassification.Trainers.OneVersusAll(binaryTrainer, labelColumn: columnInfo.LabelColumn);
202202
}
203203

204204
public PipelineNode CreatePipelineNode(IEnumerable<SweepableParam> sweepParams, ColumnInformation columnInfo)

0 commit comments

Comments
 (0)