|
3 | 3 | // See the LICENSE file in the project root for more information.
|
4 | 4 |
|
5 | 5 | using Microsoft.ML.Data;
|
| 6 | +using Microsoft.ML.Internal.Internallearn; |
6 | 7 | using Microsoft.ML.RunTests;
|
7 | 8 | using Microsoft.ML.Trainers;
|
| 9 | +using Microsoft.ML.Trainers.FastTree; |
8 | 10 | using Xunit;
|
9 | 11 |
|
10 | 12 | namespace Microsoft.ML.Tests.Scenarios.Api
|
@@ -43,5 +45,46 @@ public void IntrospectiveTraining()
|
43 | 45 | VBuffer<float> weights = default;
|
44 | 46 | model.LastTransformer.Model.GetFeatureWeights(ref weights);
|
45 | 47 | }
|
| 48 | + |
| 49 | + [Fact] |
| 50 | + public void FastTreeClassificationIntrospectiveTraining() |
| 51 | + { |
| 52 | + var ml = new MLContext(seed: 1, conc: 1); |
| 53 | + var data = ml.Data.ReadFromTextFile<SentimentData>(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true); |
| 54 | + |
| 55 | + var trainer = ml.BinaryClassification.Trainers.FastTree(numLeaves: 5, numTrees: 3); |
| 56 | + |
| 57 | + BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>> pred = null; |
| 58 | + |
| 59 | + var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") |
| 60 | + .AppendCacheCheckpoint(ml) |
| 61 | + .Append(trainer.WithOnFitDelegate(p => pred = p)); |
| 62 | + |
| 63 | + // Train. |
| 64 | + var model = pipeline.Fit(data); |
| 65 | + |
| 66 | + // Extract the learned GBDT model. |
| 67 | + var treeCollection = ((FastTreeBinaryModelParameters)((Internal.Calibration.FeatureWeightsCalibratedPredictor)pred.Model).SubPredictor).TrainedTreeCollection; |
| 68 | + |
| 69 | + // Inspect properties in the extracted model. |
| 70 | + Assert.Equal(3, treeCollection.Trees.Count); |
| 71 | + Assert.Equal(3, treeCollection.TreeWeights.Count); |
| 72 | + Assert.Equal(0, treeCollection.Bias); |
| 73 | + Assert.All(treeCollection.TreeWeights, weight => Assert.Equal(1.0, weight)); |
| 74 | + |
| 75 | + // Inspect the last tree. |
| 76 | + var tree = treeCollection.Trees[2]; |
| 77 | + |
| 78 | + Assert.Equal(5, tree.NumLeaves); |
| 79 | + Assert.Equal(4, tree.NumNodes); |
| 80 | + Assert.Equal(tree.LteChild.ToArray(), new int[] { 2, -2, -1, -3 }); |
| 81 | + Assert.Equal(tree.GtChild.ToArray(), new int[] { 1, 3, -4, -5 }); |
| 82 | + Assert.Equal(tree.NumericalSplitFeatureIndexes.ToArray(), new int[] { 14, 294, 633, 266 }); |
| 83 | + Assert.Equal(tree.NumericalSplitThresholds.ToArray(), new float[] { 0.0911167f, 0.06509889f, 0.019873254f, 0.0361835f }); |
| 84 | + Assert.All(tree.CategoricalSplitFlags.ToArray(), flag => Assert.False(flag)); |
| 85 | + |
| 86 | + Assert.Equal(0, tree.GetCategoricalSplitFeaturesAt(0).Length); |
| 87 | + Assert.Equal(0, tree.GetCategoricalCategoricalSplitFeatureRangeAt(0).Length); |
| 88 | + } |
46 | 89 | }
|
47 | 90 | }
|
0 commit comments