Skip to content

Commit f7a2fbe

Browse files
committed
Remove ISchema in tree featurizer
1 parent 88b0ec0 commit f7a2fbe

File tree

1 file changed

+56
-125
lines changed

1 file changed

+56
-125
lines changed

src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs

Lines changed: 56 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -55,118 +55,18 @@ public sealed class Arguments : ScorerArgumentsBase
5555

5656
private sealed class BoundMapper : ISchemaBoundRowMapper
5757
{
58-
private sealed class SchemaImpl : ISchema
59-
{
60-
private readonly IExceptionContext _ectx;
61-
private readonly string[] _names;
62-
private readonly ColumnType[] _types;
63-
64-
private readonly TreeEnsembleFeaturizerBindableMapper _parent;
65-
66-
public int ColumnCount { get { return _types.Length; } }
67-
68-
public SchemaImpl(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper parent,
69-
ColumnType treeValueColType, ColumnType leafIdColType, ColumnType pathIdColType)
70-
{
71-
Contracts.CheckValueOrNull(ectx);
72-
_ectx = ectx;
73-
_ectx.AssertValue(parent);
74-
_ectx.AssertValue(treeValueColType);
75-
_ectx.AssertValue(leafIdColType);
76-
_ectx.AssertValue(pathIdColType);
77-
78-
_parent = parent;
79-
80-
_names = new string[3];
81-
_names[TreeIdx] = OutputColumnNames.Trees;
82-
_names[LeafIdx] = OutputColumnNames.Leaves;
83-
_names[PathIdx] = OutputColumnNames.Paths;
84-
85-
_types = new ColumnType[3];
86-
_types[TreeIdx] = treeValueColType;
87-
_types[LeafIdx] = leafIdColType;
88-
_types[PathIdx] = pathIdColType;
89-
}
90-
91-
public bool TryGetColumnIndex(string name, out int col)
92-
{
93-
col = -1;
94-
if (name == OutputColumnNames.Trees)
95-
col = TreeIdx;
96-
else if (name == OutputColumnNames.Leaves)
97-
col = LeafIdx;
98-
else if (name == OutputColumnNames.Paths)
99-
col = PathIdx;
100-
return col >= 0;
101-
}
102-
103-
public string GetColumnName(int col)
104-
{
105-
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
106-
return _names[col];
107-
}
108-
109-
public ColumnType GetColumnType(int col)
110-
{
111-
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
112-
return _types[col];
113-
}
114-
115-
public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
116-
{
117-
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
118-
yield return
119-
MetadataUtils.GetNamesType(_types[col].VectorSize).GetPair(MetadataUtils.Kinds.SlotNames);
120-
if (col == PathIdx || col == LeafIdx)
121-
yield return BoolType.Instance.GetPair(MetadataUtils.Kinds.IsNormalized);
122-
}
123-
124-
public ColumnType GetMetadataTypeOrNull(string kind, int col)
125-
{
126-
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
127-
128-
if ((col == PathIdx || col == LeafIdx) && kind == MetadataUtils.Kinds.IsNormalized)
129-
return BoolType.Instance;
130-
if (kind == MetadataUtils.Kinds.SlotNames)
131-
return MetadataUtils.GetNamesType(_types[col].VectorSize);
132-
return null;
133-
}
134-
135-
public void GetMetadata<TValue>(string kind, int col, ref TValue value)
136-
{
137-
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
138-
139-
if ((col == PathIdx || col == LeafIdx) && kind == MetadataUtils.Kinds.IsNormalized)
140-
MetadataUtils.Marshal<bool, TValue>(IsNormalized, col, ref value);
141-
else if (kind == MetadataUtils.Kinds.SlotNames)
142-
{
143-
switch (col)
144-
{
145-
case TreeIdx:
146-
MetadataUtils.Marshal<VBuffer<ReadOnlyMemory<char>>, TValue>(_parent.GetTreeSlotNames, col, ref value);
147-
break;
148-
case LeafIdx:
149-
MetadataUtils.Marshal<VBuffer<ReadOnlyMemory<char>>, TValue>(_parent.GetLeafSlotNames, col, ref value);
150-
break;
151-
default:
152-
Contracts.Assert(col == PathIdx);
153-
MetadataUtils.Marshal<VBuffer<ReadOnlyMemory<char>>, TValue>(_parent.GetPathSlotNames, col, ref value);
154-
break;
155-
}
156-
}
157-
else
158-
throw _ectx.ExceptGetMetadata();
159-
}
160-
161-
private void IsNormalized(int iinfo, ref bool dst)
162-
{
163-
dst = true;
164-
}
165-
}
166-
167-
private const int TreeIdx = 0;
168-
private const int LeafIdx = 1;
169-
private const int PathIdx = 2;
58+
/// <summary>
59+
/// Column index of values predicted by all trees in an ensemble in <see cref="OutputSchema"/>.
60+
/// </summary>
61+
private const int TreeValuesColumnId = 0;
62+
/// <summary>
63+
/// Column index of leaf IDs containing the considered example in <see cref="OutputSchema"/>.
64+
/// </summary>
65+
private const int LeafIdsColumnId = 1;
66+
/// <summary>
67+
/// Column index of path IDs which specify the paths the considered example passing through per tree in <see cref="OutputSchema"/>.
68+
/// </summary>
69+
private const int PathIdsColumnId = 2;
17070

17171
private readonly TreeEnsembleFeaturizerBindableMapper _owner;
17272
private readonly IExceptionContext _ectx;
@@ -193,19 +93,50 @@ public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper
19393
InputRoleMappedSchema = schema;
19494

19595
// A vector containing the output of each tree on a given example.
196-
var treeValueType = new VectorType(NumberType.Float, _owner._ensemble.TrainedEnsemble.NumTrees);
96+
var treeValueType = new VectorType(NumberType.Float, owner._ensemble.TrainedEnsemble.NumTrees);
19797
// An indicator vector with length = the total number of leaves in the ensemble, indicating which leaf the example
19898
// ends up in all the trees in the ensemble.
199-
var leafIdType = new VectorType(NumberType.Float, _owner._totalLeafCount);
99+
var leafIdType = new VectorType(NumberType.Float, owner._totalLeafCount);
200100
// An indicator vector with length = the total number of nodes in the ensemble, indicating the nodes on
201101
// the paths of the example in all the trees in the ensemble.
202102
// The total number of nodes in a binary tree is equal to the number of internal nodes + the number of leaf nodes,
203103
// and it is also equal to the number of children of internal nodes (which is 2 * the number of internal nodes)
204104
// plus one (since the root node is not a child of any node). So we have #internal + #leaf = 2*(#internal) + 1,
205105
// which means that #internal = #leaf - 1.
206106
// Therefore, the number of internal nodes in the ensemble is #leaf - #trees.
207-
var pathIdType = new VectorType(NumberType.Float, _owner._totalLeafCount - _owner._ensemble.TrainedEnsemble.NumTrees);
208-
OutputSchema = Schema.Create(new SchemaImpl(ectx, owner, treeValueType, leafIdType, pathIdType));
107+
var pathIdType = new VectorType(NumberType.Float, owner._totalLeafCount - owner._ensemble.TrainedEnsemble.NumTrees);
108+
109+
// Start creating output schema with types derived above.
110+
var schemaBuilder = new SchemaBuilder();
111+
112+
// Metadata of tree values.
113+
var treeIdMetadataBuilder = new MetadataBuilder();
114+
ValueGetter<VBuffer<ReadOnlyMemory<char>>> treeIdMetadataGetter = (ref VBuffer<ReadOnlyMemory<char>> value) => owner.GetTreeSlotNames(ref value);
115+
treeIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(treeValueType.VectorSize), treeIdMetadataGetter);
116+
117+
// Tree values must be the first output column.
118+
Contracts.Assert(TreeValuesColumnId == 0);
119+
schemaBuilder.AddColumn(OutputColumnNames.Trees, treeValueType, treeIdMetadataBuilder.GetMetadata());
120+
121+
// Metadata of leaf IDs.
122+
var leafIdMetadataBuilder = new MetadataBuilder();
123+
ValueGetter<VBuffer<ReadOnlyMemory<char>>> leafIdMetadataGetter = (ref VBuffer<ReadOnlyMemory<char>> value) => owner.GetLeafSlotNames(ref value);
124+
treeIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(leafIdType.VectorSize), treeIdMetadataGetter);
125+
126+
// leaf IDs must be the second output column.
127+
Contracts.Assert(LeafIdsColumnId == 1);
128+
schemaBuilder.AddColumn(OutputColumnNames.Leaves, leafIdType, leafIdMetadataBuilder.GetMetadata());
129+
130+
// Metadata of path IDs.
131+
var pathIdMetadataBuilder = new MetadataBuilder();
132+
ValueGetter<VBuffer<ReadOnlyMemory<char>>> pathIdMetadataGetter = (ref VBuffer<ReadOnlyMemory<char>> value) => owner.GetPathSlotNames(ref value);
133+
pathIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(pathIdType.VectorSize), pathIdMetadataGetter);
134+
135+
// Path IDs must be the third output column.
136+
Contracts.Assert(PathIdsColumnId == 2);
137+
schemaBuilder.AddColumn(OutputColumnNames.Paths, pathIdType, pathIdMetadataBuilder.GetMetadata());
138+
139+
OutputSchema = schemaBuilder.GetSchema();
209140
}
210141

211142
public Row GetRow(Row input, Func<int, bool> predicate)
@@ -222,9 +153,9 @@ private Delegate[] CreateGetters(Row input, Func<int, bool> predicate)
222153

223154
var delegates = new Delegate[3];
224155

225-
var treeValueActive = predicate(TreeIdx);
226-
var leafIdActive = predicate(LeafIdx);
227-
var pathIdActive = predicate(PathIdx);
156+
var treeValueActive = predicate(TreeValuesColumnId);
157+
var leafIdActive = predicate(LeafIdsColumnId);
158+
var pathIdActive = predicate(PathIdsColumnId);
228159

229160
if (!treeValueActive && !leafIdActive && !pathIdActive)
230161
return delegates;
@@ -235,21 +166,21 @@ private Delegate[] CreateGetters(Row input, Func<int, bool> predicate)
235166
if (treeValueActive)
236167
{
237168
ValueGetter<VBuffer<float>> fn = state.GetTreeValues;
238-
delegates[TreeIdx] = fn;
169+
delegates[TreeValuesColumnId] = fn;
239170
}
240171

241172
// Get the leaf indicator getter.
242173
if (leafIdActive)
243174
{
244175
ValueGetter<VBuffer<float>> fn = state.GetLeafIds;
245-
delegates[LeafIdx] = fn;
176+
delegates[LeafIdsColumnId] = fn;
246177
}
247178

248179
// Get the path indicators getter.
249180
if (pathIdActive)
250181
{
251182
ValueGetter<VBuffer<float>> fn = state.GetPathIds;
252-
delegates[PathIdx] = fn;
183+
delegates[PathIdsColumnId] = fn;
253184
}
254185

255186
return delegates;
@@ -477,7 +408,7 @@ private static int CountLeaves(TreeEnsembleModelParameters ensemble)
477408
return totalLeafCount;
478409
}
479410

480-
private void GetTreeSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
411+
private void GetTreeSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst)
481412
{
482413
var numTrees = _ensemble.TrainedEnsemble.NumTrees;
483414

@@ -488,7 +419,7 @@ private void GetTreeSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
488419
dst = editor.Commit();
489420
}
490421

491-
private void GetLeafSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
422+
private void GetLeafSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst)
492423
{
493424
var numTrees = _ensemble.TrainedEnsemble.NumTrees;
494425

@@ -505,7 +436,7 @@ private void GetLeafSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
505436
dst = editor.Commit();
506437
}
507438

508-
private void GetPathSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
439+
private void GetPathSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst)
509440
{
510441
var numTrees = _ensemble.TrainedEnsemble.NumTrees;
511442

0 commit comments

Comments
 (0)