Skip to content

Commit 41f56e9

Browse files
committed
Fix build and polish code
1 parent aea460e commit 41f56e9

File tree

1 file changed

+27
-21
lines changed

1 file changed

+27
-21
lines changed

src/Microsoft.ML.Data/Dirty/ChooseColumnsByIndexTransform.cs

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ private sealed class Bindings
3838

3939
// Input schema of this transform. It's useful when determining column dependencies and other
4040
// relations between input and output schemas.
41-
private readonly Schema _input;
41+
private readonly Schema _sourceSchema;
4242

4343
// This transform's output schema.
4444
internal Schema OutputSchema { get; }
@@ -48,32 +48,20 @@ internal Bindings(Arguments args, Schema sourceSchema)
4848
Contracts.AssertValue(args);
4949
Contracts.AssertValue(sourceSchema);
5050

51-
_input = sourceSchema;
51+
_sourceSchema = sourceSchema;
5252

5353
if (args.Drop)
5454
// Drop columns indexed by args.Index
55-
_sources = Enumerable.Range(0, _input.ColumnCount).Except(args.Index).ToArray();
55+
_sources = Enumerable.Range(0, _sourceSchema.ColumnCount).Except(args.Index).ToArray();
5656
else
5757
// Keep columns indexed by args.Index
5858
_sources = args.Index.ToArray();
5959

6060
// Make sure the output of this transform is meaningful.
6161
Contracts.Check(_sources.Length > 0, "Choose columns by index has no output column.");
6262

63-
var schemaBuilder = new SchemaBuilder();
64-
for (int i = 0; i < _sources.Length; ++i)
65-
{
66-
var selectedIndex = _sources[i];
67-
68-
// The dropped/kept columns are determined by user-specified arguments, so we throw if a bad configuration is provided.
69-
string fmt = string.Format("Column index {0} invalid for input with {1} columns", selectedIndex, _input.ColumnCount);
70-
Contracts.Check(selectedIndex > sourceSchema.ColumnCount, fmt);
71-
72-
// Put the selected column into output schema.
73-
var selectedColumn = sourceSchema[selectedIndex];
74-
schemaBuilder.AddColumn(selectedColumn.Name, selectedColumn.Type, selectedColumn.Metadata);
75-
}
76-
OutputSchema = schemaBuilder.GetSchema();
63+
// All necessary fields in this class are set, so we can compute output schema now.
64+
OutputSchema = ComputeOutputSchema();
7765
}
7866

7967
internal Bindings(ModelLoadContext ctx, Schema sourceSchema)
@@ -85,14 +73,32 @@ internal Bindings(ModelLoadContext ctx, Schema sourceSchema)
8573
// int[]: selected source column indices
8674
_sources = ctx.Reader.ReadIntArray();
8775

88-
_input = sourceSchema;
76+
_sourceSchema = sourceSchema;
77+
OutputSchema = ComputeOutputSchema();
78+
}
79+
80+
/// <summary>
81+
/// After _input and _sources are set, pick up selected columns from _input to create output schema.
82+
/// Note that _sources tells us what columns in _input are selected.
83+
/// </summary>
84+
private Schema ComputeOutputSchema()
85+
{
8986
var schemaBuilder = new SchemaBuilder();
9087
for (int i = 0; i < _sources.Length; ++i)
9188
{
92-
var selectedColumn = sourceSchema[_sources[i]];
89+
// selectedIndex is an column index of input schema. Note that the input column indexed by _sources[i] in _input is sent
90+
// to the i-th column in the output schema.
91+
var selectedIndex = _sources[i];
92+
93+
// The dropped/kept columns are determined by user-specified arguments, so we throw if a bad configuration is provided.
94+
string fmt = string.Format("Column index {0} invalid for input with {1} columns", selectedIndex, _sourceSchema.ColumnCount);
95+
Contracts.Check(selectedIndex < _sourceSchema.ColumnCount, fmt);
96+
97+
// Copy the selected column into output schema.
98+
var selectedColumn = _sourceSchema[selectedIndex];
9399
schemaBuilder.AddColumn(selectedColumn.Name, selectedColumn.Type, selectedColumn.Metadata);
94100
}
95-
OutputSchema = schemaBuilder.GetSchema();
101+
return schemaBuilder.GetSchema();
96102
}
97103

98104
internal void Save(ModelSaveContext ctx)
@@ -112,7 +118,7 @@ internal bool[] GetActive(Func<int, bool> predicate)
112118
internal Func<int, bool> GetDependencies(Func<int, bool> predicate)
113119
{
114120
Contracts.AssertValue(predicate);
115-
var active = new bool[_input.ColumnCount];
121+
var active = new bool[_sourceSchema.ColumnCount];
116122
for (int i = 0; i < _sources.Length; i++)
117123
{
118124
if (predicate(i))

0 commit comments

Comments
 (0)