-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Remove ISchema in TreeEnsembleFeaturizer #2132
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e3fad19
d2a62d7
10c49d7
b6fdbd5
1dee5e2
f5093e3
f6c91c1
49c59ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,118 +55,18 @@ public sealed class Arguments : ScorerArgumentsBase | |
|
||
private sealed class BoundMapper : ISchemaBoundRowMapper | ||
{ | ||
private sealed class SchemaImpl : ISchema | ||
{ | ||
private readonly IExceptionContext _ectx; | ||
private readonly string[] _names; | ||
private readonly ColumnType[] _types; | ||
|
||
private readonly TreeEnsembleFeaturizerBindableMapper _parent; | ||
|
||
public int ColumnCount { get { return _types.Length; } } | ||
|
||
public SchemaImpl(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper parent, | ||
ColumnType treeValueColType, ColumnType leafIdColType, ColumnType pathIdColType) | ||
{ | ||
Contracts.CheckValueOrNull(ectx); | ||
_ectx = ectx; | ||
_ectx.AssertValue(parent); | ||
_ectx.AssertValue(treeValueColType); | ||
_ectx.AssertValue(leafIdColType); | ||
_ectx.AssertValue(pathIdColType); | ||
|
||
_parent = parent; | ||
|
||
_names = new string[3]; | ||
_names[TreeIdx] = OutputColumnNames.Trees; | ||
_names[LeafIdx] = OutputColumnNames.Leaves; | ||
_names[PathIdx] = OutputColumnNames.Paths; | ||
|
||
_types = new ColumnType[3]; | ||
_types[TreeIdx] = treeValueColType; | ||
_types[LeafIdx] = leafIdColType; | ||
_types[PathIdx] = pathIdColType; | ||
} | ||
|
||
public bool TryGetColumnIndex(string name, out int col) | ||
{ | ||
col = -1; | ||
if (name == OutputColumnNames.Trees) | ||
col = TreeIdx; | ||
else if (name == OutputColumnNames.Leaves) | ||
col = LeafIdx; | ||
else if (name == OutputColumnNames.Paths) | ||
col = PathIdx; | ||
return col >= 0; | ||
} | ||
|
||
public string GetColumnName(int col) | ||
{ | ||
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); | ||
return _names[col]; | ||
} | ||
|
||
public ColumnType GetColumnType(int col) | ||
{ | ||
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); | ||
return _types[col]; | ||
} | ||
|
||
public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col) | ||
{ | ||
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); | ||
yield return | ||
MetadataUtils.GetNamesType(_types[col].VectorSize).GetPair(MetadataUtils.Kinds.SlotNames); | ||
if (col == PathIdx || col == LeafIdx) | ||
yield return BoolType.Instance.GetPair(MetadataUtils.Kinds.IsNormalized); | ||
} | ||
|
||
public ColumnType GetMetadataTypeOrNull(string kind, int col) | ||
{ | ||
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); | ||
|
||
if ((col == PathIdx || col == LeafIdx) && kind == MetadataUtils.Kinds.IsNormalized) | ||
return BoolType.Instance; | ||
if (kind == MetadataUtils.Kinds.SlotNames) | ||
return MetadataUtils.GetNamesType(_types[col].VectorSize); | ||
return null; | ||
} | ||
|
||
public void GetMetadata<TValue>(string kind, int col, ref TValue value) | ||
{ | ||
_ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); | ||
|
||
if ((col == PathIdx || col == LeafIdx) && kind == MetadataUtils.Kinds.IsNormalized) | ||
MetadataUtils.Marshal<bool, TValue>(IsNormalized, col, ref value); | ||
else if (kind == MetadataUtils.Kinds.SlotNames) | ||
{ | ||
switch (col) | ||
{ | ||
case TreeIdx: | ||
MetadataUtils.Marshal<VBuffer<ReadOnlyMemory<char>>, TValue>(_parent.GetTreeSlotNames, col, ref value); | ||
break; | ||
case LeafIdx: | ||
MetadataUtils.Marshal<VBuffer<ReadOnlyMemory<char>>, TValue>(_parent.GetLeafSlotNames, col, ref value); | ||
break; | ||
default: | ||
Contracts.Assert(col == PathIdx); | ||
MetadataUtils.Marshal<VBuffer<ReadOnlyMemory<char>>, TValue>(_parent.GetPathSlotNames, col, ref value); | ||
break; | ||
} | ||
} | ||
else | ||
throw _ectx.ExceptGetMetadata(); | ||
} | ||
|
||
private void IsNormalized(int iinfo, ref bool dst) | ||
{ | ||
dst = true; | ||
} | ||
} | ||
|
||
private const int TreeIdx = 0; | ||
private const int LeafIdx = 1; | ||
private const int PathIdx = 2; | ||
/// <summary> | ||
/// Column index of values predicted by all trees in an ensemble in <see cref="OutputSchema"/>. | ||
/// </summary> | ||
private const int TreeValuesColumnId = 0; | ||
/// <summary> | ||
/// Column index of leaf IDs containing the considered example in <see cref="OutputSchema"/>. | ||
/// </summary> | ||
private const int LeafIdsColumnId = 1; | ||
/// <summary> | ||
/// Column index of path IDs which specify the paths the considered example passing through per tree in <see cref="OutputSchema"/>. | ||
/// </summary> | ||
private const int PathIdsColumnId = 2; | ||
|
||
private readonly TreeEnsembleFeaturizerBindableMapper _owner; | ||
private readonly IExceptionContext _ectx; | ||
|
@@ -193,19 +93,53 @@ public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper | |
InputRoleMappedSchema = schema; | ||
|
||
// A vector containing the output of each tree on a given example. | ||
var treeValueType = new VectorType(NumberType.Float, _owner._ensemble.TrainedEnsemble.NumTrees); | ||
var treeValueType = new VectorType(NumberType.Float, owner._ensemble.TrainedEnsemble.NumTrees); | ||
// An indicator vector with length = the total number of leaves in the ensemble, indicating which leaf the example | ||
// ends up in all the trees in the ensemble. | ||
var leafIdType = new VectorType(NumberType.Float, _owner._totalLeafCount); | ||
var leafIdType = new VectorType(NumberType.Float, owner._totalLeafCount); | ||
// An indicator vector with length = the total number of nodes in the ensemble, indicating the nodes on | ||
// the paths of the example in all the trees in the ensemble. | ||
// The total number of nodes in a binary tree is equal to the number of internal nodes + the number of leaf nodes, | ||
// and it is also equal to the number of children of internal nodes (which is 2 * the number of internal nodes) | ||
// plus one (since the root node is not a child of any node). So we have #internal + #leaf = 2*(#internal) + 1, | ||
// which means that #internal = #leaf - 1. | ||
// Therefore, the number of internal nodes in the ensemble is #leaf - #trees. | ||
var pathIdType = new VectorType(NumberType.Float, _owner._totalLeafCount - _owner._ensemble.TrainedEnsemble.NumTrees); | ||
OutputSchema = Schema.Create(new SchemaImpl(ectx, owner, treeValueType, leafIdType, pathIdType)); | ||
var pathIdType = new VectorType(NumberType.Float, owner._totalLeafCount - owner._ensemble.TrainedEnsemble.NumTrees); | ||
|
||
// Start creating output schema with types derived above. | ||
var schemaBuilder = new SchemaBuilder(); | ||
|
||
// Metadata of tree values. | ||
var treeIdMetadataBuilder = new MetadataBuilder(); | ||
treeIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(treeValueType.VectorSize), | ||
(ValueGetter<VBuffer<ReadOnlyMemory<char>>>)owner.GetTreeSlotNames); | ||
// Add the column of trees' output values | ||
schemaBuilder.AddColumn(OutputColumnNames.Trees, treeValueType, treeIdMetadataBuilder.GetMetadata()); | ||
|
||
// Metadata of leaf IDs. | ||
var leafIdMetadataBuilder = new MetadataBuilder(); | ||
leafIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(leafIdType.VectorSize), | ||
(ValueGetter<VBuffer<ReadOnlyMemory<char>>>)owner.GetLeafSlotNames); | ||
leafIdMetadataBuilder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, (ref bool value) => value = true); | ||
// Add the column of leaves' IDs where the input example reaches. | ||
schemaBuilder.AddColumn(OutputColumnNames.Leaves, leafIdType, leafIdMetadataBuilder.GetMetadata()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see that in their present form these asserts are terribly useful, since the purpose of the check was to see that the schema's indices lined up with our expectations, and this check no longer does that. What would be useful perhaps is that once the schema is created at the end of this constructor where we've assigned output schema, we can do things like this: So three asserts along these lines might be useful: Contracts.Assert(OutputSchema[OutputColumnNames.Leaves].Index == LeafIdsColumnId) But the straight transliteration of the old asserts is no longer serving the purpose it once did. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. We have // Tree values must be the first output column.
Contracts.Assert(OutputSchema[OutputColumnNames.Trees].Index == TreeValuesColumnId);
// leaf IDs must be the second output column.
Contracts.Assert(OutputSchema[OutputColumnNames.Leaves].Index == LeafIdsColumnId);
// Path IDs must be the third output column.
Contracts.Assert(OutputSchema[OutputColumnNames.Paths].Index == PathIdsColumnId); now. Many thanks! In reply to: 247681048 [](ancestors = 247681048) |
||
|
||
// Metadata of path IDs. | ||
var pathIdMetadataBuilder = new MetadataBuilder(); | ||
pathIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(pathIdType.VectorSize), | ||
(ValueGetter<VBuffer<ReadOnlyMemory<char>>>)owner.GetPathSlotNames); | ||
pathIdMetadataBuilder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, (ref bool value) => value = true); | ||
// Add the column of encoded paths which the input example passes. | ||
schemaBuilder.AddColumn(OutputColumnNames.Paths, pathIdType, pathIdMetadataBuilder.GetMetadata()); | ||
|
||
OutputSchema = schemaBuilder.GetSchema(); | ||
|
||
// Tree values must be the first output column. | ||
Contracts.Assert(OutputSchema[OutputColumnNames.Trees].Index == TreeValuesColumnId); | ||
// leaf IDs must be the second output column. | ||
Contracts.Assert(OutputSchema[OutputColumnNames.Leaves].Index == LeafIdsColumnId); | ||
// Path IDs must be the third output column. | ||
Contracts.Assert(OutputSchema[OutputColumnNames.Paths].Index == PathIdsColumnId); | ||
} | ||
|
||
public Row GetRow(Row input, Func<int, bool> predicate) | ||
|
@@ -222,9 +156,9 @@ private Delegate[] CreateGetters(Row input, Func<int, bool> predicate) | |
|
||
var delegates = new Delegate[3]; | ||
|
||
var treeValueActive = predicate(TreeIdx); | ||
var leafIdActive = predicate(LeafIdx); | ||
var pathIdActive = predicate(PathIdx); | ||
var treeValueActive = predicate(TreeValuesColumnId); | ||
var leafIdActive = predicate(LeafIdsColumnId); | ||
var pathIdActive = predicate(PathIdsColumnId); | ||
|
||
if (!treeValueActive && !leafIdActive && !pathIdActive) | ||
return delegates; | ||
|
@@ -235,21 +169,21 @@ private Delegate[] CreateGetters(Row input, Func<int, bool> predicate) | |
if (treeValueActive) | ||
{ | ||
ValueGetter<VBuffer<float>> fn = state.GetTreeValues; | ||
delegates[TreeIdx] = fn; | ||
delegates[TreeValuesColumnId] = fn; | ||
} | ||
|
||
// Get the leaf indicator getter. | ||
if (leafIdActive) | ||
{ | ||
ValueGetter<VBuffer<float>> fn = state.GetLeafIds; | ||
delegates[LeafIdx] = fn; | ||
delegates[LeafIdsColumnId] = fn; | ||
} | ||
|
||
// Get the path indicators getter. | ||
if (pathIdActive) | ||
{ | ||
ValueGetter<VBuffer<float>> fn = state.GetPathIds; | ||
delegates[PathIdx] = fn; | ||
delegates[PathIdsColumnId] = fn; | ||
} | ||
|
||
return delegates; | ||
|
@@ -477,7 +411,7 @@ private static int CountLeaves(TreeEnsembleModelParameters ensemble) | |
return totalLeafCount; | ||
} | ||
|
||
private void GetTreeSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst) | ||
private void GetTreeSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst) | ||
{ | ||
var numTrees = _ensemble.TrainedEnsemble.NumTrees; | ||
|
||
|
@@ -488,7 +422,7 @@ private void GetTreeSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst) | |
dst = editor.Commit(); | ||
} | ||
|
||
private void GetLeafSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst) | ||
private void GetLeafSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (off topic for the current PR) How deep are normal trees? Could we name the slots for paths/leaves as "Tree033Leaf021-MyFeatABC_MyFeatXYZ_MyFeatIOU..." to better help users understand the output of the feature importance? The downsize is the name needs to be unique, and will be long as it notes all the input features the leaf node uses. The current slot naming is a bit useless (though much better than nothing). LeavesMainLabel.Tree033Leaf021 65.73584
LeavesMainLabel.Tree033Leaf023 -42.43543
LeavesMainLabel.Tree033Leaf019 -40.72021
LeavesMainLabel.Tree057Leaf020 -37.54552
LeavesMainLabel.Tree079Leaf007 -36.29255
LeavesMainLabel.Tree055Leaf019 -34.78884
LeavesMainLabel.Tree075Leaf009 34.58635
LeavesMainLabel.Tree020Leaf020 33.72996
LeavesMainLabel.Tree047Leaf022 31.86535
LeavesMainLabel.Tree074Leaf008 31.86535
LeavesMainLabel.Tree066Leaf008 31.74181
LeavesMainLabel.Tree040Leaf019 30.9242 |
||
{ | ||
var numTrees = _ensemble.TrainedEnsemble.NumTrees; | ||
|
||
|
@@ -505,7 +439,7 @@ private void GetLeafSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst) | |
dst = editor.Commit(); | ||
} | ||
|
||
private void GetPathSlotNames(int col, ref VBuffer<ReadOnlyMemory<char>> dst) | ||
private void GetPathSlotNames(ref VBuffer<ReadOnlyMemory<char>> dst) | ||
{ | ||
var numTrees = _ensemble.TrainedEnsemble.NumTrees; | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.