Skip to content

Commit c0af761

Browse files
authored
Remove ISchema in TreeEnsembleFeaturizer (dotnet#2132)
* Remove another ISchema * Add a test to tree ferturization's output schema
1 parent 6b9f589 commit c0af761

File tree

6 files changed

+207
-127
lines changed

6 files changed

+207
-127
lines changed

src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,6 @@
1414

1515
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Internal.FastTree" + InternalPublicKey.Value)]
1616
[assembly: InternalsVisibleTo(assemblyName: "RunTests" + InternalPublicKey.Value)]
17+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)]
1718

1819
[assembly: WantsToBeBestFriends]

src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs

Lines changed: 59 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -55,118 +55,18 @@ public sealed class Arguments : ScorerArgumentsBase
5555

5656
private sealed class BoundMapper : ISchemaBoundRowMapper
5757
{
58-
private sealed class SchemaImpl : ISchema
59-
{
60-
private readonly IExceptionContext _ectx;
61-
private readonly string[] _names;
62-
private readonly ColumnType[] _types;
63-
64-
private readonly TreeEnsembleFeaturizerBindableMapper _parent;
65-
66-
public int ColumnCount { get { return _types.Length; } }
67-
68-
public SchemaImpl(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper parent,
69-
ColumnType treeValueColType, ColumnType leafIdColType, ColumnType pathIdColType)
70-
{
71-
Contracts.CheckValueOrNull(ectx);
72-
_ectx = ectx;
73-
_ectx.AssertValue(parent);
74-
_ectx.AssertValue(treeValueColType);
75-
_ectx.AssertValue(leafIdColType);
76-
_ectx.AssertValue(pathIdColType);
77-
78-
_parent = parent;
79-
80-
_names = new string[3];
81-
_names[TreeIdx] = OutputColumnNames.Trees;
82-
_names[LeafIdx] = OutputColumnNames.Leaves;
83-
_names[PathIdx] = OutputColumnNames.Paths;
84-
85-
_types = new ColumnType[3];
86-
_types[TreeIdx] = treeValueColType;
87-
_types[LeafIdx] = leafIdColType;
88-
_types[PathIdx] = pathIdColType;
89-
}
90-
91-
public bool TryGetColumnIndex(string name, out int col)
92-
{
93-
col = -1;
94-
if (name == OutputColumnNames.Trees)
95-
col = TreeIdx;
96-
else if (name == OutputColumnNames.Leaves)
97-
col = LeafIdx;
98-
else if (name == OutputColumnNames.Paths)
99-
col = PathIdx;
100-
return col >= 0;
101-
}
102-
103-
public string GetColumnName(int col)
104-
{
105-
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
106-
return _names[col];
107-
}
108-
109-
public ColumnType GetColumnType(int col)
110-
{
111-
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
112-
return _types[col];
113-
}
114-
115-
public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
116-
{
117-
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
118-
yield return
119-
MetadataUtils.GetNamesType(_types[col].VectorSize).GetPair(MetadataUtils.Kinds.SlotNames);
120-
if (col == PathIdx || col == LeafIdx)
121-
yield return BoolType.Instance.GetPair(MetadataUtils.Kinds.IsNormalized);
122-
}
123-
124-
public ColumnType GetMetadataTypeOrNull(string kind, int col)
125-
{
126-
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
127-
128-
if ((col == PathIdx || col == LeafIdx) && kind == MetadataUtils.Kinds.IsNormalized)
129-
return BoolType.Instance;
130-
if (kind == MetadataUtils.Kinds.SlotNames)
131-
return MetadataUtils.GetNamesType(_types[col].VectorSize);
132-
return null;
133-
}
134-
135-
public void GetMetadata<TValue>(string kind, int col, ref TValue value)
136-
{
137-
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
138-
139-
if ((col == PathIdx || col == LeafIdx) && kind == MetadataUtils.Kinds.IsNormalized)
140-
MetadataUtils.Marshal<bool, TValue>(IsNormalized, col, ref value);
141-
else if (kind == MetadataUtils.Kinds.SlotNames)
142-
{
143-
switch (col)
144-
{
145-
case TreeIdx:
146-
MetadataUtils.Marshal<VBuffer<ReadOnlyMemory<char>>, TValue>(_parent.GetTreeSlotNames, col, ref value);
147-
break;
148-
case LeafIdx:
149-
MetadataUtils.Marshal<VBuffer<ReadOnlyMemory<char>>, TValue>(_parent.GetLeafSlotNames, col, ref value);
150-
break;
151-
default:
152-
Contracts.Assert(col == PathIdx);
153-
MetadataUtils.Marshal<VBuffer<ReadOnlyMemory<char>>, TValue>(_parent.GetPathSlotNames, col, ref value);
154-
break;
155-
}
156-
}
157-
else
158-
throw _ectx.ExceptGetMetadata();
159-
}
160-
161-
private void IsNormalized(int iinfo, ref bool dst)
162-
{
163-
dst = true;
164-
}
165-
}
166-
167-
private const int TreeIdx = 0;
168-
private const int LeafIdx = 1;
169-
private const int PathIdx = 2;
58+
/// <summary>
59+
/// Column index of values predicted by all trees in an ensemble in <see cref="OutputSchema"/>.
60+
/// </summary>
61+
private const int TreeValuesColumnId = 0;
62+
/// <summary>
63+
/// Column index of leaf IDs containing the considered example in <see cref="OutputSchema"/>.
64+
/// </summary>
65+
private const int LeafIdsColumnId = 1;
66+
/// <summary>
67+
/// Column index of path IDs which specify the paths the considered example passing through per tree in <see cref="OutputSchema"/>.
68+
/// </summary>
69+
private const int PathIdsColumnId = 2;
17070

17171
private readonly TreeEnsembleFeaturizerBindableMapper _owner;
17272
private readonly IExceptionContext _ectx;
@@ -193,19 +93,53 @@ public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper
19393
InputRoleMappedSchema = schema;
19494

19595
// A vector containing the output of each tree on a given example.
196-
var treeValueType = new VectorType(NumberType.Float, _owner._ensemble.TrainedEnsemble.NumTrees);
96+
var treeValueType = new VectorType(NumberType.Float, owner._ensemble.TrainedEnsemble.NumTrees);
19797
// An indicator vector with length = the total number of leaves in the ensemble, indicating which leaf the example
19898
// ends up in all the trees in the ensemble.
199-
var leafIdType = new VectorType(NumberType.Float, _owner._totalLeafCount);
99+
var leafIdType = new VectorType(NumberType.Float, owner._totalLeafCount);
200100
// An indicator vector with length = the total number of nodes in the ensemble, indicating the nodes on
201101
// the paths of the example in all the trees in the ensemble.
202102
// The total number of nodes in a binary tree is equal to the number of internal nodes + the number of leaf nodes,
203103
// and it is also equal to the number of children of internal nodes (which is 2 * the number of internal nodes)
204104
// plus one (since the root node is not a child of any node). So we have #internal + #leaf = 2*(#internal) + 1,
205105
// which means that #internal = #leaf - 1.
206106
// Therefore, the number of internal nodes in the ensemble is #leaf - #trees.
207-
var pathIdType = new VectorType(NumberType.Float, _owner._totalLeafCount - _owner._ensemble.TrainedEnsemble.NumTrees);
208-
OutputSchema = Schema.Create(new SchemaImpl(ectx, owner, treeValueType, leafIdType, pathIdType));
107+
var pathIdType = new VectorType(NumberType.Float, owner._totalLeafCount - owner._ensemble.TrainedEnsemble.NumTrees);
108+
109+
// Start creating output schema with types derived above.
110+
var schemaBuilder = new SchemaBuilder();
111+
112+
// Metadata of tree values.
113+
var treeIdMetadataBuilder = new MetadataBuilder();
114+
treeIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(treeValueType.VectorSize),
115+
(ValueGetter<VBuffer<ReadOnlyMemory<char>>>)owner.GetTreeSlotNames);
116+
// Add the column of trees' output values
117+
schemaBuilder.AddColumn(OutputColumnNames.Trees, treeValueType, treeIdMetadataBuilder.GetMetadata());
118+
119+
// Metadata of leaf IDs.
120+
var leafIdMetadataBuilder = new MetadataBuilder();
121+
leafIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(leafIdType.VectorSize),
122+
(ValueGetter<VBuffer<ReadOnlyMemory<char>>>)owner.GetLeafSlotNames);
123+
leafIdMetadataBuilder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, (ref bool value) => value = true);
124+
// Add the column of leaves' IDs where the input example reaches.
125+
schemaBuilder.AddColumn(OutputColumnNames.Leaves, leafIdType, leafIdMetadataBuilder.GetMetadata());
126+
127+
// Metadata of path IDs.
128+
var pathIdMetadataBuilder = new MetadataBuilder();
129+
pathIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(pathIdType.VectorSize),
130+
(ValueGetter<VBuffer<ReadOnlyMemory<char>>>)owner.GetPathSlotNames);
131+
pathIdMetadataBuilder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, (ref bool value) => value = true);
132+
// Add the column of encoded paths which the input example passes.
133+
schemaBuilder.AddColumn(OutputColumnNames.Paths, pathIdType, pathIdMetadataBuilder.GetMetadata());
134+
135+
OutputSchema = schemaBuilder.GetSchema();
136+
137+
// Tree values must be the first output column.
138+
Contracts.Assert(OutputSchema[OutputColumnNames.Trees].Index == TreeValuesColumnId);
139+
// leaf IDs must be the second output column.
140+
Contracts.Assert(OutputSchema[OutputColumnNames.Leaves].Index == LeafIdsColumnId);
141+
// Path IDs must be the third output column.
142+
Contracts.Assert(OutputSchema[OutputColumnNames.Paths].Index == PathIdsColumnId);
209143
}
210144

211145
public Row GetRow(Row input, Func<int, bool> predicate)
@@ -222,9 +156,9 @@ private Delegate[] CreateGetters(Row input, Func<int, bool> predicate)
222156

223157
var delegates = new Delegate[3];
224158

225-
var treeValueActive = predicate(TreeIdx);
226-
var leafIdActive = predicate(LeafIdx);
227-
var pathIdActive = predicate(PathIdx);
159+
var treeValueActive = predicate(TreeValuesColumnId);
160+
var leafIdActive = predicate(LeafIdsColumnId);
161+
var pathIdActive = predicate(PathIdsColumnId);
228162

229163
if (!treeValueActive && !leafIdActive && !pathIdActive)
230164
return delegates;
@@ -235,21 +169,21 @@ private Delegate[] CreateGetters(Row input, Func<int, bool> predicate)
235169
if (treeValueActive)
236170
{
237171
ValueGetter<VBuffer<float>> fn = state.GetTreeValues;
238-
delegates[TreeIdx] = fn;
172+
delegates[TreeValuesColumnId] = fn;
239173
}
240174

241175
// Get the leaf indicator getter.
242176
if (leafIdActive)
243177
{
244178
ValueGetter<VBuffer<float>> fn = state.GetLeafIds;
245-
delegates[LeafIdx] = fn;
179+
delegates[LeafIdsColumnId] = fn;
246180
}
247181

248182
// Get the path indicators getter.
249183
if (pathIdActive)
250184
{
251185
ValueGetter<VBuffer<float>> fn = state.GetPathIds;
252-
delegates[PathIdx] = fn;
186+
delegates[PathIdsColumnId] = fn;
253187
}
254188

255189
return delegates;
@@ -477,7 +411,7 @@ private static int CountLeaves(TreeEnsembleModelParameters ensemble)
477411
return totalLeafCount;
478412
}
479413

480-
private void GetTreeSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
414+
private void GetTreeSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst)
481415
{
482416
var numTrees = _ensemble.TrainedEnsemble.NumTrees;
483417

@@ -488,7 +422,7 @@ private void GetTreeSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
488422
dst = editor.Commit();
489423
}
490424

491-
private void GetLeafSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
425+
private void GetLeafSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst)
492426
{
493427
var numTrees = _ensemble.TrainedEnsemble.NumTrees;
494428

@@ -505,7 +439,7 @@ private void GetLeafSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
505439
dst = editor.Commit();
506440
}
507441

508-
private void GetPathSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
442+
private void GetPathSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst)
509443
{
510444
var numTrees = _ensemble.TrainedEnsemble.NumTrees;
511445

src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,39 @@ public static IEnumerable<SampleVectorOfNumbersData> GetVectorOfNumbersData()
238238
return data;
239239
}
240240

241+
private const int _simpleBinaryClassSampleFeatureLength = 10;
242+
243+
public class BinaryLabelFloatFeatureVectorSample
244+
{
245+
public bool Label;
246+
247+
[VectorType(_simpleBinaryClassSampleFeatureLength)]
248+
public float[] Features;
249+
}
250+
251+
public static IEnumerable<BinaryLabelFloatFeatureVectorSample> GenerateBinaryLabelFloatFeatureVectorSamples(int exampleCount)
252+
{
253+
var rnd = new Random(0);
254+
var data = new List<BinaryLabelFloatFeatureVectorSample>();
255+
for (int i = 0; i < exampleCount; ++i)
256+
{
257+
// Initialize an example with a random label and an empty feature vector.
258+
var sample = new BinaryLabelFloatFeatureVectorSample() { Label = rnd.Next() % 2 == 0, Features = new float[_simpleBinaryClassSampleFeatureLength] };
259+
// Fill feature vector according the assigned label.
260+
for (int j = 0; j < 10; ++j)
261+
{
262+
var value = (float)rnd.NextDouble();
263+
// Positive class gets larger feature value.
264+
if (sample.Label)
265+
value += 0.2f;
266+
sample.Features[j] = value;
267+
}
268+
269+
data.Add(sample);
270+
}
271+
return data;
272+
}
273+
241274
/// <summary>
242275
/// feature vector's length in <see cref="MulticlassClassificationExample"/>.
243276
/// </summary>

test/Microsoft.ML.TestFramework/Properties/AssemblyInfo.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44
using System.Runtime.CompilerServices;
5+
using Microsoft.ML;
56

6-
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")]
7-
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries.Tests, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")]
7+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests" + PublicKey.TestValue)]
8+
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries.Tests" + PublicKey.TestValue)]

test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
<ProjectReference Include="..\..\src\Microsoft.ML.HalLearners\Microsoft.ML.HalLearners.csproj" />
1313
<ProjectReference Include="..\..\src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj" />
1414
<ProjectReference Include="..\..\src\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
15+
<ProjectReference Include="..\..\src\Microsoft.ML.SamplesUtils\Microsoft.ML.SamplesUtils.csproj" />
1516
<ProjectReference Include="..\..\src\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj" />
1617
<ProjectReference Include="..\..\src\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj" />
1718
<ProjectReference Include="..\..\src\Microsoft.ML.OnnxTransform\Microsoft.ML.OnnxTransform.csproj" />

0 commit comments

Comments
 (0)