diff --git a/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs b/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs index b5261449b2..87ec58b231 100644 --- a/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs +++ b/src/Microsoft.ML.FastTree/Properties/AssemblyInfo.cs @@ -14,5 +14,6 @@ [assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Internal.FastTree" + InternalPublicKey.Value)] [assembly: InternalsVisibleTo(assemblyName: "RunTests" + InternalPublicKey.Value)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Tests" + PublicKey.TestValue)] [assembly: WantsToBeBestFriends] diff --git a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs index 8bdb9f88de..e6b651b967 100644 --- a/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs +++ b/src/Microsoft.ML.FastTree/TreeEnsembleFeaturizer.cs @@ -55,118 +55,18 @@ public sealed class Arguments : ScorerArgumentsBase private sealed class BoundMapper : ISchemaBoundRowMapper { - private sealed class SchemaImpl : ISchema - { - private readonly IExceptionContext _ectx; - private readonly string[] _names; - private readonly ColumnType[] _types; - - private readonly TreeEnsembleFeaturizerBindableMapper _parent; - - public int ColumnCount { get { return _types.Length; } } - - public SchemaImpl(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper parent, - ColumnType treeValueColType, ColumnType leafIdColType, ColumnType pathIdColType) - { - Contracts.CheckValueOrNull(ectx); - _ectx = ectx; - _ectx.AssertValue(parent); - _ectx.AssertValue(treeValueColType); - _ectx.AssertValue(leafIdColType); - _ectx.AssertValue(pathIdColType); - - _parent = parent; - - _names = new string[3]; - _names[TreeIdx] = OutputColumnNames.Trees; - _names[LeafIdx] = OutputColumnNames.Leaves; - _names[PathIdx] = OutputColumnNames.Paths; - - _types = new ColumnType[3]; - _types[TreeIdx] = treeValueColType; - _types[LeafIdx] = leafIdColType; - _types[PathIdx] = pathIdColType; - } - - public bool TryGetColumnIndex(string name, out int col) - { - col = -1; - if (name == OutputColumnNames.Trees) - col = TreeIdx; - else if (name == OutputColumnNames.Leaves) - col = LeafIdx; - else if (name == OutputColumnNames.Paths) - col = PathIdx; - return col >= 0; - } - - public string GetColumnName(int col) - { - _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _names[col]; - } - - public ColumnType GetColumnType(int col) - { - _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - return _types[col]; - } - - public IEnumerable> GetMetadataTypes(int col) - { - _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - yield return - MetadataUtils.GetNamesType(_types[col].VectorSize).GetPair(MetadataUtils.Kinds.SlotNames); - if (col == PathIdx || col == LeafIdx) - yield return BoolType.Instance.GetPair(MetadataUtils.Kinds.IsNormalized); - } - - public ColumnType GetMetadataTypeOrNull(string kind, int col) - { - _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - - if ((col == PathIdx || col == LeafIdx) && kind == MetadataUtils.Kinds.IsNormalized) - return BoolType.Instance; - if (kind == MetadataUtils.Kinds.SlotNames) - return MetadataUtils.GetNamesType(_types[col].VectorSize); - return null; - } - - public void GetMetadata(string kind, int col, ref TValue value) - { - _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col)); - - if ((col == PathIdx || col == LeafIdx) && kind == MetadataUtils.Kinds.IsNormalized) - MetadataUtils.Marshal(IsNormalized, col, ref value); - else if (kind == MetadataUtils.Kinds.SlotNames) - { - switch (col) - { - case TreeIdx: - MetadataUtils.Marshal>, TValue>(_parent.GetTreeSlotNames, col, ref value); - break; - case LeafIdx: - MetadataUtils.Marshal>, TValue>(_parent.GetLeafSlotNames, col, ref value); - break; - default: - Contracts.Assert(col == PathIdx); - MetadataUtils.Marshal>, TValue>(_parent.GetPathSlotNames, col, ref value); - break; - } - } - else - throw _ectx.ExceptGetMetadata(); - } - - private void IsNormalized(int iinfo, ref bool dst) - { - dst = true; - } - } - - private const int TreeIdx = 0; - private const int LeafIdx = 1; - private const int PathIdx = 2; + /// + /// Column index of values predicted by all trees in an ensemble in . + /// + private const int TreeValuesColumnId = 0; + /// + /// Column index of leaf IDs containing the considered example in . + /// + private const int LeafIdsColumnId = 1; + /// + /// Column index of path IDs which specify the paths the considered example passing through per tree in . + /// + private const int PathIdsColumnId = 2; private readonly TreeEnsembleFeaturizerBindableMapper _owner; private readonly IExceptionContext _ectx; @@ -193,10 +93,10 @@ public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper InputRoleMappedSchema = schema; // A vector containing the output of each tree on a given example. - var treeValueType = new VectorType(NumberType.Float, _owner._ensemble.TrainedEnsemble.NumTrees); + var treeValueType = new VectorType(NumberType.Float, owner._ensemble.TrainedEnsemble.NumTrees); // An indicator vector with length = the total number of leaves in the ensemble, indicating which leaf the example // ends up in all the trees in the ensemble. - var leafIdType = new VectorType(NumberType.Float, _owner._totalLeafCount); + var leafIdType = new VectorType(NumberType.Float, owner._totalLeafCount); // An indicator vector with length = the total number of nodes in the ensemble, indicating the nodes on // the paths of the example in all the trees in the ensemble. // The total number of nodes in a binary tree is equal to the number of internal nodes + the number of leaf nodes, @@ -204,8 +104,42 @@ public BoundMapper(IExceptionContext ectx, TreeEnsembleFeaturizerBindableMapper // plus one (since the root node is not a child of any node). So we have #internal + #leaf = 2*(#internal) + 1, // which means that #internal = #leaf - 1. // Therefore, the number of internal nodes in the ensemble is #leaf - #trees. - var pathIdType = new VectorType(NumberType.Float, _owner._totalLeafCount - _owner._ensemble.TrainedEnsemble.NumTrees); - OutputSchema = Schema.Create(new SchemaImpl(ectx, owner, treeValueType, leafIdType, pathIdType)); + var pathIdType = new VectorType(NumberType.Float, owner._totalLeafCount - owner._ensemble.TrainedEnsemble.NumTrees); + + // Start creating output schema with types derived above. + var schemaBuilder = new SchemaBuilder(); + + // Metadata of tree values. + var treeIdMetadataBuilder = new MetadataBuilder(); + treeIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(treeValueType.VectorSize), + (ValueGetter>>)owner.GetTreeSlotNames); + // Add the column of trees' output values + schemaBuilder.AddColumn(OutputColumnNames.Trees, treeValueType, treeIdMetadataBuilder.GetMetadata()); + + // Metadata of leaf IDs. + var leafIdMetadataBuilder = new MetadataBuilder(); + leafIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(leafIdType.VectorSize), + (ValueGetter>>)owner.GetLeafSlotNames); + leafIdMetadataBuilder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, (ref bool value) => value = true); + // Add the column of leaves' IDs where the input example reaches. + schemaBuilder.AddColumn(OutputColumnNames.Leaves, leafIdType, leafIdMetadataBuilder.GetMetadata()); + + // Metadata of path IDs. + var pathIdMetadataBuilder = new MetadataBuilder(); + pathIdMetadataBuilder.Add(MetadataUtils.Kinds.SlotNames, MetadataUtils.GetNamesType(pathIdType.VectorSize), + (ValueGetter>>)owner.GetPathSlotNames); + pathIdMetadataBuilder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, (ref bool value) => value = true); + // Add the column of encoded paths which the input example passes. + schemaBuilder.AddColumn(OutputColumnNames.Paths, pathIdType, pathIdMetadataBuilder.GetMetadata()); + + OutputSchema = schemaBuilder.GetSchema(); + + // Tree values must be the first output column. + Contracts.Assert(OutputSchema[OutputColumnNames.Trees].Index == TreeValuesColumnId); + // leaf IDs must be the second output column. + Contracts.Assert(OutputSchema[OutputColumnNames.Leaves].Index == LeafIdsColumnId); + // Path IDs must be the third output column. + Contracts.Assert(OutputSchema[OutputColumnNames.Paths].Index == PathIdsColumnId); } public Row GetRow(Row input, Func predicate) @@ -222,9 +156,9 @@ private Delegate[] CreateGetters(Row input, Func predicate) var delegates = new Delegate[3]; - var treeValueActive = predicate(TreeIdx); - var leafIdActive = predicate(LeafIdx); - var pathIdActive = predicate(PathIdx); + var treeValueActive = predicate(TreeValuesColumnId); + var leafIdActive = predicate(LeafIdsColumnId); + var pathIdActive = predicate(PathIdsColumnId); if (!treeValueActive && !leafIdActive && !pathIdActive) return delegates; @@ -235,21 +169,21 @@ private Delegate[] CreateGetters(Row input, Func predicate) if (treeValueActive) { ValueGetter> fn = state.GetTreeValues; - delegates[TreeIdx] = fn; + delegates[TreeValuesColumnId] = fn; } // Get the leaf indicator getter. if (leafIdActive) { ValueGetter> fn = state.GetLeafIds; - delegates[LeafIdx] = fn; + delegates[LeafIdsColumnId] = fn; } // Get the path indicators getter. if (pathIdActive) { ValueGetter> fn = state.GetPathIds; - delegates[PathIdx] = fn; + delegates[PathIdsColumnId] = fn; } return delegates; @@ -477,7 +411,7 @@ private static int CountLeaves(TreeEnsembleModelParameters ensemble) return totalLeafCount; } - private void GetTreeSlotNames(int col, ref VBuffer> dst) + private void GetTreeSlotNames(ref VBuffer> dst) { var numTrees = _ensemble.TrainedEnsemble.NumTrees; @@ -488,7 +422,7 @@ private void GetTreeSlotNames(int col, ref VBuffer> dst) dst = editor.Commit(); } - private void GetLeafSlotNames(int col, ref VBuffer> dst) + private void GetLeafSlotNames(ref VBuffer> dst) { var numTrees = _ensemble.TrainedEnsemble.NumTrees; @@ -505,7 +439,7 @@ private void GetLeafSlotNames(int col, ref VBuffer> dst) dst = editor.Commit(); } - private void GetPathSlotNames(int col, ref VBuffer> dst) + private void GetPathSlotNames(ref VBuffer> dst) { var numTrees = _ensemble.TrainedEnsemble.NumTrees; diff --git a/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs b/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs index 8e86e8f1f4..60a9a12ee5 100644 --- a/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs +++ b/src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs @@ -238,6 +238,39 @@ public static IEnumerable GetVectorOfNumbersData() return data; } + private const int _simpleBinaryClassSampleFeatureLength = 10; + + public class BinaryLabelFloatFeatureVectorSample + { + public bool Label; + + [VectorType(_simpleBinaryClassSampleFeatureLength)] + public float[] Features; + } + + public static IEnumerable GenerateBinaryLabelFloatFeatureVectorSamples(int exampleCount) + { + var rnd = new Random(0); + var data = new List(); + for (int i = 0; i < exampleCount; ++i) + { + // Initialize an example with a random label and an empty feature vector. + var sample = new BinaryLabelFloatFeatureVectorSample() { Label = rnd.Next() % 2 == 0, Features = new float[_simpleBinaryClassSampleFeatureLength] }; + // Fill feature vector according the assigned label. + for (int j = 0; j < 10; ++j) + { + var value = (float)rnd.NextDouble(); + // Positive class gets larger feature value. + if (sample.Label) + value += 0.2f; + sample.Features[j] = value; + } + + data.Add(sample); + } + return data; + } + /// /// feature vector's length in . /// diff --git a/test/Microsoft.ML.TestFramework/Properties/AssemblyInfo.cs b/test/Microsoft.ML.TestFramework/Properties/AssemblyInfo.cs index 7b003b45d1..e25660db84 100644 --- a/test/Microsoft.ML.TestFramework/Properties/AssemblyInfo.cs +++ b/test/Microsoft.ML.TestFramework/Properties/AssemblyInfo.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. using System.Runtime.CompilerServices; +using Microsoft.ML; -[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")] -[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries.Tests, PublicKey=002400000480000094000000060200000024000052534131000400000100010015c01ae1f50e8cc09ba9eac9147cf8fd9fce2cfe9f8dce4f7301c4132ca9fb50ce8cbf1df4dc18dd4d210e4345c744ecb3365ed327efdbc52603faa5e21daa11234c8c4a73e51f03bf192544581ebe107adee3a34928e39d04e524a9ce729d5090bfd7dad9d10c722c0def9ccc08ff0a03790e48bcd1f9b6c476063e1966a1c4")] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.Predictor.Tests" + PublicKey.TestValue)] +[assembly: InternalsVisibleTo(assemblyName: "Microsoft.ML.TimeSeries.Tests" + PublicKey.TestValue)] diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index 52b5283a52..c56cfd67f1 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -12,6 +12,7 @@ + diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs new file mode 100644 index 0000000000..fe0e8d3898 --- /dev/null +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TreeEnsembleFeaturizerTest.cs @@ -0,0 +1,110 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data; +using Microsoft.ML.SamplesUtils; +using Microsoft.ML.Trainers.FastTree; +using System; +using System.Linq; +using Xunit; + +namespace Microsoft.ML.Tests.TrainerEstimators +{ + public partial class TrainerEstimators + { + [Fact] + public void TreeEnsembleFeaturizerOutputSchemaTest() + { + // Create data set + var data = DatasetUtils.GenerateBinaryLabelFloatFeatureVectorSamples(1000).ToList(); + var dataView = ComponentCreation.CreateDataView(Env, data); + + // Define a tree model whose trees will be extracted to construct a tree featurizer. + var trainer = ML.BinaryClassification.Trainers.FastTree( + new FastTreeBinaryClassificationTrainer.Options + { + NumThreads = 1, + NumTrees = 10, + NumLeaves = 5, + }); + + // Train the defined tree model. + var model = trainer.Fit(dataView); + + // From the trained tree model, a mapper of tree featurizer is created. + var treeFeaturizer = new TreeEnsembleFeaturizerBindableMapper(Env, new TreeEnsembleFeaturizerBindableMapper.Arguments(), model.Model); + + // To get output schema, we need to create RoleMappedSchema for calling Bind(...). + var roleMappedSchema = new RoleMappedSchema(dataView.Schema, + label: nameof(DatasetUtils.BinaryLabelFloatFeatureVectorSample.Label), + feature: nameof(DatasetUtils.BinaryLabelFloatFeatureVectorSample.Features)); + + // Retrieve output schema. + var boundMapper = (treeFeaturizer as ISchemaBindableMapper).Bind(Env, roleMappedSchema); + var outputSchema = boundMapper.OutputSchema; + + { + // Check if output schema is correct. + var treeValuesColumn = outputSchema[0]; + Assert.Equal("Trees", treeValuesColumn.Name); + Assert.True(treeValuesColumn.Type is VectorType); + Assert.Equal(NumberType.R4, treeValuesColumn.Type.ItemType); + Assert.Equal(10, treeValuesColumn.Type.VectorSize); + // Below we check the only metadata field. + Assert.Single(treeValuesColumn.Metadata.Schema); + VBuffer> slotNames = default; + treeValuesColumn.Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref slotNames); + Assert.Equal(10, slotNames.Length); + // Just check the head and the tail of the extracted vector. + Assert.Equal("Tree000", slotNames.GetItemOrDefault(0).ToString()); + Assert.Equal("Tree009", slotNames.GetItemOrDefault(9).ToString()); + } + + { + var treeLeafIdsColumn = outputSchema[1]; + // Check column of tree leaf IDs. + Assert.Equal("Leaves", treeLeafIdsColumn.Name); + Assert.True(treeLeafIdsColumn.Type is VectorType); + Assert.Equal(NumberType.R4, treeLeafIdsColumn.Type.ItemType); + Assert.Equal(50, treeLeafIdsColumn.Type.VectorSize); + // Below we check the two leaf-IDs column's metadata fields. + Assert.Equal(2, treeLeafIdsColumn.Metadata.Schema.Count); + // Check metadata field IsNormalized's content. + bool leafIdsNormalizedFlag = false; + treeLeafIdsColumn.Metadata.GetValue(MetadataUtils.Kinds.IsNormalized, ref leafIdsNormalizedFlag); + Assert.True(leafIdsNormalizedFlag); + // Check metadata field SlotNames's content. + VBuffer> leafIdsSlotNames = default; + treeLeafIdsColumn.Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref leafIdsSlotNames); + Assert.Equal(50, leafIdsSlotNames.Length); + // Just check the head and the tail of the extracted vector. + Assert.Equal("Tree000Leaf000", leafIdsSlotNames.GetItemOrDefault(0).ToString()); + Assert.Equal("Tree009Leaf004", leafIdsSlotNames.GetItemOrDefault(49).ToString()); + } + + { + var treePathIdsColumn = outputSchema[2]; + // Check column of path IDs. + Assert.Equal("Paths", treePathIdsColumn.Name); + Assert.True(treePathIdsColumn.Type is VectorType); + Assert.Equal(NumberType.R4, treePathIdsColumn.Type.ItemType); + Assert.Equal(40, treePathIdsColumn.Type.VectorSize); + // Below we check the two path-IDs column's metadata fields. + Assert.Equal(2, treePathIdsColumn.Metadata.Schema.Count); + // Check metadata field IsNormalized's content. + bool pathIdsNormalizedFlag = false; + treePathIdsColumn.Metadata.GetValue(MetadataUtils.Kinds.IsNormalized, ref pathIdsNormalizedFlag); + Assert.True(pathIdsNormalizedFlag); + // Check metadata field SlotNames's content. + VBuffer> pathIdsSlotNames = default; + treePathIdsColumn.Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref pathIdsSlotNames); + Assert.Equal(40, pathIdsSlotNames.Length); + // Just check the head and the tail of the extracted vector. + Assert.Equal("Tree000Node000", pathIdsSlotNames.GetItemOrDefault(0).ToString()); + Assert.Equal("Tree009Node003", pathIdsSlotNames.GetItemOrDefault(39).ToString()); + } + + } + } +}