@@ -38,7 +38,7 @@ private sealed class Bindings
38
38
39
39
// Input schema of this transform. It's useful when determining column dependencies and other
40
40
// relations between input and output schemas.
41
- private readonly Schema _input ;
41
+ private readonly Schema _sourceSchema ;
42
42
43
43
// This transform's output schema.
44
44
internal Schema OutputSchema { get ; }
@@ -48,32 +48,20 @@ internal Bindings(Arguments args, Schema sourceSchema)
48
48
Contracts . AssertValue ( args ) ;
49
49
Contracts . AssertValue ( sourceSchema ) ;
50
50
51
- _input = sourceSchema ;
51
+ _sourceSchema = sourceSchema ;
52
52
53
53
if ( args . Drop )
54
54
// Drop columns indexed by args.Index
55
- _sources = Enumerable . Range ( 0 , _input . ColumnCount ) . Except ( args . Index ) . ToArray ( ) ;
55
+ _sources = Enumerable . Range ( 0 , _sourceSchema . ColumnCount ) . Except ( args . Index ) . ToArray ( ) ;
56
56
else
57
57
// Keep columns indexed by args.Index
58
58
_sources = args . Index . ToArray ( ) ;
59
59
60
60
// Make sure the output of this transform is meaningful.
61
61
Contracts . Check ( _sources . Length > 0 , "Choose columns by index has no output column." ) ;
62
62
63
- var schemaBuilder = new SchemaBuilder ( ) ;
64
- for ( int i = 0 ; i < _sources . Length ; ++ i )
65
- {
66
- var selectedIndex = _sources [ i ] ;
67
-
68
- // The dropped/kept columns are determined by user-specified arguments, so we throw if a bad configuration is provided.
69
- string fmt = string . Format ( "Column index {0} invalid for input with {1} columns" , selectedIndex , _input . ColumnCount ) ;
70
- Contracts . Check ( selectedIndex > sourceSchema . ColumnCount , fmt ) ;
71
-
72
- // Put the selected column into output schema.
73
- var selectedColumn = sourceSchema [ selectedIndex ] ;
74
- schemaBuilder . AddColumn ( selectedColumn . Name , selectedColumn . Type , selectedColumn . Metadata ) ;
75
- }
76
- OutputSchema = schemaBuilder . GetSchema ( ) ;
63
+ // All necessary fields in this class are set, so we can compute output schema now.
64
+ OutputSchema = ComputeOutputSchema ( ) ;
77
65
}
78
66
79
67
internal Bindings ( ModelLoadContext ctx , Schema sourceSchema )
@@ -85,14 +73,32 @@ internal Bindings(ModelLoadContext ctx, Schema sourceSchema)
85
73
// int[]: selected source column indices
86
74
_sources = ctx . Reader . ReadIntArray ( ) ;
87
75
88
- _input = sourceSchema ;
76
+ _sourceSchema = sourceSchema ;
77
+ OutputSchema = ComputeOutputSchema ( ) ;
78
+ }
79
+
80
+ /// <summary>
81
+ /// After _input and _sources are set, pick up selected columns from _input to create output schema.
82
+ /// Note that _sources tells us what columns in _input are selected.
83
+ /// </summary>
84
+ private Schema ComputeOutputSchema ( )
85
+ {
89
86
var schemaBuilder = new SchemaBuilder ( ) ;
90
87
for ( int i = 0 ; i < _sources . Length ; ++ i )
91
88
{
92
- var selectedColumn = sourceSchema [ _sources [ i ] ] ;
89
+ // selectedIndex is an column index of input schema. Note that the input column indexed by _sources[i] in _input is sent
90
+ // to the i-th column in the output schema.
91
+ var selectedIndex = _sources [ i ] ;
92
+
93
+ // The dropped/kept columns are determined by user-specified arguments, so we throw if a bad configuration is provided.
94
+ string fmt = string . Format ( "Column index {0} invalid for input with {1} columns" , selectedIndex , _sourceSchema . ColumnCount ) ;
95
+ Contracts . Check ( selectedIndex < _sourceSchema . ColumnCount , fmt ) ;
96
+
97
+ // Copy the selected column into output schema.
98
+ var selectedColumn = _sourceSchema [ selectedIndex ] ;
93
99
schemaBuilder . AddColumn ( selectedColumn . Name , selectedColumn . Type , selectedColumn . Metadata ) ;
94
100
}
95
- OutputSchema = schemaBuilder . GetSchema ( ) ;
101
+ return schemaBuilder . GetSchema ( ) ;
96
102
}
97
103
98
104
internal void Save ( ModelSaveContext ctx )
@@ -112,7 +118,7 @@ internal bool[] GetActive(Func<int, bool> predicate)
112
118
internal Func < int , bool > GetDependencies ( Func < int , bool > predicate )
113
119
{
114
120
Contracts . AssertValue ( predicate ) ;
115
- var active = new bool [ _input . ColumnCount ] ;
121
+ var active = new bool [ _sourceSchema . ColumnCount ] ;
116
122
for ( int i = 0 ; i < _sources . Length ; i ++ )
117
123
{
118
124
if ( predicate ( i ) )
0 commit comments