Skip to content

Remove ISchema in TreeEnsembleFeaturizer #2132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Internal.FastTree" + InternalPublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "RunTests" + InternalPublicKey.Value)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)]

[assembly: WantsToBeBestFriends]
184 changes: 59 additions & 125 deletions src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,118 +55,18 @@ public sealed class Arguments : ScorerArgumentsBase

private sealed class BoundMapper : ISchemaBoundRowMapper
{
private sealed class SchemaImpl : ISchema
{
private readonly IExceptionContext _ectx;
private readonly string[] _names;
private readonly ColumnType[] _types;

private readonly TreeEnsembleFeaturizerBindableMapper _parent;

public int ColumnCount { get { return _types.Length; } }

public SchemaImpl(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper parent,
ColumnType treeValueColType, ColumnType leafIdColType, ColumnType pathIdColType)
{
Contracts.CheckValueOrNull(ectx);
_ectx = ectx;
_ectx.AssertValue(parent);
_ectx.AssertValue(treeValueColType);
_ectx.AssertValue(leafIdColType);
_ectx.AssertValue(pathIdColType);

_parent = parent;

_names = new string[3];
_names[TreeIdx] = OutputColumnNames.Trees;
_names[LeafIdx] = OutputColumnNames.Leaves;
_names[PathIdx] = OutputColumnNames.Paths;

_types = new ColumnType[3];
_types[TreeIdx] = treeValueColType;
_types[LeafIdx] = leafIdColType;
_types[PathIdx] = pathIdColType;
}

public bool TryGetColumnIndex(string name, out int col)
{
col = -1;
if (name == OutputColumnNames.Trees)
col = TreeIdx;
else if (name == OutputColumnNames.Leaves)
col = LeafIdx;
else if (name == OutputColumnNames.Paths)
col = PathIdx;
return col >= 0;
}

public string GetColumnName(int col)
{
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
return _names[col];
}

public ColumnType GetColumnType(int col)
{
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
return _types[col];
}

public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
{
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
yield return
MetadataUtils.GetNamesType(_types[col].VectorSize).GetPair(MetadataUtils.Kinds.SlotNames);
if (col == PathIdx || col == LeafIdx)
yield return BoolType.Instance.GetPair(MetadataUtils.Kinds.IsNormalized);
}

public ColumnType GetMetadataTypeOrNull(string kind, int col)
{
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));

if ((col == PathIdx || col == LeafIdx) && kind == MetadataUtils.Kinds.IsNormalized)
return BoolType.Instance;
if (kind == MetadataUtils.Kinds.SlotNames)
return MetadataUtils.GetNamesType(_types[col].VectorSize);
return null;
}

public void GetMetadata<TValue>(string kind, int col, ref TValue value)
{
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));

if ((col == PathIdx || col == LeafIdx) && kind == MetadataUtils.Kinds.IsNormalized)
MetadataUtils.Marshal<bool, TValue>(IsNormalized, col, ref value);
else if (kind == MetadataUtils.Kinds.SlotNames)
{
switch (col)
{
case TreeIdx:
MetadataUtils.Marshal<VBuffer<ReadOnlyMemory<char>>, TValue>(_parent.GetTreeSlotNames, col, ref value);
break;
case LeafIdx:
MetadataUtils.Marshal<VBuffer<ReadOnlyMemory<char>>, TValue>(_parent.GetLeafSlotNames, col, ref value);
break;
default:
Contracts.Assert(col == PathIdx);
MetadataUtils.Marshal<VBuffer<ReadOnlyMemory<char>>, TValue>(_parent.GetPathSlotNames, col, ref value);
break;
}
}
else
throw _ectx.ExceptGetMetadata();
}

private void IsNormalized(int iinfo, ref bool dst)
{
dst = true;
}
}

private const int TreeIdx = 0;
private const int LeafIdx = 1;
private const int PathIdx = 2;
/// <summary>
/// Column index of values predicted by all trees in an ensemble in <see cref="OutputSchema"/>.
/// </summary>
private const int TreeValuesColumnId = 0;
/// <summary>
/// Column index of leaf IDs containing the considered example in <see cref="OutputSchema"/>.
/// </summary>
private const int LeafIdsColumnId = 1;
/// <summary>
/// Column index of path IDs which specify the paths the considered example passing through per tree in <see cref="OutputSchema"/>.
/// </summary>
private const int PathIdsColumnId = 2;

private readonly TreeEnsembleFeaturizerBindableMapper _owner;
private readonly IExceptionContext _ectx;
Expand All @@ -193,19 +93,53 @@ public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper
InputRoleMappedSchema = schema;

// A vector containing the output of each tree on a given example.
var treeValueType = new VectorType(NumberType.Float, _owner._ensemble.TrainedEnsemble.NumTrees);
var treeValueType = new VectorType(NumberType.Float, owner._ensemble.TrainedEnsemble.NumTrees);
// An indicator vector with length = the total number of leaves in the ensemble, indicating which leaf the example
// ends up in all the trees in the ensemble.
var leafIdType = new VectorType(NumberType.Float, _owner._totalLeafCount);
var leafIdType = new VectorType(NumberType.Float, owner._totalLeafCount);
// An indicator vector with length = the total number of nodes in the ensemble, indicating the nodes on
// the paths of the example in all the trees in the ensemble.
// The total number of nodes in a binary tree is equal to the number of internal nodes + the number of leaf nodes,
// and it is also equal to the number of children of internal nodes (which is 2 * the number of internal nodes)
// plus one (since the root node is not a child of any node). So we have #internal + #leaf = 2*(#internal) + 1,
// which means that #internal = #leaf - 1.
// Therefore, the number of internal nodes in the ensemble is #leaf - #trees.
var pathIdType = new VectorType(NumberType.Float, _owner._totalLeafCount - _owner._ensemble.TrainedEnsemble.NumTrees);
OutputSchema = Schema.Create(new SchemaImpl(ectx, owner, treeValueType, leafIdType, pathIdType));
var pathIdType = new VectorType(NumberType.Float, owner._totalLeafCount - owner._ensemble.TrainedEnsemble.NumTrees);

// Start creating output schema with types derived above.
var schemaBuilder = new SchemaBuilder();

// Metadata of tree values.
var treeIdMetadataBuilder = new MetadataBuilder();
treeIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(treeValueType.VectorSize),
(ValueGetter<VBuffer<ReadOnlyMemory<char>>>)owner.GetTreeSlotNames);
// Add the column of trees' output values
schemaBuilder.AddColumn(OutputColumnNames.Trees, treeValueType, treeIdMetadataBuilder.GetMetadata());

// Metadata of leaf IDs.
var leafIdMetadataBuilder = new MetadataBuilder();
leafIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(leafIdType.VectorSize),
(ValueGetter<VBuffer<ReadOnlyMemory<char>>>)owner.GetLeafSlotNames);
leafIdMetadataBuilder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, (ref bool value) => value = true);
// Add the column of leaves' IDs where the input example reaches.
schemaBuilder.AddColumn(OutputColumnNames.Leaves, leafIdType, leafIdMetadataBuilder.GetMetadata());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see that in their present form these asserts are terribly useful, since the purpose of the check was to see that the schema's indices lined up with our expectations, and this check no longer does that. What would be useful perhaps is that once the schema is created at the end of this constructor where we've assigned output schema, we can do things like this:

So three asserts along these lines might be useful:

Contracts.Assert(OutputSchema[OutputColumnNames.Leaves].Index == LeafIdsColumnId)

But the straight transliteration of the old asserts is no longer serving the purpose it once did.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. We have

                // Tree values must be the first output column.
                Contracts.Assert(OutputSchema[OutputColumnNames.Trees].Index == TreeValuesColumnId);
                // leaf IDs must be the second output column.
                Contracts.Assert(OutputSchema[OutputColumnNames.Leaves].Index == LeafIdsColumnId);
                // Path IDs must be the third output column.
                Contracts.Assert(OutputSchema[OutputColumnNames.Paths].Index == PathIdsColumnId);

now. Many thanks!


In reply to: 247681048 [](ancestors = 247681048)


// Metadata of path IDs.
var pathIdMetadataBuilder = new MetadataBuilder();
pathIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(pathIdType.VectorSize),
(ValueGetter<VBuffer<ReadOnlyMemory<char>>>)owner.GetPathSlotNames);
pathIdMetadataBuilder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, (ref bool value) => value = true);
// Add the column of encoded paths which the input example passes.
schemaBuilder.AddColumn(OutputColumnNames.Paths, pathIdType, pathIdMetadataBuilder.GetMetadata());

OutputSchema = schemaBuilder.GetSchema();

// Tree values must be the first output column.
Contracts.Assert(OutputSchema[OutputColumnNames.Trees].Index == TreeValuesColumnId);
// leaf IDs must be the second output column.
Contracts.Assert(OutputSchema[OutputColumnNames.Leaves].Index == LeafIdsColumnId);
// Path IDs must be the third output column.
Contracts.Assert(OutputSchema[OutputColumnNames.Paths].Index == PathIdsColumnId);
}

public Row GetRow(Row input, Func<int, bool> predicate)
Expand All @@ -222,9 +156,9 @@ private Delegate[] CreateGetters(Row input, Func<int, bool> predicate)

var delegates = new Delegate[3];

var treeValueActive = predicate(TreeIdx);
var leafIdActive = predicate(LeafIdx);
var pathIdActive = predicate(PathIdx);
var treeValueActive = predicate(TreeValuesColumnId);
var leafIdActive = predicate(LeafIdsColumnId);
var pathIdActive = predicate(PathIdsColumnId);

if (!treeValueActive && !leafIdActive && !pathIdActive)
return delegates;
Expand All @@ -235,21 +169,21 @@ private Delegate[] CreateGetters(Row input, Func<int, bool> predicate)
if (treeValueActive)
{
ValueGetter<VBuffer<float>> fn = state.GetTreeValues;
delegates[TreeIdx] = fn;
delegates[TreeValuesColumnId] = fn;
}

// Get the leaf indicator getter.
if (leafIdActive)
{
ValueGetter<VBuffer<float>> fn = state.GetLeafIds;
delegates[LeafIdx] = fn;
delegates[LeafIdsColumnId] = fn;
}

// Get the path indicators getter.
if (pathIdActive)
{
ValueGetter<VBuffer<float>> fn = state.GetPathIds;
delegates[PathIdx] = fn;
delegates[PathIdsColumnId] = fn;
}

return delegates;
Expand Down Expand Up @@ -477,7 +411,7 @@ private static int CountLeaves(TreeEnsembleModelParameters ensemble)
return totalLeafCount;
}

private void GetTreeSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
private void GetTreeSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst)
{
var numTrees = _ensemble.TrainedEnsemble.NumTrees;

Expand All @@ -488,7 +422,7 @@ private void GetTreeSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
dst = editor.Commit();
}

private void GetLeafSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
private void GetLeafSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(off topic for the current PR)

How deep are normal trees? Could we name the slots for paths/leaves as "Tree033Leaf021-MyFeatABC_MyFeatXYZ_MyFeatIOU..." to better help users understand the output of the feature importance? The downsize is the name needs to be unique, and will be long as it notes all the input features the leaf node uses.

The current slot naming is a bit useless (though much better than nothing).
Currently the slot names look like:

LeavesMainLabel.Tree033Leaf021	65.73584
LeavesMainLabel.Tree033Leaf023	-42.43543
LeavesMainLabel.Tree033Leaf019	-40.72021
LeavesMainLabel.Tree057Leaf020	-37.54552
LeavesMainLabel.Tree079Leaf007	-36.29255
LeavesMainLabel.Tree055Leaf019	-34.78884
LeavesMainLabel.Tree075Leaf009	34.58635
LeavesMainLabel.Tree020Leaf020	33.72996
LeavesMainLabel.Tree047Leaf022	31.86535
LeavesMainLabel.Tree074Leaf008	31.86535
LeavesMainLabel.Tree066Leaf008	31.74181
LeavesMainLabel.Tree040Leaf019	30.9242

{
var numTrees = _ensemble.TrainedEnsemble.NumTrees;

Expand All @@ -505,7 +439,7 @@ private void GetLeafSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
dst = editor.Commit();
}

private void GetPathSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst)
private void GetPathSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst)
{
var numTrees = _ensemble.TrainedEnsemble.NumTrees;

Expand Down
33 changes: 33 additions & 0 deletions src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,39 @@ public static IEnumerable<SampleVectorOfNumbersData> GetVectorOfNumbersData()
return data;
}

private const int _simpleBinaryClassSampleFeatureLength = 10;

public class BinaryLabelFloatFeatureVectorSample
{
public bool Label;

[VectorType(_simpleBinaryClassSampleFeatureLength)]
public float[] Features;
}

public static IEnumerable<BinaryLabelFloatFeatureVectorSample> GenerateBinaryLabelFloatFeatureVectorSamples(int exampleCount)
{
var rnd = new Random(0);
var data = new List<BinaryLabelFloatFeatureVectorSample>();
for (int i = 0; i < exampleCount; ++i)
{
// Initialize an example with a random label and an empty feature vector.
var sample = new BinaryLabelFloatFeatureVectorSample() { Label = rnd.Next() % 2 == 0, Features = new float[_simpleBinaryClassSampleFeatureLength] };
// Fill feature vector according the assigned label.
for (int j = 0; j < 10; ++j)
{
var value = (float)rnd.NextDouble();
// Positive class gets larger feature value.
if (sample.Label)
value += 0.2f;
sample.Features[j] = value;
}

data.Add(sample);
}
return data;
}

/// <summary>
/// feature vector's length in <see cref="MulticlassClassificationExample"/>.
/// </summary>
Expand Down
5 changes: 3 additions & 2 deletions test/Microsoft.ML.TestFramework/Properties/AssemblyInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Runtime.CompilerServices;
using Microsoft.ML;

[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries.Tests, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests" + PublicKey.TestValue)]
[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries.Tests" + PublicKey.TestValue)]
1 change: 1 addition & 0 deletions test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
<ProjectReference Include="..\..\src\Microsoft.ML.HalLearners\Microsoft.ML.HalLearners.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.KMeansClustering\Microsoft.ML.KMeansClustering.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.SamplesUtils\Microsoft.ML.SamplesUtils.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.LightGBM\Microsoft.ML.LightGBM.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.Maml\Microsoft.ML.Maml.csproj" />
<ProjectReference Include="..\..\src\Microsoft.ML.OnnxTransform\Microsoft.ML.OnnxTransform.csproj" />
Expand Down
Loading