Skip to content

Commit 4093f28

Browse files
committed
Add attributes for quantile regression tree
1 parent b27b171 commit 4093f28

File tree

2 files changed

+46
-0
lines changed

2 files changed

+46
-0
lines changed

src/Microsoft.ML.FastTree/TreeEnsemble/QuantileRegressionTree.cs

+36
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using System;
6+
using System.Collections.Generic;
57
using Microsoft.ML.Data;
68
using Microsoft.ML.Internal.Utilities;
79
using Microsoft.ML.Model;
@@ -91,6 +93,40 @@ public void SetLabelsDistribution(double[] labelsDistribution, double[] weights)
9193
_instanceWeights = weights;
9294
}
9395

96+
internal IReadOnlyList<IReadOnlyList<double>> ExtractLeafSamples()
97+
{
98+
var sampledLabelsBuffer = new List<List<double>>();
99+
var sampleCountPerLeaf = _labelsDistribution.Length / NumLeaves;
100+
for (int i = 0; i < NumLeaves; ++i)
101+
{
102+
var samplesPerLeaf = new List<double>();
103+
var weightsPerLeaf = new List<double>();
104+
for (int j = 0; j < sampleCountPerLeaf; ++j)
105+
samplesPerLeaf.Add(_labelsDistribution[i * sampleCountPerLeaf + j]);
106+
sampledLabelsBuffer.Add(samplesPerLeaf);
107+
}
108+
return sampledLabelsBuffer;
109+
}
110+
111+
internal IReadOnlyList<IReadOnlyList<double>> ExtractLeafSampleWeights()
112+
{
113+
var labelWeightsBuffer = new List<List<double>>();
114+
var sampleCountPerLeaf = _labelsDistribution.Length / NumLeaves;
115+
for (int i = 0; i < NumLeaves; ++i)
116+
{
117+
var weightsPerLeaf = new List<double>();
118+
for (int j = 0; j < sampleCountPerLeaf; ++j)
119+
{
120+
if (_instanceWeights != null)
121+
weightsPerLeaf.Add(_instanceWeights[i * sampleCountPerLeaf + j]);
122+
else
123+
weightsPerLeaf.Add(1.0);
124+
}
125+
labelWeightsBuffer.Add(weightsPerLeaf);
126+
}
127+
return labelWeightsBuffer;
128+
}
129+
94130
public override int SizeInBytes()
95131
{
96132
return base.SizeInBytes() + _labelsDistribution.SizeInBytes() + (_instanceWeights != null ? _instanceWeights.SizeInBytes() : 0);

src/Microsoft.ML.FastTree/TreeEnsemble/RegressionTree.cs

+10
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ public class RegressionTreeView
3535
public IReadOnlyList<IReadOnlyList<int>> CategoricalSplitFeatureRanges => _tree.CategoricalSplitFeatureRanges;
3636
public ReadOnlySpan<double> LeafValues => _tree.LeafValues;
3737

38+
public IReadOnlyList<IReadOnlyList<double>> LeafSamples { get; }
39+
public IReadOnlyList<IReadOnlyList<double>> LeafSampleWeights { get; }
40+
3841
/// <summary>
3942
/// Number of leaves in the tree. Note that <see cref="NumLeaves"/> does not take non-leaf nodes into account.
4043
/// </summary>
@@ -55,6 +58,13 @@ public class RegressionTreeView
5558
internal RegressionTreeView(RegressionTree tree)
5659
{
5760
_tree = tree;
61+
LeafSamples = null;
62+
LeafSampleWeights = null;
63+
if (tree is QuantileRegressionTree)
64+
{
65+
LeafSamples = ((QuantileRegressionTree)tree).ExtractLeafSamples();
66+
LeafSampleWeights = ((QuantileRegressionTree)tree).ExtractLeafSampleWeights();
67+
}
5868
}
5969
}
6070

0 commit comments

Comments
 (0)