Skip to content

Commit 284e715

Browse files
committed
Add a dynamic API test
1 parent dd6236a commit 284e715

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs

+43
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
// See the LICENSE file in the project root for more information.
44

55
using Microsoft.ML.Data;
6+
using Microsoft.ML.Internal.Internallearn;
67
using Microsoft.ML.RunTests;
78
using Microsoft.ML.Trainers;
9+
using Microsoft.ML.Trainers.FastTree;
810
using Xunit;
911

1012
namespace Microsoft.ML.Tests.Scenarios.Api
@@ -43,5 +45,46 @@ public void IntrospectiveTraining()
4345
VBuffer<float> weights = default;
4446
model.LastTransformer.Model.GetFeatureWeights(ref weights);
4547
}
48+
49+
[Fact]
50+
public void FastTreeClassificationIntrospectiveTraining()
51+
{
52+
var ml = new MLContext(seed: 1, conc: 1);
53+
var data = ml.Data.ReadFromTextFile<SentimentData>(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true);
54+
55+
var trainer = ml.BinaryClassification.Trainers.FastTree(numLeaves: 5, numTrees: 3);
56+
57+
BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>> pred = null;
58+
59+
var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features")
60+
.AppendCacheCheckpoint(ml)
61+
.Append(trainer.WithOnFitDelegate(p => pred = p));
62+
63+
// Train.
64+
var model = pipeline.Fit(data);
65+
66+
// Extract the learned GBDT model.
67+
var treeCollection = ((FastTreeBinaryModelParameters)((Internal.Calibration.FeatureWeightsCalibratedPredictor)pred.Model).SubPredictor).TrainedTreeCollection;
68+
69+
// Inspect properties in the extracted model.
70+
Assert.Equal(3, treeCollection.Trees.Count);
71+
Assert.Equal(3, treeCollection.TreeWeights.Count);
72+
Assert.Equal(0, treeCollection.Bias);
73+
Assert.All(treeCollection.TreeWeights, weight => Assert.Equal(1.0, weight));
74+
75+
// Inspect the last tree.
76+
var tree = treeCollection.Trees[2];
77+
78+
Assert.Equal(5, tree.NumLeaves);
79+
Assert.Equal(4, tree.NumNodes);
80+
Assert.Equal(tree.LteChild.ToArray(), new int[] { 2, -2, -1, -3 });
81+
Assert.Equal(tree.GtChild.ToArray(), new int[] { 1, 3, -4, -5 });
82+
Assert.Equal(tree.NumericalSplitFeatureIndexes.ToArray(), new int[] { 14, 294, 633, 266 });
83+
Assert.Equal(tree.NumericalSplitThresholds.ToArray(), new float[] { 0.0911167f, 0.06509889f, 0.019873254f, 0.0361835f });
84+
Assert.All(tree.CategoricalSplitFlags.ToArray(), flag => Assert.False(flag));
85+
86+
Assert.Equal(0, tree.GetCategoricalSplitFeaturesAt(0).Length);
87+
Assert.Equal(0, tree.GetCategoricalCategoricalSplitFeatureRangeAt(0).Length);
88+
}
4689
}
4790
}

0 commit comments

Comments
 (0)