@@ -55,118 +55,18 @@ public sealed class Arguments : ScorerArgumentsBase
55
55
56
56
private sealed class BoundMapper : ISchemaBoundRowMapper
57
57
{
58
- private sealed class SchemaImpl : ISchema
59
- {
60
- private readonly IExceptionContext _ectx ;
61
- private readonly string [ ] _names ;
62
- private readonly ColumnType [ ] _types ;
63
-
64
- private readonly TreeEnsembleFeaturizerBindableMapper _parent ;
65
-
66
- public int ColumnCount { get { return _types . Length ; } }
67
-
68
- public SchemaImpl ( IExceptionContext ectx , TreeEnsembleFeaturizerBindableMapper parent ,
69
- ColumnType treeValueColType , ColumnType leafIdColType , ColumnType pathIdColType )
70
- {
71
- Contracts . CheckValueOrNull ( ectx ) ;
72
- _ectx = ectx ;
73
- _ectx . AssertValue ( parent ) ;
74
- _ectx . AssertValue ( treeValueColType ) ;
75
- _ectx . AssertValue ( leafIdColType ) ;
76
- _ectx . AssertValue ( pathIdColType ) ;
77
-
78
- _parent = parent ;
79
-
80
- _names = new string [ 3 ] ;
81
- _names [ TreeIdx ] = OutputColumnNames . Trees ;
82
- _names [ LeafIdx ] = OutputColumnNames . Leaves ;
83
- _names [ PathIdx ] = OutputColumnNames . Paths ;
84
-
85
- _types = new ColumnType [ 3 ] ;
86
- _types [ TreeIdx ] = treeValueColType ;
87
- _types [ LeafIdx ] = leafIdColType ;
88
- _types [ PathIdx ] = pathIdColType ;
89
- }
90
-
91
- public bool TryGetColumnIndex ( string name , out int col )
92
- {
93
- col = - 1 ;
94
- if ( name == OutputColumnNames . Trees )
95
- col = TreeIdx ;
96
- else if ( name == OutputColumnNames . Leaves )
97
- col = LeafIdx ;
98
- else if ( name == OutputColumnNames . Paths )
99
- col = PathIdx ;
100
- return col >= 0 ;
101
- }
102
-
103
- public string GetColumnName ( int col )
104
- {
105
- _ectx . CheckParam ( 0 <= col && col < ColumnCount , nameof ( col ) ) ;
106
- return _names [ col ] ;
107
- }
108
-
109
- public ColumnType GetColumnType ( int col )
110
- {
111
- _ectx . CheckParam ( 0 <= col && col < ColumnCount , nameof ( col ) ) ;
112
- return _types [ col ] ;
113
- }
114
-
115
- public IEnumerable < KeyValuePair < string , ColumnType > > GetMetadataTypes ( int col )
116
- {
117
- _ectx . CheckParam ( 0 <= col && col < ColumnCount , nameof ( col ) ) ;
118
- yield return
119
- MetadataUtils . GetNamesType ( _types [ col ] . VectorSize ) . GetPair ( MetadataUtils . Kinds . SlotNames ) ;
120
- if ( col == PathIdx || col == LeafIdx )
121
- yield return BoolType . Instance . GetPair ( MetadataUtils . Kinds . IsNormalized ) ;
122
- }
123
-
124
- public ColumnType GetMetadataTypeOrNull ( string kind , int col )
125
- {
126
- _ectx . CheckParam ( 0 <= col && col < ColumnCount , nameof ( col ) ) ;
127
-
128
- if ( ( col == PathIdx || col == LeafIdx ) && kind == MetadataUtils . Kinds . IsNormalized )
129
- return BoolType . Instance ;
130
- if ( kind == MetadataUtils . Kinds . SlotNames )
131
- return MetadataUtils . GetNamesType ( _types [ col ] . VectorSize ) ;
132
- return null ;
133
- }
134
-
135
- public void GetMetadata < TValue > ( string kind , int col , ref TValue value )
136
- {
137
- _ectx . CheckParam ( 0 <= col && col < ColumnCount , nameof ( col ) ) ;
138
-
139
- if ( ( col == PathIdx || col == LeafIdx ) && kind == MetadataUtils . Kinds . IsNormalized )
140
- MetadataUtils . Marshal < bool , TValue > ( IsNormalized , col , ref value ) ;
141
- else if ( kind == MetadataUtils . Kinds . SlotNames )
142
- {
143
- switch ( col )
144
- {
145
- case TreeIdx :
146
- MetadataUtils . Marshal < VBuffer < ReadOnlyMemory < char > > , TValue > ( _parent . GetTreeSlotNames , col , ref value ) ;
147
- break ;
148
- case LeafIdx :
149
- MetadataUtils . Marshal < VBuffer < ReadOnlyMemory < char > > , TValue > ( _parent . GetLeafSlotNames , col , ref value ) ;
150
- break ;
151
- default :
152
- Contracts . Assert ( col == PathIdx ) ;
153
- MetadataUtils . Marshal < VBuffer < ReadOnlyMemory < char > > , TValue > ( _parent . GetPathSlotNames , col , ref value ) ;
154
- break ;
155
- }
156
- }
157
- else
158
- throw _ectx . ExceptGetMetadata ( ) ;
159
- }
160
-
161
- private void IsNormalized ( int iinfo , ref bool dst )
162
- {
163
- dst = true ;
164
- }
165
- }
166
-
167
- private const int TreeIdx = 0 ;
168
- private const int LeafIdx = 1 ;
169
- private const int PathIdx = 2 ;
58
+ /// <summary>
59
+ /// Column index of values predicted by all trees in an ensemble in <see cref="OutputSchema"/>.
60
+ /// </summary>
61
+ private const int TreeValuesColumnId = 0 ;
62
+ /// <summary>
63
+ /// Column index of leaf IDs containing the considered example in <see cref="OutputSchema"/>.
64
+ /// </summary>
65
+ private const int LeafIdsColumnId = 1 ;
66
+ /// <summary>
67
+ /// Column index of path IDs which specify the paths the considered example passing through per tree in <see cref="OutputSchema"/>.
68
+ /// </summary>
69
+ private const int PathIdsColumnId = 2 ;
170
70
171
71
private readonly TreeEnsembleFeaturizerBindableMapper _owner ;
172
72
private readonly IExceptionContext _ectx ;
@@ -193,19 +93,53 @@ public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper
193
93
InputRoleMappedSchema = schema ;
194
94
195
95
// A vector containing the output of each tree on a given example.
196
- var treeValueType = new VectorType ( NumberType . Float , _owner . _ensemble . TrainedEnsemble . NumTrees ) ;
96
+ var treeValueType = new VectorType ( NumberType . Float , owner . _ensemble . TrainedEnsemble . NumTrees ) ;
197
97
// An indicator vector with length = the total number of leaves in the ensemble, indicating which leaf the example
198
98
// ends up in all the trees in the ensemble.
199
- var leafIdType = new VectorType ( NumberType . Float , _owner . _totalLeafCount ) ;
99
+ var leafIdType = new VectorType ( NumberType . Float , owner . _totalLeafCount ) ;
200
100
// An indicator vector with length = the total number of nodes in the ensemble, indicating the nodes on
201
101
// the paths of the example in all the trees in the ensemble.
202
102
// The total number of nodes in a binary tree is equal to the number of internal nodes + the number of leaf nodes,
203
103
// and it is also equal to the number of children of internal nodes (which is 2 * the number of internal nodes)
204
104
// plus one (since the root node is not a child of any node). So we have #internal + #leaf = 2*(#internal) + 1,
205
105
// which means that #internal = #leaf - 1.
206
106
// Therefore, the number of internal nodes in the ensemble is #leaf - #trees.
207
- var pathIdType = new VectorType ( NumberType . Float , _owner . _totalLeafCount - _owner . _ensemble . TrainedEnsemble . NumTrees ) ;
208
- OutputSchema = Schema . Create ( new SchemaImpl ( ectx , owner , treeValueType , leafIdType , pathIdType ) ) ;
107
+ var pathIdType = new VectorType ( NumberType . Float , owner . _totalLeafCount - owner . _ensemble . TrainedEnsemble . NumTrees ) ;
108
+
109
+ // Start creating output schema with types derived above.
110
+ var schemaBuilder = new SchemaBuilder ( ) ;
111
+
112
+ // Metadata of tree values.
113
+ var treeIdMetadataBuilder = new MetadataBuilder ( ) ;
114
+ treeIdMetadataBuilder . Add ( MetadataUtils . Kinds . SlotNames , MetadataUtils . GetNamesType ( treeValueType . VectorSize ) ,
115
+ ( ValueGetter < VBuffer < ReadOnlyMemory < char > > > ) owner . GetTreeSlotNames ) ;
116
+ // Add the column of trees' output values
117
+ schemaBuilder . AddColumn ( OutputColumnNames . Trees , treeValueType , treeIdMetadataBuilder . GetMetadata ( ) ) ;
118
+
119
+ // Metadata of leaf IDs.
120
+ var leafIdMetadataBuilder = new MetadataBuilder ( ) ;
121
+ leafIdMetadataBuilder . Add ( MetadataUtils . Kinds . SlotNames , MetadataUtils . GetNamesType ( leafIdType . VectorSize ) ,
122
+ ( ValueGetter < VBuffer < ReadOnlyMemory < char > > > ) owner . GetLeafSlotNames ) ;
123
+ leafIdMetadataBuilder . Add ( MetadataUtils . Kinds . IsNormalized , BoolType . Instance , ( ref bool value ) => value = true ) ;
124
+ // Add the column of leaves' IDs where the input example reaches.
125
+ schemaBuilder . AddColumn ( OutputColumnNames . Leaves , leafIdType , leafIdMetadataBuilder . GetMetadata ( ) ) ;
126
+
127
+ // Metadata of path IDs.
128
+ var pathIdMetadataBuilder = new MetadataBuilder ( ) ;
129
+ pathIdMetadataBuilder . Add ( MetadataUtils . Kinds . SlotNames , MetadataUtils . GetNamesType ( pathIdType . VectorSize ) ,
130
+ ( ValueGetter < VBuffer < ReadOnlyMemory < char > > > ) owner . GetPathSlotNames ) ;
131
+ pathIdMetadataBuilder . Add ( MetadataUtils . Kinds . IsNormalized , BoolType . Instance , ( ref bool value ) => value = true ) ;
132
+ // Add the column of encoded paths which the input example passes.
133
+ schemaBuilder . AddColumn ( OutputColumnNames . Paths , pathIdType , pathIdMetadataBuilder . GetMetadata ( ) ) ;
134
+
135
+ OutputSchema = schemaBuilder . GetSchema ( ) ;
136
+
137
+ // Tree values must be the first output column.
138
+ Contracts . Assert ( OutputSchema [ OutputColumnNames . Trees ] . Index == TreeValuesColumnId ) ;
139
+ // leaf IDs must be the second output column.
140
+ Contracts . Assert ( OutputSchema [ OutputColumnNames . Leaves ] . Index == LeafIdsColumnId ) ;
141
+ // Path IDs must be the third output column.
142
+ Contracts . Assert ( OutputSchema [ OutputColumnNames . Paths ] . Index == PathIdsColumnId ) ;
209
143
}
210
144
211
145
public Row GetRow ( Row input , Func < int , bool > predicate )
@@ -222,9 +156,9 @@ private Delegate[] CreateGetters(Row input, Func<int, bool> predicate)
222
156
223
157
var delegates = new Delegate [ 3 ] ;
224
158
225
- var treeValueActive = predicate ( TreeIdx ) ;
226
- var leafIdActive = predicate ( LeafIdx ) ;
227
- var pathIdActive = predicate ( PathIdx ) ;
159
+ var treeValueActive = predicate ( TreeValuesColumnId ) ;
160
+ var leafIdActive = predicate ( LeafIdsColumnId ) ;
161
+ var pathIdActive = predicate ( PathIdsColumnId ) ;
228
162
229
163
if ( ! treeValueActive && ! leafIdActive && ! pathIdActive )
230
164
return delegates ;
@@ -235,21 +169,21 @@ private Delegate[] CreateGetters(Row input, Func<int, bool> predicate)
235
169
if ( treeValueActive )
236
170
{
237
171
ValueGetter < VBuffer < float > > fn = state . GetTreeValues ;
238
- delegates [ TreeIdx ] = fn ;
172
+ delegates [ TreeValuesColumnId ] = fn ;
239
173
}
240
174
241
175
// Get the leaf indicator getter.
242
176
if ( leafIdActive )
243
177
{
244
178
ValueGetter < VBuffer < float > > fn = state . GetLeafIds ;
245
- delegates [ LeafIdx ] = fn ;
179
+ delegates [ LeafIdsColumnId ] = fn ;
246
180
}
247
181
248
182
// Get the path indicators getter.
249
183
if ( pathIdActive )
250
184
{
251
185
ValueGetter < VBuffer < float > > fn = state . GetPathIds ;
252
- delegates [ PathIdx ] = fn ;
186
+ delegates [ PathIdsColumnId ] = fn ;
253
187
}
254
188
255
189
return delegates ;
@@ -477,7 +411,7 @@ private static int CountLeaves(TreeEnsembleModelParameters ensemble)
477
411
return totalLeafCount ;
478
412
}
479
413
480
- private void GetTreeSlotNames ( int col , ref VBuffer < ReadOnlyMemory < char > > dst )
414
+ private void GetTreeSlotNames ( ref VBuffer < ReadOnlyMemory < char > > dst )
481
415
{
482
416
var numTrees = _ensemble . TrainedEnsemble . NumTrees ;
483
417
@@ -488,7 +422,7 @@ private void GetTreeSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
488
422
dst = editor . Commit ( ) ;
489
423
}
490
424
491
- private void GetLeafSlotNames ( int col , ref VBuffer < ReadOnlyMemory < char > > dst )
425
+ private void GetLeafSlotNames ( ref VBuffer < ReadOnlyMemory < char > > dst )
492
426
{
493
427
var numTrees = _ensemble . TrainedEnsemble . NumTrees ;
494
428
@@ -505,7 +439,7 @@ private void GetLeafSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
505
439
dst = editor . Commit ( ) ;
506
440
}
507
441
508
- private void GetPathSlotNames ( int col , ref VBuffer < ReadOnlyMemory < char > > dst )
442
+ private void GetPathSlotNames ( ref VBuffer < ReadOnlyMemory < char > > dst )
509
443
{
510
444
var numTrees = _ensemble . TrainedEnsemble . NumTrees ;
511
445
0 commit comments