|
1 |
| -using System.Collections.Generic; |
| 1 | +// Licensed to the .NET Foundation under one or more agreements. |
| 2 | +// The .NET Foundation licenses this file to you under the MIT license. |
| 3 | +// See the LICENSE file in the project root for more information. |
| 4 | + |
| 5 | +using System.Collections.Generic; |
2 | 6 | using System.Collections.Immutable;
|
3 |
| -using System.Linq; |
4 | 7 | using Microsoft.ML.Trainers.FastTree.Internal;
|
5 | 8 |
|
6 | 9 | namespace Microsoft.ML.FastTree
|
7 | 10 | {
|
8 | 11 | /// <summary>
|
9 |
| - /// A container class for exposing <see cref="InternalRegressionTree"/>'s attributes to users. |
| 12 | + /// A container base class for exposing <see cref="InternalRegressionTree"/>'s and |
| 13 | + /// <see cref="InternalQuantileRegressionTree"/>'s attributes to users. |
10 | 14 | /// This class should not be mutable, so it contains a lot of read-only members.
|
11 | 15 | /// </summary>
|
12 |
| - public class RegressionTree |
| 16 | + public abstract class RegressionTreeBase |
13 | 17 | {
|
14 | 18 | /// <summary>
|
15 |
| - /// <see cref="RegressionTree"/> is an immutable wrapper over <see cref="_tree"/> for exposing some tree's |
| 19 | + /// <see cref="RegressionTreeBase"/> is an immutable wrapper over <see cref="_tree"/> for exposing some tree's |
16 | 20 | /// attribute to users.
|
17 | 21 | /// </summary>
|
18 | 22 | private readonly InternalRegressionTree _tree;
|
19 | 23 |
|
20 |
| - /// <summary> |
21 |
| - /// Sample labels from training data. <see cref="_leafSamples"/>[i] stores the labels falling into the |
22 |
| - /// i-th leaf. |
23 |
| - /// </summary> |
24 |
| - private readonly double[][] _leafSamples; |
25 |
| - /// <summary> |
26 |
| - /// Sample labels' weights from training data. <see cref="_leafSampleWeights"/>[i] stores the weights for |
27 |
| - /// labels falling into the i-th leaf. <see cref="_leafSampleWeights"/>[i][j] is the weight of |
28 |
| - /// <see cref="_leafSamples"/>[i][j]. |
29 |
| - /// </summary> |
30 |
| - private readonly double[][] _leafSampleWeights; |
31 |
| - |
32 | 24 | /// <summary>
|
33 | 25 | /// See <see cref="LteChild"/>.
|
34 | 26 | /// </summary>
|
@@ -133,6 +125,72 @@ public IReadOnlyList<int> GetCategoricalCategoricalSplitFeatureRangeAt(int nodeI
|
133 | 125 | return _tree.CategoricalSplitFeatureRanges[nodeIndex];
|
134 | 126 | }
|
135 | 127 |
|
| 128 | + /// <summary> |
| 129 | + /// Number of leaves in the tree. Note that <see cref="NumLeaves"/> does not take non-leaf nodes into account. |
| 130 | + /// </summary> |
| 131 | + public int NumLeaves => _tree.NumLeaves; |
| 132 | + |
| 133 | + /// <summary> |
| 134 | + /// Number of nodes in the tree. This doesn't include any leaves. For example, a tree with node0->node1, |
| 135 | + /// node0->leaf3, node1->leaf1, node1->leaf2, <see cref="NumNodes"/> and <see cref="NumLeaves"/> should |
| 136 | + /// be 2 and 3, respectively. |
| 137 | + /// </summary> |
| 138 | + // A visualization of the example mentioned in this doc string. |
| 139 | + // node0 |
| 140 | + // / \ |
| 141 | + // node1 leaf3 |
| 142 | + // / \ |
| 143 | + // leaf1 leaf2 |
| 144 | + // The index of leaf starts with 1 because interally we use "-1" as the 1st leaf's index, "-2" for the 2nd leaf's index, and so on. |
| 145 | + public int NumNodes => _tree.NumNodes; |
| 146 | + |
| 147 | + internal RegressionTreeBase(InternalRegressionTree tree) |
| 148 | + { |
| 149 | + _tree = tree; |
| 150 | + |
| 151 | + _lteChild = ImmutableArray.Create(_tree.LteChild, 0, _tree.NumNodes); |
| 152 | + _gtChild = ImmutableArray.Create(_tree.GtChild, 0, _tree.NumNodes); |
| 153 | + |
| 154 | + _numericalSplitFeatureIndexes = ImmutableArray.Create(_tree.SplitFeatures, 0, _tree.NumNodes); |
| 155 | + _numericalSplitThresholds = ImmutableArray.Create(_tree.RawThresholds, 0, _tree.NumNodes); |
| 156 | + _categoricalSplitFlags = ImmutableArray.Create(_tree.CategoricalSplit, 0, _tree.NumNodes); |
| 157 | + _leafValues = ImmutableArray.Create(_tree.LeafValues, 0, _tree.NumLeaves); |
| 158 | + } |
| 159 | + } |
| 160 | + |
| 161 | + /// <summary> |
| 162 | + /// A container class for exposing <see cref="InternalRegressionTree"/>'s attributes to users. |
| 163 | + /// This class should not be mutable, so it contains a lot of read-only members. Note that |
| 164 | + /// <see cref="RegressionTree"/> is identical to <see cref="RegressionTreeBase"/> but in |
| 165 | + /// another derived class <see cref="QuantileRegressionTree"/> some attributes are added. |
| 166 | + /// </summary> |
| 167 | + public sealed class RegressionTree : RegressionTreeBase |
| 168 | + { |
| 169 | + internal RegressionTree(InternalRegressionTree tree) : base(tree) { } |
| 170 | + } |
| 171 | + |
| 172 | + /// <summary> |
| 173 | + /// A container class for exposing <see cref="InternalQuantileRegressionTree"/>'s attributes to users. |
| 174 | + /// This class should not be mutable, so it contains a lot of read-only members. In addition to |
| 175 | + /// things inherited from <see cref="RegressionTreeBase"/>, we add <see cref="GetLeafSamplesAt(int)"/> |
| 176 | + /// and <see cref="GetLeafSampleWeightsAt(int)"/> to expose (sub-sampled) training labels falling into |
| 177 | + /// the leafIndex-th leaf and their weights. |
| 178 | + /// </summary> |
| 179 | + public sealed class QuantileRegressionTree : RegressionTreeBase |
| 180 | + { |
| 181 | + /// <summary> |
| 182 | + /// Sample labels from training data. <see cref="_leafSamples"/>[i] stores the labels falling into the |
| 183 | + /// i-th leaf. |
| 184 | + /// </summary> |
| 185 | + private readonly double[][] _leafSamples; |
| 186 | + |
| 187 | + /// <summary> |
| 188 | + /// Sample labels' weights from training data. <see cref="_leafSampleWeights"/>[i] stores the weights for |
| 189 | + /// labels falling into the i-th leaf. <see cref="_leafSampleWeights"/>[i][j] is the weight of |
| 190 | + /// <see cref="_leafSamples"/>[i][j]. |
| 191 | + /// </summary> |
| 192 | + private readonly double[][] _leafSampleWeights; |
| 193 | + |
136 | 194 | /// <summary>
|
137 | 195 | /// Return the training labels falling into the specified leaf.
|
138 | 196 | /// </summary>
|
@@ -163,47 +221,9 @@ public IReadOnlyList<double> GetLeafSampleWeightsAt(int leafIndex)
|
163 | 221 | return _leafSampleWeights[leafIndex];
|
164 | 222 | }
|
165 | 223 |
|
166 |
| - /// <summary> |
167 |
| - /// Number of leaves in the tree. Note that <see cref="NumLeaves"/> does not take non-leaf nodes into account. |
168 |
| - /// </summary> |
169 |
| - public int NumLeaves => _tree.NumLeaves; |
170 |
| - |
171 |
| - /// <summary> |
172 |
| - /// Number of nodes in the tree. This doesn't include any leaves. For example, a tree with node0->node1, |
173 |
| - /// node0->leaf3, node1->leaf1, node1->leaf2, <see cref="NumNodes"/> and <see cref="NumLeaves"/> should |
174 |
| - /// be 2 and 3, respectively. |
175 |
| - /// </summary> |
176 |
| - // A visualization of the example mentioned in this doc string. |
177 |
| - // node0 |
178 |
| - // / \ |
179 |
| - // node1 leaf3 |
180 |
| - // / \ |
181 |
| - // leaf1 leaf2 |
182 |
| - // The index of leaf starts with 1 because interally we use "-1" as the 1st leaf's index, "-2" for the 2nd leaf's index, and so on. |
183 |
| - public int NumNodes => _tree.NumNodes; |
184 |
| - |
185 |
| - internal RegressionTree(InternalRegressionTree tree) |
| 224 | + internal QuantileRegressionTree(InternalQuantileRegressionTree tree) : base(tree) |
186 | 225 | {
|
187 |
| - _tree = tree; |
188 |
| - _leafSamples = null; |
189 |
| - _leafSampleWeights = null; |
190 |
| - |
191 |
| - _lteChild = ImmutableArray.Create(_tree.LteChild, 0, _tree.NumNodes); |
192 |
| - _gtChild = ImmutableArray.Create(_tree.GtChild, 0, _tree.NumNodes); |
193 |
| - |
194 |
| - _numericalSplitFeatureIndexes = ImmutableArray.Create(_tree.SplitFeatures, 0, _tree.NumNodes); |
195 |
| - _numericalSplitThresholds = ImmutableArray.Create(_tree.RawThresholds, 0, _tree.NumNodes); |
196 |
| - _categoricalSplitFlags = ImmutableArray.Create(_tree.CategoricalSplit, 0, _tree.NumNodes); |
197 |
| - _leafValues = ImmutableArray.Create(_tree.LeafValues, 0, _tree.NumLeaves); |
198 |
| - |
199 |
| - if (tree is QuantileRegressionTree) |
200 |
| - ((QuantileRegressionTree)tree).ExtractLeafSamplesAndTheirWeights(out _leafSamples, out _leafSampleWeights); |
201 |
| - else |
202 |
| - { |
203 |
| - _leafSamples = tree.LeafValues.Select(value => new double[] { value }).ToArray(); |
204 |
| - _leafSampleWeights = tree.LeafValues.Select(value => new double[] { 1.0 }).ToArray(); |
205 |
| - } |
| 226 | + tree.ExtractLeafSamplesAndTheirWeights(out _leafSamples, out _leafSampleWeights); |
206 | 227 | }
|
207 | 228 | }
|
208 |
| - |
209 | 229 | }
|
0 commit comments