Skip to content

Commit c2b5e76

Browse files
authored
Replaces ChooseColumnsTransform and DropColumnsTransform with SelectColumnsTransform (#1371)
* Removes ChooseColumnsTransform and DropColumnsTransform classes replacing them with SelectColumnsTransform. These changes include: * Updates to SelectColumnsTransform to respect ordering when keeping columns. For example, if the input is ABC and CB is selected, the output will be CB. * Updates to code that used Choose or Drop columns, replacing with SelectColumns. * Updates to baseline output for tests to pass * Re-enabled the SavePipeline tests This fixes #1342 These changes are also related to #754
1 parent 1bcb79d commit c2b5e76

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+657
-1612
lines changed

src/Microsoft.ML.Data/Commands/SaveDataCommand.cs

+4-5
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,10 @@ private void RunCore(IChannel ch)
130130

131131
if (!string.IsNullOrWhiteSpace(Args.Columns))
132132
{
133-
var args = new ChooseColumnsTransform.Arguments();
134-
args.Column = Args.Columns
135-
.Split(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries).Select(s => new ChooseColumnsTransform.Column() { Name = s }).ToArray();
136-
if (Utils.Size(args.Column) > 0)
137-
data = new ChooseColumnsTransform(Host, args, data);
133+
var keepColumns = Args.Columns
134+
.Split(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries).ToArray();
135+
if (Utils.Size(keepColumns) > 0)
136+
data = SelectColumnsTransform.CreateKeep(Host, data, keepColumns);
138137
}
139138

140139
IDataSaver saver;

src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ public static CommonOutputs.TransformOutput RenameBinaryPredictionScoreColumns(I
101101
}
102102

103103
var copyColumn = new CopyColumnsTransform(env, copyCols.ToArray()).Transform(input.Data);
104-
var dropColumn = new DropColumnsTransform(env, new DropColumnsTransform.Arguments() { Column = copyCols.Select(c => c.Source).ToArray() }, copyColumn);
104+
var dropColumn = SelectColumnsTransform.CreateDrop(env, copyColumn, copyCols.Select(c => c.Source).ToArray());
105105
return new CommonOutputs.TransformOutput { Model = new TransformModel(env, dropColumn, input.Data), OutputData = dropColumn };
106106
}
107107
}

src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs

+29-46
Original file line numberDiff line numberDiff line change
@@ -703,59 +703,42 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDa
703703
}
704704
}
705705

706-
var args = new ChooseColumnsTransform.Arguments();
707-
var cols = new List<ChooseColumnsTransform.Column>()
708-
{
709-
new ChooseColumnsTransform.Column()
710-
{
711-
Name = string.Format(FoldDrAtKFormat, _k),
712-
Source = AnomalyDetectionEvaluator.OverallMetrics.DrAtK
713-
},
714-
new ChooseColumnsTransform.Column()
715-
{
716-
Name = string.Format(FoldDrAtPFormat, _p),
717-
Source = AnomalyDetectionEvaluator.OverallMetrics.DrAtPFpr
718-
},
719-
new ChooseColumnsTransform.Column()
720-
{
721-
Name = string.Format(FoldDrAtNumAnomaliesFormat, numAnomalies),
722-
Source=AnomalyDetectionEvaluator.OverallMetrics.DrAtNumPos
723-
},
724-
new ChooseColumnsTransform.Column()
725-
{
726-
Name=AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK
727-
},
728-
new ChooseColumnsTransform.Column()
729-
{
730-
Name=AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP
731-
},
732-
new ChooseColumnsTransform.Column()
733-
{
734-
Name=AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos
735-
},
736-
new ChooseColumnsTransform.Column()
737-
{
738-
Name = BinaryClassifierEvaluator.Auc
739-
}
740-
};
706+
var kFormatName = string.Format(FoldDrAtKFormat, _k);
707+
var pFormatName = string.Format(FoldDrAtPFormat, _p);
708+
var numAnomName = string.Format(FoldDrAtNumAnomaliesFormat, numAnomalies);
709+
710+
(string Source, string Name)[] cols =
711+
{
712+
(AnomalyDetectionEvaluator.OverallMetrics.DrAtK, kFormatName),
713+
(AnomalyDetectionEvaluator.OverallMetrics.DrAtPFpr, pFormatName),
714+
(AnomalyDetectionEvaluator.OverallMetrics.DrAtNumPos, numAnomName)
715+
};
716+
717+
// List of columns to keep, note that the order specified determines the order of the output
718+
var colsToKeep = new List<string>();
719+
colsToKeep.Add(kFormatName);
720+
colsToKeep.Add(pFormatName);
721+
colsToKeep.Add(numAnomName);
722+
colsToKeep.Add(AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK);
723+
colsToKeep.Add(AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP);
724+
colsToKeep.Add(AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos);
725+
colsToKeep.Add(BinaryClassifierEvaluator.Auc);
726+
727+
overall = new CopyColumnsTransform(Host, cols).Transform(overall);
728+
IDataView fold = SelectColumnsTransform.CreateKeep(Host, overall, colsToKeep.ToArray());
741729

742-
args.Column = cols.ToArray();
743-
IDataView fold = new ChooseColumnsTransform(Host, args, overall);
744730
string weightedFold;
745731
ch.Info(MetricWriter.GetPerFoldResults(Host, fold, out weightedFold));
746732
}
747733

748734
protected override IDataView GetOverallResultsCore(IDataView overall)
749735
{
750-
var args = new DropColumnsTransform.Arguments();
751-
args.Column = new[]
752-
{
753-
AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies,
754-
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK,
755-
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP,
756-
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos
757-
};
758-
return new DropColumnsTransform(Host, args, overall);
736+
return SelectColumnsTransform.CreateDrop(Host,
737+
overall,
738+
AnomalyDetectionEvaluator.OverallMetrics.NumAnomalies,
739+
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtK,
740+
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtP,
741+
AnomalyDetectionEvaluator.OverallMetrics.ThreshAtNumPos);
759742
}
760743

761744
protected override IEnumerable<string> GetPerInstanceColumnsToSave(RoleMappedSchema schema)

src/Microsoft.ML.Data/Evaluators/BinaryClassifierEvaluator.cs

+23-35
Original file line numberDiff line numberDiff line change
@@ -1333,43 +1333,33 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDa
13331333
if (!metrics.TryGetValue(MetricKinds.ConfusionMatrix, out conf))
13341334
throw ch.Except("No overall metrics found");
13351335

1336-
var args = new ChooseColumnsTransform.Arguments();
1337-
var cols = new List<ChooseColumnsTransform.Column>()
1338-
{
1339-
new ChooseColumnsTransform.Column()
1340-
{
1341-
Name = FoldAccuracy,
1342-
Source = BinaryClassifierEvaluator.Accuracy
1343-
},
1344-
new ChooseColumnsTransform.Column()
1345-
{
1346-
Name = FoldLogLoss,
1347-
Source = BinaryClassifierEvaluator.LogLoss
1348-
},
1349-
new ChooseColumnsTransform.Column()
1350-
{
1351-
Name = BinaryClassifierEvaluator.Entropy
1352-
},
1353-
new ChooseColumnsTransform.Column()
1354-
{
1355-
Name = FoldLogLosRed,
1356-
Source = BinaryClassifierEvaluator.LogLossReduction
1357-
},
1358-
new ChooseColumnsTransform.Column()
1359-
{
1360-
Name = BinaryClassifierEvaluator.Auc
1361-
}
1362-
};
1336+
(string Source, string Name)[] cols =
1337+
{
1338+
(BinaryClassifierEvaluator.Accuracy, FoldAccuracy),
1339+
(BinaryClassifierEvaluator.LogLoss, FoldLogLoss),
1340+
(BinaryClassifierEvaluator.LogLossReduction, FoldLogLosRed)
1341+
};
1342+
1343+
var colsToKeep = new List<string>();
1344+
colsToKeep.Add(FoldAccuracy);
1345+
colsToKeep.Add(FoldLogLoss);
1346+
colsToKeep.Add(BinaryClassifierEvaluator.Entropy);
1347+
colsToKeep.Add(FoldLogLosRed);
1348+
colsToKeep.Add(BinaryClassifierEvaluator.Auc);
1349+
13631350
int index;
13641351
if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.IsWeighted, out index))
1365-
cols.Add(new ChooseColumnsTransform.Column() { Name = MetricKinds.ColumnNames.IsWeighted });
1352+
colsToKeep.Add(MetricKinds.ColumnNames.IsWeighted);
13661353
if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratCol, out index))
1367-
cols.Add(new ChooseColumnsTransform.Column() { Name = MetricKinds.ColumnNames.StratCol });
1354+
colsToKeep.Add(MetricKinds.ColumnNames.StratCol);
13681355
if (fold.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out index))
1369-
cols.Add(new ChooseColumnsTransform.Column() { Name = MetricKinds.ColumnNames.StratVal });
1356+
colsToKeep.Add(MetricKinds.ColumnNames.StratVal);
1357+
1358+
fold = new CopyColumnsTransform(Host, cols).Transform(fold);
1359+
1360+
// Select the columns that are specified in the Copy
1361+
fold = SelectColumnsTransform.CreateKeep(Host, fold, colsToKeep.ToArray());
13701362

1371-
args.Column = cols.ToArray();
1372-
fold = new ChooseColumnsTransform(Host, args, fold);
13731363
string weightedConf;
13741364
var unweightedConf = MetricWriter.GetConfusionTable(Host, conf, out weightedConf);
13751365
string weightedFold;
@@ -1386,9 +1376,7 @@ protected override void PrintFoldResultsCore(IChannel ch, Dictionary<string, IDa
13861376

13871377
protected override IDataView GetOverallResultsCore(IDataView overall)
13881378
{
1389-
var args = new DropColumnsTransform.Arguments();
1390-
args.Column = new[] { BinaryClassifierEvaluator.Entropy };
1391-
return new DropColumnsTransform(Host, args, overall);
1379+
return SelectColumnsTransform.CreateDrop(Host, overall, BinaryClassifierEvaluator.Entropy);
13921380
}
13931381

13941382
protected override void PrintAdditionalMetricsCore(IChannel ch, Dictionary<string, IDataView>[] metrics)

src/Microsoft.ML.Data/Evaluators/EvaluatorUtils.cs

+3-6
Original file line numberDiff line numberDiff line change
@@ -933,7 +933,7 @@ private static IDataView AppendPerInstanceDataViews(IHostEnvironment env, string
933933
variableSizeVectorColumnName, type);
934934

935935
// Drop the old column that does not have variable length.
936-
idv = new DropColumnsTransform(env, new DropColumnsTransform.Arguments() { Column = new[] { variableSizeVectorColumnName } }, idv);
936+
idv = SelectColumnsTransform.CreateDrop(env, idv, variableSizeVectorColumnName);
937937
}
938938
return idv;
939939
};
@@ -1059,8 +1059,7 @@ internal static IDataView GetOverallMetricsData(IHostEnvironment env, IDataView
10591059
{
10601060
if (Utils.Size(nonAveragedCols) > 0)
10611061
{
1062-
var dropArgs = new DropColumnsTransform.Arguments() { Column = nonAveragedCols.ToArray() };
1063-
data = new DropColumnsTransform(env, dropArgs, data);
1062+
data = SelectColumnsTransform.CreateDrop(env, data, nonAveragedCols.ToArray());
10641063
}
10651064
idvList.Add(data);
10661065
}
@@ -1734,9 +1733,7 @@ public static IDataView GetNonStratifiedMetrics(IHostEnvironment env, IDataView
17341733
var found = data.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.StratVal, out stratVal);
17351734
env.Check(found, "If stratification column exist, data view must also contain a StratVal column");
17361735

1737-
var dropArgs = new DropColumnsTransform.Arguments();
1738-
dropArgs.Column = new[] { data.Schema.GetColumnName(stratCol), data.Schema.GetColumnName(stratVal) };
1739-
data = new DropColumnsTransform(env, dropArgs, data);
1736+
data = SelectColumnsTransform.CreateDrop(env, data, data.Schema.GetColumnName(stratCol), data.Schema.GetColumnName(stratVal));
17401737
return data;
17411738
}
17421739
}

src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs

+13-10
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,14 @@ private IDataView WrapPerInstance(RoleMappedData perInst)
213213
var idv = perInst.Data;
214214

215215
// Make a list of column names that Maml outputs as part of the per-instance data view, and then wrap
216-
// the per-instance data computed by the evaluator in a ChooseColumnsTransform.
217-
var cols = new List<ChooseColumnsTransform.Column>();
216+
// the per-instance data computed by the evaluator in a SelectColumnsTransform.
217+
var cols = new List<(string Source, string Name)>();
218+
var colsToKeep = new List<string>();
218219

219220
// If perInst is the result of cross-validation and contains a fold Id column, include it.
220221
int foldCol;
221222
if (perInst.Schema.Schema.TryGetColumnIndex(MetricKinds.ColumnNames.FoldIndex, out foldCol))
222-
cols.Add(new ChooseColumnsTransform.Column() { Source = MetricKinds.ColumnNames.FoldIndex });
223+
colsToKeep.Add(MetricKinds.ColumnNames.FoldIndex);
223224

224225
// Maml always outputs a name column, if it doesn't exist add a GenerateNumberTransform.
225226
if (perInst.Schema.Name == null)
@@ -228,22 +229,24 @@ private IDataView WrapPerInstance(RoleMappedData perInst)
228229
args.Column = new[] { new GenerateNumberTransform.Column() { Name = "Instance" } };
229230
args.UseCounter = true;
230231
idv = new GenerateNumberTransform(Host, args, idv);
231-
cols.Add(new ChooseColumnsTransform.Column() { Name = "Instance" });
232+
colsToKeep.Add("Instance");
232233
}
233234
else
234-
cols.Add(new ChooseColumnsTransform.Column() { Source = perInst.Schema.Name.Name, Name = "Instance" });
235+
{
236+
cols.Add((perInst.Schema.Name.Name, "Instance"));
237+
colsToKeep.Add("Instance");
238+
}
235239

236240
// Maml outputs the weight column if it exists.
237241
if (perInst.Schema.Weight != null)
238-
cols.Add(new ChooseColumnsTransform.Column() { Name = perInst.Schema.Weight.Name });
242+
colsToKeep.Add(perInst.Schema.Weight.Name);
239243

240244
// Get the other columns from the evaluator.
241245
foreach (var col in GetPerInstanceColumnsToSave(perInst.Schema))
242-
cols.Add(new ChooseColumnsTransform.Column() { Name = col });
246+
colsToKeep.Add(col);
243247

244-
var chooseArgs = new ChooseColumnsTransform.Arguments();
245-
chooseArgs.Column = cols.ToArray();
246-
idv = new ChooseColumnsTransform(Host, chooseArgs, idv);
248+
idv = new CopyColumnsTransform(Host, cols.ToArray()).Transform(idv);
249+
idv = SelectColumnsTransform.CreateKeep(Host, idv, colsToKeep.ToArray());
247250
return GetPerInstanceMetricsCore(idv, perInst.Schema);
248251
}
249252

src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs

+2-10
Original file line numberDiff line numberDiff line change
@@ -1051,22 +1051,14 @@ protected override IDataView GetOverallResultsCore(IDataView overall)
10511051
private IDataView ChangeTopKAccColumnName(IDataView input)
10521052
{
10531053
input = new CopyColumnsTransform(Host, (MultiClassClassifierEvaluator.TopKAccuracy, string.Format(TopKAccuracyFormat, _outputTopKAcc))).Transform(input);
1054-
var dropArgs = new DropColumnsTransform.Arguments
1055-
{
1056-
Column = new[] { MultiClassClassifierEvaluator.TopKAccuracy }
1057-
};
1058-
return new DropColumnsTransform(Host, dropArgs, input);
1054+
return SelectColumnsTransform.CreateDrop(Host, input, MultiClassClassifierEvaluator.TopKAccuracy );
10591055
}
10601056

10611057
private IDataView DropPerClassColumn(IDataView input)
10621058
{
10631059
if (input.Schema.TryGetColumnIndex(MultiClassClassifierEvaluator.PerClassLogLoss, out int perClassCol))
10641060
{
1065-
var args = new DropColumnsTransform.Arguments
1066-
{
1067-
Column = new[] { MultiClassClassifierEvaluator.PerClassLogLoss }
1068-
};
1069-
input = new DropColumnsTransform(Host, args, input);
1061+
input = SelectColumnsTransform.CreateDrop(Host, input, MultiClassClassifierEvaluator.PerClassLogLoss);
10701062
}
10711063
return input;
10721064
}

0 commit comments

Comments
 (0)