Skip to content

Commit d715b71

Browse files
committed
Add a test
1 parent 1e98e85 commit d715b71

File tree

4 files changed

+90
-6
lines changed

4 files changed

+90
-6
lines changed

src/Microsoft.ML.FastTree/Representation/TreeRegressorCollection.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace Microsoft.ML.FastTree
1313
/// <see cref="TreeEnsemble{T}"/>, we need to compute the output values of all trees in <see cref="Trees"/>,
1414
/// scale those values via <see cref="TreeWeights"/>, and finally sum the scaled values and <see cref="Bias"/> up.
1515
/// </summary>
16-
public class TreeEnsemble<T> where T : RegressionTreeBase
16+
public sealed class TreeEnsemble<T> where T : RegressionTreeBase
1717
{
1818
/// <summary>
1919
/// When doing prediction, this is a value added to the weighted sum of all <see cref="Trees"/>' outputs.

src/Microsoft.ML.FastTree/TreeEnsemble/QuantileRegressionTree.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,9 @@ internal void ExtractLeafSamplesAndTheirWeights(out double[][] leafSamples, out
117117
leafSamples[i][j] = _labelsDistribution[i * sampleCountPerLeaf + j];
118118
else
119119
// No training label is available, so the i-th leaf's value is used directly. Note that sampleCountPerLeaf must be 1 in this case.
120-
leafSampleWeights[i][j] = LeafValues[i];
120+
leafSamples[i][j] = LeafValues[i];
121121
if (_instanceWeights != null)
122-
leafSamples[i][j] = _instanceWeights[i * sampleCountPerLeaf + j];
122+
leafSampleWeights[i][j] = _instanceWeights[i * sampleCountPerLeaf + j];
123123
else
124124
leafSampleWeights[i][j] = 1.0;
125125
}

src/Microsoft.ML.SamplesUtils/SamplesDatasetUtils.cs

+31
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,37 @@ public static IEnumerable<BinaryLabelFloatFeatureVectorSample> GenerateBinaryLa
271271
return data;
272272
}
273273

274+
public class FloatLabelFloatFeatureVectorSample
275+
{
276+
public float Label;
277+
278+
[VectorType(_simpleBinaryClassSampleFeatureLength)]
279+
public float[] Features;
280+
}
281+
282+
public static IEnumerable<FloatLabelFloatFeatureVectorSample> GenerateFloatLabelFloatFeatureVectorSamples(int exampleCount)
283+
{
284+
var rnd = new Random(0);
285+
var data = new List<FloatLabelFloatFeatureVectorSample>();
286+
for (int i = 0; i < exampleCount; ++i)
287+
{
288+
// Initialize an example with a random label and an empty feature vector.
289+
var sample = new FloatLabelFloatFeatureVectorSample() { Label = rnd.Next() % 2, Features = new float[_simpleBinaryClassSampleFeatureLength] };
290+
// Fill feature vector according the assigned label.
291+
for (int j = 0; j < _simpleBinaryClassSampleFeatureLength; ++j)
292+
{
293+
var value = (float)rnd.NextDouble();
294+
// Positive class gets larger feature value.
295+
if (sample.Label == 0)
296+
value += 0.2f;
297+
sample.Features[j] = value;
298+
}
299+
300+
data.Add(sample);
301+
}
302+
return data;
303+
}
304+
274305
public class FfmExample
275306
{
276307
public bool Label;

test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs

+56-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using Microsoft.ML.Data;
66
using Microsoft.ML.Internal.Internallearn;
77
using Microsoft.ML.RunTests;
8+
using Microsoft.ML.SamplesUtils;
89
using Microsoft.ML.Trainers;
910
using Microsoft.ML.Trainers.FastTree;
1011
using Xunit;
@@ -56,15 +57,15 @@ public void FastTreeClassificationIntrospectiveTraining()
5657

5758
BinaryPredictionTransformer<IPredictorWithFeatureWeights<float>> pred = null;
5859

59-
var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features")
60+
var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText")
6061
.AppendCacheCheckpoint(ml)
6162
.Append(trainer.WithOnFitDelegate(p => pred = p));
6263

6364
// Train.
6465
var model = pipeline.Fit(data);
6566

6667
// Extract the learned GBDT model.
67-
var treeCollection = ((FastTreeBinaryModelParameters)((Internal.Calibration.FeatureWeightsCalibratedPredictor)pred.Model).SubPredictor).TrainedTreeEnsemble;
68+
var treeCollection = ((FastForestBinaryModelParameters)((Internal.Calibration.FeatureWeightsCalibratedPredictor)pred.Model).SubPredictor).TrainedTreeEnsemble;
6869

6970
// Inspect properties in the extracted model.
7071
Assert.Equal(3, treeCollection.Trees.Count);
@@ -80,11 +81,63 @@ public void FastTreeClassificationIntrospectiveTraining()
8081
Assert.Equal(tree.LteChild, new int[] { 2, -2, -1, -3 });
8182
Assert.Equal(tree.GtChild, new int[] { 1, 3, -4, -5 });
8283
Assert.Equal(tree.NumericalSplitFeatureIndexes, new int[] { 14, 294, 633, 266 });
83-
Assert.Equal(tree.NumericalSplitThresholds, new float[] { 0.0911167f, 0.06509889f, 0.019873254f, 0.0361835f });
84+
var expectedThresholds = new float[] { 0.0911167f, 0.06509889f, 0.019873254f, 0.0361835f };
85+
for (int i = 0; i < tree.NumNodes; ++i)
86+
Assert.Equal(expectedThresholds[i], tree.NumericalSplitThresholds[i], 6);
8487
Assert.All(tree.CategoricalSplitFlags, flag => Assert.False(flag));
8588

8689
Assert.Equal(0, tree.GetCategoricalSplitFeaturesAt(0).Count);
8790
Assert.Equal(0, tree.GetCategoricalCategoricalSplitFeatureRangeAt(0).Count);
8891
}
92+
93+
[Fact]
94+
public void FastForestRegressionIntrospectiveTraining()
95+
{
96+
var ml = new MLContext(seed: 1, conc: 1);
97+
var data = DatasetUtils.GenerateFloatLabelFloatFeatureVectorSamples(1000);
98+
var dataView = ml.Data.ReadFromEnumerable(data);
99+
100+
RegressionPredictionTransformer<FastForestRegressionModelParameters> pred = null;
101+
var trainer = ml.Regression.Trainers.FastForest(numLeaves: 5, numTrees: 3).WithOnFitDelegate(p => pred = p);
102+
103+
// Train.
104+
var model = trainer.Fit(dataView);
105+
106+
// Extract the learned RF model.
107+
var treeCollection = pred.Model.TrainedTreeEnsemble;
108+
109+
// Inspect properties in the extracted model.
110+
Assert.Equal(3, treeCollection.Trees.Count);
111+
Assert.Equal(3, treeCollection.TreeWeights.Count);
112+
Assert.Equal(0, treeCollection.Bias);
113+
Assert.All(treeCollection.TreeWeights, weight => Assert.Equal(1.0, weight));
114+
115+
// Inspect the last tree.
116+
var tree = treeCollection.Trees[2];
117+
118+
Assert.Equal(5, tree.NumLeaves);
119+
Assert.Equal(4, tree.NumNodes);
120+
Assert.Equal(tree.LteChild, new int[] { -1, -2, -3, -4 });
121+
Assert.Equal(tree.GtChild, new int[] { 1, 2, 3, -5 });
122+
Assert.Equal(tree.NumericalSplitFeatureIndexes, new int[] { 9, 0, 1, 8 });
123+
var expectedThresholds = new float[] { 0.208134219f, 0.198336035f, 0.202952743f, 0.205061346f };
124+
for (int i = 0; i < tree.NumNodes; ++i)
125+
Assert.Equal(expectedThresholds[i], tree.NumericalSplitThresholds[i], 6);
126+
Assert.All(tree.CategoricalSplitFlags, flag => Assert.False(flag));
127+
128+
Assert.Equal(0, tree.GetCategoricalSplitFeaturesAt(0).Count);
129+
Assert.Equal(0, tree.GetCategoricalCategoricalSplitFeatureRangeAt(0).Count);
130+
131+
var samples = new double[] { 0.97468354430379744, 1.0, 0.97727272727272729, 0.972972972972973, 0.26124197002141325 };
132+
for (int i = 0; i < tree.NumLeaves; ++i)
133+
{
134+
var sample = tree.GetLeafSamplesAt(i);
135+
Assert.Single(sample);
136+
Assert.Equal(samples[i], sample[0], 6);
137+
var weight = tree.GetLeafSampleWeightsAt(i);
138+
Assert.Single(weight);
139+
Assert.Equal(1, weight[0]);
140+
}
141+
}
89142
}
90143
}

0 commit comments

Comments
 (0)