Skip to content

Commit 4ba15b8

Browse files
authored
Implementing copy column estimator (#706)
1 parent fe71bb8 commit 4ba15b8

File tree

7 files changed

+537
-110
lines changed

7 files changed

+537
-110
lines changed

src/Microsoft.ML.Data/EntryPoints/SchemaManipulation.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,7 @@ public static CommonOutputs.TransformOutput CopyColumns(IHostEnvironment env, Co
4848
var host = env.Register("CopyColumns");
4949
host.CheckValue(input, nameof(input));
5050
EntryPointUtils.CheckInputArgs(host, input);
51-
52-
var xf = new CopyColumnsTransform(env, input, input.Data);
51+
var xf = CopyColumnsTransform.Create(env, input, input.Data);
5352
return new CommonOutputs.TransformOutput { Model = new TransformModel(env, xf, input.Data), OutputData = xf };
5453
}
5554

src/Microsoft.ML.Data/EntryPoints/ScoreColumnSelector.cs

+4-5
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@ public static CommonOutputs.TransformOutput SelectColumns(IHostEnvironment env,
2626
Contracts.CheckValue(env, nameof(env));
2727
env.CheckValue(input, nameof(input));
2828
EntryPointUtils.CheckInputArgs(env, input);
29-
int colMax;
3029
var view = input.Data;
31-
var maxScoreId = view.Schema.GetMaxMetadataKind(out colMax, MetadataUtils.Kinds.ScoreColumnSetId);
30+
var maxScoreId = view.Schema.GetMaxMetadataKind(out int colMax, MetadataUtils.Kinds.ScoreColumnSetId);
3231
List<int> indices = new List<int>();
3332
for (int i = 0; i < view.Schema.ColumnCount; i++)
3433
{
@@ -82,7 +81,7 @@ public static CommonOutputs.TransformOutput RenameBinaryPredictionScoreColumns(I
8281
// Rename all the score columns.
8382
int colMax;
8483
var maxScoreId = input.Data.Schema.GetMaxMetadataKind(out colMax, MetadataUtils.Kinds.ScoreColumnSetId);
85-
var copyCols = new List<CopyColumnsTransform.Column>();
84+
var copyCols = new List<(string Source, string Name)>();
8685
for (int i = 0; i < input.Data.Schema.ColumnCount; i++)
8786
{
8887
if (input.Data.Schema.IsHidden(i))
@@ -99,10 +98,10 @@ public static CommonOutputs.TransformOutput RenameBinaryPredictionScoreColumns(I
9998
}
10099
var source = input.Data.Schema.GetColumnName(i);
101100
var name = source + "." + positiveClass;
102-
copyCols.Add(new CopyColumnsTransform.Column() { Name = name, Source = source });
101+
copyCols.Add((source, name));
103102
}
104103

105-
var copyColumn = new CopyColumnsTransform(env, new CopyColumnsTransform.Arguments() { Column = copyCols.ToArray() }, input.Data);
104+
var copyColumn = new CopyColumnsTransform(env, copyCols.ToArray()).Transform(input.Data);
106105
var dropColumn = new DropColumnsTransform(env, new DropColumnsTransform.Arguments() { Column = copyCols.Select(c => c.Source).ToArray() }, copyColumn);
107106
return new CommonOutputs.TransformOutput { Model = new TransformModel(env, dropColumn, input.Data), OutputData = dropColumn };
108107
}

src/Microsoft.ML.Data/Evaluators/MulticlassClassifierEvaluator.cs

+1-12
Original file line numberDiff line numberDiff line change
@@ -925,18 +925,7 @@ protected override IDataView GetOverallResultsCore(IDataView overall)
925925

926926
private IDataView ChangeTopKAccColumnName(IDataView input)
927927
{
928-
var cpyArgs = new CopyColumnsTransform.Arguments
929-
{
930-
Column = new[]
931-
{
932-
new CopyColumnsTransform.Column()
933-
{
934-
Name=string.Format(TopKAccuracyFormat, _outputTopKAcc),
935-
Source=MultiClassClassifierEvaluator.TopKAccuracy
936-
}
937-
}
938-
};
939-
input = new CopyColumnsTransform(Host, cpyArgs, input);
928+
input = new CopyColumnsTransform(Host, (MultiClassClassifierEvaluator.TopKAccuracy, string.Format(TopKAccuracyFormat, _outputTopKAcc))).Transform(input);
940929
var dropArgs = new DropColumnsTransform.Arguments
941930
{
942931
Column = new[] { MultiClassClassifierEvaluator.TopKAccuracy }

0 commit comments

Comments
 (0)